Neural-Path/Notes
35 min

RNNs & LSTMs

Sequential data has a property feedforward networks ignore: order matters. The word "not" before "good" completely changes sentiment. Recurrent Neural Networks process sequences step-by-step, maintaining a hidden state that accumulates context. Long Short-Term Memory (LSTM)s extend this with explicit memory management, solving the gradient flow problems that plagued vanilla RNNs.

Theory

Unrolled RNN — 5 timesteps
h₀RNNx1"not"RNNx2"a"RNNx3"good"RNNx4"film"RNNx5"at"
ready

A recurrent network reads one element at a time and updates a running summary of everything it's seen — like taking notes as you read, where each new word updates your understanding. The LSTM adds three gates: a forget gate (what to erase from the notes), an input gate (what new information to write down), and an output gate (what to say based on the notes right now). These gates exist because vanilla RNNs can't decide what's worth remembering across long sequences.

Vanilla RNN Hidden State

At each timestep tt, the RNN updates its hidden state by combining the previous hidden state with the current input:

ht=tanh(Whht1+Wxxt+b)h_t = \tanh(W_h h_{t-1} + W_x x_t + b)

where htRHh_t \in \mathbb{R}^H is the hidden state, xtRDx_t \in \mathbb{R}^D is the input at time tt, WhRH×HW_h \in \mathbb{R}^{H \times H}, and WxRH×DW_x \in \mathbb{R}^{H \times D}.

The output at each step (for sequence classification, only the final step matters):

yt=Woht+boy_t = W_o h_t + b_o

The Vanishing Gradient Problem

Training RNNs requires Backpropagation Through Time (BPTT). The gradient of the loss with respect to the hidden state at step 0 involves a product of Jacobians:

Lh0=LhTt=1Ththt1\frac{\partial \mathcal{L}}{\partial h_0} = \frac{\partial \mathcal{L}}{\partial h_T} \prod_{t=1}^{T} \frac{\partial h_t}{\partial h_{t-1}}

Each Jacobian term is:

htht1=Whdiag(1tanh2(Whht1+Wxxt+b))\frac{\partial h_t}{\partial h_{t-1}} = W_h^\top \cdot \text{diag}\left(1 - \tanh^2(W_h h_{t-1} + W_x x_t + b)\right)

The spectral norm of this product shrinks exponentially when the largest singular value of WhW_h is less than 1 (vanishing) or grows exponentially when it exceeds 1 (exploding). For sequences of length T=100T = 100, gradients at t=0t = 0 are effectively zero.

💡Intuition

Imagine trying to remember what you read on page 1 of a book while answering a question on page 100. Each page processed "dilutes" the memory of earlier pages. The RNN hidden state faces the same problem: early inputs get exponentially less influence on the final prediction.

LSTM: Gated Memory

The LSTM introduces a cell state ctc_t — a conveyor belt that carries information across time with minimal modification — and three gates that control information flow.

Forget gate — what fraction of the previous cell state to keep:

ft=σ(Wf[ht1,xt]+bf)(0,1)Hf_t = \sigma(W_f [h_{t-1}, x_t] + b_f) \in (0, 1)^H

Input gate — what new information to write to the cell:

it=σ(Wi[ht1,xt]+bi)(0,1)Hi_t = \sigma(W_i [h_{t-1}, x_t] + b_i) \in (0, 1)^H c~t=tanh(Wc[ht1,xt]+bc)(1,1)H\tilde{c}_t = \tanh(W_c [h_{t-1}, x_t] + b_c) \in (-1, 1)^H

Cell state update — combine forget and input:

ct=ftct1+itc~tc_t = f_t \odot c_{t-1} + i_t \odot \tilde{c}_t

Output gate — what to expose as the hidden state:

ot=σ(Wo[ht1,xt]+bo)o_t = \sigma(W_o [h_{t-1}, x_t] + b_o) ht=ottanh(ct)h_t = o_t \odot \tanh(c_t)

