PyTorch 深度学习实战(14):Deep Deterministic Policy Gradient (DDPG) 算法

在上一篇文章中,我们介绍了 Proximal Policy Optimization (PPO) 算法,并使用它解决了 CartPole 问题。本文将深入探讨 Deep Deterministic Policy Gradient (DDPG) 算法,这是一种用于连续动作空间的强化学习算法。我们将使用 PyTorch 实现 DDPG 算法,并应用于经典的 Pendulum 问题。


一、DDPG 算法基础

DDPG 是一种基于 Actor-Critic 框架的算法,专门用于解决连续动作空间的强化学习问题。它结合了深度 Q 网络(DQN)和策略梯度方法的优点,能够高效地处理高维状态和动作空间。

1. DDPG 的核心思想

  • 确定性策略

    • DDPG 使用确定性策略(Deterministic Policy),即给定状态时,策略网络直接输出一个确定的动作,而不是动作的概率分布。

  • 目标网络

    • DDPG 使用目标网络(Target Network)来稳定训练过程,类似于 DQN 中的目标网络。

  • 经验回放

    • DDPG 使用经验回放缓冲区(Replay Buffer)来存储和重用过去的经验,从而提高数据利用率。

2. DDPG 的优势

  • 适用于连续动作空间

    • DDPG 能够直接输出连续动作,适用于机器人控制、自动驾驶等任务。

  • 训练稳定

    • 通过目标网络和经验回放,DDPG 能够稳定地训练策略网络和价值网络。

  • 高效采样

    • DDPG 可以重复使用旧策略的采样数据,从而提高数据利用率。

3. DDPG 的算法流程

  1. 使用当前策略采样一批数据。

  2. 使用目标网络计算目标 Q 值。

  3. 更新 Critic 网络以最小化 Q 值的误差。

  4. 更新 Actor 网络以最大化 Q 值。

  5. 更新目标网络。

  6. 重复上述过程,直到策略收敛。


二、Pendulum 问题实战

我们将使用 PyTorch 实现 DDPG 算法,并应用于 Pendulum 问题。目标是控制摆杆使其保持直立。

1. 问题描述

Pendulum 环境的状态空间包括摆杆的角度和角速度。动作空间是一个连续的扭矩值,范围在 −2,2 之间。智能体每保持摆杆直立一步,就会获得一个负的奖励,目标是最大化累积奖励。

2. 实现步骤

  1. 安装并导入必要的库。

  2. 定义 Actor 网络和 Critic 网络。

  3. 定义 DDPG 训练过程。

  4. 测试模型并评估性能。

3. 代码实现

以下是完整的代码实现:

