import numpy as np
import matplotlib.pyplot as plt
import scipy
import pandas as pd
import torch
import jax
import jax.numpy as jnp
from jax import grad, hessian
import torch.nn as nn
import torch.optim as optim
from jax import random
import optax
import jax_dataloader as jdl
from jax_dataloader.loaders import DataLoaderJAX
from flax import nnx
from openai import OpenAI
import wandb
import jsonargparseECON622: Problem Set 4
Packages
Add whatever packages you wish here
Question 1: W&B Logging and CLI (JAX NNX)
For the linear regression examples with PyTorch we added in linear_regression_pytorch_logging.py logging and a CLI interface — which came for free with PyTorch Lightning.
In this question you will add in some of those features to the linear_regression_jax_nnx.py example.
Question 1.1: Add W&B Logging
Take the linear_regression_jax_nnx.py copied below for your convenience and:
- Setup W&B properly
- Add in logging of the
train_lossat every step of the optimizer - Remove the other epoch printing, or try to log an epoch specific
||theta - theta_hat||if you wish - Log the end
||theta - theta_hat||at the end of the training
# MODIFY CODE HERE
import jax
import jax.numpy as jnp
from jax import random
import optax
import jax_dataloader as jdl
from jax_dataloader.loaders import DataLoaderJAX
from flax import nnx
N = 500 # samples
M = 2
sigma = 0.001
rngs = nnx.Rngs(42)
theta = random.normal(rngs(), (M,))
X = random.normal(rngs(), (N, M))
Y = X @ theta + sigma * random.normal(rngs(), (N,)) # Adding noise
def residual(model, x, y):
y_hat = model(x)
return (y_hat - y) ** 2
def residuals_loss(model, X, Y):
return jnp.mean(jax.vmap(residual, in_axes=(None, 0, 0))(model, X, Y))
model = nnx.Linear(M, 1, use_bias=False, rngs=rngs)
lr = 0.001
optimizer = nnx.Optimizer(model, optax.sgd(lr), wrt=nnx.Param)
@nnx.jit
def train_step(model, optimizer, X, Y):
def loss_fn(model):
return residuals_loss(model, X, Y)
loss, grads = nnx.value_and_grad(loss_fn)(model)
optimizer.update(model, grads)
return loss
num_epochs = 1000
batch_size = 512
dataset = jdl.ArrayDataset(X, Y)
train_loader = DataLoaderJAX(dataset, batch_size=batch_size, shuffle=True)
for epoch in range(num_epochs):
for X_batch, Y_batch in train_loader:
loss = train_step(model, optimizer, X_batch, Y_batch)
if epoch % 100 == 0:
print(
f"Epoch {epoch},||theta - theta_hat|| = {jnp.linalg.norm(theta - jnp.squeeze(model.kernel.value))}"
)
print(f"||theta - theta_hat|| = {jnp.linalg.norm(theta - jnp.squeeze(model.kernel.value))}")Question 1.2: Add CLI Interface
Now, take the above code and copy it into a file named linear_regression_jax_cli.py.
We want to make it CLI ready:
- A package with many features, most of which you wouldn’t use directly, is jsonargparse. Besides the more advanced features like configuration files and the instantiation of classes/etc. as arguments, the main difference will be that it checks the types of arguments and converts them for you using python typehints.
- Alternatively, you can use the builtin Argparse or any other CLI framework
In that case, you can adapt the following code for your linear_regression_jax_cli.py:
import jsonargparse
def main_fn(lr: float = 0.001, N: int = 100):
print(f"lr = {lr}, N = {N}")
if __name__ == "__main__":
jsonargparse.CLI(main_fn)
Using your CLI
In either case, at that point you should be able to call this with python linear_regression_jax_cli.py and have it use all of the default values, python linear_regression_jax_cli.py --N=200 to change them, etc.
Either submit the file as part of the assignment or just paste the code into the notebook.
Question 1.3 (BONUS): W&B Sweep
Given the CLI you can now run a hyperparameter search. For this bonus problem, do a hyperparameter search over the --lr argument by following the W&B documentation.
To get you started, your sweep YAML might look something like this:
program: linear_regression_jax_cli.py
name: JAX Example
project: linear_regression_jax
description: JAX Sweep
method: random
parameters:
lr:
min: 0.0001
max: 0.01
Here the method is changed from Bayes to random because otherwise we would need to provide a metric to optimize over. Feel free to adapt any of these settings.
If you successfully run a sweep then paste in your own YAML file here, and a screenshot of the W&B dashboard showing something about the sweep results.
Question 2: PyTorch Neural Networks
In the repository you have code that does a linear regression with PyTorch: linear_regression_pytorch_sgd.py.
Question 2.1: Shallow MLP
Make a new file that does the same thing, but replace the nn.Linear(M, 1, bias=False) with code that gives a neural network with multiple layers. Maybe try:
M = 2 # loaded automatically in the code
num_width = 8 # etc. You can hardcode or add to the template/yaml code
nn.Sequential(
nn.Linear(M, num_width),
nn.ReLU(),
nn.Linear(num_width, 1, bias = False)
)Sequential(
(0): Linear(in_features=2, out_features=8, bias=True)
(1): ReLU()
(2): Linear(in_features=8, out_features=1, bias=False)
)
This is a network with one “hidden” layer. Try this for a very shallow network (e.g. num_width = 8) and see if it converges with Adam or SGD. It is OK if it does not! Don’t spend too much time with it.
Question 2.2: Deep MLP
Now replace this with a deeper and wider network by the same pattern. Maybe something like:
M = 2 # loaded automatically in the code
num_width = 256 # etc. You can hardcode or add to the template/yaml code
nn.Sequential(
nn.Linear(M, num_width),
nn.ReLU(),
nn.Linear(num_width, num_width),
nn.ReLU(),
nn.Linear(num_width, num_width),
nn.ReLU(),
nn.Linear(num_width, num_width),
nn.ReLU(),
nn.Linear(num_width, 1, bias = False)
)Sequential(
(0): Linear(in_features=2, out_features=256, bias=True)
(1): ReLU()
(2): Linear(in_features=256, out_features=256, bias=True)
(3): ReLU()
(4): Linear(in_features=256, out_features=256, bias=True)
(5): ReLU()
(6): Linear(in_features=256, out_features=256, bias=True)
(7): ReLU()
(8): Linear(in_features=256, out_features=1, bias=False)
)
Try to optimize this now and see if it fits well with this deeper network. If this is too slow, change the num_width or remove a layer to see if it helps.
Question 2.3 (BONUS): Nonlinear DGP
The above is trying to fit a linear DGP. Instead, increase the dimension M to something much larger, and modify the DGP to be something nonlinear. Try out the larger network to see if it fits it well.
Question 3 (BONUS): JAX NNX Neural Networks
Now we can try the same thing with JAX and NNX using linear_regression_jax_nnx.py.
Question 3.1
Take the code and modify the network from the simple nnx.Linear(M, 1, use_bias=False, rngs=rngs) to do a nonlinear function with more parameters and layers, as in Question 2.2.
There is no builtin MLP in nnx, but you can construct it manually by creating a class from nnx.Module and then nesting calls to nnx.Linear with an activation like nnx.relu. See the docs for more information.
Question 4: GP Regression (GPyTorch)
The following code comes from the GPyTorch documentation.
import math
import torch
import gpytorch
from gpytorch.kernels import ScaleKernel, RBFKernel, LinearKernel
from matplotlib import pyplot as plt
# Training data is 100 points in [0,1] inclusive regularly spaced
train_x = torch.linspace(0, 1, 100)
# True function is sin(2*pi*x) with Gaussian noise
train_y = torch.sin(train_x * (2 * math.pi)) + torch.randn(train_x.size()) * math.sqrt(0.04)
# We will use the simplest form of GP model, exact inference
class ExactGPModel(gpytorch.models.ExactGP):
def __init__(self, train_x, train_y, likelihood):
super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
self.mean_module = gpytorch.means.ConstantMean()
self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())
def forward(self, x):
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
# initialize likelihood and model
likelihood = gpytorch.likelihoods.GaussianLikelihood()
model = ExactGPModel(train_x, train_y, likelihood)
# Find optimal model hyperparameters
model.train()
likelihood.train()
# Use the adam optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.1) # Includes GaussianLikelihood parameters
# "Loss" for GPs - the marginal log likelihood
mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)
training_iter = 50
for i in range(training_iter):
# Zero gradients from previous iteration
optimizer.zero_grad()
# Output from model
output = model(train_x)
# Calc loss and backprop gradients
loss = -mll(output, train_y)
loss.backward()
print('Iter %d/%d - Loss: %.3f lengthscale: %.3f noise: %.3f' % (
i + 1, training_iter, loss.item(),
model.covar_module.base_kernel.lengthscale.item(),
model.likelihood.noise.item()
))
optimizer.step()
# Get into evaluation (predictive posterior) mode
model.eval()
likelihood.eval()
# Test points are regularly spaced along [0,1]
# Make predictions by feeding model through likelihood
with torch.no_grad(), gpytorch.settings.fast_pred_var():
test_x = torch.linspace(0, 1, 51)
observed_pred = likelihood(model(test_x))
with torch.no_grad():
# Initialize plot
f, ax = plt.subplots(1, 1, figsize=(4, 3))
# Get upper and lower confidence bounds
lower, upper = observed_pred.confidence_region()
# Plot training data as black stars
ax.plot(train_x.numpy(), train_y.numpy(), 'k*')
# Plot predictive means as blue line
ax.plot(test_x.numpy(), observed_pred.mean.numpy(), 'b')
# Shade between the lower and upper confidence bounds
ax.fill_between(test_x.numpy(), lower.numpy(), upper.numpy(), alpha=0.5)
ax.set_ylim([-3, 3])
ax.legend(['Observed Data', 'Mean', 'Confidence'])Iter 1/50 - Loss: 0.913 lengthscale: 0.693 noise: 0.693
Iter 2/50 - Loss: 0.881 lengthscale: 0.644 noise: 0.644
Iter 3/50 - Loss: 0.847 lengthscale: 0.598 noise: 0.598
Iter 4/50 - Loss: 0.810 lengthscale: 0.555 noise: 0.554
Iter 5/50 - Loss: 0.769 lengthscale: 0.514 noise: 0.513
Iter 6/50 - Loss: 0.724 lengthscale: 0.475 noise: 0.474
Iter 7/50 - Loss: 0.676 lengthscale: 0.439 noise: 0.437
Iter 8/50 - Loss: 0.627 lengthscale: 0.404 noise: 0.402
Iter 9/50 - Loss: 0.579 lengthscale: 0.372 noise: 0.370
Iter 10/50 - Loss: 0.535 lengthscale: 0.342 noise: 0.339
Iter 11/50 - Loss: 0.494 lengthscale: 0.315 noise: 0.311
Iter 12/50 - Loss: 0.457 lengthscale: 0.291 noise: 0.284
Iter 13/50 - Loss: 0.420 lengthscale: 0.272 noise: 0.260
Iter 14/50 - Loss: 0.385 lengthscale: 0.255 noise: 0.237
Iter 15/50 - Loss: 0.351 lengthscale: 0.243 noise: 0.216
Iter 16/50 - Loss: 0.316 lengthscale: 0.233 noise: 0.197
Iter 17/50 - Loss: 0.282 lengthscale: 0.227 noise: 0.179
Iter 18/50 - Loss: 0.247 lengthscale: 0.222 noise: 0.163
Iter 19/50 - Loss: 0.213 lengthscale: 0.220 noise: 0.148
Iter 20/50 - Loss: 0.179 lengthscale: 0.220 noise: 0.135
Iter 21/50 - Loss: 0.145 lengthscale: 0.222 noise: 0.122
Iter 22/50 - Loss: 0.112 lengthscale: 0.225 noise: 0.111
Iter 23/50 - Loss: 0.079 lengthscale: 0.230 noise: 0.101
Iter 24/50 - Loss: 0.049 lengthscale: 0.236 noise: 0.092
Iter 25/50 - Loss: 0.019 lengthscale: 0.244 noise: 0.084
Iter 26/50 - Loss: -0.008 lengthscale: 0.252 noise: 0.076
Iter 27/50 - Loss: -0.032 lengthscale: 0.261 noise: 0.069
Iter 28/50 - Loss: -0.054 lengthscale: 0.270 noise: 0.063
Iter 29/50 - Loss: -0.072 lengthscale: 0.280 noise: 0.058
Iter 30/50 - Loss: -0.087 lengthscale: 0.288 noise: 0.053
Iter 31/50 - Loss: -0.099 lengthscale: 0.296 noise: 0.049
Iter 32/50 - Loss: -0.107 lengthscale: 0.301 noise: 0.045
Iter 33/50 - Loss: -0.113 lengthscale: 0.304 noise: 0.042
Iter 34/50 - Loss: -0.117 lengthscale: 0.303 noise: 0.039
Iter 35/50 - Loss: -0.120 lengthscale: 0.300 noise: 0.036
Iter 36/50 - Loss: -0.121 lengthscale: 0.295 noise: 0.034
Iter 37/50 - Loss: -0.121 lengthscale: 0.288 noise: 0.032
Iter 38/50 - Loss: -0.119 lengthscale: 0.281 noise: 0.031
Iter 39/50 - Loss: -0.116 lengthscale: 0.273 noise: 0.029
Iter 40/50 - Loss: -0.112 lengthscale: 0.266 noise: 0.028
Iter 41/50 - Loss: -0.109 lengthscale: 0.260 noise: 0.028
Iter 42/50 - Loss: -0.106 lengthscale: 0.255 noise: 0.027
Iter 43/50 - Loss: -0.103 lengthscale: 0.251 noise: 0.027
Iter 44/50 - Loss: -0.102 lengthscale: 0.249 noise: 0.026
Iter 45/50 - Loss: -0.102 lengthscale: 0.247 noise: 0.026
Iter 46/50 - Loss: -0.103 lengthscale: 0.247 noise: 0.027
Iter 47/50 - Loss: -0.105 lengthscale: 0.248 noise: 0.027
Iter 48/50 - Loss: -0.108 lengthscale: 0.250 noise: 0.027
Iter 49/50 - Loss: -0.110 lengthscale: 0.252 noise: 0.028
Iter 50/50 - Loss: -0.113 lengthscale: 0.254 noise: 0.028

