参考王树森《深度强化学习》课程和书籍
1、A2C原理:
Observe a transition: ( s t , a t , r t , s t + 1 ) (s_t,{a_t},r_t,s_{t+1}) (st,at,rt,st+1)
TD target:
y t = r t + γ ⋅ v ( s t + 1 ; w ) . y_{t} = r_{t}+\gamma\cdot v(s_{t+1};\mathbf{w}). yt=rt+γ⋅v(st+1;w).
TD error:
δ t = v ( s t ; w ) − y t . \quad\delta_t = v(s_t;\mathbf{w})-y_t. δt=v(st;w)−yt.
Update the policy network (actor) by:
θ ← θ − β ⋅ δ t ⋅ ∂ ln π ( a t ∣ s t ; θ ) ∂ θ . \mathbf{\theta}\leftarrow\mathbf{\theta}-\beta\cdot\delta_{t}\cdot\frac{\partial\ln\pi(a_{t}\mid s_{t};\mathbf{\theta})}{\partial \mathbf{\theta}}. θ←θ−β⋅δt⋅∂θ∂lnπ(at∣st;θ).
def compute_value_loss(self, bs, blogp_a, br, bd, bns):# 目标价值。with torch.no_grad():target_value = br + self.args.discount * torch.logical_not(bd) * self.V_target(bns).squeeze()# torch.logical_not 对输入张量取逻辑非# 计算value loss。value_loss = F.mse_loss(self.V(bs).squeeze(), target_value)return value_loss
Update the value network (critic) by:
w ← w − α ⋅ δ t ⋅ ∂ v ( s t ; w ) ∂ w . \mathbf{w}\leftarrow\mathbf{w}-\alpha\cdot\delta_{t}\cdot{\frac{\partial{v(s_{t}};\mathbf{w})}{\partial\mathbf{w}}}. w←w−α⋅δt⋅∂w∂v(st;w).
def compute_policy_loss(self, bs, blogp_a, br, bd, bns):# 建议对比08_a2c.py,比较二者的差异。with torch.no_grad():value = self.V(bs).squeeze()policy_loss = 0for i, logp_a in enumerate(blogp_a):policy_loss += -logp_a * value[i]policy_loss = policy_loss.mean()return policy_loss
2、A2C完整代码实现:
参考后修改注释:最初的代码在https://github.com/wangshusen/DRL
"""8.3节A2C算法实现。"""
import argparse
import os
from collections import defaultdict
import gym
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categoricalclass ValueNet(nn.Module):def __init__(self, dim_state):super().__init__()self.fc1 = nn.Linear(dim_state, 64)self.fc2 = nn.Linear(64, 32)self.fc3 = nn.Linear(32, 1)def forward(self, state):x = F.relu(self.fc1(state))x = F.relu(self.fc2(x))x = self.fc3(x)return xclass PolicyNet(nn.Module):def __init__(self, dim_state, num_action):super().__init__()self.fc1 = nn.Linear(dim_state, 64)self.fc2 = nn.Linear(64, 32)self.fc3 = nn.Linear(32, num_action)def forward(self, state):x = F.relu(self.fc1(state))x = F.relu(self.fc2(x))x = self.fc3(x)prob = F.softmax(x, dim=-1)return probclass A2C:def __init__(self, args):self.args = argsself.V = ValueNet(args.dim_state)self.V_target = ValueNet(args.dim_state)self.pi = PolicyNet(args.dim_state, args.num_action)self.V_target.load_state_dict(self.V.state_dict())def get_action(self, state):probs = self.pi(state)m = Categorical(probs)action = m.sample()logp_action = m.log_prob(action)return action, logp_actiondef compute_value_loss(self, bs, blogp_a, br, bd, bns):# 目标价值。with torch.no_grad():target_value = br + self.args.discount * torch.logical_not(bd) * self.V_target(bns).squeeze()# 计算value loss。value_loss = F.mse_loss(self.V(bs).squeeze(), target_value)return value_lossdef compute_policy_loss(self, bs, blogp_a, br, bd, bns):# 目标价值。with torch.no_grad():target_value = br + self.args.discount * torch.logical_not(bd) * self.V_target(bns).squeeze()# 计算policy loss。with torch.no_grad():advantage = target_value - self.V(bs).squeeze()policy_loss = 0for i, logp_a in enumerate(blogp_a):policy_loss += -logp_a * advantage[i]policy_loss = policy_loss.mean()return policy_lossdef soft_update(self, tau=0.01):def soft_update_(target, source, tau_=0.01):for target_param, param in zip(target.parameters(), source.parameters()):target_param.data.copy_(target_param.data * (1.0 - tau_) + param.data * tau_)soft_update_(self.V_target, self.V, tau)class Rollout:def __init__(self):self.state_lst = []self.action_lst = []self.logp_action_lst = []self.reward_lst = []self.done_lst = []self.next_state_lst = []def put(self, state, action, logp_action, reward, done, next_state):self.state_lst.append(state)self.action_lst.append(action)self.logp_action_lst.append(logp_action)self.reward_lst.append(reward)self.done_lst.append(done)self.next_state_lst.append(next_state)def tensor(self):bs = torch.as_tensor(self.state_lst).float()ba = torch.as_tensor(self.action_lst).float()blogp_a = self.logp_action_lstbr = torch.as_tensor(self.reward_lst).float()bd = torch.as_tensor(self.done_lst)bns = torch.as_tensor(self.next_state_lst).float()return bs, ba, blogp_a, br, bd, bnsclass INFO:def __init__(self):self.log = defaultdict(list)self.episode_length = 0self.episode_reward = 0self.max_episode_reward = -float("inf")def put(self, done, reward):if done is True:self.episode_length += 1self.episode_reward += rewardself.log["episode_length"].append(self.episode_length)self.log["episode_reward"].append(self.episode_reward)if self.episode_reward > self.max_episode_reward:self.max_episode_reward = self.episode_rewardself.episode_length = 0self.episode_reward = 0else:self.episode_length += 1self.episode_reward += rewarddef train(args, env, agent: A2C):V_optimizer = torch.optim.Adam(agent.V.parameters(), lr=3e-3)pi_optimizer = torch.optim.Adam(agent.pi.parameters(), lr=3e-3)info = INFO()rollout = Rollout()state, _ = env.reset()for step in range(args.max_steps):action, logp_action = agent.get_action(torch.tensor(state).float())next_state, reward, terminated, truncated, _ = env.step(action.item())done = terminated or truncatedinfo.put(done, reward)rollout.put(state,action,logp_action,reward,done,next_state,)state = next_stateif done is True:# 模型训练。bs, ba, blogp_a, br, bd, bns = rollout.tensor()value_loss = agent.compute_value_loss(bs, blogp_a, br, bd, bns)V_optimizer.zero_grad()value_loss.backward(retain_graph=True)V_optimizer.step()policy_loss = agent.compute_policy_loss(bs, blogp_a, br, bd, bns)pi_optimizer.zero_grad()policy_loss.backward()pi_optimizer.step()agent.soft_update()# 打印信息。info.log["value_loss"].append(value_loss.item())info.log["policy_loss"].append(policy_loss.item())episode_reward = info.log["episode_reward"][-1]episode_length = info.log["episode_length"][-1]value_loss = info.log["value_loss"][-1]print(f"step={step}, reward={episode_reward:.0f}, length={episode_length}, max_reward={info.max_episode_reward}, value_loss={value_loss:.1e}")# 重置环境。state, _ = env.reset()rollout = Rollout()# 保存模型。if episode_reward == info.max_episode_reward:save_path = os.path.join(args.output_dir, "model.bin")torch.save(agent.pi.state_dict(), save_path)if step % 10000 == 0:plt.plot(info.log["value_loss"], label="value loss")plt.legend()plt.savefig(f"{args.output_dir}/value_loss.png", bbox_inches="tight")plt.close()plt.plot(info.log["episode_reward"])plt.savefig(f"{args.output_dir}/episode_reward.png", bbox_inches="tight")plt.close()def eval(args, env, agent):agent = A2C(args)model_path = os.path.join(args.output_dir, "model.bin")agent.pi.load_state_dict(torch.load(model_path))episode_length = 0episode_reward = 0state, _ = env.reset()for i in range(5000):episode_length += 1action, _ = agent.get_action(torch.from_numpy(state))next_state, reward, terminated, truncated, info = env.step(action.item())done = terminated or truncatedepisode_reward += rewardstate = next_stateif done is True:print(f"episode reward={episode_reward}, length={episode_length}")state, _ = env.reset()episode_length = 0episode_reward = 0if __name__ == "__main__":parser = argparse.ArgumentParser()parser.add_argument("--env", default="CartPole-v1", type=str, help="Environment name.")parser.add_argument("--dim_state", default=4, type=int, help="Dimension of state.")parser.add_argument("--num_action", default=2, type=int, help="Number of action.")parser.add_argument("--output_dir", default="output", type=str, help="Output directory.")parser.add_argument("--seed", default=42, type=int, help="Random seed.")parser.add_argument("--max_steps", default=100_000, type=int, help="Maximum steps for interaction.")parser.add_argument("--discount", default=0.99, type=float, help="Discount coefficient.")parser.add_argument("--lr", default=1e-3, type=float, help="Learning rate.")parser.add_argument("--batch_size", default=32, type=int, help="Batch size.")parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")parser.add_argument("--do_train", action="store_true", help="Train policy.")parser.add_argument("--do_eval", action="store_true", help="Evaluate policy.")args = parser.parse_args()env = gym.make(args.env)agent = A2C(args)if args.do_train:train(args, env, agent)if args.do_eval:eval(args, env, agent)
3、torch.distributions.Categorical()
probs = policy_network(state)
# Note that this is equivalent to what used to be called multinomial
m = Categorical(probs) # 用probs构造一个分布
action = m.sample() # 按照probs进行采样
next_state, reward = env.step(action)
loss = -m.log_prob(action) * reward # log_prob 计算log(probs[action])的值
loss.backward()
Probability distributions - torch.distributions — PyTorch 2.0 documentation
next_state, reward = env.step(action)
loss = -m.log_prob(action) * reward # log_prob 计算log(probs[action])的值
loss.backward()
[Probability distributions - torch.distributions — PyTorch 2.0 documentation](https://pytorch.org/docs/stable/distributions.html)[【PyTorch】关于 log_prob(action) - 简书 (jianshu.com)](https://www.jianshu.com/p/06a5c47ee7c2)