Deepseek RI最近火爆全网,甚至影响到了全球经济走势和政治格局,不得不说AI依然在迅猛的发展,R1的内置COT理念——就是人类的慢思考。

人的大脑有两个系统:系统一基于潜意识,快速做出判断,称为“快思考”;系统二则相反,处理复杂问题,但需要专注力和精力,称为“慢思考”。两者相互影响,大多数偏见来自系统一,但可以通过训练改善。

R1开始让AI有了”系统二“的能力,是一种高阶的智慧,最先让AI有慢思考能力的是openai的o1模型,不过o1不仅昂贵,而且是闭源的,是少数人才能接触到的资源。Deepseek R1 通过大量的创新,通过极低的成本训练出来有推理能力(慢思考能力)的AI,同时将训练方法进行了开源,引爆了整个技术和非技术圈。

今天笔者采用colab免费14G显存的单卡在qwen2.5-7B上面简单复现一下Deepseek RI的推理能力。

效果展示

训练了100step后的效果,模型已经能够通过推理解答数学题目了,同时可以看到此时GRPO的reward得分也在提升到一个新的水平。

Question:哈利的图书馆里有 50 本书。他的妹妹弗洛拉(Flora)拥有的书籍数量是哈利的两倍,而他们的表弟加里(Gary)拥有的书籍数量是哈利的一半。他们三人总共拥有多少本书? (Answer: 175) Response:<reasoning>哈利有50本书。弗洛拉拥有的书籍数量是哈利的两倍,即100本。加里拥有的书籍数量是哈利的一半,即25本。他们三人共有的书籍数量是50 + 100 + 25 = 175本。</reasoning><answer>175</answer>Question:由于新冠病毒的限制,当地商店老板特伦特先生将店内商品的价格每件提高了15%,以维持与之前相同的收入。如果购买一个橙子的价格为 40 美元,购买一个芒果的价格为 50 美元,请计算现在以新价格购买 10 个橙子和 10 个芒果总共需要花费多少钱。 (Answer:1035 )Response:<reasoning>原价总和:10 * 40 + 10 * 50 = 400 + 500 = 900提价幅度:15%提价后总和:900 * (1 + 15%) = 900 * 1.15 = 1035</reasoning><answer>1035</answer>

grpo reward

ac857f97eac6402c57167ccbce6a5a74.jpeg

GRPO算法简介

DeepseekR1的推理能力,是从强化学习GRPO进行冷启动的,没有采用大量的标注数据,直接通过强化学习就让模型有了很强的推理能力。那GRPO到底是个什么呢。如下图所示, 在传统的近端策略优化算法(PPO)中,通常需要同时训练策略模型和价值模型,后者用于估计每个状态的期望回报,并以此作为优势函数的基线。对于大型语言模型来说,训练与策略模型规模相当的价值网络不仅增加了计算量,还会带来显著的内存开销。为了解决这一问题,GRPO 提出了利用“组内”生成数据的思路:

189d593364b32742cc987ef6e32dd858.jpeg

  • 多样本生成:对于每个输入(例如一个问题),模型根据旧策略生成多个候选输出。

  • 奖励评估: 对每个候选输出采用特定的奖励函数进行评估,奖励可以包括答案正确性、格式符合要求、推理过程合理等指标(例如 DeepSeek 系列中常用的准确性奖励和格式奖励)

  • 组内优势计算: 将这组输出的奖励视为一个样本集,直接计算其均值和标准差,并将每个输出的奖励进行标准化(即减去均值、除以标准差),从而获得组内相对优势。这种方式能够反映出同一问题下各个候选答案的“相对好坏”,而不需要单独训练一个价值模型。

其优势省去价值网络,占用资源少,同时训练稳定性较PPO要高。

训练过程

加载模型:

安装必要的包,同时加载量化后的模型,同时增加lora参数作为训练参数,这里使用的qwen2.5-7B。

