Automatic Differentiation

Introduction

Suppose we have a function and we want to compute its derivatives automatically. Not symbolically, and not numerically using finite differences, but via automatic differentiation.

We assume this function is composed of primitives with known derivatives. For instance, consider the following Python function:

def func(x):
    return sin(x**3)

We might consider and as the primitives with known derivatives. In this scenario, the chain rule tells us how to compute the derivative of func with respect to x. Given this premise, there are two common strategies for automatically computing derivatives: forward and reverse mode. Each mode has its own set of trade-offs.

The Jacobian

Let have component functions . The Jacobian of at , denoted , is the matrix of partial derivatives:

Interpreting the derivative as the best linear approximation of the function at a point , small perturbations around can be viewed as vectors of the tangent space at , 𝕟. Restricting ourselves to the Euclidean space, 𝕟 is effectively just (isomorphic to) 𝕟. The Jacobian can therefore be viewed as a linear map from , pushing input tangent vectors forward to their corresponding output vectors: . This mapping is known as the pushforward and will resurface in the discussion of forward mode AD and Jacobian-Vector products.

The transpose of the Jacobian matrix also has significance. It pulls cotangent vectors from the output cotangent space back to the input. As you might guess, this is known as the pullback and will reappear when we deal with reverse mode AD and Vector-Jacobian products.

Although cotangent vectors are referenced here and later in code, this article does not require a precise understanding of what they are in differential geometric terms. For our purposes, it's sufficient to view them operationally as vectors that are pulled back through the transpose of the Jacobian.

The Gradient

Throughout this article, vectors are treated as column vectors unless otherwise stated. With this convention, the Jacobian may also be written in terms of the gradient of the components of :

For the special case of scalar valued functions (), , the Jacobian is just the transpose of the gradient. This is the common case in deep learning, where is the number of trainable parameters (potentially in the millions or greater) and the output is a scalar loss value (e.g. mean negative log likelihood).

Forward Mode

Consider the following function that's a composition of other functions:

By the chain rule, we have:

Or, equivalently, in terms of Jacobians:

The core idea of forward mode is to begin with a tangent vector (a small perturbation direction) at the input and propagate it forward through the computation, applying each primitive operation’s Jacobian as the function is evaluated (right to left):

Observe the following about the expression above:

However, to compute its JVP, you can simply evaluate elementwise, and avoid ever materializing the full matrix.

This begins to reveal a shortcoming of forward mode: constructing the full Jacobian is expensive when the input dimension is large. In a real-world example such as a large language model, may represent a weight matrix with millions of parameters, while the training loss is a scalar. Forward mode computes the gradient one parameter at a time.

Implementation

