Automatic Differentiation in Machine Learning
Takeaway
Autodiff computes exact derivatives of programs by chaining local derivatives through a computational graph; reverse mode powers deep learning by yielding gradients at constant cost per parameter.
The problem (before → after)
- Before: Manual differentiation is error-prone; numerical differences are unstable and slow.
- After: Build a graph during forward execution and apply the chain rule mechanically to get accurate gradients.
Mental model first
Think of a factory assembly line: each station outputs both a value and a recipe for how changes ripple backward. Reverse mode runs the assembly in reverse, distributing credit to inputs.
Just-in-time concepts
- Forward vs reverse mode: number of inputs vs outputs dictates efficiency.
- Tape and graph: Record operations and their local Jacobians.
- Vector–Jacobian product (VJP): Core primitive for reverse-mode autodiff.
First-pass solution
Trace the program; accumulate gradients by backpropagating sensitivities from outputs to inputs using VJPs.
Iterative refinement
- Memory–compute trade-offs: checkpointing and recomputation.
- Higher-order derivatives: Forward-over-reverse and reverse-over-forward.
- JIT and fusion: Optimize graphs to remove overhead.
Code as a byproduct (toy reverse-mode)
class Node:
def __init__(self, value, parents=(), grad_fn=None):
self.value = value
self.parents = parents
self.grad_fn = grad_fn
self.grad = 0.0
def add(a, b):
out = Node(a.value + b.value, parents=(a,b), grad_fn=lambda go: [(a, go), (b, go)])
return out
Principles, not prescriptions
- Choose reverse mode for many-parameter, single-scalar-loss problems.
- Keep numerical stability by fusing ops and using stable primitives.
Common pitfalls
- Retaining graphs unnecessarily; memory blows up.
- Gradient leakage through non-differentiable control flow.
Connections and contrasts
- See also: [/blog/attention-is-all-you-need] (training), [/blog/variational-inference] (reparameterization uses autodiff).
Quick checks
- When is forward mode cheaper? — Few inputs, many outputs.
- What is a VJP? — Product of a vector with the Jacobian transpose.
- Why checkpointing? — To reduce memory by recomputing intermediates.
Further reading
- Autodiff survey (source above)
- Framework docs (PyTorch, JAX)