Attention Is All You Need
Takeaway
Transformers replace recurrence and convolution with self-attention, letting models focus on the most relevant tokens anywhere in a sequence and enabling parallel training that scales.
The problem (before → after)
- Before: RNNs and LSTMs process tokens sequentially, making long-range dependencies hard to learn and training slow to parallelize.
- After: Self-attention compares every token with every other token at once, extracting relationships in parallel and shortening gradient paths.
Mental model first
Reading a paragraph, you don’t memorize every word in order—you glance back and forth, focusing on key phrases. Self-attention simulates that: each token asks, “Who in the sentence should I pay attention to for my current role?” The pattern of attention weights is a heatmap of relevance.
Just-in-time concepts
- Queries, Keys, Values (Q, K, V): Each token is projected into these three spaces; attention uses Q·K to score relevance and mixes Values accordingly.
- Scaled dot-product attention: softmax((QKᵀ)/√d) V, where √d stabilizes gradients.
- Multi-head attention: Multiple attention “heads” capture different relations (syntax, coreference, position).
- Positional encoding: Since attention is order-agnostic, add position information (sinusoidal or learned) to token embeddings.
First-pass solution
Seq2seq with attention already helped RNNs, but recurrence still serialized computation. The transformer removes recurrence entirely: stacks of attention and feedforward blocks, with residual connections and layer normalization, trained with teacher forcing.
Iterative refinement
- Encoder: Layers of multi-head self-attention + position-wise feedforward networks build contextual token representations.
- Decoder: Similar layers plus masked self-attention (to prevent peeking ahead) and cross-attention to the encoder outputs.
- Training tricks: Label smoothing, learning-rate warmup with Adam, dropout.
- Scaling: Larger models and datasets keep improving performance; attention cost is O(n²), motivating efficient variants (sparse, linearized, chunked).
Code as a byproduct (minimal attention)
import torch
import torch.nn.functional as F
def scaled_dot_attn(Q, K, V):
d = Q.size(-1)
scores = Q @ K.transpose(-2, -1) / (d ** 0.5)
weights = F.softmax(scores, dim=-1)
return weights @ V, weights
Principles, not prescriptions
- Separate content retrieval (attention) from computation (feedforward) and sequence order (positional encoding).
- Parallelize wherever possible; remove unnecessary sequential bottlenecks.
- Use multiple views (heads) to capture diverse relations.
Common pitfalls
- Ignoring positions: Without positional encoding, the model cannot distinguish permutations.
- Overfitting heads: Too many heads with tiny dimensions can underutilize capacity.
- Length scaling: O(n²) attention becomes expensive; use efficient attention for long inputs.
Connections and contrasts
- See also: [/blog/gans] (alternative generative modeling), [/blog/graph-neural-networks] (message passing vs attention), [/blog/automatic-differentiation] (training mechanics).
Quick checks
- Why divide by √d in attention? — To keep dot-product magnitudes in a stable range for softmax.
- What does multi-head attention buy you? — Parallel relation types, improving expressiveness.
- Why mask decoder self-attention? — To prevent information leakage from future tokens during training.
Further reading
- Vaswani et al., 2017 (paper above)
- “The Annotated Transformer”
- “Attention Is All You Need” reproducible implementations