In the previous post, we upgraded from representing words as string to representing words as vectors with word embeddings. Around the same time, GPUs and strategies to train deep neural networks took off with the AlexNet paper and started the whole deep learning craze. Neural networks showed plenty of promise in many different fields, so, naturally, the question arises: can we use neural networks to improve the quality of our language models? Most of the neural network technologies have already existed for decades but only until then was it feasible to train them with more parameters and on larger sets of data. Furthermore, since neural networks operate on vectors and we now know how to map words to vectors using embeddings. Let’s try to use the advancements in neural networks to see if we can create a better language model!

In this post, I’ll start to marry neural networks and language modeling with neural language models. Then we’ll build on those with a simple recurrence relation to construct plain/vanilla recurrent neural networks (RNNs). Then we’ll upgrade those into a modern version that’s more widely used by replacing the vanilla RNN cells with long short-term memory (LSTM) cells. Finally, we’ll train a language model on some public domain plays by Shakespeare and see how well they work!

Going forward, I’ll assume that you have a basic understanding of neural networks and how to train them.

Neural Language Models

Recall that for n-gram models, we represented the likelihood of the next word with the conditional distribution of the previous $n$ words. Take the bigram model for example: given the previous word, what’s the likelihood of the next word? Given a sequence of words, we can decompose it into a produce of single unigram/marginal and a product of conditionals.

\[p(w_1,\dots, w_N)\approx p(w_1) p(w_2|w_1)p(w_3|w_2)\cdots p(w_N|w_{N-1})\]

We computed n-grams by going through our corpus and counting and dividing. Instead, we can try to use a neural network to consume that information and predict the next work. Since neural networks operate on numerical data, we embed each word using either learned or pre-trained word embeddings, concatenate them all together, then pass those through a neural network. Similar to the classification task, the output is a probability distribution over the vocabulary. This training is almost exactly the same as normal categorical training over class labels except our “class labels” are the vocabulary of the corpus.

Neural Language Model

A simple neural language model consumes a context window of input embeddings, concatenates them together, and runs them through a some fully-connected layers to get an output of logits. Then we can take the output, softmax it, and compute the loss between the target output, i.e., the true next word, using the categorical cross-entropy loss. This trains the model to try to predict the next word given the context window.

This has similar properties to n-grams in that we have a context window of several words, but, instead of using count-and-divide explicit probabilities, we use the weights of the model to learn the next word. Naturally, we could use a huge context window to get better results (on average) but the trade-off is that it’s more computationally expensive. Even if we could afford that extra computation, the fact remains that we’re not really treating the data as sequential: we’re just concatenating the embeddings together but there’s nothing to tell the model that the second word in the context window comes after the first one. This is missing the critical factor of n-grams: the sequential modeling that word $w_i$ is a function of the words that came before it. Can we change the construction of the neural network to better encapsulate the sequentiality of the input data?

Recurrent Neural Networks

The core of the issue of language modeling with regular neural networks is that they don’t factor in the previous embedding $x_{t-1}$ when computing the next one $x_t$ sequentially. Embedding the whole sequence, we have vectors $x_1,\dots,x_N$ but how do we relate them in sequence? To start, just the embeddings themselves have limited expressivity so we shouldn’t use those vectors themselves but rather map them to a latent/hidden space using a single fully-connected layer just like what we do with plain neural networks and non-text-based data. We now have an $h_t = Wx_t + b^{(x)}$ for each input vector: $h_1,\dots,h_N$. How do we relate $h_i$ to $h_{i-1}$ and $h_{i-2}$ and so on? The simplest thing to do to start would be to use another fully-connected layer!: $h_t = Uh_{t-1} + b^{(h)}$ Combining these into a single equation and merging the biases, we get a recurrence relation:

\[h_t = \tanh(Wx_t + Uh_{t-1} + b^{(h)})\]

where

  • $W$ is the weight matrix from the input to the hidden layer
  • $x_t$ is the input embedding
  • $U$ is the weight matrix from the hidden layer back to the hidden layer
  • $b^{(h)}$ is the combined bias for the hidden layer
  • $h_{t-1}$ is the previous hidden state
  • $h_t$ is the current hidden state

(We’ll get to the choice of hyperbolic tangent $\tanh$ over the sigmoid $\sigma$ activation function in just a little bit.)

(Alternatively, we could concatenate the input and previous hidden state into a single vector and passing the concatenated state $[x_t;h_t]$ through a fully-connected layer like $W[x_t;h_t] + b$. It’s an equivalent formulation since all operations are linear and we can define $b= b^{(x)} + b^{(h)}$)

With this recurrence relation, we can now handle arbitrary-length sequences! Beyond the hidden layer, we compute the output layer at each timestep $y_t$ by running the hidden layer through another fully-connected layer. Then we can normalize it using a softmax operation to produce a probability distribution over the vocabulary.

\[\begin{aligned} h_t &= \tanh(Wx_t + Uh_{t-1} + b^{(h)})\\ \hat{y_t} &= Vh_t + b^{(y)}\\ \end{aligned}\]

where

  • $V$ is the weight matrix from the hidden to the output layer
  • $b^{(y)}$ is the bias for the output layer
  • $y_t$ is the output

Since we’re using trainable fully-connected layers and a recurrence relation, we call this a Recurrent Neural Network (RNN)! The component that computes and propagates forward the hidden state is called an RNN Cell.

Vanilla Recurrent Neural Network

