A2C原理和代码实现

参考王树森《深度强化学习》课程和书籍


1、A2C原理:

在这里插入图片描述


Observe a transition: ( s t , a t , r t , s t + 1 ) (s_t,{a_t},r_t,s_{t+1}) (st,at,rt,st+1)

TD target:
y t = r t + γ ⋅ v ( s t + 1 ; w ) . y_{t} = r_{t}+\gamma\cdot v(s_{t+1};\mathbf{w}). yt=rt+γv(st+1;w).
TD error:
δ t = v ( s t ; w ) − y t . \quad\delta_t = v(s_t;\mathbf{w})-y_t. δt=v(st;w)yt.
Update the policy network (actor) by:
θ ← θ − β ⋅ δ t ⋅ ∂ ln ⁡ π ( a t ∣ s t ; θ ) ∂ θ . \mathbf{\theta}\leftarrow\mathbf{\theta}-\beta\cdot\delta_{t}\cdot\frac{\partial\ln\pi(a_{t}\mid s_{t};\mathbf{\theta})}{\partial \mathbf{\theta}}. θθβδtθlnπ(atst;θ).


def compute_value_loss(self, bs, blogp_a, br, bd, bns):# 目标价值。with torch.no_grad():target_value = br + self.args.discount * torch.logical_not(bd) * self.V_target(bns).squeeze()# torch.logical_not 对输入张量取逻辑非# 计算value loss。value_loss = F.mse_loss(self.V(bs).squeeze(), target_value)return value_loss

Update the value network (critic) by:
w ← w − α ⋅ δ t ⋅ ∂ v ( s t ; w ) ∂ w . \mathbf{w}\leftarrow\mathbf{w}-\alpha\cdot\delta_{t}\cdot{\frac{\partial{v(s_{t}};\mathbf{w})}{\partial\mathbf{w}}}. wwαδtwv(st;w).


def compute_policy_loss(self, bs, blogp_a, br, bd, bns):# 建议对比08_a2c.py,比较二者的差异。with torch.no_grad():value = self.V(bs).squeeze()policy_loss = 0for i, logp_a in enumerate(blogp_a):policy_loss += -logp_a * value[i]policy_loss = policy_loss.mean()return policy_loss

2、A2C完整代码实现:

参考后修改注释:最初的代码在https://github.com/wangshusen/DRL

