Chapter 01

Foundations

This chapter gives you the minimum math and systems intuition needed to follow the report. The point is not abstract ML theory. The point is understanding why the report keeps revisiting cross-entropy, AdamW, precision policy, activation memory, and denser supervision.

1. Prediction objective

Everything starts from next-token prediction, not from a special magic loss

The report still trains a language model in the usual way: produce logits, turn them into a normalized distribution with softmax, and penalize the model when the correct next token gets low probability.

Softmax

Turns raw scores into a probability distribution over the vocabulary.

p_i = exp(z_i) / sum_j exp(z_j)
Cross-entropy

Looks up the log-probability of the correct token and penalizes low confidence.

L = -log p(target)

Perplexity is just a different view of the same signal. Lower average loss means lower perplexity, because ppl = exp(loss) .

2. Optimization

The optimizer remains conventional because the real complexity lives elsewhere

The report uses AdamW with the usual stabilizers: warmup, decay, clipping, and large-batch discipline. That matters because the paper's novelty is not a new optimizer. It is how architecture, precision, routing, and systems are made to cooperate under a standard optimizer.

AdamW update
w <- w(1 - lr*wd) - lr*m_hat / (sqrt(v_hat) + eps)
Interpretation

One term shrinks weights, the other follows a smoothed, per-parameter normalized gradient.

3. Precision policy

Low precision works only because the report keeps a map of what must stay accurate

The paper's low-precision story is selective. Dense GEMMs can move to FP8, but some states still need BF16 or FP32. The right mental model is not one global dtype. It is a precision map tied to failure risk.

Precision by role
State Why its precision matters
GEMM inputs/weights Large speed and memory wins justify FP8 if quantization is well controlled.
Optimizer moments Usually safer in BF16 or higher because they accumulate long-horizon training signal.
Master weights / accumulated gradients Need FP32-like stability because tiny updates compound over very long runs.

This is why the report also talks about recomputation and activation storage. Precision, memory, and communication are one combined design problem.

4. Multi-token prediction

Denser supervision adds extra prediction depth without replacing the base objective

The report adds multi-token prediction so the model learns from more than one future position per training example. Conceptually, the main head still predicts the next token, while extra heads or modules predict later positions.

Main loss
L_main = CE(logits_t, target_t+1)
Extra signal
L_total = L_main + alpha * sum_d CE(logits_t,d, target_t+1+d)

This matters twice: first in training, because it densifies supervision; later in inference, because the same structure can help speculative decoding.

5. Training loop

The report's large-system story still sits on top of one familiar loop

Forward

Run the model, produce logits and auxiliary outputs such as MTP heads.

->
Loss

Combine CE terms, then normalize if the batch is split across accumulation steps.

->
Backward

Recompute selected activations rather than storing everything for the full step.

->
Optimizer step

Update weights after enough microsteps have been accumulated.

The later chapters mostly explain how the report makes this loop feasible at larger scale.