import gym
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import random
from collections import deque
import matplotlib.pyplot as plt
​
# 设置 Matplotlib 支持中文显示
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
​
# 检查 GPU 是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")
​
# 环境初始化
env = gym.make('Pendulum-v1')
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
max_action = float(env.action_space.high[0])
​
# 随机种子设置
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)
​
​
# 定义 Actor 网络
class Actor(nn.Module):def __init__(self, state_dim, action_dim, max_action):super(Actor, self).__init__()self.fc1 = nn.Linear(state_dim, 512)self.ln1 = nn.LayerNorm(512)  # 层归一化self.fc2 = nn.Linear(512, 512)self.ln2 = nn.LayerNorm(512)self.fc3 = nn.Linear(512, action_dim)self.max_action = max_action
​def forward(self, x):x = F.relu(self.ln1(self.fc1(x)))x = F.relu(self.ln2(self.fc2(x)))return self.max_action * torch.tanh(self.fc3(x))
​
​
# 定义 Critic 网络
class Critic(nn.Module):def __init__(self, state_dim, action_dim):super(Critic, self).__init__()self.fc1 = nn.Linear(state_dim + action_dim, 256)self.fc2 = nn.Linear(256, 256)self.fc3 = nn.Linear(256, 1)
​def forward(self, x, u):x = F.relu(self.fc1(torch.cat([x, u], 1)))x = F.relu(self.fc2(x))x = self.fc3(x)return x
​
​
# 添加OU噪声类
class OUNoise:def __init__(self, action_dim, mu=0, theta=0.15, sigma=0.2):self.mu = mu * np.ones(action_dim)self.theta = thetaself.sigma = sigmaself.reset()
​def reset(self):self.state = np.copy(self.mu)
​def sample(self):dx = self.theta * (self.mu - self.state) + self.sigma * np.random.randn(len(self.state))self.state += dxreturn self.state
​
​
# 定义 DDPG 算法
class DDPG:def __init__(self, state_dim, action_dim, max_action):self.actor = Actor(state_dim, action_dim, max_action).to(device)self.actor_target = Actor(state_dim, action_dim, max_action).to(device)self.actor_target.load_state_dict(self.actor.state_dict())self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=1e-4)
​self.critic = Critic(state_dim, action_dim).to(device)self.critic_target = Critic(state_dim, action_dim).to(device)self.critic_target.load_state_dict(self.critic.state_dict())self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=1e-3)self.noise = OUNoise(action_dim, sigma=0.2)  # 示例:Ornstein-Uhlenbeck噪声
​self.max_action = max_actionself.replay_buffer = deque(maxlen=1000000)self.batch_size = 64self.gamma = 0.99self.tau = 0.005self.noise_sigma = 0.5  # 初始噪声强度self.noise_decay = 0.995
​self.actor_lr_scheduler = optim.lr_scheduler.StepLR(self.actor_optimizer, step_size=100, gamma=0.95)self.critic_lr_scheduler = optim.lr_scheduler.StepLR(self.critic_optimizer, step_size=100, gamma=0.95)
​def select_action(self, state):state = torch.FloatTensor(state).unsqueeze(0).to(device)self.actor.eval()with torch.no_grad():action = self.actor(state).cpu().data.numpy().flatten()self.actor.train()return action
​def train(self):if len(self.replay_buffer) < self.batch_size:return
​# 从经验回放缓冲区中采样batch = random.sample(self.replay_buffer, self.batch_size)state = torch.FloatTensor(np.array([transition[0] for transition in batch])).to(device)action = torch.FloatTensor(np.array([transition[1] for transition in batch])).to(device)reward = torch.FloatTensor(np.array([transition[2] for transition in batch])).reshape(-1, 1).to(device)next_state = torch.FloatTensor(np.array([transition[3] for transition in batch])).to(device)done = torch.FloatTensor(np.array([transition[4] for transition in batch])).reshape(-1, 1).to(device)
​# 计算目标 Q 值next_action = self.actor_target(next_state)target_Q = self.critic_target(next_state, next_action)target_Q = reward + (1 - done) * self.gamma * target_Q
​# 更新 Critic 网络current_Q = self.critic(state, action)critic_loss = F.mse_loss(current_Q, target_Q.detach())self.critic_optimizer.zero_grad()critic_loss.backward()self.critic_optimizer.step()
​# 更新 Actor 网络actor_loss = -self.critic(state, self.actor(state)).mean()self.actor_optimizer.zero_grad()actor_loss.backward()self.actor_optimizer.step()
​# 更新目标网络for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
​def save(self, filename):torch.save(self.actor.state_dict(), filename + "_actor.pth")torch.save(self.critic.state_dict(), filename + "_critic.pth")
​def load(self, filename):self.actor.load_state_dict(torch.load(filename + "_actor.pth"))self.critic.load_state_dict(torch.load(filename + "_critic.pth"))
​
​
# 训练流程
def train_ddpg(env, agent, episodes=500):rewards_history = []moving_avg = []
​for ep in range(episodes):state,_ = env.reset()episode_reward = 0done = False
​while not done:action = agent.select_action(state)next_state, reward, done, _, _ = env.step(action)agent.replay_buffer.append((state, action, reward, next_state, done))state = next_stateepisode_reward += rewardagent.train()
​rewards_history.append(episode_reward)moving_avg.append(np.mean(rewards_history[-50:]))
​if (ep + 1) % 50 == 0:print(f"Episode: {ep + 1}, Avg Reward: {moving_avg[-1]:.2f}")
​return moving_avg, rewards_history
​
​
# 训练启动
ddpg_agent = DDPG(state_dim, action_dim, max_action)
moving_avg, rewards_history = train_ddpg(env, ddpg_agent)
​
# 可视化结果
plt.figure(figsize=(12, 6))
plt.plot(rewards_history, alpha=0.6, label='single round reward')
plt.plot(moving_avg, 'r-', linewidth=2, label='moving average (50 rounds)')
plt.xlabel('episodes')
plt.ylabel('reward')
plt.title('DDPG training performance on Pendulum-v1')
plt.legend()
plt.grid(True)
plt.show()

三、代码解析

  1. Actor 和 Critic 网络

    • Actor 网络输出连续动作,通过 tanh 函数将动作限制在 −max_action,max_action 范围内。

    • Critic 网络输出状态-动作对的 Q 值。

  2. DDPG 训练过程

    • 使用当前策略采样一批数据。

    • 使用目标网络计算目标 Q 值。

    • 更新 Critic 网络以最小化 Q 值的误差。

    • 更新 Actor 网络以最大化 Q 值。

    • 更新目标网络。

  3. 训练过程

    • 在训练过程中,每 50 个 episode 打印一次平均奖励。

    • 训练结束后,绘制训练过程中的总奖励曲线。


