强化学习之 PPO 算法:原理、实现与案例深度剖析

目录

    • 一、引言
    • 二、PPO 算法原理
      • 2.1 策略梯度
      • 2.2 PPO 核心思想
    • 三、PPO 算法公式推导
      • 3.1 重要性采样
      • 3.2 优势函数估计
    • 四、PPO 算法代码实现(以 Python 和 PyTorch 为例)
    • 五、PPO 算法案例应用
      • 5.1 机器人控制
      • 5.2 自动驾驶
    • 六、总结


一、引言

强化学习作为机器学习中的一个重要领域,旨在让智能体通过与环境交互,学习到最优的行为策略以最大化长期累积奖励。近端策略优化(Proximal Policy Optimization,PPO)算法是强化学习中的明星算法,它在诸多领域都取得了令人瞩目的成果。本文将深入探讨 PPO 算法,从原理到代码实现,再到实际案例应用,力求让读者全面掌握这一强大的算法。

二、PPO 算法原理

2.1 策略梯度

在强化学习里,策略梯度是一类关键的优化方法,你可以把它想象成是智能体在学习如何行动时的 “指南针”。假设策略由参数 θ \theta θ 表示,这就好比是智能体的 “行动指南” 参数,智能体在状态 s s s 下采取行动 a a a 的概率为 π θ ( a ∣ s ) \pi_{\theta}(a|s) πθ(as) ,即根据当前的 “行动指南”,在这个状态下选择这个行动的可能性。

策略梯度的目标是最大化累计奖励的期望,用公式表示就是: J ( θ ) = E s 0 , a 0 , ⋯ [ ∑ t = 0 T γ t r ( s t , a t ) ] J(\theta)=\mathbb{E}_{s_0,a_0,\cdots}\left[\sum_{t = 0}^{T}\gamma^{t}r(s_t,a_t)\right] J(θ)=Es0,a0,[t=0Tγtr(st,at)]

这里的 γ \gamma γ 是折扣因子,它的作用是让智能体更关注近期的奖励,因为越往后的奖励可能越不确定,就像我们在做决策时,往往会更看重眼前比较确定的好处。 r ( s t , a t ) r(s_t,a_t) r(st,at) 是在状态 s t s_t st 下采取行动 a t a_t at 获得的奖励,比如玩游戏时,在某个游戏场景下做出某个操作得到的分数。

根据策略梯度定理,策略梯度可以表示为: ∇ θ J ( θ ) = E s , a [ ∇ θ log ⁡ π θ ( a ∣ s ) A ( s , a ) ] \nabla_{\theta}J(\theta)=\mathbb{E}_{s,a}\left[\nabla_{\theta}\log\pi_{\theta}(a|s)A(s,a)\right] θJ(θ)=Es,a[θlogπθ(as)A(s,a)]

这里的 A ( s , a ) A(s,a) A(s,a) 是优势函数,它表示采取行动 a a a 相对于平均策略的优势。简单来说,就是判断这个行动比一般的行动好在哪里,好多少,帮助智能体决定是否要多采取这个行动。

2.2 PPO 核心思想

PPO 算法的核心是在策略更新时,限制策略的变化幅度,避免更新过大导致策略性能急剧下降。这就好像我们在调整自行车的变速器,如果一下子调得太猛,可能车子就没法正常骑了。

它通过引入一个截断的目标函数来实现这一点: L C L I P ( θ ) = E t [ min ⁡ ( r t ( θ ) A ^ t , clip ( r t ( θ ) , 1 − ϵ , 1 + ϵ ) A ^ t ) ] L^{CLIP}(\theta)=\mathbb{E}_{t}\left[\min\left(r_t(\theta)\hat{A}_t, \text{clip}(r_t(\theta), 1 - \epsilon, 1+\epsilon)\hat{A}_t\right)\right] LCLIP(θ)=Et[min(rt(θ)A^t,clip(rt(θ),1ϵ,1+ϵ)A^t)]

其中 r t ( θ ) = π θ ( a t ∣ s t ) π θ o l d ( a t ∣ s t ) r_t(\theta)=\frac{\pi_{\theta}(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)} rt(θ)=πθold(atst)πθ(atst) 是重要性采样比,它反映了新策略和旧策略对于同一个状态 - 行动对的概率差异。 A ^ t \hat{A}_t A^t 是估计的优势函数, ϵ \epsilon ϵ 是截断参数,通常设置为一个较小的值,如 0.2 。这个截断参数就像是给策略更新幅度设定了一个 “安全范围”,在这个范围内更新策略,能保证策略既有所改进,又不会变得太糟糕。

三、PPO 算法公式推导

3.1 重要性采样

