RL 学习笔记 #10 近端策略优化(PPO)理论
本文最后更新于:2025年1月8日 下午
赵世钰老师的课程到 Actor-Critic 方法就结束了,接下来我们趁热打铁来学习 OpenAI 于 2017 年提出的近端策略优化(Proximal Policy Optimization,PPO)算法。它通过结合策略梯度方法和近端优化思想的优点,提供了一个高效且稳定的策略优化方案。
接下来的内容将会涵盖 PPO 的理论基础、LLM-RLHF 中 PPO 的应用、各种优化改进的技巧(PPO-Max)。
附上一些参考资料:
- 猛猿老师的文章:人人都能看懂的 PPO 原理与源码解读、人人都能看懂的 RL-PPO 理论知识
- OpenAI 论文:Proximal Policy Optimization Algorithms
- FudanNLP 论文:Secrets of RLHF in Large Language Models Part I: PPO
- GitHub 仓库:首推 OpenRLHF、其次 DeepSpeed-Chat
动机:朴素 Actor-Critic 的局限性
在深入探讨 PPO 之前,我们先回顾一下在上一节 Actor-Critic 方法中,我们做出的两个重要改进:
- 引入 TD Error 作为优势函数,以降低策略梯度估计的方差,提高了学习效率。
- 引入重要性采样,从 Off-Policy 数据中学习,提高了数据利用率。
然而,这些方法在实践中仍存在一些局限性。
TD Error 估计不准
在 Advantage Actor-Critic 方法中,Critic 通过 TD Error 来估计优势函数: \[ \delta_t\left(s_t, a_t\right) = r_{t+1} +\gamma v_t(s_{t+1}) -v_t(s_t) \] 然而,TD Error 的估计依赖于当前策略的价值函数 \(v_t\),这个估计本身可能存在偏差。特别是在训练的早期阶段,Critic 网络的参数尚未收敛,导致估计不准确。这种不准确性会直接影响策略更新的方向,可能导致策略性能的波动甚至下降。
此外,TD Error 的估计还受到采样噪声的影响。由于强化学习中的采样过程通常是随机的,Critic 网络可能会受到噪声的干扰,进一步加剧了估计的不准确性。因此,如何平衡 TD Error 的估计偏差和方差,是提升 Actor-Critic 方法稳定性的关键。
行为策略分布差异过大
在 Off-Policy Actor-Critic 方法中,我们使用重要性采样来修正行为策略 \(\pi_b\) 和目标策略 \(\pi_\theta\) 之间的分布差异。然而,当行为策略和目标策略的分布差异过大时,可能会使得重要性权重 \(\rho_t = \dfrac{\pi_\theta(a_t \mid s_t)}{\pi_b(a_t \mid s_t)}\) 变得非常大,导致梯度估计的方差急剧增大,进而使得训练过程不稳定。
通常,为了避免这种高方差,我们不会使用一个完全无关的行为策略,而是用前几个时间步的旧策略 \(\pi_{\theta_\text{old}}\)。这样既能够充分利用经验数据,又能够使训练变得稳定。唯一的问题就是我们需要控制策略更新的幅度,确保 \(\pi_{\theta_\text{old}}\) 和 \(\pi_\theta\) 不能差异过大。
Proximal Policy Optimization | PPO
为了解决上述问题,近端策略优化(Proximal Policy Optimization,PPO) 应运而生。PPO 的核心目标是避免策略更新中的剧烈变化,确保每次更新都不会使策略偏离原有策略太远,从而保持训练的稳定性。
「近端」一词来源于优化问题中的近端优化(Proximal Optimization)概念。它指的是在优化过程中,通过引入某种约束或惩罚机制,限制每次更新的幅度,使得新策略不会偏离当前策略太远,从而保证优化的稳定性和安全性。这种约束可以通过以下两种方式实现:
- 显式约束:例如,限制新旧策略之间的 KL 散度(Kullback-Leibler Divergence),确保它们之间的差异不超过某个阈值。
- 隐式约束:例如,通过裁剪(Clipping)重要性权重来间接限制策略更新的幅度。
下面,我们将介绍 PPO 如何解决我们前面说到的两个问题:TD Error 估计不准、行为策略分布差异过大。
平衡偏差与方差
在 TD Error 的估计过程中,偏差和方差始终是最关键的问题:
- 偏差(Bias):反映策略更新的方向是否准确,这里的偏差主要来自于 \(v_t\)。
- 方差(Variance):反映策略更新的稳定性,这里的方差主要来自于采样噪音。
我们自然希望能够实现一个低偏差、低方差的算法:当 \(v_t\) 能够准确估计策略 \(\pi\) 的价值 \(v_\pi\) 时,\(\delta_t\left(s_t, a_t\right)\) 至少就满足了低偏差的需求。注意,虽然优势函数已经显著降低了方差,但依然不能完全避免采样过程中随机变量 \(r_{t+1}\) 和 \(s_{t+1}\) 带来的方差。
然而,大部分时候 \(v_t\) 都不能够准确估计 \(v_\pi\),为了缓解高偏差问题,一个直接的想法是:减少对 \(v_t\) 的依赖,使用蒙特卡洛估计(类似 N-Step SARSA)!对于 \(r_{t+1} +\gamma v_t(s_{t+1}) -v_t(s_t)\),我们可以将 \(v_t(s_{t+1})\) 项展开: \[ \begin{aligned} A_t^{(\infty)}&= r_{t+1} +\gamma v_t(s_{t+1}) -v_t(s_t)\\ &= r_{t+1} +\gamma (r_{t+1} + \gamma (r_{t+2} + \cdots) ) - v_t(s_t) \\ &= \sum_{l=1}^{\infty}\gamma^{l-1} r_{t+l} - v_t(s_t) \\ \end{aligned} \]
其中,\(\{r_i\}\) 是采样到的即时奖励数据,如果 \(v_t\) 不准,我们就信任实际采样结果,这样至少不会造成太大偏差。
然而,这样做又会引发一个新问题:前面我们提到优势函数中的随机变量 \(r_{t+1}\) 和 \(s_{t+1}\) 会带来方差,此时我们使用了更多采样数据,相当于引入更多的随机变量,方差也将变得更大。
广义优势估计 | GAE
一个自然而然的想法就是:在两种估计间寻找平衡。为此,PPO 引入了广义优势估计(Generalized Advantage Estimation,GAE)。GAE 通过在时间步上对 TD Error 进行加权累加,提供了一个在偏差和方差之间可调的优势估计器。其定义为: \[ \begin{aligned} A_t^{\text{GAE}(\gamma, \lambda)} &= \delta_t + (\gamma \lambda) \delta_{t+1} + (\gamma \lambda)^2 \delta_{t+2} + \cdots \\ &= \sum_{l=0}^{\infty} (\gamma \lambda)^l \delta_{t+l}\\ \end{aligned} \] 其中:
- \(\delta_t\) 是第 \(t\) 步的 TD Error:\(\delta_t = r_t + \gamma v_t(s_{t+1}) - v_t(s_t)\),这里为了简洁将第 \(t\) 步的即时奖励记为 \(r_t\);
- \(\gamma\) 是折扣因子,\(\lambda \in [0, 1]\) 是 GAE 的衰减系数,控制偏差和方差之间的平衡。
其推导过程涉及到对展开项 \(A_t^{(1)},A_t^{(2)},\cdots,A_t^{(\infty)}\) 的 \(\lambda\) 指数衰减求和,可以查看发表于 ICLR 2016 的原论文。
通过调整 \(\lambda\) 的值,可以在偏差和方差之间进行调节:
- 当 \(\lambda = 0\) 时,只考虑一步的 TD Error,偏差较大,但方差较小。
- 当 \(\lambda = 1\) 时,优势估计等价于蒙特卡洛方法,偏差较小,但方差较大。
通过调整 \(\lambda\),我们可以在偏差和方差之间取得最佳平衡,估计出更准确的优势。
PPO 的前身:TRPO
在 PPO 出现之前,信赖域策略优化(Trust Region Policy Optimization,TRPO) 是一种解决策略更新过大问题的重要方法。其基本思想是,在策略更新过程中,通过限制新旧策略之间的距离,确保每次更新都不会偏离当前策略太远。
这里的「距离」使用的是 KL 散度(Kullback-Leibler Divergence),要求不超过一个预设的阈值。于是 TRPO 的优化目标为: \[ \max_\theta \mathbb{E}_{s \sim \pi_{\theta_{\text{old}}}, a \sim \pi_{\theta_{\text{old}}}} \left[ \frac{\pi_\theta(a\mid s)}{\pi_{\theta_{\text{old}}}(a\mid s)} A^{\pi_{\theta_{\text{old}}}}(s, a) \right] \]
约束条件:
\[ \mathbb{E}_{s \sim \pi_{\theta_\text{old}}} \left[ \mathrm{KL} \left( \pi_{\theta_\text{old}}(\cdot \mid s) \| \pi_\theta(\cdot \mid s) \right) \right] \leq \delta \]
其中:
- \(D_{\mathrm{KL}}\) 表示 KL 散度,用于度量新旧策略的差异;
- \(\delta\) 是预设的信赖域(Trust Region)阈值,所谓信赖域就是要让 \(\pi_\theta\) 落在 \(\pi_{\theta_\text{old}}\) 附近,这个区域是可信的;
- 注意:这里的 \(A^{\pi_{\theta_{\text{old}}}}(s, a)\) 是在旧策略上计算的优势,与之前都不同。
TRPO 在理论上具有良好的收敛性,但实现复杂,需要进行二阶导数的计算和共轭梯度优化,在高维参数空间中计算成本较高。更多细节可以参考原论文,发表于 ICML 2015。
PPO 目标函数
之所以 TRPO 的优化过程复杂,就是因为它将约束条件独立于目标函数之外。因此,为了简化 TRPO 的实现,PPO 提出了两种主要形式的目标函数:PPO-Penalty 和 PPO-Clip。
显式约束:PPO-Penalty
PPO-Penalty 在目标函数中直接加入了 KL 散度的惩罚项,将约束条件转化为惩罚项,使得优化过程可以使用一阶优化方法完成。
目标函数为:
\[ L^{\mathrm{PEN}}(\theta) = \mathbb{E}_{t} \left[ r_t(\theta) A_t^{\text{GAE}(\gamma, \lambda)} - \beta D_{\mathrm{KL}} \left( \pi_{\theta_{\text{old}}}(\cdot \mid s_t) \,||\, \pi_\theta(\cdot \mid s_t) \right) \right] \]
其中:
- \(r_t(\theta) = \frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{\text{old}}}(a_t|s_t)}\) 是新旧策略的概率比率;
- \(A_t^{\text{GAE}(\gamma, \lambda)}\) 是优势估计,这里引入了 GAE 来平衡;
- \(\beta\) 是惩罚系数,控制 KL 散度对目标函数的影响。
PPO-Penalty 通过在优化过程中监控 KL 散度,根据实际值调整 \(\beta\) 的大小,实现对策略更新幅度的动态控制:
- 首先,我们对 \(D_{\mathrm{KL}}\) 也会设置 threshold,我们分别记为 \(D_{\max}\) 和 \(D_{\min}\)。
- 当 \(D_{\mathrm{KL}} \ge D_{\max}\) 时,说明当前策略已经偏离旧策略较远了,这时我们应该增大 \(\beta\),把分布拉回来。
- 当 \(D_{\mathrm{KL}} \le D_{\min}\) 时,说明当前策略很可能找到了一条捷径,即它只优化 KL 散度一项,而不去优化前面优势相关的项,所以这时我们应该减小 \(\beta\),降低惩罚项的影响。
隐式约束:PPO-Clip
如果觉得 KL 散度计算也很麻烦,还有一个更简单的 PPO-Clip 算法。它在目标函数中直接对概率比率 \(r_t(\theta)\) 进行裁剪,避免了引入二阶导数或 KL 散度计算。
那么如何进行裁剪呢?首先,我们要明确一点,我们的目标是防止 \(\pi_\theta\) 偏离 \(\pi_{\theta_\text{old}}\) 太远——因此当 \(\pi_\theta\) 远离 \(\pi_{\theta_\text{old}}\) 时,我们需要缩小更新的幅度。
考虑到当优势 \(A_t>0\) 时,此时我们倾向于让 \(\pi_\theta(a_t\mid s_t)\) 变大,因此 \(r_t(\theta) = \frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{\text{old}}}(a_t|s_t)}\) 会变大:
- 如果两个分布相同,则 \(r_t(\theta)=1\),此时训练稳定,无需裁剪;
- 如果 \(\pi_\theta(a_t\mid s_t)\) 在之前的更新过程中已经变大,则 \(r_t(\theta)>1\),且我们需要进行裁剪,将其限制在 \(1 + \epsilon\) 内防止越变越大,分布越偏越远;
- 如果 \(\pi_\theta(a_t\mid s_t)\) 在之前的更新过程中变小了(注:这是可能发生的,因为策略作为一个参数网络,在优化其他状态和动作的时候也会相互影响),则 \(r_t(\theta)<1\),但此时的更新会导致 \(r_t(\theta)\to 1\),这是我们希望看到的变化,于是我们不再对其进行裁剪。
反之,当优势 \(A_t<0\) 时,我们也有类似的操作。
最终,目标函数可以写作: \[ L^{\mathrm{CLIP}}(\theta) = \mathbb{E}_{t} \left[ \min \left( r_t(\theta) A_t,\; \operatorname{clip}(r_t(\theta), 1 - \epsilon, 1 + \epsilon) A_t \right) \right] \]
其中:
- \(\operatorname{clip}(r_t(\theta), 1 - \epsilon, 1 + \epsilon)\) 表示将 \(r_t(\theta)\) 限制在 \([1 - \epsilon, 1 + \epsilon]\) 的范围内;
- \(\epsilon\) 是超参数,控制策略变化的范围,一般取值为 \(0.1\) 或 \(0.2\);
- \(\min(*)\) 操作实现了在优势为正的时候,仅对 \(1 + \epsilon\) 裁剪;在优势为负的时候,仅对 \(1 - \epsilon\) 裁剪。
PPO 算法实现
这里以 PPO-Clip 为例,简单介绍其算法实现。在这之前,我们知道 Actor 和 Critic 是需要交替更新的。而在 PyTorch 中,其更新步骤一般是:
1 |
|
其中的优化问题我们不再需要去计算 \(\theta_{t+1}\) 或 \(w_{t+1}\),而是使用损失函数和优化器(如 Adam)直接更新。唯一的问题就是需要明确两个 Loss 的表达式。
Actor Loss
策略网络的更新目标是最大化 PPO-Clip 的目标函数 \(L^{\text{CLIP}}(\theta)\)。具体来说,策略网络的参数 \(\theta\) 应该通过梯度上升法进行更新。现在我们已经知道其目标函数:
\[ L^{\mathrm{CLIP}}(\theta) = \mathbb{E}_{t} \left[ \min \left( r_t(\theta) A_t,\; \operatorname{clip}(r_t(\theta), 1 - \epsilon, 1 + \epsilon) A_t \right) \right] \]
其具体计算步骤如下:
- 计算概率比率:对于每个样本 \((s_t, a_t)\),计算 \(r_t(\theta)\)。
- 计算裁剪后的目标:根据 \(A_t\) 的符号,对 \(r_t(\theta)\) 进行裁剪,得到裁剪后的目标值。
- 计算目标函数:将裁剪后的目标值与未裁剪的目标值进行比较,取最小值作为最终的目标函数值。
Critic Loss
在前面的笔记中,我们知道估计值函数最常用的方法就是蒙特卡洛方法: \[ V_t^{\text{target}}(s_{t}) = \mathbb{E} \left[ \sum_{l=0}^{\infty} \gamma^l r_{t+l} \mid s_t \right] \] 而在引入优势估计后,我们还可以将状态价值的目标值表示为: \[ V_t^{\text{target}} = A_t + V_\phi(s_t) \] 其中,\(A_t\) 是优势估计,\(V_\phi(s_t)\) 是价值网络对当前状态的估计值。这个公式的意义是:目标值等于当前状态的估计值加上优势估计,即采取特定动作的额外收益。
而价值网络的更新目标是让 \(V_\phi(s_t)\) 尽可能接近 \(V_t^{\text{target}}\),换句话说,就是在目标策略下优势最小。因此,结合两个表达式,我们使用均方误差(MSE)作为损失函数,最小化优势: \[ L^{\text{VF}}(\phi) = \frac{1}{N} \sum_{t} \left( V_{\phi}(s_t) - V_t^{\text{target}} \right)^2 \] 其中,\(N\) 是样本数量。其具体计算步骤如下:
- 计算目标值:对于每个样本 \((s_t, a_t, r_t, s_{t+1})\),使用 GAE 计算优势估计 \(A_t\),然后计算 \(V_t^{\text{target}}\)。
- 计算均方误差:计算价值网络输出 \(V_{\phi}(s_t)\) 与目标值 \(V_t^{\text{target}}\) 之间的均方误差。
算法步骤
PPO 算法的整体步骤如下:
初始化策略网络 \(\pi_\theta\) 和价值网络 \(V_\phi\),设定超参数 \(\epsilon\)、\(\gamma\)、\(\lambda\) 等。
采集交互数据:使用旧策略 \(\pi_{\theta_{\text{old}}}\) 与环境进行交互,收集一批轨迹数据 \(\{ (s_t, a_t, r_t, s_{t+1}) \}\)。
计算优势:使用 GAE 计算每个时间步的优势估计 \(A_t\)。 \[ A_t^{\text{GAE}(\gamma, \lambda)} = \sum_{l=0}^{\infty} (\gamma \lambda)^l \delta_{t+l} \] 其中 \(\delta_t = r_t + \gamma V_{\phi}(s_{t+1}) - V_{\phi}(s_t)\)。
计算目标函数,更新策略网络:对于每个样本,计算概率比率 \(r_t(\theta)\) 和目标函数 \(L^{\text{CLIP}}(\theta)\),使用优化器更新策略网络参数 \(\theta\)。 \[ L^{\text{CLIP}}(\theta) = \mathbb{E}_{t} \left[ \min \left( r_t(\theta) A_t,\; \operatorname{clip}(r_t(\theta), 1 - \epsilon, 1 + \epsilon) A_t \right) \right] \]
计算状态价值的目标:对于每个样本,计算: \[ V_t^{\text{target}} = A_t + V_{\phi}(s_t) \]
更新价值网络:最小化均方误差损失,使用优化器更新价值网络参数 \(\phi\)。 \[ L^{\text{VF}}(\phi) = \frac{1}{N} \sum_{t} \left( V_{\phi}(s_t) - V_t^{\text{target}} \right)^2 \]
重复迭代:在更新多次后,重置 \(\theta_{\text{old}} \leftarrow \theta\),并继续下一个周期的采样和更新,直到策略收敛或达到预定的停止条件。