Train LLaMA from Scratch
No API, no pretrained weights: rebuild LLaMA's decoder architecture block by block in PyTorch (RMSNorm / RoPE / GQA / SwiGLU / KV cache), then train on a small corpus. The foundation under everything else.
Fine-tuning, RAG, and agents all sit on the "large model" foundation. This project takes that foundation apart: rebuild LLaMA's decoder-only architecture from scratch in PyTorch and train it — not calling an API, but writing RMSNorm / RoPE / GQA / SwiGLU / KV cache line by line. From the course's LLaMA architecture series.
Why build it from scratch
"Using it" and "understanding it" are a layer apart. Being able to rebuild LLaMA from scratch means you actually grasp it: why attention computes the way it does, how positional info is injected, why long-sequence generation needs caching. It's the prerequisite for everything above (fine-tuning / RL / agents).
LLaMA's decoder block (component by component)
LLaMA is a decoder-only architecture — N identical blocks stacked. The real components of each block:
| Component | Role | Difference from the original Transformer |
|---|---|---|
| RMSNorm | layer norm | RMS-only scaling, no mean centering → faster, more stable |
| RoPE | rotary positional encoding | applies position as a rotation to Q/K, natively relative + extrapolatable (no sinusoidal absolute encoding) |
| GQA | grouped-query attention | multiple Q heads share one KV group (e.g. 8 heads / 4 KV) → saves KV-cache memory |
| KV cache | autoregressive speedup | caches past tokens' K/V to avoid recompute, turning O(n²) into incremental O(n) |
| SwiGLU FFN | gated feed-forward | gated activation replaces the ReLU-MLP, works better |
Overall structure (Pre-Norm):
tokens → Embedding (tied with LM Head)
→ N × DecoderBlock:
x = x + Attention(RMSNorm(x)) # RoPE + GQA + KV cache
x = x + SwiGLU_FFN(RMSNorm(x))
→ final RMSNorm → LM Head → logits
Training
Train from random init on a small corpus, watching the loss fall and sample quality rise:
step 0 : loss 8.2 generation ≈ gibberish ("the the , ,")
step ~800 : loss 4.0 generation ≈ half-coherent (words right, grammar off)
step ~1600: loss 2.9 generation ≈ coherent short sentences
The config (dim 512 / 8 layers / 8 heads-4 KV / vocab 32000) and the loss/step numbers above are illustrative; the architecture components are the real LLaMA structure. The course's "2-hour train-LLaMA-from-scratch" + the LLaMA architecture-rebuild series cover this flow.
What this signals
- Owning the foundation: you can implement from scratch, not just
from transformers import ... - Understanding every modern component: what RMSNorm / RoPE / GQA / SwiGLU / KV cache each solve
- Connecting the stack: KV cache is exactly what prompt caching in [[上下文工程]] is built on; architecture understanding underpins fine-tuning and RL choices
- PyTorch engineering: from tensors to the training loop, all hand-written
What the demo replays
The demo first assembles the LLaMA decoder block layer by layer (Embedding → RMSNorm → RoPE/GQA Attention → RMSNorm → SwiGLU → final RMSNorm → tied LM Head), then trains, watching the loss fall and generations go from gibberish to coherent. The components (RMSNorm/RoPE/GQA/SwiGLU/KV-cache) are the real LLaMA architecture from the course's LLaMA series (a video course); the config and loss/step numbers are illustrative, and no training runs in the browser.