GRPO 推理强化训练器(GSM8K · Qwen2.5-0.5B)
用 TRL 的 GRPOTrainer 在 Qwen2.5-0.5B 上复现 DeepSeek-R1 的 GRPO 算法,靠 5 个可验证奖励函数,在 GSM8K 数学题上把模型从「直接给答案」教成「先写推理链再给答案」。
DeepSeek-R1 用的是 GRPO(Group Relative Policy Optimization)。这套配方在 0.5B 模型上也跑得通——挑这个尺寸的好处是「整个 loop 装得下单卡」,能真正打开看 GRPO 是怎么靠奖励诱导出 reasoning chain 的,而不只是读论文。
为什么是 GRPO 不是 PPO / DPO
挑算法不靠时髦,靠匹配你手头的资源:
| 算法 | 需要什么 | 这里不合适的原因 |
|---|---|---|
| PPO | critic 网络 + reward model | 多一倍模型显存;GSM8K 的 reward 可验证,不需要训练 reward model |
| DPO | 人工标注的 chosen/rejected 偏好对 | 数学题答案要么对要么错,没必要花钱标 |
| GRPO | 一个 reward function(甚至硬规则)+ 一组 K 个采样 | 把「组内相对优势」当作 baseline,省掉 critic——这正是它存在的意义 |
GRPO 的核心:每个 prompt 采样 K 条 completion,每条算 reward,用组内均值当 baseline,组内偏好(reward 高的 completion)通过 advantage 推向更高概率。
真实训练 config(GRPOConfig)
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from trl import GRPOConfig, GRPOTrainer
model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen2.5-0.5B-Instruct",
torch_dtype=torch.bfloat16,
).to("cuda")
# GSM8K:每道题用 #### 分隔答案,特别适合做可验证 reward
dataset = load_dataset("openai/gsm8k", "main")["train"]
training_args = GRPOConfig(
learning_rate=5e-6,
lr_scheduler_type="cosine",
warmup_ratio=0.1,
per_device_train_batch_size=1,
gradient_accumulation_steps=4,
num_generations=16, # K=16 — GRPO 的心脏
max_prompt_length=256,
max_completion_length=200,
num_train_epochs=1,
max_grad_norm=0.1,
bf16=True,
use_vllm=False, # 单设备就够
)
模型被系统 prompt 强制按这个结构输出:
<reasoning>
…在这里逐步推理…
</reasoning>
<answer>
最终数值
</answer>
5 个奖励函数:可验证、可叠加
GRPO 完全靠 reward 形状塑造模型。这套复刻把 reward 拆成 5 个独立信号——分开是为了可观察、可单独 ablation:
def correctness_reward_func(prompts, completions, answer, **kwargs):
"""提取 <answer> 后做精确匹配 GSM8K gold"""
extracted = [extract_xml_answer(c) for c in completions]
return [2.0 if r == a else 0.0 for r, a in zip(extracted, answer)]
def int_reward_func(completions, **kwargs):
"""答案是不是纯整数(GSM8K 都是)"""
extracted = [extract_xml_answer(c) for c in completions]
return [0.5 if r.isdigit() else 0.0 for r in extracted]
def strict_format_reward_func(completions, **kwargs):
"""严格匹配 <reasoning>\n…\n</reasoning>\n<answer>…</answer> 整体结构"""
pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n?$"
return [0.5 if re.match(pattern, c, re.DOTALL) else 0.0 for c in completions]
def soft_format_reward_func(completions, **kwargs):
"""宽松版:只要两个 block 都出现就给分"""
pattern = r"<reasoning>.*?</reasoning>.*?<answer>.*?</answer>"
return [0.5 if re.search(pattern, c, re.DOTALL) else 0.0 for c in completions]
def xmlcount_reward_func(completions, **kwargs):
"""按 tag 计数给分(0.125/个),尾随多余文本扣分"""
return [count_xml_tags(c) for c in completions]
trainer = GRPOTrainer(
model=model,
args=training_args,
train_dataset=dataset,
reward_funcs=[
correctness_reward_func,
int_reward_func,
strict_format_reward_func,
soft_format_reward_func,
xmlcount_reward_func,
],
)
trainer.train()
| Reward | 信号 | 权重 |
|---|---|---|
correctness_reward_func | <answer> 精确匹配 GSM8K gold | +2.0 |
int_reward_func | 答案是不是纯整数 | +0.5 |
strict_format_reward_func | 严格 regex 匹配整体格式 | +0.5 |
soft_format_reward_func | 宽松版:两个 block 都在 | +0.5 |
xmlcount_reward_func | 标签计数(0.125/个)- 多余文本惩罚 | up to +0.5 |
correctness 占主导,4 个格式 reward 是「楼梯」——严格 → 宽松 → tag 计数,逐级把模型从自由文本拉向干净的 <reasoning>/<answer> 结构。
训练前后真实对比
这是个 demo 级训练(单卡 3090、~17 GB VRAM、几小时跑完 1 epoch),重点是机制可见性而不是 leaderboard 数字:
| 例题 | "Joy can read 8 pages in 20 minutes. How many hours to read 120 pages?" |
|---|---|
| 训练前 0.5B 直接给 | 5 (不解释、不分块、有时连结构都没有) |
| 训练 1 epoch 后 | <reasoning> 块:算每页 2.5 分钟、120 页 = 300 分钟 = 5 小时;<answer> 块:5 |
奖励单独就教会了 0.5B 模型「先 show your work, 再给答案」。这正是 DeepSeek-R1 在大模型上做的事。
诚实边界
- 模型:0.5B(Qwen2.5-0.5B-Instruct)—— 比 R1-scale 模型小约 100 倍
- 算力:单卡 ~17 GB VRAM,几小时跑完 1 epoch
- 没用 vLLM:
use_vllm=False,是个学习级复刻不是生产 trainer - 不追 SOTA 准确率:在 0.5B 上几小时 GRPO 不会跑出多牛的数字。展示的是机制——组采样 + 可验证奖励 → reasoning chain,DeepSeek-R1 把这套放到大尺度上跑
价值点
- 理解 DeepSeek-R1 的 GRPO 配方到能端到端跑通的程度,不只是看论文
- 设计可验证奖励 stack 并能讲清每条 reward 怎么塑形行为(correctness 主导、format 楼梯)
- 诚实叙事:在 0.5B 上演示机制,明确说出 claim 了什么、没 claim 什么
Demo 真实可跑
不是 replay:互动 Demo 是一个 live reward calculator。你编辑 model completion 和 gold answer,5 个奖励函数(correctness/int/strict_format/soft_format/xmlcount,从 notebook 逐字移植)在你浏览器里实时重算——和 GRPOTrainer 在做组内 advantage 之前打的分一模一样。双语 EN/中文。