第 15 章 模仿学习(gym版本 >= 0.26)
- 摘要
摘要
本系列知识点讲解基于动手学强化学习中的内容进行详细的疑难点分析!具体内容请阅读动手学强化学习!
对应动手学强化学习——模仿学习
# -*- coding: utf-8 -*-import gym
import torch
import torch.nn.functional as F
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import random
import rl_utilsclass PolicyNet(torch.nn.Module):def __init__(self, state_dim, hidden_dim, action_dim):super(PolicyNet, self).__init__()self.fc1 = torch.nn.Linear(state_dim, hidden_dim)self.fc2 = torch.nn.Linear(hidden_dim, action_dim)def forward(self, x):x = F.relu(self.fc1(x))return F.softmax(self.fc2(x), dim=1)class ValueNet(torch.nn.Module):def __init__(self, state_dim, hidden_dim):super(ValueNet, self).__init__()self.fc1 = torch.nn.Linear(state_dim, hidden_dim)self.fc2 = torch.nn.Linear(hidden_dim, 1)def forward(self, x):x = F.relu(self.fc1(x))return self.fc2(x)class PPO:''' PPO算法,采用截断方式 '''def __init__(self, state_dim, hidden_dim, action_dim, actor_lr, critic_lr,lmbda, epochs, eps, gamma, device):self.actor = PolicyNet(state_dim, hidden_dim, action_dim).to(device)self.critic = ValueNet(state_dim, hidden_dim).to(device)self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=actor_lr)self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=critic_lr)self.gamma = gammaself.lmbda = lmbdaself.epochs = epochs # 一条序列的数据用于训练轮数self.eps = eps # PPO中截断范围的参数self.device = devicedef take_action(self, state):if isinstance(state, tuple):state = state[0]state = torch.tensor([state], dtype=torch.float).to(self.device)probs = self.actor(state)'''根据概率分布创建一个离散分类分布对象,用于采样离散动作。离散的概率模型。'''action_dist = torch.distributions.Categorical(probs)action = action_dist.sample()return action.item()def update(self, transition_dict): processed_state = []for s in transition_dict['states']:if isinstance(s, tuple):# 如果元素是元组,则取元组的第一个元素processed_state.append(s[0])else:processed_state.append(s)# states = torch.tensor(transition_dict['states'], dtype=torch.float).to(self.device)states = torch.tensor(processed_state, dtype=torch.float).to(self.device)actions = torch.tensor(transition_dict['actions']).view(-1, 1).to(self.device)rewards = torch.tensor(transition_dict['rewards'],dtype=torch.float).view(-1, 1).to(self.device)next_states = torch.tensor(transition_dict['next_states'],dtype=torch.float).to(self.device)dones = torch.tensor(transition_dict['dones'],dtype=torch.float).view(-1, 1).to(self.device)'''计算 TD 目标(即回归目标):td_target=r+γ×V(s′)×(1−done)'''td_target = rewards + self.gamma * self.critic(next_states) * (1 -dones)'''计算 TD 残差(或优势估计的基础):当前状态的 TD 目标减去当前 critic 估计的状态价值。'''td_delta = td_target - self.critic(states)'''调用辅助函数(在 rl_utils 模块中定义)计算优势函数,通常使用广义优势估计(GAE)。'''advantage = rl_utils.compute_advantage(self.gamma, self.lmbda, td_delta.cpu()).to(self.device)'''先将状态输入 actor 网络得到动作概率分布(例如 shape 为 (batch_size, action_dim))。使用 .gather(1, actions) 选出每个样本所执行动作对应的概率(注意 actions 的形状必须匹配)。取对数得到旧的对数概率,再 detach() 阻断梯度流,保存旧策略下的概率值。'''old_log_probs = torch.log(self.actor(states).gather(1, actions)).detach()for _ in range(self.epochs):'''在当前策略下重新计算所有样本的对数概率,与旧对数概率进行比较。'''log_probs = torch.log(self.actor(states).gather(1, actions))'''计算概率比率,即新旧策略的概率之比,用于 PPO 的 clip 损失计算。'''ratio = torch.exp(log_probs - old_log_probs)'''计算无截断的策略目标,乘上优势值。'''surr1 = ratio * advantage'''对 ratio 进行截断,确保其在 [1−ϵ,1+ϵ] 范围内(例如 [0.8, 1.2]),然后乘以优势。'''surr2 = torch.clamp(ratio, 1 - self.eps, 1 + self.eps) * advantage # 截断'''PPO 算法的目标是最大化最小值,因此这里取两者中的较小值再取负号作为损失。对整个 batch 求均值。'''actor_loss = torch.mean(-torch.min(surr1, surr2)) # PPO损失函数'''计算 critic 的均方误差(MSE)损失:当前 critic 估计与 TD 目标之间的误差,对整个 batch 取平均。'''critic_loss = torch.mean(F.mse_loss(self.critic(states), td_target.detach()))self.actor_optimizer.zero_grad()self.critic_optimizer.zero_grad()actor_loss.backward()critic_loss.backward()self.actor_optimizer.step()self.critic_optimizer.step()actor_lr = 1e-3
critic_lr = 1e-2
num_episodes = 250
hidden_dim = 128
gamma = 0.98
lmbda = 0.95
epochs = 10
eps = 0.2
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")env_name = 'CartPole-v0'
env = gym.make(env_name)
if not hasattr(env, 'seed'):def seed_fn(self, seed=None):env.reset(seed=seed)return [seed]env.seed = seed_fn.__get__(env, type(env))
# env.seed(0)
torch.manual_seed(0)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
ppo_agent = PPO(state_dim, hidden_dim, action_dim, actor_lr, critic_lr, lmbda, epochs, eps, gamma, device)return_list = rl_utils.train_on_policy_agent(env, ppo_agent, num_episodes)def sample_expert_data(n_episode):states = []actions = []for episode in range(n_episode):state = env.reset()done = Falsewhile not done:action = ppo_agent.take_action(state)states.append(state)actions.append(action)result = env.step(action)if len(result) == 5:next_state, reward, done, truncated, info = resultdone = done or truncated # 可合并 terminated 和 truncated 标志else:next_state, reward, done, info = result# next_state, reward, done, _ = env.step(action)state = next_stateprocessed_states = []for s in states:if isinstance(s, tuple):# 如果元素是元组,则取元组的第一个元素processed_states.append(s[0])else:processed_states.append(s)return np.array(processed_states), np.array(actions)if not hasattr(env, 'seed'):def seed_fn(self, seed=None):env.reset(seed=seed)return [seed]env.seed = seed_fn.__get__(env, type(env))
# env.seed(0)
torch.manual_seed(0)
random.seed(0)
n_episode = 1
expert_s, expert_a = sample_expert_data(n_episode)n_samples = 30 # 采样30个数据
random_index = random.sample(range(expert_s.shape[0]), n_samples)
expert_s = expert_s[random_index]
expert_a = expert_a[random_index]class BehaviorClone:def __init__(self, state_dim, hidden_dim, action_dim, lr):self.policy = PolicyNet(state_dim, hidden_dim, action_dim).to(device)self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=lr)def learn(self, states, actions):"""解释:定义一个学习函数,接收一批专家数据中的状态和动作,用于更新策略网络。"""states = torch.tensor(states, dtype=torch.float).to(device)actions = torch.tensor(actions).view(-1, 1).to(device)'''- 将 states 输入 policy 网络,得到每个状态下所有动作的概率分布,假设输出形状为 (batch_size, action_dim);- 使用 .gather(1, actions.long()) 从概率分布中取出对应专家动作的概率(注意动作需要转换为长整型索引);- 对这些概率取对数,得到对数概率(log likelihood)。'''log_probs = torch.log(self.policy(states).gather(1, actions.long()))# log_probs = torch.log(self.policy(states).gather(1, actions))'''计算行为克隆的损失,即负对数似然损失。对所有样本的负对数概率取均值。'''bc_loss = torch.mean(-log_probs) # 最大似然估计self.optimizer.zero_grad()bc_loss.backward()self.optimizer.step()def take_action(self, state):if isinstance(state, tuple):state = state[0]state = torch.tensor([state], dtype=torch.float).to(device)probs = self.policy(state)action_dist = torch.distributions.Categorical(probs)action = action_dist.sample()return action.item()def test_agent(agent, env, n_episode):return_list = []for episode in range(n_episode):episode_return = 0state = env.reset()done = Falsewhile not done:action = agent.take_action(state)result = env.step(action)if len(result) == 5:next_state, reward, done, truncated, info = resultdone = done or truncated # 可合并 terminated 和 truncated 标志else:next_state, reward, done, info = result# next_state, reward, done, _ = env.step(action)state = next_stateepisode_return += rewardreturn_list.append(episode_return)return np.mean(return_list)if not hasattr(env, 'seed'):def seed_fn(self, seed=None):env.reset(seed=seed)return [seed]env.seed = seed_fn.__get__(env, type(env))
# env.seed(0)
torch.manual_seed(0)
np.random.seed(0)lr = 1e-3
bc_agent = BehaviorClone(state_dim, hidden_dim, action_dim, lr)
n_iterations = 1000
batch_size = 64
test_returns = []with tqdm(total=n_iterations, desc="进度条") as pbar:for i in range(n_iterations):sample_indices = np.random.randint(low=0, high=expert_s.shape[0], size=batch_size)bc_agent.learn(expert_s[sample_indices], expert_a[sample_indices])current_return = test_agent(bc_agent, env, 5)test_returns.append(current_return)if (i + 1) % 10 == 0:pbar.set_postfix({'return': '%.3f' % np.mean(test_returns[-10:])})pbar.update(1)iteration_list = list(range(len(test_returns)))
plt.plot(iteration_list, test_returns)
plt.xlabel('Iterations')
plt.ylabel('Returns')
plt.title('BC on {}'.format(env_name))
plt.show()class Discriminator(nn.Module):def __init__(self, state_dim, hidden_dim, action_dim):super(Discriminator, self).__init__()self.fc1 = torch.nn.Linear(state_dim + action_dim, hidden_dim)self.fc2 = torch.nn.Linear(hidden_dim, 1)def forward(self, x, a):cat = torch.cat([x, a], dim=1)x = F.relu(self.fc1(cat))return torch.sigmoid(self.fc2(x))class GAIL:def __init__(self, agent, state_dim, action_dim, hidden_dim, lr_d):print(state_dim, action_dim, hidden_dim)self.discriminator = Discriminator(state_dim, hidden_dim, action_dim).to(device)self.discriminator_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr=lr_d)self.agent = agentdef learn(self, expert_s, expert_a, agent_s, agent_a, next_s, dones):expert_states = torch.tensor(expert_s, dtype=torch.float).to(device)expert_actions = torch.tensor(expert_a).to(device)processed_state = []for s in agent_s:if isinstance(s, tuple):# 如果元素是元组,则取元组的第一个元素processed_state.append(s[0])else:processed_state.append(s)agent_states = torch.tensor(processed_state, dtype=torch.float).to(device)agent_actions = torch.tensor(agent_a).to(device)'''作用:将专家动作转换为 one-hot 编码形式,转换为浮点数。'''expert_actions = F.one_hot(expert_actions.long(), num_classes=2).float()agent_actions = F.one_hot(agent_actions.long(), num_classes=2).float()expert_prob = self.discriminator(expert_states, expert_actions)agent_prob = self.discriminator(agent_states, agent_actions)'''作用:计算二元交叉熵损失(BCE):- 对 agent 数据,目标标签设为 1(即希望判别器认为 agent 数据为“真”),损失为 BCE(agent_prob, 1);- 对专家数据,目标标签设为 0(希望判别器认为专家数据为“假”),损失为 BCE(expert_prob, 0)。- 然后将两部分损失相加。'''discriminator_loss = nn.BCELoss()(agent_prob, torch.ones_like(agent_prob)) + nn.BCELoss()(expert_prob, torch.zeros_like(expert_prob))self.discriminator_optimizer.zero_grad()discriminator_loss.backward()self.discriminator_optimizer.step()'''作用:利用判别器对 agent 数据输出计算奖励:- 计算 –log(agent_prob) 作为奖励信号(当 agent_prob 较小时,奖励较高,鼓励 agent 模仿专家);- detach() 阻断梯度,转移到 CPU 并转换为 numpy 数组,方便后续传入 agent.update。'''rewards = -torch.log(agent_prob).detach().cpu().numpy()transition_dict = {'states': agent_s,'actions': agent_a,'rewards': rewards,'next_states': next_s,'dones': dones}self.agent.update(transition_dict)if not hasattr(env, 'seed'):def seed_fn(self, seed=None):env.reset(seed=seed)return [seed]env.seed = seed_fn.__get__(env, type(env))
# env.seed(0)
torch.manual_seed(0)
lr_d = 1e-3
agent = PPO(state_dim, hidden_dim, action_dim, actor_lr, critic_lr, lmbda, epochs, eps, gamma, device)
gail = GAIL(agent, state_dim, action_dim, hidden_dim, lr_d)
n_episode = 500
return_list = []with tqdm(total=n_episode, desc="进度条") as pbar:for i in range(n_episode):episode_return = 0state = env.reset()done = Falsestate_list = []action_list = []next_state_list = []done_list = []while not done:action = agent.take_action(state)result = env.step(action)if len(result) == 5:next_state, reward, done, truncated, info = resultdone = done or truncated # 可合并 terminated 和 truncated 标志else:next_state, reward, done, info = result# next_state, reward, done, _ = env.step(action)state_list.append(state)action_list.append(action)next_state_list.append(next_state)done_list.append(done)state = next_stateepisode_return += rewardreturn_list.append(episode_return)gail.learn(expert_s, expert_a, state_list, action_list, next_state_list, done_list)if (i + 1) % 10 == 0:pbar.set_postfix({'return': '%.3f' % np.mean(return_list[-10:])})pbar.update(1) iteration_list = list(range(len(return_list)))
plt.plot(iteration_list, return_list)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title('GAIL on {}'.format(env_name))
plt.show()