基于“动手学强化学习”的知识点(二):第 15 章 模仿学习(gym版本 >= 0.26)

第 15 章 模仿学习(gym版本 >= 0.26)

  • 摘要

摘要

本系列知识点讲解基于动手学强化学习中的内容进行详细的疑难点分析!具体内容请阅读动手学强化学习!


对应动手学强化学习——模仿学习


# -*- coding: utf-8 -*-import gym
import torch
import torch.nn.functional as F
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import random
import rl_utilsclass PolicyNet(torch.nn.Module):def __init__(self, state_dim, hidden_dim, action_dim):super(PolicyNet, self).__init__()self.fc1 = torch.nn.Linear(state_dim, hidden_dim)self.fc2 = torch.nn.Linear(hidden_dim, action_dim)def forward(self, x):x = F.relu(self.fc1(x))return F.softmax(self.fc2(x), dim=1)class ValueNet(torch.nn.Module):def __init__(self, state_dim, hidden_dim):super(ValueNet, self).__init__()self.fc1 = torch.nn.Linear(state_dim, hidden_dim)self.fc2 = torch.nn.Linear(hidden_dim, 1)def forward(self, x):x = F.relu(self.fc1(x))return self.fc2(x)class PPO:''' PPO算法,采用截断方式 '''def __init__(self, state_dim, hidden_dim, action_dim, actor_lr, critic_lr,lmbda, epochs, eps, gamma, device):self.actor = PolicyNet(state_dim, hidden_dim, action_dim).to(device)self.critic = ValueNet(state_dim, hidden_dim).to(device)self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=actor_lr)self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=critic_lr)self.gamma = gammaself.lmbda = lmbdaself.epochs = epochs  # 一条序列的数据用于训练轮数self.eps = eps  # PPO中截断范围的参数self.device = devicedef take_action(self, state):if isinstance(state, tuple):state = state[0]state = torch.tensor([state], dtype=torch.float).to(self.device)probs = self.actor(state)'''根据概率分布创建一个离散分类分布对象,用于采样离散动作。离散的概率模型。'''action_dist = torch.distributions.Categorical(probs)action = action_dist.sample()return action.item()def update(self, transition_dict):    processed_state = []for s in transition_dict['states']:if isinstance(s, tuple):# 如果元素是元组,则取元组的第一个元素processed_state.append(s[0])else:processed_state.append(s)# states = torch.tensor(transition_dict['states'], dtype=torch.float).to(self.device)states = torch.tensor(processed_state, dtype=torch.float).to(self.device)actions = torch.tensor(transition_dict['actions']).view(-1, 1).to(self.device)rewards = torch.tensor(transition_dict['rewards'],dtype=torch.float).view(-1, 1).to(self.device)next_states = torch.tensor(transition_dict['next_states'],dtype=torch.float).to(self.device)dones = torch.tensor(transition_dict['dones'],dtype=torch.float).view(-1, 1).to(self.device)'''计算 TD 目标(即回归目标):td_target=r+γ×V(s′)×(1−done)'''td_target = rewards + self.gamma * self.critic(next_states) * (1 -dones)'''计算 TD 残差(或优势估计的基础):当前状态的 TD 目标减去当前 critic 估计的状态价值。'''td_delta = td_target - self.critic(states)'''调用辅助函数(在 rl_utils 模块中定义)计算优势函数,通常使用广义优势估计(GAE)。'''advantage = rl_utils.compute_advantage(self.gamma, self.lmbda, td_delta.cpu()).to(self.device)'''先将状态输入 actor 网络得到动作概率分布(例如 shape 为 (batch_size, action_dim))。使用 .gather(1, actions) 选出每个样本所执行动作对应的概率(注意 actions 的形状必须匹配)。取对数得到旧的对数概率,再 detach() 阻断梯度流,保存旧策略下的概率值。'''old_log_probs = torch.log(self.actor(states).gather(1, actions)).detach()for _ in range(self.epochs):'''在当前策略下重新计算所有样本的对数概率,与旧对数概率进行比较。'''log_probs = torch.log(self.actor(states).gather(1, actions))'''计算概率比率,即新旧策略的概率之比,用于 PPO 的 clip 损失计算。'''ratio = torch.exp(log_probs - old_log_probs)'''计算无截断的策略目标,乘上优势值。'''surr1 = ratio * advantage'''对 ratio 进行截断,确保其在 [1−ϵ,1+ϵ] 范围内(例如 [0.8, 1.2]),然后乘以优势。'''surr2 = torch.clamp(ratio, 1 - self.eps, 1 + self.eps) * advantage  # 截断'''PPO 算法的目标是最大化最小值,因此这里取两者中的较小值再取负号作为损失。对整个 batch 求均值。'''actor_loss = torch.mean(-torch.min(surr1, surr2))  # PPO损失函数'''计算 critic 的均方误差(MSE)损失:当前 critic 估计与 TD 目标之间的误差,对整个 batch 取平均。'''critic_loss = torch.mean(F.mse_loss(self.critic(states), td_target.detach()))self.actor_optimizer.zero_grad()self.critic_optimizer.zero_grad()actor_loss.backward()critic_loss.backward()self.actor_optimizer.step()self.critic_optimizer.step()actor_lr = 1e-3
critic_lr = 1e-2
num_episodes = 250
hidden_dim = 128
gamma = 0.98
lmbda = 0.95
epochs = 10
eps = 0.2
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")env_name = 'CartPole-v0'
env = gym.make(env_name)
if not hasattr(env, 'seed'):def seed_fn(self, seed=None):env.reset(seed=seed)return [seed]env.seed = seed_fn.__get__(env, type(env))
# env.seed(0)
torch.manual_seed(0)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
ppo_agent = PPO(state_dim, hidden_dim, action_dim, actor_lr, critic_lr, lmbda, epochs, eps, gamma, device)return_list = rl_utils.train_on_policy_agent(env, ppo_agent, num_episodes)def sample_expert_data(n_episode):states = []actions = []for episode in range(n_episode):state = env.reset()done = Falsewhile not done:action = ppo_agent.take_action(state)states.append(state)actions.append(action)result = env.step(action)if len(result) == 5:next_state, reward, done, truncated, info = resultdone = done or truncated  # 可合并 terminated 和 truncated 标志else:next_state, reward, done, info = result# next_state, reward, done, _ = env.step(action)state = next_stateprocessed_states = []for s in states:if isinstance(s, tuple):# 如果元素是元组,则取元组的第一个元素processed_states.append(s[0])else:processed_states.append(s)return np.array(processed_states), np.array(actions)if not hasattr(env, 'seed'):def seed_fn(self, seed=None):env.reset(seed=seed)return [seed]env.seed = seed_fn.__get__(env, type(env))
# env.seed(0)
torch.manual_seed(0)
random.seed(0)
n_episode = 1
expert_s, expert_a = sample_expert_data(n_episode)n_samples = 30  # 采样30个数据
random_index = random.sample(range(expert_s.shape[0]), n_samples)
expert_s = expert_s[random_index]
expert_a = expert_a[random_index]class BehaviorClone:def __init__(self, state_dim, hidden_dim, action_dim, lr):self.policy = PolicyNet(state_dim, hidden_dim, action_dim).to(device)self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=lr)def learn(self, states, actions):"""解释:定义一个学习函数,接收一批专家数据中的状态和动作,用于更新策略网络。"""states = torch.tensor(states, dtype=torch.float).to(device)actions = torch.tensor(actions).view(-1, 1).to(device)'''- 将 states 输入 policy 网络,得到每个状态下所有动作的概率分布,假设输出形状为 (batch_size, action_dim);- 使用 .gather(1, actions.long()) 从概率分布中取出对应专家动作的概率(注意动作需要转换为长整型索引);- 对这些概率取对数,得到对数概率(log likelihood)。'''log_probs = torch.log(self.policy(states).gather(1, actions.long()))# log_probs = torch.log(self.policy(states).gather(1, actions))'''计算行为克隆的损失,即负对数似然损失。对所有样本的负对数概率取均值。'''bc_loss = torch.mean(-log_probs)  # 最大似然估计self.optimizer.zero_grad()bc_loss.backward()self.optimizer.step()def take_action(self, state):if isinstance(state, tuple):state = state[0]state = torch.tensor([state], dtype=torch.float).to(device)probs = self.policy(state)action_dist = torch.distributions.Categorical(probs)action = action_dist.sample()return action.item()def test_agent(agent, env, n_episode):return_list = []for episode in range(n_episode):episode_return = 0state = env.reset()done = Falsewhile not done:action = agent.take_action(state)result = env.step(action)if len(result) == 5:next_state, reward, done, truncated, info = resultdone = done or truncated  # 可合并 terminated 和 truncated 标志else:next_state, reward, done, info = result# next_state, reward, done, _ = env.step(action)state = next_stateepisode_return += rewardreturn_list.append(episode_return)return np.mean(return_list)if not hasattr(env, 'seed'):def seed_fn(self, seed=None):env.reset(seed=seed)return [seed]env.seed = seed_fn.__get__(env, type(env))
# env.seed(0)
torch.manual_seed(0)
np.random.seed(0)lr = 1e-3
bc_agent = BehaviorClone(state_dim, hidden_dim, action_dim, lr)
n_iterations = 1000
batch_size = 64
test_returns = []with tqdm(total=n_iterations, desc="进度条") as pbar:for i in range(n_iterations):sample_indices = np.random.randint(low=0, high=expert_s.shape[0], size=batch_size)bc_agent.learn(expert_s[sample_indices], expert_a[sample_indices])current_return = test_agent(bc_agent, env, 5)test_returns.append(current_return)if (i + 1) % 10 == 0:pbar.set_postfix({'return': '%.3f' % np.mean(test_returns[-10:])})pbar.update(1)iteration_list = list(range(len(test_returns)))
plt.plot(iteration_list, test_returns)
plt.xlabel('Iterations')
plt.ylabel('Returns')
plt.title('BC on {}'.format(env_name))
plt.show()class Discriminator(nn.Module):def __init__(self, state_dim, hidden_dim, action_dim):super(Discriminator, self).__init__()self.fc1 = torch.nn.Linear(state_dim + action_dim, hidden_dim)self.fc2 = torch.nn.Linear(hidden_dim, 1)def forward(self, x, a):cat = torch.cat([x, a], dim=1)x = F.relu(self.fc1(cat))return torch.sigmoid(self.fc2(x))class GAIL:def __init__(self, agent, state_dim, action_dim, hidden_dim, lr_d):print(state_dim, action_dim, hidden_dim)self.discriminator = Discriminator(state_dim, hidden_dim, action_dim).to(device)self.discriminator_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr=lr_d)self.agent = agentdef learn(self, expert_s, expert_a, agent_s, agent_a, next_s, dones):expert_states = torch.tensor(expert_s, dtype=torch.float).to(device)expert_actions = torch.tensor(expert_a).to(device)processed_state = []for s in agent_s:if isinstance(s, tuple):# 如果元素是元组,则取元组的第一个元素processed_state.append(s[0])else:processed_state.append(s)agent_states = torch.tensor(processed_state, dtype=torch.float).to(device)agent_actions = torch.tensor(agent_a).to(device)'''作用:将专家动作转换为 one-hot 编码形式,转换为浮点数。'''expert_actions = F.one_hot(expert_actions.long(), num_classes=2).float()agent_actions = F.one_hot(agent_actions.long(), num_classes=2).float()expert_prob = self.discriminator(expert_states, expert_actions)agent_prob = self.discriminator(agent_states, agent_actions)'''作用:计算二元交叉熵损失(BCE):- 对 agent 数据,目标标签设为 1(即希望判别器认为 agent 数据为“真”),损失为 BCE(agent_prob, 1);- 对专家数据,目标标签设为 0(希望判别器认为专家数据为“假”),损失为 BCE(expert_prob, 0)。- 然后将两部分损失相加。'''discriminator_loss = nn.BCELoss()(agent_prob, torch.ones_like(agent_prob)) + nn.BCELoss()(expert_prob, torch.zeros_like(expert_prob))self.discriminator_optimizer.zero_grad()discriminator_loss.backward()self.discriminator_optimizer.step()'''作用:利用判别器对 agent 数据输出计算奖励:- 计算 –log(agent_prob) 作为奖励信号(当 agent_prob 较小时,奖励较高,鼓励 agent 模仿专家);- detach() 阻断梯度,转移到 CPU 并转换为 numpy 数组,方便后续传入 agent.update。'''rewards = -torch.log(agent_prob).detach().cpu().numpy()transition_dict = {'states': agent_s,'actions': agent_a,'rewards': rewards,'next_states': next_s,'dones': dones}self.agent.update(transition_dict)if not hasattr(env, 'seed'):def seed_fn(self, seed=None):env.reset(seed=seed)return [seed]env.seed = seed_fn.__get__(env, type(env))
# env.seed(0)
torch.manual_seed(0)
lr_d = 1e-3
agent = PPO(state_dim, hidden_dim, action_dim, actor_lr, critic_lr, lmbda, epochs, eps, gamma, device)
gail = GAIL(agent, state_dim, action_dim, hidden_dim, lr_d)
n_episode = 500
return_list = []with tqdm(total=n_episode, desc="进度条") as pbar:for i in range(n_episode):episode_return = 0state = env.reset()done = Falsestate_list = []action_list = []next_state_list = []done_list = []while not done:action = agent.take_action(state)result = env.step(action)if len(result) == 5:next_state, reward, done, truncated, info = resultdone = done or truncated  # 可合并 terminated 和 truncated 标志else:next_state, reward, done, info = result# next_state, reward, done, _ = env.step(action)state_list.append(state)action_list.append(action)next_state_list.append(next_state)done_list.append(done)state = next_stateepisode_return += rewardreturn_list.append(episode_return)gail.learn(expert_s, expert_a, state_list, action_list, next_state_list, done_list)if (i + 1) % 10 == 0:pbar.set_postfix({'return': '%.3f' % np.mean(return_list[-10:])})pbar.update(1)    iteration_list = list(range(len(return_list)))
plt.plot(iteration_list, return_list)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title('GAIL on {}'.format(env_name))
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/33718.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

