<torch._C.Generator at 0x7fd6610c80f0>
Machine Learning Fundamentals for Economists
C, d, n = 3, 4, 3 # vocab size, embedding dim, sequence length
# One-hot encoding → embedding
token_ids = torch.tensor([0, 2, 1])
one_hot = F.one_hot(token_ids, num_classes=C).float()
W_E = torch.randn(C, d)
X = one_hot @ W_E # (n, d) embedded tokens
# Learned projection matrices
W_Q = torch.randn(d, d)
W_K = torch.randn(d, d)
W_V = torch.randn(d, d)
Q = X @ W_Q # broadcasts to all tokens
K = X @ W_K
V = X @ W_V
# Scaled dot-product attention (manual)
scores = (Q @ K.T) / d**0.5
weights = torch.softmax(scores, dim=-1)
output = weights @ V
print("Attention weights (each row sums to 1):")
print(weights.detach().numpy().round(3))
print("Output:")
print(output.detach().numpy().round(3))Attention weights (each row sums to 1):
[[0.283 0.313 0.404]
[0.142 0.172 0.687]
[0.756 0.242 0.001]]
Output:
[[-1.087 1.036 -1.564 0.502]
[-1.771 1.595 -2.899 1.01 ]
[-0.184 0.118 0.385 -0.133]]
There will be different senses of similarity for the query, key, and value.
A powerful approach is to have multiple self-attention layers \(j=1,\ldots J\) \[ Z^j = \text{Attn}(X W_Q^j, X W_K^j, X W_V^j) \]
Then concatenate the \(Z^j\)’s along the feature dimension and project with embedding \(W_O\)
\[ \text{MultiHead}(X) = \text{Concat}(Z^1, \ldots, Z^J) \cdot W_O \]
nn.MultiheadAttention includes the output projection \(W_O\) (out_proj)# X persists from earlier: (n=3, d=4) embedded tokens
mha = nn.MultiheadAttention(embed_dim=d, num_heads=2, batch_first=True)
# Self-attention: all heads run in parallel on the same X
mha_output, mha_weights = mha(
X.unsqueeze(0), X.unsqueeze(0), X.unsqueeze(0)
)
print("Attention weights (2 heads, averaged):")
print(mha_weights[0].detach().numpy().round(3))
print("\nOutput (n×d) — one row per token:")
print(mha_output[0].detach().numpy().round(3))Attention weights (2 heads, averaged):
[[0.328 0.347 0.324]
[0.325 0.383 0.292]
[0.398 0.415 0.187]]
Output (n×d) — one row per token:
[[-0.311 0.217 -0.162 -0.223]
[-0.298 0.204 -0.149 -0.222]
[-0.381 0.261 -0.098 -0.273]]
# NNX wraps jax.nn.dot_product_attention with learned projections
mha_nnx = nnx.MultiHeadAttention(
num_heads=2, in_features=d, decode=False, rngs=nnx.Rngs(42)
)
# Self-attention: pass X once (key/value default to query)
X_jax = jnp.array(X.detach().numpy())
output_nnx = mha_nnx(X_jax[None, :, :]) # add batch dim
print("Output shape:", output_nnx.shape)
print(output_nnx[0].round(3))Output shape: (1, 3, 4)
[[-0.45400003 -1.388 -0.75000006 -0.39100003]
[-0.45400003 -1.393 -0.734 -0.41900003]
[-0.393 -1.1960001 -0.698 -0.18100001]]
A standard, broad interpretation of attention for tokens
\[ \text{LayerNorm}(z) = \gamma \odot \frac{z - \mu_z}{\sqrt{\sigma_z^2 + \epsilon}} + \beta \]

\[ \begin{aligned} z' &= \text{LayerNorm}\!\bigl(x + \text{MultiHead}(x)\bigr) \\ z_{\text{out}} &= \text{LayerNorm}\!\bigl(z' + \text{FFN}(z')\bigr) \end{aligned} \]
Data flow (one block):
Input → Multi-Head Attention → Add residual → LayerNorm → FFN → Add residual → LayerNorm → Output


[MASK] fill-in task and the word embeddings from the Embeddings lecture[CLS] token whose final representation is used for classification