[PyTorch][chapter 64][强化学习-DQN]

前言:

            DQN 就是结合了深度学习和强化学习的一种算法,最初是 DeepMind 在 NIPS 2013年提出,它的核心利润包括马尔科夫决策链以及贝尔曼公式。

            Q-learning的核心在于Q表格,通过建立Q表格来为行动提供指引,但这适用于状态和动作空间是离散且维数不高时,当状态和动作空间是高维连续时Q表格将变得十分巨大,对于维护Q表格和查找都是不现实的。


1: DQN 历史

2:  DQN 网络参数配置

3:DQN 网络模型搭建


一 DQN 历史

     DQN 跟机器学习的时序差分学习里面的Q-Learning 算法相似

    1.1 Q-Learning 算法

在Q Learning 中,我们有个Q table ,记录不同状态下,各个动作的Q 值

我们通过Q table 更新当前的策略

Q table 的作用: 是我们输入S,通过查表返回能够获得最大Q值的动作A.

但是很多场景状态S 并不是离散的,很难去定义

 1.2  DQN 发展史

     Deep network+Q-learning = DQN

     DQN 和 Q-tabel 没有本质区别:

     Q-table: 内部维护 Q Tabel

     DQN:   通过神经网络  a= NN(s), 替代了 Q Tabel

   


二 网络模型

    2.1 DQN 算法

  2.1 模型

模型参数


三  代码实现:

 5.1 main.py

   

# -*- coding: utf-8 -*-
"""
Created on Fri Nov 17 16:53:02 2023@author: chengxf2
"""import numpy as np
import torch
import gym
import random 
from Replaybuffer import Replay
from Agent import DQN
import rl_utils
import matplotlib.pyplot as plt
from tqdm import tqdm  #生成进度条lr = 5e-3
hidden_dim = 128
num_episodes = 500
minimal_size = 500
gamma = 0.98
epsilon =0.01
target_update = 10
buffer_size = 10000
mini_size = 500
batch_size = 64
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")if __name__ == "__main__":env_name = 'CartPole-v0'env = gym.make(env_name)random.seed(0)np.random.seed(0)env.seed(0)torch.manual_seed(0)replay_buffer = Replay(buffer_size)state_dim = env.observation_space.shape[0]action_dim = env.action_space.nagent = DQN(state_dim, hidden_dim, action_dim, lr, gamma, epsilon,target_update, device)return_list = []for i in range(10):with tqdm(total=int(num_episodes / 10), desc='Iteration %d' % i) as pbar:for i_episode in range(int(num_episodes / 10)):episode_return = 0state = env.reset()done = Falsewhile not done:action = agent.take_action(state)next_state, reward, done, _ = env.step(action)replay_buffer.add(state, action, reward, next_state, done)state = next_stateepisode_return += reward# 当buffer数据的数量超过一定值后,才进行Q网络训练if replay_buffer.size() > minimal_size:b_s, b_a, b_r, b_ns, b_d = replay_buffer.sample(batch_size)transition_dict = {'states': b_s,'actions': b_a,'next_states': b_ns,'rewards': b_r,'dones': b_d}agent.update(transition_dict)return_list.append(episode_return)if (i_episode + 1) % 10 == 0:pbar.set_postfix({'episode':'%d' % (num_episodes / 10 * i + i_episode + 1),'return':'%.3f' % np.mean(return_list[-10:])})pbar.update(1)episodes_list = list(range(len(return_list)))plt.figure(1) plt.subplot(1, 2, 1)  # fig.1是一个一行两列布局的图,且现在画的是左图plt.plot(episodes_list, return_list,c='r')plt.xlabel('Episodes')plt.ylabel('Returns')plt.title('DQN on {}'.format(env_name))plt.figure(1)  # 当前要处理的图为fig.1,而且当前图是fig.1的左图plt.subplot(1, 2, 2)  # 当前图变为fig.1的右图mv_return = rl_utils.moving_average(return_list, 9)plt.plot(episodes_list, mv_return,c='g')plt.xlabel('Episodes')plt.ylabel('Returns')plt.title('DQN on {}'.format(env_name))plt.show()

