What are RNNs?
RNNs stands for Recurrent Neural Networks. They use backpropagation through time instead for just forwardpassing through time. They work expecially well for sequential data processing. They are also ideal for cases when input sequences are not of the same size. Think of RNNs when you have to work with:
- Variable length input
- Reference dependencies after an interval
- Keep track of order of events
How do RNNs work?
Recurrent neural networks have loops in them that allows for information to persist through time. This re-occourance of information gets them the name Recurrent as : Information is passed from one time step to the next within the network.
Let x be the input t be the time step. y be the output ht be update to internal state
Formula for the recurrence relation can be given by
ht = fw(h(t-1), xt) where fw function with respect to weights.
What is the algorithm for RNNs?
- Initialize RNN with hidden state, random weights and take input
- Loop through each word in the sentence
- At each time step current word and previous state are used as input
- Prediction is generated for next word in the sequence and use it for updating state
- After looping and all words have been fed, a new word is output
Formula for updating the sate of the RNN and output:
ht = tanh(W1*h(t+1) + W2*xt)
where W1 and W2 are weight matrices each.
Think of RNN as multiple copies of same network where each copy sents a message to next copy based on ht, which is the internal state. Weight matrix W1 and weight matrix W2 are the same across time steps. Summation of losses accross time steps is taken gives us the total loss.
How to implement an RNN?
In TensorFlow we can use the SimpleRNN layer.
- Update hidden state by ht formula
- Take previous hidden state and input x and multiply by weight matrices
- Take summation
- Pass through a non-linear function
- Return current output and updated hidden state at each time step
How to train an RNN model?
Instead of backpropagating errors through a single feed forward network, RNNs backpropagate the error all the way back to the begining of our sequence. Problems faced when training an RNN model is:
- Exploding Gradients
- Vanishing Gradients
What are RNNs used for?
- Detecting the language used in the given text
- Machine translation
- Next word prediction
- Self-driving cars
- Text and music generation
- Weather forecasting
- Time series forcasting
- Time series data
What models are better than RNNs in some particular applications?
- LSTMs: Long Short-Term Memory models
- Attention-based models
Note: These are my notes based on MIT lectures.