动手学强化学习 第 12 章 PPO 算法 训练代码

基于 Hands-on-RL/第12章-PPO算法.ipynb at main · boyu-ai/Hands-on-RL · GitHub

理论 PPO 算法

修改了警告和报错

运行环境

Debian GNU/Linux 12
Python 3.9.19
torch 2.0.1
gym 0.26.2

运行代码

PPO.py

#!/usr/bin/env pythonimport gym
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import rl_utilsclass PolicyNet(torch.nn.Module):def __init__(self, state_dim, hidden_dim, action_dim):super(PolicyNet, self).__init__()self.fc1 = torch.nn.Linear(state_dim, hidden_dim)self.fc2 = torch.nn.Linear(hidden_dim, action_dim)def forward(self, x):x = F.relu(self.fc1(x))return F.softmax(self.fc2(x), dim=1)class ValueNet(torch.nn.Module):def __init__(self, state_dim, hidden_dim):super(ValueNet, self).__init__()self.fc1 = torch.nn.Linear(state_dim, hidden_dim)self.fc2 = torch.nn.Linear(hidden_dim, 1)def forward(self, x):x = F.relu(self.fc1(x))return self.fc2(x)class PPO:''' PPO算法,采用截断方式 '''def __init__(self, state_dim, hidden_dim, action_dim, actor_lr, critic_lr,lmbda, epochs, eps, gamma, device):self.actor = PolicyNet(state_dim, hidden_dim, action_dim).to(device)self.critic = ValueNet(state_dim, hidden_dim).to(device)self.actor_optimizer = torch.optim.Adam(self.actor.parameters(),lr=actor_lr)self.critic_optimizer = torch.optim.Adam(self.critic.parameters(),lr=critic_lr)self.gamma = gammaself.lmbda = lmbdaself.epochs = epochs  # 一条序列的数据用来训练轮数self.eps = eps  # PPO中截断范围的参数self.device = devicedef take_action(self, state):state = torch.tensor(np.array([state]), dtype=torch.float).to(self.device)probs = self.actor(state)action_dist = torch.distributions.Categorical(probs)action = action_dist.sample()return action.item()def update(self, transition_dict):states = torch.tensor(np.array(transition_dict['states']),dtype=torch.float).to(self.device)actions = torch.tensor(transition_dict['actions']).view(-1, 1).to(self.device)rewards = torch.tensor(transition_dict['rewards'],dtype=torch.float).view(-1, 1).to(self.device)next_states = torch.tensor(np.array(transition_dict['next_states']),dtype=torch.float).to(self.device)dones = torch.tensor(transition_dict['dones'],dtype=torch.float).view(-1, 1).to(self.device)td_target = rewards + self.gamma * self.critic(next_states) * (1 -dones)td_delta = td_target - self.critic(states)advantage = rl_utils.compute_advantage(self.gamma, self.lmbda,td_delta.cpu()).to(self.device)old_log_probs = torch.log(self.actor(states).gather(1,actions)).detach()for _ in range(self.epochs):log_probs = torch.log(self.actor(states).gather(1, actions))ratio = torch.exp(log_probs - old_log_probs)surr1 = ratio * advantagesurr2 = torch.clamp(ratio, 1 - self.eps,1 + self.eps) * advantage  # 截断actor_loss = torch.mean(-torch.min(surr1, surr2))  # PPO损失函数critic_loss = torch.mean(F.mse_loss(self.critic(states), td_target.detach()))self.actor_optimizer.zero_grad()self.critic_optimizer.zero_grad()actor_loss.backward()critic_loss.backward()self.actor_optimizer.step()self.critic_optimizer.step()actor_lr = 1e-3
critic_lr = 1e-2
num_episodes = 500
hidden_dim = 128
gamma = 0.98
lmbda = 0.95
epochs = 10
eps = 0.2
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")env_name = 'CartPole-v1'
env = gym.make(env_name)
env.reset(seed=0)
torch.manual_seed(0)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
agent = PPO(state_dim, hidden_dim, action_dim, actor_lr, critic_lr, lmbda,epochs, eps, gamma, device)return_list = rl_utils.train_on_policy_agent(env, agent, num_episodes)episodes_list = list(range(len(return_list)))
plt.plot(episodes_list, return_list)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title('PPO on {}'.format(env_name))
plt.show()mv_return = rl_utils.moving_average(return_list, 9)
plt.plot(episodes_list, mv_return)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title('PPO on {}'.format(env_name))
plt.show()

rl_utils.py 参考

动手学强化学习 第 11 章 TRPO 算法 训练代码-CSDN博客

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/388313.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【python】PyQt5中QRadioButton的详细用法教程与应用实战

✨✨ 欢迎大家来到景天科技苑✨✨ 🎈🎈 养成好习惯,先赞后看哦~🎈🎈 🏆 作者简介:景天科技苑 🏆《头衔》:大厂架构师,华为云开发者社区专家博主,…

[windows] 关于多线程中使用SendMessage

https://developer.aliyun.com/article/228325

@antv/x6 利用工具,在节点的左上角,或者节点的右上角,增加一个X的红色删除小按钮。

1、上个图: 官方地址:https://x6.antv.antgroup.com/tutorial/intermediate/tools 2、鼠标移上去,左上角会有一个删除小按钮,这个是x6自动的功能,只要稍微写二行代码就实现了: graph.on(node:mouseenter,…

leetcode 矩阵专题——java实现

73. 矩阵置零 给定一个 m x n 的矩阵,如果一个元素为 0 ,则将其所在行和列的所有元素都设为 0 。请使用 原地 算法。 输入:matrix [[1,1,1],[1,0,1],[1,1,1]] 输出:[[1,0,1],[0,0,0],[1,0,1]] 关键在于:一次扫描全表…

