Topic 7

Scaled dot-product attention as the algebraic heart of Transformer blocks

Equation (1) in the paper deserves flash-card status: softmax(QKᵀ / √d_k) V cleanly splits responsibilities between similarity scoring (Q,K) and content payload (V).

Math & statistics used here

  • Matrix form: softmax(Q Kᵀ / √d_k) V — each softmax row is independent; divide stabilises variance of dot products when d_k grows.
  • Masking replaces forbidden logits with −∞ before softmax so corresponding weights hit exactly zero.
  • Associative law ( softmax(Q Kᵀ) V ) parallels database joins: similarities pick rows, values are retrieved fields.
  • Backprop through softmax+V chain is textbook—watch numeric stability when softmax saturates.
  • Position-wise feed-forward Vaswani form: FFN(x)=max(0, x W_1 + b_1) W_2 + b_2 applied row-wise; modern stacks swap ReLU→GELU/SwiGLU but identical tensor shapes.

Checklist you can map to code

  • Queries index; keys advertise content addresses; values carry payloads mixed by weights.
  • Scaling by √d_k keeps dot-product variance roughly O(1) as embedding width grows stabilising softmax slopes.
  • Masking inserts −∞ logits before softmax to zero undesirable connections (padding or future tokens).
  • Batched GEMMs mean QKV projections share the same optimised BLAS kernels as feed-forward matrices.
  • Gradient flow through softmax benefits from logits staying in moderate dynamic range—mirrors logits temperature tricks seen in distillations.
  • A full Transformer layer alternates MHSA sublayers with position-wise dense FFN slabs (two affine maps + activation) sandwiching residuals—depth is repeats of that pattern.

Implementation sketch for one head: take hidden states X ∈ ℝ^{n × d_model}, multiply by learnt matrices W_Q, W_K, W_V to yield Q,K,V shaped n × d_k (or reshaped variants). Multiply Q Kᵀ, scale, mask, softmax along keys, multiply by V, project with W_O—all fused in libraries like FlashAttention when shapes align.

Why separate queries and keys? Because asymmetry permits encoder–decoder cross attention where decoder queries inspect encoder memories that never served as queries themselves. Weight sharing decisions differ per architecture variant—that asymmetry disappears in pure encoder self-attention but returns in decoding stacks.

Softmax masking is paramount for causal decoders—before flash kernels, trainees learned to sprinkle `-1e9` at forbidden entries; now specialised kernels propagate mask bits. Mathematically masked entries receive zero probability and therefore zero gradient with respect to their values (though logits still receive gradients that keep them negative).

Studying extremes helps: identical keys and orthogonal queries flatten attention uniformly; orthogonal keys with peaked similarity yield near one-hot allocations—inductive bias lives in geometric arrangements plus learned projections.

Connect with translation copy behaviour: specialised heads saturate mass on aligning source offsets; others spread mass wide summarising predicates—inspectable via attention rollout techniques outside paper scope.

This topic is prerequisite for interpreting Table 5 ablations on attention variants—changing dot-products or removing scaling directly manipulates softmax curvature.