Generative Adversarial Networks
Takeaway
GANs pit a generator against a discriminator in a minimax game so the generator learns to produce samples indistinguishable from real data.
The problem (before → after)
- Before: Likelihood-based generative models require tractable densities or approximations; modeling high-dimensional data directly is hard.
- After: Train a sampler indirectly—if a discriminator cannot tell generated data from real data, the generator implicitly captures the data distribution’s support and structure.
Mental model first
Imagine a counterfeiter (generator) and a detective (discriminator). The counterfeiter improves by fooling the detective; the detective improves by catching fakes. Their contest escalates until fakes pass as real.
Just-in-time concepts
- Minimax objective: min_G max_D E_{x~p_data} [log D(x)] + E_{z~p_z} [log(1 - D(G(z)))]
- Non-saturating trick: maximize log D(x) + log D(G(z)) to stabilize generator gradients.
- Mode collapse: Generator produces limited variety; use techniques like minibatch discrimination, unrolled GAN, or diversity regularizers.
- Evaluation: Inception Score, FID; both imperfect.
First-pass solution
Train D to classify real vs fake; train G to fool D. Alternate gradient updates. Start with simple architectures and careful optimization (Adam, spectral norm for stability).
Iterative refinement
- Architecture: Convolutional GANs (DCGAN) for images; residual blocks for stability.
- Losses: Wasserstein GAN with gradient penalty improves stability by optimizing Earth-Mover distance.
- Conditioning: cGANs incorporate labels or text for controllable generation.
- Regularization: Spectral normalization, gradient penalties, path length regularization.
Code as a byproduct (WGAN-GP snippet)
import torch
def gradient_penalty(D, real, fake):
bsz = real.size(0)
eps = torch.rand(bsz, 1, 1, 1, device=real.device)
x_hat = eps * real + (1 - eps) * fake
x_hat.requires_grad_(True)
d_hat = D(x_hat)
grads = torch.autograd.grad(d_hat.sum(), x_hat, create_graph=True)[0]
gp = ((grads.view(bsz, -1).norm(2, dim=1) - 1) ** 2).mean()
return gp
Principles, not prescriptions
- Treat training as aligning two objectives; stabilize the game rather than overfitting one player.
- Prefer metrics that reflect perceptual quality and diversity, not only loss curves.
- Encourage diversity explicitly to combat mode collapse.
Common pitfalls
- Overpowering D early: G receives vanishing gradients; balance updates or use WGAN-GP.
- Ignoring evaluation: High realism but low diversity indicates collapse.
- Unstable training due to poor normalization or learning rates.
Connections and contrasts
- See also: [/blog/attention-is-all-you-need] (sequence modeling), [/blog/variational-inference] (likelihood-based alternative), [/blog/normalizing-flows] (exact likelihood with invertibility).
Quick checks
- Why can GANs avoid explicit density modeling? — They train a sampler via a discriminator signal.
- What does WGAN improve? — Replaces JS divergence with Earth-Mover distance for smoother gradients.
- What is mode collapse? — Generator outputs low-diversity samples that fool D but miss modes of the data.
Further reading
- Goodfellow et al., 2014 (paper above)
- DCGAN, WGAN, WGAN-GP papers
- StyleGAN series for high-fidelity image synthesis