【云服务器】vscode + onethingAi + SSH远程连接

通过VS code远程连接服务器,并进行上传和下载文件操作_vs code 上传制定文件-CSDN博客 vscode远程连接服务器(remote ssh)上传本地文件到服务器(sftp)_vscode上传文件到服务器-CSDN博客 vscode连接远程服务器(傻瓜式教学&#x…

苹果电脑怎么录制屏幕?3招教你轻松录制,高效实用

随着数字化时代的快速发展,屏幕录制已经成为我们日常工作和生活中不可或缺的一部分。它不仅是展示产品、教授知识、分享经验的重要工具,更是我们展现个性和创造力的新舞台。在苹果电脑上,屏幕录制功能的应用更是将这一体验推向了新的高度。 …

【屏驱MCU】RT-Thread 文件系统接口解析

本文主要介绍【屏驱MCU】基于RT-Thread 系统的文件系统原理介绍与代码接口梳理 目录 0. 个人简介 && 授权须知1. 文件系统架构1.1 虚拟文件系统目录架构 2. menuconfig 分析3. 代码接口分析3.1 DFS框架挂载目录3.2 【FAL抽象层】分区表和设备表3.3 如何将【文件路径】挂…

多任务协程处理的流程,看看是否和你想像的一样

import time import asyncioasync def func1():print("你好,我是第一个任务")await asyncio.sleep(3)print("你好,我是第二个任务")async def func2():print("你好,我是第3个任务")await asyncio.sleep(2)…

GNSS形变监测系统

TH-WY1 GNSS形变监测系统采用扼流圈设计有以下几个优势: 高精度测量:扼流圈是一种高精度的传感器,可以提供非常精确的测量结果。这使得GNSS形变监测系统能够准确地测量结构物的形变变化。 高稳定性:扼流圈设计使得传感器具有良好…

第33篇 计算数据中最长的连续1的个数<三>

Q:如何将计算出的结果(最长的连续1的个数)显示在DE2-115开发板的HEX上? A:基本原理:DE2-115_Computer_System中的HEX并行端口作为内存映射设备连接到DE2-115开发板的七段数码管,每个端口都对应…

大模型提示工程(Prompt),让LLM自己优化提示词

前言 随着大家对于prompt提问的研究以及对于高质量回答的追求,现在有一个比较热的词叫做prompt creator。 Prompt Creator 实际上是使得 ChatGPT 更好的引导你去完善自己的提问,同时也完善自己的回答,更好地指导自己回答出更加令使用者满意…

win10桌面任务栏美化(不用软件)(任务栏应用居中,透明任务栏)

透明任务栏 1、打开设置——个性化——颜色,打开透明效果; 2、在搜索框搜索注册表编辑器; 3、找如下路径:计算机\HKEY-CURRENT-USER\Software\Microsoft\Windows\CurrentVersion\Explorer\Advanced; 4、寻找文件&a…

【TS】TypeScript类型断言:掌握类型转换的艺术

🌈个人主页: 鑫宝Code 🔥热门专栏: 闲话杂谈| 炫酷HTML | JavaScript基础 ​💫个人格言: "如无必要,勿增实体" 文章目录 TypeScript类型断言:掌握类型转换的艺术1. 引言2. 什么是类型断言&a…

链表的实现(C++版)

对于链表的学习,之前在C语言部分的时候就已经有学习过,也学会了使用C语言来打造一个链表.如今学了C 则想通过C来打造一个链表,以达到锻炼自己的目的. 1.链表的初步实现 1.节点模板的设置 template <class T> struct ListNode{ListNode <T>* _next;ListNode <T…

k8s学习--使用kubepshere部署devops项目时遇到的报错(无法找到gitee仓库)

今天在kubesphere部署devops项目&#xff0c;编辑流水线的时候&#xff0c;发现怎么也访问不到gitee仓库 报错的流水线位置 报错日志 报错原因 变量问题 因为看见了csy/sangomall&#xff0c;所以理所当然的把路径变量GITEE_ACCOUNT写成了用户名 解决方法 结果发现仓库…

可靠的图纸加密软件,七款图纸加密软件推荐

大家好啊,我是小固,今天跟大家聊聊图纸加密软件。 作为一名设计师,我深知保护自己的知识产权有多重要。曾经就因为图纸泄露,差点血本无归,那个教训可真是惨痛啊!所以我今天就给大家推荐几款靠谱的图纸加密软件,希望能帮到你们。 固信软件https://www.gooxion.com/ 首先要隆重…

Java语言程序设计——篇十一(1)

&#x1f33f;&#x1f33f;&#x1f33f;跟随博主脚步&#xff0c;从这里开始→博主主页&#x1f33f;&#x1f33f;&#x1f33f; 欢迎大家&#xff1a;这里是CSDN&#xff0c;我的学习笔记、总结知识的地方&#xff0c;喜欢的话请三连&#xff0c;有问题可以私信&#x1f33…

Vue 的安装与配置

今天是八月一日&#xff0c;我也开启了Vue的学习&#xff0c;希望和大家一起学编程&#xff0c;相互督促&#xff0c;相互进步&#xff01; 安装vscode 安装Node.js 官网&#xff1a;https://nodejs.org/zh-cn 下载完正常安装就行 可以winr输入cmd&#xff0c;也可以vscod…

springboot智能健康管理平台-计算机毕业设计源码57256

摘要 在当今社会&#xff0c;人们越来越重视健康饮食和健康管理。借助SpringBoot框架和MySQL数据库的支持&#xff0c;开发智能健康管理平台成为可能。该平台结合了小程序技术的便利性和SpringBoot框架的快速开发能力&#xff0c;为用户提供了便捷的健康管理解决方案。 通过智能…