5.2  Agent.py

# -*- coding: utf-8 -*-
"""
Created on Fri Nov 17 16:00:46 2023@author: chengxf2
"""import random 
import numpy as np
from   torch import nn
import torch
import torch.nn.functional as Fclass QNet(torch.nn.Module):def __init__(self, state_dim, hidden_dim, action_dim):super(QNet, self).__init__()self.net = nn.Sequential(nn.Linear(state_dim, hidden_dim),nn.Linear(hidden_dim, action_dim))def forward(self, state):qvalue = self.net(state)return qvalueclass  DQN:def __init__(self,state_dim, hidden_dim, action_dim,learning_rate,discount, epsilon, target_update, device):self.action_dim = action_dimself.q_net = QNet(state_dim, hidden_dim, action_dim).to(device)self.target_q_net = QNet(state_dim, hidden_dim, action_dim).to(device)#Adam 优化器self.optimizer = torch.optim.Adam(self.q_net.parameters(),lr=learning_rate)self.gamma = discount #折扣因子self.epsilon = epsilon  # e-贪心算法self.target_update = target_update  # 目标网络更新频率self.device = deviceself.count = 0 #计数器def  take_action(self, state):rnd = np.random.random() #产生随机数if rnd <self.epsilon:action = np.random.randint(0, self.action_dim)else:state = torch.tensor([state], dtype=torch.float).to(self.device)qvalue = self.q_net(state)action = qvalue.argmax().item()return actiondef update(self, data):states = torch.tensor(data['states'],dtype=torch.float).to(self.device)actions = torch.tensor(data['actions']).view(-1, 1).to(self.device)rewards = torch.tensor(data['rewards'],dtype=torch.float).view(-1, 1).to(self.device)next_states = torch.tensor(data['next_states'],dtype=torch.float).to(self.device)dones = torch.tensor(data['dones'],dtype=torch.float).view(-1, 1).to(self.device)#从完整数据中按索引取值[64]#print("\n actions ",actions,actions.shape)q_value = self.q_net(states).gather(1,actions) #Q值#下一个状态的Q值max_next_q_values = self.target_q_net(next_states).max(1)[0].view(-1,1)q_targets = rewards + self.gamma * max_next_q_values * (1 - dones)loss = F.mse_loss(q_value, q_targets)loss = torch.mean(loss)self.optimizer.zero_grad()loss.backward()self.optimizer.step()if self.count %self.target_update  ==0:#更新目标网络self.target_q_net.load_state_dict(self.q_net.state_dict())self.count +=1

 5.3 Replaybuffer.py

   

# -*- coding: utf-8 -*-
"""
Created on Fri Nov 17 15:50:07 2023@author: chengxf2
"""import collections 
import random 
import numpy as np
class Replay:def __init__(self, capacity):#双向队列,可以在队列的两端任意添加或删除元素。self.buffer = collections.deque(maxlen = capacity)def add(self, state, action ,reward, next_state, done):#数据加入bufferself.buffer.append((state,action,reward, next_state, done))def sample(self, batch_size):#采样数据data = random.sample(self.buffer, batch_size)state,action, reward, next_state,done = zip(*data)return np.array(state), action, reward, np.array(next_state), donedef size(self):return len(self.buffer)

 5.4 rl_utils.py

