Reinforcement Learning with Code【Code 5. Policy Gradient Methods】

Reinforcement Learning with Code【Code 5. Policy Gradient Methods】

This note records how the author begin to learn RL. Both theoretical understanding and code practice are presented. Many material are referenced such as ZhaoShiyu’s Mathematical Foundation of Reinforcement Learning, .

文章目录

  • Reinforcement Learning with Code【Code 5. Policy Gradient Methods】
    • 1. Policy Gradient 回顾
    • 2. Policy Gradient Code
    • Reference

1. Policy Gradient 回顾

之前介绍的 Q-learning、DQN 及 DQN 改进算法都是基于价值(value-based)的方法,其中 Q-learning 是处理有限状态的算法,而 DQN 可以用来解决连续状态的问题。在强化学习中,除了基于值函数的方法,还有一支非常经典的方法,那就是基于策略(policy-based)的方法。对比两者,基于值函数的方法主要是学习值函数,然后根据值函数导出一个策略,学习过程中并不存在一个显式的策略;而基于策略的方法则是直接显式地学习一个目标策略。策略梯度是基于策略的方法的基础,本章从策略梯度算法说起。

由之前的学习参考Reinforcement Learning with Code 【Chapter 9. Policy Gradient Methods】,可以策略梯度有三个metric可以使用,分别是平均状态值(average state value),平均奖励(average reward)和从特定状态出发的平均状态值(state value of a specific starting state)。其中使用最多就是从特定状态出发的平均状态值,记为 v π ( s 0 ) v_\pi(s_0) vπ(s0),其中 s 0 s_0 s0表示初始状态。所有当我们使用从特点状态出发的平均状态值(state value of a specific starting state)作为优化目标函数的时候,我们的待优化函数可以写作
max ⁡ θ J ( θ ) = E [ v π θ ( s 0 ) ] \max_\theta J(\theta) = \mathbb{E}[v_{\pi_\theta}(s_0)] θmaxJ(θ)=E[vπθ(s0)]
再根据策略梯度定理,则有证明略(可以参考Hands on RL)
∇ θ J ( θ ) = E [ ∇ θ ln ⁡ π ( A ∣ S ; θ ) q π ( S , A ) ] \nabla_\theta J(\theta) = \mathbb{E}[\nabla_\theta \ln \pi(A|S;\theta)q_\pi(S,A)] θJ(θ)=E[θlnπ(AS;θ)qπ(S,A)]
这一梯度更新法则是不能使用的,这是因为 q π ( S , A ) q_\pi(S,A) qπ(S,A)是真值,我们不能获得,我们可以借助Monte-Carlo的思想来对此进行更新,用一个episode的return来代替这个Q-value,即
q π ( s t , a t ) = ∑ k = t + 1 T γ k − t − 1 r k q_\pi(s_t,a_t) = \sum^T_{k=t+1}\gamma^{k-t-1} r_k qπ(st,at)=k=t+1Tγkt1rk
那我们获得的梯度更新法则为
∇ θ J ( θ ) = E [ ∇ θ ln ⁡ π ( A ∣ S ; θ ) × ∑ k = t + 1 T γ k − t − 1 r k ) ] \nabla_\theta J(\theta) = \mathbb{E}[\nabla_\theta \ln \pi(A|S;\theta) \times \sum^T_{k=t+1}\gamma^{k-t-1} r_k)] θJ(θ)=E[θlnπ(AS;θ)×k=t+1Tγkt1rk)]
还原出待优化的目标函数即为
max ⁡ θ J ( θ ) = E [ ln ⁡ π ( A ∣ S ; θ ) × ∑ k = t + 1 T γ k − t − 1 r k ) ] \max_\theta J(\theta) = \mathbb{E}[\ln \pi(A|S;\theta) \times \sum^T_{k=t+1}\gamma^{k-t-1} r_k)] θmaxJ(θ)=E[lnπ(AS;θ)×k=t+1Tγkt1rk)]
这个应用了Monte-Carlo思想的算法又被称为REINFORCE。