A recurrent neural network (RNN) has an input, recurrent hidden state, and output. The hidden state is fed back into itself across the entire input sequence. We can represent it “folded” or “unfolded” for a few timesteps.

Going back to the choice of activation function, a few intuitive reasons we use $\tanh$ instead of a sigmoid is that the range is in $(-1,1)$ so the hidden state is more expressive of values as opposed to the range $(0,1)$. Also, we don’t have any requirements to normalize the hidden state to $(0,1)$ unlike for binary classification or probabilities. Both $\tanh$ and $\sigma$ are bounded, which is what we want since we’re accumulating the hidden state over potentially a large number of timesteps and we don’t want it to go to infinity.

Practically speaking, for the task of language modeling, the input $x_t$ and output $y_t$ sizes are the same size as the vocabulary (and also the size of the embedding matrix but it need not be). The hidden state $h_t$ size is a hyperparameter that we can set.

Backpropagation with RNNs

Now that we’ve defined an RNN, let’s see how we can train one for language modeling. The first thing we need is a corpus of text. Recall that language modeling predicts the next word given the previous history so all we need a single corpus to construct a supervised training set from it. Given the corpus and a vocabulary size, we can tokenize the text and run it through an embedding matrix to get an embedded vector $x_t$.

At each timestep, we run each embedded vector $x_t$ through the RNN to compute the next hidden state $h_t$ and the output $\hat{y_t}$. The very first hidden state $h_0$ is normally set to the zero vector. As we progress through the sequence, we keep folding in the previous hidden state $h_{t-1}$ into the current one $h_t$ so that the current hidden state is a “summarization”/representation of all of the previous history that the RNN has seen so far.

The output of the RNN $\hat{y_t}$ is computed by passing the hidden layer through a single-layer neural network to get an output vector same size as the vocabulary. The target $y_t$ is a one-hot embedding of the input offset by one word into the future since we’re trying to predict the next word. We take the predicted output, run it through a softmax layer, and then compute the categorical cross-entropy between the softmax’ed predicted output and the one-hot embedding of the target output: $\sum_t L_{CE}(\hat{y_t} - y_t)$. We do this for each timestep and sum up all of the timestep loss into a single, global loss over the entire corpus. This flavor of backpropagation is sometimes called backpropagation through time (BPTT).

Backpropagation Through Time

Backpropagation through time “unrolls” the RNN across all of the timesteps and computes a loss between each output and each target at each timestep. The total loss is summed up from the individual losses at each timestep

We accumulate this loss throughout the entire sequence and then backpropagate by “unrolling” the RNN through time. One issue with this is that the computational complexity increases with the length of the sequence; to help bound this, we chop up the full input sequence into chunks of a fixed size and just unroll and backpropagate for those chunks. This technique is sometimes called truncated backpropagation through time (tBPTT).

Truncated Backpropagation Through Time

Truncated backpropagation through time “unrolls” the RNN only for a fixed length sequence and backpropate for that subset of the sequence. However, we always propagate the hidden state forward and never reset it until we’re gone through the entire sequence.

Even though we’re only accumulating gradients for the size of the chunks, we still accumulate the hidden state over the entire sequence, but we just unroll and backpropagate over time window. This is computationally ok since the hidden state itself is a finite size.

Sampling RNN Language Models

Given a trained RNN language model, we can sample from it to generate text. We’ll need a starting word or token then we run it through the RNN to get the output vector. We run the output vector through a softmax and use that probability distribution to sample from to get the next word.

Sampling RNNs

When sampling RNNs, we need some initial seed word (or alternatively we can use a dedicated SOS token). After running that first one through an RNN, we take the output, normalize it into a distribution over the vocabulary using the softmax operator, sample from that distribution, and then use that sampled word as the next input.

Given the next word, we treat it as input into the next time step and repeat the process until we produce a sequence of words of the desired length. An alternative is to use greedy sampling where we always take the highest likelihood word but that tends to restrict the output variability. There are even better (and more complex) sampling approaches such as top-k sampling, nucleus sampling, and beam search.

RNN Flavors

So far, we’ve discussed the most basic kind of vanilla RNN, but there are a number of different improvements on this that have been made over the years so we’ll discuss a few common ones briefly.

Bidirectional RNNs

When we’re training the RNN, we always run the sequence through the RNN sequentially forward in time. However, we can give the model more information if it could also “see into the future” by also running the sequence in backwards in time and combining the data before passing to the output layer. This gives the RNN information about the previous history as well as the future at each timestep. Specifically, at each timestep $t$, we could create a joint hidden state $[h^{(f)}_t;h^{(b)}_t]$ of the forward hidden state $h^{(f)}_t$ from being propagating the sequence forward and backward hidden state $h^{(b)}_t$ from running the sequence in reverse. This flavor of RNN is called a bidirectional RNN.

Bidirectional RNN

A bidirectional RNN passes the sequence forward to compute a forward hidden state $h^{(f)}_t$ and runs the sequence in reverse to compute a backward hidden state $h^{(b)}_t$. Both are concatenated together to get a joint hidden state $[h^{(f)}_t;h^{(b)}_t]$ at a particular timestep $t$. Intuitively, this gives the model more information (both the past and the future) to compute a better output in the present.

Computing the output is the same as regular RNNs: we run the concatenated hidden state through a fully-connected layer to get an output.

Stacked RNNs

Similar to how deep neural networks provide more expressive power when we add more hidden layers between the input and output, we can do the same thing with RNNs and stack RNN cells on top of each other. We do this by feeding a hidden state at a particular layer $h_t^{(l)}$ as the input to the next layer’s RNN cell at the same timestep. We keep doing this through the layers until we get to the last one, then we compute the output as usual. This flavor of RNN is sometimes called a stacked RNN.

