ECON622: Problem Set 2

Author

Jesse Perla, UBC

Student Name/Number: (doubleclick to edit)

Packages

Add whatever packages you wish here

import numpy as np
import matplotlib.pyplot as plt
import scipy
import pandas as pd
import jax
import jax.numpy as jnp
from jax import grad, hessian
from jax import random
import optax
import optimistix
import time

Question 1

The trace of the Hessian matrix is useful in a variety of applications in statistics, econometrics, and stochastic processes. It can also be used to regularize a loss function.

For a function \(f:\mathbb{R}^N\to\mathbb{R}\), denote the Hessian as \(\nabla^2 f(x) \in \mathbb{R}^{N\times N}\).

It can be shown that for some mean zero, unit variance random vectors \(v\in\mathbb{R}^N\) with \(\mathbb{E}(v) = 0\) and \(\mathbb{E}(v v^{\top}) = I\) the trace of the Hessian fulfills

\[ \mathrm{Tr}(\nabla^2 f(x)) = \mathbb{E}\left[v^{\top} \nabla^2 f(x)\, v\right] \]

Which leads to a random algorithm by sampling \(M\) vectors \(v_1,\ldots,v_M\) and using the Monte Carlo approximation of the expectation, called the Hutchinson Trace Estimator

\[ \mathrm{Tr}(\nabla^2 f(x)) \approx \frac{1}{M} \sum_{m=1}^M v_m^{\top} \nabla^2 f(x)\, v_m \]

Question 1.1

Now, let’s take the function \(f(x) = \frac{1}{2}x^{\top} P x\), which is a quadratic form and where we know that \(\nabla^2 f(x) = P\).

The following code finds the trace of the Hessian, which is equivalently just the sum of the diagonal of \(P\) in this simple function.

key = jax.random.PRNGKey(0)

N = 100  # Dimension of the matrix
A = jax.random.normal(key, (N, N))
# Create a positive-definite matrix P by forming A^T * A
P = jnp.dot(A.T, A)
def f(x):
    return 0.5 * jnp.dot(x.T, jnp.dot(P, x))
x = jax.random.normal(key, (N,))
print(jnp.trace(jax.hessian(f)(x)))
print(jnp.diag(P).sum())
10240.816
10240.817

Now, instead of calculating the whole Hessian, use a Hessian-vector product in JAX and the approximation above with \(M\) draws of random vectors to calculate an approximation of the trace of the Hessian. Increase the numbers of \(M\) to see what the variance of the estimator is, comparing to the above closed-form solution for this quadratic.

Hint: you will want to do Forward-over-Reverse mode differentiation for this (i.e. the vjp gives a pullback function for first derivative, then differentiate that new function. Given that it would then be \(\mathbb{R}^N \to \mathbb{R}^N\), it makes sense to use forward mode with a jvp)

# ADD CODE HERE

Question 1.2 (Bonus)

If you wish, you can play around with radically increase the size of the N and change the function itself. One suggestion is to move towards a sparse or even matrix-free \(f(x)\) calculation so that the \(P\) doesn’t itself need to materialize.

# ADD CODE HERE

Question 2

This section gives some hints on how to setup a differentiable likelihood function with implicit functions

Question 2.1

The following code uses scipy to find the equilibrium price and demand for some simple supply and demand functions with embedded parameters

from scipy.optimize import root_scalar

# Define the demand function with power c
def demand(P, c_d):
    return 100 - 2 * P**c_d

# Define the supply function with power f
def supply(P, c_s):
    return 5 * 3**(c_s * P)

# Define the function to find the root of, including c and f
def equilibrium(P, c_d, c_s):
    return demand(P, c_d) - supply(P, c_s)

# Use root_scalar to find the equilibrium price
def find_equilibrium(c_d, c_s):
    result = root_scalar(equilibrium, args=(c_d, c_s), bracket=[0, 100], method='brentq')
    return result.root, demand(result.root, c_d)

# Example usage
c_d = 0.5
c_s = 0.15
equilibrium_price, equilibrium_quantity = find_equilibrium(c_d, c_s)
print(f"Equilibrium Price: {equilibrium_price:.2f}")
print(f"Equilibrium Quantity: {equilibrium_quantity:.2f}")
Equilibrium Price: 17.65
Equilibrium Quantity: 91.60

First, convert this to use JAX and Optimistix for finding the root using optimistix.root_find(). Make sure you can jit the whole find_equilibrium function

# ADD CODE HERE

Question 2.2

Now, assume that you get a noisy signal on the price that fulfills that demand system.

\[ \hat{p} \sim \mathcal{N}(p, \sigma^2) \]

In that case, the log likelihood for the Gaussian is

\[ \log \mathcal{L}(\hat{p}\,|\,c_d, c_s, p) = -\frac{1}{2} \log(2\pi\sigma^2) - \frac{1}{2\sigma^2} (\hat{p} - p)^2 \]

Or, if \(p\) was implicitly defined by the equilibrium conditions as some \(p(c_d, c_s)\) from above,

\[ \log \mathcal{L}(\hat{p}\,|\,c_d, c_s) = -\frac{1}{2} \log(2\pi\sigma^2) - \frac{1}{2\sigma^2} (\hat{p} - p(c_d, c_s))^2 \]

Then for some \(\sigma = 0.01\) we can calculate this log likelihood the above as

def log_likelihood(p_hat, c_d, c_s, sigma):
    p, x = find_equilibrium(c_d, c_s)
    return -0.5 * np.log(2 * np.pi * sigma**2) - 0.5 * (p_hat - p)**2 / sigma**2

c_d = 0.5
c_s = 0.15
sigma = 0.01
p, x = find_equilibrium(c_d, c_s) # get the true value for simulation
p_hat = p + np.random.normal(0, sigma) # simulate a noisy signal
log_likelihood(p_hat, c_d, c_s, sigma)
np.float64(3.534192121654968)

Now, take this code for the likelihood and convert it to JAX and jit. Use your function from Question 2.1

# ADD CODE HERE

Question 2.3

Use the function from the previous part and calculate the gradient with respect to params (i.e., c_d and c_s) using grad and JAX.

# ADD CODE HERE

Question 2.4 (Bonus)

You could try to run maximum likelihood estimation by using a gradient-based optimizer. You can use either Optax (standard for ML optimization) or Optimistix with optimistix.minimise().

If you attempt this:

  • Consider starting your optimization at the “pseudo-true” values with the c_s, c_d, sigma you used to simulate the data and even start with p_hat = p.
  • You may find that it is a little too noisy with only the one observation. If so, you could adapt your likelihood to take a vector of \(\hat{p}\) instead. The likelihood of IID gaussians is a simple variation on the above.
# ADD CODE HERE