深度解析强化学习核心算法 PPO:原理与 PyTorch 代码
在强化学习的发展历程中,策略梯度(Policy Gradient)算法是解决连续决策问题的基石,但其固有缺陷始终限制着工业落地与大规模应用 —— 标准策略梯度算法使用环境采集的样本进行一次梯度更新后就必须丢弃旧数据,一旦更新步长控制不当,就会导致新策略与旧策略差异过大,后续采样数据无法匹配训练需求,直接导致训练不稳定、易崩溃,样本效率极低。
为了解决这一核心痛点,研究人员提出了 “信任区域”(Trust Region)的优化思想,即约束策略更新的幅度,确保新策略与旧策略不会偏离太远。基于这一思想,衍生出了 Trust Region Policy Optimization(TRPO)算法。TRPO 的核心是通过 KL 散度约束策略更新的幅度,理论上能保障训练的 monotonic improvement(单调提升)特性,但实际应用中却存在显著短板:TRPO 的约束条件需要计算二阶导数,实现复杂度极高,且无法兼容带噪声的网络结构(如 Dropout),也不支持策略网络与价值网络的参数共享,这极大限制了其在大规模复杂任务中的落地 —— 比如在模拟机器人 locomotion(运动控制)、Atari 游戏对战等主流强化学习基准任务中,TRPO 的实用性都受到限制。
Proximal Policy Optimization(PPO,近端策略优化)算法正是在这一背景下应运而生,它继承了 TRPO 中 “信任区域” 的核心优势,却仅采用一阶优化就能实现策略的稳定更新。这一平衡了 “样本效率” 与 “实现复杂度” 的算法,已经成为当前强化学习领域的主流工业方案:无论是 OpenAI 的强化学习研究、DeepMind 的机器人控制任务,还是大模型训练中对齐人类偏好的核心步骤 RLHF(基于人类反馈的强化学习),PPO 都是不可或缺的关键技术 —— 在大模型 RLHF 流程中,PPO 的核心作用是优化语言模型的响应策略,让模型生成的内容更符合人类的偏好标准(如有用性、安全性、表达准确性)。
本文将从技术原理出发,深入剖析 PPO 算法的核心设计逻辑,再通过经过行业验证的 CleanRL 库实现的完整 PyTorch 代码案例,详细拆解 PPO 的工程落地细节,帮助有基础的开发者彻底掌握这一核心算法。
PPO 算法核心原理
PPO 本质上是一种在线策略(On-Policy)的 Actor-Critic 类算法,它的核心设计逻辑非常简洁直接:在策略更新过程中,严格约束新策略与旧策略之间的差异幅度,避免单次更新步长过大导致训练崩溃。这一设计既保留了 TRPO 的训练稳定性优势,又将实现复杂度降低到了工程可接受的范围。
从技术架构上看,PPO 采用了典型的 Actor-Critic 架构,这是它能同时兼顾 “策略优化效率” 与 “训练稳定性” 的核心基础。在这个架构中,Actor 模块负责输出给定状态下的动作决策,Critic 模块负责评估当前状态的长期价值 —— 即从当前状态出发,后续所有奖励的折现值;基于这个价值评估结果,算法会计算优势函数(即当前动作的实际收益与平均预期收益的差值),并以此指导策略的更新方向。与传统的单网络模型方案不同,PPO 将 “动作输出” 与 “价值评估” 拆分为两个独立的网络头,这一设计让算法能够同时利用策略梯度和价值梯度的信息,从而显著提升样本的学习效率。
而 PPO 真正成为 “强化学习主流工业方案” 的核心原因,在于它用两种简单且工程易实现的技术方案,完美解决了传统策略梯度算法的 “更新幅度不可控” 核心痛点:其一,是通过对策略更新的代理目标(Surrogate Objective)进行截断,直接限制单次更新的幅度上限;其二,是在更新过程中引入自适应的 KL 散度惩罚项,从概率分布层面约束新旧策略的差异幅度。这两种方案都能有效避免策略更新时的 “步长过大” 问题,其中截断方案的综合表现更优,是当前行业的默认选择。
接下来,我们将深入拆解这两种核心技术方案的背后逻辑。
为什么需要 “近端”?
传统的无约束策略梯度算法,在实际应用中存在一个几乎无法解决的核心矛盾:更新步长太小的话,策略优化的进度会极其缓慢,样本效率低到无法落地;步长太大的话,策略会在某次更新后突然 “跑偏”—— 新策略与旧策略的输出分布差异过大,之前采集的所有样本数据都无法再匹配新的训练需求,训练精度会直接出现崩塌式下降。
这一问题的根源在于,策略梯度的更新方向,是基于旧策略采集的样本数据计算得到的 —— 如果新策略与旧策略的分布偏差过大,旧样本的信息对新策略的指导价值会急剧下降,甚至完全产生负面指导效果。TRPO 算法试图通过约束新旧策略的 KL 散度来解决这一问题,但它的实现复杂度太高,需要计算二阶导数,且无法兼容带噪声的网络结构(如 Dropout),也不支持策略网络与价值网络的参数共享,工程落地门槛极高。
PPO 的 “近端”(Proximal)设计思想,本质上是用一种更简单、更轻量化的方案,实现了 TRPO 中 “信任区域” 的核心效果 —— 它不再通过复杂的 KL 散度约束计算来限制策略更新幅度,而是通过对策略更新的代理目标(Surrogate Objective)进行直接的截断(Clipping),将每次策略更新的幅度严格限制在安全区间内。这一方案的核心逻辑是:允许策略在一定的幅度范围内进行优化,但绝对不允许出现偏离幅度过大的更新步骤。这就给策略梯度的优化过程拴上了一根 “安全绳”,即使在学习率设置偏大的情况下,也能将策略的变化幅度限制在可控的区间内,确保训练过程的稳定。
这一设计的巧妙之处在于,它用极低的实现成本,换取了训练稳定性的大幅提升。PPO 通过 “截断代理目标” 的轻量化方案,完美平衡了样本效率与实现复杂度:它仅使用一阶优化技术,就实现了与 TRPO 类似的稳定训练效果,同时显著提升了算法的样本效率,让工程落地变得简单可行。
数学推导:从目标函数到截断机制
要理解 PPO 的技术细节,我们需要先从策略梯度的基础理论说起。
在强化学习中,策略梯度的核心逻辑,是通过最大化一个 “代理目标函数”(Surrogate Objective Function),让策略网络在后续决策中更倾向于选择那些能获得更高奖励的动作。这个代理目标函数的物理意义是:在当前旧策略的基础上,通过调整策略参数,提高那些高优势动作的被选择概率。优势函数(Advantage Function)在这里表示为 ,它衡量了在当前状态下,选择某个动作的实际长期收益与平均预期收益之间的差值 —— 如果这个值为正,说明当前动作比平均预期收益更好,策略应该提高这个动作的被选择概率;如果为负,则说明这个动作比平均水平更差,策略应该降低其被选择概率。
在标准的策略梯度算法中,这个代理目标函数的数学形式如下:
这里的 是我们要优化的新策略,
是用于采集样本的旧策略,
是新旧策略选择同一个动作的概率比值 —— 在优化过程中,这个比值的变化幅度,直接决定了策略更新的幅度大小。
是指保守策略迭代(Conservative Policy Iteration)目标函数,这一函数的核心逻辑是,通过计算新旧策略的概率比,来指导策略的优化方向。
但如果直接最大化这个无约束的目标函数,策略更新幅度必然会失控:只要某个动作的优势函数是正的,优化过程就会不断提高这个动作的被选择概率,甚至会把概率比推到极端值 —— 这会导致新策略的输出分布与旧策略完全偏离,之前采集的样本数据全部失效,训练过程直接崩溃。
PPO 的核心创新,就是针对这个无约束目标函数的缺陷,提出了两种行之有效的幅度约束方案,两者的核心本质都是限制 的变化区间,最终实现策略更新幅度的可控性。
方案一:截断代理目标函数(Clipped Surrogate Objective)
这是 PPO 的核心默认方案,也是行业内应用最广泛的 PPO 实现方式。其核心设计逻辑是:通过对概率比 \(r_t(\theta)\) 的变化范围进行直接截断,强制限制新旧策略之间的差异幅度,从而避免策略更新幅度过大导致的训练崩溃。
具体来说,PPO 在原有的保守策略迭代目标函数的基础上,增加了一个明确的截断项,将概率比 的取值范围严格限制在
区间内,其中
是一个需要提前设置的超参数,其典型值为 0.2;这一数值是 OpenAI 在论文中通过大量基准实验验证得到的最优经验值。这一截断机制的核心逻辑是:当某个动作的优势函数为正时,只有在概率比小于
的前提下,才会去最大化代理目标函数;当概率比超过
时,优化过程会被直接截断,不再继续放大这个动作的被选择概率。反之,如果动作的优势函数为负,则只有在概率比大于
的前提下,才会去最小化目标函数;当概率比低于
时,优化过程同样会被截断。
这一机制的核心本质,是给策略更新设置了一个 “变化幅度安全区间”。它的巧妙之处在于,对概率比 进行截断的同时,将最终的优化目标设置为 “未截断目标函数” 和 “截断后目标函数” 中的较小值 —— 这相当于给策略优化过程提供了一个 “悲观的下界”,也就是在最坏情况下的最优更新幅度。这一设计可以通俗理解为:算法在优化过程中,只会在 “能提升目标函数” 的方向上移动到截断区间的边界位置,再继续优化下去,对目标函数的提升效果会被截断项完全抵消,不会再产生任何实际的更新增益。例如,当优势函数为正时,优化过程会在 “概率比达到
” 时停止,再继续更新参数也无法进一步提升目标函数;当优势函数为负时,概率比会被限制在
以上,避免策略过度倾向于其他动作。这就从数学形式上,彻底消除了 “单次更新步长过大” 的可能性。
用数学语言来描述,这个截断后的目标函数形式为:
这一目标函数是 PPO 的核心创新点,它的核心逻辑是:对每个样本的代理目标函数进行评估时,会同时计算 “未截断的原始概率比” 和 “经过截断处理后的概率比”,再取两者中的较小值 —— 这相当于在优化过程中,选择了一个更保守、更安全的更新方向。通过这个目标函数优化得到的策略参数,就是 PPO 的最终策略更新方向。
这一设计的巧妙之处在于,它不需要像 TRPO 那样,通过复杂的二阶计算来约束 KL 散度,只需要在一阶优化过程中,对概率比进行简单的数值截断,就能实现与 TRPO 类似的 “信任区域” 约束效果:在策略更新幅度较小时,优化过程可以正常推进;但一旦更新幅度超过预设的安全区间 ,截断机制就会立即生效,将进一步的更新增益直接降为零,从而限制策略的更新幅度。这一方案在实现成本可控的前提下,完美解决了策略更新不稳定的问题。
方案二:自适应 KL 惩罚项(Adaptive KL Penalty)
这是 PPO 的另一种可选方案,其核心逻辑与 TRPO 的约束思路更接近 —— 它不再直接限制概率比的数值变化,而是在目标函数中增加了一个基于 KL 散度的惩罚项,从概率分布的底层逻辑上,约束新旧策略之间的差异幅度。
具体来说,这一方案会在优化目标中,加入一个衡量新旧策略动作输出分布差异的 KL 散度项,再通过一个动态调整的惩罚系数 ,将策略的实际 KL 散度值控制在预设的目标区间内。这一方案的数学形式如下:
在实际优化过程中,算法会根据当前新旧策略的实际 KL 散度值,动态调整惩罚系数 的大小:如果实际 KL 散度值超过预设目标值的 1.5 倍,算法会将
的数值翻倍,增大惩罚项的权重,强制让后续更新后新旧策略的分布差异缩小;如果实际 KL 散度值低于预设目标值的 1/1.5,则将
的数值减半,降低惩罚项的权重,让策略在后续更新时有更大的优化空间。这一动态调整的逻辑,是启发式设计的,实际应用中对算法性能的影响极小,不会干扰核心优化流程。
这一方案的核心本质,是将新旧策略的分布差异约束,从 “直接截断数值” 升级为了 “基于分布差异的动态惩罚”,与 TRPO 的核心约束逻辑更接近。但需要明确的是,在 PPO 的原始论文中,通过大量的基准任务对比实验验证,这一方案的实际表现要略逊于截断方案 —— 无论是样本效率还是训练稳定性,都存在小幅的差距。因此,这一方案在实际工程中很少被采用,更多是作为理论研究的基线对比方案,或是在对策略分布一致性有极高要求的特定场景下,才会被考虑使用。
完整 PPO 算法流程
PPO 算法的完整流程可以分为三个核心阶段:数据采样、优势计算、策略与价值函数更新,这三个阶段会在训练过程中不断迭代,最终让策略收敛到最优的决策效果。从整体架构上看,PPO 采用了 “多环境并行采样 + 批量迭代优化” 的经典设计逻辑 —— 这一设计的核心目的,是在提升采样效率的同时,通过多轮小批量更新,最大化利用已有的样本数据,从而兼顾训练速度和样本效率。
下面我们将详细拆解这个流程的每一个技术细节:
- 初始化与多环境设置:在训练正式开始前,需要完成三个核心初始化步骤:第一,设置训练的相关超参数,包括学习率、折扣因子、并行环境数量等;第二,初始化策略网络(Actor)和价值网络(Critic)的权重参数;第三,通过旧策略网络与环境的交互,采集足够的轨迹数据,作为后续优化的样本基础。为了提升采样效率,PPO 一般会采用并行环境采样的设计逻辑 —— 通过多个并行的环境副本,同时执行旧策略采集多组不同的样本轨迹,这一设计可以在单位时间内采集到更多的样本数据,提升训练效率。在 CleanRL 的官方实现中,默认采用了 4 个并行环境同时采样,这一数量是经过实验验证的、在 “采样效率” 和 “训练资源消耗” 之间的最优平衡点。
- 并行数据采样:在每个训练迭代回合中,N 个并行环境副本会同步基于旧策略
执行 T 个时间步的交互采样 —— 在 CleanRL 的默认参数设置中,每个环境副本会连续执行 128 个时间步,采集足够的轨迹数据。在这个过程中,所有的交互关键信息都会被完整保存下来,包括每个时间步的环境状态、策略输出的动作、环境返回的奖励、交互结束的标志位,以及旧策略在当前状态下的动作输出概率对数、旧价值网络对当前状态的价值评估结果 —— 这些数据将作为后续优化的核心样本基础。这里的一个关键设计细节是,采样过程中使用的策略,是优化前的旧策略
—— 用旧策略采集的样本数据,来指导新策略的更新方向,这是为了保证样本分布与新策略更新方向的一致性,避免分布偏移;在整个优化过程中,旧策略的参数会被完整保留,不会被任何中间优化步骤所更新,直到当前迭代回合的所有优化步骤完成后,才会用新策略的参数覆盖旧策略参数。
- 优势与回报计算:在完成样本数据采集后,算法会基于保存好的轨迹数据,计算每个时间步的 “优势函数”(Advantage)和 “回报值”(Return)。这一步的核心技术方案是 “广义优势估计”(Generalized Advantage Estimation,GAE)—— 这是一种能有效降低优势函数方差的计算方案,它通过组合多个时间步的时序差分(TD)误差,在 “优势函数的偏差” 和 “方差” 之间找到了一个完美的平衡点,后续的策略梯度也会更稳定,让训练过程更易收敛。在计算过程中,GAE 会先计算每个时间步的时序差分误差
—— 即实际奖励与预期价值评估结果的差值;再基于这个差值,通过一个折扣系数
(其典型值为 0.95),将多个时间步的误差加权求和,得到最终的优势评估结果。回报值则是通过 “优势值 + 旧价值网络的评估结果” 计算得到的 —— 这一回报值将作为优化价值网络的 “真实标签”,用于后续的价值网络更新。这一步的关键技术细节是,所有的计算过程都必须在 “反传” 的模式下进行,即从轨迹的最后一个时间步开始,一步步向前推导计算,这是因为 GAE 的计算逻辑需要依赖下一个时间步的评估结果,才能保证整个轨迹的优势评估结果的连贯性和准确性。
- 多轮小批量优化:这是 PPO 算法的核心优化阶段,也是它区别于传统策略梯度算法的关键设计 —— 传统策略梯度算法只能对采集到的样本数据进行一次梯度更新,而 PPO 可以对同样的样本数据进行 K 轮的小批量梯度更新。这一设计的核心目的,是为了最大化利用每一批次的样本数据,提升算法的样本效率。在每一轮更新中,算法都会随机抽取一部分样本组成小批量数据,计算截断后的代理目标函数,再通过 Adam 优化器执行梯度上升更新 —— 这里的关键细节是,在每一次小批量更新之前,算法都会将旧策略的概率比固定下来,确保在多个小批量更新步骤中,截断机制仍然能够有效地约束策略更新的幅度。在这个过程中,算法会重点跟踪两个关键指标:一是 “截断比例”—— 即概率比被截断的样本占总样本数的比例;二是 “近似 KL 散度”—— 这一指标用于近似评估新旧策略的实际分布差异。这两个指标的主要作用是,在训练过程中实时监控更新的幅度是否正常,作为调试参考依据,并不会直接干预优化过程;只有在这个值超过一定的合理范围时,才会提前终止当前的优化步骤,避免进行无效更新。这一步的关键技术细节是,更新的轮次数 K 和小批量样本大小 M,都是需要提前设置的超参数 —— 在 CleanRL 的官方实现中,K 的默认值为 4,M 的默认值为总样本数的 1/4;这两个数值是经过大量基准任务实验验证的最优经验值,能够在 “样本利用效率” 和 “训练过拟合风险” 之间达到最优的平衡效果。
- 代理目标函数计算:在每一个小批量更新步骤中,算法会基于当前的小批量样本数据,计算出三个核心损失项,再将它们合并成最终的代理目标函数,用于指导网络参数的更新。第一项是策略代理损失项 —— 即我们前面提到的截断后的代理目标函数,这是优化策略网络的核心依据;第二项是价值函数损失项 —— 即价值网络的评估结果与实际回报值之间的均方误差,这一损失项用于优化价值网络,提升其对状态价值的评估准确性;第三项是熵正则化项 —— 这一项的作用是维持策略的探索性,避免策略过早收敛到局部最优的决策结果。在计算这三个损失项时,算法会将旧策略的概率比作为一个固定的常数参与计算,确保截断机制在多轮更新中仍然有效,不会因为多轮迭代优化而失效。而在实际优化过程中,我们会将这三个损失项加权组合,形成最终的代理目标函数:
其中,
和
是两个需要提前设置的超参数,分别用于平衡价值函数损失项和熵正则化项的权重;
是价值函数的损失项,
是熵正则化项。这里需要特别注意的是,价值函数损失项前面的符号是减号 —— 这是因为我们的优化目标是最大化策略的代理目标函数,而价值函数损失项的优化目标是最小化评估误差;通过减号将其转化为最大化问题后,才能与策略代理损失项的优化方向保持一致。熵正则化项前面的符号是加号 —— 这是为了在优化过程中,鼓励策略网络保持一定的随机探索性,避免它过早收敛到局部最优的决策结果;当熵值越高时,策略的探索性就越强,反之则越弱。
- 网络参数更新:在计算得到代理目标函数后,优化器会执行梯度下降(或上升)步骤,更新策略网络和价值网络的参数。这里的关键技术细节是,所有样本的梯度都会被计算完成后,才会统一执行参数更新步骤 —— 这是为了保证批量数据中的更新方向的一致性,避免参数更新过于频繁导致的训练不稳定。在更新过程中,为了避免梯度爆炸问题,算法会对梯度的范数进行强制截断,将其限制在一个合理的区间内。在 CleanRL 的官方实现中,梯度范数的最大默认值为 0.5—— 这是一个经过实验验证的经验值,能在保证更新方向有效的前提下,避免梯度过大导致的训练不稳定。此外,在更新过程中,算法还会对学习率进行衰减 —— 随着迭代次数的增加,学习率会逐步线性减小,这是为了在训练后期,让策略的更新幅度逐步放缓,避免出现训练震荡的情况。
- 旧策略参数更新:在完成 K 轮的小批量更新后,算法会用优化后的新策略参数,覆盖旧策略网络的参数。这里需要特别注意的是,在当前迭代回合的所有优化步骤中,旧策略的参数是被完全保留的,不会被任何中间优化步骤所更新;只有在当前迭代回合的所有优化步骤完成后,才会用新策略的参数覆盖旧策略参数。这一设计是为了保证下一轮迭代中,样本数据采集的策略基础与优化方向的一致性 —— 避免因为旧策略参数在优化过程中被中途更新,导致采样数据和优化方向出现分布偏移,影响训练的稳定性。
- 迭代与收敛判断:算法会不断重复上述 “采样 - 计算 - 优化” 的完整流程,直到策略网络的性能指标(如平均奖励、策略变化幅度)收敛到预设的停止标准,或是达到了预先设置的总训练步数为止。
通过这一整套流程的设计,PPO 在保证训练稳定性的前提下,最大化利用了样本数据,实现了性能与实现复杂度的完美平衡。
PyTorch 代码案例解析
通过行业公认的、结构清晰的 CleanRL 库中 PPO 的经典实现版本,完整拆解 PPO 算法的工程落地细节。
环境准备
在开始训练之前,我们需要先安装所有的依赖库。CleanRL 库的设计逻辑是将所有算法的核心逻辑封装在单一文件中,因此我们只需要安装它的基础依赖项,就可以运行 PPO 算法。具体的安装步骤如下:
首先,通过 git 命令将 CleanRL 的官方代码仓库克隆到本地,再进入项目的根目录:
git clone https://github.com/vwxyzjn/cleanrl.git
cd cleanrl
接下来,通过 poetry 命令安装项目的所有依赖项,包括 PyTorch、Gymnasium、Tensorboard 等强化学习和训练日志监控需要的核心工具包:
poetry install
需要特别注意的是,这里使用的 Python 版本必须在 3.8 到 3.9 之间 —— 这是因为 CleanRL 的部分依赖项,如 Gymnasium 的部分功能包,还不支持 Python 3.10 及以上的版本。如果你的本地环境中没有安装符合要求的 Python 版本,可以通过 conda 命令创建一个专属的虚拟环境,再执行上述安装命令,避免版本冲突导致的安装失败问题。
完整代码解析
CleanRL 的 PPO 实现代码位于 cleanrl/ppo.py 文件中,整个代码的逻辑结构可以分为五个核心部分:超参数设置、环境构建、网络结构定义、训练数据采集、损失计算与网络优化。在实际运行过程中,我们可以通过命令行参数的形式,灵活调整所有的超参数配置 —— 比如修改训练的总步数、调整并行环境的数量、截断阈值的大小,或者更换训练的环境 ID 等。
接下来,我们将逐行拆解这个核心实现文件。
1. 超参数配置与导入
首先,代码导入了所有需要的核心依赖库,包括 PyTorch 的核心库、Gymnasium 的环境交互库、用于并行数据处理的 NumPy 库,以及用于训练日志可视化的 Tensorboard 工具包等。
接着,代码通过dataclass定义了一个Args数据类,将所有的超参数都集中在这个类中 —— 这一设计可以避免超参数在代码中零散分布,大幅提升了代码的可维护性。在这个Args类中,定义了 PPO 算法的所有核心超参数,包括环境 ID、训练总步数、学习率、并行环境数量、每个环境的采样步数、折扣因子、GAE 的 lambda 系数、更新的轮次、小批量的大小、截断阈值,以及价值函数损失项和熵正则化项的权重系数等。
这部分的核心代码如下:
import os
import random
import time
from dataclasses import dataclass
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import tyro
from torch.distributions.categorical import Categorical
from torch.utils.tensorboard import SummaryWriter
@dataclass
class Args:
exp_name: str = os.path.basename(__file__)[: -len(".py")]
seed: int = 1
torch_deterministic: bool = True
cuda: bool = True
track: bool = False
wandb_project_name: str = "cleanRL"
wandb_entity: str = None
capture_video: bool = False
env_id: str = "CartPole-v1"
total_timesteps: int = 500000
learning_rate: float = 2.5e-4
num_envs: int = 4
num_steps: int = 128
anneal_lr: bool = True
gamma: float = 0.99
gae_lambda: float = 0.95
num_minibatches: int = 4
update_epochs: int = 4
norm_adv: bool = True
clip_coef: float = 0.2
clip_vloss: bool = True
ent_coef: float = 0.01
vf_coef: float = 0.5
max_grad_norm: float = 0.5
target_kl: float = None
batch_size: int = 0
minibatch_size: int = 0
num_iterations: int = 0
在这些超参数中,有几个对训练效果影响极大的核心参数,需要特别说明:
- num_envs:并行环境的数量,默认值为 4—— 这意味着算法会同时启动 4 个 CartPole-v1 环境的副本,在每个迭代回合中同时采集样本数据;这一设计可以在单位时间内采集到更多的样本数据,提升训练效率。
- num_steps:每个并行环境在每个迭代回合中采集的时间步数,默认值为 128—— 这意味着每个环境副本在每次迭代中,会连续执行 128 个时间步的交互采样,采集足够的轨迹数据。
- update_epochs:每个迭代回合中,对采集到的样本数据进行小批量更新的轮次,默认值为 4—— 这意味着每一批次的样本数据,会被重复使用 4 次,最大化利用样本的同时,避免过拟合风险。
- clip_coef:PPO 算法中截断机制的核心参数,默认值为 0.2—— 这也是 OpenAI 在论文中经过大量基准实验验证得到的最优经验值;这一参数会将概率比的变化区间严格限制在 [0.8, 1.2] 范围内,有效约束策略更新的幅度。
- vf_coef:价值函数损失项在总代理目标函数中的权重系数,默认值为 0.5—— 这一参数平衡了策略网络和价值网络的更新优先级。
- ent_coef:熵正则化项在总代理目标函数中的权重系数,默认值为 0.01—— 这一参数维持了策略的探索性,避免策略过早收敛到局部最优的决策结果。
在代码执行过程中,这些超参数的具体数值会被打印到控制台,同时也会被保存到 Tensorboard 的日志文件中,方便后续的调试和实验结果复现。
2. 环境构建与初始化
这部分代码的核心作用是,创建并初始化我们要训练的目标环境,同时对环境进行必要的封装适配,后续采集到的样本数据才能符合 PPO 算法的输入格式要求。
代码中定义了一个make_env函数,这是一个高阶函数,它会返回一个环境初始化函数实例。通过这个函数,我们可以对环境进行一些标准化的封装适配 —— 比如在训练过程中记录游戏的回放视频,在每个回合游戏结束后,统计智能体的累计奖励和回合长度等关键指标。
接着,代码通过 Gymnasium 的SyncVectorEnv接口,将多个环境副本打包成一个并行处理的环境向量 —— 这一设计的核心目的,是为了提升采样的效率,让多个环境副本在同一时间执行交互采样。在实际运行过程中,每个并行环境副本都会被分配一个唯一的 ID,用于区分不同的采样数据来源;并且会在单独的进程中运行,避免多个环境副本之间的交互计算相互干扰。
这部分的核心代码如下:
def make_env(env_id, idx, capture_video, run_name):
def thunk():
if capture_video and idx == 0:
env = gym.make(env_id, render_mode="rgb_array")
env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
else:
env = gym.make(env_id)
env = gym.wrappers.RecordEpisodeStatistics(env)
return env
return thunk
if __name__ == "__main__":
args = tyro.cli(Args)
args.batch_size = int(args.num_envs * args.num_steps)
args.minibatch_size = int(args.batch_size // args.num_minibatches)
args.num_iterations = args.total_timesteps // args.batch_size
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
envs = gym.vector.SyncVectorEnv(
[make_env(args.env_id, i, args.capture_video, run_name) for i in range(args.num_envs)],
)
在这部分代码中,有几个关键的技术细节需要说明:
- 并行环境的数量,是通过args.num_envs参数来控制的 —— 在 PPO 算法的论文和大多数工程实现中,8 个并行环境是一个比较常见的配置;但在 CleanRL 的默认实现中,为了降低运行的资源消耗,将其设置为 4 个;在实际工程场景中,我们可以根据训练服务器的计算资源情况,灵活调整这一参数的数值。
- 每个并行环境在执行step函数时,会接收一个批量的动作向量作为输入 —— 这个向量的长度,等于并行环境的数量;执行完成后,会返回一个批量的观察值、奖励值、终止状态标志位,以及额外的调试信息。这一设计可以最大化利用计算资源,在单位时间内采集到更多的样本数据。
- 代码中通过RecordEpisodeStatistics包装器,自动统计了每个回合的累计奖励和回合长度 —— 这些指标是评估训练效果的核心依据。
3. Actor-Critic 网络结构定义
这部分代码是 PPO 算法的核心基础,定义了 Actor-Critic 架构的网络结构。在 CleanRL 的实现中,采用了 “共享特征提取层 + 双输出头” 的经典设计逻辑 —— 这一设计的核心是,让 Actor 和 Critic 网络共享同一个底层的特征提取网络,再通过两个独立的全连接输出头,分别输出动作决策和状态价值评估结果。这一设计的优势是,减少了网络参数的数量,提升了计算效率;并且通过共享特征提取层,保证了策略网络和价值网络的特征输入的一致性,提升了训练的稳定性。
具体来说,这个网络架构包含两个核心部分:
- Actor 网络(策略网络) :输入是当前环境的状态向量,输出是动作空间的概率分布 —— 在 CartPole-v1 环境中,动作空间是离散的(向左移动或向右移动),因此网络输出的是动作类别的对数概率;在实际执行过程中,算法会从这个概率分布中采样,得到最终要执行的动作。
- Critic 网络(价值网络) :输入同样是当前环境的状态向量,输出是一个单独的标量值 —— 即对当前状态的长期价值评估结果,也就是从当前状态出发,后续所有奖励的折现值;这一结果将用于计算优势函数,指导策略的更新方向。
在网络结构的具体实现中,Actor 网络和 Critic 网络各采用了两个隐藏层,每层包含 64 个神经元;激活函数选择了 Tanh—— 这一选择是为了将网络的输出数据分布限制在 [-1, 1] 的区间内,这与我们在数据预处理时对环境状态的归一化处理逻辑是匹配的;如果使用其他激活函数,比如 ReLU,可能会导致输出数据分布差异过大,影响训练的稳定性。
需要特别说明的是,这一网络结构的设计,是针对 CartPole-v1 这类经典的低维控制任务优化的。如果我们要处理的是高维输入任务(比如以游戏画面作为输入的 Atari 任务),需要在网络结构中增加卷积层,用于提取输入画面中的空间特征;或是在网络结构中增加 LSTM 层,用于处理带时间依赖的状态序列。
这部分的核心代码如下:
def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
torch.nn.init.orthogonal_(layer.weight, std)
torch.nn.init.constant_(layer.bias, bias_const)
return layer
class Agent(nn.Module):
def __init__(self, envs):
super().__init__()
self.critic = nn.Sequential(
layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)),
nn.Tanh(),
layer_init(nn.Linear(64, 64)),
nn.Tanh(),
layer_init(nn.Linear(64, 1), std=1.0),
)
self.actor = nn.Sequential(
layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)),
nn.Tanh(),
layer_init(nn.Linear(64, 64)),
nn.Tanh(),
layer_init(nn.Linear(64, envs.single_action_space.n), std=0.01),
)
def get_value(self, x):
return self.critic(x)
def get_action_and_value(self, x, action=None):
logits = self.actor(x)
probs = Categorical(logits=logits)
if action is None:
action = probs.sample()
return action, probs.log_prob(action), probs.entropy(), self.critic(x)
在这部分代码中,有几个关键的技术细节需要说明:
- 网络权重初始化:代码中定义了一个layer_init函数,用于对网络的权重参数进行初始化。这一函数采用了正交初始化方案 —— 这是在训练深度强化学习网络时,避免梯度消失或梯度爆炸问题的常用且非常有效的技术方案;这一方案可以保证网络在初始阶段,输出特征的分布保持在合理的区间内。
- Actor 网络的输出层初始化:Actor 网络的输出层权重初始化的标准差被设置为 0.01—— 这是一个经过实验验证的经验值,它可以保证在训练的初始阶段,网络输出的动作概率分布尽可能均匀,不会偏向某一个特定的动作;这一设计的目的,是为了让智能体在训练初期,能够充分探索环境中的各种可能动作,避免过早收敛到局部最优的决策结果。
- Critic 网络的输出层初始化:Critic 网络的输出层权重初始化的标准差被设置为 1.0—— 这一设置是为了让价值网络在初始阶段,对状态的价值评估结果有一个合理的数值基础;如果初始化的标准差太小,价值网络的初始评估结果会过于集中在一个较小的区间内,这会导致策略网络的更新方向不够明确,影响训练速度。
- get_action_and_value函数:这是网络的核心接口函数,它会同时输出四个核心内容:根据概率分布采样得到的动作、该动作对应的对数概率、当前动作概率分布的熵值,以及价值网络对当前状态的评估结果。在优化过程中,我们会将采集到的动作作为输入,传入这个函数中,重新计算当前策略下的动作对数概率和熵值 —— 这是为了计算新旧策略之间的概率比,是后续计算代理目标函数的关键依据。
4. 训练数据采集与存储
这部分代码是 PPO 算法的关键环节,负责在每个迭代回合中,通过旧策略与环境交互采集样本数据,再将这些样本数据批量保存下来,作为后续优化阶段的核心输入。
在采集数据之前,代码会初始化一系列的张量,用于存储采集到的多步轨迹数据 —— 这些张量的维度被设置为(num_steps, num_envs, ...),这一设计的核心是,按照 “时间步 - 环境副本” 的二维格式,批量存储所有并行环境的采样数据。在存储过程中,所有数据会被按照时间步长排列,保持轨迹的时序关联性;这些数据包含:环境的状态、执行的动作、动作对应的对数概率、环境返回的奖励值、环境是否结束的标志位,以及价值网络对每个状态的评估结果。
在每个迭代回合中,代码会执行num_steps步的采样逻辑 —— 在每一步中,首先,旧策略会根据当前的环境状态,输出一个动作决策;接着,这个动作会被传入到环境的step函数中,环境执行这个动作并返回对应的结果,包括下一个状态、动作的即时奖励值、环境是否结束的标志位、以及额外的调试信息;最后,所有的交互结果数据,会被保存到预先初始化的张量中。
这部分的核心代码如下:
# ALGO Logic: Storage setup
obs = torch.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape).to(device)
actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device)
logprobs = torch.zeros((args.num_steps, args.num_envs)).to(device)
rewards = torch.zeros((args.num_steps, args.num_envs)).to(device)
dones = torch.zeros((args.num_steps, args.num_envs)).to(device)
values = torch.zeros((args.num_steps, args.num_envs)).to(device)
# TRY NOT TO MODIFY: seeding
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.backends.cudnn.deterministic = args.torch_deterministic
device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
# TRY NOT TO MODIFY: start the game
global_step = 0
start_time = time.time()
next_obs, _ = envs.reset(seed=args.seed)
next_obs = torch.Tensor(next_obs).to(device)
next_done = torch.zeros(args.num_envs).to(device)
for iteration in range(1, args.num_iterations + 1):
# Annealing the rate if instructed to do so.
if args.anneal_lr:
frac = 1.0 - (iteration - 1.0) / args.num_iterations
lrnow = frac * args.learning_rate
optimizer.param_groups[0]["lr"] = lrnow
for step in range(0, args.num_steps):
global_step += args.num_envs
obs[step] = next_obs
dones[step] = next_done
# ALGO LOGIC: action logic
with torch.no_grad():
action, logprob, _, value = agent.get_action_and_value(next_obs)
values[step] = value.flatten()
actions[step] = action
logprobs[step] = logprob
# TRY NOT TO MODIFY: execute the game and log data.
next_obs, reward, terminations, truncations, infos = envs.step(action.cpu().numpy())
next_done = np.logical_or(terminations, truncations)
rewards[step] = torch.tensor(reward).to(device).view(-1)
next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(next_done).to(device)
在这部分代码中,有几个关键的技术细节需要说明:
- 梯度关闭:在采样过程中,代码中使用了torch.no_grad()装饰器 —— 这是为了关闭 PyTorch 的自动梯度计算引擎,避免采样过程中产生多余的梯度计算逻辑,这一设计可以显著节省显存空间和采样的计算时间。
- 数据存储格式:所有采集到的样本数据,都会被存储在连续的张量空间中 —— 这一设计是为了后续能够高效地将数据批量传入模型进行训练,提升数据加载的效率。
- 多环境并行采样:在每一步的采样逻辑中,所有并行环境的动作决策会被一次性批量执行 —— 这一设计可以最大化利用计算资源,在单位时间内采集到更多的样本数据。
- 学习率衰减:代码中实现了学习率的线性衰减逻辑 —— 随着迭代次数的增加,学习率会逐步线性减小,这是为了在训练后期,让策略的更新幅度逐步放缓,避免出现训练震荡的情况。这一逻辑是可配置的,通过anneal_lr参数来控制是否开启这一功能。
5. 优势计算与样本准备
在完成样本数据采集后,需要计算每个时间步的优势函数与回报值。这是连接数据采样与策略更新的关键环节 —— 优势函数是衡量动作决策相对优劣的核心依据,也是后续计算代理目标函数的关键输入。
这部分代码采用了 “广义优势估计”(Generalized Advantage Estimation,GAE)方案来计算优势函数 —— 这是当前行业内,降低优势函数的方差、提升训练稳定性的最优技术方案。在计算过程中,GAE 会先计算每个时间步的时序差分误差—— 即实际奖励与价值网络预期评估结果的差值;再基于这个差值,通过一个折扣系数
,将多个时间步的误差加权求和,得到最终的优势评估结果。这一方案可以在 “优势函数的偏差” 和 “方差” 之间找到最优平衡点,让后续的策略梯度更稳定,训练过程更易收敛。
在具体实现逻辑中,GAE 的计算需要从轨迹的最后一个时间步开始,一步步向前推导计算 —— 这是因为前一个时间步的优势计算结果,需要依赖下一个时间步的优势评估结果;所以在代码中,采用了逆序循环的方式来完成这一计算过程。
在完成优势函数计算后,代码会将采集到的多组连续轨迹数据,重新整理成适合小批量训练的格式 —— 将所有的轨迹数据拼接成一个完整的批量数据集,再将这些数据随机打乱,分成多个小批量数据,供后续的多轮更新使用。
这部分的核心代码如下:
# bootstrap value if not done
with torch.no_grad():
next_value = agent.get_value(next_obs).reshape(1, -1)
advantages = torch.zeros_like(rewards).to(device)
lastgaelam = 0
for t in reversed(range(args.num_steps)):
if t == args.num_steps - 1:
nextnonterminal = 1.0 - next_done
nextvalues = next_value
else:
nextnonterminal = 1.0 - dones[t + 1]
nextvalues = values[t + 1]
delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t]
advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam
returns = advantages + values
在这部分代码中,有几个关键的技术细节需要说明:
- 逆序计算:代码中采用了逆序循环的方式,从最后一个时间步开始向前推导计算优势值 —— 这是因为 GAE 的计算逻辑需要依赖下一个时间步的评估结果,才能保证整个轨迹的优势评估结果的连贯性和准确性;这一步的计算复杂度是 O (T),T 是每个环境采集的时间步数。
- 终止状态处理:在计算过程中,代码会根据环境是否到达终止状态,来决定是否 “屏蔽” 下一步的价值评估结果 —— 如果环境已经到达终止状态,后续的价值评估结果会被设置为 0;这一设计是为了保证优势计算的准确性,避免将无效的未来价值评估结果,纳入到当前状态的优势计算中。
- 回报值计算:在完成优势函数计算后,代码会通过returns = advantages + values计算得到每个时间步的实际回报值 —— 这一回报值将作为优化价值网络的 “真实标签”,用于后续的价值网络更新;这一设计的核心本质,是将 “优势函数” 和 “价值网络评估结果” 结合起来,得到一个更准确的、对当前状态的长期价值评估结果。
6. 代理目标函数计算与网络优化
这是 PPO 算法的核心优化环节,也是整个代码中最关键的逻辑部分 —— 在这个环节中,算法会基于之前计算得到的优势值,构建代理目标函数,并执行梯度更新,优化 Actor 网络和 Critic 网络的参数。
在每个小批量的更新步骤中,代码会首先重新计算当前策略下的动作对数概率分布 —— 这是为了计算新旧策略之间的概率比。接着,代码会按照 PPO 的核心逻辑,构建并计算出三个损失项,将它们合并成最终的代理目标函数:
- 策略代理损失项:这是截断后的代理目标函数,由新旧策略的概率比乘以优势函数得到;这一损失项的核心作用是,指导 Actor 网络的参数更新,让策略在后续决策中,更倾向于选择那些能获得更高奖励的动作。
- 价值函数损失项:这是价值网络的评估结果与实际回报值之间的均方误差;这一损失项的核心作用是,指导 Critic 网络的参数更新,提升其对状态价值的评估准确性 —— 只有当价值网络的评估结果足够准确时,优势函数的计算结果才会有意义。
- 熵正则化项:这是策略输出的动作概率分布的熵值;这一项的核心作用是,维持策略的探索性,避免策略过早收敛到局部最优的决策结果 —— 当熵值越高时,策略的探索性就越强;反之,则越弱。
在计算得到代理目标函数后,代码会通过 Adam 优化器,执行梯度上升更新步骤 —— 在更新过程中,为了避免梯度爆炸问题,算法会对梯度的范数进行强制截断,将其限制在一个合理的区间内。此外,代码还会在更新过程中,实时计算新旧策略之间的近似 KL 散度,以及截断操作的触发比例 —— 这些指标会被记录到 Tensorboard 的日志中,用于监控训练的状态,以及调试超参数的配置是否合理。
这部分的核心代码如下:
# flatten the batch
b_obs = obs.reshape((-1,) + envs.single_observation_space.shape)
b_logprobs = logprobs.reshape(-1)
b_actions = actions.reshape((-1,) + envs.single_action_space.shape)
b_advantages = advantages.reshape(-1)
b_returns = returns.reshape(-1)
b_values = values.reshape(-1)
# Optimizing the policy and value network
b_inds = np.arange(args.batch_size)
clipfracs = []
for epoch in range(args.update_epochs):
np.random.shuffle(b_inds)
for start in range(0, args.batch_size, args.minibatch_size):
end = start + args.minibatch_size
mb_inds = b_inds[start:end]
_, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_actions.long()[mb_inds])
logratio = newlogprob - b_logprobs[mb_inds]
ratio = logratio.exp()
with torch.no_grad():
# calculate approx_kl http://joschu.net/blog/kl-approx.html
old_approx_kl = (-logratio).mean()
approx_kl = ((ratio - 1) - logratio).mean()
clipfracs += [((ratio - 1.0).abs() > args.clip_coef).float().mean().item()]
mb_advantages = b_advantages[mb_inds]
if args.norm_adv:
mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8)
# Policy loss
pg_loss1 = -mb_advantages * ratio
pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef)
pg_loss = torch.max(pg_loss1, pg_loss2).mean()
# Value loss
newvalue = newvalue.view(-1)
if args.clip_vloss:
v_loss_unclipped = (newvalue - b_returns[mb_inds]) ** 2
v_clipped = b_values[mb_inds] + torch.clamp(
newvalue - b_values[mb_inds],
-args.clip_coef,
args.clip_coef,
)
v_loss_clipped = (v_clipped - b_returns[mb_inds]) ** 2
v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
v_loss = 0.5 * v_loss_max.mean()
else:
v_loss = 0.5 * ((newvalue - b_returns[mb_inds]) ** 2).mean()
# Entropy loss
entropy_loss = entropy.mean()
# Total loss
loss = pg_loss - args.vf_coef * v_loss + args.ent_coef * entropy_loss
optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm)
optimizer.step()
在这部分代码中,有几个关键的技术细节需要特别说明:
- 新旧策略概率比计算:代码中通过logratio = newlogprob - b_logprobs[mb_inds]计算得到新旧策略的动作概率比的自然对数;再通过ratio = logratio.exp(),得到概率比的实际数值 —— 这一计算方式是为了提升数值计算的稳定性:在概率比的数值差异较大时,对对数概率进行减法运算,可以避免直接进行除法运算带来的数值溢出风险。
- 截断机制的实现逻辑:代码中通过torch.clamp函数,将概率比的变化区间严格限制在[1-args.clip_coef, 1+args.clip_coef]范围内 —— 这一函数会将超出这个区间的概率比强制设置为区间的边界值,如当概率比的数值大于1+args.clip_coef时,会被强制截断为1+args.clip_coef;小于1-args.clip_coef时,会被强制截断为1-args.clip_coef。这一逻辑完全遵循了 PPO 论文中的核心设计,将策略更新幅度的变化区间,限制在一个合理的安全范围内。
- 优势归一化:代码中通过norm_adv参数,控制是否对优势函数进行归一化处理 —— 这一设计是为了提升训练的稳定性:在实际场景中,不同环境状态下的优势函数数值差异可能很大;通过归一化处理,可以将优势函数的数值分布,调整到一个相对标准的区间内,让梯度更新的方向更加稳定。
- 策略损失的计算逻辑:代码中分别计算了未截断的代理目标函数和经过截断处理的代理目标函数,再取两者中的较大值,作为最终的策略损失项 —— 这一设计的核心本质,是在优化过程中选择一个更保守、更安全的更新方向;这与论文中定义的截断目标函数逻辑完全一致,保证了策略更新的稳定性。
- 价值损失的计算逻辑:代码中实现了对价值函数损失项的截断处理 —— 这一设计是为了防止价值网络的更新幅度过大,导致训练过程出现震荡的情况;这一功能是可配置的,通过clip_vloss参数来控制是否开启。如果开启了这一功能,代码会将价值网络的更新幅度,限制在与策略网络更新幅度相同的安全区间内。
- 熵损失的计算逻辑:熵正则化项的权重系数在代码中被设置为正数 —— 这意味着在优化过程中,我们会最大化策略的熵值,以鼓励智能体保持一定的探索性;避免它过早收敛到局部最优的决策结果。
- 梯度截断:代码中通过nn.utils.clip_grad_norm_函数,对梯度的范数进行强制截断,将其限制在预设的最大范数区间内 —— 这是训练深度强化学习网络时,避免梯度爆炸问题的一项非常关键的技术方案;在 CleanRL 的官方实现中,这一最大范数的默认值为 0.5,这是一个经过实验验证的经验值。
- KL 散度检查:在训练过程中,代码会实时计算新旧策略之间的近似 KL 散度。这一指标不会直接约束优化过程,仅作为参考指标被记录到日志中;但如果这一数值的变化幅度超过了合理区间,说明截断机制可能失效了,需要调整截断系数的大小。
7. 训练日志与监控
在每个迭代回合完成后,代码会将训练过程中的关键指标,通过 Tensorboard 工具记录到日志文件中 —— 这些指标包括:策略损失项、价值函数损失项、熵正则化项、新旧策略之间的近似 KL 散度、截断操作的触发比例、每个训练回合的累计奖励值、回合长度,以及 GPU 的利用率等。
在训练完成后,我们可以通过 Tensorboard 的可视化界面,实时查看这些指标的变化趋势,评估训练的效果,以及判断超参数的配置是否合理。
这部分的核心代码如下:
# Logging
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step)
writer.add_scalar("losses/value_loss", v_loss.item(), global_step)
writer.add_scalar("losses/entropy", entropy_loss.item(), global_step)
writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step)
writer.add_scalar("losses/clipfrac", np.mean(clipfracs), global_step)
运行与效果验证
在完成代码解析后,我们可以通过以下命令,直接启动训练过程,验证 PPO 算法的实际训练效果:
poetry run python cleanrl/ppo.py \
--seed 1 \
--env-id CartPole-v1 \
--total-timesteps 500000
在这个命令中,我们设置了三个关键参数:随机种子为 1,训练环境为 CartPole-v1,训练的总步数为 50 万步 —— 这一训练量,足以让 PPO 算法在 CartPole-v1 环境中,收敛到最优的决策效果。
在训练过程中,代码会在控制台实时打印当前的训练进度信息,包括当前的全局训练步数、最近一个训练回合的累计奖励值、回合长度等核心指标。同时,所有的训练日志数据,都会被保存到当前目录的runs文件夹中。
我们可以通过以下命令,启动 Tensorboard 的可视化看板,在浏览器中实时监控训练过程中各项指标的变化趋势:
tensorboard --logdir runs
在训练结束后,我们可以在 Tensorboard 的可视化看板中,查看以下几项核心指标的变化趋势,验证训练效果:
- 累计奖励值:这是衡量策略性能的最直接指标 —— 随着训练步数的增加,CartPole-v1 环境的回合累计奖励值会逐步上升,最终收敛到最优水平;这一环境的理论最高奖励值为 500 分。
- 策略损失项:在训练初期,策略损失项的数值会比较大,随着训练的推进,其数值会逐步稳定在一个较低的区间内 —— 这说明策略的更新方向已经趋于稳定,不会再出现剧烈的参数变化。
- 价值函数损失项:这一指标会随着训练的推进,逐步稳定在一个较低的区间内 —— 这说明价值网络对状态价值的评估结果,已经足够准确。
- 近似 KL 散度:这一指标的数值,应该始终被限制在合理的区间内 —— 如果其数值突然大幅上升,说明截断机制的约束效果不足,需要调整截断系数或学习率的配置;如果这一数值一直处于合理区间内,说明截断机制正在正常工作。
如果在训练过程中,开启了视频录制的配置选项,代码会将训练过程中,智能体在第一个并行环境内的游戏画面,自动录制并保存到videos文件夹中 —— 我们可以通过这些回放视频,直观地查看智能体的决策表现。
总结
PPO 算法是当前强化学习领域的主流工业方案,它的核心优势在于,通过一个简单且有效的截断机制,在保证训练稳定性的前提下,实现了接近 TRPO 的样本效率,并且大幅降低了工程实现门槛。这一算法的核心设计逻辑,可以总结为三个关键环节:其一是通过 “信任区域” 的约束思想,限制单次策略更新的幅度;其二是通过截断代理目标函数或自适应 KL 散度惩罚项,实现对策略更新幅度的约束;其三是通过多轮小批量更新方案,最大化利用样本数据,提升算法的样本效率。
在技术实现层面,PPO 采用了 Actor-Critic 架构,并且通过截断机制的约束,将策略的更新幅度限制在安全区间内;这一设计的核心本质,是将 “策略优化” 与 “更新幅度约束” 这两个矛盾的目标,通过一个简单的一阶截断机制,完美平衡在了一起。
本文中我们拆解的 CleanRL 库实现的 PPO 代码案例,是学习 PPO 算法落地细节的最优参考案例 —— 这一实现方式,完全遵循了 PPO 论文中的原始设计,并且将所有核心逻辑封装在一个不到 300 行的文件中,没有任何多余的封装逻辑;其代码结构清晰,且复现难度较低,非常适合作为学习 PPO 算法工程落地细节的参考资料。
从实际应用的角度来看,PPO 算法是解决大部分连续决策类任务的首选方案 —— 无论是在经典的强化学习基准任务中,还是在复杂的工业场景中,比如机器人的运动控制、工业机械臂的动作规划、大模型的人类偏好对齐,或是游戏中的非玩家角色(NPC)的决策逻辑优化;PPO 都能在训练稳定性和样本效率之间,达到一个非常优异的平衡效果。
对于有进一步深入学习需求的开发者来说,在理解了本文所介绍的核心原理与基础实现逻辑后,可以再进一步拓展学习方向:第一,尝试将这一基础版本的 PPO 算法,应用到更复杂的强化学习任务场景中,比如 Atari 游戏的对战任务、机器人运动控制的 MuJoCo 环境任务等;第二,研究 PPO 算法的一些主流优化变体,比如采用并行环境提升采样效率的 PPO-x2、结合了 lstm 层处理带时间依赖的状态序列的 PPO、或是结合了分位数价值分布的 PPG 算法等;第三,尝试在网络结构中引入卷积层或 LSTM 层,处理高维输入任务,比如以游戏画面、传感器数据作为输入的决策任务;第四,研究如何调整算法的超参数,包括截断系数、学习率、折扣因子、GAE 的 lambda 系数等,进一步提升算法的样本效率和训练稳定性;第五,研究如何将 PPO 算法应用到多智能体的协同或对抗场景中,这也是当前强化学习领域的一个主流研究方向。
掌握 PPO 算法的原理与工程落地实现细节,是每一位强化学习开发者的必经之路 —— 通过本文的原理解读与代码细节拆解,相信你已经对这一核心算法,有了更深入的理解;并且具备了将这一算法,应用到实际项目中的基础能力。
DAMO开发者矩阵,由阿里巴巴达摩院和中国互联网协会联合发起,致力于探讨最前沿的技术趋势与应用成果,搭建高质量的交流与分享平台,推动技术创新与产业应用链接,围绕“人工智能与新型计算”构建开放共享的开发者生态。
更多推荐

所有评论(0)