Stacked RNN

A stacked RNN layers the hidden cells so that the output of a hidden cell at one layer $h_t^{(l)}$ is propagated as input to a hidden cell at a higher level $h_t^{(l+1)}$.

These are also trained and sampled in the exact same way as regular RNNs. However, just like with deep neural networks, these have more parameters as a function of how many hidden layers we have so they’ll perform better (on average) but take longer to train!

Long Short-term Memory (LSTM) Cells

RNNs seem really great at capturing the sequence nature of text and language, but the vanilla RNNs we’ve seen so far suffer from two major issues: (i) exploding gradient and (ii) vanishing gradient. Both of these issues arise from backpropagating the gradient at the current timestep through all of the previous timesteps to the start of the sequence. Recall that when we’re backpropagating through hidden layers of a neural network, we use the chain rule of calculus to multiply the gradient by a factor for each layer we backpropagate through. Backpropagation through time does a similar thing except the gradient is backpropagated through the timesteps back to the start of the sequence. When we’re moving backwards through the timesteps, we’re multiplying the gradient by some factor, let’s call it $\alpha$, each time. So at a timestep $t$, going all the way back to the first timestep, we have a long product of those factors like $\alpha_1\cdots\alpha_{t}$. If each factor is exactly $1$, then the product is also $1$. If most of the factors are greater than $1$, then the product is going to go off to infinity. On the other hand, if most of the factors are less than $1$, then the product is going to go towards zero. The former problem is the exploding gradient problem and the latter is the vanishing gradient problem.

Backpropagating a Single Output

For a single output, we have to backpropagate to the start of the sequence through the hidden layer at each timestep since all of the weights and biases for the hidden state are re-used for each timestep. This means multiplying the gradient by some factor for each timestep we backpropagate through a timestep.

For the exploding gradient problem, a crude but very effective and direct solution is to clip all of the gradients into a finite range before updating the model parameters. A common range to clip the gradients to is $(-5, 5)$. While it might seem that, since we’re intentionally clipping the gradients, the RNN will train slower, it actually means that the training is going to be far more stable and overall take less time since we won’t be jumping the parameters everywhere.

Unfortunately, the vanishing gradient problem is more challenging to resolve. It isn’t a novel problem since we see it in regular neural networks: as we add more and more layers, the gradient gets smaller and smaller until it approaches zero and the earlier layers get no gradient signal so their weights and biases don’t update.

With RNNs, we have a similar problem with the gradient vanishing, but not in space, in time! For longer sequences, when we unroll the RNN, the gradient vanishes by the time we get to the earlier part of the sequence. This prevents our model from learning long-term relationships between our words.

For a more mathematical treatment of both of these issues, check the Appendix!

The vanishing gradient problem is fundamental to the RNN recurrent relation itself so instead of trying to shoehorn “solutions” to the issue, it would be better to redesign the entire RNN cell. Remember that the root of the issue is that, when the gradient backpropagates backwards through the hidden layers, we end up multiplying by a factor at each timestep. Instead of multiplicative operations, additive operations are a bit easier for the gradient since addition acts as a “gradient copier” and preserves, not attenuates, the gradient. So it seems like we need an alternative or additional mechanisms that allows the gradient to more easily flow, unedited, to earlier timesteps.

(As an aside, I’m going to need to make some hand-wavy justifications and sequence of steps to get us to where we’re trying to go since it’s difficult to directly motivate a solution. This is a somewhat common theme in machine learning, but I think that’s fine. Research sometimes requires us to take a leap using our intuition and evaluate our solution to see how well our intuition works out.)

The first thing we can try is to define a new kind of state called the cell state $C_t$ that we can propagate forward. Ideally we want to avoid multiplicative operations on this state so that the gradient can flow but how do we populate it? The most straightforward and simplest thing to do is to use the previous hidden state and carry it forward $C_t=C_{t-1}$ but this doesn’t provide any input into it. Similar to what we did with the hidden state, we can take the input and previous hidden state through a fully-connected layer and $\tanh$ activation and add it to the cell state.

\[\begin{align*} g_t &= \tanh(W^{(g)}x_t + U^{(g)}h_{t-1} + b^{(g)})\\ C_t &= C_{t-1} + g_t\\ \end{align*}\]

This new $g_t$ is a candidate gate that gates the values of what we’d like to put into the cell state. However the cell state isn’t bounded in any way: if we keep adding to it, even by clipping the gradient, it could become infinity! Furthermore, we’re assuming that we want to pass forward the entire previous cell state and the entire candiate input. In both scenarios, rather than hard-coding what information we should preserve and what information we should take into the cell state, we can have the model learn what to do. For the former, we want to learn which information we remember and which information we forget; for the latter, we want to learn which information to put into the cell state. We can use two more fully-connected layers to gate the information we pass forward to the current cell state from the previous one as well as the information we take from the input to the current cell state.

\[\begin{align*} f_t &= \sigma(W^{(f)}x_t + U^{(f)}h_{t-1} + b^{(f)})\\ g_t &= \tanh(W^{(g)}x_t + U^{(g)}h_{t-1} + b^{(g)})\\ i_t &= \sigma(W^{(i)}x_t + U^{(i)}h_{t-1} + b^{(i)})\\ C_t &= f_t\odot C_{t-1} + i_t\odot g_t\\ \end{align*}\]

