强化学习--多维动作状态空间的设计

目录

  • 一、离散动作
  • 二、连续动作
    • 1、例子1
    • 2、知乎给出的示例
    • 2、github里面的代码

免责声明:以下代码部分来自网络,部分来自ChatGPT,部分来自个人的理解。如有其他观点,欢迎讨论!

一、离散动作

注意:本文均以PPO算法为例。

# time: 2023/11/22 21:04
# author: YanJPimport torch
import torch
import torch.nn as nn
from torch.distributions import Categoricalclass MultiDimensionalActor(nn.Module):def __init__(self, input_dim, output_dims):super(MultiDimensionalActor, self).__init__()# Define a shared feature extraction networkself.feature_extractor = nn.Sequential(nn.Linear(input_dim, 128),nn.ReLU(),nn.Linear(128, 64),nn.ReLU())# Define individual output layers for each action dimensionself.output_layers = nn.ModuleList([nn.Linear(64, num_actions) for num_actions in output_dims])def forward(self, state):# Feature extractionfeatures = self.feature_extractor(state)# Generate Categorical objects for each action dimensioncategorical_objects = [Categorical(logits=output_layer(features)) for output_layer in self.output_layers]return categorical_objects# 定义主函数
def main():# 定义输入状态维度和每个动作维度的动作数input_dim = 10output_dims = [5, 8]  # 两个动作维度,分别有 3 和 4 个可能的动作# 创建 MultiDimensionalActor 实例actor_network = MultiDimensionalActor(input_dim, output_dims)# 生成输入状态(这里使用随机数据作为示例)state = torch.randn(1, input_dim)# 调用 actor 网络categorical_objects = actor_network(state)# 输出每个动作维度的采样动作和对应的对数概率for i, categorical in enumerate(categorical_objects):sampled_action = categorical.sample()log_prob = categorical.log_prob(sampled_action)print(f"Sampled action for dimension {i+1}: {sampled_action.item()}, Log probability: {log_prob.item()}")if __name__ == "__main__":main()#Sampled action for dimension 1: 1, Log probability: -1.4930928945541382
#Sampled action for dimension 2: 3, Log probability: -2.1875085830688477

注意代码中categorical函数的两个不同传入参数的区别:参考链接
简单来说,logits是计算softmax的,probs直接就是已知概率的时候传进去就行。

二、连续动作

参考链接:github、知乎
为什么取对数概率?参考回答
在这里插入图片描述

1、例子1

先看如下的代码:

# time: 2023/11/21 21:33
# author: YanJP
#这是对应多维连续变量的例子:
# 参考链接:https://github.com/XinJingHao/PPO-Continuous-Pytorch/blob/main/utils.py
# https://www.zhihu.com/question/417161289
import torch.nn as nn
import torch
class Policy(nn.Module):def __init__(self, in_dim, n_hidden_1, n_hidden_2, num_outputs):super(Policy, self).__init__()self.layer = nn.Sequential(nn.Linear(in_dim, n_hidden_1),nn.ReLU(True),nn.Linear(n_hidden_1, n_hidden_2),nn.ReLU(True),nn.Linear(n_hidden_2, num_outputs))class Normal(nn.Module):def __init__(self, num_outputs):super().__init__()self.stds = nn.Parameter(torch.zeros(num_outputs))  #创建一个可学习的参数 def forward(self, x):dist = torch.distributions.Normal(loc=x, scale=self.stds.exp())action = dist.sample((every_dimention_output,))  #这里我觉得是最重要的,不填sample的参数的话,默认每个分布只采样一个值!!!!!!!!return actionif __name__ == '__main__':policy = Policy(4,20,20,5)normal = Normal(5) #设置5个维度every_dimention_output=10  #每个维度10个输出observation = torch.Tensor(4)action = normal.forward(policy.layer( observation))print("action: ",action)
  • self.stds.exp(),表示求指数,因为正态分布的标准差都是正数。
  • action = dist.sample((every_dimention_output,))这里最重要!!!

2、知乎给出的示例


class Agent(nn.Module):def __init__(self, envs):super(Agent, self).__init__()self.actor_mean = nn.Sequential(layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)),nn.Tanh(),layer_init(nn.Linear(64, 64)),nn.Tanh(),layer_init(nn.Linear(64, np.prod(envs.single_action_space.shape)), std=0.01),)self.actor_logstd = nn.Parameter(torch.zeros(1, np.prod(envs.single_action_space.shape)))def get_action_and_value(self, x, action=None):action_mean = self.actor_mean(x)action_logstd = self.actor_logstd.expand_as(action_mean)action_std = torch.exp(action_logstd)probs = Normal(action_mean, action_std)if action is None:action = probs.sample()return action, probs.log_prob(action).sum(1), probs.entropy().sum(1), self.critic(x)

