主要观点总结
该文章介绍了量化投资与机器学习微信公众号注意到的一篇论文,该论文提出了一种新型的训练大语言模型的方法SASR。文章详细描述了SASR是如何结合监督学习(SFT)和强化学习(RL)进行模型训练的,以及解决这两种方法各自存在的问题。同时介绍了GRPO算法的特点及其在SASR中的应用。文章还给出了实验数据,对比了SASR与其他训练方式在三个不同数据集上的表现,结果显示SASR表现最佳或次优。
关键观点总结
关键观点1: 论文背景及现状
随着大语言模型的发展,监督学习和强化学习是两种主要的训练方法。但是,它们都存在一些问题,如监督学习的依赖高质量标签数据,模型容易过拟合;强化学习训练不稳定,容易出现模式坍缩等问题。
关键观点2: SASR方法介绍
SASR是一种结合监督学习和强化学习的新型训练大语言模型的方法。它通过引入自适应决策函数I(t),在每一步训练开始前根据模型的当前状态来动态判断应该使用监督学习还是强化学习,实现了训练方式的智能适配。
关键观点3: GRPO算法介绍
GRPO是一种为大语言模型定制的强化学习算法。它通过生成多个答案并分组进行策略优化,解决了传统强化学习中不稳定和收敛性差的问题。
关键观点4: 实验数据与结果
作者在三个数据集上进行了实验,对比了五种方法的表现。结果显示SASR在所有任务中均取得最佳或次优成绩,尤其在逻辑推理任务KK上表现最为突出。
关键观点5: 总结
SASR通过动态平衡监督学习与强化学习的比例,显著提升了模型在不同类型推理任务上的表现,验证了其在任务特定训练中的有效性与优越性。
正文
input_ids, target_ids = batch
# 让模型根据输入预测出整个输出序列
outputs = model(input_ids, labels=target_ids)
# 计算交叉熵损失,目标是最大化“参考答案”的概率
loss = outputs.loss
# 反向传播 + 梯度更新
loss.backward()
optimizer.step()
optimizer.zero_grad()
这种方式训练稳定、效果直观,
但其最大问题在于高度依赖高质量标签数据,且模型更容易“背答案”,在新问题上的泛化能力较差。
与SFT不同,
强化学习采用的是“先尝试、后打分”的策略:模型生成多个答案,系统根据一定的奖励机制给出每个答案的评分,模型再根据这些反馈信号进行更新,从而学会输出更优结果。
以GRPO(Group Relative Policy Optimization)为例,它将多个输出分为“高优势组”和“低优势组”,鼓励模型向前者学习,同时通过KL正则项限制模型偏离原始分布太远。其训练过程如下:
# 假设我们已有初始模型和一个评估器 reward_fn(比如打分器)
for step in range(num_steps):
for question in questions_batch:
# 生成多个可能的答案
responses = [model.generate(question) for _ in range(G)]
# 对每个响应进行奖励打分(如是否正确、格式、逻辑清晰度)
rewards = [reward_fn(response) for response in responses]
# 计算每个响应的 advantage(优劣度)
median_reward = np.median(rewards)
advantages = [r - median_reward for r in rewards]
# 计算 GRPO 损失(鼓励高评分回答,惩罚离谱输出)
loss = compute_grpo_loss(model, responses, advantages)
# 加入 KL 正则项,防止模型输出偏离原始分布太远
loss += kl_penalty(model, ref_model)
# 更新模型参数
loss.backward()
optimizer.step()
optimizer.zero_grad()
这种方式的优点在于突破了SFT的模仿限制,能够探索更丰富的解法,但缺点是训练不稳定、容易出现“模式坍缩”或“奖励欺骗”等问题。
为了解决这两种方法各自存在的短板,本论文提出了“混合训练”策略,即将SFT与RL结合使用。然而现有方法大多采用
静态切换机制
:比如前两轮训练使用SFT,后两轮训练使用RL。这种方式的问题在于它缺乏对任务复杂度和模型学习状态的响应能力,容易导致转换阶段不平滑:切换太早,模型尚未掌握基础能力;切换太晚,则限制了模型的主动探索。
这种僵化策略很难适应多任务、多阶段、多样化训练需求。
# 设定切换阶段,例如前2轮用SFT,后1轮用RL
num_epochs_sft = 2
num_epochs_rl = 1
total_epochs = num_epochs_sft + num_epochs_rl
for epoch in range(total_epochs):
for batch in dataloader:
if epoch < num_epochs_sft:
# 执行 SFT:模仿标准答案
input_ids, target_ids = batch
outputs = model(input_ids, labels=target_ids)
loss = outputs.loss
else:
# 执行 RL(例如 GRPO):基于奖励优化
question = batch["question"]
responses = [model.generate(question) for _ in range(G)]
rewards = [reward_fn(r) for r in responses]
advantages = [r - np.median(rewards) for r in rewards]
loss = compute_grpo_loss(model, responses, advantages)
loss += kl_penalty(model, ref_model)
# 通用优化步骤
loss.backward()
optimizer.step()
optimizer.zero_grad()