从 0 训练 LLaMA
不调 API、不加载权重,用 PyTorch 从零把 LLaMA 的 decoder 架构一块块搭出来(RMSNorm / RoPE / GQA / SwiGLU / KV 缓存),再在小语料上训练。理解大模型的地基。
微调、RAG、Agent 都建在「大模型」这块地基上。这个项目把地基拆开:用 PyTorch 从零复刻 LLaMA 的 decoder-only 架构并训练——不是调 API,是把 RMSNorm / RoPE / GQA / SwiGLU / KV 缓存一行行写出来。来自大模型原理正课的 LLaMA 架构系列。
为什么从零写
「会用」和「理解」差一层。能把 LLaMA 从零搭出来,意味着真正吃透了:注意力为什么这么算、位置信息怎么注入、长序列生成为什么要缓存。这是上层一切([微调] / [RL] / [Agent])的前置认知。
LLaMA 的 decoder block(逐组件)
LLaMA 是 decoder-only 架构,N 层相同的 block 堆叠。每个 block 的真实组件:
| 组件 | 作用 | 与原始 Transformer 的差别 |
|---|---|---|
| RMSNorm | 层归一化 | 只按均方根缩放,去掉均值中心化 → 更快更稳 |
| RoPE | 旋转位置编码 | 把位置信息以旋转作用到 Q/K,天然支持相对位置 + 外推(不再用正弦绝对编码) |
| GQA | 分组查询注意力 | 多个 Q 头共享一组 KV 头(如 8 头 / 4 KV)→ 省 KV 缓存显存 |
| KV 缓存 | 自回归加速 | 缓存历史 token 的 K/V,避免每步重算,把 O(n²) 降到增量 O(n) |
| SwiGLU FFN | 门控前馈 | 门控激活替代 ReLU-MLP,效果更好 |
整体结构(Pre-Norm):
tokens → Embedding(与 LM Head 权重共享)
→ N × DecoderBlock:
x = x + Attention(RMSNorm(x)) # RoPE + GQA + KV cache
x = x + SwiGLU_FFN(RMSNorm(x))
→ 最终 RMSNorm → LM Head → logits
训练
在小语料上从随机初始化训练,监控 loss 下降 + 采样生成质量:
step 0 : loss 8.2 生成 ≈ 乱码("the the 的 的 ,,")
step ~800 : loss 4.0 生成 ≈ 半通顺(词对了语法乱)
step ~1600: loss 2.9 生成 ≈ 通顺短句
配置(dim 512 / 8 层 / 8 头-4 KV / vocab 32000)和上面的 loss/step 是示意值;架构组件是真实 LLaMA 结构。课程的「2 小时从 0 到 1 训练 LLaMA」+ LLaMA 架构复现系列覆盖这套流程。
价值点
- 吃透地基:能从零实现,而不只是
from transformers import ... - 理解每个现代组件:RMSNorm / RoPE / GQA / SwiGLU / KV 缓存 各自解决什么问题
- 打通上下层:KV 缓存正是 [[上下文工程]] 里 prompt cache 的底层;架构理解支撑 [微调] 和 [RL] 的选择
- PyTorch 工程:从张量到训练循环全手写
Demo 真实材料对应
互动 Demo 先逐层拼出 LLaMA decoder block(Embedding → RMSNorm → RoPE/GQA Attention → RMSNorm → SwiGLU → 最终 RMSNorm → tied LM Head),再跑训练看 loss 下降、生成从乱码变通顺。架构组件(RMSNorm/RoPE/GQA/SwiGLU/KV缓存)是真实 LLaMA 结构,来自大模型原理正课 LLaMA 系列(视频课);config 与 loss/step 数值为示意,不在浏览器真跑训练。