这里的np.prod(envs.single_action_space.shape),表示每个维度的动作数相乘,然后初始化这么多个actor网络的标准差和均值,最后action里面的sample就是采样这么多个数据。(感觉还是拉成了一维计算)

2、github里面的代码

github

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Beta,Normalclass GaussianActor_musigma(nn.Module):def __init__(self, state_dim, action_dim, net_width):super(GaussianActor_musigma, self).__init__()self.l1 = nn.Linear(state_dim, net_width)self.l2 = nn.Linear(net_width, net_width)self.mu_head = nn.Linear(net_width, action_dim)self.sigma_head = nn.Linear(net_width, action_dim)def forward(self, state):a = torch.tanh(self.l1(state))a = torch.tanh(self.l2(a))mu = torch.sigmoid(self.mu_head(a))sigma = F.softplus( self.sigma_head(a) )return mu,sigmadef get_dist(self, state):mu,sigma = self.forward(state)dist = Normal(mu,sigma)return distdef deterministic_act(self, state):mu, sigma = self.forward(state)return mu

上述代码主要是通过设置mu_head 和sigma_head的个数,来实现多维动作。

class GaussianActor_mu(nn.Module):def __init__(self, state_dim, action_dim, net_width, log_std=0):super(GaussianActor_mu, self).__init__()self.l1 = nn.Linear(state_dim, net_width)self.l2 = nn.Linear(net_width, net_width)self.mu_head = nn.Linear(net_width, action_dim)self.mu_head.weight.data.mul_(0.1)self.mu_head.bias.data.mul_(0.0)self.action_log_std = nn.Parameter(torch.ones(1, action_dim) * log_std)def forward(self, state):a = torch.relu(self.l1(state))a = torch.relu(self.l2(a))mu = torch.sigmoid(self.mu_head(a))return mudef get_dist(self,state):mu = self.forward(state)action_log_std = self.action_log_std.expand_as(mu)action_std = torch.exp(action_log_std)dist = Normal(mu, action_std)return distdef deterministic_act(self, state):return self.forward(state)
class Critic(nn.Module):def __init__(self, state_dim,net_width):super(Critic, self).__init__()self.C1 = nn.Linear(state_dim, net_width)self.C2 = nn.Linear(net_width, net_width)self.C3 = nn.Linear(net_width, 1)def forward(self, state):v = torch.tanh(self.C1(state))v = torch.tanh(self.C2(v))v = self.C3(v)return v

上述代码只定义了mu的个数与维度数一样,std作为可学习的参数之一。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/201405.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Docker】Docker安装Nginx配置静态资源

1.下载镜像 2.创建nginx配置文件 3.创建nginx容器运行 4.配置nginx静态资源 1.下载镜像 Dockerhub官网:Docker docker pull nginx docker pull nginx下载最新版本 默认latest 下载指定版本docker pull nginx:xxx 2.创建nginx配置文件 启动容器之前要创建nginx…

【云原生】初识 Service Mesh

目录 一、什么是Service Mesh 二、微服务发展历程 2.1 微服务架构演进历史 2.1.1 单体架构 2.1.2 SOA阶段 2.1.3 微服务阶段 2.2 微服务治理中的问题 2.2.1 技术栈庞杂 2.2.2 版本升级碎片化 2.2.3 侵入性强 2.2.4 中间件多,学习成本高 2.2.5 服务治理功…

智能座舱架构与芯片- (8) 视觉篇

一、概述 相比起用于ADAS感知系统的摄像头,用于智能座舱内部的摄像头,其功能特性和性能要求相对简单。例如,OMS乘客监控摄像头,一般达到5MP即可有良好的效果。同时,OMS也可应用于车内会议系统,还应用于车内…

selenium判断元素是否存在的方法

文章目录 快捷方法完整示例程序 快捷方法 selenium没有exist_xxx相关的方法,无法直接判断元素存在。但是锁定元素时使用的browser.find_elements(By.CSS_SELECTOR, "css元素")会返回一个列表list,如果不存在这个元素就会返回一个空列表。因此…

【办公常识】写好的代码如何上传?使用svn commit

首先找到对应的目录 找到文件之后点击SVN Commit

ModernCSS.dev - 来自微软前端工程师的 CSS 高级教程,讲解如何用新的 CSS 语法来解决旧的问题

今天给大家安利一套现代 CSS 的教程,以前写网页的问题,现在都可以用新的写法来解决了。 ModernCSS.dev 是一个现代 CSS 语法的教程,讲解新的 CSS 语法如何解决一些传统问题,一共有30多课。 这套教程的作者是 Stephanie Eckles&am…

git clone慢的解决办法

在网站 https://www.ipaddress.com/ 分别搜索: github.global.ssl.fastly.net github.com 得到ip: 打开hosts文件 sudo vim /etc/hosts 在hosts文件末尾添加 140.82.114.3 github.com 151.101.1.194 github.global-ssl.fastly.net 151.101.65.194 g…

