Vector-Quantized VAEs (VQ-VAE)
Takeaway
VQ-VAE replaces continuous latents with a learned discrete codebook, improving fidelity and avoiding posterior collapse via commitment and codebook losses.
The problem (before → after)
- Before: Continuous VAEs can blur outputs and collapse posteriors when decoders are strong.
- After: Discrete latents with a codebook capture structure and enable powerful autoregressive priors over indices.
Mental model first
Think of compressing images into indices of a palette; each index references a prototype patch from a dictionary. The decoder reconstructs by painting with these prototypes.
Just-in-time concepts
- Codebook E ∈ R^{K×D}; encoder outputs z_e; quantize to nearest e_k.
- Straight-through estimator: Backpropagate through quantization by copying encoder gradients.
- Loss: reconstruction + ||sg[z_e] − e||² + β ||z_e − sg[e]||².
First-pass solution
Train encoder/decoder with codebook updates (EMA variant stabilizes learning). Fit an autoregressive prior over code indices for high-quality generation.
Iterative refinement
- Multi-scale codebooks and hierarchical latents improve global coherence.
- Gumbel-softmax alternatives offer differentiable relaxation.
- Tokenizers power large-scale generative pipelines (e.g., image/audio).
Code as a byproduct (nearest codebook vector)
import torch
def quantize(z_e, codebook):
# z_e: [B, D, H, W], codebook: [K, D]
z = z_e.permute(0,2,3,1).reshape(-1, z_e.size(1))
d2 = (z.unsqueeze(1) - codebook.unsqueeze(0)).pow(2).sum(-1)
idx = d2.argmin(dim=1)
z_q = codebook[idx].view(z_e.size(0), z_e.size(2), z_e.size(3), -1).permute(0,3,1,2)
return z_q, idx.view(z_e.size(0), z_e.size(2), z_e.size(3))
Principles, not prescriptions
- Use discrete bottlenecks to encourage informative latents.
- Stabilize codebook learning with EMA and commitment losses.
Common pitfalls
- Codebook collapse; ensure usage with diversity losses or warmup.
- Mismatch between tokenizer and prior capacity limits quality.
Connections and contrasts
- See also: [/blog/variational-inference], [/blog/gans], [/blog/attention-is-all-you-need] (autoregressive priors over tokens).
Quick checks
- Why discrete latents? — Better compression and strong priors; avoids collapse.
- Why straight-through? — Enables gradient flow through quantization.
- How to sample? — Sample code indices from the prior, then decode.
Further reading
- VQ-VAE, VQ-VAE-2 papers (source above)
- Tokenizer + transformer pipelines