2. Policy Gradient Code

智能体的交互环境采用的是gymCartPole-v1环境,已经在中Reinforcement Learning with Code 【Code 4. Vanilla DQN】进行过介绍,此处不再赘述。

import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np# Define the policy network
class PolicyNet(nn.Module):def __init__(self, state_dim, hidden_dim, action_dim):super(PolicyNet, self).__init__()self.fc1 = nn.Linear(state_dim, hidden_dim)self.fc2 = nn.Linear(hidden_dim, action_dim)def forward(self, observation):x = F.relu(self.fc1(observation))x = F.softmax(self.fc2(x), dim=1)return x# Implement REINFORCE algorithm
class REINFORCE():def __init__(self, state_dim, hidden_dim, action_dim, learning_rate, gamma, device):self.policy_net = PolicyNet(state_dim, hidden_dim, action_dim).to(device)self.optimizer = torch.optim.Adam(self.policy_net.parameters(), lr=learning_rate)self.gamma = gammaself.device = devicedef choose_action(self, state):state = torch.tensor([state], dtype=torch.float).to(self.device)probs = self.policy_net(state)action_probs_dist = torch.distributions.Categorical(probs)    # generate prob distribution according to probsaction = action_probs_dist.sample().item()return actiondef learn(self, transition_dict):reward_list = transition_dict['rewards']state_list = transition_dict['states']action_list = transition_dict['actions']G = 0self.optimizer.zero_grad()for i in reversed(range(len(reward_list))):reward = reward_list[i]state = torch.tensor([state_list[i]], dtype=torch.float).to(self.device)action = torch.tensor([action_list[i]]).view(-1,1).to(self.device)log_prob = torch.log(self.policy_net(state).gather(dim=1,index=action))G = self.gamma * G + rewardloss = -log_prob * G    # 计算每一步的损失函数,有负号是因为我们需要max这个lossloss.backward()     # 反向传播累计梯度self.optimizer.step()   # after one episode 梯度更新def train_policy_net_agent(env, agent, num_episodes, seed):return_list = []for i in range(10):with tqdm(total = int(num_episodes/10), desc="Iteration %d"%(i+1)) as pbar:for i_episode in range(int(num_episodes/10)):episode_return = 0transition_dict = {'states': [],'actions': [],'next_states': [],'rewards': [],'dones': []}observation, _ = env.reset(seed=seed)done = Falsewhile not done:if render:env.render()action = agent.choose_action(observation)observation_, reward, terminated, truncated, _ = env.step(action)done = terminated or truncated# save one episode experience into a dicttransition_dict['states'].append(observation)transition_dict['actions'].append(action)transition_dict['next_states'].append(observation_)transition_dict['rewards'].append(reward)transition_dict['dones'].append(done)# swap stateobservation = observation_# compute one episode returnepisode_return += rewardreturn_list.append(episode_return)agent.learn(transition_dict)if((i_episode + 1) % 10 == 0):pbar.set_postfix({'episode': '%d'%(num_episodes / 10 * i + i_episode + 1),'return': '%.3f'%(np.mean(return_list[-10:]))})pbar.update(1)env.close()return return_listdef moving_average(a, window_size):cumulative_sum = np.cumsum(np.insert(a, 0, 0)) middle = (cumulative_sum[window_size:] - cumulative_sum[:-window_size]) / window_sizer = np.arange(1, window_size-1, 2)begin = np.cumsum(a[:window_size-1])[::2] / rend = (np.cumsum(a[:-window_size:-1])[::2] / r)[::-1]return np.concatenate((begin, middle, end))def plot_curve(return_list, mv_return, algorithm_name, env_name):episodes_list = list(range(len(return_list)))plt.plot(episodes_list, return_list, c='gray', alpha=0.6)plt.plot(episodes_list, mv_return)plt.xlabel('Episodes')plt.ylabel('Returns')plt.title('{} on {}'.format(algorithm_name, env_name))plt.show()if __name__=="__main__":learning_rate = 1e-3    # learning ratenum_episodes = 1000     # episodes lengthhidden_dim = 128        # hidden layers dimensiongamma = 0.98            # discounted ratedevice = torch.device('cuda' if torch.cuda.is_available() else 'gpu')env_name = 'CartPole-v1'    # gym env name  render = False              # render or not# reproducibleseed_number = 0torch.manual_seed(seed_number)np.random.seed(seed_number)if render:env = gym.make(id=env_name, render_mode='human')else:env = gym.make(id=env_name, render_mode=None)state_dim = env.observation_space.shape[0]action_dim = env.action_space.nagent = REINFORCE(state_dim, hidden_dim, action_dim, learning_rate, gamma, device)return_list = train_policy_net_agent(env, agent, num_episodes, seed_number)mv_return = moving_average(return_list, 9)plot_curve(return_list, mv_return, 'REINFORCE', env_name)

