1、环境准备,gym的版本为0.26.2
2、编写网络代码
# 导入必要的库
import gym
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from collections import deque
import random# 定义DQN网络
class DQN(nn.Module):def __init__(self, state_size, action_size):super(DQN, self).__init__()# 定义三层全连接网络self.fc1 = nn.Linear(state_size, 24)self.fc2 = nn.Linear(24, 24)self.fc3 = nn.Linear(24, action_size)def forward(self, x):x = torch.relu(self.fc1(x))x = torch.relu(self.fc2(x))return self.fc3(x)# 定义DQN智能体
class DQNAgent:def __init__(self, state_size, action_size):self.state_size = state_sizeself.action_size = action_sizeself.memory = deque(maxlen=2000) # 经验回放池self.gamma = 0.95 # 折扣因子self.epsilon = 1.0 # 探索率self.epsilon_min = 0.01self.epsilon_decay = 0.995self.learning_rate = 0.001self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")self.model = DQN(state_size, action_size).to(self.device)self.optimizer = optim.Adam(self.model.parameters(), lr=self.learning_rate)def remember(self, state, action, reward, next_state, done):# 将经验存储到经验回放池中self.memory.append((state, action, reward, next_state, done))def act(self, state):# ε-贪婪策略选择动作if np.random.rand() <= self.epsilon:return random.randrange(self.action_size)state = torch.FloatTensor(state).unsqueeze(0).to(self.device)act_values = self.model(state)return np.argmax(act_values.cpu().data.numpy())def replay(self, batch_size):# 从经验回放池中随机采样进行学习minibatch = random.sample(self.memory, batch_size)for state, action, reward, next_state, done in minibatch:target = rewardif not done:next_state = torch.FloatTensor(next_state).unsqueeze(0).to(self.device)target = (reward + self.gamma * np.amax(self.model(next_state).cpu().data.numpy()))state = torch.FloatTensor(state).unsqueeze(0).to(self.device)target_f = self.model(state)target_f[0][action] = targetself.optimizer.zero_grad()loss = nn.MSELoss()(self.model(state), target_f)loss.backward()self.optimizer.step()# 更新探索率if self.epsilon > self.epsilon_min:self.epsilon *= self.epsilon_decaydef load(self, name):# 加载模型self.model.load_state_dict(torch.load(name))def save(self, name):# 保存模型torch.save(self.model.state_dict(), name)# 训练函数
def train_dqn():env = gym.make('CartPole-v1')state_size = env.observation_space.shape[0]action_size = env.action_space.nagent = DQNAgent(state_size, action_size)episodes = 1000batch_size = 32for e in range(episodes):state, _ = env.reset() #重置环境,返回初始观察值和初始奖励for time in range(500):action = agent.act(state)next_state, reward, done, _, _ = env.step(action) # 执行动作,返回5个数值reward = reward if not done else -10 # 如果游戏结束,给予负奖励agent.remember(state, action, reward, next_state, done)state = next_stateif done:print(f"episode: {e}/{episodes}, score: {time}, epsilon: {agent.epsilon:.2}")breakif len(agent.memory) > batch_size:agent.replay(batch_size)if e % 100 == 0:agent.save(f"cartpole-dqn-{e}.pth") # 每100回合保存一次模型# 使用训练好的模型玩游戏
def play_cartpole():env = gym.make('CartPole-v1')state_size = env.observation_space.shape[0]action_size = env.action_space.nagent = DQNAgent(state_size, action_size)agent.load("cartpole-dqn-900.pth") # 加载训练好的模型for e in range(10): # 玩10局state, _ = env.reset()for time in range(500):env.render()action = agent.act(state)next_state, reward, done, _, _= env.step(action)state = next_stateif done:print(f"episode: {e}, score: {time}")breakenv.close()if __name__ == '__main__':# 如果要训练模型,取消下面这行的注释# train_dqn()# 如果要使用训练好的模型玩游戏,取消下面这行的注释play_cartpole()
更多解析请参考:https://zhuanlan.zhihu.com/p/29283993
https://zhuanlan.zhihu.com/p/29213893