OpenHarmony自定义子系统、部件与模块

如图所示,OpenHarmony系统源码中,大体上按照不同种类的功能分成多个子系统,然后一个子系统内部进一步在同类功能上的差异性划分成一个或多个部件,也就是说一个部件表示一个具体功能的源码集合。最后一个部件的源码再划分成一个或多…

【论文笔记】Contrastive Learning for Compact Single Image Dehazing(AECR-Net)

文章目录 问题创新网络主要贡献Autoencoder-like Dehazing NetworkAdaptive Mixup for Feature PreservingDynamic Feature Enhancement1. 可变形卷积的使用2. 扩展感受野3. 减少网格伪影4. 融合空间结构信息 Contrastive Regularization1. 核心思想2. 正样本对和负样本对的构建…

uni-app打包h5并部署到nginx,路由模式history

uni-app打包有些坑,当时运行的基础路径填写了./,导致在二级页面刷新之后,页面直接空白。就只能换一个路径了,nginx也要跟着改,下面是具体步骤。 manifest.json配置web 运行路径写/h5/,或者写你们网站的目…

SQLiteStudio:一款免费开源跨平台的SQLite管理工具

目录 1.简介 2.下载与安装 3.实现分析 4.总结 1.简介 SQLiteStudio 是一款专门用于管理 SQLite 数据库的图形化工具,由波兰开发者开发并维护。由于 SQLite 以其轻量级、零配置、嵌入式等特性被广泛应用于各种小型项目、移动应用和桌面应用中,而 SQLi…

