返回项目
GRPO 推理强化训练器(GSM8K · Qwen2.5-0.5B)
案例拆解

GRPO 推理强化训练器(GSM8K · Qwen2.5-0.5B)

用 TRL 的 GRPOTrainer 在 Qwen2.5-0.5B 上复现 DeepSeek-R1 的 GRPO 算法,靠 5 个可验证奖励函数,在 GSM8K 数学题上把模型从「直接给答案」教成「先写推理链再给答案」。

GRPOTRLDeepSeek-R1GSM8KQwen2.5Reasoning

DeepSeek-R1 用的是 GRPO(Group Relative Policy Optimization)。这套配方在 0.5B 模型上也跑得通——挑这个尺寸的好处是「整个 loop 装得下单卡」,能真正打开看 GRPO 是怎么靠奖励诱导出 reasoning chain 的,而不只是读论文。

为什么是 GRPO 不是 PPO / DPO

挑算法不靠时髦,靠匹配你手头的资源:

算法需要什么这里不合适的原因
PPOcritic 网络 + reward model多一倍模型显存;GSM8K 的 reward 可验证,不需要训练 reward model
DPO人工标注的 chosen/rejected 偏好对数学题答案要么对要么错,没必要花钱标
GRPO一个 reward function(甚至硬规则)+ 一组 K 个采样把「组内相对优势」当作 baseline,省掉 critic——这正是它存在的意义

GRPO 的核心:每个 prompt 采样 K 条 completion,每条算 reward,用组内均值当 baseline,组内偏好(reward 高的 completion)通过 advantage 推向更高概率。

真实训练 config(GRPOConfig)

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from trl import GRPOConfig, GRPOTrainer

model = AutoModelForCausalLM.from_pretrained(
    "Qwen/Qwen2.5-0.5B-Instruct",
    torch_dtype=torch.bfloat16,
).to("cuda")

# GSM8K:每道题用 #### 分隔答案,特别适合做可验证 reward
dataset = load_dataset("openai/gsm8k", "main")["train"]

training_args = GRPOConfig(
    learning_rate=5e-6,
    lr_scheduler_type="cosine",
    warmup_ratio=0.1,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    num_generations=16,           # K=16 — GRPO 的心脏
    max_prompt_length=256,
    max_completion_length=200,
    num_train_epochs=1,
    max_grad_norm=0.1,
    bf16=True,
    use_vllm=False,                # 单设备就够
)

模型被系统 prompt 强制按这个结构输出:

<reasoning>
…在这里逐步推理…
</reasoning>
<answer>
最终数值
</answer>

5 个奖励函数:可验证、可叠加

GRPO 完全靠 reward 形状塑造模型。这套复刻把 reward 拆成 5 个独立信号——分开是为了可观察、可单独 ablation

def correctness_reward_func(prompts, completions, answer, **kwargs):
    """提取 <answer> 后做精确匹配 GSM8K gold"""
    extracted = [extract_xml_answer(c) for c in completions]
    return [2.0 if r == a else 0.0 for r, a in zip(extracted, answer)]

def int_reward_func(completions, **kwargs):
    """答案是不是纯整数(GSM8K 都是)"""
    extracted = [extract_xml_answer(c) for c in completions]
    return [0.5 if r.isdigit() else 0.0 for r in extracted]

def strict_format_reward_func(completions, **kwargs):
    """严格匹配 <reasoning>\n…\n</reasoning>\n<answer>…</answer> 整体结构"""
    pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n?$"
    return [0.5 if re.match(pattern, c, re.DOTALL) else 0.0 for c in completions]

def soft_format_reward_func(completions, **kwargs):
    """宽松版:只要两个 block 都出现就给分"""
    pattern = r"<reasoning>.*?</reasoning>.*?<answer>.*?</answer>"
    return [0.5 if re.search(pattern, c, re.DOTALL) else 0.0 for c in completions]

def xmlcount_reward_func(completions, **kwargs):
    """按 tag 计数给分(0.125/个),尾随多余文本扣分"""
    return [count_xml_tags(c) for c in completions]

trainer = GRPOTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    reward_funcs=[
        correctness_reward_func,
        int_reward_func,
        strict_format_reward_func,
        soft_format_reward_func,
        xmlcount_reward_func,
    ],
)
trainer.train()
Reward信号权重
correctness_reward_func<answer> 精确匹配 GSM8K gold+2.0
int_reward_func答案是不是纯整数+0.5
strict_format_reward_func严格 regex 匹配整体格式+0.5
soft_format_reward_func宽松版:两个 block 都在+0.5
xmlcount_reward_func标签计数(0.125/个)- 多余文本惩罚up to +0.5

correctness 占主导,4 个格式 reward 是「楼梯」——严格 → 宽松 → tag 计数,逐级把模型从自由文本拉向干净的 <reasoning>/<answer> 结构。

训练前后真实对比

这是个 demo 级训练(单卡 3090、~17 GB VRAM、几小时跑完 1 epoch),重点是机制可见性而不是 leaderboard 数字:

例题"Joy can read 8 pages in 20 minutes. How many hours to read 120 pages?"
训练前 0.5B 直接给5 (不解释、不分块、有时连结构都没有)
训练 1 epoch 后<reasoning> 块:算每页 2.5 分钟、120 页 = 300 分钟 = 5 小时;<answer> 块:5

奖励单独就教会了 0.5B 模型「先 show your work, 再给答案」。这正是 DeepSeek-R1 在大模型上做的事。

诚实边界

  • 模型:0.5B(Qwen2.5-0.5B-Instruct)—— 比 R1-scale 模型小约 100 倍
  • 算力:单卡 ~17 GB VRAM,几小时跑完 1 epoch
  • 没用 vLLMuse_vllm=False,是个学习级复刻不是生产 trainer
  • 不追 SOTA 准确率:在 0.5B 上几小时 GRPO 不会跑出多牛的数字。展示的是机制——组采样 + 可验证奖励 → reasoning chain,DeepSeek-R1 把这套放到大尺度上跑

价值点

  • 理解 DeepSeek-R1 的 GRPO 配方到能端到端跑通的程度,不只是看论文
  • 设计可验证奖励 stack 并能讲清每条 reward 怎么塑形行为(correctness 主导、format 楼梯)
  • 诚实叙事:在 0.5B 上演示机制,明确说出 claim 了什么、没 claim 什么
Demo strategy

Demo 真实可跑

不是 replay:互动 Demo 是一个 live reward calculator。你编辑 model completion 和 gold answer,5 个奖励函数(correctness/int/strict_format/soft_format/xmlcount,从 notebook 逐字移植)在你浏览器里实时重算——和 GRPOTrainer 在做组内 advantage 之前打的分一模一样。双语 EN/中文。

Public preview can be enabled later without redesigning the case-study layout