Takeaway

Transformers replace recurrence and convolution with self-attention, letting models focus on the most relevant tokens anywhere in a sequence and enabling parallel training that scales.

The problem (before → after)

  • Before: RNNs and LSTMs process tokens sequentially, making long-range dependencies hard to learn and training slow to parallelize.
  • After: Self-attention compares every token with every other token at once, extracting relationships in parallel and shortening gradient paths.

Mental model first

Reading a paragraph, you don’t memorize every word in order—you glance back and forth, focusing on key phrases. Self-attention simulates that: each token asks, “Who in the sentence should I pay attention to for my current role?” The pattern of attention weights is a heatmap of relevance.

Just-in-time concepts

  • Queries, Keys, Values (Q, K, V): Each token is projected into these three spaces; attention uses Q·K to score relevance and mixes Values accordingly.
  • Scaled dot-product attention: softmax((QKᵀ)/√d) V, where √d stabilizes gradients.
  • Multi-head attention: Multiple attention “heads” capture different relations (syntax, coreference, position).
  • Positional encoding: Since attention is order-agnostic, add position information (sinusoidal or learned) to token embeddings.

First-pass solution

Seq2seq with attention already helped RNNs, but recurrence still serialized computation. The transformer removes recurrence entirely: stacks of attention and feedforward blocks, with residual connections and layer normalization, trained with teacher forcing.

Iterative refinement

  1. Encoder: Layers of multi-head self-attention + position-wise feedforward networks build contextual token representations.
  2. Decoder: Similar layers plus masked self-attention (to prevent peeking ahead) and cross-attention to the encoder outputs.
  3. Training tricks: Label smoothing, learning-rate warmup with Adam, dropout.
  4. Scaling: Larger models and datasets keep improving performance; attention cost is O(n²), motivating efficient variants (sparse, linearized, chunked).

Code as a byproduct (minimal attention)

import torch
import torch.nn.functional as F

def scaled_dot_attn(Q, K, V):
    d = Q.size(-1)
    scores = Q @ K.transpose(-2, -1) / (d ** 0.5)
    weights = F.softmax(scores, dim=-1)
    return weights @ V, weights

Principles, not prescriptions

  • Separate content retrieval (attention) from computation (feedforward) and sequence order (positional encoding).
  • Parallelize wherever possible; remove unnecessary sequential bottlenecks.
  • Use multiple views (heads) to capture diverse relations.

Common pitfalls

  • Ignoring positions: Without positional encoding, the model cannot distinguish permutations.
  • Overfitting heads: Too many heads with tiny dimensions can underutilize capacity.
  • Length scaling: O(n²) attention becomes expensive; use efficient attention for long inputs.

Connections and contrasts

  • See also: [/blog/gans] (alternative generative modeling), [/blog/graph-neural-networks] (message passing vs attention), [/blog/automatic-differentiation] (training mechanics).

Quick checks

  1. Why divide by √d in attention? — To keep dot-product magnitudes in a stable range for softmax.
  2. What does multi-head attention buy you? — Parallel relation types, improving expressiveness.
  3. Why mask decoder self-attention? — To prevent information leakage from future tokens during training.

Further reading

  • Vaswani et al., 2017 (paper above)
  • “The Annotated Transformer”
  • “Attention Is All You Need” reproducible implementations