Question 4.1: Different Kernel
Using the code above, try a different kernel from the docs and plot the results. Doesn’t matter if it is better or worse, but you should try to pick a kernel that has different parameters to fit.
Question 4.2 (BONUS): Noiseless GP Interpolation
Now take the code above:
- Remove the noise in the DGP
- Decrease the number of generated datapoints to maybe 10 or so.
- Try to see how to fit a GP without any observational noise, so that it interpolates the data. This may require changing the likelihood object in the loop above.
Question 5: OpenAI API
Setting up OpenAI to access ChatGPT and embeddings programmatically.
Question 5.1: Setup and First Call
Setup an account on the OpenAI platform:
- Sign up
- Go to the
API keystab and create a key - In your terminal, set
OPENAI_API_KEYto this value (see here)
Given the change in the environment variable, you may need to restart your browser/editor/etc.
Get the following code running, which shows a prompt:
client = OpenAI()
completion = client.chat.completions.create(
model="gpt-4o-mini",
temperature=0.7, max_tokens=300, #optional
messages=[
{
"role": "system",
"content": "You are generating numbers that are easy to parse."
},
{
"role": "user",
"content": "Give me a list of 3 numbers"
}
]
)
print(completion.choices[0].message.content)Question 5.2: Temperature Exploration
Take that prompt and run it a bunch of times with temperature = 0.0. What happens and why?
Next, run it a bunch of times with temperature = 2.0, which is the maximum entropy. What happens and why?
Question 5.3: Parsing LLM Output
Modify the system role in that prompt until you can easily parse the completion.choices[0].message.content into a list, reliably, with temperature = 0.7.
client = OpenAI()
completion = client.chat.completions.create(
model="gpt-4o-mini",
temperature=0.7,
messages=[
{
"role": "system",
"content": "You are generating numbers that are easy to parse." #modify this
},
{
"role": "user",
"content": "Give me a list of 3 numbers"
}
]
)
print(completion.choices[0].message.content)
# Add parsing logic