Machine Learning Fundamentals for Economists
jit compiler to XLA, including accelerators (e.g. GPUs)grad Auto-differentiationvmap vectorizationuv is great as a pip replacement, but conda sometimes has better binary supportFrom JAX quickstart
Builtin composable transformations: jit, grad and vmap
jit1.23 ms ± 7.03 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
670 μs ± 1.3 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
jit@jitgradCommon to run the same function along one dimension of an array
638 μs ± 1.54 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
vmapThe vmap applies across a dimension
vmapCan fix dimensions with in_axes
vmap functionsCan fix dimensions with in_axes
jitable are tricky11.0
25.0
[11. 25.]
The in_axes can match more complicated structures
[1. 2. 3. 4. 5.]
dct = {'a': jnp.array([3.0, 5.0]), 'b': jnp.array([2.0, 4.0])}
def foo2(dct, x):
return dct['a'] + dct['b'] + x
# axes must match shape of the PyTree
x = 1.
out = vmap(foo2, in_axes=(
{'a': 0, 'b': 0}, #broadcast over the 'a' and 'b'
None # no broadcasting over the "x"
))(dct, x)
# example now: {'a': 3.0, 'b': 2.0} etc.
print(out)[ 6. 10.]
dct = {'a': jnp.array([3.0, 5.0]), 'b': jnp.arange(5.)}
def foo3(dct, x):
return dct['a'][0] * dct['a'][1] + dct['b'] + x
# axes must match shape of the PyTree
out = vmap(foo3, in_axes=(
{'a': None, 'b': 0}, #broadcast over the 'b'
None # no broadcasting over the "x"
))(dct, x)
# example now: {'a': [3.0, 5.0], 'b': 0} etc.
print(out)[16. 17. 18. 19. 20.]