四、运行结果

运行上述代码后,你将看到以下输出:

  • 训练过程中每 50 个 episode 打印一次平均奖励。

  • 训练结束后,绘制训练过程中的总奖励曲线。


五、总结

本文介绍了 DDPG 算法的基本原理,并使用 PyTorch 实现了一个简单的 DDPG 模型来解决 Pendulum 问题。通过这个例子,我们学习了如何使用 DDPG 算法进行连续动作空间的策略优化。

在下一篇文章中,我们将探讨更高级的强化学习算法,如 Twin Delayed DDPG (TD3)。敬请期待!

代码实例说明

  • 本文代码可以直接在 Jupyter Notebook 或 Python 脚本中运行。

  • 如果你有 GPU,可以将模型和数据移动到 GPU 上运行,例如:actor = actor.to('cuda')state = state.to('cuda')

希望这篇文章能帮助你更好地理解 DDPG 算法!如果有任何问题,欢迎在评论区留言讨论。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/34549.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

3.14-1列表

列表 一.列表的介绍和定义 1 .列表 类型: <class list> 2.符号:[] 3.定义列表: 方式1:[] 通过[] 来定义 list[1,2,3,4,6] print(type(list)) #<class list> 方式2: 通过list 转换 str2"12345" print(type(str2)) #<class str> list2lis…

Java集合 - HashMap

HashMap 是 Java 集合框架中的一个重要类&#xff0c;位于 java.util 包中。它实现了 Map 接口&#xff0c;基于哈希表的数据结构来存储键值对&#xff08;key-value pairs&#xff09;。HashMap 允许使用 null 作为键和值&#xff0c;并且是非同步的&#xff08;非线程安全的&…

有效的山脉数组 力扣941

一、题目 给定一个整数数组 arr&#xff0c;如果它是有效的山脉数组就返回 true&#xff0c;否则返回 false。 让我们回顾一下&#xff0c;如果 arr 满足下述条件&#xff0c;那么它是一个山脉数组&#xff1a; arr.length > 3在 0 < i < arr.length - 1 条件下&am…

本地部署Spark集群

部署Spark集群大体上分为两种模式&#xff1a;单机模式与集群模式 大多数分布式框架都支持单机模式&#xff0c;方便开发者调试框架的运行环境。但是在生产环境中&#xff0c;并不会使用单机模式。 下面详细列举了Spark目前支持的部署模式。 &#xff08;1&#xff09;Local…

前端---初识HTML(前端三剑客)

1.HTML 先为大家介绍几个学习前端的网站&#xff1a;菜鸟教程&#xff0c;w3school&#xff0c;CSS HTML&#xff1a;超文本标记语言 超⽂本: ⽐⽂本要强⼤. 通过链接和交互式⽅式来组织和呈现信息的⽂本形式. 不仅仅有⽂本, 还可能包含图⽚, ⾳频, 或者⾃已经审阅过它的学者…

AcWing 4905. 面包店 二分

类似还有一个题是二分&#xff0c;用区间来判断是否有解 这个爆long long 有点坑了 const int N 1e2 10;LL n,tc,Tm; LL a[N],b[N],c[N];bool check(LL mid) {LL minx max(0LL,mid 1 - Tm),maxx min(tc - 1LL,mid);//将y转为x的函数,此时判断x是否有解//枚举所有客户的需…

SpringBoot 第一课(Ⅲ) 配置类注解

目录 一、PropertySource 二、ImportResource ①SpringConfig &#xff08;Spring框架全注解&#xff09; ②ImportResource注解实现 三、Bean 四、多配置文件 多Profile文件的使用 文件命名约定&#xff1a; 激活Profile&#xff1a; YAML文件支持多文档块&#xff…

2025年西安交通大学少年班招生考试初试数学试题(初中组)

1、已知正整数 x 、 y 、 z x、y、z x、y、z 满足 x y z 2025 xyz2025 xyz2025 &#xff0c; x 2 y y 2 z z 2 x x y 2 y z 2 z x 2 x^2yy^2zz^2xxy^2yz^2zx^2 x2yy2zz2xxy2yz2zx2&#xff0c;则 x 、 y 、 z x、y、z x、y、z 共有 ___ 组解。 2、在数 1 、 2 、 …

开发、科研、日常办公工具汇总(自用,持续更新)

主要记录汇总一下自己平常会用到的网站工具&#xff0c;方便查阅。 update&#xff1a;2025/2/11&#xff08;开发网站补一下&#xff09; update&#xff1a;2025/2/21&#xff08;补充一些AI工具&#xff0c;刚好在做AI视频相关工作&#xff09; update&#xff1a;2025/3/7&…