Java入职篇(2)——开发流程以及专业术语

Java入职篇(2)——开发流程以及专业术语 开发流程 开发术语 测试用例(用例) 测试人员写的测试方案,基本上就是编写的测试过程,以及测试的预取结果 灰度测试 现在小部分范围内使用,然后逐步…

Figma介绍(基于云的协作式界面设计工具,主要用于UI/UX设计、原型制作和团队协作)

文章目录 注册和登录简单操作说明Figma介绍**核心特点**1. **云端协作与实时同步**2. **跨平台兼容**3. **高效设计工具**4. **原型交互与动效**5. **开发对接友好**6. **插件生态**7. **版本控制与历史记录** **适用场景**- **团队协作**:远程团队共同设计、评审、…

RAW图与BAYER图异同

RAW图是一种未经处理、未压缩的图像文件格式,它记录了图像传感器捕捉到的原始数据,包含了拍摄时的大量图像信息。下面从多个方面详细介绍RAW图: 参考:B站大清光学 定义与基本概念 定义:RAW文件是图像传感器将捕捉到…

mac安装navicat及使用

0.删除旧的 sudo rm -Rf /Applications/Navicat\ Premium.app sudo rm -Rf /private/var/db/BootCaches/CB6F12B3-2C14-461E-B5A7-A8621B7FF130/app.com.prect.NavicatPremium.playlist sudo rm -Rf ~/Library/Caches/com.apple.helpd/SDMHelpData/Other/English/HelpSDMIndexF…

