Machine Learning Fundamentals for Economists
Building on the previous lecture’s direct methods, we now explore iterative approaches
Iterative methods and matrix conditioning you’ll learn:
Key insight: Performance depends on geometry (conditioning), not just size
These methods are essential for ML: optimization, regularization, and large-scale problems
\[ \text{cond}(A) \equiv \|A\| \|A^{-1}\|\geq 1 \]
\[ \text{cond}(A) = \left|\frac{\lambda_{max}}{\lambda_{min}}\right| \]
inv is a bad idea when \(\text{cond}(A)\) is hugecond(L) = 1.73e+08
cond(L.T @ L) = inf
Matrix L:
[[1.e+00 1.e+00 1.e+00]
[1.e-08 0.e+00 0.e+00]
[0.e+00 1.e-08 0.e+00]
[0.e+00 0.e+00 1.e-08]]
\[ \rho v = r + Q v \]
N = 100
a = 0.1
b = 0.05
rho = 0.05
# Define diagonals for tridiagonal matrix Q
lower_diag = jnp.full(N-1, b)
main_diag = jnp.concatenate([jnp.array([-a]),
jnp.full(N-2, -(a+b)),
jnp.array([-b])])
upper_diag = jnp.full(N-1, a)
Q = cola.ops.Tridiagonal(lower_diag, main_diag, upper_diag)
# For direct solve, convert to dense
r = jnp.linspace(0.0, 10.0, N)
Q_dense = Q.to_dense()
A = rho * jnp.eye(N) - Q_dense
v_direct = jnp.linalg.solve(A, r)
print(f"Mean value: {jnp.mean(v_direct):.6f}")Mean value: 101.963066
\[ |A_{ii}| \geq \sum_{j\neq i} |A_{ij}| \quad\text{for all } i = 1\ldots N \]
i.e., sum of all off-diagonal elements in a row is less than the diagonal element in absolute value
Note for our problem rows sum to 0 so if \(\rho > 0\) then \(\rho I - Q\) is strictly diagonally dominant
\[ A \equiv D + R \]
\[ D \equiv \begin{bmatrix} A_{11} & 0 & \ldots & 0\\ 0 & A_{22} & \ldots & 0\\ \vdots & \vdots & \vdots & \vdots\\ 0 & 0 & \ldots & A_{NN} \end{bmatrix}\,\, R \equiv \begin{bmatrix} 0 & A_{12} & \ldots & A_{1N} \\ A_{21} & 0 & \ldots & A_{2N} \\ \vdots & \vdots & \vdots & \vdots\\ A_{N1} & A_{N2} & \ldots & 0 \end{bmatrix} \]
\[ \begin{aligned} D x &= b - R x\\ x &= D^{-1} (b - R x) \end{aligned} \]
Where \(D^{-1}\) is trivial since diagonal. To solve, take an iteration \(x^k\), starting from \(x^0\),
\[ x^{k+1} = D^{-1}(b - R x^k) \]
See Jacobi Implementation in appendix for code example.
# Diagonal operator
D = cola.ops.Diagonal(jnp.arange(1.0, N+1))
# Compose without forming the matrix (lazy composition)
Op = Q + D
# CoLA handles the structure automatically in solves
b_test = jnp.ones(N)
x_cola = cola.solve(Op, b_test)
# Can also use eigenvalue methods
eigenvalues, eigenvectors = cola.eig(Op, k=3, which='SM')
print(f"Three smallest eigenvalues: {eigenvalues}")Three smallest eigenvalues: [0.89475226+0.j 1.8502488 +0.j 2.850001 +0.j]
Q + D as a dense matrixN_cg = 100
key, subkey = random.split(key)
A_sparse = random.uniform(subkey, (N_cg, N_cg))
key, subkey = random.split(key)
A_sparse = jnp.where(random.uniform(subkey, (N_cg, N_cg)) < 0.1, A_sparse, 0.0)
A_pd = A_sparse @ A_sparse.T + 0.5 * jnp.eye(N_cg)
key, subkey = random.split(key)
b_cg = random.uniform(subkey, (N_cg,))
x_direct = jnp.linalg.solve(A_pd, b_cg)
operator = lx.MatrixLinearOperator(A_pd, tags=lx.positive_semidefinite_tag)
solver = lx.CG(rtol=1e-5, atol=1e-5, max_steps=1000)
solution = lx.linear_solve(operator, b_cg, solver)
print(f"cond(A) = {jnp.linalg.cond(A_pd):.2e}, Iterations: {solution.stats['num_steps']}, Error: {jnp.linalg.norm(solution.value - x_direct):.2e}")cond(A) = 6.32e+01, Iterations: 32, Error: 2.61e-05
# Warmup both solvers (JIT compilation happens on first call)
_ = jnp.linalg.solve(A_pd, b_cg).block_until_ready()
_ = lx.linear_solve(operator, b_cg, solver).value.block_until_ready()
# Direct solve benchmark
start = time.perf_counter()
x_direct = jnp.linalg.solve(A_pd, b_cg)
x_direct.block_until_ready() # Wait for JAX async execution
direct_time = time.perf_counter() - start
# CG solve benchmark
start = time.perf_counter()
solution = lx.linear_solve(operator, b_cg, solver)
solution.value.block_until_ready()
cg_time = time.perf_counter() - start
print(f"Direct solve: {direct_time*1000:.2f} ms")
print(f"CG solve: {cg_time*1000:.2f} ms")
print(f"Speedup: {direct_time/cg_time:.2f}x")Direct solve: 2.51 ms
CG solve: 0.45 ms
Speedup: 5.60x
Key insights:
\[ \min_{\beta} \| X \beta -y \|^2 + \alpha \| \beta\|^2 \]
M = 1000
N_lls = 10000
sigma = 0.1
key, subkey = random.split(key)
X_sparse = random.uniform(subkey, (N_lls, M))
key, subkey = random.split(key)
X_sparse = jnp.where(random.uniform(subkey, (N_lls, M)) < 0.1, X_sparse, 0.0)
key, subkey = random.split(key)
beta_true = random.uniform(subkey, (M,))
key, subkey = random.split(key)
y = X_sparse @ beta_true + sigma * random.normal(subkey, (N_lls,))
beta_direct = jnp.linalg.lstsq(X_sparse, y, rcond=None)[0]
operator = lx.MatrixLinearOperator(X_sparse)
solver = lx.NormalCG(rtol=1e-5, atol=1e-5, max_steps=1000)
solution = lx.linear_solve(operator, y, solver)
beta_normalcg = solution.value
print(f"Norm difference: {jnp.linalg.norm(beta_direct - beta_normalcg):.2e}, Iterations: {solution.stats['num_steps']}")Norm difference: 1.97e-04, Iterations: 15
# Warmup both solvers (JIT compilation happens on first call)
_ = jnp.linalg.lstsq(X_sparse, y, rcond=None)[0].block_until_ready()
_ = lx.linear_solve(operator, y, solver).value.block_until_ready()
# Benchmark direct least squares
start = time.perf_counter()
beta_direct = jnp.linalg.lstsq(X_sparse, y, rcond=None)[0]
beta_direct.block_until_ready()
direct_time = time.perf_counter() - start
# Benchmark NormalCG
start = time.perf_counter()
solution = lx.linear_solve(operator, y, solver)
solution.value.block_until_ready()
normalcg_time = time.perf_counter() - start
print(f"Direct lstsq: {direct_time*1000:.2f} ms")
print(f"NormalCG: {normalcg_time*1000:.2f} ms")
print(f"Speedup: {direct_time/normalcg_time:.2f}x")
print(f"Iterations: {solution.stats['num_steps']}")Direct lstsq: 1283.48 ms
NormalCG: 127.66 ms
Speedup: 10.05x
Iterations: 15
Trade-offs:
FunctionLinearOperatordef matvec(vec):
return X_sparse @ vec
input_structure = jax.ShapeDtypeStruct((M,), jnp.float32)
X_op = lx.FunctionLinearOperator(matvec, input_structure)
solver = lx.NormalCG(rtol=1e-5, atol=1e-5, max_steps=1000)
solution = lx.linear_solve(X_op, y, solver)
beta_matvec = solution.value
print(f"Norm diff: {jnp.linalg.norm(beta_direct - beta_matvec):.2e}, Iterations: {solution.stats['num_steps']}")Norm diff: 7.22e-04, Iterations: 16
N_eig = 4
a = 0.1
b = 0.05
lower_diag = jnp.full(N_eig-1, b)
main_diag = jnp.concatenate([jnp.array([-a]), jnp.full(N_eig-2, -(a+b)), jnp.array([-b])])
upper_diag = jnp.full(N_eig-1, a)
Q_eig = cola.ops.Tridiagonal(lower_diag, main_diag, upper_diag)
Q_T = cola.ops.Tridiagonal(upper_diag, main_diag, lower_diag)
eigenvalues, eigenvectors = cola.eig(Q_T, k=1, which='SM')
lambda_min = eigenvalues[0].real
phi = eigenvectors[:, 0].real
phi = phi / jnp.sum(phi)
print(f"λ_min: {lambda_min:.2e}, Mean(φ): {jnp.mean(phi):.6f}, Q.T:\n{Q_T.to_dense()}")λ_min: -2.50e-01, Mean(φ): 0.250000, Q.T:
[[-0.1 0.05 0. 0. ]
[ 0.1 -0.15 0.05 0. ]
[ 0. 0.1 -0.15 0.05]
[ 0. 0. 0.1 -0.05]]
def Q_adj_product(x):
first = -a * x[0] + b * x[1]
middle = a * x[:-2] - (a + b) * x[1:-1] + b * x[2:]
last = a * x[-2] - b * x[-1]
return jnp.concatenate([jnp.array([first]), middle, jnp.array([last])])
key, subkey = random.split(key)
x_check = random.uniform(subkey, (N_eig,))
Q_dense = Q_eig.to_dense()
error = jnp.linalg.norm(Q_adj_product(x_check) - Q_dense.T @ x_check)
print(f"Matrix-free error: {error:.2e}")Matrix-free error: 3.73e-09
FunctionLinearOperator wrapper adds features required for algorithms# Wrap in Lineax FunctionLinearOperator
input_structure = jax.ShapeDtypeStruct((N_eig,), jnp.float32)
Q_adj_op = lx.FunctionLinearOperator(Q_adj_product, input_structure)
# For eigenvalues, we can use CoLA with the dense version
# (CoLA's matrix-free operator support is limited for eigenvalue problems)
Q_dense_T = Q_dense.T
Q_cola = cola.ops.Dense(Q_dense_T)
# Find smallest eigenvalue
eigenvalues_mf, eigenvectors_mf = cola.eig(Q_cola, k=1, which='SM')
lambda_min_mf = eigenvalues_mf[0].real
phi_mf = eigenvectors_mf[:, 0].real
phi_mf = phi_mf / jnp.sum(phi_mf)
print(f"Smallest eigenvalue (matrix-free): {lambda_min_mf:.2e}")
print(f"Mean of eigenvector: {jnp.mean(phi_mf):.6f}")Smallest eigenvalue (matrix-free): -2.50e-01
Mean of eigenvector: 0.250000
\[ \begin{aligned} A x &= b\\ A P^{-1} P x &= b\\ A P^{-1} y &= b\\ P x &= y \end{aligned} \]
That is, solve \((A P^{-1})y = b\) for \(y\), and then solve \(P x = y\) for \(x\).
N_precond = 200
key, subkey = random.split(key)
A_sparse_precond = random.uniform(subkey, (N_precond, N_precond))
key, subkey = random.split(key)
A_sparse_precond = jnp.where(random.uniform(subkey, (N_precond, N_precond)) < 0.1, A_sparse_precond, 0.0)
A_precond = A_sparse_precond @ A_sparse_precond.T + 0.5 * jnp.eye(N_precond)
key, subkey = random.split(key)
b_precond = random.uniform(subkey, (N_precond,))
operator_precond = lx.MatrixLinearOperator(A_precond, tags=lx.positive_semidefinite_tag)
solver_precond = lx.CG(rtol=1e-6, atol=1e-6, max_steps=1000)
solution_no_precond = lx.linear_solve(operator_precond, b_precond, solver_precond)
print(f"cond(A) = {jnp.linalg.cond(A_precond):.2e}, Iterations: {solution_no_precond.stats['num_steps']}")cond(A) = 2.26e+02, Iterations: 59
D_inv_sqrt = 1.0 / jnp.sqrt(jnp.diag(A_precond))
P_diag = jnp.diag(D_inv_sqrt)
A_precond_system = P_diag @ A_precond @ P_diag
b_precond_system = P_diag @ b_precond
operator_precond_system = lx.MatrixLinearOperator(A_precond_system, tags=lx.positive_semidefinite_tag)
solution_precond = lx.linear_solve(operator_precond_system, b_precond_system, solver_precond)
x_precond = P_diag @ solution_precond.valueIterations (with diagonal preconditioner): 57
Reduction: 3.4%
Error: 1.55e-05
print(f"Without preconditioner: {solution_no_precond.stats['num_steps']} iterations")
print(f"With diagonal precond: {solution_precond.stats['num_steps']} iterations")
reduction_pct = (1 - solution_precond.stats['num_steps']/solution_no_precond.stats['num_steps'])*100
print(f"Reduction: {reduction_pct:.1f}%")
print(f"\nCondition numbers:")
print(f"cond(A): {jnp.linalg.cond(A_precond):.2e}")
print(f"cond(P A P): {jnp.linalg.cond(A_precond_system):.2e}")Without preconditioner: 59 iterations
With diagonal precond: 57 iterations
Reduction: 3.4%
Condition numbers:
cond(A): 2.26e+02
cond(P A P): 2.26e+02
When preconditioning helps:
jax.experimental.sparse)Sources: JAX GitHub Discussion #18452, JAX Sparse Documentation
IncompleteLU.jl for comparisonSOR (Successive Over-Relaxation) is more complex to implement functionally in JAX due to the need for sequential updates. Use Krylov methods instead for better performance. - Educational implementation of Jacobi iteration using jax.lax.scan
# Use the CTMC example from earlier
A_jacobi = A # From the CTMC example
b_jacobi = r
v_direct_jacobi = v_direct
# Jacobi iteration: x^{k+1} = D^{-1}(b - R x^k)
def jacobi_step(x, _):
"""Single Jacobi iteration step (functional, no mutation)"""
D_inv = 1.0 / jnp.diag(A_jacobi)
R = A_jacobi - jnp.diag(jnp.diag(A_jacobi))
x_new = D_inv * (b_jacobi - R @ x)
return x_new, None
# Run 40 iterations using scan
x0 = jnp.zeros(N)
x_jacobi, _ = jax.lax.scan(jacobi_step, x0, None, length=40)
error_jacobi = jnp.linalg.norm(x_jacobi - v_direct_jacobi, ord=jnp.inf)
print(f"Error after 40 iterations: {error_jacobi:.2e}")Error after 40 iterations: 1.77e-03