核心内容摘要
CarMaker新手必看:3个核心模块详解与快速上手技巧(附避坑指南)
论文下载地址https://arxiv.org/pdf/
2
03816由清华大学和智谱 AI 等机构的研究人员于 2024 年发表在NeurIPS 上的文章。
ReST-MCTS: LLM Self-Training via Process Reward Guided Tree Search(通过过程奖励引导的树搜索进行LLM自训练)这篇论文之所以重要是因为它试图解决大模型目前最大的痛点之一如何让模型具备类似人类“慢思考”System 2的复杂推理能力并能够通过自我对弈来持续进化。
本文将深入拆解经典论文ReST-MCTS探讨如何将蒙特卡洛树搜索 (MCTS)与过程奖励模型 (PRM)引入 LLM 训练闭环。
我们将从算法原理、架构设计、数学推导到实战效果全方位解析模型如何通过“搜索-筛选-自训练”的迭代在 MATH 等高难度数据集上实现能力的螺旋式上升。
如果你对下一代具备“慢思考”能力的推理模型感兴趣这篇文章不容错过。
第一讲背景与动机从 Chain-of-Thought 到 System 2为什么我们需要树搜索今天我们开始第一讲。
在钻进复杂的算法之前我们必须先搞清楚作者为什么要写这篇论文他们试图解决什么核心矛盾
LLM 的“线性直觉”陷阱之前的主流模型如 GPT-4 或 Claude在处理简单任务时表现出色。
它们主要依赖于Chain-of-Thought (CoT)也就是思维链。
CoT 的本质它是线性的。
模型预测下一个词再下一个词就像一个人在凭直觉说话不打草稿一条路走到黑。
致命弱点错误累积 (Error Propagation)。
在解决复杂的数学题或逻辑题时如果第一步推理错了后面所有的推导都会基于这个错误导致“一步错步步错”。
模型很难回过头来说“哎呀我刚才那步好像不对我退回去重算。
”
System 1 vs. System 2在认知科学中丹尼尔·卡尼曼《思考快与慢》System 1 (快思考)直觉、本能、快速反应。
目前的 LLM 只有 CoT 时主要是在用 System 1。
System 2 (慢思考)逻辑、规划、审视、回溯。
人类做难题时会尝试不同的路径如果发现走不通会退回来换个思路。
ReST-MCTS 的核心想法就是给 LLM 装上 System 2 的大脑。
为什么要引入树搜索 (Tree Search)?既然线性生成容易出错我们自然想到了树 (Tree)。
如果我们将推理的每一步看作树的一个节点。
那么解决问题的过程就不是一条线而是一棵分叉的树。
我们需要一种算法在这棵树上找到通往正确答案的“最优路径”。
这就是MCTS (蒙特卡洛树搜索)登场的原因。
AlphaGo 当年打败李世石靠的就是 MCTS。
ReST-MCTS 试图将 AlphaGo 的这种“走一步、看三步”的能力移植到大语言模型的推理中。
在我上一篇博客详细的讲解了MCTS不懂的朋友可以去看看我个人认为那篇博客算是中文互联网对MCTS讲解最详细深入的。
不过如果不想深入理解MCTS接下来我也会尽量讲清楚MCTS。
为什么要“自训练” (Self-Training)?有了搜索还不够。
AlphaGo 之所以强是因为它通过左右互搏Self-Play产生了大量高质量的棋谱然后用这些棋谱训练自己。
对于 LLM 来说高质量的推理数据尤其是带有详细步骤评分的数据极其昂贵且稀缺。
ReST-MCTS 的洞见既然 MCTS 可以通过搜索找到比模型原始输出更好的答案那我们为什么不把这些“搜索出来的更好答案”当作教材反过来训练模型呢这就是ReST (Reinforced Self-Training)的含义模型利用 MCTS 搜索生成数据然后用这些数据强化自己形成一个正向循环。
这个框架听起来很完美但实施起来有两个巨大的拦路虎LLM 的搜索空间太大了词汇表那么大树怎么建。
如何判断推理的“中间步骤”是对是错最终答案对不代表中间全对。
第二讲核心概念预备MCTS 与 PRM 的极简入门这一讲我们将脱离复杂的公式从直觉层面搞懂这两个组件是如何工作的以及为什么它们是绝配。
MCTS在迷雾中寻找最优路径大语言模型生成答案的过程本质上是一个在无限可能中做选择的过程。
传统的生成方式Greedy Search 或 Sampling就像是一个走夜路的醉汉只看脚下当前概率最高的一个词走一步算一步直到撞墙生成错误或到家生成正确。
MCTS (Monte Carlo Tree Search)则赋予了这个醉汉“模拟未来”的能力。
它不急着真正迈出这一步而是在脑海里先模拟各种走的后果。
标准的 MCTS 包含经典的4个步骤请务必记住这4个词后面会反复用到选择 (Selection)从根节点当前问题出发根据某种策略通常是 UCB 公式平衡“利用已知好路径”和“探索未知路径”选择一个最有希望的子节点一直走到叶子节点。
直觉挑一条目前看起来最靠谱的路走下去。
扩展 (Expansion)如果你走到了一个从未探索过的路口叶子节点那么就尝试迈出新的一步生成一个新的节点推理步骤。
直觉在这个路口我看看能往哪儿走多想一种可能性。
模拟 (Simulation / Rollout)从这个新节点开始快速地走到终点得出最终答案。
在围棋里是快速下完这盘棋在数学题里就是让模型把剩下的步骤做完。
直觉如果我按这个思路走最后能做对吗让我快速推演一下。
反向传播 (Backpropagation)根据模拟的结果赢了/输了做对了/做错了把这个分数值“回传”给路径上的所有父节点。
直觉如果这条路通向成功那就告诉前面的每一个路口“这条路不错下次多往这儿走”在 ReST-MCTS 中的特殊性传统的 MCTS如 AlphaGo主要处理离散的、有限的动作空间围棋盘面是固定的。
但 LLM 的词汇表是巨大的推理步骤是无限的。
因此ReST-MCTS 对上述步骤做了针对性的改造我们会在第四讲详细展开。
PRM从“期末考试”到“平时测验”有了 MCTS 这个导航仪我们还需要一个能够评估“这步走得好不好”的打分器。
在强化学习RLHF的早期大家主要用ORM (Outcome Reward Model结果奖励模型)。
ORM 的逻辑模型做完一整道复杂的数学题最后答案对就是 1错就是 0。
ORM 的问题稀疏奖励这就好比期末考试。
你写了20行的证明过程第3行犯了个小错导致最后结果错了。
老师ORM直接给你打0分却不告诉你是哪一步错了。
模型很委屈不知道该改哪里。
PRM (Process Reward Model过程奖励模型)应运而生。
PRM 的逻辑它像是一个耐心的辅导老师盯着你写的每一步。
Step 1: 设 x 为苹果数量... -Reward:
9(起步不错)Step 2: 根据题意列方程 2x 5
.. -Reward:
1(甚至负分因为方程列错了)PRM 的优势它提供了密集奖励 (Dense Reward)。
模型在第2步就能收到反馈MCTS 就会立刻知道“这条分支废了不用再往后模拟了浪费算力赶紧换条路。
”
强强联手Value Function (价值函数)在 ReST-MCTS 论文中你会频繁看到 V(s) 这个符号。
它就是MCTS 和 PRM 的结合点。
Policy (策略)负责生成下一步动作生成新的推理步骤。
这是 LLM 的本职工作。
Value (价值)负责评估当前状态当前的推理步骤到底有多好未来有多大几率能做对。
在 ReST-MCTS 中PRM 实际上充当了 Value Function 的角色或者辅助训练 Value Function。
它指导 MCTS 的搜索树向着高分值的方向生长修剪掉低分值的枯枝。
第三讲ReST-MCTS 总体架构鸟瞰全局搜索、生成与训练的闭环ReST-MCTS 是一个迭代式的、自我进化的闭环。
它像滚雪球一样随着迭代次数增加模型越来越强。
我们可以把这个架构拆解为三个核心步骤循环往复核心循环Generate - Refine - Train第一步搜索与生成 (Generate / Search)动作给定一系列数学难题模型并不直接写答案而是利用MCTS辅助自己思考。
过程在这个阶段模型会尝试生成很多条不同的推理路径Reasoning Traces。
有些路径走得通得出正确答案有些走不通逻辑错误或答案错误。
关键点由于 MCTS 具有“前瞻性”和 PRM 的“指引”这样搜索出来的路径通常比模型直接由 greedy decoding贪婪解码生成的路径质量要高得多。
这就好比虽然我水平一般但我可以多打几遍草稿最后拼凑出一个比我平时水平高得多的解题过程。
第二步筛选与精炼 (Refine / Filter)动作从第一步生成的大量“草稿”中挑选出那些真正正确且高质量的路径。
过程我们需要把那些最终答案错误的路径扔掉。
即使答案对了如果中间步骤很啰嗦或者有跳跃PRM 分数不高的也可能被降权。
最后留下来的是当前模型能力上限能触达到的最优推理过程。
第三步自训练 (Train / Learn)动作用筛选出来的黄金数据来微调Fine-tune模型本身。
双重更新注意这里其实有两个模型在进化或者说是模型的一体两面Policy Model (策略模型)学习那些好的推理步骤。
目的是下次不用费力搜索凭借直觉也能大概率走出这一步。
让 System 1 吸收 System 2 的经验ValueModel (价值模型)学习更准确地给步骤打分。
目的是下次搜索时指南针能指得更准。
架构的精髓Policy 与 Value 的“双螺旋”进化这篇论文最漂亮的地方在于它不仅仅训练了生成能力Policy还同步训练了判别能力Value。
Round 0:Policy 很弱Value 很弱。
MCTS 搜索得很费劲只能找到少量正确路径。
Round 1:用 Round 0 找到的少量正确路径训练 Policy - Policy 变强了。
同时训练 Value - 判别能力变强了。
Round 2:现在 Policy 变强了作为 MCTS 的“直觉”基础它能生成更好的候选步骤。
Value 变强了能更早发现错误路径。
结果MCTS 能搜索到更难、更复杂的题目的正确解法... Loop Continues ...这就解释了为什么论文题目叫ReST (Reinforced Self-Training)Reinforced: 通过 MCTS 和 Value 引导强化学习的思想。
Self-Training: 自己生成数据教自己半监督学习的思想。
在这个循环中如果第一步搜索生成的全是垃圾一道题都没做对整个系统会发生什么 答案系统会卡死。
这就引出了一个关键问题冷启动。
如果模型太弱连 MCTS 都救不了怎么办这也是为什么论文通常需要一个基础不错的 Base Model 和一些初始的种子数据。
第四讲搜索策略详解在这一讲中我们将重点搞清楚MCTS 是如何被魔改以适应大语言模型的毕竟下围棋只有361个落子点和写数学证明词表几万组合无穷完全是两码事。
MCTS 在推理空间中的“探路”机制在 ReST-MCTS 中搜索不再是简单的“走迷宫”而是一场精心设计的概率博弈。
核心定义节点与边首先我们必须重新定义树的构成元素这是理解算法的基础状态 (State / Node,)树的一个节点代表了当前生成的中间推理过程。
根节点是原始的问题Prompt。
比如“证明勾股定理”。
中间节点是“问题 前步推理”。
原始问题 你到现在为止写下的所有推理步骤。
比如s₃ 可能就是“问题 步骤
步骤
步骤3”。
动作 (Action / Edge,)这与传统 MCTS 最大的不同点在围棋里动作是“落一颗子”。
在 ReST-MCTS 中动作是一个完整的推理步骤 (Reasoning Step)。
一个推理步骤可能是“设直角三角形的两条直角边分别为 a 和 b斜边为 c。
” 或者 “根据已知条件X我们可以推导出Y。
”为什么如果按 Token字来搜索树会深不见底且极其庞大。
按“步骤”通常以换行符或句号分隔来搜索可以将树的深度控制在几十层以内大大降低计算量。
选择从大本营根节点出发根据现有地图信息选择一条看起来最有希望的路径分支往下走。
选择策略会平衡“探索新路”和“利用已知的好路”。
扩展当你走到一个地图边缘一个未被完全探索的节点时就让语言模型当“预言家”生成几个k个可能的“下一段小径”即几个候选的推理步骤。
把这几个新路口画到地图上创建新的子节点。
仿真Rollout为了快速评估这个新路口的前景探险队长会从这个路口开始不再仔细搜索而使用一个价值模型打分。
回溯将这次打分得到的结果成功/失败以及路径的质量评估沿着刚才走过的路一路传回大本营更新沿途每个节点的统计数据比如这个节点被访问过多少次从这里出发的探索平均成功率有多高。
这个过程反复进行。
“概率博弈” 就体现在选择时用类似UCB的公式计算是去探索访问次数少的新路口探索还是去走平均成功率高的老路利用。
扩展时语言模型生成的是概率最高的 k 个推理步骤本身就带有概率性。
仿真时快速生成的内容也具有随机性。
MCTS 的四步循环ReST 版深度解析ReST-MCTS 对经典的四步循环进行了适配第一步选择 (Selection)从根节点出发我们需要决定往哪走。
这时候就需要用到PUCT (Predictor UCT) 算法。
模型会计算每个子节点的得分得分最高的被选中。
公式如下请注意 LaTeX 细节这个公式非常精妙它包含了两个对抗的力量(Exploitation, 利用)代表动作的平均价值。
即“根据过往经验这条路通往正确答案的概率有多大”。
这部分数值主要由 Value Model 提供。
(Exploration, 探索)是 Policy Model (LLM) 认为这一步合理的先验概率。
是这一步被访问的次数。
分母越大这一项越小。
含义如果一条路 LLM 觉得很有戏高且大家都没怎么走过小那么这项得分就会很高鼓励程序去“探索”这个潜力股。
第二步扩展 (Expansion)当选择到达叶子节点还没走完的推理我们需要生成新的步骤。
模型调用 Policy Model (LLM)针对当前上下文采样生成 k 个不同的候选步骤Actions。
这 k 个步骤就成为了树上的新分支。
第三步评估 (Evaluation) —— 关键差异点在 AlphaGo 中评估通常靠“快速把棋下完”Rollout。
但在 LLM 中生成文本太慢了且方差极大。
ReST-MCTS 引入了Value Model (价值模型)来加速这一步不一定要跑到底我们直接把当前状态 s 喂给 Value Model。
打分Value Model 也是一个神经网络它会吐出一个标量值或预估从当前这一步继续走下去最终答对的概率。
注论文中其实保留了 Value Model 和 Rollout 结合的可能性但在核心训练循环中训练好的 Value Model 是效率的关键。
第四步反向传播 (Backpropagation)拿到评估值 V 后我们需要把这个好消息或坏消息告诉祖先节点。
沿着刚才走下来的路径更新沿途所有节点的 N(s, a)访问次数1和 Q(s, a)重新计算平均价值。
这样如果底下的叶子节点被发现是个“金矿”它的父节点、祖父节点的 Q 值都会升高下次 Selection 阶段被选中的概率就更大了。
这里的“魔法”在哪里你可能会问这不就是标准的 MCTS 吗ReST-MCTS 的特殊之处在于它处理了 LLM 的两个棘手特性无限的动作空间围棋只有 361 个点LLM 可以生成无数种句子。
策略ReST-MCTS 不会穷举所有可能的句子而是只从 LLM 当前采样的 top-k 个结果中构建树。
这意味着树是稀疏的但它仅仅覆盖了高概率区域。
思维链的连贯性ReST-MCTS 强制树按“步骤”生长保证了逻辑的颗粒度是人类可理解的。
这不仅有助于搜索也为后续的过程奖励 (PRM)训练提供了完美的数据颗粒度。
在这一讲我们拆解了 ReST-MCTS 的核心引擎Step-level Search按推理步骤而非 Token 建树降低复杂度。
PUCT 算法完美平衡了“信赖高分路径”和“尝试潜力路径”。
Value Model 替代 Rollout用神经网络的预测值来代替昂贵的蒙特卡洛模拟大幅提速。
思考题公式中的是一个超参数。
如果设得特别大搜索行为会变成什么样提示会变得极其发散像无头苍蝇一样到处乱试。
如果设得特别小又会怎样提示会变得极其保守只敢走最熟悉的老路失去发现新解法的能力。
搞懂了搜索策略你可能会好奇那个负责打分的神奇 Value Model (PRM) 到底是怎么训练出来的它怎么知道这一步是
8 分还是
2 分第五讲过程奖励 (Process Reward) 的奥秘如何评价推理的每一步是好是坏
痛点信用分配问题 (The Credit Assignment Problem)先回顾一下我们在第二讲提到的ORM (结果奖励)。
假设模型做一道复杂的证明题一共走了10步最后答案错了。
ORM 的反馈0分。
模型的困惑“我走了10步到底是第1步设错了还是第5步公式背错了还是第10步计算错了”这就是著名的信用分配问题——只要结果不对中间所有的努力都被否定模型很难学到具体的逻辑。
PRM (过程奖励)的目标就是解决这个问题给这10步中的每一步都打分。
ReST-MCTS 中的 PRM其实就是价值函数 V(s)在 ReST-MCTS 这篇论文中作者并没有训练一个独立的外部 PRM 模型像某些方法那样雇佣人类去标注每一步而是巧妙地利用了价值模型 (Value Model)来充当 PRM 的角色。
形式。
这是一个神经网络通常是基于 LLM 改装的输入是当前的问题和推理步骤输出是一个标量。
物理含义代表了从状态出发最终能推导出正确答案的估计概率。
核心魔法谁来给 Value Model 提供“标准答案”这是本讲最烧脑也最精彩的部分。
既然没有人类老师给每一步打分Value Model 怎么知道它的预测准不准呢答案是依靠 MCTS 的搜索结果进行“自我标注”。
想象一下MCTS 在搜索树上跑了很多次模拟节点 A被访问了 100 次其中 80 次最终都找到了正确答案。
那么节点 A 的真实价值Ground Truth Value大约是
8。
节点 B被访问了 50 次只有 1 次最终找到了正确答案。
那么节点 B 的真实价值大约是
02。
ReST-MCTS 的训练逻辑如下搜索 (Search)先用当前的 Policy 和 Value 模型跑一轮 MCTS。
统计 (Aggregate)观察搜索树上的每一个节点。
统计经过这个节点的路径中有多少条最终是正确的。
我们计算出一个基于搜索统计的价值记为。
常用公式蒸馏 (Distill)训练 Value Model去逼近这个统计出来的。
Loss Function:均方误差这就是“过程奖励”的奥秘模型不需要人类告诉它哪一步好。
它通过大量的模拟MCTS发现“哎每次只要我走到这一步后面大概率都能做对。
” 于是这一步就被标记为“高分步骤”。
为什么这样做比单纯的 CoT 强CoT是“盲推”。
ReST-MCTS 的 PRM是“后视镜智慧”。
通过 MCTS 搜索完之后我们拥有了上帝视角Gods Eye View。
我们知道哪些路径是死胡同哪些是康庄大道。
我们将这个上帝视角的经验压缩进 Value Model 里。
下次再遇到类似的情况模型不需要再跑几百次模拟只要过一下 Value Model即可。
思考题 这种方法有一个潜在的风险叫做“奖励欺骗” (Reward Hacking)或者说“过拟合”。
如果 MCTS 在搜索时运气好蒙对了一次导致一个错误的步骤被标记为 100% 正确率Value Model 学进去了怎么办 提示这就是为什么我们需要足够多的模拟次数 $N$以及为什么要迭代多次。
单一的样本可能会骗人但大数定律不会。
第六讲自训练 (Self-Training) 机制如何利用搜索生成的轨迹来反哺模型如果说 MCTS 是在“临时抱佛脚”查资料那么自训练就是“考后复盘”把这些知识真正背下来下次考试不用查资料也能直接写出来。
数据的“收割” (Trace Harvesting)在 MCTS 搜索结束后我们会得到一棵茂盛的树。
这棵树里包含了各种各样的路径有的通向正确答案成功路径。
有的通向错误答案失败路径。
有的半途而废未探索完。
第一步我们要从树中收割数据。
我们只关心那些通向正确答案的路径。
我们将这些路径从根节点到叶子节点提取出来形成一条完整的推理轨迹 (Reasoning Trace)。
假设问题是正确的推理步骤序列是。
我们把这些对收集起来放入一个用来训练的“经验池” (Experience Buffer)。
ReST 的核心筛选与加权 (Refining Filtering)ReST (Reinforced Self-Training) 这个名字里“Reinforced”体现在哪里体现在我们不是要把所有生成的数据都塞给模型而是要进行优胜劣汰。
虽然 MCTS 找到的正确路径通常都不错但质量也有高低之分。
路径 A步骤繁琐绕了弯路但最后碰巧做对了。
路径 B逻辑清晰步骤简洁直达核心。
在 ReST-MCTS 的训练过程中我们可以利用Value Model的打分或者路径的搜索统计信息如访问次数来对这些路径进行筛选。
策略通常只保留那些 Value 值高或者被 MCTS 频繁访问说明非常稳健的正确路径。
目的确保模型学到的是“最优解”而不是“凑合解”。
策略模型更新 (Policy Update)数据准备好后就到了真正的“学习”阶段。
我们要更新策略模型的参数。
这其实就是标准的SFT (Supervised Fine-Tuning有监督微调)过程但在强化学习语境下这通常被称为行为克隆 (Behavior Cloning)。
输入问题。
目标输出经过筛选的高质量推理步骤。
目标函数最大化生成这些步骤的概率。
用通俗的话说模型在对自己说“虽然这些步骤是我刚才费劲搜索出来的但现在我要强行记住它们。
下次看到类似的问题我要不假思索地直接生成。
”
闭环进化的魔力 (The Loop Effect)还记得第三讲提到的“闭环”吗这里是它发挥威力的地方Iteration 1:模型很菜 - MCTS 搜索 - 找到一些简单题的正确路径 - 训练模型。
结果模型学会了做简单题。
Iteration 2:模型变强了一点 - MCTS 基于更强的模型搜索 - 能够探索到更深的地方 - 找到了中等题的正确路径 - 训练模型。
结果模型学会了做中等题。
Iteration 3:模型更强了 - MCTS 攻克难题- 训练模型。
结果模型学会了做难题。
这就是Curriculum Learning (课程学习)的一种自动形式。
模型随着训练的进行自己生成的训练数据越来越难水平也螺旋式上升。
思考题 如果我们在自训练的时候不小心混入了一些“答案正确但推理过程错误”的数据False Positive会对模型造成什么影响 提示这是 LLM 推理训练中最大的噪音来源。
模型会学会“一本正经地胡说八道”逻辑不通却能硬凑出答案。
这也反过来证明了第五讲中 PRM (Value Model) 的重要性——它能帮我们识别出那些逻辑混乱的坏步骤。
第七讲损失函数与数学推导深入腹地Policy Loss 与 Value Loss 的设计细节在 ReST-MCTS 中模型训练的目标是双重的让生成的步骤更像搜索出来的最优解(Policy Optimization)。
让打分的直觉更接近搜索出来的真实胜率(Value Optimization)。
通常为了节省显存和提高效率Policy Model和Value Model会共享同一个 LLM 的主体Backbone只在最后加两个不同的“头” (Head)Policy Head: 输出下一个词的概率 (Vocabulary size)。
Value Head: 输出当前状态的标量评分 (Scalar)。
策略损失 (Policy Loss):这是让 System 1 (直觉) 模仿 System 2 (搜索) 的关键。
直觉如果你在 MCTS 搜索中发现了一条通往正确答案的路径那么这条路径上的每一步都是“好动作”。
我们要强迫模型的概率分布去拟合这些动作。
数学公式这本质上就是标准的语言建模损失Cross-Entropy Loss但只针对筛选出的高质量数据计算逐项拆解这是我们通过 MCTS 搜索并筛选出的“黄金数据集”。
x 是问题是搜索出来的正确且高质量的推理轨迹。
在给定上文的情况下模型预测出下一步刚好是的概率。
负对数。
概率越大Loss 越小概率越小Loss 越大惩罚。
一句话解释“对于 MCTS 辛辛苦苦找出来的正确步骤模型预测它的概率越大越好。
”
价值损失 (Value Loss): $\mathcal{L}_{V}$这是让 PRM (评分器) 变得更准的关键。
直觉我们在第五讲说过通过 MCTS 搜索我们可以统计出某个状态的“真实胜率”。
我们要训练 Value Model让它的预测值尽可能接近这个统计值。
数学公式这是一个典型的回归问题 (Regression)通常使用均方误差 (MSE)逐项拆解从搜索树中采集的所有状态推理步骤的集合。
模型当前对于状态的预测打分比如模型觉得这一步有
6 的概率赢。
MCTS 搜索后计算出的真实目标值比如搜索显示经过这个节点赢了 80 次输了 20 次那真实值就是
8。
误差的平方。
预测越准Loss 越接近 0。
一句话解释“模型预测的分数要尽可能接近 MCTS 实际跑出来的胜率。
”
总损失函数 (Total Loss)由于我们是同时训练这两个任务Multi-task Learning最终的损失函数是两者的加权和(Lambda)这是一个超参数用来平衡两个任务的权重。
如果太大模型会过于关注打分准不准而忽视了生成通顺的语句。
如果太小模型生成能力强了但自我评估能力Value跟不上下一轮搜索的效果就会打折。
思考题 我们在计算 Policy Loss 时只用了“正确路径”。
那么那些**“错误路径”**Negative Samples就没有用了吗在标准的 ReST-MCTS 中主要利用正样本进行 Self-Training。
但是在 Value Loss 的计算中错误路径产生的数据点胜率为0是非常重要的它们教会了模型什么是不该做的。
如果只拿正样本训练 Value Model它就会学会给所有东西都打满分那就彻底失效了。
第八讲实验设置与数据集在 GSM8K 和 MATH 上是如何评测的要验证一个“逻辑推理”模型强不强不能让它写诗或者闲聊必须做数学题。
ReST-MCTS 选择了两个最硬核的战场。
战场一GSM8K (小学数学应用题)难度入门级The Qualifying Exam。
内容大约
5k 道小学水平的应用题。
例子“小明有5个苹果吃了2个妈妈又给了3个现在有几个”特点逻辑链条相对较短
步不需要太高深的数学知识但非常考验模型“听懂人话”并将文字转化为算式的能力。
ReST-MCTS 的目标在这里主要验证“基础逻辑是否扎实”。
战场二MATH (竞赛级数学难题)难度地狱级The Final Boss。
内容包含代数、几何、微积分、概率论等 7 个领域的
1
5k 道难题来源于 AMC 10/12美国数学竞赛和 AIME 等。
特点逻辑链条极长可能需要
步推理。
需要领域专业知识。
陷阱多一步算错全盘皆输。
这是 MCTS 最能发挥威力的地方因为普通 CoT 很难蒙对这种题。
参赛选手基座模型 (Base Model)论文作者并没有使用 GPT-4 这种闭源巨无霸因为没法改它的参数进行自训练而是选择了当时开源社区的明星Llama-2 (7B 和 13B)Mistral-7B为什么要用小模型这是一个很聪明的选择。
如果用 7B 的小模型通过 ReST-MCTS 方法能打败参数量大它几倍比如 34B 或 70B的模型那才更能证明“算法的优越性”而不仅仅是靠堆算力。
对照组我们要打败谁 (Baselines)为了证明 ReST-MCTS 有多强必须找几个“陪练”Standard CoT (标准思维链)最基础的用法给几个 prompt 例子让模型直接生成。
代表了 System 1 的原始水平。
Self-Consistency (CoT-SC / Majority Voting)让模型生成 100 个答案然后投票选出最多的那个。
这是推理领域的“强力基线”。
任何新方法如果打不过简单的“人多力量大”的投票法那就没意义。
RFT (Rejection Sampling Fine-Tuning)这是 ReST-MCTS 最大的竞争对手。
原理让模型疯狂生成只保留对的然后微调。
区别RFT 只有“生成过滤”没有 MCTS 的树搜索也没有 Value Model 的过程指导。
胜负手如果 ReST-MCTS 能打败 RFT就证明了“过程奖励”和“树搜索”是必不可少的光靠“题海战术”是不够的。
实验流程迭代 (Iterations)实验不是跑一次就完了而是分了几轮RoundRound 0: 原始模型。
Round 1: 用 Round 0 模型跑 MCTS生成数据训练出 Model-1。
Round 2: 用 Model-1 跑 MCTS此时搜得更准了生成数据训练出 Model-2。
... 作者通常展示了2到3轮的迭代结果。
我们要观察的是随着轮次增加准确率是不是在稳步上升思考题 MATH 数据集特别难很多原始模型 (Round
的准确率可能只有 5% 甚至更低。
在这种情况下MCTS 搜索可能搜半天也搜不到一个正确答案全军覆没。
这时候 ReST-MCTS 还能训练吗提示这就是稀疏奖励的极端情况。
通常需要“热启动”——先用一些高质量的外部数据把模型微调到一个基本水平比如
%然后再开启自训练循环。
第九讲结果分析与消融实验ReST-MCTS 到底强在哪里哪些组件最关键
主战场战报碾压基线 (Main Results)让我们直接看最硬核的MATH数据集竞赛级难题上的表现。
这是区分“好学生”和“天才”的分水岭。
对手一SFT (普通微调)这是底线。
Llama-2 经过普通微调后在 MATH 上可能只能拿个及格分。
对手二RFT (拒绝采样微调)这是最强劲敌。
RFT 不用复杂的树搜索就是让模型生成几百个答案挑对的训练。
结果ReST-MCTS 显著击败了 RFT。
数据解读在同等规模的模型下ReST-MCTS 的准确率通常能比 RFT 高出几个百分点在 MATH 这种高难度数据集上1% 的提升都很难得。
结论这就证明了“有向搜索 (Guided Search)” “盲目抽样 (Random Sampling)”。
MCTS 就像一个经验丰富的探险家能挖掘出 RFT 根本碰运气碰不到的那些复杂解法。
进化的轨迹迭代的魔力 (The Power of Iteration)ReST-MCTS 是分多轮Round训练的。
我们来看看随着轮次增加模型发生了什么变化。
Round 0 (Base): 模型只能做简单的题。
Round 1: 模型学会了 MCTS 在第一轮搜出来的路径。
准确率大幅跳跃。
Round 2:惊人发现准确率继续上升原因因为 Round 1 的模型变强了它的 Value Model 也变准了。
于是 MCTS 在 Round 2 能搜索到更深、更难的题目答案。
模型“看到”了之前从未见过的解法。
Round 3: 提升幅度开始变缓边际效应递减但依然在涨。
这验证了那个核心假设模型可以通过“左脚踩右脚”的方式螺旋式上升突破原本的能力天花板。
除了做数学题这个模型还能干别的吗 实验显示通过 MATH 训练出来的 ReST-MCTS 模型在处理GSM8K简单题或者其他类似的逻辑推理任务时也表现出了极强的泛化性。
这意味着模型学到的不是“背诵这道题的答案”而是学到了通用的“System 2 思维模式”——即如何拆解问题、如何规划步骤、如何自我纠错。
作为这门课的最后一讲我们不仅要回头看看走过的路更要抬头看看天。
这篇论文虽然精彩但它不是终点而是通向下一代推理模型Reasoning Models的重要里程碑。
特别是考虑到 OpenAI 发布的o1 (Strawberry)系列模型还有DEEPMIND的gemini 3pro你会发现 ReST-MCTS 简直就是这一趋势的“预言家”。
总结与展望局限性分析及未来 LLM 推理的发展方向
并没有免费的午餐ReST-MCTS 的代价虽然 ReST-MCTS 在数学题上大杀四方但如果要把它部署到实际产品中我们必须面对几个严峻的现实问题推理解析度与延迟 (Latency)这是最大的痛点。
MCTS 需要生成几十甚至上百个节点才能产出一个最终答案。
体验你问 GPT-4 一个问题它几秒钟就开始打字。
如果你问 ReST-MCTS它可能需要“思考”几分钟甚至更久在后台疯狂搜索、剪枝最后才吐出一个字。
结论它目前适合离线任务如科学研究、解难题不适合实时对话。
算力黑洞 (Computational Cost)为了获得更高质量的数据我们在推理端消耗了巨大的算力Test-time Compute。
这比普通的 CoT 昂贵数个数量级。
alue Model 的训练难度训练一个能准确打分的 Value Model 极其困难。
如果题目太难MCTS 一次都没搜对全部分数为0Value Model 就学不到东西如果题目太简单又区分不出好坏。
这需要极其精细的工程调优Reward Engineering。
时代的趋势从 Training-time 到 Inference-timeReST-MCTS 的核心哲学正好契合了当前 LLM 发展的最大趋势用推理时的算力换取智能 (Trading Inference Compute for Intelligence)。
以前的法则 (Scaling Laws)想让模型变强堆更多的数据训练更大的模型Pre-training。
现在的法则 (Test-time Scaling)模型参数不用变大只要让它在回答前“思考”得久一点搜索、验证、回溯它的智商就能表现得更高。
OpenAI 的o1就是这个逻辑的极致体现。
它内部其实就在进行某种形式的“思维链搜索”和“自我强化”。
ReST-MCTS 提前展示了这一路径的可行性通过 Self-Training把这种费时的搜索能力逐渐内化为模型直觉。
未来的方向ReST-MCTS 还能怎么改如果你想在这个领域做研究以下几个方向非常有前景蒸馏 (Distillation)现在的 MCTS 太慢了。
未来的方向是用慢的 MCTS 老师教出一个快的 Student 模型。
让 Student 模型不需要搜索一眼就能看出 MCTS 思考半天得出的最优解。
超越数学 (Beyond Math)目前主要用在数学和逻辑题。
代码生成 (Coding)代码天然适合 MCTS因为代码可以运行Execution运行结果就是最完美的 Reward。
Agent (智能体)让 Agent 在复杂的环境中规划任务MCTS 用来模拟“如果我这样做后果是什么”。
更高效的搜索算法现在的 MCTS 还是比较原始的。
能否引入Lookahead(多看几步) 或者Beam Search的变体来加速