返回项目
从 0 训练 LLaMA
案例拆解

从 0 训练 LLaMA

不调 API、不加载权重,用 PyTorch 从零把 LLaMA 的 decoder 架构一块块搭出来(RMSNorm / RoPE / GQA / SwiGLU / KV 缓存),再在小语料上训练。理解大模型的地基。

LLaMATransformerRoPERMSNormPyTorchfrom-scratch

微调、RAG、Agent 都建在「大模型」这块地基上。这个项目把地基拆开:用 PyTorch 从零复刻 LLaMA 的 decoder-only 架构并训练——不是调 API,是把 RMSNorm / RoPE / GQA / SwiGLU / KV 缓存一行行写出来。来自大模型原理正课的 LLaMA 架构系列。

为什么从零写

「会用」和「理解」差一层。能把 LLaMA 从零搭出来,意味着真正吃透了:注意力为什么这么算、位置信息怎么注入、长序列生成为什么要缓存。这是上层一切([微调] / [RL] / [Agent])的前置认知。

LLaMA 的 decoder block(逐组件)

LLaMA 是 decoder-only 架构,N 层相同的 block 堆叠。每个 block 的真实组件:

组件作用与原始 Transformer 的差别
RMSNorm层归一化只按均方根缩放,去掉均值中心化 → 更快更稳
RoPE旋转位置编码把位置信息以旋转作用到 Q/K,天然支持相对位置 + 外推(不再用正弦绝对编码)
GQA分组查询注意力多个 Q 头共享一组 KV 头(如 8 头 / 4 KV)→ 省 KV 缓存显存
KV 缓存自回归加速缓存历史 token 的 K/V,避免每步重算,把 O(n²) 降到增量 O(n)
SwiGLU FFN门控前馈门控激活替代 ReLU-MLP,效果更好

整体结构(Pre-Norm):

tokens → Embedding(与 LM Head 权重共享)
      → N × DecoderBlock:
            x = x + Attention(RMSNorm(x))      # RoPE + GQA + KV cache
            x = x + SwiGLU_FFN(RMSNorm(x))
      → 最终 RMSNorm → LM Head → logits

训练

在小语料上从随机初始化训练,监控 loss 下降 + 采样生成质量:

step 0    : loss 8.2   生成 ≈ 乱码("the the 的 的 ,,")
step ~800 : loss 4.0   生成 ≈ 半通顺(词对了语法乱)
step ~1600: loss 2.9   生成 ≈ 通顺短句

配置(dim 512 / 8 层 / 8 头-4 KV / vocab 32000)和上面的 loss/step 是示意值;架构组件是真实 LLaMA 结构。课程的「2 小时从 0 到 1 训练 LLaMA」+ LLaMA 架构复现系列覆盖这套流程。

价值点

  • 吃透地基:能从零实现,而不只是 from transformers import ...
  • 理解每个现代组件:RMSNorm / RoPE / GQA / SwiGLU / KV 缓存 各自解决什么问题
  • 打通上下层:KV 缓存正是 [[上下文工程]] 里 prompt cache 的底层;架构理解支撑 [微调] 和 [RL] 的选择
  • PyTorch 工程:从张量到训练循环全手写
Demo strategy

Demo 真实材料对应

互动 Demo 先逐层拼出 LLaMA decoder block(Embedding → RMSNorm → RoPE/GQA Attention → RMSNorm → SwiGLU → 最终 RMSNorm → tied LM Head),再跑训练看 loss 下降、生成从乱码变通顺。架构组件(RMSNorm/RoPE/GQA/SwiGLU/KV缓存)是真实 LLaMA 结构,来自大模型原理正课 LLaMA 系列(视频课);config 与 loss/step 数值为示意,不在浏览器真跑训练。

Public preview can be enabled later without redesigning the case-study layout