from tqdm import tqdm
import numpy as np
import torch
import collections
import randomclass ReplayBuffer:def __init__(self, capacity):self.buffer = collections.deque(maxlen=capacity) def add(self, state, action, reward, next_state, done): self.buffer.append((state, action, reward, next_state, done)) def sample(self, batch_size): transitions = random.sample(self.buffer, batch_size)state, action, reward, next_state, done = zip(*transitions)return np.array(state), action, reward, np.array(next_state), done def size(self): return len(self.buffer)def moving_average(a, window_size):cumulative_sum = np.cumsum(np.insert(a, 0, 0)) middle = (cumulative_sum[window_size:] - cumulative_sum[:-window_size]) / window_sizer = np.arange(1, window_size-1, 2)begin = np.cumsum(a[:window_size-1])[::2] / rend = (np.cumsum(a[:-window_size:-1])[::2] / r)[::-1]return np.concatenate((begin, middle, end))def train_on_policy_agent(env, agent, num_episodes):return_list = []for i in range(10):with tqdm(total=int(num_episodes/10), desc='Iteration %d' % i) as pbar:for i_episode in range(int(num_episodes/10)):episode_return = 0transition_dict = {'states': [], 'actions': [], 'next_states': [], 'rewards': [], 'dones': []}state = env.reset()done = Falsewhile not done:action = agent.take_action(state)next_state, reward, done, _ = env.step(action)transition_dict['states'].append(state)transition_dict['actions'].append(action)transition_dict['next_states'].append(next_state)transition_dict['rewards'].append(reward)transition_dict['dones'].append(done)state = next_stateepisode_return += rewardreturn_list.append(episode_return)agent.update(transition_dict)if (i_episode+1) % 10 == 0:pbar.set_postfix({'episode': '%d' % (num_episodes/10 * i + i_episode+1), 'return': '%.3f' % np.mean(return_list[-10:])})pbar.update(1)return return_listdef train_off_policy_agent(env, agent, num_episodes, replay_buffer, minimal_size, batch_size):return_list = []for i in range(10):with tqdm(total=int(num_episodes/10), desc='Iteration %d' % i) as pbar:for i_episode in range(int(num_episodes/10)):episode_return = 0state = env.reset()done = Falsewhile not done:action = agent.take_action(state)next_state, reward, done, _ = env.step(action)replay_buffer.add(state, action, reward, next_state, done)state = next_stateepisode_return += rewardif replay_buffer.size() > minimal_size:b_s, b_a, b_r, b_ns, b_d = replay_buffer.sample(batch_size)transition_dict = {'states': b_s, 'actions': b_a, 'next_states': b_ns, 'rewards': b_r, 'dones': b_d}agent.update(transition_dict)return_list.append(episode_return)if (i_episode+1) % 10 == 0:pbar.set_postfix({'episode': '%d' % (num_episodes/10 * i + i_episode+1), 'return': '%.3f' % np.mean(return_list[-10:])})pbar.update(1)return return_listdef compute_advantage(gamma, lmbda, td_delta):td_delta = td_delta.detach().numpy()advantage_list = []advantage = 0.0for delta in td_delta[::-1]:advantage = gamma * lmbda * advantage + deltaadvantage_list.append(advantage)advantage_list.reverse()return torch.tensor(advantage_list, dtype=torch.float)

DQN 算法
遇强则强(八):从Q-table到DQN - 知乎使用Pytorch实现强化学习——DQN算法_dqn pytorch-CSDN博客

https://www.cnblogs.com/xiaohuiduan/p/12993691.html

https://www.cnblogs.com/xiaohuiduan/p/12945449.html

强化学习第五节(DQN)【个人知识分享】_哔哩哔哩_bilibili

CSDN

组会讲解强化学习的DQN算法_哔哩哔哩_bilibili

3-ε-greedy_ReplayBuffer_FixedQ-targets_哔哩哔哩_bilibili

4-代码实战DQN_Agent和Env整体交互_哔哩哔哩_bilibili

DQN基本概念和算法流程(附Pytorch代码) - 知乎

CSDN

DQN 算法

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/207407.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

社区便利店销售微信APP的设计与实现

