Machine Learning Fundamentals for Economists
Population Risk
\[ {f^*}= \arg\min_{f \in {\mathcal{F}}} \underbrace{\mathbb{E}_{(x,y)\sim \mu^*} \left[\ell(f, x, y)\right]}_{\equiv R(f, \mu^*)} \]
Empirical Risk
\[ {\theta^*}= \arg\min_{\theta \in \Theta}\underbrace{\frac{1}{|{\mathcal{D}}|}\sum_{(x,y) \in {\mathcal{D}}} \ell(f_{\theta}, x, y)}_{\equiv \hat{R}(\theta,{\mathcal{D}})} \]
\[ \small \mathbb{E}_{{\mathcal{D}}\overset{\mathrm{iid}}{\sim}\mu^*}\left[\min_{\theta \in \Theta} \hat{R}(\theta, {\mathcal{D}}) - \min_{f \in \mathcal{F}} R(f, \mu^*)\right] = \underbrace{ R({f_{{\theta^*}}}, \mu^*) - R({f^*}, \mu^*)}_{\equiv {\varepsilon_{\mathrm{app}}}({f_{{\theta^*}}})} + \underbrace{\mathbb{E}_{{\mathcal{D}}\overset{\mathrm{iid}}{\sim}\mu^*}\left[\hat{R}(\theta^*, {\mathcal{D}}) - R({f_{{\theta^*}}}, \mu^*)\right]}_{\equiv {\varepsilon_{\mathrm{gen}}}({f_{{\theta^*}}})} \]
We can think of this manual process as taking the raw data \(x\) and transforming it into a representation \(z\in {\mathcal{Z}}\) with \(\phi : {\mathcal{X}}\to {\mathcal{Z}}\)
Then, instead of finding a \(f_{\theta} : {\mathcal{X}}\to {\mathcal{Y}}\), we can find a \(h_{\theta} : {\mathcal{Z}}\to {\mathcal{Y}}\) and use \(\ell(h_{\theta} \circ \phi, x, y)\). i.e., \({f_{{\theta^*}}}\equiv h_{{\theta^*}} \circ \phi\)
\[ {\theta^*}= \arg\min_{\theta \in \Theta}\frac{1}{|{\mathcal{D}}|}\sum_{(x,y) \in {\mathcal{D}}} \ell(h_{\theta} \circ \phi, x, y) \]
Suppose \(x \in {\mathbb{R}}\) and we want to approximate \(f(x)\) with a polynomial of degree \(d\)
For some polynomial basis \(T_1(x), \ldots, T_d(x)\) (e.g., monomials, Chebyshev, Legendre, etc.) \[ \phi(x) = \begin{bmatrix}1 & T_1(x) & T_2(x) & \cdots & T_d(x)\end{bmatrix}^{\top} \]
Approximate \(f_{\theta}(x)\) with \(h_{\theta}(z) = W^{\top} z\), where \(W \in {\mathbb{R}}^{d+1}\) and \(W \in \theta\) \[ f_{\theta}(x) \equiv W^{\top} \phi(x) = \sum_{i=0}^d W_i T_i(x) \]
The term “neural network” is very broad and covers a variety of approximation classes, with a high degree of flexibility.
While there is no theoretical requirement to use these patterns, in practice hardware is optimized to make them fast.
For example, consider the function \(f : \mathbb{R}^N \to \mathbb{R}^M\) defined as \[ y = f_{\theta}(x) = W_3 \sigma(W_2 \sigma(W_1 x + b_1) + b_2) + b_3 \]
For examples on why multiple layers help see: Mark Schmidt’s CPSC 440, CPSC340, ProbML Book 1 Section 13.2.1 on the XOR Problem and 13.2.5-13.2.7 for more
Use economic intuition and problem specific knowledge to design \({\mathcal{H}}\)
For example, you can approximation function \(f : \mathbb{R}^N \to \mathbb{R}\) which are symmetric in arguments (i.e. permutation invariance) with \(\rho : \mathbb{R}^M \to \mathbb{R}\), \(\phi : \mathbb{R} \to \mathbb{R}^M\)
\[ f(X) = \rho\left(\frac{1}{N}\sum_{x\in X} \phi(x)\right) \]
See Probabilistic Symmetries and Invariant Neural Networks or Exploiting Symmetry in High Dimensional Dynamic Programming
\[ \min_{\theta_e, \theta_d} \mathbb{E}_{x \sim \mu^*} (h(\phi(x;\theta_e);\theta_d) - x)^2 + \text{regularizer} \]
MyLinear case for a linear function without an affine term (i.e., no “bias”)nnx.Paramnnx.ModuleMyLinear it finds the nnx.Param.nnx.grad does not perturb out_size, etc.State({
'kernel': Param( # 3 (12 B)
value=Array([[ 0.60576403, 0.7990441 , -0.908927 ]], dtype=float32)
)
})
nnx.Modulennx.Linear instead, construct a simple NNclass MyMLP(nnx.Module):
def __init__(self, din, dout, width: int, *, rngs: nnx.Rngs):
self.width = width
self.linear1 = nnx.Linear(din, width, use_bias = False, rngs=rngs)
self.linear2 = nnx.Linear(width, dout, use_bias = True, rngs=rngs)
def __call__(self, x: jax.Array):
x = self.linear1(x)
x = nnx.relu(x)
x = self.linear2(x)
return x
m = MyMLP(N, 1, 2, rngs = rngs) \[ \bar{m} = \partial f(m)^{\top}[\bar{f}] = \bar{f} \cdot \nabla f(m) \]
jax.grad does for scalar functionsnnx.grad does this recursively through each nnx.Module and its nnx.Param valuesState({ 'linear1': { 'kernel': Param( # 6 (24 B) value=Array([[ 0.2700633 , 0.36592388], [-0.6980436 , -0.23368652], [-0.1246624 , -1.0047528 ]], dtype=float32) ) }, 'linear2': { 'bias': Param( # 1 (4 B) value=Array([0.], dtype=float32) ), 'kernel': Param( # 2 (8 B) value=Array([[-0.21603853], [-1.0095432 ]], dtype=float32) ) } })
graphdef Contains Fixed Values and MetadataGraphDef(nodes=[NodeDef(
type='MyMLP',
index=0,
outer_index=None,
num_attributes=5,
metadata=MyMLP
), NodeDef(
type='GenericPytree',
index=None,
outer_index=None,
num_attributes=0,
metadata=({}, PyTreeDef(CustomNode(PytreeState[(False, False)], [])))
), NodeDef(
type='Linear',
index=1,
outer_index=None,
num_attributes=13,
metadata=Linear
), NodeDef(
type='GenericPytree',
index=None,
outer_index=None,
num_attributes=0,
metadata=({}, PyTreeDef(CustomNode(PytreeState[(False, False)], [])))
), NodeDef(
type='NoneType',
index=None,
outer_index=None,
num_attributes=0,
metadata=None
), NodeDef(
type='NoneType',
index=None,
outer_index=None,
num_attributes=0,
metadata=None
), VariableDef(
type='Param',
index=2,
outer_index=None,
metadata=PrettyMapping({
'is_hijax': False,
'has_ref': False,
'is_mutable': True,
'eager_sharding': True
})
), NodeDef(
type='NoneType',
index=None,
outer_index=None,
num_attributes=0,
metadata=None
), NodeDef(
type='NoneType',
index=None,
outer_index=None,
num_attributes=0,
metadata=None
), NodeDef(
type='Linear',
index=3,
outer_index=None,
num_attributes=13,
metadata=Linear
), NodeDef(
type='GenericPytree',
index=None,
outer_index=None,
num_attributes=0,
metadata=({}, PyTreeDef(CustomNode(PytreeState[(False, False)], [])))
), VariableDef(
type='Param',
index=4,
outer_index=None,
metadata=PrettyMapping({
'is_hijax': False,
'has_ref': False,
'is_mutable': True,
'eager_sharding': True
})
), NodeDef(
type='NoneType',
index=None,
outer_index=None,
num_attributes=0,
metadata=None
), VariableDef(
type='Param',
index=5,
outer_index=None,
metadata=PrettyMapping({
'is_hijax': False,
'has_ref': False,
'is_mutable': True,
'eager_sharding': True
})
), NodeDef(
type='NoneType',
index=None,
outer_index=None,
num_attributes=0,
metadata=None
), NodeDef(
type='NoneType',
index=None,
outer_index=None,
num_attributes=0,
metadata=None
)], attributes=[('_pytree__nodes', Static(value={'_pytree__state': True, 'width': False, 'linear1': True, 'linear2': True, '_pytree__nodes': False})), ('_pytree__state', NodeAttr()), ('linear1', NodeAttr()), ('_pytree__nodes', Static(value={'_pytree__state': True, 'kernel': True, 'bias': True, 'in_features': False, 'out_features': False, 'use_bias': False, 'dtype': False, 'param_dtype': False, 'precision': False, 'dot_general': False, 'promote_dtype': False, 'preferred_element_type': False, '_pytree__nodes': False})), ('_pytree__state', NodeAttr()), ('bias', NodeAttr()), ('dot_general', Static(value=<function dot_general at 0x7f6720c4c220>)), ('dtype', NodeAttr()), ('in_features', Static(value=3)), ('kernel', NodeAttr()), ('out_features', Static(value=2)), ('param_dtype', Static(value=<class 'jax.numpy.float32'>)), ('precision', NodeAttr()), ('preferred_element_type', NodeAttr()), ('promote_dtype', Static(value=<function promote_dtype at 0x7f65c2cf8040>)), ('use_bias', Static(value=False)), ('linear2', NodeAttr()), ('_pytree__nodes', Static(value={'_pytree__state': True, 'kernel': True, 'bias': True, 'in_features': False, 'out_features': False, 'use_bias': False, 'dtype': False, 'param_dtype': False, 'precision': False, 'dot_general': False, 'promote_dtype': False, 'preferred_element_type': False, '_pytree__nodes': False})), ('_pytree__state', NodeAttr()), ('bias', NodeAttr()), ('dot_general', Static(value=<function dot_general at 0x7f6720c4c220>)), ('dtype', NodeAttr()), ('in_features', Static(value=2)), ('kernel', NodeAttr()), ('out_features', Static(value=1)), ('param_dtype', Static(value=<class 'jax.numpy.float32'>)), ('precision', NodeAttr()), ('preferred_element_type', NodeAttr()), ('promote_dtype', Static(value=<function promote_dtype at 0x7f65c2cf8040>)), ('use_bias', Static(value=True)), ('width', Static(value=2))], num_leaves=3)
nnx.ModuleState({
'linear1': {
'kernel': Param( # 6 (24 B)
value=Array([[ 0., 0.],
[ 0., 0.],
[-0., -0.]], dtype=float32)
)
},
'linear2': {
'bias': Param( # 1 (4 B)
value=Array([1.], dtype=float32)
),
'kernel': Param( # 2 (8 B)
value=Array([[0.],
[0.]], dtype=float32)
)
}
})
graphdefm by applying the graphdefstate from before, and make a new type using the graphdefeta = 0.01 # e.g., a gradient descent update
# jax.tree.map recursively goes through the model
# Updates the underlying nnx.Param given the delta_m grad
new_state = jax.tree.map(
lambda p, g: p - eta*g,
state, delta_m) # new_state = state - eta * delta_m
m_new = nnx.merge(graphdef, new_state)
f(m_new)Array(0.0327667, dtype=float32)
nnx.jit, nnx.vmap, nnx.grad will automatically split and merge for you (i.e., filtering) on nnx.Module types as arguments, then call underlying JAX functions
nnx.grad etc. would not work without modification since the NN combine differentiable and non-differentiable partsstate and graphdef and then merge them back togetherf_gen functionstategraphdef, state = nnx.split(m)
@jax.jit # note jax.jit instead of nnx.jit
def f_split(state): # closure on graphdef
m = nnx.merge(graphdef, state)
return f_gen(m, x, b)
# Can use jax.grad, rather than nnx.grad
state_diff = jax.grad(f_split)(state)
print(state_diff)
new_state = jax.tree.map(
lambda p, g: p - eta*g,
state, delta_m)
m_new = nnx.merge(graphdef, new_state)
f(m_new)State({ 'linear1': { 'kernel': Param( # 6 (24 B) value=Array([[ 0., 0.], [ 0., 0.], [-0., -0.]], dtype=float32) ) }, 'linear2': { 'bias': Param( # 1 (4 B) value=Array([1.], dtype=float32) ), 'kernel': Param( # 2 (8 B) value=Array([[0.], [0.]], dtype=float32) ) } })
Array(0.99, dtype=float32)
nnx.Param in Pytorch is torch.nn.Parameterclass MyLinearTorch(nn.Module):
def __init__(self, in_size, out_size):
super(MyLinearTorch, self).__init__()
self.out_size = out_size
self.in_size = in_size
self.kernel = nn.Parameter(torch.randn(out_size, in_size))
# Similar to PyTorch's forward
def forward(self, x):
return self.kernel @ x
def f_gen_torch(m, x, b):
return torch.squeeze(m(x) + b)Parameter containing:
tensor([[0.5495, 1.6847, 0.1938]], requires_grad=True)
class MyMLPTorch(nn.Module):
def __init__(self, din, dout, width):
super(MyMLPTorch, self).__init__()
self.width = width
self.linear1 = nn.Linear(din, width, bias=False)
self.linear2 = nn.Linear(width, dout, bias=True)
def forward(self, x):
x = self.linear1(x)
x = torch.relu(x)
x = self.linear2(x)
return xm = MyMLPTorch(N, 1, 2)
m.zero_grad()
output = f_torch(m)
# Start with d output = [1.0]
output.backward()
# Now `m` has the gradients
# Manually update parameters recursively
# Done in-place, as torch optimizers will do
with torch.no_grad():
# Recursively
for param in m.parameters():
param -= eta * param.grad
for name, param in m.named_parameters():
print(f"{name}: {param.numpy()}")linear1.weight: [[-0.15995035 -0.04328426 0.27121732]
[-0.01335035 -0.17065024 0.4135141 ]]
linear2.weight: [[-0.4177399 -0.45026302]]
linear2.bias: [-0.39413318]
width) of the representations, as well as tweaks to the optimizer and algorithmsGiven that you may want to solve your problem with a variety of different hyperparameters, possibly running in parallel, you need a convenient way to pass the values and see the results
One model, framework, OS independent way to do this is to use commandline arguments
For example, if you have a python file called mlp_regression_jax_nnx_logging.py that accepts arguments for the width and learning rate, you may call it with
Many python frameworks exist to help you take CLI and convert to calling python functions. One convenient tool isjsonargparse
Advantage: simply annotate a function with defaults, and it will generate the CLI
Then can call python mlp_regression_jax_nnx_logging.py, python mlp_regression_jax_nnx_logging.py --width=64 etc.
While the CLI file could save output for later interpretation, a common approach in ML is to log results to visualize the optimization process, compare results, etc.
One package for this is Weights and Biases
This will log into a website calculations, usually organized by a project name, and let you sort hyperparameters, etc.
To use this, setup an account and then add code to initialiize in your python file, then log intermediate results
Putting together the logging and the CLI, you can setup a process to run the code with a variety of parameters in a “sweep” (e.g., with sweep file)
wandb agent <sweep_id>python mlp_regression_jax_nnx_logging.py --width=128 --lr=0.0015 etc.<sweep_id> on multiple computers/processes/etc.lr and width to minimize test_loss (if logged with wandb.log({"test_loss": test_loss}), etc.)program: lectures/examples/mlp_regression_jax_nnx_logging.py
name: Sweep Example
method: bayes
metric:
name: test_loss
goal: minimize
parameters:
num_epochs:
value: 300 # fixed for all calls
lr: # uniformly distributed
min: 0.0001
max: 0.01
width: # discrete values to optimize over
values: [64, 128, 256]nnx.Linear and nnx.relu layersclass MyMLP(nnx.Module):
def __init__(self, din: int, dout: int, width: int, *, rngs: nnx.Rngs):
self.linear1 = nnx.Linear(din, width, rngs=rngs)
self.linear2 = nnx.Linear(width, width, rngs=rngs)
self.linear3 = nnx.Linear(width, dout, rngs=rngs)
def __call__(self, x: jax.Array):
x = self.linear1(x)
x = nnx.relu(x)
x = self.linear2(x)
x = nnx.relu(x)
x = self.linear3(x)
return xThis randomly generates a \(\theta\) and then generates data with
ERM: with \(m \in {\mathcal{H}}\), minimize the residuals for batch (X, Y)
Optimizer uses loss differentiated wrt \(m\) as discussed
In order to use jsonargparse, this creates a function signature with defaults
def fit_model(
N: int = 500,
M: int = 2,
sigma: float = 0.0001,
width: int = 128,
lr: float = 0.001,
num_epochs: int = 2000,
batch_size: int = 512,
seed: int = 42,
wandb_project: str = "grad_econ_ML",
wandb_mode: str = "offline", # "online", "disabled
):
# ... generate data, fit model, save test_lossTo run this sweep, you can run the following command (checking the relative location file)
The output should be something along the lines of
(grad-econ-ml) /Users/username/GitHub/grad_econ_ML>wandb sweep lectures/examples/mlp_regression_jax_nnx_sweep.yaml
wandb: Creating sweep from: lectures/examples/mlp_regression_jax_nnx_sweep.yaml
wandb: Creating sweep with ID: virfdcn6
wandb: View sweep at: https://wandb.ai/highdimensionaleconlab/grad_econ_ML-lectures_examples/sweeps/virfdcn6
wandb: Run sweep agent with: wandb agent highdimensionaleconlab/grad_econ_ML-lectures_examples/virfdcn6wandb agent highdimensionaleconlab/grad_econ_ML-lectures_examples/virfdcn6 and go to web to see results in progress