Automatic Differentiation
Introduction
Suppose we have a function
We assume this function
def func(x):
return sin(x**3)We might consider func with respect to x. Given this premise, there are two common strategies for automatically computing derivatives: forward and reverse mode. Each mode has its own set of trade-offs.
The Jacobian
Let
Interpreting the derivative as the best linear approximation of the function
The transpose of the Jacobian matrix also has significance. It pulls cotangent vectors from the output cotangent space back to the input. As you might guess, this is known as the pullback and will reappear when we deal with reverse mode AD and Vector-Jacobian products.
Although cotangent vectors are referenced here and later in code, this article does not require a precise understanding of what they are in differential geometric terms. For our purposes, it's sufficient to view them operationally as vectors that are pulled back through the transpose of the Jacobian.
The Gradient
Throughout this article, vectors are treated as column vectors unless otherwise stated. With this convention, the Jacobian may also be written in terms of the gradient of the components of
For the special case of scalar valued functions (
Forward Mode
Consider the following function
By the chain rule, we have:
Or, equivalently, in terms of Jacobians:
The core idea of forward mode is to begin with a tangent vector
Observe the following about the expression above:
- Forward mode computes Jacobian-Vector Products (commonly abbreviated as JVP) in step with the function evaluation itself, accumulating outwards. The intermediate full Jacobian matrices do not need to be explicitly computed. As an example, consider the Jacobian of an elementwise-operation, say,
where . Its Jacobian is an diagonal matrix:
However, to compute its JVP, you can simply evaluate
-
Initializing
as a one-hot vector (e.g. ) yields the corresponding partial derivative (a single column in the final Jacobian matrix). More generally, forward mode propagates an arbitrary vector to compute a directional derivative, of which partial derivatives are the special case where the direction is aligned with a coordinate axis. -
Building on the observation above, you could compute the full Jacobian matrix
one column at a time, iterating over each individual one-hot basis vector.
This begins to reveal a shortcoming of forward mode: constructing the full Jacobian is expensive when the input dimension
Implementation
One way to implement forward mode AD in Python would be to wrap existing primitives (like, say, NumPy's functions) such that they operate on a new Dual type (named after dual numbers) defined as follows:
class Dual(NamedTuple):
primal: np.ndarray
tangent: np.ndarrayThe primal is just the output of the function evaluated at the given input. The tangent is the vector mentioned above, pushed forward by the function's Jacobian (also evaluated at the given input). For instance, for sin, the wrapped function would look like this:
def sin(x: Dual) -> Dual:
return Dual(
# Evaluate the function as usual
primal=np.sin(x.primal),
# JVP
tangent=np.cos(x.primal) * x.tangent
)These primitives compose correctly according to the chain rule, and the tangent value at the end of evaluation gives the desired directional derivative.
To put the pieces above into a complete forward-mode AD example, consider the function
where
import numpy as np
from typing import NamedTuple
class Dual(NamedTuple):
primal: np.ndarray
tangent: np.ndarray
# Primitives required for our function
add = lambda x, y: Dual(x.primal + y.primal, x.tangent + y.tangent)
mul = lambda x, y: Dual(x.primal * y.primal, x.tangent * y.primal + x.primal * y.tangent)
matmul = lambda x, y: Dual(x.primal @ y.primal, x.tangent @ y.primal + x.primal @ y.tangent)
square = lambda x: Dual(x.primal ** 2, 2 * x.primal * x.tangent)
sin = lambda x: Dual(np.sin(x.primal), np.cos(x.primal) * x.tangent)
# Constants
n, m = 4, 3
constant = lambda c: Dual(primal=c, tangent=np.zeros_like(c))
W = constant(np.random.rand(m, n))
b = constant(np.random.rand(m))
pi = constant(np.pi)
# The function we want to differentiate
f = lambda x: add(matmul(W, sin(mul(pi, square(x)))), b)
# Wrap the input in a Dual, with tangent set to the standard basis vector
# for the variable we want to differentiate with respect to.
x = Dual(
primal=np.random.rand(n),
tangent=np.array([1, 0, 0, 0])
)
# Compute the function value and its derivative at x
y = f(x)
print(f'f(x) = {y.primal}')
print(f'df/dx = {y.tangent}')Using np.random.seed(0), the example above outputs:
f(x) = [1.59230569 2.13912743 1.24483847]
df/dx = [0.30036226 0.23186367 0.5274067]The full Jacobian can be built one column at a time by iterating over the standard bases:
x_primal = np.random.rand(n)
jacobian = np.stack(
[
f(x=Dual(primal=x_primal, tangent=basis)).tangent
for basis in np.eye(n)
],
axis=1
)
print(jacobian)Which produces:
[[ 0.30036226 0.09085468 -1.79903433 -0.86712717]
[ 0.23186367 0.0820517 -1.30604222 -1.41916768]
[ 0.5274067 0.04871081 -2.36301769 -0.84168345]]If you're familiar with JAX, this is equivalent to (and produces the same Jacobian as):
f = lambda x: W @ jnp.sin(jnp.pi * x ** 2) + b
jacobian = jnp.stack(
[
jax.jvp(f, [x], [basis])[1]
for basis in jnp.eye(n)
],
axis=1
)Reverse Mode
Returning to the chain rule: instead of pushing a tangent vector forward as in forward mode, we can instead pull back a vector
The evaluation here proceeds left to right, computing Vector-Jacobian Products (commonly abbreviated as VJP). However, notice that the Jacobians are evaluated at the function outputs computed right to left (as usual). To avoid redundant computation, these intermediate outputs are typically cached during the forward evaluation of
Implementation
Suppose you're dealing with the following function:
def func(x):
v = x**2
return np.exp(v) * np.sin(v)Visualizing this as a computational graph:
For reverse mode, the propagation of the gradient via Jacobians starts at the output
- Each node must know its parents (in order to propagate the gradient up to them).
- A node with multiple parents, like
mul, needs to know how to compute the derivative with respect to each parent. In our implementation, we'll pair the parent node with its VJP function:
# A function that performs the Vector-Jacobian product
type VJP = Callable[[np.ndarray], np.ndarray]
# Parent node and its VJP
type Parent = tuple[Node, VJP]- A node must first accumulate gradients from all its children before propagating the accumulated gradient to its own parents. In practice, this can be satisfied by propagating gradients in topological order.
Similar to forward mode, we'll use wrapped functions that operate on a Node structure that captures the constraints above:
class Node:
def __init__(self, value: np.ndarray, *parents: Parent):
self.value = value
self.parents = parents
self.grad = 0.The value member caches the output of the function evaluation. This will be used for evaluating the Jacobian at the linearization point later (e.g.
Putting it together to evaluate the same function as before:
import numpy as np
from typing import Callable
type VJP = Callable[[np.ndarray], np.ndarray]
type Parent = tuple[Node, VJP]
class Node:
def __init__(self, value: np.ndarray, *parents: Parent):
self.value = value
self.parents = parents
self.grad = 0.
# Primitives required for our function
def sin(x: Node) -> Node:
return Node(
np.sin(x.value),
(x, lambda g: g * np.cos(x.value))
)
def square(x: Node) -> Node:
return Node(
x.value ** 2,
(x, lambda g: g * 2 * x.value)
)
def mul(x: Node, y: Node) -> Node:
return Node(
x.value * y.value,
(x, lambda g: y.value * g),
(y, lambda g: x.value * g)
)
def matmul(x: Node, y: Node) -> Node:
return Node(
x.value @ y.value,
(x, lambda g: g @ y.value.T),
(y, lambda g: x.value.T @ g)
)
def add(x: Node, y: Node) -> Node:
return Node(
x.value + y.value,
(x, lambda g: g),
(y, lambda g: g)
)
# Reverse-mode auto-diff helpers
def sort_topologically(node, visited, ordered):
if node not in visited:
visited.add(node)
for parent, _ in node.parents:
sort_topologically(parent, visited, ordered)
ordered.append(node)
return ordered
def compute_gradient(output: Node, v):
nodes = sort_topologically(output, set(), [])
# Reset gradients
for node in nodes:
node.grad = 0
output.grad = v
for node in reversed(nodes):
for parent, vjp in node.parents:
parent.grad += vjp(node.grad)
# Function constants
n, m = 4, 3
W = Node(np.random.rand(m, n))
b = Node(np.random.rand(m, 1))
pi = Node(np.array(np.pi))
# The function we want to differentiate
def f(x):
return add(matmul(W, sin(mul(pi, square(x)))), b)
# Forward pass
x = Node(np.random.rand(n, 1))
y = f(x)
# Reverse accumulation
compute_gradient(y, v=np.array([[1], [0], [0]]))
print(f'f(x): {y.value.squeeze()}')
print(f'Gradient: {x.grad.squeeze()}')Which produces:
f(x): [1.59230569 2.13912743 1.24483847]
Gradient: [ 0.30036226 0.09085468 -1.79903433 -0.86712717]Notice that when compared to the forward mode output,
def compute_jacobian(input, output):
rows = []
for basis in np.eye(m):
compute_gradient(output, v=basis[..., None])
rows.append(input.grad.squeeze())
return np.vstack(rows)
jacobian = compute_jacobian(input=x, output=y)
print(f'Jacobian:\n{jacobian}')Which produces the same Jacobian as forward mode:
[[ 0.30036226 0.09085468 -1.79903433 -0.86712717]
[ 0.23186367 0.0820517 -1.30604222 -1.41916768]
[ 0.5274067 0.04871081 -2.36301769 -0.84168345]]And as before, this is equivalent to the following in JAX:
f = lambda x: W @ jnp.sin(jnp.pi * x ** 2) + b
_, vjp_fun = jax.vjp(f, x)
jacobian = jnp.stack(
[
vjp_fun(basis[..., None])[0].squeeze()
for basis in jnp.eye(m)
]
)Gradient Accumulation
As a quick aside, if you're familiar with gradient accumulation in frameworks like PyTorch, it might be tempting to think that skipping this bit in compute_gradient:
for node in nodes:
node.grad = 0might be equivalent to skipping the zero_grad call in PyTorch (which generally results in gradients being accumulated when, say, loss.backward() is called). However, that's not quite the case here. Notice that the stale gradients would accidentally get used in the VJPs. That said, it'd take just a slight tweak to make gradient accumulation work as expected:
# Separately track the gradients for this pass
gradients = {node: 0 for node in nodes}
gradients[output] = v
for node in reversed(nodes):
for parent, vjp in node.parents:
# Use only the gradients from this pass for propagation
contrib = vjp(gradients[node])
gradients[parent] += contrib
# # Also accumulate into parent.grad across passes
parent.grad += contribNote that in this variant, the caller is responsible for zeroing out node.grad at the appropriate times to ensure correctness.
VJP from JVP
Typically, frameworks pick reverse mode (like PyTorch) or forward mode (like Google's Ceres Solver) and write a specialized implementation around their particular flavor of AD. In contrast, JAX uses a neat trick that allows it to implement both forward and reverse modes without implementing per-mode primitives. This is described in the paper You Only Linearize Once, and more concisely in their earlier Decomposing reverse-mode automatic differentiation.
The general idea is alluded to in the paper's subtitle: tangents transpose to gradients. As an example, consider the matmul primitive. Recall that for a linear map
Where
and so, we have:
And comparing the JVP for matmul to the VJP, you can observe this transposition:
# Forward mode JVP
lambda x, y: Dual(..., x.tangent @ y.primal + x.primal @ y.tangent)
# Reverse mode VJP
def matmul(x: Node, y: Node) -> Node:
return Node(
...,
(x, lambda g: g @ y.value.T),
(y, lambda g: x.value.T @ g)
)So you might wonder if it's possible to automatically arrive at the VJP if you have a JVP primitive? As shown in the paper (and done in practice in JAX), the answer is yes. You do this by proceeding in two steps, starting with the forward mode JVP primitives:
- Linearize. Evaluate the primal as usual, but defer fully evaluating the JVP. Instead capture its structure, say, in some intermediate representation (IR). For instance, for
sin, you might do something like this:
def sin(x: Dual) -> Dual:
return Dual(
# Evaluate the primal as usual
np.sin(x.primal),
# Partially evaluate the JVP, capture the rest in an IR.
# Our IR is a tuple of the form (<kind>, ...args)
(
# Scale
'*',
# Scaling factor
(
# A "closed over" constant
'const',
# Partially evaluated term that depends on the primal
np.cos(x.primal)
),
x.tangent
)
)Instead of fully evaluating the JVP, we've only evaluated the portion that depends on the primal, and captured it as a linear function of the tangent (notice that the cos(x.primal) portion is treated as a constant). The JAX paper considers this to be a form of partial evaluation where just the primal is known.
- Transpose. Evaluating the IR as-is yields the JVP. Transposing it during evaluation yields the VJP. For instance, if we encounter a
matmulexpression in the IR, we know it's either of the form or , where is a constant (since thematmulmust be linear in ). And depending on the case, the transposition will yield or .
Putting it all together, we have:
from __future__ import annotations
from typing import Literal, NamedTuple
import numpy as np
type Expr = (
# Input variable
tuple[Literal['input']] |
# Closed-over constant
# Only appears as a coefficient (e.g. inside '*' or '@'),
# never as a standalone tangent expression.
tuple[Literal['const'], np.ndarray] |
# Zero tangent (with constant and input shapes)
tuple[Literal['zero'], tuple[int, ...], tuple[int, ...]] |
# Binary Operations
tuple[Literal['+', '*', '@'], Expr, Expr]
)
class Dual(NamedTuple):
primal: np.ndarray
tangent: Expr
def add(x: Dual, y: Dual) -> Dual:
return Dual(
x.primal + y.primal,
('+', x.tangent, y.tangent)
)
def mul(x: Dual, y: Dual) -> Dual:
return Dual(
x.primal * y.primal,
(
'+',
('*', ('const', y.primal), x.tangent),
('*', ('const', x.primal), y.tangent)
)
)
def square(x: Dual) -> Dual:
return Dual(
x.primal ** 2,
('*', ('const', 2.0 * x.primal), x.tangent)
)
def sin(x: Dual) -> Dual:
return Dual(
np.sin(x.primal),
('*', ('const', np.cos(x.primal)), x.tangent)
)
def matmul(x: Dual, y: Dual) -> Dual:
return Dual(
x.primal @ y.primal,
(
'+',
('@', x.tangent, ('const', y.primal)),
('@', ('const', x.primal), y.tangent)
)
)
def jvp(expr: Expr, tangent: np.ndarray) -> np.ndarray:
match expr:
case ('input',):
return tangent
case ('zero', constant_shape, _):
return np.zeros(constant_shape)
case ('+', x, y):
return jvp(x, tangent) + jvp(y, tangent)
case ('*', ('const', c), x) | ('*', x, ('const', c)):
return c * jvp(x, tangent)
case ('@', ('const', A), x):
return A @ jvp(x, tangent)
case ('@', x, ('const', A)):
return jvp(x, tangent) @ A
raise ValueError(f'Unknown linear expression: {expr}')
def vjp(expr: Expr, cotangent: np.ndarray) -> np.ndarray:
match expr:
case ('input',):
return cotangent
case ('zero', _, input_shape):
return np.zeros(input_shape)
case ('+', x, y):
return vjp(x, cotangent) + vjp(y, cotangent)
case ('*', ('const', c), x) | ('*', x, ('const', c)):
return vjp(x, c * cotangent)
case ('@', ('const', A), x):
return vjp(x, A.T @ cotangent)
case ('@', x, ('const', A)):
return vjp(x, cotangent @ A.T)
raise ValueError(f'Unknown linear expression: {expr}')
# Define our constants
n, m = 4, 3
constant = lambda c: Dual(c, ('zero', c.shape, (n, 1)))
W = constant(np.random.random((m, n)))
b = constant(np.random.random((m, 1)))
pi = constant(np.array(np.pi))
# The function we want to differentiate
def f(x):
return add(matmul(W, sin(mul(pi, square(x)))), b)
# Compute function output + linearized representation
x = np.random.random((n, 1))
y, f_lin = f(Dual(primal=x, tangent=('input',)))
# Forward mode (JVP)
d_fwd = jvp(f_lin, tangent=np.eye(4)[:, :1])
# Reverse mode (VJP)
d_rev = vjp(f_lin, cotangent=np.eye(3)[:, :1])When we evaluate our function f, the output tangent is f_lin, which captures the partially evaluated JVP computation expressed in our IR. The jvp function takes this IR along with the input tangent and fully evaluates it.
The order of its evaluation is visualized here (where
For reverse mode, the vjp function accepts the same f_lin along with a cotangent and evaluates the VJP by using a transposed version of the JVP:
Note that vjp propagates cotangents backward through the computation as it recurses (e.g., vjp(x, cotangent @ A.T)), whereas jvp evaluates the innermost JVP first (e.g., jvp(x, tangent) @ A). This difference in evaluation order determines the structure of the graphs.
Shared Primitives. This route allows you to implement both forward and reverse mode using shared primitives. Now, you do need an additional set of "transposition rules". In practice, a framework will likely define far more primitives than transposition rules. For instance, in our toy example, you can add the cos primitive without having to add any new transposition rules.
Caching and Checkpointing. When we previously implemented reverse mode, we explicitly cached the primal values (in Node.value) for reuse later. The equivalent here is handled by the ('const', ...) nodes in the IR: they capture the dependence on the primal so it doesn't have to be re-computed. As before, however, it's worth noting that this doesn't necessarily have to be the case. For instance, the IR can be tweaked to accommodate recomputing the constants rather than storing them (e.g. to implement gradient/activation checkpointing).
Aside on Adjoints
The general idea here, of taking the transpose (or adjoint) of the linear bits to transform between "forward" and "reverse" variants of an algorithm predates JAX. Dan Piponi's paper Two Tricks for the Price of One from 2009 (cited in the JAX paper) applies the same technique for spatially varying convolutions and briefly discusses it in the context of AD. Piponi also references the idea in his earlier blog posts from 2005 on adjoints and AD. Following the breadcrumbs from there, Jon Claerbout has interesting observations about adjoints as back projections, as does Carlos Scheidegger in his article Adjoints and Inverses (related ideas are also explored in this video by Sam Levey).
See Also
- Matt Johnson's Auto Differentiation Video from Deep Learning Summer School 2017
- Matt Johnson's Autodidact, a pedagogical implementation of Autograd
- Dougal Maclaurin's PhD thesis also covers AutoGrad and has a nice concise intro to AD
- Autodidax goes into how to build an interpreter-based AD system, similar to JAX, from scratch
- JAX's guides on Advanced Automatic Differentiation
- JAX's AutoDiff Cookbook (has overlaps with the JAX advanced AD guides)
- Simon Peyton Jones' keynote talk at ECOOP 2017: Auto Differentiation for Dummies
- James Townsend's neat trick for calculating JVPs by composing two VJPs
- Evaluating Derivatives by Griewank and Walther