重要性采样是 PPO 算法中的关键技术之一。由于直接从当前策略采样数据效率较低,我们可以从旧策略 π θ o l d \pi_{\theta_{old}} πθold 采样数据,然后通过重要性采样比 r t ( θ ) r_t(\theta) rt(θ) 来校正数据的分布。 E s ∼ π θ [ f ( s ) ] ≈ 1 N ∑ i = 1 N π θ ( s i ) π θ o l d ( s i ) f ( s i ) \mathbb{E}_{s\sim\pi_{\theta}}[f(s)]\approx\frac{1}{N}\sum_{i = 1}^{N}\frac{\pi_{\theta}(s_i)}{\pi_{\theta_{old}}(s_i)}f(s_i) Esπθ[f(s)]N1i=1Nπθold(si)πθ(si)f(si)

比如我们要了解一群鸟的飞行习惯,直接去观察所有鸟的飞行轨迹很困难,那我们可以先观察一部分容易观察到的鸟(旧策略采样),然后根据这些鸟和所有鸟的一些特征差异(重要性采样比),来推测整个鸟群的飞行习惯。

3.2 优势函数估计

优势函数 A ( s , a ) A(s,a) A(s,a) 可以通过多种方法估计,常用的是广义优势估计(Generalized Advantage Estimation,GAE): A ^ t = ∑ k = 0 ∞ ( γ λ ) k δ t + k \hat{A}_t=\sum_{k = 0}^{\infty}(\gamma\lambda)^k\delta_{t + k} A^t=k=0(γλ)kδt+k

其中 δ t = r t + γ V ( s t + 1 ) − V ( s t ) \delta_{t}=r_t+\gamma V(s_{t + 1})-V(s_t) δt=rt+γV(st+1)V(st) 是 TD 误差, λ \lambda λ 是 GAE 参数,通常在 0 到 1 之间。优势函数的估计就像是给智能体的行动打分,告诉它每个行动到底有多好,以便它做出更好的决策。

四、PPO 算法代码实现(以 Python 和 PyTorch 为例)

import torchimport torch.nn as nnimport torch.optim as optimimport gymclass Policy(nn.Module):def __init__(self, state_dim, action_dim):super(Policy, self).__init__()self.fc1 = nn.Linear(state_dim, 64)self.fc2 = nn.Linear(64, 64)self.mu_head = nn.Linear(64, action_dim)self.log_std_head = nn.Linear(64, action_dim)def forward(self, x):x = torch.relu(self.fc1(x))x = torch.relu(self.fc2(x))mu = torch.tanh(self.mu_head(x))log_std = self.log_std_head(x)std = torch.exp(log_std)dist = torch.distributions.Normal(mu, std)return distclass Value(nn.Module):def __init__(self, state_dim):super(Value, self).__init__()self.fc1 = nn.Linear(state_dim, 64)self.fc2 = nn.Linear(64, 64)self.v_head = nn.Linear(64, 1)def forward(self, x):x = torch.relu(self.fc1(x))x = torch.relu(self.fc2(x))v = self.v_head(x)return vdef ppo_update(policy, value, optimizer_policy, optimizer_value, states, actions, rewards, dones, gamma=0.99,clip_epsilon=0.2, lambda_gae=0.95):states = torch.FloatTensor(states)actions = torch.FloatTensor(actions)rewards = torch.FloatTensor(rewards)dones = torch.FloatTensor(dones)values = value(states).squeeze(1)returns = []gae = 0for i in reversed(range(len(rewards))):if i == len(rewards) - 1:next_value = 0else:next_value = values[i + 1]delta = rewards[i] + gamma * next_value * (1 - dones[i]) - values[i]gae = delta + gamma * lambda_gae * (1 - dones[i]) * gaereturns.insert(0, gae + values[i])returns = torch.FloatTensor(returns)old_dist = policy(states)old_log_probs = old_dist.log_prob(actions).sum(-1)for _ in range(3):dist = policy(states)log_probs = dist.log_prob(actions).sum(-1)ratios = torch.exp(log_probs - old_log_probs)advantages = returns - values.detach()surr1 = ratios * advantagessurr2 = torch.clamp(ratios, 1 - clip_epsilon, 1 + clip_epsilon) * advantagespolicy_loss = -torch.min(surr1, surr2).mean()optimizer_policy.zero_grad()policy_loss.backward()optimizer_policy.step()value_loss = nn.MSELoss()(value(states).squeeze(1), returns)optimizer_value.zero_grad()value_loss.backward()optimizer_value.step()def train_ppo(env_name, num_episodes=1000):env = gym.make(env_name)state_dim = env.observation_space.shape[0]action_dim = env.action_space.shape[0]policy = Policy(state_dim, action_dim)value = Value(state_dim)optimizer_policy = optim.Adam(policy.parameters(), lr=3e-4)optimizer_value = optim.Adam(value.parameters(), lr=3e-4)for episode in range(num_episodes):states, actions, rewards, dones = [], [], [], []state = env.reset()done = Falsewhile not done:state = torch.FloatTensor(state)dist = policy(state)action = dist.sample()next_state, reward, done, _ = env.step(action.detach().numpy())states.append(state)actions.append(action)rewards.append(reward)dones.append(done)state = next_stateppo_update(policy, value, optimizer_policy, optimizer_value, states, actions, rewards, dones)if episode % 100 == 0:total_reward = 0state = env.reset()done = Falsewhile not done:state = torch.FloatTensor(state)dist = policy(state)action = dist.meannext_state, reward, done, _ = env.step(action.detach().numpy())total_reward += rewardstate = next_stateprint(f"Episode {episode}, Average Reward: {total_reward}")if __name__ == "__main__":train_ppo('Pendulum-v1')

