ECON622: Problem Set 1

Author

Jesse Perla, UBC

Student Name/Number: (doubleclick to edit)

Question 1

Follow instructions in Python Environment Setup to get your python environment setup.

Some imported packages which will tell you if you have a successful installation

import numpy as np
import matplotlib.pyplot as plt
import scipy
import pandas as pd
import torch
import jax
import jax.numpy as jnp
import jax.numpy as jnp
from jax import grad, hessian

Mess around with your setup until the following work

x = torch.zeros((2, 3))
print(x)
tensor([[0., 0., 0.],
        [0., 0., 0.]])
x = jnp.zeros((2, 3))
print(x)
[[0. 0. 0.]
 [0. 0. 0.]]

No need to get GPUs/etc. working. If you have issues on windows, see JAX Installation and Pytorch Installation instructions

Question 2

This section just summarizes some reading on how JAX works, and asks you to come up with some demonstrations of its functionality. Read:

Question 2.1

Write a function which takes a normally distributed random number and calculates the standard deviation manually (i.e. calculate the mean, second moment, etc. manually and then form the standard deviation yourself in the function)

Compile the function using the jit and the decorator @jit

# MODIFY HERE

Question 2.2

Now take that function and generate a matrix of random normals and use vmap to calculate the standard deviation across rows, then across columns. Compare to jnp.std(x, axis=0) and jnp.std(x, axis=1)

# MODIFY HERE

Question 2.3

Take the following code, and find the hessian of f at x

# MODIFY HERE

# Define a multivariate function
def f(x):
    return jnp.sum(x ** 2)

# Define a point at which to evaluate the function
x = jnp.array([1.0, 2.0, 3.0])