Windows11【1001问】打开Windows 11控制面板的14种方法

在Windows 11中,尽管微软逐渐转向现代的“设置”应用,但传统的“控制面板”仍然是许多用户管理系统、调整硬件设置和自定义功能的首选工具。然而,由于Windows 11的界面设计更注重简洁性,控制面板的访问方式可能对部分用户来说不够…

Language Models are Few-Shot Learners,GPT-3详细讲解

GPT的训练范式:预训练Fine-Tuning GPT2的训练范式:预训练Prompt predict (zero-shot learning) GPT3的训练范式:预训练Prompt predict (few-shot learning) GPT2的性能太差,新意高&…

数据结构--图的基本操作

知识总览: 一、图的基本操作 1.Adjacent(G,x,y),判断图G是否有边---对于有向图和无向图来说,邻间接矩阵的时复杂度更低。 邻接矩阵时间复杂度 O(1) 邻接表时间复杂度 O(1)~~O(v) 2.Neighbors(G,x):判断图G与结点x邻接的边.---邻间接矩…

Unity中解锁图片像素点,动态闭合轨迹检测

Unity中解锁图片像素点,动态闭合轨迹检测 介绍资源下载搭建总结 介绍 因为最近在研究Mane天蚕变的游戏完整逻辑,研究了两套方案做解锁图片的功能,这里我先讲一下我的这个图片像素点的方案解锁图片,这个逻辑其实很简单就是利用划线…