The forget gate must be initialized near 1.0 during training — not near 0. If it starts near zero, the cell state is cleared on the very first input, and the network loses the ability to form long-range dependencies before it has learned anything. This is one of the few cases where initialization is architecturally critical: a forget gate bias of +1 is standard practice, not convention.

The critical difference from vanilla RNN: the cell state update is additive — not multiplicative through a weight matrix. Gradients flow back through addition without shrinking.

ℹ️Note

The LSTM cell state is analogous to long-term memory, while the hidden state is working memory. The forget gate decides what to clear, the input gate decides what to write, and the output gate decides what to expose for computation at the current step.

LSTM Cell — click each gate to highlight its role
C_{t-1}C_t×+tanh×h_tσ_fσ_itanh_cσ_oh_{t-1}x_t
Forget Gatef_t
σ(W_f·[h_{t-1}, x_t] + b_f)

Decides what to throw away from cell state. Output 0 = forget, 1 = keep.

GRU: Simplified Gating

The Gated Recurrent Unit merges the forget and input gates into a single update gate, reducing parameters by 25%:

zt=σ(Wz[ht1,xt]+bz),rt=σ(Wr[ht1,xt]+br)z_t = \sigma(W_z [h_{t-1}, x_t] + b_z), \quad r_t = \sigma(W_r [h_{t-1}, x_t] + b_r) h~t=tanh(W[rtht1,xt]+b)\tilde{h}_t = \tanh(W [r_t \odot h_{t-1}, x_t] + b) ht=(1zt)ht1+zth~th_t = (1 - z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t

GRU often achieves similar performance to LSTM, especially on smaller datasets.

Walkthrough

Dataset: IMDB Sentiment Analysis

The IMDB dataset contains 50,000 movie reviews labeled positive or negative.

Preprocessing Pipeline

python
from torchtext.datasets import IMDB
from torchtext.vocab import build_vocab_from_iterator
from torchtext.data.utils import get_tokenizer
 
tokenizer = get_tokenizer("basic_english")
 
def yield_tokens(data_iter):
    for label, text in data_iter:
        yield tokenizer(text)
 
train_iter = IMDB(split="train")
vocab = build_vocab_from_iterator(
    yield_tokens(train_iter),
    specials=["<unk>", "<pad>"],
    min_freq=5,
    max_tokens=25000,
)
vocab.set_default_index(vocab["<unk>"])

Sequence Padding

Reviews have variable lengths. We pad shorter sequences to match the longest in each mini-batch:

python
import torch
from torch.nn.utils.rnn import pad_sequence
 
MAX_LEN = 512
 
def collate_fn(batch):
    labels, texts = zip(*batch)
    texts   = [torch.tensor(vocab(tokenizer(t))[:MAX_LEN]) for t in texts]
    lengths = torch.tensor([len(t) for t in texts])
    labels  = torch.tensor([1 if l == "pos" else 0 for l in labels])
    texts   = pad_sequence(texts, batch_first=True, padding_value=1)
    return texts, labels, lengths

Teacher Forcing

For sequence-to-sequence tasks, teacher forcing feeds the ground-truth token as the next input during training rather than the model's own prediction:

python
# Training: use ground-truth targets (teacher forcing)
output, _ = decoder(tgt[:-1], hidden, encoder_out)
loss = criterion(output.reshape(-1, vocab_size), tgt[1:].reshape(-1))
 
# Inference: autoregressive (model uses its own predictions)
with torch.no_grad():
    for step in range(max_len):
        out, hidden = decoder(current_token, hidden, encoder_out)
        current_token = out.argmax(-1)

Code Implementation

train.py
python
"""
LSTM Training Pipeline — IMDB Sentiment Analysis
=================================================
Architecture: Embedding -> BiLSTM (2 layers) -> Dropout -> Classifier
Optimizer:    Adam with gradient clipping and ReduceLROnPlateau
Target:       ~90% test accuracy in 10 epochs
"""
import argparse
 
import mlflow
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.utils.rnn import pack_padded_sequence
from torch.utils.data import DataLoader
 
 
class SentimentLSTM(nn.Module):
    def __init__(
        self,
        vocab_size: int,
        embed_dim: int = 128,
        hidden_dim: int = 256,
        num_layers: int = 2,
        dropout: float = 0.3,
        num_classes: int = 2,
    ):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=1)
        self.lstm = nn.LSTM(
            input_size=embed_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            bidirectional=True,
            dropout=dropout if num_layers > 1 else 0.0,
        )
        self.dropout    = nn.Dropout(dropout)
        # BiLSTM output dim is 2 * hidden_dim (forward + backward)
        self.classifier = nn.Sequential(
            nn.Linear(2 * hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, num_classes),
        )
 
    def forward(self, x: torch.Tensor, lengths: torch.Tensor) -> torch.Tensor:
        emb    = self.dropout(self.embedding(x))   # (B, T, embed_dim)
        packed = pack_padded_sequence(
            emb, lengths.cpu(), batch_first=True, enforce_sorted=False)
        _, (hidden, _) = self.lstm(packed)
        # hidden: (num_layers*2, B, hidden_dim)
        # Last layer forward and backward states
        fwd    = hidden[-2]                         # (B, hidden_dim)
        bwd    = hidden[-1]                         # (B, hidden_dim)
        pooled = torch.cat([fwd, bwd], dim=-1)      # (B, 2*hidden_dim)
        return self.classifier(self.dropout(pooled))
 
 