These are called the forget gates $f_t$ and input gates $i_t$. $\odot$ is called the Hadamard product which is a fancy name for an element-wise product. $\sigma$ is the usual sigmoid function. Note that the forget and input gates use the sigmoid so components of those gates with a value close to $0$ mean we’ll “forget” or “ignore” those components of the previous cell state and candidate gate as well. For values close to $1$, we’ll “remember” or “retain” those components.

What about the hidden state? Do we even need it or can we just use the new cell state that we’ve developed? As it turns out, yes we do because they perform the same purpose but at different scales. The intent of the cell state is to act as a long-term memory and retain long-term information (we’re being careful and intentional about gradient propagation) while the hidden state acts as a short-term memory or “working memory”.

Now that we have both hidden and cell states, how do we relate them to complete the loop? Since the cell state represents long-term memory, we don’t want to just copy it into the hidden state since it defeats the purpose of these two state. Instead, we can follow a similar pattern to the cell state and learn which parts of cell state, i.e., long-term memory, to apply to the hidden state, i.e., working memory, using a new gate called the output gate $o_t$ that we element-wise multiply with the previous cell state.

\[\begin{align*} f_t &= \sigma(W^{(f)}x_t + U^{(f)}h_{t-1} + b^{(f)})\\ g_t &= \tanh(W^{(g)}x_t + U^{(g)}h_{t-1} + b^{(g)})\\ i_t &= \sigma(W^{(i)}x_t + U^{(i)}h_{t-1} + b^{(i)})\\ o_t &= \sigma(W^{(o)}x_t + U^{(o)}h_{t-1} + b^{(o)})\\\\ C_t &= f_t\odot C_{t-1} + i_t\odot g_t\\ h_t &= o_t\odot\tanh(C_t)\\ \end{align*}\]

Intuitively, the model will learn which parts of the long-term memory to put into the working memory! Note that we pass the current cell state through a $\tanh$ layer so that the hidden state is still bounded in the same range of $(-1, 1)$ except we use the new output gate to determine which components of the cell state to write to the hidden state.

With some hand-waving and intuition, we’ve created the Long Short-term Memory (LSTM) cell!

Long Short-term Memory

A Long Short-term Memory cell has a cell state that’s propagated forward that’s written to in a more intentionally way than vanilla RNNs. It features four gates: forget, candidate, input, and output. The forget gate determines which parts of the cell state to forget; the candidate and input gates determine what to write to the cell state and which parts to write to; finally the output gate is used to determine which parts of the cell state are written to the hidden state.

In addition to the LSTM cell equations, we also have an output as well so the full set of equations becomes the following.

\[\begin{align*} f_t &= \sigma(W^{(f)}x_t + U^{(f)}h_{t-1} + b^{(f)})\\ g_t &= \tanh(W^{(g)}x_t + U^{(g)}h_{t-1} + b^{(g)})\\ i_t &= \sigma(W^{(i)}x_t + U^{(i)}h_{t-1} + b^{(i)})\\ o_t &= \sigma(W^{(o)}x_t + U^{(o)}h_{t-1} + b^{(o)})\\\\ C_t &= f_t\odot C_{t-1} + i_t\odot g_t\\ h_t &= o_t\odot\tanh(C_t)\\ \hat{y_t} &= Vh_t + b^{(y)} \end{align*}\]

Now that we’ve intuited an LSTM cell with all of its gates and structure, let’s look a bit more closely at the individual gates and their intents.

The forget gate $f$ selectively removes/forgets information in the cell state/long-term memory as a function of the current input. The activation function we use for $f_t$, is the sigmoid to get a value in $(0, 1)$, and then we take the element-wise product of it with the cell state. A value of 0 for a component $j$ means we’ll forget the $j$th component of the cell state, and a value of 1 means we’ll remember that component of the cell state.

The input gate, similar to the forget gate, determines which parts of the cell state will take on new values. The candidate gate actually produces those values (specifically in the $(-1,1)$ range). We take the element-wise product of those two gate vectors as the combined result of values to add into the current cell state (after applying the forget gate).

The output gate $o$ is used to determine which parts of the new, updated cell state/long-term memory make it to the hidden state/short-term memory.

Recall that the whole purpose for redesigning the vanilla RNN cell was to avoid/mitigate the vanishing gradient problem. Do we accomplish this with the cell state? With the hidden state for plain RNNs, we were multiplying by a constant, variable factor each time. With the current cell state, we’re adding a scaled version of the previous cell state to the candidate and input gates. Addition means that the gradient is copied and preserved across the addition! The only complication is the forget gate sigmoid activation. This will still technically attenuate the gradient, but at a much smaller rate than with RNNs!

Again, for a more mathematical treatment, check the Appendix!

In terms of training and sampling LSTM cells, they’re exactly the same as plain RNN cells! Everything we’ve seen about training and sampling is exactly the same! In fact, we can construct training frameworks that are agnostic to the kind of RNN that we’re training since all RNNs fundamentally operate on the same kinds of inputs and outputs even if their internal cell representations are different.

There are so many different flavors of RNNs and LSTMs and some of them work better than others for certain kinds of tasks. For example, there’s a different flavor called Peephole LSTMs where the gates can peek at the previous cell state in addition to the previous hidden state. There’s another very popular kind of cell called a Gated Recurrent Unit (GRU) that is functionally similar to an LSTM cell but computationally much cheaper with only two gates: (i) update and (ii) reset.

Lots of people have tried a number of different things to help RNNs and I’d encourage you to experiment as well!