摘 要 社区便利店销售小程序采用的技术&#xff1a;第一是Mysql数据库&#xff1b;第二是java程序开发语言&#xff1b;第三是ssm框架&#xff1b;第四是B/S结构。系统主要分为管理员、商家、用户三部分&#xff0c;这个销售小程序的功能有首页和个人中心&#xff0c;同时还有…

计算机毕业设计|基于SpringBoot+MyBatis框架健身房管理系统的设计与实现

计算机毕业设计|基于SpringBootMyBatis框架的健身房管理系统的设计与实现 摘 要:本文基于Spring Boot和MyBatis框架&#xff0c;设计并实现了一款综合功能强大的健身房管理系统。该系统涵盖了会员卡查询、会员管理、员工管理、器材管理以及课程管理等核心功能&#xff0c;并且…

MySQL 教程 1.4

MySQL 连接 使用mysql二进制方式连接 您可以使用MySQL二进制方式进入到mysql命令提示符下来连接MySQL数据库。 实例 以下是从命令行中连接mysql服务器的简单实例&#xff1a; [roothost]# mysql -u root -p Enter password:****** 在登录成功后会出现 mysql> 命令提示窗…

Android11编译第八弹:root用户密码设置

问题&#xff1a;user版本增加su 指令以后&#xff0c;允许切换root用户&#xff0c;但是&#xff0c;root用户默认没有设置密码&#xff0c;这样访问不安全。 需要增加root用户密码。 一、Linux账户管理 1.1 文件和权限 Linux一切皆文件。文件和目录都有相应的权限&#x…

微信小程序踩坑记录

一、引言 作者在开发微信小程序《目的地到了》的过程中遇到过许多问题&#xff0c;这里讲讲一些技术和经验问题。 基本目录机构&#xff1a; 二、问题 1、定位使用 获取定位一定要在app.json里面申明&#xff0c;不然是没办法获取定位信息的 "requiredPrivateInfos"…

jetson nano SSH远程连接(使用MobaXterm)

文章目录 SSH远程连接1.SSH介绍2.准备工作3.连接步骤3.1 IP查询3.2 新建会话和连接 SSH远程连接 本节课的实现&#xff0c;需要将Jetson Nano和电脑保持在同一个局域网内&#xff0c;也就是连接同一个路 由器&#xff0c;通过SSH的方式来实现远程登陆。 1.SSH介绍 SSH是一种网…

腾讯云最新优惠券领取入口,总面值2000元代金券,新用户、老用户、企业用户均可领取!

腾讯云推出年末感恩回馈活动&#xff0c;新老用户可免费领取总面值2000元的代金券礼包&#xff0c;适用于多种预付费产品&#xff0c;最高可抵扣36个月订单&#xff0c;领取后30天内有效。 领取入口&#xff1a; https://curl.qcloud.com/UpmL4ho3 领取说明&#xff1a; 腾…

制作太阳能小车

今天偶然星期想搞一个太阳能小车耍一下子&#xff0c;那么接下来就介绍下相关的准备物品吧 首先介绍下需要准备的物品&#xff1a; 1、玩具车拆下四个轮子 2、小马达一个 3、1.5v太阳能板&#xff08;根据自己的需求购买相应的电压1.5v 3.7v 5v 12v等等&#xff09; 4、3D打…

11.28~11.29基本二叉树的性质、定义、复习;排序算法;堆

完全二叉树&#xff08;Complete Binary Tree&#xff09;是一种特殊的二叉树结构&#xff0c;它具有以下特点&#xff1a; 所有的叶子节点都集中在树的最后两层&#xff1b;最后一层的叶子节点都靠左排列&#xff1b;除了最后一层&#xff0c;其他层的节点数都达到最大值。 …

Python 进阶(十二):随机数(random 模块)

《Python入门核心技术》专栏总目录・点这里 文章目录 1. 导入random库2. 常用随机数函数2.1 生成随机浮点数2.2 生成随机整数2.3 从序列中随机选择2.4 随机打乱序列3. 设置随机数种子4. 应用实例4.1 游戏开发4.2 数据分析4.3 加密与安全4.4 模拟实验