汇编-间接寻址(处理数组)

直接寻址很少用于数组处理,因为用常数偏移量来寻址多个数组元素时,直接寻址并不实用。取而代之的是使用寄存器作为指针(称为间接寻址(indirect addressing) ) 并控制该寄存器的值。如果一个操作数使用的是间接寻址, 就称之为间接操作数(indie…

minio安装使用-linux

下载地址:MinIO | Code and downloads to create high performance object storage 选择 minio server 可以直接下载二进制文件。 将下载的文件传输到服务器的指定文件夹下,如 /opt/minio。 然后在,命令行启动minio: /opt/mini…

网工内推 | 合资公司网工,CCNP/HCIP认证优先,朝九晚六

01 中企网络通信技术有限公司 招聘岗位:网络工程师 职责描述: 1、按照工作流程和指引监控网络运行情况和客户连接状况; 2、确保各监控系统能正常运作; 3、快速响应各个网络告警事件; 4、判断出网络故障,按…

利用OpenCV实现图片中导线的识别

下面是一个需求,识别图片中的导线,要在图像中检测导线,我们需要采用不同于直线检测的方法。由于OpenCV没有直接的曲线检测函数,如同它对直线提供的HoughLines或HoughLinesP,检测曲线通常需要更多的图像处理步骤和算法&…

集成多元算法,打造高效字面文本相似度计算与匹配搜索解决方案,助力文本匹配冷启动[BM25、词向量、SimHash、Tfidf、SequenceMatcher]

搜索推荐系统专栏简介:搜索推荐全流程讲解(召回粗排精排重排混排)、系统架构、常见问题、算法项目实战总结、技术细节以及项目实战(含码源) 专栏详细介绍:搜索推荐系统专栏简介:搜索推荐全流程讲解(召回粗排精排重排混排)、系统架构、常见问题、算法项目实战总结、技术…

单体进化微服务:拆分、注册、调用、网关、过滤、治理、分布式事务

这里写目录标题 基本介绍生产-消费-网关父依赖生产者服务消费者服务网关服务common服务 感想 基本介绍 Spring Cloud 是一个用于构建分布式系统和微服务架构的开发工具包。它提供了一系列的功能和组件,用于解决微服务架构中的常见问题,如服务注册与发现…

Hadoop -hdfs的读写请求

1、HDFS写数据(宏观): 1、首先,客户端发送一个写数据的请求,通过rpc与NN建立连接,NN会做一些简单的校验,文件是否存在,是否有空间存储数据等。 2、NN就会将校验的结果发送给客户端…

在win10上安装pytorch-gpu版本2

安装anaconda即下载了python,还可以创建虚拟环境。 目录 1.1 anaconda安装 1.2 pytorch-gpu安装 1.1 Anaconda安装 anaconda的安装请看我之前发的tensoflow-gpu安装,里面有详细的安装过程,这里不做重复描述,传送门 1.2 pyt…

注解案例:山寨Junit与山寨JPA

作者简介:大家好,我是smart哥,前中兴通讯、美团架构师,现某互联网公司CTO 联系qq:184480602,加我进群,大家一起学习,一起进步,一起对抗互联网寒冬 上篇讲了什么是注解&am…

echarts的使用

1. 普通版 其实主要就是option1&#xff0c;option1就是画的图 echats不能响应刷新&#xff0c;要想实时刷新监听刷新的值重新调用一下方法即可 html <div class"echart" style"width: 100%;height: calc(100% - 130px)" ref"main1">&l…

排序算法-----快速排序(非递归实现)

目录 前言 快速排序 基本思路 非递归代码实现 前言 很久没跟新数据结构与算法这一栏了&#xff0c;因为数据结构与算法基本上都发布完了&#xff0c;哈哈&#xff0c;那今天我就把前面排序算法那一块的快速排序完善一下&#xff0c;前面只发布了快速排序递归算法&#xff0c;…

Java架构师软件架构风格

目录 1 数据流风格1.1 管道过滤器1.2 数据流风格的优点2 调用返回风格2.1 面向对象风格2.2 调用返回风格总结3 独立构件风格3.1 事件驱动系统风格的主要特点3.2 独立构件风格总结4 虚拟机风格4.1 虚拟机风格总结5 仓库风格5.1 仓库风格总结想学习架构师构建流程请跳转:Java架构…

VSCode任务tasks.json中的问题匹配器problemMatcher的问题匹配模式ProblemPattern详解

☞ ░ 前往老猿Python博客 ░ https://blog.csdn.net/LaoYuanPython 一、简介 在 VS Code 中&#xff0c;tasks.json 文件中的 problemMatcher 字段用于定义如何解析任务输出中的问题&#xff08;错误、警告等&#xff09;。 problemMatcher有三种配置方式&#xff0c;具体可…