import numpy as np
import matplotlib.pyplot as plt
import scipy
import pandas as pd
import torch
import jax
import jax.numpy as jnp
import jax.numpy as jnp
from jax import grad, hessianECON622: Problem Set 1
Question 1
Follow instructions in Python Environment Setup to get your python environment setup.
Some imported packages which will tell you if you have a successful installation
Mess around with your setup until the following work
x = torch.zeros((2, 3))
print(x)tensor([[0., 0., 0.],
[0., 0., 0.]])
x = jnp.zeros((2, 3))
print(x)[[0. 0. 0.]
[0. 0. 0.]]
No need to get GPUs/etc. working. If you have issues on windows, see JAX Installation and Pytorch Installation instructions
Question 2
This section just summarizes some reading on how JAX works, and asks you to come up with some demonstrations of its functionality. Read:
Question 2.1
Write a function which takes a normally distributed random number and calculates the standard deviation manually (i.e. calculate the mean, second moment, etc. manually and then form the standard deviation yourself in the function)
Compile the function using the jit and the decorator @jit
# MODIFY HEREQuestion 2.2
Now take that function and generate a matrix of random normals and use vmap to calculate the standard deviation across rows, then across columns. Compare to jnp.std(x, axis=0) and jnp.std(x, axis=1)
# MODIFY HEREQuestion 2.3
Take the following code, and find the hessian of f at x
# MODIFY HERE
# Define a multivariate function
def f(x):
return jnp.sum(x ** 2)
# Define a point at which to evaluate the function
x = jnp.array([1.0, 2.0, 3.0])