C++核心编程——运算符重载

C核心编程——运算符重载 运算符重载的方法运算符重载函数作成员函数与友元函数重载双目运算符重载单目运算符重载流插入运算符和"<<"和流提取运算符">>"重载流插入运算符和"<<"流提取运算符">>" 运算符重载的…

finebi 新手入门案例

finebi 新手入门案例 连锁超市销售数据分析 步骤&#xff1a; 准备公共数据新建分析主题处理数据在数据中分析在图形中分析数据大屏 准备公共数据 点击公共数据 点击新建文件夹 修改文件夹名称 上传数据 鼠标悬停在文件夹上&#xff0c;右侧出现 鼠标悬停在文件夹上&#x…

Ubuntu中MySQL安装与使用

一、安装教程&#xff1a;移步 二、通过sql文件创建表格&#xff1a; 首先进入mysql&#xff1a; mysql -u 用户 -p 回车 然后输入密码source sql文件&#xff08;路径&#xff09;;上面是sql语句哈&#xff0c;所以记得加分号。 sql文件部分截图&#xff1a; 创建成功后的部…

《opencv实用探索·七》一文看懂图像卷积运算

1、图像卷积使用场景 图像卷积是图像处理中的一种常用的算法&#xff0c;它是一种基本的滤波技术&#xff0c;通过卷积核&#xff08;也称为滤波器&#xff09;对图像进行操作&#xff0c;使用场景如下&#xff1a; 模糊&#xff08;Blur&#xff09;&#xff1a; 使用加权平…

与原有视频会议系统对接

要实现与原有视频会议系统对接&#xff0c;需要确保通信协议的一致性。连通宝视频会议系统可与第三方视频会议系统对接。实现与第三方会议系统对接还可以使用会议室连接器&#xff0c;可以确保不同系统之间的数据传输和交互。 具体对接流程可能因不同品牌和类型的视频会议系统而…

蓝桥杯第四场双周赛(1~6)

1、水题 2、模拟题&#xff0c;写个函数即可 #define pb push_back #define x first #define y second #define int long long #define endl \n const LL maxn 4e057; const LL N 5e0510; const LL mod 1e097; const int inf 0x3f3f; const LL llinf 5e18;typedef pair…

BEV+Transformer架构加速“上车”,智能驾驶市场变革开启

BEVTransformer成为了高阶智能驾驶领域最为火热的技术趋势。 近日&#xff0c;在2023年广州车展期间&#xff0c;不少车企及智能驾驶厂商都发布了BEVTransformer方案。其中&#xff0c;极越01已经实现了“BEVTransformer”的“纯视觉”方案的量产&#xff0c;成为国内唯一量产…

Leetcode-二叉树oj题

1.二叉树的前序遍历 144. 二叉树的前序遍历https://leetcode.cn/problems/binary-tree-preorder-traversal/这个题目在遍历的基础上还要求返回数组&#xff0c;数组里面按前序存放二叉树节点的值。 既然要返回数组&#xff0c;就必然要malloc一块空间&#xff0c;那么我们需…

C# WPF上位机开发(第一个应用)

【 声明&#xff1a;版权所有&#xff0c;欢迎转载&#xff0c;请勿用于商业用途。 联系信箱&#xff1a;feixiaoxing 163.com】 万事开头难&#xff0c;很多事情都是难在第一步。走出了这第一步&#xff0c;回过头看以前走的每一步&#xff0c;发现其实也不难。用c# wpf编写界…

系列九、声明式事务(xml方式)

一、概述 声明式事务(declarative transaction management)是Spring提供的对程序事务管理的一种方式&#xff0c;Spring的声明式事务顾名思义就是采用声明的方式来处理事务。这里所说的声明&#xff0c;是指在配置文件中声明&#xff0c;用在Spring配置文件中声明式的处理事务来…