buu-ciscn_2019_ne_5-好久不见50

1. 背景分析 目标程序是一个存在漏洞的二进制文件,我们可以通过以下方式利用漏洞获取 shell: 程序中存在 system() 函数,但没有明显的 /bin/sh 字符串。 使用工具(如 ROPgadget)发现程序中有 sh 字符串,可…

图论part4|827. 最大人工岛、127. 单词接龙、463. 岛屿的周长

827. 最大人工岛 🔗:827. 最大人工岛 - 力扣(LeetCode)827. 最大人工岛 - 给你一个大小为 n x n 二进制矩阵 grid 。最多 只能将一格 0 变成 1 。返回执行此操作后,grid 中最大的岛屿面积是多少?岛屿 由一…

SpeechCraf论文学习

Abstract 核心问题 挑战 语音风格包含细微的多样化信息(如情感、语调、节奏),传统基于标签/模板的标注方法难以充分捕捉,制约了语音-语言多模态模型的性能。 数据瓶颈: 大规模数据收集与高质量标注之间存在矛盾&…

SAIL-RK3576核心板应用方案——无人机视觉定位与地面无人设备通信控制方案

本方案以 EFISH-RK3576-SBC工控板 或 SAIL-RK3576核心板 为核心,结合高精度视觉定位、实时通信与智能控制技术,实现无人机与地面无人设备的协同作业。方案适用于物流巡检、农业植保、应急救援等场景,具备高精度定位、低延迟通信与强环境适应性…

PostgreSQL的学习心得和知识总结(一百七十一)|深入理解PostgreSQL数据库之 外连接消除 的使用和实现

目录结构 注:提前言明 本文借鉴了以下博主、书籍或网站的内容,其列表如下: 1、参考书籍:《PostgreSQL数据库内核分析》 2、参考书籍:《数据库事务处理的艺术:事务管理与并发控制》 3、PostgreSQL数据库仓库…

C语言实现括号匹配检查及栈的应用详解

目录 栈数据结构简介 C语言实现栈 栈的初始化 栈的销毁 栈的插入 栈的删除 栈的判空 获取栈顶数据 利用栈实现括号匹配检查 总结 在编程中,经常会遇到需要检查括号是否匹配的问题,比如在编译器中检查代码的语法正确性,或者在…

【机器学习chp12】半监督学习(自我训练+协同训练多视角学习+生成模型+半监督SVM+基于图的半监督算法+半监督聚类)

目录 一、半监督学习简介 1、半监督学习的定义和基本思想 2、归纳学习 和 直推学习 (1)归纳学习 (2)直推学习 3、半监督学习的作用与优势 4、半监督学习的关键假设 5、半监督学习的应用 6、半监督学习的常见方法 7、半…

2024 年第四届高校大数据挑战赛-赛题 A:岩石的自动鉴定

↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