软件架构设计习题及复习

软件系统需求分析 系统需求模型转换为架构模型 软件架构设计 架构风格领域 难点 单选 平衡点是敏感点的一种&#xff0c;如果达到了平衡点一定要选平衡点&#xff0c;不能选敏感点添加层次不能提高系统性能&#xff0c;任何时候直接沟通性能最高效

ccf3501密码

//密码 #include<iostream> #include<cstring> using namespace std; int panduan(char a[]){int lstrlen(a);int s0;int zm0,sz0,t0;int b[26]{0},c[26]{0},d[10]{0},e0,f0;while(s<l&&l>6){if(a[s]<Z&&a[s]>A){b[a[s]-A];zm;}if(a[s…

【JavaEE进阶】Spring事务

目录 &#x1f343;前言 &#x1f334;事务简介 &#x1f6a9; 什么是事务? &#x1f6a9;为什么需要事务? &#x1f6a9;事务的操作 &#x1f340;Spring 中事务的实现 &#x1f6a9;Spring 编程式事务 &#x1f6a9;Spring声明式事务Transactional &#x1f6a9;T…

MySQL索引特性——会涉及索引的底层B+树

1 没有索引&#xff0c;可能会有什么问题 索引&#xff1a;提高数据库的性能&#xff0c;索引是物美价廉的东西了。不用加内存&#xff0c;不用改程序&#xff0c;不用调sql&#xff0c;只要执行正确的 create index &#xff0c;查询速度就可能提高成百上千倍。但是天下没有免…

给单片机生成字库的方案

Python 这段代码用来将txt文件中储存的字符串转变成二进制的像素数据 from PIL import Image, ImageFont, ImageDraw import osdef find_microsoft_yahei():"""Windows系统定位微软雅黑字体"""font_paths ["C:/Windows/Fonts/msyh.ttc&q…

01Spring Security框架

Spring Security是什么&#xff1f; Spring Security是⼀个提供身份验证、授权和针对常见攻击的保护的框架。 Spring Security做什么&#xff1f; 作为开发⼈员&#xff0c;在⽇常开发过程中需要⽤到Spring Security的场景⾮常多。事实上&#xff0c;对Web应⽤程序⽽⾔&#xf…

「BWAPP靶场通关记录(1)」A1注入漏洞

BWAPP通关秘籍&#xff08;1&#xff09;&#xff1a;A1 injection 1.HTML Injection - Reflected (GET)1.1Low1.2Medium1.3High 2.HTML Injection - Reflected (POST)2.1Low2.2Medium2.3High 3.HTML Injection - Reflected (URL)3.1Low3.2/3.3Medium/HIgh 4.HTML Injection - …

机器学习算法实战——敏感词检测(主页有源码)

✨个人主页欢迎您的访问 ✨期待您的三连 ✨ ✨个人主页欢迎您的访问 ✨期待您的三连 ✨ ✨个人主页欢迎您的访问 ✨期待您的三连✨ ​ ​​​ 1. 引言 随着互联网的快速发展&#xff0c;信息传播的速度和范围达到了前所未有的高度。然而&#xff0c;网络空间中也充斥着大量的…

Ollama+DeepSeek+NatCross内网穿透本地部署外网访问教程

目录 一、Ollama 简介 二、下载Ollama 三、下载并运行 DeepSeek 模型 四、运行 DeepSeek 模型 五、NatCross免费内网穿透 六、配置 ChatBox 连接 Ollama 七、外网使用ChatBox体验 一、Ollama 简介 Ollama 是一个开源的本地大模型部署工具&#xff0c;旨在让用户能够在个…

联想台式电脑启动项没有U盘

开机按F12&#xff0c;进入启动设备菜单&#xff0c;发现这里没有识别到插在主机的U盘&#xff1f; 解决方法 1、选上图的Enter Setup或者开机按F2&#xff0c;进入BIOS设置 选择Startup -> Primary Boot Sequence 2、选中“Excludeed from boot order”中U盘所在的一行 …

开源链动 2+1 模式 AI 智能名片 S2B2C 商城小程序助力社群发展中榜样影响力的提升

摘要&#xff1a;本文深入剖析了社群发展进程中榜样所承载的关键影响力&#xff0c;并对开源链动 21 模式 AI 智能名片 S2B2C 商城小程序在增强这一影响力方面所具备的潜力进行了全面探讨。通过对不同类型社群&#xff0c;如罗辑思维社群和 007 不出局社群中灵魂人物或意见领袖…