def train_epoch(model, loader, optimizer, criterion, device, clip=1.0):
    model.train()
    total_loss, correct, total = 0.0, 0, 0
    for texts, labels, lengths in loader:
        texts, labels = texts.to(device), labels.to(device)
        optimizer.zero_grad()
        logits = model(texts, lengths)
        loss   = criterion(logits, labels)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()
        total_loss += loss.item()
        correct    += (logits.argmax(1) == labels).sum().item()
        total      += len(labels)
    return total_loss / len(loader), correct / total
 
 
@torch.no_grad()
def evaluate(model, loader, criterion, device):
    model.eval()
    total_loss, correct, total = 0.0, 0, 0
    for texts, labels, lengths in loader:
        texts, labels = texts.to(device), labels.to(device)
        logits = model(texts, lengths)
        total_loss += criterion(logits, labels).item()
        correct    += (logits.argmax(1) == labels).sum().item()
        total      += len(labels)
    return total_loss / len(loader), correct / total
 
 
def train(cfg: dict) -> None:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    hp     = cfg["hyperparams"]
 
    VOCAB_SIZE = 25002  # 25000 words + <unk> + <pad>
    model = SentimentLSTM(
        vocab_size=VOCAB_SIZE,
        embed_dim=hp["embed_dim"],
        hidden_dim=hp["hidden_dim"],
        num_layers=hp["num_layers"],
        dropout=hp["dropout"],
    ).to(device)
 
    optimizer = optim.Adam(model.parameters(), lr=hp["learning_rate"], weight_decay=1e-5)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=2, factor=0.5)
    criterion = nn.CrossEntropyLoss()
 
    mlflow.set_experiment("lstm-imdb")
    with mlflow.start_run():
        mlflow.log_params(hp)
        best_acc = 0.0
 
        for epoch in range(hp["epochs"]):
            train_loss, train_acc = train_epoch(
                model, train_loader, optimizer, criterion, device)
            test_loss,  test_acc  = evaluate(
                model, test_loader,  criterion, device)
            scheduler.step(test_loss)
 
            mlflow.log_metrics({
                "train_loss": train_loss, "train_acc": train_acc,
                "test_loss":  test_loss,  "test_acc":  test_acc,
            }, step=epoch)
            print(f"Ep {epoch+1:3d} | train={train_acc:.4f} | test={test_acc:.4f}")
 
            if test_acc > best_acc:
                best_acc = test_acc
                torch.save(model.state_dict(), "best_model.pt")
 
        print(f"Best test accuracy: {best_acc:.4f}")
        # Expected: 0.9034 — BiLSTM 2-layer, hidden=256, embed=128, dropout=0.3
 
 
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", default="../../config.yaml")
    import yaml
    train(yaml.safe_load(open(parser.parse_args().config)))
 
 
