Implementing an RNN from scratch in Python.
The main objective of this post is to implement an RNN from scratch and provide an easy explanation as well to make it useful for the readers. Implementing any neural network from scratch at least once is a valuable exercise. It helps you gain an understanding of how neural networks work and here we are implementing an RNN which has its own complexity and thus provides us with a good opportunity to hone our skills.
There are various tutorials that provide a very detailed information of the internals of an RNN. You can find some of the very useful references at the end of this post. I could understand the working of an RNN rather quickly but what troubled me most was going through the BPTT calculations and its implementation. I had to spent some time to understand and finally put it all together. Without wasting any more time, let us quickly go through the basics of an RNN first.
What is an RNN?
A recurrent neural network is a neural network that is specialized for processing a sequence of data x(t)= x(1), . . . , x(τ) with the time step index t ranging from 1 to τ. For tasks that involve sequential inputs, such as speech and language, it is often better to use RNNs. In a NLP problem, if you want to predict the next word in a sentence it is important to know the words before it. RNNs are called recurrent because they perform the same task for every element of a sequence, with the output being depended on the previous computations. Another way to think about RNNs is that they have a “memory” which captures information about what has been calculated so far.
Architecture : Let us briefly go through a basic RNN network.
The left side of the above diagram shows a notation of an RNN and on the right side an RNN being unrolled (or unfolded) into a full network. By unrolling we mean that we write out the network for the complete sequence. For example, if the sequence we care about is a sentence of 3 words, the network would be unrolled into a 3-layer neural network, one layer for each word.
Input: x(t) is taken as the input to the network at time step t. For example, x1,could be a one-hot vector corresponding to a word of a sentence.
Hidden state: h(t) represents a hidden state at time t and acts as “memory” of the network. h(t) is calculated based on the current input and the previous time step’s hidden state: h(t) = f(U x(t) + W h(t−1)). The function f is taken to be a non-linear transformation such as tanh, ReLU.
Weights: The RNN has input to hidden connections parameterized by a weight matrix U, hidden-to-hidden recurrent connections parameterized by a weight matrix W, and hidden-to-output connections parameterized by a weight matrix V and all these weights (U,V,W) are shared across time.
Output: o(t) illustrates the output of the network. In the figure I just put an arrow after o(t) which is also often subjected to non-linearity, especially when the network contains further layers downstream.
The ﬁgure does not specify the choice of activation function for the hidden units. Before we proceed we make few assumptions: 1) we assume the hyperbolic tangent activation function for hidden layer. 2) We assume that the output is discrete, as if the RNN is used to predict words or characters. A natural way to represent discrete variables is to regard the output o as giving the un-normalized log probabilities of each possible value of the discrete variable. We can then apply the softmax operation as a post-processing step to obtain a vector ŷof normalized probabilities over the output.
The RNN forward pass can thus be represented by below set of equations.
This is an example of a recurrent network that maps an input sequence to an output sequence of the same length. The total loss for a given sequence of x values paired with a sequence of y values would then be just the sum of the losses over all the time steps. We assume that the outputs o(t)are used as the argument to the softmax function to obtain the vector ŷ of probabilities over the output. We also assume that the loss L is the negative log-likelihood of the true target y(t)given the input so far.
The gradient computation involves performing a forward propagation pass moving left to right through the graph shown above followed by a backward propagation pass moving right to left through the graph. The runtime is O(τ) and cannot be reduced by parallelization because the forward propagation graph is inherently sequential; each time step may be computed only after the previous one. States computed in the forward pass must be stored until they are reused during the backward pass, so the memory cost is also O(τ). The back-propagation algorithm applied to the unrolled graph with O(τ) cost is called back-propagation through time (BPTT). Because the parameters are shared by all time steps in the network, the gradient at each output depends not only on the calculations of the current time step, but also the previous time steps.
Given our loss function L, we need to calculate the gradients for our three weight matrices U, V, W, and bias terms b, c and update them with a learning rate α. Similar to normal back-propagation, the gradient gives us a sense of how the loss is changing with respect to each weight parameter. We update the weights W to minimize loss with the following equation:
The same is to be done for the other weights U, V, b, c as well.
Let us now compute the gradients by BPTT for the RNN equations above. The nodes of our computational graph include the parameters U, V, W, b and c as well as the sequence of nodes indexed by t for x (t), h(t), o(t) and L(t). For each node n we need to compute the gradient ∇nL recursively, based on the gradient computed at nodes that follow it in the graph.
Gradient with respect to output o(t) is calculated assuming the o(t) are used as the argument to the softmax function to obtain the vector ŷ of probabilities over the output. We also assume that the loss is the negative log-likelihood of the true target y(t).
Please refer here for deriving the above elegant solution.
Let us now understand how the gradient flows through hidden state h(t). This we can clearly see from the below diagram that at time t, hidden state h(t) has gradient flowing from both current output and the next hidden state.
We work our way backward, starting from the end of the sequence. At the ﬁnal time step τ, h(τ) only has o(τ) as a descendant, so its gradient is simple:
We can then iterate backward in time to back-propagate gradients through time, from t=τ −1 down to t = 1, noting that h(t) (for t < τ ) has as descendants both o(t) and h(t+1). Its gradient is thus given by:
Once the gradients on the internal nodes of the computational graph are obtained, we can obtain the gradients on the parameter nodes. The gradient calculations using the chain rule for all parameters is:
We will implement a full Recurrent Neural Network from scratch using Python. We will try to build a text generation model using an RNN. We train our model to predict the probability of a word given the preceding words. It’s a generative model. Given an existing sequence of characters we sample a next character from the predicted probabilities, and repeat the process until we have a full sentence. This implementation is from Andrej Karparthy great post building a character level RNN. Here we will discuss the implementation details step by step.
General steps to follow:
- Initialize weight matrices U, V, W from random distribution and bias b, c with zeros
- Forward propagation to compute predictions
- Compute the loss
- Back-propagation to compute gradients
- Update weights based on gradients
- Repeat steps 2–5
Step 1: Initialize
To start with the implementation of the basic RNN cell, we first define the dimensions of the various parameters U,V,W,b,c.
Dimensions:Let’s assume we pick a vocabulary size vocab_size= 8000 and a hidden layer size hidden_dim=100. Then we have:
Vocabulary size can be the number of unique chars for a char based model or number of unique words for a word based model.
With our few hyper-parameters and other model parameters, let us start defining our RNN cell.
Proper initialization of weights seems to have an impact on training results there has been lot of research in this area. It turns out that the best initialization depends on the activation function (tanh in our case) and one recommended approach is to initialize the weights randomly in the interval from[ -1/sqrt(n), 1/sqrt(n)]where n is the number of incoming connections from the previous layer.
Step 2: Forward pass
Straightforward as per our equations for each timestamp t, we calculate hidden state hs[t] and output os[t] applying softmax to get the probability for the next character.
Computing softmax and numerical stability:
Softmax function takes an N-dimensional vector of real numbers and transforms it into a vector of real number in range (0,1) which add upto 1. The mapping is done using the below formula.
The implementation of softmax is:
Though it looks fine however when we call this softmax with a bigger number like below it gives ‘nan’ values
The numerical range of the floating-point numbers used by Numpy is limited. For float64, the maximal representable number is on the order of 10³⁰⁸. Exponentiation in the softmax function makes it possible to easily overshoot this number, even for fairly modest-sized inputs. A nice way to avoid this problem is by normalizing the inputs to be not too large or too small. There is a small mathematical trick applied refer here for details. So our softmax looks like:
Step 3: Compute Loss
Since we are implementing a text generation model, the next character can be any of the unique characters in our vocabulary. So our loss will be cross-entropy loss. In multi-class classification we take the sum of log loss values for each class prediction in the observation.
- M — number of possible class labels (unique characters in our vocab)
- y — a binary indicator (0 or 1) of whether class label C is the correct classification for observation O
- p — the model’s predicted probability that observation
Step 4: Backward pass
If we refer to the BPTT equations, the implementation is as per the equations. Sufficient comments added to understand the code.
While in principle the RNN is a simple and powerful model, in practice, it is hard to train properly. Among the main reasons why this model is so unwieldy are the vanishing gradient and exploding gradient problems. While training using BPTT the gradients have to travel from the last cell all the way to the first cell. The product of these gradients can go to zero or increase exponentially. The exploding gradients problem refers to the large increase in the norm of the gradient during training. The vanishing gradients problem refers to the opposite behavior, when long term components go exponentially fast to norm 0, making it impossible for the model to learn correlation between temporally distant events.
Whereas the exploding gradient can be fixed with gradient clipping technique as is used in the example code here, the vanishing gradient issue is still is major concern with an RNN.
This vanishing gradient limitation was overcome by various networks such as long short-term memory (LSTM), gated recurrent units (GRUs), and residual networks (ResNets), where the first two are the most used RNN variants in NLP applications.
Step 5: Update weights
Using BPTT we calculated the gradient for each parameter of the model. it is now time to update the weights.
In the original implementation by Andrej Karparthy, Adagrad is used for gradient update. Adagrad performs much better than SGD. Please check and compare both.
Step 6: Repeat steps 2–5
In order for our model to learn from the data and generate text, we need to train it for sometime and check loss after each iteration. If the loss is reducing over a period of time that means our model is learning what is expected of it.
We train for some time and if all goes well, we should have our model ready to predict some text. Let us see how it works for us.
We will implement a predict method to predict few words like below:
Let us see how our RNN is learning after a few epochs of training.
The output looks more like real text with word boundaries and some grammar as well. So our baby RNN has staring learning the language and able to predict the next few words.
The implementation presented here just meant to be easy to understand and grasp the concepts. In case you want to play around the model hyper parameters, the notebook is here.
Bonus: Want to visualize what’s actually going on while training an RNN, watch here.
Hope it was useful for you.Thanks for the read.