Machine Learning Fundamentals for Economists
In preparation for the ML lectures we cover core numerical linear algebra concepts
Direct methods and matrix factorizations you’ll learn:
These methods solve problems to machine precision but scale with matrix size
In the next lecture, we’ll see iterative methods that trade precision for scalability
Big-O Notation
For a function \(f(N)\) and a positive constant \(C\), we say \(f(N)\) is \(O(g(N))\), if there exist positive constants \(C\) and \(N_0\) such that:
\[ 0 \leq f(N) \leq C \cdot g(N) \quad \text{for all } N \geq N_0 \]
Ask yourself whether the following is a computationally expensive operation as the matrix size increases
Machine Epsilon
For a given datatype, \(\epsilon\) is defined as \(\epsilon = \min_{\delta > 0} \left\{ \delta : 1 + \delta > 1 \right\}\)
machine epsilon for float64 = 2.220446049250313e-16
1 + eps/2 == 1? True
machine epsilon for float32 = 1.1920928955078125e-07
# Create sparse random matrix using scipy
np.random.seed(42)
A_sp = sp.random(10, 10, density=0.45, format='csr')
print(f"Non-zeros in A: {A_sp.nnz}")
# Invert (must convert to dense)
A_dense = A_sp.toarray()
invA_dense = jnp.linalg.inv(A_dense)
# Count non-zeros (threshold for numerical zeros)
invA_nnz = jnp.sum(jnp.abs(invA_dense) > 1e-10)
print(f"Non-zeros in inv(A): {invA_nnz}")Non-zeros in A: 45
Non-zeros in inv(A): 100
N = 5
# Create tridiagonal matrix
lower = jnp.concatenate([jnp.full(N-2, 0.1), jnp.array([0.2])])
diag = jnp.full(N, 0.8)
upper = jnp.concatenate([jnp.array([0.2]), jnp.full(N-2, 0.1)])
# Build full matrix for inversion
A_tri = jnp.diag(diag) + jnp.diag(lower, -1) + jnp.diag(upper, 1)
print("Inverse of tridiagonal (all elements non-zero):")
print(jnp.linalg.inv(A_tri))Inverse of tridiagonal (all elements non-zero):
[[ 1.2909946e+00 -3.2795697e-01 4.1666660e-02 -5.3763431e-03
6.7204301e-04]
[-1.6397849e-01 1.3118279e+00 -1.6666664e-01 2.1505373e-02
-2.6881720e-03]
[ 2.0833330e-02 -1.6666664e-01 1.2916665e+00 -1.6666664e-01
2.0833334e-02]
[-2.6881718e-03 2.1505374e-02 -1.6666666e-01 1.3118279e+00
-1.6397850e-01]
[ 6.7204307e-04 -5.3763445e-03 4.1666668e-02 -3.2795700e-01
1.2909946e+00]]
N = 1000
key, subkey = jax.random.split(key)
b = jax.random.uniform(subkey, (N,))
# Create tridiagonal matrix
lower_diag = jnp.concatenate([jnp.full(N-2, 0.1), jnp.array([0.2])])
main_diag = jnp.full(N, 0.8)
upper_diag = jnp.concatenate([jnp.array([0.2]), jnp.full(N-2, 0.1)])
# Lineax tridiagonal operator (uses parallel scan, O(N))
A_tri_op = lx.TridiagonalLinearOperator(main_diag, lower_diag, upper_diag)
# Dense matrix for comparison
A_dense = jnp.diag(main_diag) + jnp.diag(lower_diag, -1) + jnp.diag(upper_diag, 1)
# Warmup both solvers (JIT compilation happens on first call)
_ = lx.linear_solve(A_tri_op, b).value.block_until_ready()
_ = jnp.linalg.solve(A_dense, b).block_until_ready()
start = time.perf_counter()
x_tri = lx.linear_solve(A_tri_op, b).value
x_tri.block_until_ready()
tri_time = time.perf_counter() - start
# Dense solve (O(N^3))
start = time.perf_counter()
x_dense = jnp.linalg.solve(A_dense, b)
x_dense.block_until_ready()
dense_time = time.perf_counter() - start
print(f"Tridiagonal (Lineax): {tri_time*1000:.2f} ms")
print(f"Dense (JAX): {dense_time*1000:.2f} ms")
print(f"Speedup: {dense_time/tri_time:.1f}x")Tridiagonal (Lineax): 2.01 ms
Dense (JAX): 14.95 ms
Speedup: 7.4x
Key insight: Lineax uses a direct parallel scan solver (O(N)), much faster than dense solve (O(N^3))
\[ \begin{aligned} U x &= b\\ U &\equiv \begin{bmatrix} 3 & 1 \\ 0 & 2 \\ \end{bmatrix}, \quad b = \begin{bmatrix} 7 \\ 2 \\ \end{bmatrix} \end{aligned} \]
Solving bottom row for \(x_2\)
\[ 2 x_2 = 2,\quad x_2 = 1 \]
Move up a row, solving for \(x_1\), substituting for \(x_2\)
\[ 3 x_1 + 1 x_2 = 7,\quad 3 x_1 + 1 \times 1 = 7,\quad x_1 = 2 \]
Generalizes to many rows. For \(L\) it is “forward substitution”
P @ A ≈ L @ U? True
jnp.linalg.solve# Manual solve using LU: solve L(Ux) = Pb
y = jla.solve_triangular(L, P @ b_lu, lower=True)
x_lu = jla.solve_triangular(U, y, lower=False)
# Direct solve for comparison
x_direct = jnp.linalg.solve(A, b_lu)
print(f"LU solution: {x_lu}")
print(f"Direct solution: {x_direct}")
print(f"Solutions match? {jnp.allclose(x_lu, x_direct)}")LU solution: [-2.7407506 0.50105435 2.7111826 1.1074696 ]
Direct solution: [-2.7407506 0.50105435 2.7111826 1.1074696 ]
Solutions match? True
inv
key, subkey = jax.random.split(key)
b_chol = jax.random.uniform(subkey, (N_chol,))
# Cholesky factorization
L_chol = jla.cholesky(A_pd, lower=True)
# Solve using Cholesky
y_chol = jla.solve_triangular(L_chol, b_chol, lower=True)
x_chol = jla.solve_triangular(L_chol.T, y_chol, lower=False)
# Warmup both solvers (JIT compilation happens on first call)
_ = jnp.linalg.solve(A_pd, b_chol).block_until_ready()
_ = jla.cholesky(A_pd, lower=True).block_until_ready()
_ = jla.solve_triangular(L_chol, b_chol, lower=True).block_until_ready()
_ = jla.solve_triangular(L_chol.T, y_chol, lower=False).block_until_ready()
# Direct solve (doesn't know it's positive definite)
start = time.perf_counter()
x_direct = jnp.linalg.solve(A_pd, b_chol)
x_direct.block_until_ready()
direct_time = time.perf_counter() - start
# Cholesky solve
start = time.perf_counter()
L_chol = jla.cholesky(A_pd, lower=True)
y_chol = jla.solve_triangular(L_chol, b_chol, lower=True)
x_chol = jla.solve_triangular(L_chol.T, y_chol, lower=False)
x_chol.block_until_ready()
chol_time = time.perf_counter() - start
print(f"Direct solve: {direct_time*1000:.2f} ms")
print(f"Cholesky solve: {chol_time*1000:.2f} ms")
print(f"Speedup: {direct_time/chol_time:.1f}x")Direct solve: 3.02 ms
Cholesky solve: 2.63 ms
Speedup: 1.1x
\[ A = Q \Lambda Q^{-1} \]
key, subkey = jax.random.split(key)
A_sym = jax.random.uniform(subkey, (5, 5))
A_sym = (A_sym + A_sym.T) / 2 # Make symmetric
# Eigenvalue decomposition
eigenvalues, Q = jnp.linalg.eigh(A_sym) # eigh for symmetric/Hermitian
Lambda = jnp.diag(eigenvalues)
print(f"||Q Λ Q^-1 - A||: {jnp.linalg.norm(Q @ Lambda @ jnp.linalg.inv(Q) - A_sym):.2e}")
print(f"||Q Λ Q^T - A||: {jnp.linalg.norm(Q @ Lambda @ Q.T - A_sym):.2e}")||Q Λ Q^-1 - A||: 7.66e-07
||Q Λ Q^T - A||: 6.90e-07
# Create sparse system
N_sparse = 1000
A_sparse = sp.random(N_sparse, N_sparse, density=0.01, format='csr')
A_sparse = A_sparse + sp.eye(N_sparse) * 10 # Make diagonally dominant
b_sparse = np.random.rand(N_sparse)
# Solve using SciPy's sparse solver (uses UMFPACK/SuperLU)
x_sparse = spla.spsolve(A_sparse, b_sparse)
print(f"Solved sparse system of size {N_sparse}x{N_sparse} with {A_sparse.nnz} non-zeros")
print(f"Residual: {np.linalg.norm(A_sparse @ x_sparse - b_sparse):.2e}")Solved sparse system of size 1000x1000 with 10987 non-zeros
Residual: 1.71e-14
Note: For production sparse linear solves, use SciPy or interface with PETSc, not JAX
\[ \mathbb P \{ X(t + \Delta) = j \,|\, X(t) \} = \begin{cases} q_{ij} \Delta + o(\Delta) & i \neq j\\ 1 + q_{ii} \Delta + o(\Delta) & i = j \end{cases} \]
\[ Q = \begin{bmatrix} -0.1 & 0.1 & 0 & 0 & 0 & 0\\ 0.1 &-0.2 & 0.1 & 0 & 0 & 0\\ 0 & 0.1 & -0.2 & 0.1 & 0 & 0\\ 0 & 0 & 0.1 & -0.2 & 0.1 & 0\\ 0 & 0 & 0 & 0.1 & -0.2 & 0.1\\ 0 & 0 & 0 & 0 & 0.1 & -0.1\\ \end{bmatrix} \]
\[ \frac{d}{dt} \pi(t) = \pi(t) Q,\quad \text{ given }\pi(0) \]
alpha = 0.1
N_ctmc = 6
# Create tridiagonal Q matrix
lower_ctmc = jnp.full(N_ctmc-1, alpha)
main_ctmc = jnp.concatenate([jnp.array([-alpha]),
jnp.full(N_ctmc-2, -2*alpha),
jnp.array([-alpha])])
upper_ctmc = jnp.full(N_ctmc-1, alpha)
# Build dense matrix for display
Q = jnp.diag(main_ctmc) + jnp.diag(lower_ctmc, -1) + jnp.diag(upper_ctmc, 1)
print("Q matrix:")
print(Q)Q matrix:
[[-0.1 0.1 0. 0. 0. 0. ]
[ 0.1 -0.2 0.1 0. 0. 0. ]
[ 0. 0.1 -0.2 0.1 0. 0. ]
[ 0. 0. 0.1 -0.2 0.1 0. ]
[ 0. 0. 0. 0.1 -0.2 0.1]
[ 0. 0. 0. 0. 0.1 -0.1]]
# Eigenvalue decomposition of Q^T
eigenvalues, eigenvectors = jnp.linalg.eig(Q.T)
# Find eigenvector corresponding to eigenvalue ≈ 0
idx = jnp.argmin(jnp.abs(eigenvalues))
pi_stationary = eigenvectors[:, idx].real
pi_stationary = pi_stationary / jnp.sum(pi_stationary)
print(f"Eigenvalues:\n{eigenvalues.real}")
print(f"\nStationary distribution:")
print(pi_stationary)Eigenvalues:
[-3.7320518e-01 -3.0000022e-01 -2.0000000e-01 -9.9999972e-02
-1.0465228e-08 -2.6794920e-02]
Stationary distribution:
[0.16666669 0.16666669 0.16666669 0.16666666 0.16666664 0.16666664]
\[ \rho v = r + Q v \]
Value function:
[ 38.153847 57.230774 84.92308 115.076935 142.76924 161.84616 ]
Teaser: Can we use iterative methods to avoid forming the full matrix? See next lecture!