if __name__ == "__main__":
    main()

Analysis & Evaluation

Where Your Intuition Breaks

LSTMs solve the vanishing gradient problem. They reduce it — the cell state provides a highway for gradients to flow through time with minimal decay. But "solve" overstates it: on sequences longer than a few hundred tokens, LSTMs still lose signal. This is why Transformers replaced LSTMs for most long-sequence tasks. The LSTM's gates are an engineering fix for a fundamental constraint of sequential computation; the Transformer sidesteps that constraint entirely with direct attention.

Training Dynamics

BiLSTM (hidden=256, 2 layers, embed=128, Adam lr=1e-3) on IMDB:

EpochTrain LossTrain AccTest Acc
10.582169.8%72.3%
30.323486.7%85.1%
50.210991.3%89.2%
100.142394.8%90.3%

The train-test gap (~4.5%) is typical for RNNs on text. Unlike vision, text augmentation is difficult without changing meaning, so regularization via dropout and weight decay is critical.

Gradient Flow: RNN vs LSTM

Comparing gradient norms at the embedding layer after backpropagating through 200 timesteps:

python
# After loss.backward(), inspect embedding gradient
embed_grad = model.embedding.weight.grad.norm().item()
 
# Results on the same IMDB batch:
# Vanilla RNN  (tanh, hidden=256):  0.000031  <- near zero, barely learns
# LSTM         (hidden=256):        0.041200
# GRU          (hidden=256):        0.038700
# BiLSTM       (hidden=256):        0.057300

The LSTM gradient is over 1,000x larger at the embedding layer — meaning word representations are actually updating and learning from the signal.

💡Intuition

The LSTM cell state update ct=ftct1+itc~tc_t = f_t \odot c_{t-1} + i_t \odot \tilde{c}_t means the gradient of the loss with respect to c0c_0 travels back via addition, not through repeated matrix multiplications. This is the same principle as ResNet skip connections: addition preserves gradient magnitude across many steps.

Sequence Length Sensitivity

How test accuracy degrades as sequences get longer (both models trained on sequences up to that length):

Max Sequence LengthRNN AccLSTM AccLSTM Advantage
50 tokens87.2%88.1%+0.9%
100 tokens83.4%88.4%+5.0%
200 tokens74.1%87.9%+13.8%
400 tokens62.3%87.1%+24.8%

The RNN degrades dramatically; the LSTM remains robust. This table makes the vanishing gradient problem quantitatively visible.

⚠️Warning

Bidirectional LSTMs require the complete sequence before processing — no streaming. For real-time or autoregressive generation, use a unidirectional LSTM or a causal transformer. BiLSTM is ideal for offline tasks: sentiment analysis, named entity recognition, and machine translation encoding.

Production-Ready Code

python
"""serve_api/app.py — Sentiment analysis inference endpoint."""
import pickle
 
import torch
import torch.nn as nn
from fastapi import FastAPI
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from torch.nn.utils.rnn import pack_padded_sequence
from torchtext.data.utils import get_tokenizer
 
app = FastAPI(title="Sentiment Analysis API", version="1.0.0")
 
VOCAB_SIZE = 25002
HIDDEN_DIM = 256
EMBED_DIM  = 128
NUM_LAYERS = 2
MAX_LEN    = 512
 
 
class SentimentLSTM(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers):
        super().__init__()
        self.embedding  = nn.Embedding(vocab_size, embed_dim, padding_idx=1)
        self.lstm       = nn.LSTM(embed_dim, hidden_dim, num_layers,
                                   batch_first=True, bidirectional=True)
        self.classifier = nn.Sequential(
            nn.Linear(2 * hidden_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, 2),
        )
 
    def forward(self, x, lengths):
        packed = pack_padded_sequence(
            self.embedding(x), lengths.cpu(),
            batch_first=True, enforce_sorted=False)
        _, (h, _) = self.lstm(packed)
        return self.classifier(torch.cat([h[-2], h[-1]], dim=-1))
 
 
# Load vocabulary saved during training
with open("vocab.pkl", "rb") as f:
    vocab = pickle.load(f)
 
