Takeaway

Autodiff computes exact derivatives of programs by chaining local derivatives through a computational graph; reverse mode powers deep learning by yielding gradients at constant cost per parameter.

The problem (before → after)

  • Before: Manual differentiation is error-prone; numerical differences are unstable and slow.
  • After: Build a graph during forward execution and apply the chain rule mechanically to get accurate gradients.

Mental model first

Think of a factory assembly line: each station outputs both a value and a recipe for how changes ripple backward. Reverse mode runs the assembly in reverse, distributing credit to inputs.

Just-in-time concepts

  • Forward vs reverse mode: number of inputs vs outputs dictates efficiency.
  • Tape and graph: Record operations and their local Jacobians.
  • Vector–Jacobian product (VJP): Core primitive for reverse-mode autodiff.

First-pass solution

Trace the program; accumulate gradients by backpropagating sensitivities from outputs to inputs using VJPs.

Iterative refinement

  1. Memory–compute trade-offs: checkpointing and recomputation.
  2. Higher-order derivatives: Forward-over-reverse and reverse-over-forward.
  3. JIT and fusion: Optimize graphs to remove overhead.

Code as a byproduct (toy reverse-mode)

class Node:
    def __init__(self, value, parents=(), grad_fn=None):
        self.value = value
        self.parents = parents
        self.grad_fn = grad_fn
        self.grad = 0.0

def add(a, b):
    out = Node(a.value + b.value, parents=(a,b), grad_fn=lambda go: [(a, go), (b, go)])
    return out

Principles, not prescriptions

  • Choose reverse mode for many-parameter, single-scalar-loss problems.
  • Keep numerical stability by fusing ops and using stable primitives.

Common pitfalls

  • Retaining graphs unnecessarily; memory blows up.
  • Gradient leakage through non-differentiable control flow.

Connections and contrasts

  • See also: [/blog/attention-is-all-you-need] (training), [/blog/variational-inference] (reparameterization uses autodiff).

Quick checks

  1. When is forward mode cheaper? — Few inputs, many outputs.
  2. What is a VJP? — Product of a vector with the Jacobian transpose.
  3. Why checkpointing? — To reduce memory by recomputing intermediates.

Further reading

  • Autodiff survey (source above)
  • Framework docs (PyTorch, JAX)