REINFORCE的效果如下图所示

Image

Reference

赵世钰老师的课程
Reinforcement Learning with Code 【Chapter 9. Policy Gradient Methods】
Hands on RL
Reinforcement Learning with Code 【Code 4. Vanilla DQN】

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/86247.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Kotlin~Mediator中介者模式

概念 创建一个中介来降低对象之间的耦合度,关系”多对多“变为“一对多”。 角色介绍 Mediator:抽象中介者,接口或者抽象类。ConcreteMediator:中介者具体实现,实现中介者接口,定义一个List管理Colleagu…

制造业企业数字化转型之设备数据采集

导 读 ( 文/ 1894 ) 随着信息技术的快速发展和制造业的转型升级,企业数字化转型已成为保持竞争力和实现可持续发展的关键。在数字化转型过程中,设备数据采集作为重要的一环,发挥着关键的作用。设备数据采集通过收集、分析和利用设备所产生的数…

激活函数总结(六):ReLU系列激活函数补充(RReLU、CELU、ReLU6)

激活函数总结(六):ReLU系列激活函数补充 1 引言2 激活函数2.1 RReLU激活函数2.2 CELU激活函数2.3 ReLU6 激活函数 3. 总结 1 引言 在前面的文章中已经介绍了介绍了一系列激活函数 (Sigmoid、Tanh、ReLU、Leaky ReLU、PReLU、Swish、ELU、SEL…

PHP 求解两字符串所有公共子序列及最长公共子序列 支持多字节字符串

/*** 获取两字符串所有公共子序列【不连续的】 例:abc ac > ac** param string $str1 字符串1* param string $str2 字符串2** return array*/ function public_sequence(string $str1, string $str2): array {$data [[-1, -1, , 0, ]]; // 子序列容器【横坐标 …

喜报!诚恒科技与赛时达科技达成BI金蝶云星空项目合作

随着全球数字化浪潮轰轰烈烈袭来,仅仅凭借手工处理的方式难以在庞大的数据海洋中精准获取信息、把握市场需求、了解目标用户,为企业创新提供强有力的支持。深圳赛时达科技有限公司(简称赛时达科技)希望通过数字化转型实现从手工处…

StarRocks企业级数据库

第1章 StarRocks简介 1.1 StarRocks介绍 StarRocks是新一代极速全场景MPP数据库 StraRocks充分吸收关系型OLAP数据库和分布式存储系统在大数据时代的优秀研究成果,在业界实践的基础上,进一步改进优化、升级架构,并增添了众多全新功能&…

06-2_Qt 5.9 C++开发指南_自定义对话框及其调用

本篇介绍到的对话框及其调用实例较为复杂但十分详细,如果做了解可以先参考:QT从入门到实战x篇_13_模态和非模态对话框创建。 文章目录 1. 对话框的不同调用方式2. 对话框QWDialogSize 的创建和使用2.1 创建对话框QWDialogSize2.2 对话框的调用和返回值 …

PE启动盘和U启动盘(第三十六课)

PE启动盘和U启动盘(第三十六课) 一 WindowsPE工具盘 1. 制作WinPE镜像光盘 双击WePE64_V2.2-是-点击右下角光盘图标-选择ISO的输出位置-立即生成ISO 2. 通过光盘启动WinPE

Android平台GB28181设备接入端如何实现多视频通道接入?

技术背景 我们在设计Android平台GB28181设备接入模块的时候,有这样的场景诉求,一个设备可能需要多个通道,常见的场景,比如车载终端,一台设备,可能需要接入多个摄像头,那么这台车载终端设备可以…

QT创建项目

可选择CMake或qmake

实例036 使窗体标题栏文字右对齐

实例说明 窗口标题栏中的文字是窗口的重要说明,该文字可以标示窗口的功能、状态或名称等信息,一般该文字是居左显示的,在本例中设计一个标题栏文字右对齐的窗口。本实例运行结果如图1.36所示。 技术要点 在C# 2.0中实现这一功能非常容易&am…

【Spring】-Spring的IoC和DI

作者:学Java的冬瓜 博客主页:☀冬瓜的主页🌙 专栏:【Framework】 主要内容:什么是spring?IoC容器是什么?如何使代码解耦合?IoC的核心原理,IoC的优点。依赖注入/对象装配/…

【C++基础(九)】C++内存管理--new一个对象出来

💓博主CSDN主页:杭电码农-NEO💓   ⏩专栏分类:C从入门到精通⏪   🚚代码仓库:NEO的学习日记🚚   🌹关注我🫵带你学习C   🔝🔝 C内存管理 1. 前言2. new2.1 new的使用方法2.2 …

MIT 6.830数据库系统 -- lab six

MIT 6.830数据库系统 -- lab six 项目拉取引言steal/no-force策略redo log与undo log日志格式和检查点 开始回滚练习1:LogFile.rollback() 恢复练习2:LogFile.recover() 测试结果疑问点分析 项目拉取 原项目使用ant进行项目构建,我已经更改为…

Nodejs安装及环境变量配置(修改全局安装依赖工具包和缓存文件夹及npm镜像源)

本机环境:win11家庭中文版 一、官网下载 二、安装 三、查看nodejs及npm版本号 1、查看node版本号 node -v 2、查看NPM版本号(安装nodejs时已自动安装npm) npm -v 四、配置npm全局下载工具包和缓存目录 1、查看安装目录 在本目录下创建no…

【学习日记】【FreeRTOS】手动任务切换详解

前言 本文是关于 FreeRTOS 中实现两个任务轮流切换并执行的代码详解。目前不支持优先级,仅实现两个任务轮流切换。 一、任务的自传 任务从生到死的过程究竟是怎么样的呢?(其实也没死),这个问题一直困扰着我&#xf…

【前端 | CSS】滚动到底部加载,滚动监听、懒加载

背景 在日常开发过程中,我们会遇到图片懒加载的功能,基本原理是,滚动条滚动到底部后再次获取数据进行渲染。 那怎么判断滚动条是否滚动到底部呢?滚动条滚动到底部触发时间的时机和方法又该怎样定义? 针对以上问题我…

关于技术转管理角色的认知

软件质量保障:所寫即所思|一个阿里质量人对测试的所感所悟。 程序员发展的岔路口 技术人做了几年专业工作之后,会来到一个重要的“分岔路口”,一边是专业的技术路线,一边是技术团队的管理路线。不少人就开始犯难&…

redis 数据结构(一)

Redis 为什么那么快 redis是一种内存数据库,所有的操作都是在内存中进行的,还有一种重要原因是:它的数据结构的设计对数据进行增删查改操作很高效。 redis的数据结构是什么 redis数据结构是对redis键值对值的数据类型的底层的实现&#xff0c…

网站SSL安全证书是什么及其重要性

网站SSL安全证书具体来说是一个数字文件,是由受信任的数字证书颁发机构(CA机构)进行审核颁发的,其中包含CA发布的信息,该信息表明该网站已使用加密连接进行了安全保护。 网站SSL安全证书也被称为SSL证书、https证书和…