Recurrent Neural Networks & LSTM
From simple recurrence to gated memory cells. A comprehensive, visually interactive deep dive into sequence modeling architectures that learn temporal dependencies in data.
Begin Learning ↓From simple recurrence to gated memory cells. A comprehensive, visually interactive deep dive into sequence modeling architectures that learn temporal dependencies in data.
Begin Learning ↓How the quest to model sequential data led to recurrent architectures and gated memory cells.
The idea of feeding a network's output back into itself dates to the earliest days of neural network research. In 1982, John Hopfield introduced the Hopfield Network, a fully connected recurrent system that could store and recall patterns through energy minimization. While not designed for sequences, it proved that recurrence could endow a network with a form of memory.
The modern era of sequence modeling began with Jeffrey Elman, who in 1990 published "Finding Structure in Time." Elman proposed a simple recurrent network (SRN) where hidden-state activations at time \( t-1 \) are copied into special "context units" and fed back as additional input at time \( t \). This elegant trick gave the network a short-term memory of its recent past, enabling it to learn temporal patterns in language and other sequential data.
Elman's architecture was remarkably simple: a standard feedforward network augmented with a single recurrent connection. Yet it was powerful enough to learn grammar-like rules from raw character sequences, sparking an explosion of interest in recurrent approaches to language modeling, speech recognition, and time-series prediction.
Despite Elman's success, vanilla RNNs suffered from a crippling limitation: they could not learn long-range dependencies. When the gap between relevant information and the point where it is needed grows large, gradients either vanish or explode during training, making learning impossible.
In 1997, Sepp Hochreiter and Jurgen Schmidhuber published their landmark paper introducing the Long Short-Term Memory (LSTM) network. The key innovation was the cell state -- a dedicated memory highway that runs through the entire sequence, regulated by learned gates that control what information to store, forget, and output. This architecture solved the vanishing gradient problem by allowing gradients to flow unchanged through the cell state.
The LSTM was refined over the following decade. In 2000, Felix Gers added the forget gate, which became essential for allowing the cell to reset its memory. By the mid-2010s, LSTMs had become the dominant architecture for virtually every sequence modeling task, from machine translation to speech synthesis.
Why sequences need memory and how the hidden state carries information forward through time.
Standard feedforward neural networks treat each input independently. They have no notion of order, context, or history. But many real-world problems are inherently sequential: the meaning of a word depends on the words before it; tomorrow's stock price depends on today's; the next note in a melody depends on the preceding phrase.
Consider predicting the next word in the sentence "The clouds are dark and it is likely to ___." A feedforward network seeing only the word "to" has no information to work with. But a recurrent network that has processed the entire sentence so far carries a compressed summary of "clouds," "dark," and "likely" in its hidden state, enabling it to predict "rain" with confidence.
The key insight is that sequence modeling requires memory. The network must maintain an internal state that summarizes everything it has seen so far, updating this state at each timestep as new information arrives. This is exactly what the hidden state \( h_t \) provides in a recurrent neural network.
At each timestep \( t \), the RNN receives two inputs: the current observation \( x_t \) and the previous hidden state \( h_{t-1} \). It combines these to produce a new hidden state \( h_t \) that encodes both the current input and the relevant history:
The hidden state \( h_t \) is a fixed-size vector that must compress all relevant information from the entire history \( (x_1, x_2, \ldots, x_t) \) into a single representation. This is both the power and the limitation of RNNs: the fixed-size bottleneck forces the network to learn what to remember and what to forget, but it also means that very long-range dependencies can be difficult to capture.
The output \( y_t \) at each timestep is then computed from the hidden state:
This simple formulation -- state update followed by output -- is the fundamental building block of all recurrent architectures, from vanilla RNNs to LSTMs and GRUs.
Entire sequence maps to a single output. Used for sentiment analysis, document classification, and activity recognition from sensor streams.
Each input timestep produces an output. Used for named entity recognition, part-of-speech tagging, and frame-level video classification.
Input sequence is encoded, then decoded into a different-length output. Used for machine translation, summarization, and chatbots.
The simplest recurrent architecture -- a single hidden state updated at every timestep with a tanh nonlinearity.
A vanilla RNN computes its hidden state at time \( t \) using the following recurrence:
Where \( W_h \in \mathbb{R}^{d_h \times d_h} \) is the hidden-to-hidden weight matrix, \( W_x \in \mathbb{R}^{d_h \times d_x} \) is the input-to-hidden weight matrix, and \( b \in \mathbb{R}^{d_h} \) is a bias vector. The \( \tanh \) activation squashes the result to the range \([-1, 1]\).
The output at each timestep is computed as:
Where \( W_y \in \mathbb{R}^{d_y \times d_h} \) maps the hidden state to the output space. For classification tasks, a softmax is applied: \( \hat{y}_t = \text{softmax}(y_t) \).
The critical feature is weight sharing: the same matrices \( W_h \), \( W_x \), and \( W_y \) are used at every timestep. This means an RNN can process sequences of any length with a fixed number of parameters, and patterns learned at one position in the sequence can be applied at any other position.
To understand how an RNN processes a sequence, we "unroll" it through time. The single recurrent cell is replicated \( T \) times (once per timestep), with the hidden state flowing from left to right. This unrolled view reveals the RNN as a very deep feedforward network with shared weights at each layer.
A vanilla RNN with hidden size \( d_h = 256 \) and input size \( d_x = 100 \) has only \( 256 \times 256 + 256 \times 100 + 256 = 91{,}904 \) parameters for the recurrent layer, regardless of the sequence length. This makes RNNs extremely parameter-efficient compared to architectures that use separate weights for each position.
Visualize how an RNN cell is unrolled across timesteps. Input \( x_t \) enters from below, hidden state \( h_t \) flows to the right, and output \( y_t \) exits above. Use the slider to control the number of timesteps, and click Animate to watch data pulse through the chain.
Adjust timesteps with the slider. Click Animate to pulse data through the recurrent chain from left to right.
How gradients flow backward through the unrolled computational graph to update shared weights.
Backpropagation Through Time (BPTT) is simply the standard backpropagation algorithm applied to the unrolled RNN graph. Because the RNN is unrolled into \( T \) copies of the same cell, BPTT must sum the gradient contributions from every timestep.
The total loss over a sequence of length \( T \) is the sum of per-timestep losses:
To compute the gradient of \( L \) with respect to the shared weight matrix \( W_h \), we apply the chain rule through every timestep where \( W_h \) is used:
The critical term is the product \( \prod_{j=k+1}^{t} \frac{\partial h_j}{\partial h_{j-1}} \). Each factor in this product involves the Jacobian of the recurrent transition function. When this product spans many timesteps (large \( t - k \)), its magnitude can either explode or vanish, depending on the spectral properties of \( W_h \).
For the vanilla RNN with \( h_t = \tanh(W_h h_{t-1} + W_x x_t + b) \), each Jacobian factor is:
Where \( \text{diag}(1 - h_t^2) \) is a diagonal matrix of \( \tanh \) derivatives (values between 0 and 1). The product of \( (t - k) \) such Jacobians behaves roughly as:
Where \( \gamma < 1 \) accounts for the \( \tanh \) derivative suppression. If the largest singular value of \( W_h \) multiplied by \( \gamma \) is less than 1, the gradient vanishes exponentially. If it exceeds 1, the gradient explodes. This is the fundamental instability of vanilla RNN training.
In practice, computing gradients across the entire sequence is both memory-intensive and numerically unstable. Truncated BPTT limits the backward pass to a fixed window of \( \tau \) timesteps, computing gradients only through the most recent \( \tau \) steps. This is an approximation that trades long-range gradient accuracy for computational efficiency.
Process the full sequence forward, computing \( h_t \) and \( L_t \) at each step. The hidden state carries information from the entire history.
At each timestep \( t \), backpropagate the gradient through only the previous \( \tau \) steps: \( h_t, h_{t-1}, \ldots, h_{t-\tau} \). Stop the gradient at \( h_{t-\tau} \).
Sum the truncated gradients over all timesteps and perform a single parameter update. This is the standard approach in frameworks like PyTorch.
The fundamental challenge of training recurrent networks -- and why it led to the invention of LSTM.
The gradient of the loss at time \( t \) with respect to the hidden state at an earlier time \( k \) involves a product of Jacobians spanning \( (t-k) \) timesteps. For a simplified scalar RNN \( h_t = w \cdot h_{t-1} \), this product reduces to:
Three regimes emerge:
Vanishing gradients mean the network is effectively blind to long-range dependencies. If a word at position 5 is crucial for predicting the output at position 50, the gradient signal connecting them is astronomically small. The weights responsible for maintaining that memory receive essentially zero updates, making it impossible for the network to learn the dependency.
Exploding gradients are easier to detect (the loss suddenly becomes NaN) and easier to fix. The standard remedy is gradient clipping:
Where \( \theta \) is a threshold (typically 1.0 or 5.0). Gradient clipping rescales the gradient vector when its norm exceeds \( \theta \), preventing parameter updates from becoming destructively large while preserving the gradient direction.
Watch how gradient magnitude changes across timesteps for different weight values. When \( w < 1 \), gradients vanish (blue). When \( w > 1 \), they explode (red). When \( w \approx 1 \), gradients remain stable (green).
Adjust the recurrent weight and sequence length. Bars show the gradient magnitude at each timestep relative to the final step.
The chart below compares gradient magnitude over 50 timesteps for four different weight values, illustrating how even small deviations from \( w = 1 \) cause dramatic exponential effects.
The gated memory cell that solved the vanishing gradient problem and dominated sequence modeling for two decades.
The central innovation of LSTM is the cell state \( C_t \) -- a dedicated memory channel that runs through the entire sequence like a conveyor belt. Information can flow along the cell state unchanged, with only minor linear interactions controlled by gates. This is why gradients can propagate through hundreds of timesteps without vanishing.
The cell state is updated at each timestep through a two-step process: first, selectively forget old information; then, selectively add new information. Both operations are controlled by learned sigmoid gates that output values between 0 and 1.
An LSTM cell contains four interacting components, each with its own weight matrices and bias:
1. Forget Gate -- decides what information to discard from the cell state:
2. Input Gate -- decides what new information to store in the cell state:
3. Cell State Update -- computes candidate values and updates the cell:
4. Output Gate -- decides what to output based on the filtered cell state:
Where \( \sigma \) is the sigmoid function, \( \odot \) denotes element-wise multiplication, and \( [\cdot, \cdot] \) denotes concatenation. Note that the cell state update \( C_t = f_t \odot C_{t-1} + i_t \odot \tilde{C}_t \) is a linear operation on \( C_{t-1} \), which is why gradients flow through it without vanishing.
An LSTM with hidden size \( d_h \) and input size \( d_x \) has \( 4 \times (d_h \times (d_h + d_x) + d_h) \) parameters -- four times that of a vanilla RNN, because it has four separate sets of weight matrices (one per gate plus the candidate cell).
Explore how the forget and input gates control the LSTM cell state. The cell state highway runs along the top. The forget gate controls how much old memory to keep, while the input gate controls how much new information to add.
Adjust forget and input gate values with the sliders. Click Step Forward to advance one timestep and see how the cell state changes.
A streamlined gated architecture that merges the forget and input gates into a single update mechanism.
The Gated Recurrent Unit (GRU), introduced by Cho et al. in 2014, simplifies the LSTM by combining the forget and input gates into a single update gate and merging the cell state and hidden state into one. This reduces the number of parameters while maintaining most of the LSTM's ability to capture long-range dependencies.
Reset Gate -- decides how much past information to forget:
Update Gate -- decides how much of the new candidate to use vs keeping old state:
Candidate Hidden State:
Final Hidden State -- interpolates between old and new:
When \( z_t \approx 0 \), the hidden state is copied forward unchanged (memory preservation). When \( z_t \approx 1 \), the hidden state is completely replaced by the candidate. This linear interpolation provides a gradient highway similar to the LSTM cell state.
GRU has 3 gate matrices vs LSTM's 4, making it roughly 25% smaller. For hidden size 512: LSTM has ~4.2M params, GRU has ~3.1M params in the recurrent layer.
GRU trains faster per epoch due to fewer matrix multiplications. On GPU, the difference is 10-15%. On CPU, it can be 20-30% faster.
Neither consistently outperforms the other. LSTM tends to be better on very long sequences (>500 steps) where the separate cell state provides an advantage. GRU often matches LSTM on shorter sequences.
Start with LSTM as the default. Try GRU if you need faster training or have limited compute. The performance difference is usually small enough that engineering constraints matter more.
Processing sequences in both directions to capture context from both the past and the future.
A standard (unidirectional) RNN processes the sequence from left to right, so the hidden state at time \( t \) only contains information from past inputs \( (x_1, \ldots, x_t) \). But in many tasks, context from future inputs is equally important. For example, in the sentence "He said the bank was steep," the meaning of "bank" depends on the word "steep" that comes after it.
A Bidirectional RNN (BiRNN) solves this by running two separate RNNs in parallel: one processing the sequence left-to-right (forward), and another processing it right-to-left (backward). The outputs of both are then combined at each timestep.
The forward hidden state captures past context:
The backward hidden state captures future context:
The final representation at each timestep concatenates both directions:
This doubles the representation size at each timestep, providing the network with complete bidirectional context. Bidirectional LSTMs (BiLSTMs) apply this same principle using LSTM cells instead of vanilla RNN cells, and they are the standard architecture for tasks like named entity recognition and machine reading comprehension.
Bidirectional RNNs require the entire sequence to be available before processing. This makes them unsuitable for real-time applications like speech recognition during live conversation or online time-series forecasting where future data is not yet available. Use unidirectional RNNs for autoregressive and streaming tasks.
The key knobs that control RNN and LSTM behavior, and practical advice for tuning them.
The hidden size \( d_h \) determines the capacity of the recurrent layer -- how much information the hidden state can encode. Larger hidden sizes can model more complex patterns but increase computation quadratically (since \( W_h \in \mathbb{R}^{d_h \times d_h} \)).
Stacking multiple RNN layers creates a deep RNN where the hidden state of layer \( l \) becomes the input to layer \( l+1 \):
Essential for stable RNN training. Clips the global gradient norm before the optimizer step:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)Applying dropout to recurrent connections is tricky because naive dropout at every timestep destroys the hidden state signal. The standard approach is variational dropout (Gal & Ghahramani, 2016): use the same dropout mask across all timesteps within a sequence, but different masks across sequences in a batch.
Applied to the input \( x_t \) at each layer. Standard values: 0.2-0.5. Use the same mask across timesteps.
Applied to the hidden state \( h_{t-1} \) in the recurrent connection. Values: 0.1-0.3. Must use the same mask across timesteps.
Applied between stacked layers. PyTorch's nn.LSTM(dropout=0.3) applies this automatically between layers (not after the last layer).
Real-world domains where RNNs and LSTMs deliver powerful results on sequential data.
Language modeling, machine translation, text generation, sentiment analysis, and named entity recognition. LSTMs capture grammatical dependencies across sentences.
Converting audio waveforms to text. BiLSTMs process mel spectrograms frame by frame, and CTC loss enables alignment-free training.
Stock prices, weather prediction, energy demand, and sensor data. LSTMs learn temporal patterns and seasonality from historical sequences.
Composing melodies note by note. RNNs learn musical structure, harmony, and rhythm from MIDI sequences to generate novel compositions.
Type a short string and see a simple character-level frequency-based prediction for the next character. The canvas visualizes characters flowing through RNN cells.
Enter text below and click Predict to see next-character probability bars and the sequence flow visualization.
The chart below compares Vanilla RNN, LSTM, GRU, and BiLSTM across four common sequence tasks, showing accuracy or F1 score.
RNNs naturally handle sequences of any length with the same model, unlike CNNs which require fixed input sizes or padding.
Weight sharing across timesteps means the model size is independent of sequence length, enabling processing of very long sequences.
Unidirectional RNNs can process streaming data one element at a time, making them ideal for real-time applications.
The recurrent structure encodes the assumption that recent context matters more, which is correct for most sequential tasks.
RNNs process tokens one at a time, preventing parallelization across timesteps. This makes them much slower to train than Transformers on GPUs.
Even LSTMs struggle with dependencies spanning thousands of tokens. Transformers with attention handle this better.
Stacking many layers requires careful initialization, gradient clipping, and residual connections. Training deep RNNs is harder than deep Transformers.
For most NLP tasks, Transformers now achieve better accuracy. RNNs remain competitive mainly for streaming and low-latency applications.
From basic LSTM classification to complete sequence-to-sequence pipelines in PyTorch and Keras.
A complete LSTM model for classifying text sequences using PyTorch.
import torch
import torch.nn as nn
class LSTMClassifier(nn.Module):
def __init__(self, vocab_size, embed_dim, hidden_dim,
output_dim, n_layers=2, dropout=0.3):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.lstm = nn.LSTM(
embed_dim, hidden_dim, num_layers=n_layers,
batch_first=True, dropout=dropout,
bidirectional=True
)
self.fc = nn.Linear(hidden_dim * 2, output_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
embedded = self.dropout(self.embedding(x))
output, (hidden, cell) = self.lstm(embedded)
# Concatenate final forward and backward hidden states
hidden = torch.cat((hidden[-2], hidden[-1]), dim=1)
return self.fc(self.dropout(hidden))
# Usage
model = LSTMClassifier(
vocab_size=10000, embed_dim=128,
hidden_dim=256, output_dim=2
)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# Training loop
for epoch in range(10):
model.train()
for batch_x, batch_y in train_loader:
optimizer.zero_grad()
predictions = model(batch_x)
loss = criterion(predictions, batch_y)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
Using Keras to build a stacked LSTM for predicting the next value in a time series.
import numpy as np
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense, Dropout
# Generate synthetic sine wave data
t = np.linspace(0, 100, 2000)
data = np.sin(t) + 0.1 * np.random.randn(len(t))
# Create sequences of length 50
seq_len = 50
X, y = [], []
for i in range(len(data) - seq_len):
X.append(data[i:i+seq_len])
y.append(data[i+seq_len])
X = np.array(X).reshape(-1, seq_len, 1)
y = np.array(y)
# Split data
split = int(0.8 * len(X))
X_train, X_test = X[:split], X[split:]
y_train, y_test = y[:split], y[split:]
# Build stacked LSTM model
model = Sequential([
LSTM(64, return_sequences=True,
input_shape=(seq_len, 1)),
Dropout(0.2),
LSTM(32, return_sequences=False),
Dropout(0.2),
Dense(1)
])
model.compile(optimizer='adam', loss='mse')
model.summary()
# Train
history = model.fit(
X_train, y_train, epochs=20, batch_size=32,
validation_data=(X_test, y_test)
)
# Evaluate
mse = model.evaluate(X_test, y_test)
print(f"Test MSE: {mse:.6f}")
import torch
import torch.nn as nn
class CharGRU(nn.Module):
def __init__(self, vocab_size, embed_dim=64,
hidden_dim=128, n_layers=1):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.gru = nn.GRU(
embed_dim, hidden_dim, n_layers,
batch_first=True
)
self.fc = nn.Linear(hidden_dim, vocab_size)
def forward(self, x, hidden=None):
emb = self.embedding(x)
out, hidden = self.gru(emb, hidden)
logits = self.fc(out)
return logits, hidden
# Build vocabulary from text
text = "hello world this is a character level model"
chars = sorted(set(text))
char2idx = {c: i for i, c in enumerate(chars)}
idx2char = {i: c for c, i in char2idx.items()}
vocab_size = len(chars)
# Encode text
encoded = [char2idx[c] for c in text]
X = torch.tensor(encoded[:-1]).unsqueeze(0)
y = torch.tensor(encoded[1:]).unsqueeze(0)
# Train
model = CharGRU(vocab_size)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()
for epoch in range(200):
logits, _ = model(X)
loss = criterion(logits.view(-1, vocab_size), y.view(-1))
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (epoch + 1) % 50 == 0:
print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")
import torch
import torch.nn as nn
class BiLSTM_NER(nn.Module):
def __init__(self, vocab_size, tagset_size,
embed_dim=100, hidden_dim=128):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.lstm = nn.LSTM(
embed_dim, hidden_dim, num_layers=2,
batch_first=True, bidirectional=True,
dropout=0.3
)
self.fc = nn.Linear(hidden_dim * 2, tagset_size)
def forward(self, x):
emb = self.embedding(x)
lstm_out, _ = self.lstm(emb)
logits = self.fc(lstm_out)
return logits
# Example: predict a tag for every token
model = BiLSTM_NER(vocab_size=5000, tagset_size=9)
sample_input = torch.randint(0, 5000, (4, 30)) # batch=4, seq=30
output = model(sample_input)
print(f"Output shape: {output.shape}") # (4, 30, 9)
Regardless of the architecture (LSTM, GRU, or vanilla), always use gradient clipping during training. Add torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) before optimizer.step(). This single line prevents exploding gradients and makes training dramatically more stable.