import numpy as np
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
from jax import grad, jvp, vjp
from jax import random
from numpy.linalg import condECON622: Problem Set 3
Packages
Add whatever packages you wish here
Question 1: Conditioning and Preconditioned Gradient Descent
In this question, you will explore how the condition number affects gradient descent convergence and implement a simple diagonal preconditioner to improve performance.
We consider minimizing the quadratic function:
\[ f(x) = \frac{1}{2} x^\top P x \]
where \(P\) is a symmetric positive definite matrix. The gradient is \(\nabla f(x) = P x\), so the minimum is at \(x^* = 0\).
Starter Code
The following function visualizes gradient descent on a 2D quadratic:
def plot_gd_steps(N_steps, x_0, Lambda, eta, theta=np.pi/6):
"""
Plot gradient descent steps on a 2D quadratic f(x) = 0.5 * x^T P x
where P = Q @ diag(Lambda) @ Q^T for rotation angle theta.
Args:
N_steps: Number of GD iterations to plot
x_0: Initial point [x1, x2]
Lambda: Eigenvalues [lambda1, lambda2] (condition number = max/min)
eta: Step size (learning rate)
theta: Rotation angle for Q (default: pi/6)
"""
c, s = np.cos(theta), np.sin(theta)
Q = np.array([[c, -s], [s, c]])
P = Q @ np.diag(Lambda) @ Q.T
gd_step = lambda x: x - eta * P @ x
x_vals = np.linspace(-1, 1, 100)
X1, X2 = np.meshgrid(x_vals, x_vals)
Z = 0.5 * (P[0,0]*X1**2 + 2*P[0,1]*X1*X2 + P[1,1]*X2**2)
plt.figure(figsize=(6, 5))
plt.contour(X1, X2, Z, levels=20, cmap='viridis')
x_current = np.array(x_0)
for i in range(N_steps):
x_next = gd_step(x_current)
plt.arrow(x_current[0], x_current[1],
x_next[0] - x_current[0], x_next[1] - x_current[1],
head_width=0.03, head_length=0.02, fc='red', ec='red')
x_current = x_next
plt.xlabel('x1'); plt.ylabel('x2')
plt.gca().set_aspect('equal')
return plt.gcf()Question 1.1: Explore Conditioning Visually
Run the visualization with different condition numbers and observe how the contours change shape.
x_0 = [0.9, 0.0]
eta = 0.5
# Well-conditioned: kappa = 2
Lambda_good = np.array([1.0, 0.5])
plot_gd_steps(5, x_0, Lambda_good, eta)
plt.title(r'$\kappa = 2$: Well-conditioned')
plt.show()
Now run with \(\kappa = 10\) (eigenvalues [1.0, 0.1]) and \(\kappa = 100\) (eigenvalues [1.0, 0.01]).
# modify here(double click to edit your answer)
Explain in 2-3 sentences why gradient descent struggles with poorly-conditioned problems (narrow ellipses).
Question 1.2: Count Iterations to Convergence
Use the following function to count how many iterations GD needs to converge:
def count_gd_iterations(x_0, P, eta, tol=1e-6, max_iter=5000):
"""Count iterations until ||x|| < tol"""
x = np.array(x_0)
for i in range(max_iter):
if np.linalg.norm(x) < tol:
return i
x = x - eta * P @ x
return max_iter # didn't convergeHelper function to create \(P\) from eigenvalues:
def make_P(Lambda, theta=np.pi/6):
"""Create P = Q @ diag(Lambda) @ Q^T with rotation angle theta"""
c, s = np.cos(theta), np.sin(theta)
Q = np.array([[c, -s], [s, c]])
return Q @ np.diag(Lambda) @ Q.TCount iterations for \(\kappa = 2, 10, 100\) with eta = 0.5 and x_0 = [0.9, 0.0]:
# modify here(double click to edit your answer)
Report the iteration counts and explain the relationship between \(\kappa\) and convergence rate.
Question 1.3 (BONUS): Understand Preconditioned Gradient Descent
Standard gradient descent updates:
\[ x^{k+1} = x^k - \eta \nabla f(x^k) = x^k - \eta P x^k \]
Preconditioned gradient descent scales the gradient by a matrix \(D^{-1}\):
\[ x^{k+1} = x^k - \eta D^{-1} \nabla f(x^k) = x^k - \eta D^{-1} P x^k \]
Your task: Show that preconditioned GD is equivalent to standard GD on a transformed problem.
Let \(\tilde{x} = D^{1/2} x\) (assuming \(D\) is diagonal with positive entries). Show that:
- The objective becomes \(\tilde{f}(\tilde{x}) = \frac{1}{2} \tilde{x}^\top \tilde{P} \tilde{x}\) where \(\tilde{P} = D^{-1/2} P D^{-1/2}\)
- If we choose \(D = \text{diag}(P)\), what are the diagonal entries of \(\tilde{P}\)?
(double click to edit your answer)
Derivation:
Question 1.4: Implement Diagonal Preconditioning
Implement preconditioned GD and compare iteration counts:
def count_preconditioned_gd_iterations(x_0, P, D_inv, eta, tol=1e-6, max_iter=5000):
"""Count iterations for preconditioned GD: x_{k+1} = x_k - eta * D^{-1} @ P @ x_k"""
x = np.array(x_0)
for i in range(max_iter):
if np.linalg.norm(x) < tol:
return i
# modify here: implement preconditioned step
x = x - eta * P @ x
return max_iterTest with \(\kappa = 100\):
Lambda = np.array([1.0, 0.01])
P = make_P(Lambda)
D_inv = np.diag(1.0 / np.diag(P)) # D^{-1} where D = diag(P)
# modify here: compare iterations with and without preconditioningQuestion 1.5: Verify the Condition Number Improvement
Compute and compare the condition numbers before and after preconditioning:
Lambda = np.array([1.0, 0.01])
P = make_P(Lambda)
D = np.diag(np.diag(P))
D_sqrt_inv = np.diag(1.0 / np.sqrt(np.diag(P)))
# Preconditioned matrix: D^{-1/2} P D^{-1/2}
P_precond = D_sqrt_inv @ P @ D_sqrt_inv
# modify here: compute and print condition numbers
print(f"cond(P) = ...")
print(f"cond(P_precond) = ...")cond(P) = ...
cond(P_precond) = ...
(double click to edit your answer)
Explain why the diagonal preconditioner improves the condition number for this problem.
Question 1.6: When Diagonal Preconditioning Fails
The effectiveness of diagonal preconditioning depends on the rotation angle \(\theta\) used to construct \(P = Q(\theta) \Lambda Q(\theta)^\top\).
Repeat the preconditioning experiment from Q1.4-1.5 with \(\theta = \pi/4\) (a 45-degree rotation). Compare iteration counts with and without preconditioning, and compute the condition numbers of \(P\) and \(\tilde{P}\).
Lambda = np.array([1.0, 0.01])
# modify here: construct P with theta = pi/4, compute D_inv, and compare
# iterations with and without preconditioning
# Also compute cond(P) and cond(P_precond)(double click to edit your answer)
- What do you observe about the iteration counts?
- Examine the diagonal entries of \(P\) at \(\theta = \pi/4\). Why does diagonal preconditioning fail here?
- What is \(D^{-1}\) in this case, and what does the preconditioned step \(x - \eta D^{-1} P x\) simplify to?
Question 2: Vectorized AD Rule for Matrix-Vector Product
In this question, you will derive the JVP and VJP rules for a vectorized operation (matrix-vector multiplication) and verify your implementation against JAX.
Consider the function \(f(A, x) = Ax\) where \(A \in \mathbb{R}^{m \times n}\) and \(x \in \mathbb{R}^n\).
Question 2.1: Derive the JVP Rule
Given tangents \(\dot{A} \in \mathbb{R}^{m \times n}\) and \(\dot{x} \in \mathbb{R}^n\), derive the pushforward (JVP):
\[ \dot{y} = \partial f(A, x)[(\dot{A}, \dot{x})] = ? \]
Hint: Use the product rule for differentiation.
(double click to edit your answer)
Derivation:
Question 2.2 (BONUS): Derive the VJP Rule
Given a cotangent \(\bar{y} \in \mathbb{R}^m\), derive the pullback (VJP):
\[ (\bar{A}, \bar{x}) = \partial f(A, x)^\top[\bar{y}] = ? \]
Hint for adjoint algebra: Start with the inner product \(\langle \bar{y}, \dot{y} \rangle = \langle \bar{y}, \dot{A} x + A \dot{x} \rangle\) and rearrange to identify \(\bar{A}\) and \(\bar{x}\).
Recall that for matrices, \(\langle U, V \rangle = \text{Tr}(U^\top V)\), and for vectors, \(\langle u, v \rangle = u^\top v\).
(double click to edit your answer)
Derivation:
Question 2.3: Implement and Verify JVP
Implement your JVP rule and verify against JAX:
def matvec(A, x):
"""f(A, x) = A @ x"""
return A @ x
def matvec_jvp(A, x, A_dot, x_dot):
"""Your implementation of the JVP rule: returns y_dot"""
# modify here
pass
# Test data
key = random.PRNGKey(42)
keys = random.split(key, 4)
m, n = 4, 3
A = random.normal(keys[0], (m, n))
x = random.normal(keys[1], (n,))
A_dot = random.normal(keys[2], (m, n))
x_dot = random.normal(keys[3], (n,))
# modify here: compare to JAX jvpQuestion 2.4: Implement and Verify VJP
Implement your VJP rule and verify against JAX:
def matvec_vjp(A, x, y_bar):
"""Your implementation of the VJP rule: returns (A_bar, x_bar)"""
# modify here
pass
# Test cotangent
y_bar = random.normal(random.PRNGKey(99), (m,))
# modify here: compare to JAX vjp(double click to edit your answer)
Explain in 2-3 sentences the connection between the VJP formula \(\bar{x} = A^\top \bar{y}\) and the “adjoint” or “transpose” terminology used in automatic differentiation.
Question 3 (BONUS): Sensitivity Analysis via the LLS Adjoint
In this question, you will derive and apply the VJP (adjoint) rule for linear least squares to perform sensitivity analysis: understanding how each observation affects the estimated coefficients.
The Linear Least Squares Problem
Given design matrix \(X \in \mathbb{R}^{n \times p}\) and observations \(y \in \mathbb{R}^n\), the least squares solution is:
\[ \hat{\beta} = (X^\top X)^{-1} X^\top y \]
Question 3.1: The LLS Adjoint Rule
The VJP rule for \(\hat{\beta} = \text{LLS}(X, y)\) is derived by chaining through the matrix inverse and products.
Given: A cotangent \(\bar{\beta} \in \mathbb{R}^p\) (representing “how much we care about each coefficient”)
Compute:
- Solve the linear system: \(g = (X^\top X)^{-1} \bar{\beta}\)
- Compute the residual: \(r = y - X\hat{\beta}\)
Then the adjoints are:
\[ \bar{y} = X g \]
\[ \bar{X} = r\, g^\top - (Xg)\, \hat{\beta}^\top \]
(double click to edit your answer)
Explain in 2-3 sentences why this is computationally powerful: if we want to know how \(\hat{\beta}_0\) depends on ALL \(n\) observations \(y_i\), how many backward passes do we need?
Question 3.2: Generate Data and Compute LLS
# Generate random regression data
key = random.PRNGKey(123)
keys = random.split(key, 3)
n = 1000 # observations
p = 3 # parameters
# True coefficients
beta_true = jnp.array([2.0, -1.0, 0.5])
# Random design matrix and noisy observations
X = random.normal(keys[0], (n, p))
noise = 0.1 * random.normal(keys[1], (n,))
y = X @ beta_true + noise
# Compute LLS solution
beta_hat = jnp.linalg.lstsq(X, y)[0]
print(f"True beta: {beta_true}")
print(f"Estimated beta: {beta_hat}")True beta: [ 2. -1. 0.5]
Estimated beta: [ 2.0004642 -1.0009896 0.49019286]
Question 3.3: Manual Sensitivity Calculation
Use the LLS adjoint rule to compute \(\partial \hat{\beta}_0 / \partial y_i\) for all observations.
# We want sensitivity of beta_hat[0] to each y_i
# Set beta_bar = e_0 = [1, 0, 0]
beta_bar = jnp.array([1.0, 0.0, 0.0])
# modify here: compute g, r, and y_bar using the adjoint formula
# g = (X^T X)^{-1} @ beta_bar
# r = y - X @ beta_hat
# y_bar = X @ gQuestion 3.4: Verify with JAX VJP
Use JAX’s built-in autodiff to verify your manual calculation:
def lls(y):
"""LLS as a function of y only (X is fixed)"""
return jnp.linalg.lstsq(X, y)[0]
# modify here: use vjp to compute sensitivity of beta_hat[0] to y
# beta_hat_jax, vjp_fun = vjp(lls, y)
# y_bar_jax, = vjp_fun(beta_bar)
# Compare to y_bar_manualQuestion 3.5: Interpret the Sensitivities
Plot the sensitivities and identify influential observations:
# modify here: create a plot showing the sensitivities
# plt.figure(figsize=(10, 4))
# plt.plot(y_bar_manual, 'o', alpha=0.5, markersize=3)
# plt.xlabel('Observation index i')
# plt.ylabel(r'$\partial \hat{\beta}_0 / \partial y_i$')
# plt.title('Sensitivity of $\\hat{\\beta}_0$ to each observation')
# plt.show()
# Find the most influential observations
# top_indices = jnp.argsort(jnp.abs(y_bar_manual))[-5:](double click to edit your answer)
- What determines whether an observation has high sensitivity (large \(|\partial \hat{\beta}_0 / \partial y_i|\))?
- How does this relate to the concept of “leverage” in regression diagnostics?
Question 3.6 (Extra Credit): Sensitivity to X
The full adjoint rule also gives \(\bar{X}\). Compute and interpret the sensitivity of \(\hat{\beta}_0\) to changes in \(X_{ij}\).
# modify here: compute X_bar using the full adjoint formula
# X_bar = r @ g.T - (X @ g) @ beta_hat.T
# This gives d(beta_hat[0])/d(X_ij) for all i,j(double click to edit your answer)
What is the computational advantage of using the adjoint/VJP approach for this sensitivity analysis compared to finite differences?