Symbolic, Numerical, and Automatic Differentiation

Machine Learning Fundamentals for Economists

Jesse Perla

University of British Columbia

Overview

Why the Emphasis on Differentiation?

  • Modern ML would be impossible without (1) software that makes calculating gradients easy; and (2) specialized hardware
  • Old methods, but flexible software + hardware have radically changed the scale of problems we can solve
  • You simply can’t solve large problems (or sample from high-dimensional distributions) without gradients, or jacobian of constraints
  • A mental shift was towards “differentiable programming”, i.e. to treat entire software programs as differentiable, nested functions
    • As long as you have helpful software to manage the bookkeeping
    • You can differentiate almost anything continuous, and at least expectations or distributions of almost anything discrete

Types of Differentiation

A few general types of differentiation

  1. Numerical Differentiation (i.e., finite differences)

  2. Symbolic Differentiation (i.e., chain rule and simplify subexpressions by hand)

  3. Automatic Differentiation (i.e., execute chain rule on computer)

    • Use the chain rule forwards vs. backwards
    • Think matrix-free methods
  4. Sparse Differentiation (i.e., use one of the above to calculate directional derivatives, potentially filling in sparse Jacobians with fewer passes)

Numerical Derivatives

Finite Differences

  • With \(f : \mathbb{R}^N \to \mathbb{R}^M\), take \(e_i\) as the \(i\)th standard basis vector

\[ \frac{\partial f(x)}{\partial x_i} \approx \frac{f(x + \epsilon e_i) - f(x)}{\epsilon} \]

  • Requires \(N\) forward passes for the full \(\nabla f(x)\). Same as forward-mode AD.
  • Good rule of thumb with above is \(\epsilon = \sqrt{\epsilon_{\text{machine}}}\)
  • Tough tradeoffs: roundoff vs. truncation errors
    • \(\epsilon\) too small hit machine precision errors, especially with GPUs
    • \(\epsilon\) too large and the approximation is bad
  • Still useful in many cases, especially for sparse problems

More Points for More Accuracy

  • Trickier in practice to handle tradeoff than you might expect
  • Could use more points which improves accuracy at the cost of more function evaluations. e.g. 5 point central differences