%%capture# Skip restarting message in Colabimport sys; modules = list(sys.modules.keys())for x in modules: sys.modules.pop(x) if "PIL" in x or "google" in x else None!pip install unsloth vllm!pip install --upgrade pillow# If you are running this notebook on local, you need to install `diffusers` too# !pip install diffusers# Temporarily install a specific TRL nightly version!pip install git+https://github.com/huggingface/trl.git@e95f9fb74a3c3647b86f251b7e230ec51c64b72bfrom unsloth import FastLanguageModel, PatchFastRLPatchFastRL("GRPO", FastLanguageModel)from unsloth import is_bfloat16_supportedimport torchmax_seq_length = 1024 # Can increase for longer reasoning traceslora_rank = 32 # Larger rank = smarter, but slowermodel, tokenizer = FastLanguageModel.from_pretrained(    model_name = "Qwen/Qwen2.5-7B-Instruct",    max_seq_length = max_seq_length,    load_in_4bit = True, # False for LoRA 16bit    fast_inference = True, # Enable vLLM fast inference    max_lora_rank = lora_rank,    gpu_memory_utilization = 0.6, # Reduce if out of memory)model = FastLanguageModel.get_peft_model(    model,    r = lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128    target_modules = [        "q_proj", "k_proj", "v_proj", "o_proj",        "gate_proj", "up_proj", "down_proj",    ], # Remove QKVO if out of memory    lora_alpha = lora_rank,    use_gradient_checkpointing = "unsloth", # Enable long context finetuning    random_state = 3407,)

加载数据:

数据长这样,就是一个数学问题,一个答案,答案是数字。没有任何推理数据,后面靠强化学习让模型学会推理。

9ab70f99d772243ab04b1df62cd91199.jpeg

import refrom datasets import load_dataset, Dataset# Load and prep datasetSYSTEM_PROMPT = """采用下方格式回复用户问题:<reasoning>...</reasoning><answer>...</answer>"""XML_COT_FORMAT = """\\<reasoning>{reasoning}</reasoning><answer>{answer}</answer>"""def extract_xml_answer(text: str) -> str:    answer = text.split("<answer>")[-1]    answer = answer.split("</answer>")[0]    return answer.strip()def extract_hash_answer(text: str) -> str | None:    if "####" not in text:        return None    return text.split("####")[1].strip()# uncomment middle messages for 1-shot promptingdef get_gsm8k_questions(split = "train") -> Dataset:    data = load_dataset('swulling/gsm8k_chinese')[split] # type: ignore    data = data.map(lambda x: { # type: ignore        'prompt': [            {'role': 'system', 'content': SYSTEM_PROMPT},            {'role': 'user', 'content': x['question_zh-cn']}        ],        'answer': x['answer_only']    }) # type: ignore    return data # type: ignoredataset = get_gsm8k_questions()

定义奖励函数:

主要是包含正确性奖励和格式奖励这种基于rule的奖励。奖励函数的定义对于最终的结果非常重要。

# Reward functionsdef correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:    responses = [completion[0]['content'] for completion in completions]    q = prompts[0][-1]['content']    extracted_responses = [extract_xml_answer(r) for r in responses]    print('-'*20, f"Question:\\n{q}", f"\\nAnswer:\\n{answer[0]}", f"\\nResponse:\\n{responses[0]}", f"\\nExtracted:\\n{extracted_responses[0]}")    return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]def int_reward_func(completions, **kwargs) -> list[float]:    responses = [completion[0]['content'] for completion in completions]    extracted_responses = [extract_xml_answer(r) for r in responses]    return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]def strict_format_reward_func(completions, **kwargs) -> list[float]:    """Reward function that checks if the completion has a specific format."""    pattern = r"^<reasoning>\\n.*?\\n</reasoning>\\n<answer>\\n.*?\\n</answer>\\n$"    responses = [completion[0]["content"] for completion in completions]    matches = [re.match(pattern, r) for r in responses]    return [0.5 if match else 0.0 for match in matches]def soft_format_reward_func(completions, **kwargs) -> list[float]:    """Reward function that checks if the completion has a specific format."""    pattern = r"<reasoning>.*?</reasoning>\\s*<answer>.*?</answer>"    responses = [completion[0]["content"] for completion in completions]    matches = [re.match(pattern, r) for r in responses]    return [0.5 if match else 0.0 for match in matches]def count_xml(text) -> float:    count = 0.0    if text.count("<reasoning>\\n") == 1:        count += 0.125    if text.count("\\n</reasoning>\\n") == 1:        count += 0.125    if text.count("\\n<answer>\\n") == 1:        count += 0.125        count -= len(text.split("\\n</answer>\\n")[-1])*0.001    if text.count("\\n</answer>") == 1:        count += 0.125        count -= (len(text.split("\\n</answer>")[-1]) - 1)*0.001    return countdef xmlcount_reward_func(completions, **kwargs) -> list[float]:    contents = [completion[0]["content"] for completion in completions]    return [count_xml(c) for c in contents]

