Machine Learning Fundamentals for Economists
A few general types of differentiation
Numerical Differentiation (i.e., finite differences)
Symbolic Differentiation (i.e., chain rule and simplify subexpressions by hand)
Automatic Differentiation (i.e., execute chain rule on computer)
Sparse Differentiation (i.e., use one of the above to calculate directional derivatives, potentially filling in sparse Jacobians with fewer passes)
\[ \frac{\partial f(x)}{\partial x_i} \approx \frac{f(x + \epsilon e_i) - f(x)}{\epsilon} \]
\[ f'(x) \approx \frac{-f(x-2\epsilon) + 8f(x-\epsilon) - 8f(x+\epsilon) + f(x+2\epsilon)}{12\epsilon} \]
\[ f'(x) = g'(x) f_1(g(x), h(g(x))) + g'(x) h'(g(x)) f_2(g(x), h(g(x))) \]
Auto-differentiation/differentiable programming works on “computer programs”. i.e., computational graphs are just functions
Finally: many frameworks will compile the resulting sequence of operations to be efficient on a GPU since this is so central to deep learning performance
| Step | Operation | Result |
|---|---|---|
| 1 (in) | \(x_1\) | \(z_1\) |
| 2 (in) | \(x_2\) | \(z_2\) |
| 3 | \(z_1 \cdot z_2\) | \(z_3\) |
| 4 (i=1) | \(z_2^2\) | \(z_4\) |
| 5 | \(z_3 + z_4\) | \(z_5\) |
| 6 (i=2) | \(z_2^2\) | \(z_6\) |
| 7 | \(z_5 + z_6\) | \(z_7\) |
| 8 (out) | \(\log(z_7)\) | \(z_8\) |
Use the standard basis vectors \(e_1, e_2\) and calculate \(\mathcal{A}(e_1), \mathcal{A}(e_2)\)
Use the standard basis vectors \(e_1, e_2, e_3\) (now of \(\mathbb{R}^3\)) and calculate \(\mathcal{A}^{\top}(e_1), \mathcal{A}^{\top}(e_2), \mathcal{A}^{\top}(e_3)\)
Denote the operator, linearized around \(x\), and applied to \(v\in\mathbb{R}^N\) as
\[ (x, v) \mapsto \partial f(x)[v] \in \mathbb{R}^M \]
JAX (and others) will take an \(f\) and an \(x\) and compile a new function from \(\mathbb{R}^N\) to \(\mathbb{R}^M\) that calculates \(\partial f(x)[v]\)
\[ \partial f(x)^{\top} : \mathbb{R}^M \to \mathbb{R}^N \]
Let \(f : \mathbb{R}^2 \to \mathbb{R}^2\) be defined as
\[ f(x) \equiv \begin{bmatrix} x_1^2 + x_2^2 \\ x_1 x_2 \end{bmatrix} \]
Then
\[ \nabla f(x) \equiv \begin{bmatrix} 2 x_1 & 2 x_2 \\ x_2 & x_1 \end{bmatrix} \]
Let \(v = \begin{bmatrix} 1 & 0 \end{bmatrix}^{\top}\), i.e. the \(e_1\) in the standard basis then
\[ \partial f(x)[v] = \nabla f(x) \cdot \begin{bmatrix} 1 \\ 0 \end{bmatrix} = \begin{bmatrix} 2 x_1 \\ x_2 \end{bmatrix} \]
\[ \partial f(x)^{\top}[u] = \begin{bmatrix} 1 & 0 \end{bmatrix} \cdot \nabla f(x) = \begin{bmatrix} 2 x_1 & 2 x_2 \end{bmatrix} \]
\[ \partial f(x) = \partial c(b(a(x))) \circ \partial b(a(x)) \circ \partial a(x) \]
\[ \partial f(x)[v] = \partial c(b(a(x))) \left[ \partial b(a(x))[\partial a(x)[v]] \right] \]
\[ \partial f(x)[v] = \partial c(b(a(x))) \left[ \partial b(a(x))[\partial a(x)[v]] \right] \]
Calculation order inside out, recursively finding linearization points:
Conveniently follows calculating “primal” calculation. Many ways to do it (e.g. overloading, duals)
Can calculate the “primal” and the “push-forward” at the same time
\[ \partial f(x) = \partial c(b(a(x))) \circ \partial b(a(x)) \circ \partial a(x) \]
\[ \partial f(x)^{\top} = \partial a(x)^{\top} \circ \partial b(a(x))^{\top} \circ \partial c(b(a(x)))^{\top} \]
\[ \partial f(x)^{\top}[u] = \partial a(x)^{\top} \left[ \partial b(a(x))^{\top} \left[ \partial c(b(a(x)))^{\top}[u] \right] \right] \]
\[ \partial f(x)^{\top}[u] = \partial a(x)^{\top} \left[ \partial b(a(x))^{\top} \left[ \partial c(b(a(x)))^{\top}[u] \right] \right] \]
Automatic differentiation is best understood as the automatic composition of linear operators and their adjoints. Not as providing Jacobians
For \(f:\mathcal{X}\to\mathcal{Y}\) between Hilbert spaces, the Fréchet derivative at \(x\) is a bounded linear operator \[ Df(x): \mathcal{X} \to \mathcal{Y} \]
such that for perturbation direction \(h\in\mathcal{X}\)
\[ \lim_{\|h\|\to 0} \frac{\|f(x+h)-f(x)-Df(x)[h]\|_{\mathcal{Y}}}{\|h\|_{\mathcal{X}}} = 0 \]
Interpretation:
Taking adjoints of the operator chain rule, with cotangent \(w\in\mathcal{Z}\), \[ Dh(x)^{\top} = Df(x)^{\top} \circ Dg(f(x))^{\top} \]
For a cotangent \(w\in\mathcal{Z}\), \[ Dh(x)^{\top}[w] = Df(x)^{\top}\big[Dg(f(x))^{\top}[w]\big] \]
Interpretation:
| Step | Operation | Result |
|---|---|---|
| 1 (in) | \(x_1\) | \(z_1\) |
| 2 (in) | \(x_2\) | \(z_2\) |
| 3 | \(z_1 \cdot z_2\) | \(z_3\) |
| 4 (i=1) | \(z_2^2\) | \(z_4\) |
| 5 | \(z_3 + z_4\) | \(z_5\) |
| 6 (i=2) | \(z_2^2\) | \(z_6\) |
| 7 | \(z_5 + z_6\) | \(z_7\) |
| 8 (out) | \(\log(z_7)\) | \(z_8\) |
| Operation | Primal (given \(x\)) | JVP (given \(\dot{x}\)) | VJP (given \(\bar{z}\)) |
|---|---|---|---|
| Power | \(x^n\) | \(\dot{z} = n x^{n-1} \cdot \dot{x}\) | \(\bar{x} = \bar{z} \cdot n x^{n-1}\) |
| Exponential | \(\exp(x)\) | \(\dot{z} = \exp(x) \cdot \dot{x}\) | \(\bar{x} = \bar{z} \cdot \exp(x)\) |
| Logarithm | \(\log(x)\) | \(\dot{z} = \frac{\dot{x}}{x}\) | \(\bar{x} = \frac{\bar{z}}{x}\) |
| Reciprocal | \(\frac{1}{x}\) | \(\dot{z} = -\frac{\dot{x}}{x^2}\) | \(\bar{x} = -\frac{\bar{z}}{x^2}\) |
| Operation | Primal \(z\) | JVP (given \(\dot{x}_1, \dot{x}_2\)) | VJP (given \(\bar{z}\)) |
|---|---|---|---|
| Addition | \(x_1 + x_2\) | \(\dot{z} = \dot{x}_1 + \dot{x}_2\) | \(\bar{x}_1 = \bar{z}\), \(\bar{x}_2 = \bar{z}\) |
| Multiplication | \(x_1 \cdot x_2\) | \(\dot{z} = x_2 \dot{x}_1 + x_1 \dot{x}_2\) | \(\bar{x}_1 = \bar{z} x_2\), \(\bar{x}_2 = \bar{z} x_1\) |
| Division | \(\frac{x_1}{x_2}\) | \(\dot{z} = \frac{\dot{x}_1}{x_2} - \frac{x_1 \dot{x}_2}{x_2^2}\) | \(\bar{x}_1 = \frac{\bar{z}}{x_2}\), \(\bar{x}_2 = -\frac{\bar{z} x_1}{x_2^2}\) |
| Power | \(x_1^{x_2}\) | \(\dot{z} = x_2 x_1^{x_2-1} \dot{x}_1 + x_1^{x_2} \log(x_1) \dot{x}_2\) | \(\bar{x}_1 = \bar{z} x_2 x_1^{x_2-1}\), \(\bar{x}_2 = \bar{z} x_1^{x_2} \log(x_1)\) |
| Step | Primal | Tangent (JVP) |
|---|---|---|
| 1 (in) | \(z_1 = x_1\) | \(\dot{z}_1 = \dot{x}_1\) |
| 2 (in) | \(z_2 = x_2\) | \(\dot{z}_2 = \dot{x}_2\) |
| 3 | \(z_3 = z_1 \cdot z_2\) | \(\dot{z}_3 = z_2 \dot{z}_1 + z_1 \dot{z}_2\) |
| 4 (i=1) | \(z_4 = z_2^2\) | \(\dot{z}_4 = 2\, z_2 \dot{z}_2\) |
| 5 | \(z_5 = z_3 + z_4\) | \(\dot{z}_5 = \dot{z}_3 + \dot{z}_4\) |
| 6 (i=2) | \(z_6 = z_2^2\) | \(\dot{z}_6 = 2\, z_2 \dot{z}_2\) |
| 7 | \(z_7 = z_5 + z_6\) | \(\dot{z}_7 = \dot{z}_5 + \dot{z}_6\) |
| 8 (out) | \(z_8 = \log(z_7)\) | \(\dot{z}_8 = \frac{\dot{z}_7}{z_7}\) |
Step 1: Compute primal (forward). Step 2: Propagate cotangents (backward).
Primal (forward)
| Step | Primal |
|---|---|
| 1 (in) | \(z_1 = x_1\) |
| 2 (in) | \(z_2 = x_2\) |
| 3 | \(z_3 = z_1 \cdot z_2\) |
| 4 (i=1) | \(z_4 = z_2^2\) |
| 5 | \(z_5 = z_3 + z_4\) |
| 6 (i=2) | \(z_6 = z_2^2\) |
| 7 | \(z_7 = z_5 + z_6\) |
| 8 (out) | \(z_8 = \log(z_7)\) |
Cotangent (backward, usually seed \(\bar{z}_8 = 1\))
| Step | VJP |
|---|---|
| 8 | \(\bar{z}_7 = \frac{\bar{z}_8}{z_7}\) |
| 7 | \(\bar{z}_5 \mathrel{+}= \bar{z}_7\), \(\bar{z}_6 \mathrel{+}= \bar{z}_7\) |
| 6 | \(\bar{z}_2 \mathrel{+}= 2\, z_2 \bar{z}_6\) |
| 5 | \(\bar{z}_3 \mathrel{+}= \bar{z}_5\), \(\bar{z}_4 \mathrel{+}= \bar{z}_5\) |
| 4 | \(\bar{z}_2 \mathrel{+}= 2\, z_2 \bar{z}_4\) |
| 3 | \(\bar{z}_1 \mathrel{+}= z_2 \bar{z}_3\), \(\bar{z}_2 \mathrel{+}= z_1 \bar{z}_3\) |
| 2 | \(\bar{x}_2 = \bar{z}_2\) |
| 1 | \(\bar{x}_1 = \bar{z}_1\) |
The term “neural network” is very broad and covers a variety of approximation classes, with a high degree of flexibility.
While there is no theoretical requirement to use these patterns, in practice hardware is optimized to make them fast.
For example, consider the function \(f : \mathbb{R}^N \to \mathbb{R}^M\) defined as \[ y = f_{\theta}(x) = W_3 \sigma(W_2 \sigma(W_1 x + b_1) + b_2) + b_3 \]
\[ y = f_{\theta}(x) = W_3 \sigma(W_2 \sigma(W_1 x + b_1) + b_2) + b_3, \quad \text{given } \bar{y} \in {\mathbb{R}}^M \]
Primal (forward)
| Step | Primal |
|---|---|
| 1 | \(t_1 = W_1 x\) |
| 2 | \(z_1 = t_1 + b_1\) |
| 3 | \(a_1 = \sigma(z_1)\) |
| 4 | \(t_2 = W_2 a_1\) |
| 5 | \(z_2 = t_2 + b_2\) |
| 6 | \(a_2 = \sigma(z_2)\) |
| 7 | \(t_3 = W_3 a_2\) |
| 8 | \(y = t_3 + b_3\) |
Cotangent (backward)
| Step | VJP |
|---|---|
| 8 | \(\bar{t}_3 = \bar{y}\), \(\bar{b}_3 = \bar{y}\) |
| 7 | \(\bar{a}_2 = W_3^\top \bar{t}_3\), \(\bar{W}_3 = \bar{t}_3 a_2^\top\) |
| 6 | \(\bar{z}_2 = \mathrm{diag}(\sigma'(z_2))\, \bar{a}_2\) |
| 5 | \(\bar{t}_2 = \bar{z}_2\), \(\bar{b}_2 = \bar{z}_2\) |
| 4 | \(\bar{a}_1 = W_2^\top \bar{t}_2\), \(\bar{W}_2 = \bar{t}_2 a_1^\top\) |
| 3 | \(\bar{z}_1 = \mathrm{diag}(\sigma'(z_1))\, \bar{a}_1\) |
| 2 | \(\bar{t}_1 = \bar{z}_1\), \(\bar{b}_1 = \bar{z}_1\) |
| 1 | \(\bar{x} = W_1^\top \bar{t}_1\), \(\bar{W}_1 = \bar{t}_1 x^\top\) |
Primal (forward)
| Step | Primal |
|---|---|
| 1 | \(z_1 = \mathrm{muladd}(W_1, x, b_1)\) |
| 2 | \(a_1 = \sigma(z_1)\) |
| 3 | \(z_2 = \mathrm{muladd}(W_2, a_1, b_2)\) |
| 4 | \(a_2 = \sigma(z_2)\) |
| 5 | \(y = \mathrm{muladd}(W_3, a_2, b_3)\) |
Cotangent (backward)
| Step | VJP |
|---|---|
| 5a | \(\bar{a}_2 = W_3^\top \bar{y}\) |
| 5b | \(\bar{W}_3 = \bar{y} a_2^\top\) |
| 5c | \(\bar{b}_3 = \bar{y}\) |
| 4 | \(\bar{z}_2 = \mathrm{diag}(\sigma'(z_2))\, \bar{a}_2\) |
| 3a | \(\bar{a}_1 = W_2^\top \bar{z}_2\) |
| 3b | \(\bar{W}_2 = \bar{z}_2 a_1^\top\) |
| 3c | \(\bar{b}_2 = \bar{z}_2\) |
| 2 | \(\bar{z}_1 = \mathrm{diag}(\sigma'(z_1))\, \bar{a}_1\) |
| 1a | \(\bar{x} = W_1^\top \bar{z}_1\) |
| 1b | \(\bar{W}_1 = \bar{z}_1 x^\top\) |
| 1c | \(\bar{b}_1 = \bar{z}_1\) |
vmapvmap might do)Single input
| Step | Primal |
|---|---|
| 1 | \(z_1 = W_1 x + b_1\) |
| 2 | \(a_1 = \sigma(z_1)\) |
| 3 | \(z_2 = W_2 a_1 + b_2\) |
| 4 | \(a_2 = \sigma(z_2)\) |
| 5 | \(y = W_3 a_2 + b_3\) |
Batched (vmap)
| Step | Primal |
|---|---|
| 1 | \(Z_1 = W_1 X + b_1 \mathbf{1}^\top\) |
| 2 | \(A_1 = \sigma(Z_1)\) |
| 3 | \(Z_2 = W_2 A_1 + b_2 \mathbf{1}^\top\) |
| 4 | \(A_2 = \sigma(Z_2)\) |
| 5 | \(Y = W_3 A_2 + b_3 \mathbf{1}^\top\) |
| Step | Batched VJP |
|---|---|
| 5a | \(\bar{A}_2 = W_3^\top \bar{Y}\) |
| 5b | \(\bar{W}_3 = \bar{Y} A_2^\top\) |
| 5c | \(\bar{b}_3 = \bar{Y} \mathbf{1}\) |
| 4 | \(\bar{Z}_2 = \mathrm{diag}(\sigma'(Z_2)) \bar{A}_2\) |
| 3a | \(\bar{A}_1 = W_2^\top \bar{Z}_2\) |
| 3b | \(\bar{W}_2 = \bar{Z}_2 A_1^\top\) |
| 3c | \(\bar{b}_2 = \bar{Z}_2 \mathbf{1}\) |
| 2 | \(\bar{Z}_1 = \mathrm{diag}(\sigma'(Z_1)) \bar{A}_1\) |
| 1a | \(\bar{X} = W_1^\top \bar{Z}_1\) |
| 1b | \(\bar{W}_1 = \bar{Z}_1 X^\top\) |
| 1c | \(\bar{b}_1 = \bar{Z}_1 \mathbf{1}\) |
.backward()requires_grad=True.backward()x = torch.tensor(2.0, requires_grad=True)
# Trace computations for the "forward" pass
y = torch.tanh(x)
# Do the "backward" pass for Reverse-mode AD
y.backward()
print(x.grad)
def f(x, y):
return x**3 + 2 * y[0]**2 - 3 * y[1] + 1
x = torch.tensor(1.0, requires_grad=True)
y = torch.tensor([2.0, 3.0],
requires_grad=True)
z = f(x, y)
z.backward()
print(x.grad, y.grad)tensor(0.0707)
tensor(3.) tensor([ 8., -3.])
grad is \(\mathbb{R}^N \to \mathbb{R}\) reverse-diff) as well as lower-level functions to directly use jvp, vjp, and hessian-vector productsjax.config.update('jax_enable_x64', True) for 64bit precision (default is 32bit)grad is the high-level reverse-mode AD function0.070650816
0.070650816
4.0
Array(4., dtype=float32, weak_type=True)
def sigmoid(x):
return 0.5 * (jnp.tanh(x / 2) + 1)
def predict(W, b, inputs):
return sigmoid(jnp.dot(inputs, W) + b)
inputs = jnp.array([[0.52, 1.12, 0.77],
[0.88, -1.08, 0.15],
[0.52, 0.06, -1.30],
[0.74, -2.49, 1.39]])
targets = jnp.array([True, True, False, True])
key, W_key, b_key = random.split(key, 3)
W = random.normal(W_key, (3,))
b = random.normal(b_key, ()) y = [0.10947311 0.79829013 0.41004258 0.99217653]
dy/dW[0] = [0.05069415 0.14170025 0.12579198 0.00574409]
vjp_fun is the pullback operatorvjp_fun(y_bar) computes \(\bar{W} = \partial f(W)^\top[\bar{y}]\)y = [0.10947311 0.79829013 0.41004258 0.99217653]
y_bar = [-1.2574776 -0.4016044 -1.1213601 0.87837774]
W_bar = [-0.25666684 -0.10071305 0.2580282 ]
{'W': Array([-0.433146 , -0.7354605, -1.2598922], dtype=float32), 'b': Array(-0.69001776, dtype=float32)}
For a JVP or VJP, we first need to calculate the \(f(x)\)
Often madness to descend recursively into primal calculations
\[ \partial f(x)[v] = -\sin(x) \cdot v \]
AD systems all have a library of these rules, and typically a way to create new ones for “custom” rules for complicated functions
f.defjvp implements: \((x, \dot{x}) \mapsto (f(x), \partial f(x)[\dot{x}])\)@jax.custom_jvp
def f(x, y):
return jnp.sin(x) * y
@f.defjvp
def f_jvp(primals, tangents):
x, y = primals
x_dot, y_dot = tangents
primal_out = f(x, y)
tangent_out = jnp.cos(x) * x_dot * y + jnp.sin(x) * y_dot
return primal_out, tangent_out
print(f(2., 3.))
y, y_dot = jvp(f, (2., 3.), (1., 0.)) # perturb x, not y
print(y, y_dot)2.7278922
2.7278922 -1.2484405
\[ \begin{aligned} I &= C A\\ 0 &= \partial C A + C \partial A \\ 0 &= \partial C A C + C (\partial A) C \\ 0 &= \partial C + C (\partial A) C \\ \partial C &= -C (\partial A) C \\ \end{aligned} \]
Solve primal problem \(z^*(a) = f(a, z^*(a))\) for \(z^*(a)\) using Anderson iteration, Newton, etc. fixing \(a\). Use implicit function theorem at \(z^* \equiv z^*(a_0)\) \[ \frac{\partial z^*(a)}{\partial a} = \left[ I - \frac{\partial f(a, z^*)}{\partial z} \right]^{-1} \frac{\partial f(a, z^*)}{\partial a}. \]
For JVP: \((a, v) \mapsto \frac{\partial z^*(a)}{\partial a}v\)
\[ \frac{\partial z^*(a)}{\partial a}\cdot v = \left[ I - \frac{\partial f(a,z^*)}{\partial z} \right]^{-1} \frac{\partial f(a, z^*)}{\partial a}\cdot v \]
import optimistix as optx
def F(x, factor):
return factor * x ** 3 - x - 2
@jax.jit
def root(factor):
solver = optx.Newton(rtol=1e-6, atol=1e-6)
sol = optx.root_find(F, solver, y0=jnp.array(1.5),
args=factor, max_steps=20, throw=False)
return sol.value
# Derivative of root with respect to factor at 2.0
print(grad(root)(2.0))-0.22139916
┌─────────────┐ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐
│ Python │ │ Jaxpr │ │ StableHLO │ │ Machine │
│ Code │───▶│ (JAX IR) │───▶│ (MLIR) │───▶│ Code │
└─────────────┘ └─────────────┘ └─────────────┘ └─────────────┘
jax.jit tracing lowering XLA compile
vmap makes the code clearer because you write more natural code and let the JAX compiler handle efficient parallelizationunsqueeze and expand to broadcast dimensions and use tensor reductions rather than just matrix/vector operations).vmap, jit, and vjp/jvp can all be composed together
make_jaxpr shows JAX’s internal representation of your programmul and add are explicit primitivesa:f32[3] vs. a:f32[4] are different types and it needs to recompileimport jax
import jax.numpy as jnp
def my_program(a, b, c):
return a * b + c
# Need "dummy" arguments for tracing dtypes, sizes
a, b, c = jnp.ones(3), jnp.ones(3), jnp.ones(3)
print(jax.make_jaxpr(my_program)(a, b, c))
a2, b2, c2 = jnp.ones(4), jnp.ones(4), jnp.ones(4)
print(jax.make_jaxpr(my_program)(a2, b2, c2)){ lambda ; a:f32[3] b:f32[3] c:f32[3]. let
d:f32[3] = mul a b
e:f32[3] = add d c
in (e,) }
{ lambda ; a:f32[4] b:f32[4] c:f32[4]. let
d:f32[4] = mul a b
e:f32[4] = add d c
in (e,) }
.lower() converts Jaxpr to StableHLO (MLIR dialect that XLA consumes)tensor<3xf32> are fixed)module @jit_my_program attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>, %arg2: tensor<3xf32>) -> (tensor<3xf32> {jax.result_info = "result"}) {
%0 = stablehlo.multiply %arg0, %arg1 : tensor<3xf32>
%1 = stablehlo.add %0, %arg2 : tensor<3xf32>
return %1 : tensor<3xf32>
}
}
jax.vjp: The VJP Primitivevjp(f, *primals) returns (output, vjp_fn) - separates forward/backwardvjp_fn(cotangent) computes the cotangents (backward pass)jax.vjp creates a function (with captured primals) as a function of \bar{y}def f(a, b, c):
return a * b + c
a, b, c = 2.0, 3.0, 1.0
_, vjp_fn = jax.vjp(f, a, b, c)
# Show the Jaxpr for the backward pass
# \bar{y} = 1
print(jax.make_jaxpr(vjp_fn)(1.0))
# Jaxpr: Input a = y_bar. Captures primals: 2.0=a, 3.0=b
# b = mul(2.0, a) = a * y_bar -> b_bar
# c = mul(a, 3.0) = y_bar * b -> a_bar
# Output (c, b, a) = (a_bar, b_bar, c_bar){ lambda ; a:f32[]. let
b:f32[] = mul 2.0:f32[] a
c:f32[] = mul a 3.0:f32[]
in (c, b, a) }
jax.grad: Convenience for Scalar Functionsgrad(f): calls vjp with cotangent \(\bar{y} = 1.0\) returning the gradient{ lambda ; a:f32[] b:f32[] c:f32[]. let
d:f32[] = mul a b
_:f32[] = add d c
e:f32[] = mul a 1.0:f32[]
f:f32[] = mul 1.0:f32[] b
in (f, e, 1.0:f32[]) }
add, multiply, tanh, expdot_general, cholesky, triangular_solvereduce, reduce_window (pooling)broadcast_in_dim, reshape, transposelax.fori_loop or lax.scanscan primitive for any size known at compile-time{ lambda ; a:f32[]. let
_:i32[] b:f32[] = scan[
_split_transpose=False
jaxpr={ lambda ; c:f32[] d:i32[] e:f32[]. let
f:i32[] = add d 1:i32[]
g:f32[] = add e c
in (f, g) }
length=3
linear=(False, False, False)
num_carry=2
num_consts=1
reverse=False
unroll=1
] a 0:i32[] 0.0:f32[]
in (b,) }
if/else cannot be traced - condition depends on runtime valueAttempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function abs_python at /tmp/ipykernel_4362/624549699.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument x.
See https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerBoolConversionError
lax.cond for true conditional executionlambda :){ lambda ; a:f32[]. let
b:bool[] = ge a 0.0:f32[]
c:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b
d:f32[] = cond[
branches=(
{ lambda ; e:f32[]. let f:f32[] = neg e in (f,) }
{ lambda ; g:f32[]. let in (g,) }
)
] c a
in (d,) }
.at[].set() for functional updates, returns a new array. Avoid at all costs!{ lambda ; a:f32[3] b:f32[]. let
c:i32[1] = broadcast_in_dim[
broadcast_dimensions=()
shape=(1,)
sharding=None
] 0:i32[]
d:f32[] = convert_element_type[new_dtype=float32 weak_type=False] b
e:f32[3] = scatter[
dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,), operand_batching_dims=(), scatter_indices_batching_dims=())
indices_are_sorted=True
mode=GatherScatterMode.FILL_OR_DROP
unique_indices=True
update_consts=()
update_jaxpr=None
] a c d
in (e,) }
Eager: [2. 4.]
Array boolean indices must be concrete; got bool[4]
See https://docs.jax.dev/en/latest/errors.html#jax.errors.NonConcreteBooleanIndexError
jnp.where to maintain static shapes{ lambda ; a:f32[4]. let
b:bool[4] = gt a 0.0:f32[]
c:f32[4] = jit[
name=_where
jaxpr={ lambda ; b:bool[4] a:f32[4] d:f32[]. let
e:f32[] = convert_element_type[new_dtype=float32 weak_type=False] d
f:f32[4] = broadcast_in_dim[
broadcast_dimensions=()
shape=(4,)
sharding=None
] e
c:f32[4] = select_n b f a
in (c,) }
] b a 0.0:f32[]
g:f32[] = reduce_sum[axes=(0,) out_sharding=None] c
in (g,) }