Back to projects
Train LLaMA from Scratch
Case Study

Train LLaMA from Scratch

No API, no pretrained weights: rebuild LLaMA's decoder architecture block by block in PyTorch (RMSNorm / RoPE / GQA / SwiGLU / KV cache), then train on a small corpus. The foundation under everything else.

LLaMATransformerRoPERMSNormPyTorchfrom-scratch

Fine-tuning, RAG, and agents all sit on the "large model" foundation. This project takes that foundation apart: rebuild LLaMA's decoder-only architecture from scratch in PyTorch and train it — not calling an API, but writing RMSNorm / RoPE / GQA / SwiGLU / KV cache line by line. From the course's LLaMA architecture series.

Why build it from scratch

"Using it" and "understanding it" are a layer apart. Being able to rebuild LLaMA from scratch means you actually grasp it: why attention computes the way it does, how positional info is injected, why long-sequence generation needs caching. It's the prerequisite for everything above (fine-tuning / RL / agents).

LLaMA's decoder block (component by component)

LLaMA is a decoder-only architecture — N identical blocks stacked. The real components of each block:

ComponentRoleDifference from the original Transformer
RMSNormlayer normRMS-only scaling, no mean centering → faster, more stable
RoPErotary positional encodingapplies position as a rotation to Q/K, natively relative + extrapolatable (no sinusoidal absolute encoding)
GQAgrouped-query attentionmultiple Q heads share one KV group (e.g. 8 heads / 4 KV) → saves KV-cache memory
KV cacheautoregressive speedupcaches past tokens' K/V to avoid recompute, turning O(n²) into incremental O(n)
SwiGLU FFNgated feed-forwardgated activation replaces the ReLU-MLP, works better

Overall structure (Pre-Norm):

tokens → Embedding (tied with LM Head)
      → N × DecoderBlock:
            x = x + Attention(RMSNorm(x))      # RoPE + GQA + KV cache
            x = x + SwiGLU_FFN(RMSNorm(x))
      → final RMSNorm → LM Head → logits

Training

Train from random init on a small corpus, watching the loss fall and sample quality rise:

step 0    : loss 8.2   generation ≈ gibberish ("the the , ,")
step ~800 : loss 4.0   generation ≈ half-coherent (words right, grammar off)
step ~1600: loss 2.9   generation ≈ coherent short sentences

The config (dim 512 / 8 layers / 8 heads-4 KV / vocab 32000) and the loss/step numbers above are illustrative; the architecture components are the real LLaMA structure. The course's "2-hour train-LLaMA-from-scratch" + the LLaMA architecture-rebuild series cover this flow.

What this signals

  • Owning the foundation: you can implement from scratch, not just from transformers import ...
  • Understanding every modern component: what RMSNorm / RoPE / GQA / SwiGLU / KV cache each solve
  • Connecting the stack: KV cache is exactly what prompt caching in [[上下文工程]] is built on; architecture understanding underpins fine-tuning and RL choices
  • PyTorch engineering: from tensors to the training loop, all hand-written
Demo strategy

What the demo replays

The demo first assembles the LLaMA decoder block layer by layer (Embedding → RMSNorm → RoPE/GQA Attention → RMSNorm → SwiGLU → final RMSNorm → tied LM Head), then trains, watching the loss fall and generations go from gibberish to coherent. The components (RMSNorm/RoPE/GQA/SwiGLU/KV-cache) are the real LLaMA architecture from the course's LLaMA series (a video course); the config and loss/step numbers are illustrative, and no training runs in the browser.

Public preview can be enabled later without redesigning the case-study layout