大语言模型训练基本原理之——PPO算法
在大语言模型的训练流程中,通常会先经过预训练和监督微调,经过这两个步骤后,模型已经能够理解语言结构,也能掌握基本的知识和指令执行能力,但你可能会发现,模型有时候仍会胡说八道、答非所问——这是由于监督微调出来的模型还不够聪明,它只是单纯能模仿人类已经写好的答案,但并不明白什么样的回答是「好的」(更符合人类偏好)。说白了,模型在监督微调阶段学的是「怎么答」,但没学会「怎么才能答得好」。
为了进一步提升模型输出的质量和对齐程度,研究者引入了基于人类反馈的强化学习(RLHF)。通过奖励模型对不同响应进行偏好打分,再利用强化学习算法对语言模型进行微调,使其在生成文本时更加贴合人类价值与偏好。其中,PPO(Proximal Policy Optimization) 是 RLHF 阶段最常用的优化算法,也是在 InstructGPT 和 ChatGPT 等模型中取得显著效果的关键技术。
本文将先从这个PPO算法入手,拆解此算法的核心理论。
策略梯度算法
PPO算法是一种On-Policy的策略梯度算法,关于策略梯度,我在之前的一篇文章中曾提到过其核心公式的推导:
这里 表示采取 以 为参数的动作策略 时,能获得的所有状态下的回报的期望值。
简单来说,我们希望优化 ,让 变得更大。但策略梯度算法在实际应用时极不稳定,这种更新的方法容易让策略变化太剧烈,尤其是对于像大语言模型这样参数量巨大的网络(也算一种策略网络)而言,更是难以承受。为了引入对策略变化的约束,后续又提出了一些新的方法,例如TRPO、PPO等,前者直接在训练时强行限制新旧策略之间的KL散度,而后者则提出了一种更简单并且非常有效的手段。
PPO算法
我们回到 期望内部的 函数的定义:
我们要最大化 ,其实就相当于最大化 函数:
将 函数中的期望展开为积分:
我们发现,在给定状态 的情况下, 为关于动作 的单变量函数,此时有两种情况:
- ,说明它给我们带来的价值是正的,意味着这是一个比较好的动作,我们就应该进一步提升当前状态下这个动作被取到的概率,也就是 的值。
- ,说明它没有给我们带来价值或带来了负价值,意味着这个动作比较差,同理我们应该进一步降低当前状态下这个动作被取到的概率。
重要性采样
对于上面的两种情况,我们想要定量地描述「进一步」这个词,很自然地会想到,如果将优化前后的概率值做一个比值:
固定分母,优化分子,就可以体现新概率相对于旧概率的变化。
对于情况 1,我们希望新概率变大,故需要对 梯度上升;对于情况 2,我们希望新概率变小,故需要对 梯度下降。合而为一,我们总是需要对下式:
进行梯度上升。
这就将原先的优化目标转化为了:
这个操作也被称为重要性采样。
约束
回想策略梯度算法的缺陷:难以在优化过程中控制新旧策略的差异,导致策略剧烈波动,使得训练不稳定。
在经过了重要性采样以后,如何规避这个缺陷?
盯着重要性采样引出的优化目标看,这不答案已经拍脸上了吗?直接约束新旧概率之间的差距不就行了?
对于这件事,PPO算法使用的约束方法是对新旧概率的比值,也就是优化目标左边那一坨东西,进行一个裁剪:
将概率的比值保持在区间 之内,简单粗暴地控制了策略的差异,让策略在训练过程中能够逐步收敛,不至于步子跨太大。
另外,在大语言模型的训练中,我们还要对新策略和原始策略之间的KL散度进行惩罚,这同样也是为了防止新策略跑的离旧策略太远。计算KL散度有多种方式,这里暂时不管。
优势函数
设想我们的 Agent 因为前面操作太垃种种原因,处在一个已经非常糟糕的状态 ,这个状态下,无论这个 Agent 采取哪个动作 ,价值函数 都是负的,由刚刚重要性采样部分得出的结论,我们发现对于每个动作都要降低它被取到的概率,这不就是摆烂么。难道对于 Agent 而言,原地摆烂才是最优解?
虽然对于心理承受能力比较普通的人类玩家而言,恐怕大部分人都会选择摆烂吧,但 Agent 毕竟不是情绪化的玩家。作为没有任何感情的 Bot,它应该做的绝对不是摆烂,而是在逆境中找到那一记也许能够力挽狂澜的神之一手,无论最终结局如何。
从强化学习的角度来看,我们先前优化目标中的 函数就显得不够合理了,因此需要找出一个新的函数来代替 函数,其能够更准确地衡量某一个动作在当前局势下的优劣程度。
这便是优势函数。
我们定义 状态下,采取动作 的优势函数如下:
也就是采取了 动作 以后,能得到的预期回报与「遵循策略时能带来的预期回报的期望」之差。
说白了就是衡量你这个动作能让局势改善多少。如果有所改善,说明该动作是比原策略更优的,我们要增大这个动作的概率,反之亦然。
于是我们的优化目标变为了:
最终的目标函数
在实际的应用中,还会对优化目标进行以下 计算:
方便起见,定义 ,则 PPO 算法最终的优化目标如下:
对于这个操作,可能有一些初学者会有疑惑:已经clip
了为什么还要取min
?这一点许多其他博客都没有提到,其实可以简单分析一下如果不取min
会发生什么预想之外的事:
第一种情况: 时,此时这个动作是好的动作,因此需要把动作概率向上调整。此时又分为三种情况:
:此时,优化目标会把 进行一个
clip
,变成 ,此时由于clip
函数在阈值外不产生梯度,优化目标就不会对策略网络产生梯度累积,这意味着这个概率过大的动作不再被用于参数的更新,这是合理的。:此时不截断,正常计算梯度,正常更新策略网络参数,也非常合理。
:此时,与前面第一条子情况同理,由于产生了截断,优化目标同样不会对策略网络产生梯度累积,意味着这个概率过小的动作也没有被用于网络参数的更新,但这是不合理的:因为我们本应该将这个动作的概率调大。
- 第二种情况: 时,与第一种情况同理,我们会发现对于 的动作,将由于
clip
运算阻断了梯度,从而不参与网络参数的更新,这同样是不合理的。
对于上述没有取 min
运算的情况,有几条不合理性,而这些不合理都可以通过做min
运算得到解决。我们同样分类讨论:
时,需要把动作概率向上调整,此时若:
- :取
min
运算后,得到的是截断后的值,故梯度反馈为0,不更新参数,非常合理。 - :
min
运算不产生效果,正常计算梯度,更新参数,非常合理。 - :
min
运算后,得到的是未被截断的值,故梯度反馈不为0,能够正常更新参数,非常合理。
- :取
时,需要注意此时
min
运算作用在两个负数上,故会得到与前面相反的结果,即对于 的动作,进行了min
运算后得到的反而是未被截断的值,故这种情况下仍能够正常更新参数,也变得合理了起来!
因此,这个取min
运算是非常重要的,如果没有这个运算,在某些情况下参数会得不到我们希望的更新。
如何估计优势函数
优势函数这个东西看上去十分抽象,要怎么去估计呢?
为方便起见,我们记 为 的 阶近似估计, 为 的 阶近似估计, ,写出 时刻优势函数的 1 阶近似估计:
同理,2 阶近似估计:
我们可以写出 阶近似估计的通项:
近似阶数 越大,我们得到的估计值的偏差越小,但其中包含的随机变量()越多,因此方差反而变大。
接下来有个操作叫 算法,它的作用是平衡这些估计的偏差与方差。
简而言之,该方法使用一个 上的系数 ,对这些估计进行加权求和,即:
阶数越高的估计值,权重越小,以此降低其方差。由此得到的和式能够兼顾偏差和方差。
由于所有的 都是 的估计,上式的期望差不多相当于 , 因此我们还应该乘上一个系数,才能得到真正对 的估计:
这便是广义优势估计 (GAE) 算法。
对于PPO算法的基本原理就讲到这儿,接下来让我们来看看这个算法是如何应用在大语言模型的训练流程中的。为了让PPO能够发挥作用,我们还缺少一个模块,那就是用来计算上文中多次出现的 的值(Reward)的模型,也就是所谓 Reward Model,下一篇文章,博主将介绍如何训练一个 Reward Model。