Training an RNN Language Model

So far we’ve done a lot of theory and maths behind RNNs and LSTMs but now it’s time to train one! Specifically, let’s train a character-based RNN and LSTM on Shakespeare’s plays and see if we can get it to generate some dialogue. There’s going to be a lot of boilerplate code to load the dataset and setup logging and whatnot so the see the full (thoroughly-commented!) code on my GitHub!

The first thing we can do is define the vanilla RNN model in Pytorch (check out the Pytorch documentation for the APIs! They’re pretty straightforward but I’ll try to explain the more complicated bits). To do this, we’ll need to know the input embedding size, the hidden state size, and the output size. We need to define an embedding layer that maps each index into a full one-hot vector and then into the embedding space (same as the size of the vocabulary). Fortunately Pytorch has a submodule called nn.Embedding that does that for us! Then we’ll need three fully-connected layers: input to hidden state, hidden state to hidden state, and hidden state to output! Pytorch also has nn.Linear that defines a fully-connected layer with weights and a bias.

class RNN(nn.Module):

    def __init__(self,
                 input_size: int,
                 hidden_size: int,
                 output_size: int,
                 _num_layers: int):
        super().__init__()

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size

        # define an embedding layer to map index inputs to a learned dense vector
        self.embedding = nn.Embedding(input_size, input_size)
        self.i2h = nn.Linear(input_size, hidden_size)
        self.h2h = nn.Linear(hidden_size, hidden_size)
        self.h2o = nn.Linear(hidden_size, output_size)

We’re not going to use _num_layers for the vanilla RNN; we just need it to have a uniform constructor but you’re welcome to try to implement a stacked RNN! Afterwards, we can define the forward pass function as taking in an input sequence of size (seq_size, vocab_size) (first dimension is the sequence size for truncated backpropagation through time and the second dimension is the size of the vocabulary where each entry is an index into the vocabulary) and hidden state of size (1, hidden_size).

    def forward(self, x: torch.Tensor, h=None):
        # initialize hidden state if none was provided
        if h is None:
            h = torch.zeros(1, self.hidden_size).to(x.device)

        seq_size, _ = x.size()
        out = []

        # run each token through the RNN and collect the outputs
        for t in range(seq_size):
            embedding = self.embedding(x[t])
            h = F.tanh(self.i2h(embedding) + self.h2h(h))
            o = self.h2o(h)
            out.append(o)
        out = torch.stack(out)

        # detach hidden state so we can optimize over it over the sequence
        return out, h.detach()

For each timestep, we run it through the embedding layer, compute the hidden state (re-using the variable so we propagate it forward!) and finally computing the output. Since we have a sequence, we keep the list of outputs and stack them into a Pytorch tensor. Finally, we return the sequence of output as well as the accumulated hidden state!

The LSTM variant is also fairly straightforward. The only nuance is that the “hidden state” is actually the hidden state concatenated with the cell state. We do this to abide by Pytorch conventions but there’s nothing stopping us from accepting multiple inputs and producing multiple outputs. As it turns out, Pytorch also has a (much better) implementation of RNNs and LSTMs so we can use that as well! Check the GitHub for implementation details!

Putting aside the model implementation, now let’s see how to prepare our text corpus and the main training loop. (I’m going to omit any boilerplate or additional logic for the sake of brevity.) First thing we need to do is load the corpus and create a “vocabulary” of characters. Then we can convert each character in the corpus into an index in the vocabulary and turn it into a Pytorch tensor.

with open(args.corpus, 'r') as f:
    corpus = f.read()

unique_chars = sorted(set(corpus))
vocab_size = len(unique_chars)

# create mappings between chars and indices
ch_to_ix = {ch: ix for ix, ch in enumerate(unique_chars)}
ix_to_ch = {ix: ch for ix, ch in enumerate(unique_chars)}

# convert string corpus into Pytorch tensors
data = [ch_to_ix[ch] for ch in corpus]
data = torch.tensor(data).to(device)

# reshape into tensor format: num_chars x 1
data = torch.unsqueeze(data, dim=1)

In practice, for very large datasets, we generally can’t load them all into memory at once so instead, we stream to the model and only keep a buffer in memory; it’s a bit slower than loading everything into memory but it means we can train on very large data sets! Now we can create our model and define our optimizer (Adam) and loss function (categorical cross-entropy).

# create model
model_init = get_model_type(args.model_arch)
model = model_init(vocab_size, args.hidden_size, vocab_size, args.num_layers).to(device)

# create loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)

Since the model architecture is an input argument, we do quick mapping of the string to a Python class (get_model_type) and instantiate it (note all classes have the same constructor for this reason) using model_init. For Pytorch’s nn.CrossEntropyLoss, it’ll handle the normalization for us so we don’t need an explicit softmax operation.

Now we can get into the main training loop over the number of epochs. Also, according to backpropagation through time, we also have a sequence size that we backpropagate over instead of the entire sequence so we can iterate over chunks of that size. Then we can chunk our source and target sequences. Remember that the target sequence is the source sequence but offset by one character.

for e in range(args.num_epochs):
    epoch_loss = 0
    hidden_state = None

    for i in range(0, len(data), args.sequence_size):
        # extract source and target sequences of len sequence_size
        source = data[i:i+args.sequence_size]
        # target sequence is offset by 1 char
        target = data[i+1:i+args.sequence_size+1]

Now it’s as simple as running both through our model and backpropagating! Remember to clip the gradients before doing the backward pass!

