Straight to the point.
Audience for this post
If you’re not comfortable with recurrent neural networks (RNNs) through why, how, and when, then this is for you.
Clarification: I’m focusing on NLP input, hence words, for the consistency but most of what I’m describing also applies to sequence and time series data. While doing that I’m not saying anything about how we’re representing words (embedding space, integers, one hot encoded vectors) for simplicity sake, words will be just words.
NN (neural network), DNN (deep neural network), RNN (recurrent neural network), CNN (convolutional neural network).
INTRO — Why RNNs are simpler than CNNs
First a reminder — RNNs are a way to infer some category from a sequence of words.
Currently most DL courses cover CNNs before RNNs. It has some logic to it, the data itself, images, are something we all have experience with and might have even worked in computer vision classes. Most programmers have some knowledge about the RBG/gray scale structure of pixels, and they have generally witnessed the main properties of an image such as width and height. Furthermore the goals of classification (“regular” multi class, cats vs dogs like classification), object detection with bounding boxes and segmentation have all visual outcomes that have mostly clear meaning, but nonetheless CNNs hold a non trivial architecture to achieve these goals; convolutions, feature maps and pooling are the basic building blocks and each architecture among the various solutions tries to improve SOTA by tweaking those blocks and adding layers (Also back propagating errors in CNN is not trivial).
My thesis is this: RNNs can be much simpler to implement and understand which is evident by the numerous numpy implementations available online (this, or this)— without any framework involved. I personally find RNNs architecture elegant and its weight sharing genius in its simplicity. For some time I struggled with the concept of recurrent networks due to the lack of visual explanations and analogies to reality which I find necessary for me to learn new abstract concepts, and that’s my main motivation of writing this post. If the “hidden state” of RNNs are truly hidden from you, I promise that some of the obscurity will be relived soon.
** Example of confusing math notation, from: link
Given that our recurrent neural network has the following structure:
There are two obstacles in our way:
- What does the hidden state learn/represent?
- How does the various words in the sequence (sentence) get to contribute to the outcome?
I’ll start with the hidden state due to its simplicity.
HIDDEN STATE’S GIST
Here’s the gist, the hidden state learns correlations between some high level features of sentences (parts of speech for example) to the output target (sentiment for example).
RNNs learns to map words to abstract concepts (let’s say for simplicity things like ‘parts of speech’, mapping each word to ‘verb’, ’noun’ and henceforth) and map these concepts to the target function — sentiment analysis for example. Hence, learning the correlation between parts of speech and the target. Here’s a simple correlative rule, if our sentence contains mostly verbs then its sentiment is negative.
I’ll clarify this even further by an example. Let’s say our input is the following sentence: “Kicking and screaming” and our target is detecting the sentiment (Negative, Positive). I’ve defined the mapping function as “part of speech” detection, so the mapping is as follows: kicking = transitive verb, screaming = transitive word.
An analogy to a CNN, if a picture contains a tail, whiskers, paws and pointy ears then it’s a cat. CNNs work the same, they learn correlations of high level features and the target.
We utilise all of the words due to the fact the hidden state functions as a memory that aggregates (mathematical addition) the previous hidden state to the current.
Most courses use a known diagram with mathematical notation, these visual materials, personally are hard to swallow when first looking at them.
Instead I pose a question, it’s given that we have 4 layers in our network, how can we utilise our structure to use each of the words in the sentence for our prediction? RNNs are a genius in recurrence, that’s exactly what they meant for. How do they achieve this?
An illustration of this “recurrence” phenomenon is the image below. It’s a hidden state that’s being altered after each word (step).
- Multiply input by the weights of the hidden state
- Add Hidden State to Previous Hidden state
- Check output nodes after processing all of the words
Issues with RNNs
- See the first left node, at first it had large values but with time, its strength diminished, that’s due to the activation function squashing. This is not optimal due to
Recent sample bias
- See the first node on the right, it’s really noticeable due to it being calculated at the last timestamp and not being squashed many times over the iterations.
Attention is all you need — adding attention mechanism:
- What is the most important context/index of the word. Is it the first? Middle? End? Probably when iterating each word the neighbour words are the most important, hence we should consider them somehow. This is exactly what the attention mechanism is for.
I’ve tried to simplify RNNs as much as possible without leaving the important parts. It introduces a pattern of reusing weights which is part of most major breakthroughs this field has. I hope this resonates with you as well.
In the next instalment, I’ll go into detail into pros and cons of other architectures when trying to deal with sequence data.