Machine Learning Fundamentals for Economists
optax is a common package for ML optimization methodskeykey you get the same value.N = 500 # samples
M = 2
sigma = 0.001
key = random.PRNGKey(42)
# Pattern: split before using key, replace name "key"
key, *subkey = random.split(key, num=4)
theta = random.normal(subkey[0], (M,))
X = random.normal(subkey[1], (N, M))
Y = X @ theta + sigma * random.normal(subkey[2], (N,)) # Adding noiseyielddef data_loader(key, X, Y, batch_size):
N = X.shape[0]
assert N == Y.shape[0]
indices = jnp.arange(N)
indices = random.permutation(key, indices)
# Loop over batches and yield
for i in range(0, N, batch_size):
b_indices = indices[i:i + batch_size]
yield X[b_indices], Y[b_indices]
# e.g. iterate and get first element
dl_test = data_loader(key, X, Y, 4)
print(next(iter(dl_test)))(Array([[-0.92034245, -0.7187076 ],
[-0.6151726 , 0.47314 ],
[-0.35952824, -0.8299562 ],
[ 0.88198936, -0.3076048 ]], dtype=float32), Array([-1.1311196 , 0.0050716 , -0.88230723, 0.28763232], dtype=float32))
theta_0 = [-0.21089035 -1.3627948 ], theta = [0.60576403 0.7990441 ]
params rather than the model itselfoptimizer.init(theta_0) provides the initial state for the iterationsfor epoch in range(num_epochs):
key, subkey = random.split(key) # changing key for shuffling each epoch
train_loader = data_loader(subkey, X, Y, batch_size)
for X_batch, Y_batch in train_loader:
params, opt_state, train_loss = make_step(params, opt_state, X_batch, Y_batch)
if epoch % 100 == 0:
print(f"Epoch {epoch},||theta - theta_hat|| = {jnp.linalg.norm(theta - params)}")
print(f"||theta - theta_hat|| = {jnp.linalg.norm(theta - params)}")Epoch 0,||theta - theta_hat|| = 2.1659655570983887
Epoch 100,||theta - theta_hat|| = 0.0036812787875533104
Epoch 200,||theta - theta_hat|| = 6.539194873766974e-05
||theta - theta_hat|| = 6.539194873766974e-05
vectorized_residuals was able to use a directly vectorized function.vmapdef residual(theta, x, y):
y_hat = predict(theta, x)
return (y_hat - y) ** 2
@jit
def residuals(theta, X, Y):
# Use vmap, fixing the 1st argument
batched_residuals = jax.vmap(residual, in_axes=(None, 0, 0))
return jnp.mean(batched_residuals(theta, X, Y))
print(residual(theta_0, X[0], Y[0]))
print(residuals(theta_0, X, Y))2.6319637
5.4140573
value_and_grad call to use the new residuals function and resets our optimizer@jax.jit
def make_step(params, opt_state, X, Y):
loss_value, grads = jax.value_and_grad(residuals)(params, X, Y)
updates, opt_state = optimizer.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)
return params, opt_state, loss_value
optimizer = optax.sgd(lr) # better than optax.adam here
opt_state = optimizer.init(theta_0)
params = theta_0for epoch in range(num_epochs):
key, subkey = random.split(key) # changing key for shuffling each epoch
train_loader = data_loader(subkey, X, Y, batch_size)
for X_batch, Y_batch in train_loader:
params, opt_state, train_loss = make_step(params, opt_state, X_batch, Y_batch)
if epoch % 100 == 0:
print(f"Epoch {epoch},||theta - theta_hat|| = {jnp.linalg.norm(theta - params)}")
print(f"||theta - theta_hat|| = {jnp.linalg.norm(theta - params)}")Epoch 0,||theta - theta_hat|| = 2.167938232421875
Epoch 100,||theta - theta_hat|| = 0.003675078274682164
Epoch 200,||theta - theta_hat|| = 6.522066541947424e-05
||theta - theta_hat|| = 6.522066541947424e-05
vmap as above@nnx.jit which replaces @jax.jitmodelbatch_size = 64
for epoch in range(500):
key, subkey = random.split(key)
train_loader = data_loader(subkey, X, Y, batch_size)
for X_batch, Y_batch in train_loader:
loss = train_step(model, optimizer, X_batch, Y_batch)
if epoch % 100 == 0:
norm_diff = jnp.linalg.norm(theta - jnp.squeeze(model.kernel.value))
print(f"Epoch {epoch},||theta-theta_hat|| = {norm_diff}")
norm_diff = jnp.linalg.norm(theta - jnp.squeeze(model.kernel.value))
print(f"||theta - theta_hat|| = {norm_diff}")Epoch 0,||theta-theta_hat|| = 1.2717349529266357
Epoch 100,||theta-theta_hat|| = 0.24903634190559387
Epoch 200,||theta-theta_hat|| = 0.04919437691569328
Epoch 300,||theta-theta_hat|| = 0.00985759124159813
Epoch 400,||theta-theta_hat|| = 0.002040109597146511
||theta - theta_hat|| = 0.0004721158475149423
nnx.Module or create/use differentiable nnx.Paramclass MyLinear(nnx.Module):
def __init__(self, in_size, out_size, rngs):
self.out_size = out_size
self.in_size = in_size
self.kernel = nnx.Param(jax.random.normal(rngs(), (self.out_size, self.in_size)))
# Similar to Pytorch's forward
def __call__(self, x):
return self.kernel @ x
model = MyLinear(M, 1, rngs = rngs)optimizer = nnx.Optimizer(model, optax.sgd(0.001), wrt=nnx.Param)
for epoch in range(500):
for X_batch, Y_batch in train_loader:
loss = train_step(model, optimizer, X_batch, Y_batch)
if epoch % 100 == 0:
norm_diff = jnp.linalg.norm(theta - jnp.squeeze(model.kernel.value))
print(f"Epoch {epoch},||theta-theta_hat|| = {norm_diff}")
norm_diff = jnp.linalg.norm(theta - jnp.squeeze(model.kernel.value))
print(f"||theta - theta_hat|| = {norm_diff}")Epoch 0,||theta-theta_hat|| = 0.6275200247764587
Epoch 100,||theta-theta_hat|| = 0.6275200247764587
Epoch 200,||theta-theta_hat|| = 0.6275200247764587
Epoch 300,||theta-theta_hat|| = 0.6275200247764587
Epoch 400,||theta-theta_hat|| = 0.6275200247764587
||theta - theta_hat|| = 0.6275200247764587
nnx.Module the nnx.Param are values which you might look to differentiate, others are fixedout_size, in_size, kernel. We only want to differentate the kernel since wrapped in nnx.Paramnnx.split and to recombine use nnx.mergeGraphDef(nodes=[NodeDef(
type='MyLinear',
index=0,
outer_index=None,
num_attributes=5,
metadata=MyLinear
), NodeDef(
type='GenericPytree',
index=None,
outer_index=None,
num_attributes=0,
metadata=({}, PyTreeDef(CustomNode(PytreeState[(False, False)], [])))
), VariableDef(
type='Param',
index=1,
outer_index=None,
metadata=PrettyMapping({
'is_hijax': False,
'has_ref': False,
'is_mutable': True,
'eager_sharding': True
})
)], attributes=[('_pytree__nodes', Static(value={'_pytree__state': True, 'out_size': False, 'in_size': False, 'kernel': True, '_pytree__nodes': False})), ('_pytree__state', NodeAttr()), ('in_size', Static(value=2)), ('kernel', NodeAttr()), ('out_size', Static(value=1))], num_leaves=1)
graphdef was the fixed structure, state is the differentiablennx.merge to combine the fixed and differentiable partsState({ 'kernel': Param( # 2 (8 B) value=Array([[-0.2166012, -1.9878021]], dtype=float32) ) }) MyLinear( # Param: 2 (8 B) in_size=2, kernel=Param( # 2 (8 B) value=Array(shape=(1, 2), dtype=dtype('float32')) ), out_size=1 )
jax with nnx equivalents
nnx.jit, nnx.value_and_grad etc. automatically filter for Params