# run source (and hidden state) through model and compute loss of target set
output, hidden_state = model(source, hidden_state)
loss = criterion(torch.squeeze(output), torch.squeeze(target))

# compute gradients
optimizer.zero_grad()
loss.backward()

# clip the gradient to prevent exploding gradient!
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad)

# update parameters
optimizer.step()

At each epoch, we can also sample from our model to see how it improves as it trains.

# sample output every epoch
sampled_output = ''.join(ix_to_ch[i] for i in sample(device, model, args.output_sequence_size))

Sampling is also fairly straightforward: we start by picking a random character to start the sequence (or we could select a seed input). Then we can run the hidden state and input through the model, normalize the output into a probability distribution over the characters, and sample from that distribution. Remember to set the input to the sampled output so it’s updated for the next output!

def sample(device: torch.device,
           model: nn.Module,
           output_seq_size: int) -> list[int]:
    hidden_state = None

    # store output as list of indices
    sampled_output = []
    # create an input tensor from a random index/character in the input set
    random_idx = np.random.randint(model.input_size)
    seq = torch.tensor(random_idx).reshape(1, 1).to(device)

    for _ in range(output_seq_size):
        output, hidden_state = model(seq, hidden_state)

        # normalize output into probability distribution over all characters
        probs = F.softmax(torch.squeeze(output), dim=0)
        dist = torch.distributions.categorical.Categorical(probs)

        # sample from the distribution and append to list
        sampled_idx = dist.sample()
        sampled_output.append(sampled_idx.item())

        # reset sequence to sampled char for next loop iteration
        seq[0][0] = sampled_idx.item()
    return sampled_output

In the GitHub repo, I’ve also put some logs and pre-trained models for our custom vanilla RNN as well as the Pytorch LSTM, both trained on the Shakespeare corpus for 32 epochs. We can see that, even for the vanilla RNN, within the first epoch, it starts to learn a bit about the structure of how plays are written and even starts to get names right!

hatbetts,
Well by you shokseecing.

ANTONIO:
Son, wrworn, speak your fore them.

SEBASTIAN:
Witholers ndvery backs.

ASTONSON:
W

The LSTM at the first epoch does a bit better since it also remembers words.

ry a poor tinis
Would tisle bechosh attein, and I,
My father, and risun.

ANTONIO:
You boumon manicable.

ANTONIO:
All old thou 

In the later epochs, we start to get some better results from the vanilla RNN.

ALONSO:
Of the so speanfelty,
I do my should shipt yould your and ateal,--
I pind, adve yeny youbt sones in you so.

SEBASTIAN:

The LSTM does even better.

GONZALO:
'Tis incapress to us in actions.

ANTONO:
Therefore I will not
Some pillaria, but what again, woul

See the full code full code, logs, and pre-trained models on my GitHub! Try to train your own RNN on your own corpus or generate text using the pre-trained models!

Conclusion

In this post, we graduated from n-grams and plain neural networks into recurrent neural networks (RNNs) that can handle arbitrary-length sequences and more correctly model the sequential nature of language! We discussed how to train them and how to sample from them. We also saw a few variants that also ran the sequence in reverse and combined the hidden states (bidirectional RNNs) as well as a variant that stacked them deep like neural networks (stacked RNNs). Vanishing and exploding gradients are the primary issue with these and that motivated us to create the long short-term memory (LSTM) cell to help address vanishing gradient. Finally, we saw how to train RNNs using Pytorch and saw some example outputs during and near the end of training!

In the next post, we’ll finally get to a state-of-the-art language model called a Transformer, the very same ones used by many different Large Language Models (LLMs) such as OpenAI’s ChatGPT and Anthropic’s Claude! 🙂

Appendix

Vanishing/Exploding Gradient in Plain and LSTM Cells

To better mathematically see how vanishing and exploding gradient appear, we have to derive the backpropagation equations for the gradient from the RNN equations. Specifically, we have to compute the derivative of the loss function with respect to the hidden state weight matrix $\frac{\p L}{\p U}$ since it’s the main parameter used by the hidden states.

We’ll be accumulating the total gradient as we move backward so we’ll start with $\frac{\p L}{\p U}$ and then expand into smaller pieces using the chain rule of calculus.

For any RNN, the total loss is the sum of the individual losses at each timestep.

\[L = \sum_t L_t\]

Since this simply sums over all of the individual losses, the gradient is just copied; to help focus, let’s ignore the top-level gradient and just worry about a particular $\frac{\p L_t}{\p U}$, knowing that we can just sum over $t$ to the the total loss. For each individual loss, we’re using categorical cross-entropy using the true “next word” $y_t$ and the predicted one from the model $\hat{y_t}$.

\[L_t = L_{\text{CE}}(y_t, \hat{y_t})\]

Instead of getting right into a generic solution, let’s try to compute it by hand for a small sequence of size three.

BPTT for an RNN

In a toy example of an RNN with three timesteps, the green arrows show the gradient moving backwards; on the arrows are the local gradients! To compute a derivative with respect to a parameter, we multiply the local gradients along all paths behind the current timesteps to the target parameter and sum them up. For example, $\frac{\p L_1}{\p U}=\frac{\p L_1}{\p \hat{y_1}}\frac{\p \hat{y_1}}{\p h_1}\frac{\p h_1}{\p U}$.

Following the local gradient, the loss at the first timestep is the easiest since there are no previous timesteps to apply to.

\[\frac{\p L_1}{\p U}=\frac{\p L_1}{\p \hat{y_1}}\frac{\p \hat{y_1}}{\p h_1}\bigg( \frac{\p h_1}{\p U} \bigg)\]