1. correctness_reward_func(正确性奖励函数)

检查模型的输出是否与参考答案 (answer) 完全匹配,匹配则奖励 2.0,否则 0.0

2. int_reward_func(整数检测奖励函数)

检查模型输出是否是纯数字(整数),是则奖励 0.5,否则 0.0

3. strict_format_reward_func(严格格式奖励函数)

检查模型输出是否完全符合 **严格的 XML 格式,**符合格式的奖励 0.5,否则 0.0

<reasoning>逻辑推理内容</reasoning><answer>答案内容</answer>

4. soft_format_reward_func(宽松格式奖励函数)

允许更灵活的格式,只要包含 <reasoning>...</reasoning> 和 <answer>...</answer>,即奖励 0.5

5. count_xml,xmlcount_reward_func(XML 结构评分函数)

计算模型输出 XML 结构的完整度,并给予相应奖励。奖励规则:

检查 XML 结构完整度:

每个正确的标签匹配增加 0.125 奖励:

  • <reasoning>\\n:+0.125

  • </reasoning>\\n:+0.125

  • <answer>\\n:+0.125

  • </answer>:+0.125

考虑额外文本的惩罚

  • 如果 </answer> 后面有多余的内容,则减少奖励 0.001 × 额外字符数

模型训练

采用trl实现的grpo就可以开始训练了,可以看到奖励慢慢在增加。

from trl import GRPOConfig, GRPOTrainertraining_args = GRPOConfig(    use_vllm = True, # use vLLM for fast inference!    learning_rate = 5e-6,    adam_beta1 = 0.9,    adam_beta2 = 0.99,    weight_decay = 0.1,    warmup_ratio = 0.1,    lr_scheduler_type = "cosine",    optim = "paged_adamw_8bit",    logging_steps = 1,    bf16 = is_bfloat16_supported(),    fp16 = not is_bfloat16_supported(),    per_device_train_batch_size = 1,    gradient_accumulation_steps = 1, # Increase to 4 for smoother training    num_generations = 6, # Decrease if out of memory    max_prompt_length = 256,    max_completion_length = 200,    # num_train_epochs = 1, # Set to 1 for a full training run    max_steps = 250,    save_steps = 250,    max_grad_norm = 0.1,    report_to = "none", # Can use Weights & Biases    output_dir = "outputs",)trainer = GRPOTrainer(    model = model,    processing_class = tokenizer,    reward_funcs = [        xmlcount_reward_func,        soft_format_reward_func,        strict_format_reward_func,        int_reward_func,        correctness_reward_func,    ],    args = training_args,    train_dataset = dataset,)trainer.train()

0b40a91bb9e379a5507cb719fcfd1859.jpeg

一套下来,模型就学会了思考,不得不说deepseek真牛皮,没有任何cot数据也能训练出推理模型。

参考

代码:https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.1_(8B)-GRPO.ipynb4

数据:https://huggingface.co/datasets/swulling/gsm8k_chinese

模型:https://huggingface.co/Qwen/Qwen2.5-7B-Instruct

Logo

DAMO开发者矩阵,由阿里巴巴达摩院和中国互联网协会联合发起,致力于探讨最前沿的技术趋势与应用成果,搭建高质量的交流与分享平台,推动技术创新与产业应用链接,围绕“人工智能与新型计算”构建开放共享的开发者生态。

更多推荐