\[ f'(x) \approx \frac{-f(x-2\epsilon) + 8f(x-\epsilon) - 8f(x+\epsilon) + f(x+2\epsilon)}{12\epsilon} \]

  • In that case, use \(\epsilon = \sqrt[4]{\epsilon_{\text{machine}}}\)

Symbolic Differentiation

Roll up Your Sleeves

  • Do it by hand, or use Mathematica/Sympy/etc
  • Seems like it should always be better?
    • Often identical to auto-differentiation, though it gives you more control over algebra with subexpressions. Prone to algebra or coding errors
    • Substituting expressions could speed things up (or slow things down)
    • Less overhead than many auto-differentiation methods, which may lead to better performance. Or may not if you do a different calculation (e.g. flatten the computational graph)
  • Very useful in many cases, even if only for designing new AD “primitives”

Sub-Expressions and Computational Graphs

  • Take \(f(g(x), h(g(x)))\). Would you want to substitute/simplify the gradient?

\[ f'(x) = g'(x) f_1(g(x), h(g(x))) + g'(x) h'(g(x)) f_2(g(x), h(g(x))) \]

G A x B g(x) A->B C h(g(x)) B->C D f(g(x), h(g(x))) B->D C->D

Automatic Differentiation

Let the Computer Execute the Chain Rule

  • Auto-differentiation/differentiable programming works on “computer programs”. i.e., computational graphs are just functions

    1. Converts the program into a computational graph (i.e., nested functions)
    2. Apply the chain rule to the computational graph recursively
    3. Provide library of “primitives” where the recursion stops, and provides registration of new primitives to teach the computer calculus

Finally: many frameworks will compile the resulting sequence of operations to be efficient on a GPU since this is so central to deep learning performance

Example Program

  • AD systems take a “program” and convert it into a computational graph
  • Things that may seem incompatible with derivatives (e.g., loops, assignment, zeros) usually have some mapping to a function
  • Then using the graph, they can apply the chain rule automatically
def program(x1, x2):
    # Initial state
    a = x1 * x2
    
    # Hardcoded loop (2 iterations)
    for i in range(2):
        a = a + (x2 ** 2)
        
    y = jnp.log(a)
    return y

Computational Graph (Unrolling Loop and Adding Intermediates)

Step Operation Result
1 (in) \(x_1\) \(z_1\)
2 (in) \(x_2\) \(z_2\)
3 \(z_1 \cdot z_2\) \(z_3\)
4 (i=1) \(z_2^2\) \(z_4\)
5 \(z_3 + z_4\) \(z_5\)
6 (i=2) \(z_2^2\) \(z_6\)
7 \(z_5 + z_6\) \(z_7\)
8 (out) \(\log(z_7)\) \(z_8\)

G x1 x₁ z1 z₁ x1->z1 x2 x₂ z2 z₂ x2->z2 z3 z₃ × z1->z3 z2->z3 z4 z₄ (·)² z2->z4 z6 z₆ (·)² z2->z6 z5 z₅ + z3->z5 z4->z5 z7 z₇ + z5->z7 z6->z7 z8 z₈ log z7->z8

Forward and Reverse Mode

  • The chain rule can be done forwards or backwards
  • See Wikipedia for good examples. Intuition for \(f : \mathbb{R}^N \to \mathbb{R}^M\):
    • Forward-Mode: grab one of the \(N\) inputs and wiggle it to see impact on all \(M\) outputs. Need \(N\) passes to get full Jacobian
    • Reverse-Mode: grab one of the \(M\) outputs and wobble it to see impact on all \(N\) inputs. Need \(M\) passes to get full Jacobian
  • Hence, reverse-mode is good for calculating gradients when \(N \gg M\) (e.g. neural networks). If \(M = 1\) gradients are the same complexity as evaluating the function
  • Reverse-mode has significant overhead, so often forward-mode is preferred even if \(N > M\)

Forward and Backwards With the Computational Graph

  • See wikipedia for classic treatment, and ProbML: Introduction Section 13.3 for a special case
  • Useful to read, but missing key linear algebra interpretations that are useful for understanding how to adapt AD
  • Instead, we will think of AD as linearization/etc. and follow ProbML: Advanced Topics and the JAX documentation
    • While we won’t cover it, this is much more amenable to higher-order derivatives and perturbations

Reminder: Filling in a Matrix from a Linear Operator

  • A standard basis in \(\mathbb{R}^2\) is \(e_1 = \begin{bmatrix} 1 & 0 \end{bmatrix}^{\top}\) and \(e_2 = \begin{bmatrix} 0 & 1 \end{bmatrix}^{\top}\)
  • Given linear operator \(\mathcal{A} : \mathbb{R}^2 \to \mathbb{R}^3\) and adjoint \(\mathcal{A}^{\top} : \mathbb{R}^3 \to \mathbb{R}^2\) how can we get the underlying matrix (i.e. \(A\) such that \(\mathcal{A}(v) = A v\) for all \(v\in\mathbb{R}^2\))?
  1. Use the standard basis vectors \(e_1, e_2\) and calculate \(\mathcal{A}(e_1), \mathcal{A}(e_2)\)

    • Gives two columns of the \(A\) matrix, so \(A = \begin{bmatrix} \mathcal{A}(e_1) & \mathcal{A}(e_2) \end{bmatrix}\)
  2. Use the standard basis vectors \(e_1, e_2, e_3\) (now of \(\mathbb{R}^3\)) and calculate \(\mathcal{A}^{\top}(e_1), \mathcal{A}^{\top}(e_2), \mathcal{A}^{\top}(e_3)\)

    • Gives the three columns of \(A^{\top}\), i.e. \(A^{\top} = \begin{bmatrix} \mathcal{A}^{\top}(e_1) & \mathcal{A}^{\top}(e_2) & \mathcal{A}^{\top}(e_3) \end{bmatrix}\)

Jacobians and Linearization

  • Differentiation linearizes around a point, yielding the Jacobian
  • i.e., for \(f : \mathbb{R}^N \to \mathbb{R}^M\), then \(x \to \partial f(x)\) maps to an \(N \times M\) Jacobian matrix
    • But remember matrix-free linear operators!
  • Instead of the Jacobian as a matrix, think of matrix-vector products and \(\partial f(x) : \mathbb{R}^N \to \mathbb{R}^M\) as a linear operator
    • Note: \(x\) is the linearization point in that notation, not the argument
  • See ProbML: Advanced Topics Chapter 6

Push-Forwards and JVPs

  • Denote the operator, linearized around \(x\), and applied to \(v\in\mathbb{R}^N\) as

    \[ (x, v) \mapsto \partial f(x)[v] \in \mathbb{R}^M \]

    • This is called the “push-forward”. The Jacobian Vector Product (JVP)
    • i.e. \(\nabla f(x) \cdot v\), as the product of the jacobian and a direction
  • JAX (and others) will take an \(f\) and an \(x\) and compile a new function from \(\mathbb{R}^N\) to \(\mathbb{R}^M\) that calculates \(\partial f(x)[v]\)

Adjoints, Pullbacks, and VJPs

  • Just as we can transpose a linear operator, we can transpose the Jacobian around the linearization point, \(x\)

\[ \partial f(x)^{\top} : \mathbb{R}^M \to \mathbb{R}^N \]

  • Which lets us define the “pullback”: \((x, u) \mapsto \partial f(x)^{\top}[u] \in \mathbb{R}^N\)
  • Just as with matrix-free linear operators, we can think of this as an inner product: The Vector Jacobian Product (VJP)
  • i.e., \(u \cdot \nabla f(x)\) or \(\nabla f(x)^{\top} \cdot u\) is the reason for the “adjoint” terminology
  • JAX (and others) will take an \(f\) and an \(x\) and compile a new function from \(\mathbb{R}^M\) to \(\mathbb{R}^N\) that calculates \(\partial f(x)^{\top}[u]\)

Example of a Jacobian

Let \(f : \mathbb{R}^2 \to \mathbb{R}^2\) be defined as

\[ f(x) \equiv \begin{bmatrix} x_1^2 + x_2^2 \\ x_1 x_2 \end{bmatrix} \]

Then

\[ \nabla f(x) \equiv \begin{bmatrix} 2 x_1 & 2 x_2 \\ x_2 & x_1 \end{bmatrix} \]

JVP

Let \(v = \begin{bmatrix} 1 & 0 \end{bmatrix}^{\top}\), i.e. the \(e_1\) in the standard basis then

\[ \partial f(x)[v] = \nabla f(x) \cdot \begin{bmatrix} 1 \\ 0 \end{bmatrix} = \begin{bmatrix} 2 x_1 \\ x_2 \end{bmatrix} \]

  • Each gives a column of the Jacobian
  • Could use \(e_1, \ldots, e_N\) to get the full Jacobian

VJP

  • Let \(u = \begin{bmatrix} 1 & 0 \end{bmatrix}^{\top}\), i.e. the \(e_1\) in the standard basis then

\[ \partial f(x)^{\top}[u] = \begin{bmatrix} 1 & 0 \end{bmatrix} \cdot \nabla f(x) = \begin{bmatrix} 2 x_1 & 2 x_2 \end{bmatrix} \]

  • The first row of the Jacobian (or the first column of its transpose)
  • Could use \(e_1,\ldots,e_M\) we can get the full Jacobian

Chain Rule for JVP

  • Consider \(f : \mathbb{R}^N \rightarrow \mathbb{R}^M\) with \(f = c \circ b \circ a\)

\[ \partial f(x) = \partial c(b(a(x))) \circ \partial b(a(x)) \circ \partial a(x) \]

  • JVP against an input perturbation \(v \in \mathbb{R}^N\)
  • Moving inside out because as we perturbing inputs

\[ \partial f(x)[v] = \partial c(b(a(x))) \left[ \partial b(a(x))[\partial a(x)[v]] \right] \]

Calculation Order for JVP Chain Rule

\[ \partial f(x)[v] = \partial c(b(a(x))) \left[ \partial b(a(x))[\partial a(x)[v]] \right] \]

  • Calculation order inside out, recursively finding linearization points:

    1. \(\partial a(x)[v]\) and \(a(x)\)
    2. \(\partial b(a(x))[\partial a(x)[v]]\) and \(b(a(x))\)
    3. \(\partial c(b(a(x)))[\partial b(a(x))[\partial a(x)[v]]]\) (and \(c(b(a(x)))\) if required)
  • Conveniently follows calculating “primal” calculation. Many ways to do it (e.g. overloading, duals)

  • Can calculate the “primal” and the “push-forward” at the same time

Chain Rule for VJP

\[ \partial f(x) = \partial c(b(a(x))) \circ \partial b(a(x)) \circ \partial a(x) \]

  • Take the transpose,

\[ \partial f(x)^{\top} = \partial a(x)^{\top} \circ \partial b(a(x))^{\top} \circ \partial c(b(a(x)))^{\top} \]

  • In particular, if we multiply by some \(u \in \mathbb{R}^M\) (i.e., \(u \cdot \nabla f(x)\)), we get

\[ \partial f(x)^{\top}[u] = \partial a(x)^{\top} \left[ \partial b(a(x))^{\top} \left[ \partial c(b(a(x)))^{\top}[u] \right] \right] \]

Calculation Order for VJP Chain Rule

\[ \partial f(x)^{\top}[u] = \partial a(x)^{\top} \left[ \partial b(a(x))^{\top} \left[ \partial c(b(a(x)))^{\top}[u] \right] \right] \]

  • Calculation order outside in (of original):
    1. \(a(x), b(a(x)), c(b(a(x)))\) (i.e., the “primal” calculations required for linearization points)
    2. \(\partial c(b(a(x)))^{\top}[u]\)
    3. \(\partial b(a(x))^{\top}[\partial c(b(a(x)))^{\top}[u]]\)
    4. \(\partial a(x)^{\top}[\partial b(a(x))^{\top}[\partial c(b(a(x)))^{\top}[u]]]\)
  • Unlike with JVP, we need the full calculations before going backwards through the computational graph at the end (i.e., “backprop” terminology)

Complexity with Reverse-Mode AD

  • In principle for \(f : \mathbb{R}^N \to \mathbb{R}\) can calculate \(\nabla f(x)\) in the same computational order as \(f\) itself - independent of \(N\)
    • This is a key part of the secret sauce that makes ML possible
  • But in practice it isn’t quite so simple
    • Requires storage for entire “primal” graph before going backwards (unlike forward-mode). Inplace operations in primal often useless
    • Requires more complicated code to keep track of the steps in the computational graph, which creates overhead
  • This means that often forward-mode will be faster even when \(N > M\) (e.g., mabye 50-100 dimensions, but depends)

Fréchet Derivatives and Automatic Differentiation

Connection to the Fréchet Derivative

Automatic differentiation is best understood as the automatic composition of linear operators and their adjoints. Not as providing Jacobians

For \(f:\mathcal{X}\to\mathcal{Y}\) between Hilbert spaces, the Fréchet derivative at \(x\) is a bounded linear operator \[ Df(x): \mathcal{X} \to \mathcal{Y} \]

such that for perturbation direction \(h\in\mathcal{X}\)

\[ \lim_{\|h\|\to 0} \frac{\|f(x+h)-f(x)-Df(x)[h]\|_{\mathcal{Y}}}{\|h\|_{\mathcal{X}}} = 0 \]

Pushforwards and Pullbacks

  • AD systems compute applications of these operators, not the operators explicitly
    • Linear operators and adjoints easily compose and chain!
    • Note: using \(^{\top}\) for adjoint notation.
  • Hilbert space: By the Riesz theorem, \(Df(x)^{\top}:\mathcal{Y}\to\mathcal{X}\) lives in the same primal spaces
  • Coordinate-free: Defined via linear operators, not bases or matrices
  • Pushforward (JVP): Apply the derivative to a tangent direction \(v\) \[ v \mapsto Df(x)[v] \]
  • Pullback (VJP): Apply the adjoint to a cotangent \(u\) \[ u \mapsto Df(x)^{\top}[u] \]

Chain Rule: Pushforward (Forward-Mode AD)

  • Let \(f:\mathcal{X}\to\mathcal{Y}, g:\mathcal{Y}\to\mathcal{Z}\) be Fréchet differentiable and \(h = g\circ f\)
  • The Fréchet derivative satisfies the operator chain rule \[ Dh(x) = Dg(f(x)) \circ Df(x) \]
  • Applying this operator to a tangent direction \(v\in\mathcal{X}\), \[ Dh(x)[v] = Dg(f(x))\big[Df(x)[v]\big] \]

Interpretation:

  • Forward-mode AD propagates tangents
  • Each primitive applies its linearization
  • Tangents flow in the same order as the primal computation

Chain Rule: Pullback (Reverse-Mode AD)

  • Taking adjoints of the operator chain rule, with cotangent \(w\in\mathcal{Z}\), \[ Dh(x)^{\top} = Df(x)^{\top} \circ Dg(f(x))^{\top} \]

  • For a cotangent \(w\in\mathcal{Z}\), \[ Dh(x)^{\top}[w] = Df(x)^{\top}\big[Dg(f(x))^{\top}[w]\big] \]

Interpretation:

  • Reverse-mode AD propagates cotangents
  • Each primitive applies the adjoint of its linearization
  • Cotangents flow in the reverse order of the primal computation

Differentiable Programming Example

Reminder: Our Example Program

def program(x1, x2):
    # Initial state
    a = x1 * x2
    
    # Hardcoded loop (2 iterations)
    for i in range(2):
        a = a + (x2 ** 2)
        
    y = jnp.log(a)
    return y

Computational Graph (Unrolling Loop and Adding Intermediates)

Step Operation Result
1 (in) \(x_1\) \(z_1\)
2 (in) \(x_2\) \(z_2\)
3 \(z_1 \cdot z_2\) \(z_3\)
4 (i=1) \(z_2^2\) \(z_4\)
5 \(z_3 + z_4\) \(z_5\)
6 (i=2) \(z_2^2\) \(z_6\)
7 \(z_5 + z_6\) \(z_7\)
8 (out) \(\log(z_7)\) \(z_8\)

G x1 x₁ z1 z₁ x1->z1 x2 x₂ z2 z₂ x2->z2 z3 z₃ × z1->z3 z2->z3 z4 z₄ (·)² z2->z4 z6 z₆ (·)² z2->z6 z5 z₅ + z3->z5 z4->z5 z7 z₇ + z5->z7 z6->z7 z8 z₈ log z7->z8

Univariate Primitives

  • \(z = f(x)\) where \(f : \mathbb{R} \to \mathbb{R}\)
  • Standard notation in AD for linearization around \(x\)
    • Pushforward: \(\dot{z} = \partial f(x)[\dot{x}] = \nabla f(x) \cdot \dot{x}\)
    • Pullback: \(\bar{x} = \partial f(x)^\top[\bar{z}] = \bar{z} \cdot \nabla f(x)\)
    • i.e., compared to previous slides, \(\dot{z} = v\) and \(\bar{z} = u\)
Operation Primal (given \(x\)) JVP (given \(\dot{x}\)) VJP (given \(\bar{z}\))
Power \(x^n\) \(\dot{z} = n x^{n-1} \cdot \dot{x}\) \(\bar{x} = \bar{z} \cdot n x^{n-1}\)
Exponential \(\exp(x)\) \(\dot{z} = \exp(x) \cdot \dot{x}\) \(\bar{x} = \bar{z} \cdot \exp(x)\)
Logarithm \(\log(x)\) \(\dot{z} = \frac{\dot{x}}{x}\) \(\bar{x} = \frac{\bar{z}}{x}\)
Reciprocal \(\frac{1}{x}\) \(\dot{z} = -\frac{\dot{x}}{x^2}\) \(\bar{x} = -\frac{\bar{z}}{x^2}\)

Bivariate Primitives

  • \(z = f(x_1, x_2)\) where \(f : \mathbb{R}^2 \to \mathbb{R}\)
Operation Primal \(z\) JVP (given \(\dot{x}_1, \dot{x}_2\)) VJP (given \(\bar{z}\))
Addition \(x_1 + x_2\) \(\dot{z} = \dot{x}_1 + \dot{x}_2\) \(\bar{x}_1 = \bar{z}\), \(\bar{x}_2 = \bar{z}\)
Multiplication \(x_1 \cdot x_2\) \(\dot{z} = x_2 \dot{x}_1 + x_1 \dot{x}_2\) \(\bar{x}_1 = \bar{z} x_2\), \(\bar{x}_2 = \bar{z} x_1\)
Division \(\frac{x_1}{x_2}\) \(\dot{z} = \frac{\dot{x}_1}{x_2} - \frac{x_1 \dot{x}_2}{x_2^2}\) \(\bar{x}_1 = \frac{\bar{z}}{x_2}\), \(\bar{x}_2 = -\frac{\bar{z} x_1}{x_2^2}\)
Power \(x_1^{x_2}\) \(\dot{z} = x_2 x_1^{x_2-1} \dot{x}_1 + x_1^{x_2} \log(x_1) \dot{x}_2\) \(\bar{x}_1 = \bar{z} x_2 x_1^{x_2-1}\), \(\bar{x}_2 = \bar{z} x_1^{x_2} \log(x_1)\)

Forward-Mode AD for Our Example

  • Forward-mode takes the arguments \(x_1, x_2\) and “seed” direction \(\dot{x}_1, \dot{x}_2\)
  • Directional derivative (e.g., \(\dot{x}_1 = 1, \dot{x}_2 = 0\) gives \(\frac{\partial y}{\partial x_1}\))
  • Here we use the appropriate primitive JVP rules at each step
Step Primal Tangent (JVP)
1 (in) \(z_1 = x_1\) \(\dot{z}_1 = \dot{x}_1\)
2 (in) \(z_2 = x_2\) \(\dot{z}_2 = \dot{x}_2\)
3 \(z_3 = z_1 \cdot z_2\) \(\dot{z}_3 = z_2 \dot{z}_1 + z_1 \dot{z}_2\)
4 (i=1) \(z_4 = z_2^2\) \(\dot{z}_4 = 2\, z_2 \dot{z}_2\)
5 \(z_5 = z_3 + z_4\) \(\dot{z}_5 = \dot{z}_3 + \dot{z}_4\)
6 (i=2) \(z_6 = z_2^2\) \(\dot{z}_6 = 2\, z_2 \dot{z}_2\)
7 \(z_7 = z_5 + z_6\) \(\dot{z}_7 = \dot{z}_5 + \dot{z}_6\)
8 (out) \(z_8 = \log(z_7)\) \(\dot{z}_8 = \frac{\dot{z}_7}{z_7}\)

Reverse-Mode AD for Our Example

Step 1: Compute primal (forward). Step 2: Propagate cotangents (backward).

Primal (forward)

Step Primal
1 (in) \(z_1 = x_1\)
2 (in) \(z_2 = x_2\)
3 \(z_3 = z_1 \cdot z_2\)
4 (i=1) \(z_4 = z_2^2\)
5 \(z_5 = z_3 + z_4\)
6 (i=2) \(z_6 = z_2^2\)
7 \(z_7 = z_5 + z_6\)
8 (out) \(z_8 = \log(z_7)\)

Cotangent (backward, usually seed \(\bar{z}_8 = 1\))

Step VJP
8 \(\bar{z}_7 = \frac{\bar{z}_8}{z_7}\)
7 \(\bar{z}_5 \mathrel{+}= \bar{z}_7\), \(\bar{z}_6 \mathrel{+}= \bar{z}_7\)
6 \(\bar{z}_2 \mathrel{+}= 2\, z_2 \bar{z}_6\)
5 \(\bar{z}_3 \mathrel{+}= \bar{z}_5\), \(\bar{z}_4 \mathrel{+}= \bar{z}_5\)
4 \(\bar{z}_2 \mathrel{+}= 2\, z_2 \bar{z}_4\)
3 \(\bar{z}_1 \mathrel{+}= z_2 \bar{z}_3\), \(\bar{z}_2 \mathrel{+}= z_1 \bar{z}_3\)
2 \(\bar{x}_2 = \bar{z}_2\)
1 \(\bar{x}_1 = \bar{z}_1\)

Differentiating “Neural Networks”

“Neural Networks” and Deep Learning

  • The term “neural network” is very broad and covers a variety of approximation classes, with a high degree of flexibility.

  • While there is no theoretical requirement to use these patterns, in practice hardware is optimized to make them fast.

  • For example, consider the function \(f : \mathbb{R}^N \to \mathbb{R}^M\) defined as \[ y = f_{\theta}(x) = W_3 \sigma(W_2 \sigma(W_1 x + b_1) + b_2) + b_3 \]

    • \(W_1 \in \mathbb{R}^{K \times N}, W_2 \in \mathbb{R}^{K, \times K}, W_3 \in \mathbb{R}^{M \times K}, b_1 \in \mathbb{R}^K, b_2 \in \mathbb{R}^K, b_3 \in \mathbb{R}^M\)
    • \(\sigma(\cdot)\) is a scalar non-linear function applied componentwise
    • Combine into a set of parameters \(\theta \equiv \{W_1, W_2, W_3, b_1, b_2, b_3\}\)

Key Primitives and Their Adjoints

  • \(y = W x\) has adjoint \(\bar{x} = W^{\top} \bar{y}\)
  • \(y = a + b\) has adjoints
    • \(\bar{a} = \bar{y} b^{\top}\) and \(\bar{b} = a^{\top} \bar{y}\)
  • \(y = \mathrm{muladd}(W, x, b) \equiv W x + b\) has adjoints
    • \(\bar{x} = W^{\top} \bar{y}\), \(\bar{W} = \bar{y} x^{\top}\), and \(\bar{b} = \bar{y}\)
  • Denote \(\sigma(z)\) as an \({\mathbb{R}}\to {\mathbb{R}}\) function applied componentwise
    • The Jacobian is a diagonal matrix, \(\mathrm{diag}(\sigma'(z))\) where the scalar derivative is applied componentwise

Reverse-Mode AD for Neural Network (Separated Steps)

\[ y = f_{\theta}(x) = W_3 \sigma(W_2 \sigma(W_1 x + b_1) + b_2) + b_3, \quad \text{given } \bar{y} \in {\mathbb{R}}^M \]

Primal (forward)

Step Primal
1 \(t_1 = W_1 x\)
2 \(z_1 = t_1 + b_1\)
3 \(a_1 = \sigma(z_1)\)
4 \(t_2 = W_2 a_1\)
5 \(z_2 = t_2 + b_2\)
6 \(a_2 = \sigma(z_2)\)
7 \(t_3 = W_3 a_2\)
8 \(y = t_3 + b_3\)

Cotangent (backward)

Step VJP
8 \(\bar{t}_3 = \bar{y}\), \(\bar{b}_3 = \bar{y}\)
7 \(\bar{a}_2 = W_3^\top \bar{t}_3\), \(\bar{W}_3 = \bar{t}_3 a_2^\top\)
6 \(\bar{z}_2 = \mathrm{diag}(\sigma'(z_2))\, \bar{a}_2\)
5 \(\bar{t}_2 = \bar{z}_2\), \(\bar{b}_2 = \bar{z}_2\)
4 \(\bar{a}_1 = W_2^\top \bar{t}_2\), \(\bar{W}_2 = \bar{t}_2 a_1^\top\)
3 \(\bar{z}_1 = \mathrm{diag}(\sigma'(z_1))\, \bar{a}_1\)
2 \(\bar{t}_1 = \bar{z}_1\), \(\bar{b}_1 = \bar{z}_1\)
1 \(\bar{x} = W_1^\top \bar{t}_1\), \(\bar{W}_1 = \bar{t}_1 x^\top\)

Reverse-Mode AD with Muladd Primitive

Primal (forward)

Step Primal
1 \(z_1 = \mathrm{muladd}(W_1, x, b_1)\)
2 \(a_1 = \sigma(z_1)\)
3 \(z_2 = \mathrm{muladd}(W_2, a_1, b_2)\)
4 \(a_2 = \sigma(z_2)\)
5 \(y = \mathrm{muladd}(W_3, a_2, b_3)\)

Cotangent (backward)

Step VJP
5a \(\bar{a}_2 = W_3^\top \bar{y}\)
5b \(\bar{W}_3 = \bar{y} a_2^\top\)
5c \(\bar{b}_3 = \bar{y}\)
4 \(\bar{z}_2 = \mathrm{diag}(\sigma'(z_2))\, \bar{a}_2\)
3a \(\bar{a}_1 = W_2^\top \bar{z}_2\)
3b \(\bar{W}_2 = \bar{z}_2 a_1^\top\)
3c \(\bar{b}_2 = \bar{z}_2\)
2 \(\bar{z}_1 = \mathrm{diag}(\sigma'(z_1))\, \bar{a}_1\)
1a \(\bar{x} = W_1^\top \bar{z}_1\)
1b \(\bar{W}_1 = \bar{z}_1 x^\top\)
1c \(\bar{b}_1 = \bar{z}_1\)

Accelerating with Batching

  • That is a lot of linear algebra operations!
    • And a few easily parallelized componentwise nonlinearities.
  • The core of ML hardware is that it is specialized for these types of operations.
  • Consider if we wanted \(\begin{bmatrix} f(x_1) & f(x_2) & \cdots & f(x_B) \end{bmatrix}\) for a batch of inputs \(\{x_i\}_{i=1}^B\)
    • Then notice that we can do all of these operations in parallel with matrix-matrix calls
    • Let \(X \equiv \begin{bmatrix} x_1 & x_2 & \cdots & x_B \end{bmatrix} \in {\mathbb{R}}^{N \times B}\).
  • Automatically doing these sorts of steps is the purpose of vmap
  • See section on accelerators for more

Batched Forward Pass (e.g., what vmap might do)

  • Single input \(x \in {\mathbb{R}}^N\) becomes batch \(X \in {\mathbb{R}}^{N \times B}\) (columns are inputs)
  • Matrix-vector products become matrix-matrix products

Single input

Step Primal
1 \(z_1 = W_1 x + b_1\)
2 \(a_1 = \sigma(z_1)\)
3 \(z_2 = W_2 a_1 + b_2\)
4 \(a_2 = \sigma(z_2)\)
5 \(y = W_3 a_2 + b_3\)

Batched (vmap)

Step Primal
1 \(Z_1 = W_1 X + b_1 \mathbf{1}^\top\)
2 \(A_1 = \sigma(Z_1)\)
3 \(Z_2 = W_2 A_1 + b_2 \mathbf{1}^\top\)
4 \(A_2 = \sigma(Z_2)\)
5 \(Y = W_3 A_2 + b_3 \mathbf{1}^\top\)
  • \(b_i \mathbf{1}^\top\) broadcasts bias to all \(B\) columns (often implicit in frameworks)
  • \(Y \in {\mathbb{R}}^{M \times B}\): each column is output for corresponding input

Batched Backward Pass

  • Cotangent seed \(\bar{Y} \in {\mathbb{R}}^{M \times B}\) (one cotangent per sample). \(\bar{b}_i = \bar{Z}_i \mathbf{1}\) sums gradients across batch (row sums)
Step Batched VJP
5a \(\bar{A}_2 = W_3^\top \bar{Y}\)
5b \(\bar{W}_3 = \bar{Y} A_2^\top\)
5c \(\bar{b}_3 = \bar{Y} \mathbf{1}\)
4 \(\bar{Z}_2 = \mathrm{diag}(\sigma'(Z_2)) \bar{A}_2\)
3a \(\bar{A}_1 = W_2^\top \bar{Z}_2\)
3b \(\bar{W}_2 = \bar{Z}_2 A_1^\top\)
3c \(\bar{b}_2 = \bar{Z}_2 \mathbf{1}\)
2 \(\bar{Z}_1 = \mathrm{diag}(\sigma'(Z_1)) \bar{A}_1\)
1a \(\bar{X} = W_1^\top \bar{Z}_1\)
1b \(\bar{W}_1 = \bar{Z}_1 X^\top\)
1c \(\bar{b}_1 = \bar{Z}_1 \mathbf{1}\)

Software Examples

Software and Implementation

  • Tracking the computational graph for reverse-mode is tricky (especially if there were inplace modifications)
    • Mutating support rare for reverse-mode, functional style typical
  • The recursion goes forwards, backwards, or both ways down the computational graph until it hits a primitive
    • Recursion stops when it hits a function that has JVP/VJP implemented
  • These are two extreme cases, where in principle you can mix them (e.g., an internal function has \(\mathbb{R}^N \to \mathbb{R}^K \to \mathbb{R}\), where \(K \gg N\), then use forward-mode from \(K \to N\) and reverse-mode from \(N \to 1\))
  • Similar methods apply for higher order derivatives, e.g. Hessian-vector products and taylor series

Sparse Differentiation

  • For the full Jacobian a \(f : \mathbb{R}^N \rightarrow \mathbb{R}^M\) you need either \(N\) forward passes or \(M\) backwards passes
    • But if sparse, then maybe could use better directional derivatives than \(e_i\)
    • e.g. Tridiagonal matrices can be done with 3 directional derivatives.
  • See SparseDiffTools.jl and use FiniteDiff.jl
    • See sparsejac for an experimental version in JAX?
  • Finding the right directional derivatives is hard and requires knowing the sparsity pattern (and solves a problem equivalent to graph coloring)

Pytorch

  • See Probabilistic ML: Chapter 8 Notebook
  • Reverse-mode centric, especially convenient for neural networks but can be confusing for general functions
  • In general, you will find it the most convenient for a standard supervised learning problems (e.g. neural networks with empirical risk minimization)
  • We will discuss later when we look at ML pipelines
import torch

Example with Pytorch .backward()

  • Reverse-mode AD passes values with requires_grad=True
  • Traces the intermediates, and does the AD on .backward()
x = torch.tensor(2.0, requires_grad=True)
# Trace computations for the "forward" pass
y = torch.tanh(x)
# Do the "backward" pass for Reverse-mode AD
y.backward()
print(x.grad)

def f(x, y):
    return x**3 + 2 * y[0]**2 - 3 * y[1] + 1

x = torch.tensor(1.0, requires_grad=True)
y = torch.tensor([2.0, 3.0],
                  requires_grad=True)
z = f(x, y)
z.backward()
print(x.grad, y.grad)
tensor(0.0707)
tensor(3.) tensor([ 8., -3.])

JAX

  • See section on accelerators for more about compiation pipeline
  • Very flexible with high level tools (e.g. grad is \(\mathbb{R}^N \to \mathbb{R}\) reverse-diff) as well as lower-level functions to directly use jvp, vjp, and hessian-vector products
  • Emphasizing JAX here because non-trivial algorithms will typically require more flexibility to scale (e.g., cross-derivatives, matrix-free, etc.)
  • Easier to use for general functions rather than in standard estimation pipelines
  • See JAX Autodiff Cookbook
  • See Probabilistic ML: Chapter 8 Notebook
  • See JAX Advanced Autodiff

JAX Setup

  • From Autodiff cookbook and ProbML book 1 chapter 8
  • Random numbers always require keys, which can be split for reproducibility
  • Use jax.config.update('jax_enable_x64', True) for 64bit precision (default is 32bit)
import jax
import jax.numpy as jnp
import numpy as np
from jax import grad, jit, vmap
from jax import random, vjp, jvp
key = random.PRNGKey(0)
subkey1, subkey2 = random.split(key)
random.normal(subkey1, (2,))
Array([ 1.0040143, -0.9063372], dtype=float32)

High Level Grad (i.e. Reverse-mode)

  • grad is the high-level reverse-mode AD function
  • Returns a new function, which could be compiled
grad_tanh = grad(jnp.tanh)
print(grad_tanh(2.0))
grad_tanh_jit = jit(grad_tanh)
print(grad_tanh_jit(2.0))

def f(x):
    return x**3 + 2 * x**2 - 3 * x + 1
print(grad(f)(1.0))
@jit
def f2(x):
    return x**3 + 2 * x**2 - 3 * x + 1
grad(f2)(1.0)
0.070650816
0.070650816
4.0
Array(4., dtype=float32, weak_type=True)

Fixing an argument

def f3(x, y):
    return x**2 + y
v, gx = jax.value_and_grad(f3, argnums=0)(2.0, 3.0)
print(v)
print(gx)

gy = grad(f3, argnums=1)(2.0, 3.0)
print(gy)
7.0
4.0
1.0

Full Jacobians (Forward and Reverse)

  • Goes through full basis forwards or backwards
def fun(x):
    return jnp.dot(A, x)
A = np.random.normal(size=(4, 3))
x = np.random.normal(size=(3,))
Jf = jax.jacfwd(fun)(x)
Jr = jax.jacrev(fun)(x)
print(np.allclose(Jf, Jr))
True

Setup for Logistic Regression

def sigmoid(x):
    return 0.5 * (jnp.tanh(x / 2) + 1)

def predict(W, b, inputs):
    return sigmoid(jnp.dot(inputs, W) + b)
inputs = jnp.array([[0.52, 1.12,  0.77],
                   [0.88, -1.08, 0.15],
                   [0.52, 0.06, -1.30],
                   [0.74, -2.49, 1.39]])
targets = jnp.array([True, True, False, True])                   
key, W_key, b_key = random.split(key, 3)
W = random.normal(W_key, (3,))
b = random.normal(b_key, ())                    

JVP

  • For \(y = f(W)\): returns \((y, \dot{y})\) where \(\dot{y} = \partial f(W)[\dot{W}]\)
  • Computes primal and tangent together in one forward pass (no function created)
f = lambda W: predict(W, b, inputs)
W_dot = jnp.array([1.0, 0.0, 0.0])  # basis vector e_0

# Push forward: directional derivative in W[0] direction
y, y_dot = jvp(f, (W,), (W_dot,))
print(f"y = {y}\ndy/dW[0] = {y_dot}")
y = [0.10947311 0.79829013 0.41004258 0.99217653]
dy/dW[0] = [0.05069415 0.14170025 0.12579198 0.00574409]

VJP

  • For \(y = f(W)\): returns \((y, \text{vjp\_fun})\) where vjp_fun is the pullback operator
  • Calling vjp_fun(y_bar) computes \(\bar{W} = \partial f(W)^\top[\bar{y}]\)
  • Must create function because primal is computed first, then pullback called later
y, vjp_fun = vjp(f, W)
key, subkey = random.split(key)
y_bar = random.normal(subkey, y.shape)  # random cotangent

# Pull back
W_bar, = vjp_fun(y_bar)
print(f"y = {y}\ny_bar = {y_bar}\nW_bar = {W_bar}")
y = [0.10947311 0.79829013 0.41004258 0.99217653]
y_bar = [-1.2574776  -0.4016044  -1.1213601   0.87837774]
W_bar = [-0.25666684 -0.10071305  0.2580282 ]

Differentiating PyTrees

  • Key JAX feature is “flattening” of nested data
  • Works for arbitrarily nested tree structures
def loss2(params_dict):
    preds = predict(params_dict['W'], params_dict['b'], inputs)
    label_probs = preds * targets + (1 - preds) * (1 - targets)
    return -jnp.sum(jnp.log(label_probs))
params = {'W': W, 'b': b}
print(grad(loss2)(params))
{'W': Array([-0.433146 , -0.7354605, -1.2598922], dtype=float32), 'b': Array(-0.69001776, dtype=float32)}

Implicit Differentiation and Custom Rules

Primal and JVP/VJP Calculations are Separate

  • For a JVP or VJP, we first need to calculate the \(f(x)\)

    • This could involve complicated algorithms, external libraries, etc.
  • Often madness to descend recursively into primal calculations

    • e.g. if \(f(x) = \cos(x)\) then should it step inside \(\cos(x)\)?
    • Alternatively, use the known derivative to find

    \[ \partial f(x)[v] = -\sin(x) \cdot v \]

  • AD systems all have a library of these rules, and typically a way to create new ones for “custom” rules for complicated functions

Custom Rules/Primitives

  • Derive the derivative by hand, and register it with the AD system
  • See Matrix Algebra and Matrix Derivative Results, and ChainRules.jl Docs for examples
  • Derivations for forward-mode is relatively easier using the total derivative
  • Derivations for reverse-mode is difficult.
    • See tricks for reverse using the trace of the Frobenius Inner product.
  • See here for JAX implementation

Smooth Matrix Functions

  • Consider \(f : \mathbb{R}^{N \times N} \to \mathbb{R}^{N \times N}\) and following Mathias 1996 and Higham 2008
  • Assume a suitably smooth function and a perturbation \(\delta A\), where we want to calculate the forward-mode \(\partial f(A)[\delta A]\)
  • Then, apply \(f(\cdot)\) to the following \(\mathbb{R}^{2N \times 2 N}\) block matrix and extract the answer form the upper right corner \[ f\left(\begin{bmatrix} A & \delta A\\ 0 & A\end{bmatrix}\right) = \begin{bmatrix} f(A) & \partial f(A)[\delta A]\\ 0 & f(A)\end{bmatrix} \]
    • This is a remarkable result true for any \(\delta A\). Not always the most efficient way, but very general

Registering JVPs in JAX

  • f.defjvp implements: \((x, \dot{x}) \mapsto (f(x), \partial f(x)[\dot{x}])\)
@jax.custom_jvp
def f(x, y):
  return jnp.sin(x) * y
@f.defjvp
def f_jvp(primals, tangents):
  x, y = primals
  x_dot, y_dot = tangents
  primal_out = f(x, y)
  tangent_out = jnp.cos(x) * x_dot * y + jnp.sin(x) * y_dot
  return primal_out, tangent_out

print(f(2., 3.))
y, y_dot = jvp(f, (2., 3.), (1., 0.)) # perturb x, not y
print(y, y_dot)
2.7278922
2.7278922 -1.2484405

Deriving Rules for Matrix Inverse

  • Let \(f : \mathbb{R}^{N \times N} \to \mathbb{R}^{N \times N}\) where \(f(A) = A^{-1} = C\)

\[ \begin{aligned} I &= C A\\ 0 &= \partial C A + C \partial A \\ 0 &= \partial C A C + C (\partial A) C \\ 0 &= \partial C + C (\partial A) C \\ \partial C &= -C (\partial A) C \\ \end{aligned} \]

  • So given the \(A\), the “primal” can be calculated \(C = A^{-1}\) using the appropriate method (e.g. LU decomposition, cholesky, etc.)
  • Then the forward mode AD is just matrix products
  • Reverse is harder to derive (\(\partial A = - C^{\top} (\partial C) C^{\top}\))

Implicit Functions

Differentiating a Fixed Point Solution

  • Solve primal problem \(z^*(a) = f(a, z^*(a))\) for \(z^*(a)\) using Anderson iteration, Newton, etc. fixing \(a\). Use implicit function theorem at \(z^* \equiv z^*(a_0)\) \[ \frac{\partial z^*(a)}{\partial a} = \left[ I - \frac{\partial f(a, z^*)}{\partial z} \right]^{-1} \frac{\partial f(a, z^*)}{\partial a}. \]

  • For JVP: \((a, v) \mapsto \frac{\partial z^*(a)}{\partial a}v\)

\[ \frac{\partial z^*(a)}{\partial a}\cdot v = \left[ I - \frac{\partial f(a,z^*)}{\partial z} \right]^{-1} \frac{\partial f(a, z^*)}{\partial a}\cdot v \]

  • Note that this requires the gradients of \(f(a, z)\) using symbolics, AD, etc.

JAX Packages with Builtin Implicit Differentiation

  • Most JAX and Pytorch packages will be built with AD rules
import optimistix as optx

def F(x, factor):
    return factor * x ** 3 - x - 2

@jax.jit
def root(factor):
    solver = optx.Newton(rtol=1e-6, atol=1e-6)
    sol = optx.root_find(F, solver, y0=jnp.array(1.5),
                         args=factor, max_steps=20, throw=False)
    return sol.value

# Derivative of root with respect to factor at 2.0
print(grad(root)(2.0))
-0.22139916

Hardware Acceleration

JAX, MLIR, and XLA the

  • JAX uses a multi-stage compilation pipeline to optimize and execute code
┌─────────────┐    ┌─────────────┐    ┌─────────────┐    ┌─────────────┐
│   Python    │    │   Jaxpr     │    │ StableHLO   │    │  Machine    │
│    Code     │───▶│  (JAX IR)   │───▶│   (MLIR)    │───▶│    Code     │
└─────────────┘    └─────────────┘    └─────────────┘    └─────────────┘
     jax.jit           tracing          lowering         XLA compile
  • Jaxpr: JAX’s intermediate representation (functional, SSA-style)
  • StableHLO: ~100 portable ops (MLIR dialect), 5-year backward compatibility
    • Emphasis on operations common to ML workloads, not general purpose
    • Not just JAX: PyTorch/XLA, TensorFlow, Reactant.jl also target StableHLO
  • XLA: Hardware-specific compiler (GPU/TPU/CPU), performs fusion and optimization
    • Optimized for accelerator driven operations

Program Transformation

  • Compilers like JAX take the sorts of computational graphs described above and find ways to parallelize them by “fusing” multiple operations together on GPUs, etc.
    • In JAX, the vmap makes the code clearer because you write more natural code and let the JAX compiler handle efficient parallelization
    • In pytorch you often need to organize your code to do batching yourself (e.g., using unsqueeze and expand to broadcast dimensions and use tensor reductions rather than just matrix/vector operations).
  • The core design of JAX is that vmap, jit, and vjp/jvp can all be composed together
    • Then the MLIR/XLA compiler fuses and optimizes the resulting code for the target hardware

The Compilation Process

  • Next we will peak inside some of these JAX passes to get a glimpse of what goes on below the hood
  • You never need to do this yourself directly
  • While the JAX IR and MLIR are hardware-agnostic, the XLA compiler generates hardware-specific code (e.g., CUDA for NVIDIA GPUs) and you would see different output for different hardware targets

JAX Tracing to Jaxpr

  • make_jaxpr shows JAX’s internal representation of your program
  • Operations like mul and add are explicit primitives
  • The internal representation of JAX is functional (i.e., no side effects)
  • It needs to know the sizes at compilation time for GPU accelerators/etc.
    • e.g.. a:f32[3] vs. a:f32[4] are different types and it needs to recompile
import jax
import jax.numpy as jnp
def my_program(a, b, c):
    return a * b + c
# Need "dummy" arguments for tracing dtypes, sizes
a, b, c = jnp.ones(3), jnp.ones(3), jnp.ones(3)
print(jax.make_jaxpr(my_program)(a, b, c))
a2, b2, c2 = jnp.ones(4), jnp.ones(4), jnp.ones(4)
print(jax.make_jaxpr(my_program)(a2, b2, c2))
{ lambda ; a:f32[3] b:f32[3] c:f32[3]. let

    d:f32[3] = mul a b

    e:f32[3] = add d c

  in (e,) }

{ lambda ; a:f32[4] b:f32[4] c:f32[4]. let

    d:f32[4] = mul a b

    e:f32[4] = add d c

  in (e,) }

Lowering to MLIR/StableHLO

  • .lower() converts Jaxpr to StableHLO (MLIR dialect that XLA consumes)
  • This is the portable representation between frameworks and compilers
    • i.e., GPUs, TPUs, and CPUs all share the same StableHLO representation, but it is tuned for ML-style workloads and parallelization
    • Again, notice that the types (e.g., tensor<3xf32> are fixed)
jitted_fn = jax.jit(my_program)
lowered = jitted_fn.lower(a, b, c)
# Show first 20 lines of StableHLO
print('\n'.join(lowered.as_text().split('\n')[:20]))
module @jit_my_program attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>, %arg2: tensor<3xf32>) -> (tensor<3xf32> {jax.result_info = "result"}) {
    %0 = stablehlo.multiply %arg0, %arg1 : tensor<3xf32>
    %1 = stablehlo.add %0, %arg2 : tensor<3xf32>
    return %1 : tensor<3xf32>
  }
}

jax.vjp: The VJP Primitive

  • vjp(f, *primals) returns (output, vjp_fn) - separates forward/backward
  • vjp_fn(cotangent) computes the cotangents (backward pass)
  • For \(y = a \cdot b + c\): \(\bar{a} = \bar{y} \cdot b\), \(\bar{b} = \bar{y} \cdot a\), \(\bar{c} = \bar{y}\)
  • The jax.vjp creates a function (with captured primals) as a function of \bar{y}
def f(a, b, c):
    return a * b + c

a, b, c = 2.0, 3.0, 1.0
_, vjp_fn = jax.vjp(f, a, b, c)
# Show the Jaxpr for the backward pass

# \bar{y} = 1
print(jax.make_jaxpr(vjp_fn)(1.0))
# Jaxpr: Input a = y_bar. Captures primals: 2.0=a, 3.0=b
# b = mul(2.0, a) = a * y_bar -> b_bar
# c = mul(a, 3.0) = y_bar * b -> a_bar
# Output (c, b, a) = (a_bar, b_bar, c_bar)
{ lambda ; a:f32[]. let

    b:f32[] = mul 2.0:f32[] a

    c:f32[] = mul a 3.0:f32[]

  in (c, b, a) }

jax.grad: Convenience for Scalar Functions

  • grad(f): calls vjp with cotangent \(\bar{y} = 1.0\) returning the gradient
  • Works only for scalar-output functions
  • Can differentiate again for higher-order AD (it’s just a function!)
grad_fn = jax.grad(f, argnums=(0, 1, 2))
print(jax.make_jaxpr(grad_fn)(a, b, c))
# jaxpr Forward: d = mul(a,b), e = add(d,c)
# Backward: mul VJP -> (b, a), add VJP -> (1,)
# The mul 1.0 etc. is since y_bar = 1.0
{ lambda ; a:f32[] b:f32[] c:f32[]. let

    d:f32[] = mul a b

    _:f32[] = add d c

    e:f32[] = mul a 1.0:f32[]

    f:f32[] = mul 1.0:f32[] b

  in (f, e, 1.0:f32[]) }

JAX Limitations

  • The JAX pipeline (tracing Python → MLIR → accelerator codegen) is extremely powerful
    • However, not ideal for arbitrary Python - works best for ML-style operations
    • Classic challenges: control flow, mutation, dynamic shapes, scalar-heavy code
  • Must map to StableHLO operations that XLA optimizes:
    • Element-wise: add, multiply, tanh, exp
    • Linear algebra: dot_general, cholesky, triangular_solve
    • Reductions: reduce, reduce_window (pooling)
    • Shape: broadcast_in_dim, reshape, transpose
  • See JAX Sharp Edges for details

JAX Challenge: Loops (The Problem)

  • Python loops get unrolled during tracing. Slow compilation, large IR, doesn’t scale
def sum_with_loop(x):
    total = 0.0
    for i in range(3):
        total = total + x
    return total

x = 2.0
print(jax.make_jaxpr(sum_with_loop)(x))
{ lambda ; a:f32[]. let

    b:f32[] = add 0.0:f32[] a

    c:f32[] = add b a

    d:f32[] = add c a

  in (d,) }

JAX Challenge: Loops (The Solution)

def sum_with_fori(x):
    def body(i, total):
        return total + x
    return jax.lax.fori_loop(0, 3, body, 0.0)

# Compact IR with while_loop primitive
print(jax.make_jaxpr(sum_with_fori)(x))
{ lambda ; a:f32[]. let

    _:i32[] b:f32[] = scan[

      _split_transpose=False

      jaxpr={ lambda ; c:f32[] d:i32[] e:f32[]. let

          f:i32[] = add d 1:i32[]

          g:f32[] = add e c

        in (f, g) }

      length=3

      linear=(False, False, False)

      num_carry=2

      num_consts=1

      reverse=False

      unroll=1

    ] a 0:i32[] 0.0:f32[]

  in (b,) }

JAX Challenge: Conditionals (The Problem)

  • Python if/else cannot be traced - condition depends on runtime value
  • JAX doesn’t know which branch to take at trace time
def abs_python(x):
    if x >= 0:  # Error: can't convert tracer to bool
        return x
    else:
        return -x

try:
    print(jax.make_jaxpr(abs_python)(1.0))
except jax.errors.TracerBoolConversionError as e:
    print(e)
Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function abs_python at /tmp/ipykernel_4362/624549699.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument x.
See https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerBoolConversionError

JAX Challenge: Conditionals (The Solution)

  • Use lax.cond for true conditional execution
  • Only the selected branch is executed at runtime (i.e., the argument free lambda :)
def abs_lax(x):
    return jax.lax.cond(
        x >= 0,
        lambda: x,      # true branch
        lambda: -x      # false branch
    )

# Uses cond primitive - branches are separate
print(jax.make_jaxpr(abs_lax)(1.0))
{ lambda ; a:f32[]. let

    b:bool[] = ge a 0.0:f32[]

    c:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b

    d:f32[] = cond[

      branches=(

        { lambda ; e:f32[]. let f:f32[] = neg e in (f,) }

        { lambda ; g:f32[]. let  in (g,) }

      )

    ] c a

  in (d,) }

JAX Challenge: Mutation (The Problem)

  • JAX arrays are immutable - direct assignment fails
  • This enables functional transformations but requires different patterns
x = jnp.array([1.0, 2.0, 3.0])
try:
    x[0] = 99.0
except TypeError as e:
    print(e)
JAX arrays are immutable and do not support in-place item assignment. Instead of x[idx] = y, use x = x.at[idx].set(y) or another .at[] method: https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html

JAX Challenge: Mutation (A Solution To be Avoided)

  • Use .at[].set() for functional updates, returns a new array. Avoid at all costs!
def set_first(x, val):
    return x.at[0].set(val)

x = jnp.array([1.0, 2.0, 3.0])

# Functional update - creates new array with each call
# complexity of the generated code hints at how
# inefficient this simple operation is in jax
print(jax.make_jaxpr(set_first)(x, 99.0))
{ lambda ; a:f32[3] b:f32[]. let

    c:i32[1] = broadcast_in_dim[

      broadcast_dimensions=()

      shape=(1,)

      sharding=None

    ] 0:i32[]

    d:f32[] = convert_element_type[new_dtype=float32 weak_type=False] b

    e:f32[3] = scatter[

      dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,), operand_batching_dims=(), scatter_indices_batching_dims=())

      indices_are_sorted=True

      mode=GatherScatterMode.FILL_OR_DROP

      unique_indices=True

      update_consts=()

      update_jaxpr=None

    ] a c d

  in (e,) }

JAX Challenge: Dynamic Shapes (The Problem)

  • JIT requires static shapes - boolean indexing creates dynamic shapes
  • Can’t know output size at compile time, so must use slow non-compiled code
def filter_positive(x):
    mask = x > 0
    return x[mask]  # Dynamic shape!

x = jnp.array([-1.0, 2.0, -3.0, 4.0])
# Works without JIT
print(f"Eager: {filter_positive(x)}")
# But fails with JIT
try:
    jax.jit(filter_positive)(x)
except Exception as e:
    print(e)
Eager: [2. 4.]
Array boolean indices must be concrete; got bool[4]

See https://docs.jax.dev/en/latest/errors.html#jax.errors.NonConcreteBooleanIndexError

JAX Challenge: Dynamic Shapes (The Solution)

  • Use jnp.where to maintain static shapes
  • MLIR/XLA are vector/matrix native, so prefer masking over dynamic shapes
    • These challenges are part of the reason that scalar code (e.g., loops, mutation) are often much slower in JAX.
    • Python has limited general compilers (e.g., numba), consider matlab/julia/fortran
def sum_positive(x):
    # Keep static shape, zero out negatives
    return jnp.where(x > 0, x, 0.0).sum()

x = jnp.array([-1.0, 2.0, -3.0, 4.0])
print(jax.make_jaxpr(sum_positive)(x))
{ lambda ; a:f32[4]. let

    b:bool[4] = gt a 0.0:f32[]

    c:f32[4] = jit[

      name=_where

      jaxpr={ lambda ; b:bool[4] a:f32[4] d:f32[]. let

          e:f32[] = convert_element_type[new_dtype=float32 weak_type=False] d

          f:f32[4] = broadcast_in_dim[

            broadcast_dimensions=()

            shape=(4,)

            sharding=None

          ] e

          c:f32[4] = select_n b f a

        in (c,) }

    ] b a 0.0:f32[]

    g:f32[] = reduce_sum[axes=(0,) out_sharding=None] c

  in (g,) }

Reactant.jl (with Enzyme.jl): Julia Targeting MLIR

  • Enzyme.jl: Source-to-source AD at the LLVM level for Julia
    • Supports both forward and reverse mode
    • Works on arbitrary Julia code, including mutation
  • Reactant.jl: Julia framework that compiles to StableHLO/MLIR
    • Targets the same XLA backend as JAX
    • Julia code gets the same GPU optimizations, kernel fusion, etc.
  • This is the power of MLIR as a common IR:
    • Multiple frontends (JAX, Julia via Reactant, PyTorch/XLA)
    • Share the same optimizing compiler backend (XLA)
    • Write in your preferred language, get world-class performance