Pretty straightforward! Now let’s look at the loss at the second timestep where we have the loss at the second timestep as well as the one from the first timestep.

\[\frac{\p L_2}{\p U}=\frac{\p L_2}{\p \hat{y_2}}\frac{\p \hat{y_2}}{\p h_2}\bigg( \frac{\p h_2}{\p U} + \frac{\p h_2}{\p h_1}\frac{\p h_1}{\p U} \bigg)\]

Notice the first term in the parentheses is the same, but the second term arises since we have to backpropagate to the first timestep using $\frac{\p h_2}{\p h_1}$. Now let’s do the same for the third timestep.

\[\frac{\p L_3}{\p U}=\frac{\p L_3}{\p \hat{y_3}}\frac{\p \hat{y_3}}{\p h_3}\bigg( \frac{\p h_3}{\p U} + \frac{\p h_3}{\p h_2}\frac{\p h_2}{\p U} + \frac{\p h_3}{\p h_2}\frac{\p h_2}{\p h_1}\frac{\p h_1}{\p U} \bigg)\]

See the pattern? At a particular timestep $t$, we backpropagate to each earlier timestep using the hidden states and then, after we get to a timestep, we backpropagate a little bit into the weight matrix.

Now that we’ve seen an example, let’s go back and try to formulate this more generically across an arbitrary number of timesteps.

\[\frac{\p L_t}{\p U} = \frac{\p L_t}{\p \hat{y_t}}\frac{\p \hat{y_t}}{\p U}\]

(I’m abusing some notation since taking the derivative with respect to a matrix is technically undefined.) To compute $\hat{y_t}$, we use the output weight matrix and bias applied to the hidden state at timestep $t$.

\[\hat{y_t} = Vh_t + b^{(y)}\]

To get to the hidden state, we have to backpropagate through the output weight matrix.

\[\frac{\p L_t}{\p U}=\frac{\p L_t}{\p \hat{y_t}}\frac{\p \hat{y_t}}{\p h_t}\frac{\p h_t}{\p U}\]

We have to break down $\frac{\p h_t}{\p U}$ carefully since we’re applying the hidden state weight matrix $U$ at each of the previous timesteps. Consider the diagram: we have multiple gradients going into $U$ so we have to sum over them. Getting the gradient at $t$ is straightforward, but what about the earlier timesteps? They also depend on $U$ since it’s the same one we use for all timesteps! We can move the gradient backwards through the hidden states since $h_t$ depends on $h_{t-1}$ and $h_{t-1}$ depends on $h_{t-2}$ and so on. So $\frac{\p h_t}{\p U}$ really expands into a sum over all of the previous timesteps.

\[\frac{\p L_t}{\p U}=\frac{\p L_t}{\p \hat{y_t}}\frac{\p \hat{y_t}}{\p h_t}\sum_{k=1}^t \frac{\p h_k}{\p h_{k-1}}\frac{\p h_{k-1}}{\p U}\]

The first term in the sum moves the gradient back to earlier timesteps while the second term backpropagates into the hidden state weight matrix $U$. And we sum over all of the previous timesteps in the sequence up to timestep $t$.

However $\frac{\p h_k}{\p h_{k-1}}$ can be expanded out again using the chain rule into a product! For example, if we’re at timestep $t$ trying to go back to some timestep $t-3$, then we need to go back through the hidden states at $t-1$, $t-2$, and finally $t-3$ so the product looks like $\frac{\p h_t}{\p h_{t-1}}\frac{\p h_{t-1}}{\p h_{t-2}}\frac{\p h_{t-2}}{\p h_{t-3}}$. So we can expand $\frac{\p h_k}{\p h_{k-1}}$ into a product.

\[\frac{\p L_t}{\p U}=\frac{\p L_t}{\p \hat{y_t}}\frac{\p \hat{y_t}}{\p h_t}\sum_{k=1}^t \bigg(\prod_{j=k+1}^{t} \frac{\p h_j}{\p h_{j-1}}\bigg) \frac{\p h_{k-1}}{\p U}\]

This is full gradient of $\frac{\p L_t}{\p U}$!

Now that we have the gradient of the hidden state weight matrix, we can finally investigate the vanishing and exploding gradient problems! Since both of these problems occur with the gradient moving backwards to the earlier timesteps, the core of the issue lies in the product term, specifically the partial derivative of the next hidden state with the previous one.

\[\prod_{j=k+1}^{t} \frac{\p h_j}{\p h_{j-1}}\]

Recall what we said earlier, since this is a product, if most of these terms are less than 1, then we get vanishing gradient issue. If most of these terms are greater than 1, then we get the exploding gradient issue. Let’s expand this term and investigate!

\[\begin{align*} \frac{\p h_j}{\p h_{j-1}} &= \frac{\p}{\p h_{j-1}}\tanh(Wx_t + Uh_{j-1} + b^{(h)})\\ &= \tanh'(Wx_t + Uh_{t-1} + b^{(h)})\frac{\p}{\p h_{j-1}}\bigg[ Wx_t + Uh_{j-1} + b^{(h)}\bigg]\\ &= \tanh'(Wx_t + Uh_{t-1} + b^{(h)})U\\ \end{align*}\]

