论文阅读-多目标强化学习-envelope MOQ-learning

introduction

一种多目标强化学习算法,来自2019 Nips《A Generalized Algorithm for Multi-Objective Reinforcement Learning and Policy Adaptation》本文引用代码全部来源于论文中的链接。主要参考run_e3c_double.py文件

1 总体思想

1.将输入中加入多目标的偏好参数。 2. 在输出中改本为标量的状态价值为向量的状态价值。 3. 实现了可以在多个目标上寻找帕累托前沿,也即多目标最优解的算法。

2 算法

虽然论文中用的是Q-learning的架构,但是在提供的代码中,采用的是A3C的架构,使用envelope 网络作为价值网络,估计状态价值用于更新,所以接下来以代码为准,结合论文思想,展示用到的输入、输出和损失函数。

2.1 输入

以多目标马里奥环境为例,输入为连续四帧状态 S,随机采样的偏好w。w的值均为正数,且和为1,每一位的值,代表对该维目标的偏好大小。

2.2 输出

Actor 网络和Value网络共享同一个特征提取网络,Actor网络输出维度为所有可能动作数,Value网络输出维度为偏好的维度,也即多目标的目标维度数。

2.3 损失函数

2.3.1 Critic loss

        mse = nn.MSELoss()critic_loss_l1 = mse(wvalue, wtarget)critic_loss_l2 = mse(value.view(-1), target_batch.view(-1))loss += 0.5 * (self.beta * critic_loss_l1 + (1-self.beta) * critic_loss_l2)

Critic 网络的损失由critic loss1和critic loss2加权和组成,critic loss2 理解为多目标损失函数,即当Critic网络能够准确评估多目标状态时,所有pareto前沿上的点都满足critic loss2 为零。因此用梯度下降优化CL2显得不平滑且困难(因为它的解不止一个,而是很多个)。所以引入critic loss1 来减少这种不平滑,critic loss 1 是某种偏好下,critic网络的TD LOSS,因为偏好确定了,所以解只有一个,作者认为这样的损失函数更容易优化,更平滑。

操作上,wvalue和wtarget的唯独都是(batch_size, 1) ; 而 value和target的维度都是(batch_size,reward_size)。显然也是前者的优化更简单。

2.3.2 Actor loss

wadv_batch = torch.bmm(adv_batch.unsqueeze(1), w_batch.unsqueeze(2)).squeeze()
actor_loss = -m.log_prob(action_batch) * wadv_batch

actor loss形式上和带基线的policy gradient的损失函数类似,只不过Critic网络输出的维度不是1而是reward_size,优势adv先与偏好权重w矩阵相乘,得到维度为1的优势adv后再输入actor loss中,这也说明actor loss 的优化方向是朝着使得当前偏好的期望回报最大的方向优化的。

2.4 更新方式

2.4.1 数据收集方式

论文中伪代码表示用类似Q-learning 离线更新的方式, 给出的代码中使用类A3C在线更新的方式,以下以代码为准。

在一个epsiode开始前,随机初始化一个preference,并用这个偏好贯穿这一幕,直至结束。

explore_w = generate_w(args.num_worker, pref_param)

每一步,模型输入状态和偏好,输出动作

 while True:actions = agent.get_action(states, explore_w)for parent_conn in parent_conns:s, r, d, rd, mor, sc = parent_conn.recv()

将一幕中数据以此收齐后立即用于更新神经网络参数(因为A3C是在线算法,所以E3C也是在线)

2.4.2 参数更新方式

 value, next_value, policy = agent.forward_transition(total_state, total_next_state, total_update_w)

1.将收集到的状态,下一状态,偏好的序列输入网络,得到价值(5维)下一状态价值(5维)策略(和动作维度相同)

 for idx in range(args.num_worker):target = make_train_data(args,total_moreward[idx*args.num_step+idw*ofs : (idx+1)*args.num_step+idw*ofs],total_done[idx*args.num_step+idw*ofs: (idx+1)*args.num_step+idw*ofs],value[idx*args.num_step+idw*ofs : (idx+1)*args.num_step+idw*ofs],next_value[idx*args.num_step+idw*ofs : (idx+1)*args.num_step+idw*ofs],reward_size)