"""8.3节A2C算法实现。"""
import argparse
import os
from collections import defaultdict
import gym
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categoricalclass ValueNet(nn.Module):def __init__(self, dim_state):super().__init__()self.fc1 = nn.Linear(dim_state, 64)self.fc2 = nn.Linear(64, 32)self.fc3 = nn.Linear(32, 1)def forward(self, state):x = F.relu(self.fc1(state))x = F.relu(self.fc2(x))x = self.fc3(x)return xclass PolicyNet(nn.Module):def __init__(self, dim_state, num_action):super().__init__()self.fc1 = nn.Linear(dim_state, 64)self.fc2 = nn.Linear(64, 32)self.fc3 = nn.Linear(32, num_action)def forward(self, state):x = F.relu(self.fc1(state))x = F.relu(self.fc2(x))x = self.fc3(x)prob = F.softmax(x, dim=-1)return probclass A2C:def __init__(self, args):self.args = argsself.V = ValueNet(args.dim_state)self.V_target = ValueNet(args.dim_state)self.pi = PolicyNet(args.dim_state, args.num_action)self.V_target.load_state_dict(self.V.state_dict())def get_action(self, state):probs = self.pi(state)m = Categorical(probs)action = m.sample()logp_action = m.log_prob(action)return action, logp_actiondef compute_value_loss(self, bs, blogp_a, br, bd, bns):# 目标价值。with torch.no_grad():target_value = br + self.args.discount * torch.logical_not(bd) * self.V_target(bns).squeeze()# 计算value loss。value_loss = F.mse_loss(self.V(bs).squeeze(), target_value)return value_lossdef compute_policy_loss(self, bs, blogp_a, br, bd, bns):# 目标价值。with torch.no_grad():target_value = br + self.args.discount * torch.logical_not(bd) * self.V_target(bns).squeeze()# 计算policy loss。with torch.no_grad():advantage = target_value - self.V(bs).squeeze()policy_loss = 0for i, logp_a in enumerate(blogp_a):policy_loss += -logp_a * advantage[i]policy_loss = policy_loss.mean()return policy_lossdef soft_update(self, tau=0.01):def soft_update_(target, source, tau_=0.01):for target_param, param in zip(target.parameters(), source.parameters()):target_param.data.copy_(target_param.data * (1.0 - tau_) + param.data * tau_)soft_update_(self.V_target, self.V, tau)class Rollout:def __init__(self):self.state_lst = []self.action_lst = []self.logp_action_lst = []self.reward_lst = []self.done_lst = []self.next_state_lst = []def put(self, state, action, logp_action, reward, done, next_state):self.state_lst.append(state)self.action_lst.append(action)self.logp_action_lst.append(logp_action)self.reward_lst.append(reward)self.done_lst.append(done)self.next_state_lst.append(next_state)def tensor(self):bs = torch.as_tensor(self.state_lst).float()ba = torch.as_tensor(self.action_lst).float()blogp_a = self.logp_action_lstbr = torch.as_tensor(self.reward_lst).float()bd = torch.as_tensor(self.done_lst)bns = torch.as_tensor(self.next_state_lst).float()return bs, ba, blogp_a, br, bd, bnsclass INFO:def __init__(self):self.log = defaultdict(list)self.episode_length = 0self.episode_reward = 0self.max_episode_reward = -float("inf")def put(self, done, reward):if done is True:self.episode_length += 1self.episode_reward += rewardself.log["episode_length"].append(self.episode_length)self.log["episode_reward"].append(self.episode_reward)if self.episode_reward > self.max_episode_reward:self.max_episode_reward = self.episode_rewardself.episode_length = 0self.episode_reward = 0else:self.episode_length += 1self.episode_reward += rewarddef train(args, env, agent: A2C):V_optimizer = torch.optim.Adam(agent.V.parameters(), lr=3e-3)pi_optimizer = torch.optim.Adam(agent.pi.parameters(), lr=3e-3)info = INFO()rollout = Rollout()state, _ = env.reset()for step in range(args.max_steps):action, logp_action = agent.get_action(torch.tensor(state).float())next_state, reward, terminated, truncated, _ = env.step(action.item())done = terminated or truncatedinfo.put(done, reward)rollout.put(state,action,logp_action,reward,done,next_state,)state = next_stateif done is True:# 模型训练。bs, ba, blogp_a, br, bd, bns = rollout.tensor()value_loss = agent.compute_value_loss(bs, blogp_a, br, bd, bns)V_optimizer.zero_grad()value_loss.backward(retain_graph=True)V_optimizer.step()policy_loss = agent.compute_policy_loss(bs, blogp_a, br, bd, bns)pi_optimizer.zero_grad()policy_loss.backward()pi_optimizer.step()agent.soft_update()# 打印信息。info.log["value_loss"].append(value_loss.item())info.log["policy_loss"].append(policy_loss.item())episode_reward = info.log["episode_reward"][-1]episode_length = info.log["episode_length"][-1]value_loss = info.log["value_loss"][-1]print(f"step={step}, reward={episode_reward:.0f}, length={episode_length}, max_reward={info.max_episode_reward}, value_loss={value_loss:.1e}")# 重置环境。state, _ = env.reset()rollout = Rollout()# 保存模型。if episode_reward == info.max_episode_reward:save_path = os.path.join(args.output_dir, "model.bin")torch.save(agent.pi.state_dict(), save_path)if step % 10000 == 0:plt.plot(info.log["value_loss"], label="value loss")plt.legend()plt.savefig(f"{args.output_dir}/value_loss.png", bbox_inches="tight")plt.close()plt.plot(info.log["episode_reward"])plt.savefig(f"{args.output_dir}/episode_reward.png", bbox_inches="tight")plt.close()def eval(args, env, agent):agent = A2C(args)model_path = os.path.join(args.output_dir, "model.bin")agent.pi.load_state_dict(torch.load(model_path))episode_length = 0episode_reward = 0state, _ = env.reset()for i in range(5000):episode_length += 1action, _ = agent.get_action(torch.from_numpy(state))next_state, reward, terminated, truncated, info = env.step(action.item())done = terminated or truncatedepisode_reward += rewardstate = next_stateif done is True:print(f"episode reward={episode_reward}, length={episode_length}")state, _ = env.reset()episode_length = 0episode_reward = 0if __name__ == "__main__":parser = argparse.ArgumentParser()parser.add_argument("--env", default="CartPole-v1", type=str, help="Environment name.")parser.add_argument("--dim_state", default=4, type=int, help="Dimension of state.")parser.add_argument("--num_action", default=2, type=int, help="Number of action.")parser.add_argument("--output_dir", default="output", type=str, help="Output directory.")parser.add_argument("--seed", default=42, type=int, help="Random seed.")parser.add_argument("--max_steps", default=100_000, type=int, help="Maximum steps for interaction.")parser.add_argument("--discount", default=0.99, type=float, help="Discount coefficient.")parser.add_argument("--lr", default=1e-3, type=float, help="Learning rate.")parser.add_argument("--batch_size", default=32, type=int, help="Batch size.")parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")parser.add_argument("--do_train", action="store_true", help="Train policy.")parser.add_argument("--do_eval", action="store_true", help="Evaluate policy.")args = parser.parse_args()env = gym.make(args.env)agent = A2C(args)if args.do_train:train(args, env, agent)if args.do_eval:eval(args, env, agent)