五、PPO 算法案例应用

5.1 机器人控制

在机器人控制领域,PPO 算法可以用于训练机器人的运动策略。例如,训练一个双足机器人行走,机器人的状态可以包括关节角度、速度等信息,行动则是关节的控制指令。通过 PPO 算法,机器人可以学习到如何根据当前状态调整关节控制,以实现稳定高效的行走。

5.2 自动驾驶

在自动驾驶场景中,车辆的状态包括位置、速度、周围环境感知信息等,行动可以是加速、减速、转向等操作。PPO 算法可以让自动驾驶系统学习到在不同路况和环境下的最优驾驶策略,提高行驶的安全性和效率。

六、总结

PPO 算法作为强化学习中的优秀算法,以其高效的学习能力和良好的稳定性在多个领域得到了广泛应用。通过深入理解其原理、公式推导,结合代码实现和实际案例分析,我们能够更好地掌握和运用这一算法,为解决各种复杂的实际问题提供有力的工具。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/15791.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

拾取丢弃物品(结构体/数组/子UI/事件分发器)

实现:场景中随机生成几种物品,玩家可以拾取这些物品,也可丢弃已经拾取到的物品。 拾取丢弃物品时UI能实时更新玩家身上的物品量。 一.物品信息的创建 1.枚举 物品名 2.结构体表示物体属性 3.物品缩略图(缩略图大小要为2的n次方…

KITE提示词框架:引导大语言模型的高效新工具

大语言模型的应用日益广泛。然而,如何确保这些模型生成的内容在AI原生应用中符合预期,仍是一个需要不断探索的问题。以下内容来自于《AI 原生应用开发:提示工程原理与实战》一书(京东图书:https://item.jd.com/1013604…

性能优化中的系统架构优化

系统架构优化是性能优化的一个重要方面,它涉及到对整个IT系统或交易链上各个环节的分析与改进。通过系统架构优化,可以提高系统的响应速度、吞吐量,并降低各层之间的耦合度,从而更好地应对市场的变化和需求。业务增长导致的性能问…

【学习笔记】计算机网络(三)

第3章 数据链路层 文章目录 第3章 数据链路层3.1数据链路层的几个共同问题3.1.1 数据链路和帧3.1.2 三个基本功能3.1.3 其他功能 - 滑动窗口机制 3.2 点对点协议PPP(Point-to-Point Protocol)3.2.1 PPP 协议的特点3.2.2 PPP协议的帧格式3.2.3 PPP 协议的工作状态 3.3 使用广播信…

机器学习 - 理解偏差-方差分解

为了避免过拟合,我们经常会在模型的拟合能力和复杂度之间进行权衡。拟合能力强的模型一般复杂度会比较高,容易导致过拟合。相反,如果限制模型的复杂度,降低其拟合能力,又可能会导致欠拟合。因此,如何在模型…

【STM32】ADC

本次实现的是ADC实现数字信号与模拟信号的转化,数字信号时不连续的,模拟信号是连续的。 1.ADC转化的原理 模拟-数字转换技术使用的是逐次逼近法,使用二分比较的方法来确定电压值 当单片机对应的参考电压为3.3v时,0~ 3.3v(模拟信号…

DeepSeek 助力 Vue 开发:打造丝滑的步骤条

前言:哈喽,大家好,今天给大家分享一篇文章!并提供具体代码帮助大家深入理解,彻底掌握!创作不易,如果能帮助到大家或者给大家一些灵感和启发,欢迎收藏关注哦 💕 目录 Deep…

基于Python的人工智能驱动基因组变异算法:设计与应用(下)

3.3.2 数据清洗与预处理 在基因组变异分析中,原始数据往往包含各种噪声和不完整信息,数据清洗与预处理是确保分析结果准确性和可靠性的关键步骤。通过 Python 的相关库和工具,可以有效地去除噪声、填补缺失值、标准化数据等,为后续的分析提供高质量的数据基础。 在基因组…

elasticsearch安装插件analysis-ik分词器(深度研究docker内elasticsearch安装插件的位置)

最近在学习使用elasticsearch,但是在安装插件ik的时候遇到许多问题。 所以在这里开始对elasticsearch做一个深度的研究。 首先提供如下链接: https://github.com/infinilabs/analysis-ik/releases 我们下载elasticsearch-7-17-2的Linux x86_64版本 …

linux部署ollama+deepseek+dify

Ollama 下载源码 curl -L https://ollama.com/download/ollama-linux-amd64.tgz -o ollama-linux-amd64.tgz sudo tar -C /usr -xzf ollama-linux-amd64.tgz启动 export OLLAMA_HOST0.0.0.0:11434 ollama serve访问ip:11434看到即成功 Ollama is running 手动安装deepseek…

力扣 单词拆分

动态规划,字符串截取,可重复用,集合类。 题目 单词可以重复使用,一个单词可用多次,应该是比较灵活的组合形式了,可以想到用dp,遍历完单词后的状态的返回值。而这里的wordDict给出的是list&…

【JVM详解二】常量池

一、常量池概述 JVM的常量池主要有以下几种: class文件常量池运行时常量池字符串常量池基本类型包装类常量池 它们相互之间关系大致如下图所示: 每个 class 的字节码文件中都有一个常量池,里面是编译后即知的该 class 会用到的字面量与符号引…

企业数据集成案例:吉客云销售渠道到MySQL

测试-查询销售渠道信息-dange:吉客云数据集成到MySQL的技术案例分享 在企业的数据管理过程中,如何高效、可靠地实现不同系统之间的数据对接是一个关键问题。本次我们将分享一个具体的技术案例——通过轻易云数据集成平台,将吉客云中的销售渠…

CTFHub-RCE系列wp

目录标题 引言什么是RCE漏洞 eval执行文件包含文件包含php://input读取源代码远程包含 命令注入无过滤过滤cat过滤空格过滤目录分隔符过滤运算符综合过滤练习 引言 题目共有如下类型 什么是RCE漏洞 RCE漏洞,全称是Remote Code Execution漏洞,翻译成中文…

深度学习之神经网络框架搭建及模型优化

神经网络框架搭建及模型优化 目录 神经网络框架搭建及模型优化1 数据及配置1.1 配置1.2 数据1.3 函数导入1.4 数据函数1.5 数据打包 2 神经网络框架搭建2.1 框架确认2.2 函数搭建2.3 框架上传 3 模型优化3.1 函数理解3.2 训练模型和测试模型代码 4 最终代码测试4.1 SGD优化算法…

STM32自学记录(十)

STM32自学记录 文章目录 STM32自学记录前言一、USART杂记二、实验1.学习视频2.复现代码 总结 前言 USART 一、USART杂记 通信接口:通信的目的:将一个设备的数据传送到另一个设备,扩展硬件系统。 通信协议:制定通信的规则&#x…

Linux --- 如何安装Docker命令并且使用docker安装Mysql【一篇内容直接解决】

目录 安装Docker命令 1.卸载原有的Docker: 2.安装docker: 3.启动docker: 4.配置镜像加速: 使用Docker安装Mysql 1.上传文件: 2.创建目录: 3.运行docker命令: 4.测试: 安装…

Linux磁盘空间使用率100%(解决删除文件后还是显示100%)

本文适用于,删除过了对应的数据文件,查看还是显示使用率100%的情况 首先使用df -h命令查看各个扇区所占用的情况 一、先对系统盘下所有文件大小进行统计,是否真的是数据存储以达到了磁盘空间 在对应的扇区路径下使用du -sh * | sort -hr 命…

Python——批量图片转PDF(GUI版本)

目录 专栏导读1、背景介绍2、库的安装3、核心代码4、完整代码总结专栏导读 🌸 欢迎来到Python办公自动化专栏—Python处理办公问题,解放您的双手 🏳️‍🌈 博客主页:请点击——> 一晌小贪欢的博客主页求关注 👍 该系列文章专栏:请点击——>Python办公自动化专…

IDEA查看项目依赖包及其版本

一.IDEA将现有项目转换为Maven项目 在IntelliJ IDEA中,将现有项目转换为Maven项目是一个常见的需求,可以通过几种不同的方法来实现。Maven是一个强大的构建工具,它可以帮助自动化项目的构建过程,管理依赖关系,以及其他许多方面。 添加Maven支持 如果你的项目还没有pom.xm…