tokenizer = get_tokenizer("basic_english")
model     = SentimentLSTM(VOCAB_SIZE, EMBED_DIM, HIDDEN_DIM, NUM_LAYERS)
model.load_state_dict(torch.load("best_model.pt", map_location="cpu"))
model.eval()
 
 
class ReviewRequest(BaseModel):
    text: str
 
 
@app.post("/predict")
def predict(req: ReviewRequest):
    tokens  = vocab(tokenizer(req.text))[:MAX_LEN]
    tensor  = torch.tensor(tokens).unsqueeze(0)
    lengths = torch.tensor([len(tokens)])
 
    with torch.inference_mode():
        logits = model(tensor, lengths)
        probs  = torch.softmax(logits, dim=-1).squeeze()
 
    pred = probs.argmax().item()
    return JSONResponse({
        "sentiment":     "positive" if pred == 1 else "negative",
        "confidence":    round(probs[pred].item(), 4),
        "positive_prob": round(probs[1].item(), 4),
        "negative_prob": round(probs[0].item(), 4),
    })
 
# Example: POST {"text": "This film was absolutely brilliant!"}
# -> {"sentiment": "positive", "confidence": 0.9821, ...}

From LSTMs to Attention

The LSTM solved the vanishing gradient problem that broke vanilla RNNs — but it introduced a new bottleneck: information compression. No matter how long the input sequence, the entire context must pass through the hidden state htRHh_t \in \mathbb{R}^H, a fixed-size vector. For a 512-token sentence, the LSTM must compress all 512 tokens into a single 256-dimensional vector before the decoder can attend to it. Early tokens are still crowded out — just more slowly than in vanilla RNNs.

The attention mechanism (Bahdanau et al., 2015) was originally introduced as a fix within encoder-decoder RNN architectures: rather than forcing the decoder to rely on a single compressed context vector, let it directly attend to every encoder hidden state at each decoding step:

ct=i=1Tαtihi,αti=exp(eti)jexp(etj)c_t = \sum_{i=1}^T \alpha_{ti} h_i, \qquad \alpha_{ti} = \frac{\exp(e_{ti})}{\sum_j \exp(e_{tj})}

This gave LSTMs a significant boost on long sequences — but attention still depended on the sequential hidden states hih_i, which meant the encoder still had to process tokens one at a time.

The Transformer (Vaswani et al., 2017) asked: what if we removed the recurrence entirely and only kept attention? Every token attends directly to every other token in a single parallel pass. This addresses both LSTM limitations simultaneously:

LimitationLSTMTransformer
Long-range dependenciesPartial (gating helps but compression remains)Direct — any token attends to any other in O(1)O(1) steps
Sequential computationO(T)O(T) steps — cannot parallelizeO(1)O(1) steps — fully parallel over the sequence
Gradient path lengthO(T)O(T) — shorter than RNN but still proportional to lengthO(1)O(1) — any position connects directly to any other
💡Intuition

The LSTM is a compress-then-retrieve architecture: encode everything into a hidden state, then decode from it. Attention is a retrieve-directly architecture: skip the compression and look up what you need from the full input at query time. For long sequences, direct retrieval wins — you don't lose information to compression.

LSTMs remain competitive when: (1) sequence length is short (T<200T < 200), (2) inference latency is critical and O(T2)O(T^2) attention cost is prohibitive, or (3) you need true streaming with no fixed context window. For most NLP tasks at scale, the Transformer family has displaced LSTMs — but understanding the LSTM's design is essential for understanding why attention works.

🚀Production

For production sentiment analysis at scale, distilbert-base-uncased fine-tuned on IMDB achieves 94–95% vs 90% for a BiLSTM, at roughly 3× the latency per request. Quantize the LSTM to int8 with torch.quantization.quantize_dynamic for 4× model size reduction and 2× CPU throughput with less than 0.5% accuracy loss. For long sequences, LSTM often outperforms Transformer on inference throughput due to its O(T)O(T) vs O(T2)O(T^2) attention cost.

Enjoying these notes?

Get new lessons delivered to your inbox. No spam.