3、torch.distributions.Categorical()

probs = policy_network(state)
# Note that this is equivalent to what used to be called multinomial
m = Categorical(probs) # 用probs构造一个分布
action = m.sample() # 按照probs进行采样
next_state, reward = env.step(action)
loss = -m.log_prob(action) * reward # log_prob 计算log(probs[action])的值
loss.backward()

Probability distributions - torch.distributions — PyTorch 2.0 documentation

next_state, reward = env.step(action)
loss = -m.log_prob(action) * reward # log_prob 计算log(probs[action])的值
loss.backward()


[Probability distributions - torch.distributions — PyTorch 2.0 documentation](https://pytorch.org/docs/stable/distributions.html)[【PyTorch】关于 log_prob(action) - 简书 (jianshu.com)](https://www.jianshu.com/p/06a5c47ee7c2)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/82714.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

如何在Spring MVC中使用@ControllerAdvice创建全局异常处理器

文章目录 前言一、认识注解:RestControllerAdvice和ExceptionHandler二、使用步骤1、封装统一返回结果类2、自定义异常类封装3、定义全局异常处理类4、测试 总结 前言 全局异常处理器是一种 🌟✨机制,用于处理应用程序中发生的异常&#xff…

ROS入门核心教材重要节选

ROS核心教程 1、文件系统 使用下述命令查看包 rospack ros pack(age) 如rospack find roscpp roscd ros cd 如roscd roscpp rosls ros ls 如rosls roscpp2、ROS节点 节点可以理解为人工定义一个机器人模块,然后抽象成可执行文件。 rosnode li…

TCP的四次挥手与TCP状态转换

文章目录 四次挥手场景步骤TCP状态转换 四次挥手场景 TCP客户端与服务器断开连接的时候,在程序中使用close()函数,会使用TCP协议四次挥手。 客户端和服务端都可以主动发起。 因TCP连接时候是双向的,所以断开的时候也是双向的。 步骤 三次…

LabVIEW开发3D颈动脉图像边缘检测

LabVIEW开发3D颈动脉图像边缘检测 近年来,超声图像在医学领域对疾病诊断具有重要意义。边缘检测是图像处理技术的重要组成部分。边缘包含图像信息。边缘检测的主要目的是根据强度和纹理等属性识别图像中均匀区域的边界。超声(US)图像存在视觉…

vue项目实战-脑图编辑管理系统kitymind百度脑图

前言 项目为前端vue项目,把kitymind百度脑图整合到前端vue项目中,显示了脑图的绘制,编辑,到处为json,png,text等格式的功能 文章末尾有相关的代码链接,代码只包含前端项目,在原始的…

微服务与Nacos概述

微服务概述 软件架构的演变:单体架构、垂直应用架构、流式计算架构 SOA、微服务架构和服务网格。 微服务是一种软件开发架构,它将一个大型应用程序拆分为一系列小型、独立的服务。每个服务都可以独立开发、部署和扩展,并通过轻量级的通信机…

事务,不只ACID | 京东物流技术团队

1. 什么是事务? 应用在运行时可能会发生数据库、硬件的故障,应用与数据库的网络连接断开或多个客户端端并发修改数据导致预期之外的数据覆盖问题,为了提高应用的可靠性和数据的一致性,事务应运而生。 从概念上讲,事务…

开发中常用的数据库日志都长啥样呢?

目录 常见日志级别 数据库日志 Undo log 逻辑日志 redolog binlog 慢查询日志 AOF 文本文件 RDB 二进制文件 常见日志级别 DEBUG:用于详细记录应用程序的运行过程,如变量值、执行流程等。DEBUG级别的日志通常用于开发和调试过程中,以…

[保研/考研机试] 约瑟夫问题No.2 C++实现

题目要求&#xff1a; 输入、输出样例&#xff1a; 源代码&#xff1a; #include<iostream> #include<queue> #include<vector> using namespace std;//例题5.2 约瑟夫问题No.2 int main() {int n, p, m;while (cin >> n >> p >> m) {//如…

业务中如何过滤敏感词

在我们访问网站的时候&#xff0c;如果发现我们发布的内容有色情暴力的东西等等&#xff0c;会屏蔽掉&#xff0c;这种行为就是过滤敏感词。 从技术层面实现起来&#xff0c;其实比较简单&#xff0c;因为我们输入的内容就是一个大型的字符串&#xff0c;我们要调用某些api来判…

ESP32开发阶段启用 Secure Boot 与 Flash encryption

Secure Boot 与 Flash encryption详情 请参考&#xff1a;https://blog.csdn.net/espressif/article/details/79362094 1、开发环境 AT版本&#xff1a;2.4.0.0 发布IDF 与 python&#xff1a; idf4.3_py3.10_env系统&#xff1a;虚拟机 ubuntu 20 2、使能 secure boot 和 …

【动态规划刷题 6】 删除并获得点数 粉刷房子

740. 删除并获得点数 给你一个整数数组 nums &#xff0c;你可以对它进行一些操作。 每次操作中&#xff0c;选择任意一个 nums[i] &#xff0c;删除它并获得 nums[i] 的点数。之后&#xff0c;你必须删除 所有 等于 nums[i] - 1 和 nums[i] 1 的元素。 开始你拥有 0 个点数。…

list模拟实现【引入反向迭代器】

文章目录 1.适配器1.1传统意义上的适配器1.2语言里的适配器1.3理解 2.list模拟实现【注意看反向迭代器】2.1 list_frame.h2.2riterator.h2.3list.h2.4 vector.h2.5test.cpp 3.反向迭代器的应用1.使用要求2.迭代器的分类 1.适配器 1.1传统意义上的适配器 1.2语言里的适配器 容…

实现链式队列

dl.h dl.c main.c 结果

BM5 合并k个已排序的链表 javascript

描述 合并 k 个升序的链表并将结果作为一个升序的链表返回其头节点。 数据范围&#xff1a; 示例1 输入&#xff1a; [{1,2,3},{4,5,6,7}] 返回值&#xff1a; {1,2,3,4,5,6,7}示例2 输入&#xff1a; [{1,2},{1,4,5},{6}] 返回值&#xff1a; {1,1,2,4,5,6}解题思路 利用两个…

RabbitMQ 发布确认机制

发布确认模式是避免消息由生产者到RabbitMQ消息丢失的一种手段 发布确认模式 原理说明实现方式开启confirm&#xff08;确认&#xff09;模式阻塞确认异步确认 总结 原理说明 生产者通过调用channel.confirmSelect方法将信道设置为confirm模式&#xff0c;之后RabbitMQ会返回Co…

spring 面试题

一、Spring面试题 专题部分 1.1、什么是spring? Spring是一个轻量级Java开发框架&#xff0c;最早有Rod Johnson创建&#xff0c;目的是为了解决企业级应用开发的业务逻辑层和其他各层的耦合问题。它是一个分层的JavaSE/JavaEE full-stack&#xff08;一站式&#xff09;轻量…

Unity之ShaderGraph 节点介绍 Utility节点

Utility 逻辑All&#xff08;所有分量都不为零&#xff0c;返回 true&#xff09;Any&#xff08;任何分量不为零&#xff0c;返回 true&#xff09;And&#xff08;A 和 B 均为 true&#xff09;Branch&#xff08;动态分支&#xff09;Comparison&#xff08;两个输入值 A 和…

未来C#上位机软件发展趋势

C#上位机软件迎来新的发展机遇。随着工业自动化的快速发展&#xff0c;C#作为一种流行的编程语言在上位机软件领域发挥着重要作用。未来&#xff0c;C#上位机软件可能会朝着以下几个方向发展&#xff1a; 1.智能化&#xff1a;随着人工智能技术的不断演进&#xff0c;C#上位机…

中间件RabbitMQ消息队列介绍

1. MQ的相关概念 1.1 什么是MQ MQ&#xff08;message queue&#xff09;&#xff0c;从字面意思上看&#xff0c;本质是个队列&#xff0c;FIFO先入先出&#xff0c;只不过队列中存放的内容是message而已&#xff0c;还是一种跨进程的通信机制&#xff0c;用于上下游传递消息…