One way to implement forward mode AD in Python would be to wrap existing primitives (like, say, NumPy's functions) such that they operate on a new Dual type (named after dual numbers) defined as follows:

class Dual(NamedTuple):
    primal: np.ndarray
    tangent: np.ndarray

The primal is just the output of the function evaluated at the given input. The tangent is the vector mentioned above, pushed forward by the function's Jacobian (also evaluated at the given input). For instance, for sin, the wrapped function would look like this:

def sin(x: Dual) -> Dual:
    return Dual(
        # Evaluate the function as usual
        primal=np.sin(x.primal),
        # JVP
        tangent=np.cos(x.primal) * x.tangent
    )

These primitives compose correctly according to the chain rule, and the tangent value at the end of evaluation gives the desired directional derivative.

To put the pieces above into a complete forward-mode AD example, consider the function :

where is an matrix and is an vector, and the sine and square operations are applied elementwise.

import numpy as np
from typing import NamedTuple
 
class Dual(NamedTuple):
    primal: np.ndarray
    tangent: np.ndarray
 
# Primitives required for our function
add = lambda x, y: Dual(x.primal + y.primal, x.tangent + y.tangent)
mul = lambda x, y: Dual(x.primal * y.primal, x.tangent * y.primal + x.primal * y.tangent)
matmul = lambda x, y: Dual(x.primal @ y.primal, x.tangent @ y.primal + x.primal @ y.tangent)
square = lambda x: Dual(x.primal ** 2, 2 * x.primal * x.tangent)
sin = lambda x: Dual(np.sin(x.primal), np.cos(x.primal) * x.tangent)
 
# Constants
n, m = 4, 3
constant = lambda c: Dual(primal=c, tangent=np.zeros_like(c))
W = constant(np.random.rand(m, n))
b = constant(np.random.rand(m))
pi = constant(np.pi)
 
# The function we want to differentiate
f = lambda x: add(matmul(W, sin(mul(pi, square(x)))), b)
 
# Wrap the input in a Dual, with tangent set to the standard basis vector
# for the variable we want to differentiate with respect to.
x = Dual(
    primal=np.random.rand(n),
    tangent=np.array([1, 0, 0, 0])
)
 
# Compute the function value and its derivative at x
y = f(x)
print(f'f(x) = {y.primal}')
print(f'df/dx = {y.tangent}')

Using np.random.seed(0), the example above outputs:

f(x) = [1.59230569 2.13912743 1.24483847]
df/dx = [0.30036226 0.23186367 0.5274067]

The full Jacobian can be built one column at a time by iterating over the standard bases:

x_primal = np.random.rand(n)
jacobian = np.stack(
    [
        f(x=Dual(primal=x_primal, tangent=basis)).tangent
        for basis in np.eye(n)
    ],
    axis=1
)
print(jacobian)

Which produces:

[[ 0.30036226  0.09085468 -1.79903433 -0.86712717]
 [ 0.23186367  0.0820517  -1.30604222 -1.41916768]
 [ 0.5274067   0.04871081 -2.36301769 -0.84168345]]

If you're familiar with JAX, this is equivalent to (and produces the same Jacobian as):

f = lambda x: W @ jnp.sin(jnp.pi * x ** 2) + b
jacobian = jnp.stack(
    [
        jax.jvp(f, [x], [basis])[1]
        for basis in jnp.eye(n)
    ],
    axis=1
)

Reverse Mode

Returning to the chain rule: instead of pushing a tangent vector forward as in forward mode, we can instead pull back a vector (an element of the cotangent space) at the output:

The evaluation here proceeds left to right, computing Vector-Jacobian Products (commonly abbreviated as VJP). However, notice that the Jacobians are evaluated at the function outputs computed right to left (as usual). To avoid redundant computation, these intermediate outputs are typically cached during the forward evaluation of , then reused when pulling the gradient back through the Jacobians.

Implementation

Suppose you're dealing with the following function:

def func(x):
    v = x**2
    return np.exp(v) * np.sin(v)

Visualizing this as a computational graph:

For reverse mode, the propagation of the gradient via Jacobians starts at the output and accumulates upwards. In terms of implementation, this has a few implications:

# A function that performs the Vector-Jacobian product
type VJP = Callable[[np.ndarray], np.ndarray]
# Parent node and its VJP
type Parent = tuple[Node, VJP]

Similar to forward mode, we'll use wrapped functions that operate on a Node structure that captures the constraints above:

class Node:
    def __init__(self, value: np.ndarray, *parents: Parent):
        self.value = value
        self.parents = parents
        self.grad = 0.

The value member caches the output of the function evaluation. This will be used for evaluating the Jacobian at the linearization point later (e.g. ). Strictly speaking, this is purely for efficiency purposes. For instance, you could always recompute this later, as is done with Gradient Checkpointing, later adopted for Deep Neural Networks (now more commonly referred to as Activation Checkpointing).

Putting it together to evaluate the same function as before:

import numpy as np
from typing import Callable
 
type VJP = Callable[[np.ndarray], np.ndarray]
type Parent = tuple[Node, VJP]
 
class Node:
    def __init__(self, value: np.ndarray, *parents: Parent):
        self.value = value
        self.parents = parents
        self.grad = 0.
 
# Primitives required for our function
 
def sin(x: Node) -> Node:
    return Node(
        np.sin(x.value),
        (x, lambda g: g * np.cos(x.value))
    )
 
def square(x: Node) -> Node:
    return Node(
        x.value ** 2,
        (x, lambda g: g * 2 * x.value)
    )
 
def mul(x: Node, y: Node) -> Node:
    return Node(
        x.value * y.value,
        (x, lambda g: y.value * g),
        (y, lambda g: x.value * g)
    )
 
def matmul(x: Node, y: Node) -> Node:
    return Node(
        x.value @ y.value,
        (x, lambda g: g @ y.value.T),
        (y, lambda g: x.value.T @ g)
    )
 
def add(x: Node, y: Node) -> Node:
    return Node(
        x.value + y.value,
        (x, lambda g: g),
        (y, lambda g: g)
    )
 
# Reverse-mode auto-diff helpers
 
def sort_topologically(node, visited, ordered):
    if node not in visited:
        visited.add(node)
        for parent, _ in node.parents:
            sort_topologically(parent, visited, ordered)
        ordered.append(node)
    return ordered
 
def compute_gradient(output: Node, v):
    nodes = sort_topologically(output, set(), [])
 
    # Reset gradients
    for node in nodes:
        node.grad = 0
    output.grad = v
 
    for node in reversed(nodes):
        for parent, vjp in node.parents:
            parent.grad += vjp(node.grad)
 
# Function constants
n, m = 4, 3
W = Node(np.random.rand(m, n))
b = Node(np.random.rand(m, 1))
pi = Node(np.array(np.pi))
 
# The function we want to differentiate
def f(x):
    return add(matmul(W, sin(mul(pi, square(x)))), b)
 
# Forward pass
x = Node(np.random.rand(n, 1))
y = f(x)
 
# Reverse accumulation
compute_gradient(y, v=np.array([[1], [0], [0]]))
 
print(f'f(x): {y.value.squeeze()}')
print(f'Gradient: {x.grad.squeeze()}')

Which produces:

f(x): [1.59230569 2.13912743 1.24483847]
Gradient: [ 0.30036226  0.09085468 -1.79903433 -0.86712717]

Notice that when compared to the forward mode output, is unchanged but the computed derivative is now a row in the Jacobian matrix. Thus, if your function was , you'd be done in a single pass. Since we're dealing with an 𝕞 output, we need passes to construct the Jacobian one row at a time:

def compute_jacobian(input, output):
    rows = []
    for basis in np.eye(m):
        compute_gradient(output, v=basis[..., None])
        rows.append(input.grad.squeeze())
    return np.vstack(rows)
 
jacobian = compute_jacobian(input=x, output=y)
print(f'Jacobian:\n{jacobian}')

Which produces the same Jacobian as forward mode:

[[ 0.30036226  0.09085468 -1.79903433 -0.86712717]
 [ 0.23186367  0.0820517  -1.30604222 -1.41916768]
 [ 0.5274067   0.04871081 -2.36301769 -0.84168345]]

And as before, this is equivalent to the following in JAX:

f = lambda x: W @ jnp.sin(jnp.pi * x ** 2) + b
_, vjp_fun = jax.vjp(f, x)
jacobian = jnp.stack(
    [
        vjp_fun(basis[..., None])[0].squeeze()
        for basis in jnp.eye(m)
    ]
)

Gradient Accumulation

As a quick aside, if you're familiar with gradient accumulation in frameworks like PyTorch, it might be tempting to think that skipping this bit in compute_gradient:

for node in nodes:
    node.grad = 0

might be equivalent to skipping the zero_grad call in PyTorch (which generally results in gradients being accumulated when, say, loss.backward() is called). However, that's not quite the case here. Notice that the stale gradients would accidentally get used in the VJPs. That said, it'd take just a slight tweak to make gradient accumulation work as expected:

# Separately track the gradients for this pass
gradients = {node: 0 for node in nodes}
gradients[output] = v
 
for node in reversed(nodes):
    for parent, vjp in node.parents:
        # Use only the gradients from this pass for propagation
        contrib = vjp(gradients[node])
        gradients[parent] += contrib
        # # Also accumulate into parent.grad across passes
        parent.grad += contrib

Note that in this variant, the caller is responsible for zeroing out node.grad at the appropriate times to ensure correctness.

VJP from JVP

Typically, frameworks pick reverse mode (like PyTorch) or forward mode (like Google's Ceres Solver) and write a specialized implementation around their particular flavor of AD. In contrast, JAX uses a neat trick that allows it to implement both forward and reverse modes without implementing per-mode primitives. This is described in the paper You Only Linearize Once, and more concisely in their earlier Decomposing reverse-mode automatic differentiation.

The general idea is alluded to in the paper's subtitle: tangents transpose to gradients. As an example, consider the matmul primitive. Recall that for a linear map defined by , its transpose has the defining property that:

Where represents the inner product. Using ,

and so, we have:

And comparing the JVP for matmul to the VJP, you can observe this transposition:

# Forward mode JVP
lambda x, y: Dual(..., x.tangent @ y.primal + x.primal @ y.tangent)
 
# Reverse mode VJP
def matmul(x: Node, y: Node) -> Node:
    return Node(
        ...,
        (x, lambda g: g @ y.value.T),
        (y, lambda g: x.value.T @ g)
    )

So you might wonder if it's possible to automatically arrive at the VJP if you have a JVP primitive? As shown in the paper (and done in practice in JAX), the answer is yes. You do this by proceeding in two steps, starting with the forward mode JVP primitives:

def sin(x: Dual) -> Dual:
    return Dual(
        # Evaluate the primal as usual
        np.sin(x.primal),
 
        # Partially evaluate the JVP, capture the rest in an IR.
        # Our IR is a tuple of the form (<kind>, ...args)
        (
            # Scale
            '*',
            # Scaling factor
            (
                # A "closed over" constant
                'const',
                # Partially evaluated term that depends on the primal
                np.cos(x.primal)
            ),
            x.tangent
        )
    )

Instead of fully evaluating the JVP, we've only evaluated the portion that depends on the primal, and captured it as a linear function of the tangent (notice that the cos(x.primal) portion is treated as a constant). The JAX paper considers this to be a form of partial evaluation where just the primal is known.

Putting it all together, we have:

from __future__ import annotations
from typing import Literal, NamedTuple
 
import numpy as np
 
type Expr = (
    # Input variable
    tuple[Literal['input']] |
 
    # Closed-over constant
    # Only appears as a coefficient (e.g. inside '*' or '@'),
    # never as a standalone tangent expression.
    tuple[Literal['const'], np.ndarray] |
 
    # Zero tangent (with constant and input shapes)
    tuple[Literal['zero'], tuple[int, ...], tuple[int, ...]] |
 
    # Binary Operations
    tuple[Literal['+', '*', '@'], Expr, Expr]
)
 
class Dual(NamedTuple):
    primal: np.ndarray
    tangent: Expr
 
def add(x: Dual, y: Dual) -> Dual:
    return Dual(
        x.primal + y.primal,
        ('+', x.tangent, y.tangent)
    )
 
def mul(x: Dual, y: Dual) -> Dual:
    return Dual(
        x.primal * y.primal,
        (
            '+',
            ('*', ('const', y.primal), x.tangent),
            ('*', ('const', x.primal), y.tangent)
        )
    )
 
def square(x: Dual) -> Dual:
    return Dual(
        x.primal ** 2,
        ('*', ('const', 2.0 * x.primal), x.tangent)
    )
 
def sin(x: Dual) -> Dual:
    return Dual(
        np.sin(x.primal),
        ('*', ('const', np.cos(x.primal)), x.tangent)
    )
 
def matmul(x: Dual, y: Dual) -> Dual:
    return Dual(
        x.primal @ y.primal,
        (
            '+',
            ('@', x.tangent, ('const', y.primal)),
            ('@', ('const', x.primal), y.tangent)
        )
    )
 
def jvp(expr: Expr, tangent: np.ndarray) -> np.ndarray:
    match expr:
        case ('input',):
            return tangent
 
        case ('zero', constant_shape, _):
            return np.zeros(constant_shape)
 
        case ('+', x, y):
            return jvp(x, tangent) + jvp(y, tangent)
 
        case ('*', ('const', c), x) | ('*', x, ('const', c)):
            return c * jvp(x, tangent)
 
        case ('@', ('const', A), x):
            return A @ jvp(x, tangent)
 
        case ('@', x, ('const', A)):
            return jvp(x, tangent) @ A
 
    raise ValueError(f'Unknown linear expression: {expr}')
 
def vjp(expr: Expr, cotangent: np.ndarray) -> np.ndarray:
    match expr:
        case ('input',):
            return cotangent
 
        case ('zero', _, input_shape):
            return np.zeros(input_shape)
 
        case ('+', x, y):
            return vjp(x, cotangent) + vjp(y, cotangent)
 
        case ('*', ('const', c), x) | ('*', x, ('const', c)):
            return vjp(x, c * cotangent)
 
        case ('@', ('const', A), x):
            return vjp(x, A.T @ cotangent)
 
        case ('@', x, ('const', A)):
            return vjp(x, cotangent @ A.T)
 
    raise ValueError(f'Unknown linear expression: {expr}')
 
# Define our constants
n, m = 4, 3
constant = lambda c: Dual(c, ('zero', c.shape, (n, 1)))
W = constant(np.random.random((m, n)))
b = constant(np.random.random((m, 1)))
pi = constant(np.array(np.pi))
 
# The function we want to differentiate
def f(x):
    return add(matmul(W, sin(mul(pi, square(x)))), b)
 
# Compute function output + linearized representation
x = np.random.random((n, 1))
y, f_lin = f(Dual(primal=x, tangent=('input',)))
 
# Forward mode (JVP)
d_fwd = jvp(f_lin, tangent=np.eye(4)[:, :1])
 
# Reverse mode (VJP)
d_rev = vjp(f_lin, cotangent=np.eye(3)[:, :1])

When we evaluate our function f, the output tangent is f_lin, which captures the partially evaluated JVP computation expressed in our IR. The jvp function takes this IR along with the input tangent and fully evaluates it.

The order of its evaluation is visualized here (where represents the input tangent):

For reverse mode, the vjp function accepts the same f_lin along with a cotangent and evaluates the VJP by using a transposed version of the JVP:

Note that vjp propagates cotangents backward through the computation as it recurses (e.g., vjp(x, cotangent @ A.T)), whereas jvp evaluates the innermost JVP first (e.g., jvp(x, tangent) @ A). This difference in evaluation order determines the structure of the graphs.

Shared Primitives. This route allows you to implement both forward and reverse mode using shared primitives. Now, you do need an additional set of "transposition rules". In practice, a framework will likely define far more primitives than transposition rules. For instance, in our toy example, you can add the cos primitive without having to add any new transposition rules.

Caching and Checkpointing. When we previously implemented reverse mode, we explicitly cached the primal values (in Node.value) for reuse later. The equivalent here is handled by the ('const', ...) nodes in the IR: they capture the dependence on the primal so it doesn't have to be re-computed. As before, however, it's worth noting that this doesn't necessarily have to be the case. For instance, the IR can be tweaked to accommodate recomputing the constants rather than storing them (e.g. to implement gradient/activation checkpointing).

Aside on Adjoints

The general idea here, of taking the transpose (or adjoint) of the linear bits to transform between "forward" and "reverse" variants of an algorithm predates JAX. Dan Piponi's paper Two Tricks for the Price of One from 2009 (cited in the JAX paper) applies the same technique for spatially varying convolutions and briefly discusses it in the context of AD. Piponi also references the idea in his earlier blog posts from 2005 on adjoints and AD. Following the breadcrumbs from there, Jon Claerbout has interesting observations about adjoints as back projections, as does Carlos Scheidegger in his article Adjoints and Inverses (related ideas are also explored in this video by Sam Levey).

See Also