Between the first two steps, we backpropagate through the $\tanh$ non-linearity and then directly take the derivative. So the culprit is $U$! We’re compounding $U$ at each timestep which will cause our gradients to either vanish or explode. $U$ is a matrix so it’s more difficult to reason about what kind of $U$ will cause either vanishing or exploding gradients. Fortunately, that part of the work has already been done for in the Appendix of “On the Difficulty of Training Recurrent Neural Networks” by Pascanu, Mikolov, and Bengio. If we compute the largest eigenvalue of $U$, we can prove if the magnitude of that eigenvalue is greater than $1$, then the gradient will grow exponentially fast; if it’s smaller than $1$, then the gradient in the limit will approach $0$.

This is the mathematical reason that we get vanishing and exploding gradients in RNNs!

How do LSTMs fair? Recall that with the LSTM, we propagate the cell state forward in a similar way to the hidden state so we can look at the partial derivative of the current cell state with respect to the previous one.

\[\frac{\p c_k}{\p c_{k-1}} = \frac{\p}{\p c_{k-1}}\bigg[ f_k\odot c_{k-1} + i_k\odot g_k \bigg]\]

How far back in the equation do we go? We have to keep unraveling it until all of the $c_{k-1}$s are found. Remember that $f_k$, $i_k$, and $g_k$ are all functions of $h_{k-1}$ which is a function of $c_{k-1}$! Let’s start by applying the chain rule.

\[\begin{align*} \frac{\p c_k}{\p c_{k-1}} &= \frac{\p}{\p c_{k-1}}f_k\odot c_{k-1} + f_k\odot \frac{\p}{\p c_{k-1}} c_{k-1} + \frac{\p}{\p c_{k-1}}i_k\odot g_k + i_k\odot \frac{\p}{\p c_{k-1}}g_k\\ &= \frac{\p f_k}{\p c_{k-1}}\odot c_{k-1} + f_k + \frac{\p i_k}{\p c_{k-1}}\odot g_k + i_k\odot \frac{\p g_k}{\p c_{k-1}}\\ &= c_{k-1}\frac{\p f_k}{\p c_{k-1}} + f_k + g_k\frac{\p i_k}{\p c_{k-1}} + i_k\frac{\p g_k}{\p c_{k-1}}\\ \end{align*}\]

So we have three other partial derivatives that we have to compute, one for each gate except the output gate (which we’ll encounter later). Let’s compute each one in turn.

\[\begin{align*} \frac{\p f_k}{\p c_{k-1}} &= \frac{\p}{\p c_{k-1}}\sigma(W_f x_f + U_f h_{k-1} + b_f)\\ &= \frac{\p}{\p c_{k-1}}\sigma(z_f)\\ &= \sigma'(z_f)\frac{\p}{\p c_{k-1}}\bigg[W_f x_f + U_f h_{k-1} + b_f\bigg]\\ &= \sigma'(z_f)U_f\frac{\p}{\p c_{k-1}}\bigg[h_{k-1}\bigg]\\ &= \sigma'(z_f)U_f\frac{\p}{\p c_{k-1}}\bigg[o_{k-1}\odot\tanh(c_{k-1})\bigg]\\ &= \sigma'(z_f)U_f o_{k-1}\frac{\p}{\p c_{k-1}}\bigg[\tanh(c_{k-1})\bigg]\\ &= \sigma'(z_f)U_f o_{k-1}\tanh'(c_{k-1})\\ \end{align*}\]

As it turns out, the other partial derivatives are basically the same with some constants being different so I’ll skip the derivations.

\[\begin{align*} \frac{\p i_k}{\p c_{k-1}} &= \sigma'(z_i)U_i o_{k-1}\tanh'(c_{k-1})\\ \frac{\p g_k}{\p c_{k-1}} &= \tanh'(z_g)U_g o_{k-1}\tanh'(c_{k-1})\\ \end{align*}\]

Combining all of these together, we have the partial derivative $\frac{\p c_k}{\p c_{k-1}}$.

\[\begin{align*} \frac{\p c_k}{\p c_{k-1}} = &c_{k-1}\sigma'(z_f)U_f o_{k-1}\tanh'(c_{k-1})\\ &+ f_k\\ &+ g_k\sigma'(z_i)U_i o_{k-1}\tanh'(c_{k-1})\\ &+ i_k\tanh'(z_g)U_g o_{k-1}\tanh'(c_{k-1}) \end{align*}\]

Let’s compare the product of this with the vanilla RNN side-by-side.

\[\begin{align*} \prod_{j=k+1}^{t} \frac{\p h_k}{\p h_{k-1}} &= \prod_{j=k+1}^{t}\tanh'(Wx_t + Uh_{t-1} + b^{(h)})U\\ \prod_{j=k+1}^{t} \frac{\p c_k}{\p c_{k-1}} &= \prod_{j=k+1}^{t}\bigg[ c_{k-1}\sigma'(z_f)U_f o_{k-1}\tanh'(c_{k-1}) + f_k + g_k\sigma'(z_i)U_i o_{k-1}\tanh'(c_{k-1}) + i_k\tanh'(z_g)U_g o_{k-1}\tanh'(c_{k-1})\bigg] \end{align*}\]

The LSTM one is very different (and more complex) than the similar one for plain RNNs! The most important part of it is that it’s additive: with the plain RNN the term was multiplicative. So when we multiply everything together, for the vanilla RNN, we get a giant product that could explode or vanish. On the other hand, for the LSTM, we’d still have a sum! This means there’s a much fewer chance of the gradient vanishing since addition copies the gradient. Also know that the forget gate is right there in the equation that can help adjust the gradient to prevent it from vanishing; the values of the gate are learned so that the LSTM can decide when it should prevent the gradient from vanishing.