2.从最后一状态以此计算 TD-error中的taget,target = r+v(s'),target也是五维

> (ps:一直不知道为什么在线算法要从最后一步一直迭代倒推到第一步,都用r+γv(s')来做代表当前状态价值,导致第一个状态v(s0)=r0+γ*r1+γ**2+....+γ**nV(Sn),导致方差很大。为什么不每一步直接从价值网络导出,这样v(s0)=r0+v(s1),这样方差小的方法呢?很奇怪)

total_target, total_adv = envelope_operator(args, update_w, total_target, value, reward_size, global_step)

 3. 使用envelope operator函数对target做处理,在训练初期,只计算优势 adv = target - value,

在训练中后期用于从随机采样的多个偏好(代码默认八个偏好,总和维度为(8,5))中,挑选出能使target最大的一种偏好。和Q-learning中取q=r+qmax(s')有点像。[这里的reshape我也有点看不懂,此观点只做参考]

agent.train_model()
actor_loss = -m.log_prob(action_batch) * wadv_batch# Entropy(for more exploration)
entropy = m.entropy()# Critic lossmse = nn.MSELoss()
critic_loss_l1 = mse(wvalue, wtarget)
critic_loss_l2 = mse(value.view(-1), target_batch.view(-1))# Total loss (don't compute tempreture)loss = actor_loss.mean()loss += 0.5 * (self.beta * critic_loss_l1 + (1-self.beta) * critic_loss_l2)loss -= self.entropy_coef * entropy.mean()

4.计算loss,反向传播。这一部分就很明了了,计算前面提到的几种loss,给与不同权重后反向传播,唯一特别注意的是,actor loss中使用的优势adv,不知出于什么理由,使用了优势向量与偏好向量做内积后的偏好,(可能是因为解唯一,优化方便)

5.其他注意事项:1、用于和环境交互的偏好并不被保存,更新参数时会重新抽样偏好,这样做有什么理论依据嘛?暂时还没想明白。

2.5 损失函数中偏好和输入网络偏好的关系

从伪代码,和代码中可见,在进行前向推导时输入网络的preference 和在训练时使用的preference并不是同一个。并且,前向时所用的preference并没有被replayer buffer记录下来。训练时actor 和 critic里用的偏好仍然是随机抽取的偏好。

3 其他bug和优化技巧

1.为达到论文所示的训练速度,需要使用简化后的Mario-v3环境,并且跳5帧。

2.由于A3C是异步算法,有多个环境并行采样,所以环境初始化的位置应在启动进程的代码之后,即在multiprocess的run函数之中再reset环境,否则会发生内存地址错误,找不到创造的环境的错误。

对于文中的问题,欢迎有不同见解的同学在评论区讨论交流学习,祝你学习愉快!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/165702.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Leetcode 202 快乐数(HashSet,环形链表思想)

Leetcode 202 快乐数(HashSet) 解法1 : 用HashSet来检测循环:star:为什么说数字n的位数由log n给定呢?解法2 : 链表的思想[出现循环表示链表出现环],使用快慢指针法 题目链接>>>>>>>&…

用 Java 在 PDF 中创建和管理图层,实现交互式文档

PDF 图层(也称为可见图层或附加图层等)是组织和管理 PDF 文档中内容可见性的一种方法。PDF 图层可用于创建交互式文档、隐藏或显示特定信息、创建多语言版本文档等。通过添加和删除图层,用户可以根据需要定制 PDF 文档指定内容的可见性与显示…

PO模式在selenium自动化测试框架的优势

大家都知道po模式可以提高代码的可读性和减少了代码的重复,但是相对的缺点还有,今天通过本文一起学习下PO模式在selenium自动化测试框架的优势,需要的朋友可以参考下 PO模式简介 1.什么是PO模式 PO模型是:Page Object Model的简写 页面对象…

国内有哪些做得好的企业协同办公软件

在当今信息化时代,企业协同办公软件成为了提升企业效率和推动协作的重要工具。国内市场涌现出许多优秀的企业协同办公软件,为企业提供了高效、便捷的协同办公解决方案。在本文中,我们将向大家介绍3款在国内好评如潮的企业协同办公软件&#x…

项目知识点总结-过滤器-MD5注册-邮箱登录

(1)过滤器 使用过滤器验证用户是否登录 /** * Title: NoLoginFilter.java * Package com.qfedu.web.filter * Description: TODO(用一句话描述该文件做什么) * author Feri * date 2018年5月28日 * version V1.0 */ package com.gdsdx…

Mingw快捷安装教程 并完美解决出现的下载错误:The file has been downloaded incorrectly

安装c语言编译器的时候,老是出现The file has been downloaded incorrectly,真的让人 直接去官网拿压缩包:https://sourceforge.net/projects/mingw-w64/files/ (往下拉找到那个x86_64-win32-seh的链接,点击后会自动…

【Java题】实现继承和多态的例子

一:题目 1.员工类Employee: (1)私有成员变量:姓名,年龄,工资 (2)提供无参,有参构造 (3)成员方法:work()方法——员工工作 …

AUTOSAR AP 硬核知识点梳理(2)— 架构详解

一 AUTOSAR 平台逻辑体系结构 图示逻辑体系结构描述了平台是如何组成的,有哪些模块,模块之间的接口是如何工作的。 经典平台具有分层的软件体系结构。定义明确的抽象层,每个抽象层都有精确定义的角色和接口。 对于应用程序,我们需要考虑使用的软件组件,希望它们是可重用的…

洗地机哪个好用?2023年洗地机推荐指南

说到提高家庭幸福生活的家电,洗地机肯定是少不了的,特别对于现在快节奏的生活来说,高效率的解决家务活,而且能够大幅度的提高生活质量。在市场上,消费者面临着选择合适洗地机的难题,因为有各种型号、功能和…

vue3 + fastapi 实现选择目录所有文件自定义上传到服务器

文章目录 ⭐前言💖 技术栈选择 ⭐前端页面搭建💖 调整请求content-type传递formData ⭐后端接口实现💖 swagger文档测试接口 ⭐前后端实现效果💖 上传单个文件💖 上传目录文件 ⭐总结⭐结束 ⭐前言 大家好&#xff0c…

嵌入式养成计划-45----QT--事件机制--定时器事件--键盘事件和鼠标事件--绘制事件

一百一十五、事件机制 当这件事情发生时,会自动走对应的函数处理(重写的事件函数) 115.1 事件处理简介 什么是事件? (重点) 件是由窗口系统或者自身产生的,用以响应所发生的各类事情,比如用户按下并释放…

进程与线程

进程 进程锁 进程之间数据不共享,但是共享同一套文件系统,所以访问同一个文件,或同一个打印终端,是没有问题的,而共享带来的是竞争,竞争带来的结果就是错乱,如何控制,就是加锁处理 part1:多个进程共享同一打印终端 …

【框架源码篇 03】Spring源码手写篇-手写AOP

Spring源码手写篇-手写AOP 手写IoC和DI后已经实现的类图结构。 一、AOP分析 1.AOP是什么? AOP[Aspect Oriented Programming] 面向切面编程,在不改变类的代码的情况下,对类方法进行功能的增强。 2.我们要做什么? 我们需要在前面手写IoC&…

排序算法,冒泡排序算法及优化,选择排序SelectionSort,快速排序(递归-分区)

一、冒泡排序算法: 介绍: 冒泡排序(Bubble Sort)是一种简单直观的排序算法。它重复地走访过要排序的数列,一次比较两个元素,如果他们的顺序错误就把他们交换过来。走访数列的工作是重复地进行直到没有再需…

关于SparkRdd和SparkSql的几个指标统计,scala语言,打包上传到spark集群,yarn模式运行

需求: ❖ 要求:分别用SparkRDD, SparkSQL两种编程方式完成下列数据分析,结合webUI监控比较性能优劣并给出结果的合理化解释. 1、分别统计用户,性别,职业的个数: 2、查看统计年龄分布情况(按照年龄分段为7段&#xff0…

初识树结构和二叉树

一,树概念及结构 1.1树结构的概念 树是一种非线性的数据结构,它是由n(n>0)个有限结点组成一个具有层次关系的集合。把它叫做树是因为它看起来像一棵倒挂的树,也就是说它是根朝上,而叶朝下的。 注意&a…

AI全栈大模型工程师(九)Function Calling 的机制

文章目录 Function Calling 的机制Function Calling 示例 1:加法计算器Function Calling 实例 2:四则混合运算计算器后记Function Calling 的机制 Function Calling 示例 1:加法计算器 需求:用户输入任意可以用加法解决的问题,都能得到计算结果。 # 加载环境变量import o…

人工智能发展与结构科学

人工智能(AI)在各种应用中的影响力不断增强,从简单的计算任务到复杂的决策支持。但在这背后,AI的发展其实是一个关于结构演变的故事。从最早的线性结构,到今天的复杂网络结构,结构的演变对AI的能力和效率产…

【linux】查看下载应用在服务器的日志

查看日志路径 一般在配置文件中logback.xml 账号密码xshell连接服务器,进入日志路径 根据搜索关键字查看xxx.log文件内容 cat xxx.log | grep 关键字 下载 xxx.log 到本地,一般可以下载当天的日志文件到本地查看比较方便 sz xxx.log 参考文章&#xff…

Adobe 推出 Photoshop Elements 2024 新版

🦉 AI新闻 🚀 Adobe 推出 Photoshop Elements 2024 新版 摘要:Adobe 最新发布 Photoshop Elements 2024 版本,新增引入 AI 功能,提供匹配颜色、创建照片卷、一键选择照片天空或背景等新功能,界面也进行了优化更新。本次发布重点加强了 AI 支持,简化复杂…