正文
关键在于合理使用强化学习的探索
仅靠 MCTS 无法让模型学会思考问题的关联,隐式自动化 CoT 的背后,是模型真正学会了合理的中间推理过程 Rationales。
当人们写作或说话时,常常会停下来思考。然而,大语言模型在通过 Next Token Prediction 生成回答时,更像是一种 “快思考” 过程。由于缺乏详细的中间推理步骤,模型一开始可能会犯错,而这些错误可能会传播,最终导致生成的答案也是错误的。
为了优化这一过程,产生了一系列方法,其中包括在
Token 级别
或
子句级别
提供奖励信号,帮助模型调整生成的回答。这些方法如
蒙特卡洛树搜索(MCTS)
,将输出建模为一系列节点,这些节点可以是 Token 级别或句子级别。例如:
-
Token 级别的节点
:每个节点对应生成序列中的一个 Token。通过 MCTS,模型可以探索不同的 Token 序列,最终生成更连贯的响应。
-
句子级别的节点
:在复杂推理任务中,每个节点可以代表一个完整的句子或推理步骤,帮助模型更好地处理多步推理任务。
另一种方式是通过
思维链(Chain of Thought, CoT)
优化模型输出。CoT
通过分步推理的方式,要求模型在生成最终答案之前,先生成一系列中间推理步骤。
这种 “思考链” 的生成过程有助于增强模型的推理能力,尤其在数学和代码生成等任务中表现出色。
然而,CoT 虽然能够生成中间步骤,但并未教会模型如何从内部深入思考问题的关联。特别是对于尤其复杂且需要多步推理规划的任务,这样的合理的中间 CoT 推理过程(Rationales) 更为重要。
类似的思路在 STaR [1] 和 Quiet-STaR [7] 中有所体现。
STaR 的核心思路是利用 LLM 已有的推理能力,迭代式的 Bootstrap 模型产生合理推理过程(Rationales) 的能力,并将 Rationales 融入到训练过程内,让模型学会进行推理。
-
推理
:起始数据集仅有 [Question, Answer] ,首先利用一些带有推理过程的 Few-Shot Examples 来 Prompt 模型对于数据集中的问题生成对应的推理过程和答案。
-
过滤
:如果生成的答案正确,则将推理过程加入到原有的数据集中;如果生成的答案错误,则尝试在给出正确答案的前提下再次生成推理过程。将最终生成正确答案的推理收集,构建一个构建一个微调数据集 [Question, Rationale, Answer ] 进行微调。
-
迭代
:重复这一过程,且每次获得一个新的数据集,都从原始的模型开始进行 Fine-tune 从而防止过拟合。
STaR 的思路和 RL 中策略梯度算法是近似的,甚至整体的优化目标可以近似为一个策略梯度优化的目标。
模型首先采样潜在的推理路径(rationale)的过程类似于 RL 中通过策略选择动作(action),基于环境状态选择一个可能的策略路径。STaR 中,通过计算目标函数,模型对整个数据集的预测结果进行评估,并且只根据预测正确的样本更新模型。
STaR 在同一批数据上进行多次梯度更新,这类似于某些策略梯度算法中的策略,即通过多次调整同一批数据来稳定学习过程。在 RL 中,策略梯度算法通过这种方式在探索动作空间时进行学习,而 STaR 则通过探索推理和答案空间,逐步改善推理生成的准确性。
这种方法和先前提到的通过细粒度奖励或 MCTS
优化输出
有所不同,模型在正确和错误的示例中更多的学会的是如何进行
显式的合理推理
。
与此同时,这种合理推理不只是问题拆解分步理,更适用于一般常识问答任务上。例如:
-
问题:什么可以被用来装一只小狗
-
选项:(a) 游泳池 (b) 篮子 (c) 后院 (d) 自己的家
-
合理推理:答案必须是可以用来携带一只小狗的东西。篮子是用来装东西的。因此,答案是 (b) 篮子。
-
对少样本示例的依赖
:STaR 在推理任务中高度依赖少量的 Few-Shot 推理示例,这导致模型的推理能力较为有限,难以应对复杂和广泛的任务。
-
泛化能力受限
:STaR 虽然能够通过迭代的方式提升模型的推理能力,但其应用主要局限于特定的结构化任务(如问题回答),难以在开放域或任意文本生成任务中取得同样的效果。
针对 STaR 的局限性,Quiet-STaR [7] 提出 “内部思维” 的概念,将显式的 Rationales 推理过程转化为模型内部隐式的推理过程,从而摆脱对于外部示例的依赖。
同时,引入可学习的 <|startofthought|> 和 <|endofthought|> token 来标记思维的开始和结束。
Quiet-STaR 还实现了在更一般文本上的推理学习,这意味着大量复杂任务下的非结构化语料(如医疗、金融等领域)都可以被加入学习过程。同时利用带推理过程的结果与真实结果的分布差异引入奖励信号,通过 REINFORCE 的方法优化生成的推理,使得基于这些推理的模型预测未来的 tokens 更为准确。
就目前来看,STaR 和 Quiet-STaR 是最接近 o1 的技术路线和模型表现效果的,但是如果想要进一步达到 OpenAI o1 的效果,还需要克服很多问题。