玩转Atari-Pong游戏
- Atari: 雅达利,最初是一家游戏公司,旗下有超过200款游戏,不过已经破产。在强化学习中,Atari游戏是经典的实验环境之一,因此,本项目旨在学习使用强化学习算法玩Atari游戏。
- Pong: 1972年,雅达利(Atari)创办人布什内尔及达布尼推出首款街机Pong,最初仅生产12部,以简单点线接口仿真打乒乓球的游戏,奠定街机始祖地位。该游戏的简略版英文描述为:
You control the right paddle, you compete against the left paddle controlled by the computer. You each try to keep deflecting the ball away from your goal and into your opponent’s goal.
翻译成中文就是:
你控制右边的球拍,你与电脑控制的左边的球拍竞争。你们各自努力使球不断偏离自己的目标,进入对手的目标。
游戏示意图:
从该动态图可以看出,不经训练的右侧球拍完全打不过左侧球拍的,因此我们的目标就是训练右侧球拍使其战胜左侧球拍。
-
Pong环境的状态、动作与奖励:
- 状态:Pong环境提供的状态默认是
Box(210, 160, 3)
,也就是3通道的彩色图 - 动作:Pong-v0和Pong-V4版本返回的动作都是
Discrete(6)
,也就是离散的6个动作。网上有介绍:Pong 环境介绍,提到其实6个动作中有用的只有3个,可以参考该介绍,加深理解。 - 奖励:奖励有三种状态:-1,0,1,分别表示右侧未接到球;中间过程;左侧未接到球。
- 状态:Pong环境提供的状态默认是
-
训练结果展示:
我们同时提供了动态图Pong-v4_trained.gif
,因为该动态图超过10MB,无法展示,可自行下载观看。
1.Atari环境的安装
在运行man.ipynb之前,请先运行help.ipynb生成我们的依赖环境!!!
目前Ai studio平台并没有内嵌Atari环境,需要我们自行安装,为避免反复安装,我们将安装过程写到help.ipynb。可运行我们的help.ipynb
进行持久化安装。主要的安装命令如下所示:
- ! pip install atari_py==0.2.6 -i https://pypi.tuna.tsinghua.edu.cn/simple -t /home/aistudio/external-libraries
- ! pip install ale-py -i https://pypi.tuna.tsinghua.edu.cn/simple -t /home/aistudio/external-libraries
- ! pip install pyglet -i https://pypi.tuna.tsinghua.edu.cn/simple -t /home/aistudio/external-libraries
- ! pip install autorom -i https://pypi.tuna.tsinghua.edu.cn/simple -t /home/aistudio/external-libraries
- ! pip install AutoROM.accept-rom-license -i https://pypi.tuna.tsinghua.edu.cn/simple -t /home/aistudio/external-libraries
- !rar x Roms.rar
- !python -m atari_py.import_roms ROMS
其中需要注意:第4、5条安装命令可能无法一次成功,多运行几次即可;第6条命令一个项目仅运行一次即可。
2.导入我们的依赖包
注意要先将我们自行安装的Atari环境加入到系统中,即
sys.path.append(‘/home/aistudio/external-libraries’)
import sys
sys.path.append('/home/aistudio/external-libraries')import gym
import numpy as np
import time
import matplotlib.pyplot as plt
import paddle
import os
from collections import deque,Counter
from visualdl import LogWriter
import copy
from collections import Counter
from matplotlib import animation
from PIL import Image
3.环境测试
检测我们是否可以成功加载环境,并查看我们的状态空间和动作空间
env = gym.make('Pong-v4')
print(env.observation_space)
print(env.action_space)
Box(210, 160, 3)
Discrete(6)
4.状态的预处理
在这里我们首先定义了状态的预处理函数preprocess
,该函数说明如下:
- 输入:状态,Pong环境给出的不加任何处理的环境状态,Box(210, 160, 3)
- 处理:处理过程可以看我们下边的过程图片。
- 裁剪:将实际没有用的部分去除,主要是Pong环境返回的图像的上边和下边的部分
- 下采样:在保留特征的前提下进行像素点的缩减
- 擦除背景,在我们下采样后,环境的背景其实是有两种(109,144),这个也需要多观察才能看出,可以参考我们给出的示例图。
- 转为灰度图:非0即1,我们仅保留左右球拍和球,减少不必要因素的干扰
- 打平:将图像打平,进而只使用线性层进行特征学习
4.1 preprocess函数
def preprocess(image):""" 预处理 210x160x3 uint8 frame into 6400 (80x80) 1维 float vector """image = image[35:195] # 裁剪image = image[::2, ::2, 0] # 下采样,缩放2倍image[image == 144] = 0 # 擦除背景 (background type 1)image[image == 109] = 0 # 擦除背景 image[image != 0] = 1 # 转为灰度图,除了黑色外其他都是白色return image.astype(np.float).ravel() #打平,(6400,)
4.2 对preprocess函数进行可视化说明,展示中间过程
def show_image(status):status1=status[35:195] #裁剪有效区域status2 = status1[::2, ::2, 0] #下采样,缩减# 观察我们的像素点构成def see_color(status):allcolor=[]for i in range(80):allcolor.extend(status[i])dict_color=Counter(allcolor)print("像素点构成: ",dict_color)see_color(status2)# 观察好像素点后,擦除背景def togray(image_in):image=image_in.copy()image[image == 144] = 0 # 擦除背景 (background type 1)image[image == 109] = 0 # 擦除背景image[image != 0] = 1 # 转为灰度图,除了黑色外其他都是白色return imagestatus3=togray(status2)# 可视化我们的操作中间图def show_status(list_status):fig = plt.figure(figsize=(8, 32), dpi=200)plt.subplots_adjust(left=None, bottom=None, right=None, top=None,wspace=0.3, hspace=0)for i in range(len(list_status)):plt.subplot(1,len(list_status),i+1)plt.imshow(list_status[i],cmap=plt.cm.binary)plt.show()show_status([status,status1,status2,status3])
4.3 背景为109的preprocess展示
status = env.reset() #原始图
show_image(status)
像素点构成: Counter({109: 6382, 101: 16, 53: 2})/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2349: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop workingif isinstance(obj, collections.Iterator):
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2366: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop workingreturn list(data) if isinstance(data, collections.MappingView) else data
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/image.py:425: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() insteada_min = np.asscalar(a_min.astype(scaled_dtype))
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/image.py:426: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() insteada_max = np.asscalar(a_max.astype(scaled_dtype))
4.4 背景为144的preprocess展示
for i in range(200):action=env.action_space.sample()status,reward,done,info=env.step(action)show_image(status)
像素点构成: Counter({144: 6366, 213: 16, 92: 16, 236: 2})
5.模型的定义,简单的全连接层
class Model(paddle.nn.Layer):""" 使用全连接网络.参数:obs_dim (int): 观测空间的维度.act_dim (int): 动作空间的维度."""def __init__(self, obs_dim, act_dim):super(Model, self).__init__()hid1_size = 256hid2_size = 64self.fc1 = paddle.nn.Linear(obs_dim, hid1_size)self.fc2 = paddle.nn.Linear(hid1_size, hid2_size)self.fc3 = paddle.nn.Linear(hid2_size, act_dim)def forward(self, obs): h1 = paddle.nn.functional.relu(self.fc1(obs))h2 = paddle.nn.functional.relu(self.fc2(h1))prob = paddle.nn.functional.softmax(self.fc3(h2), axis=-1)return prob
6.策略梯度算法
强化学习的经典算法之一,可以参考我们之前的项目【强化学习】REINFORCE算法
在这里我们仅定义预测
和更新
两个函数。
# 梯度下降算法
class PolicyGradient():def __init__(self, model, lr):self.model = modelself.optimizer = paddle.optimizer.Adam(learning_rate=lr, parameters=self.model.parameters())def predict(self, obs):prob = self.model(obs)return probdef learn(self, obs, action, reward):prob = self.model(obs)#print("prob: ",prob)log_prob = paddle.distribution.Categorical(prob).log_prob(action)loss = paddle.mean(-1 * log_prob * reward)self.optimizer.clear_grad()loss.backward()self.optimizer.step()return loss
7.策略梯度智能体
- 我们默认从文件中加载参数进行训练,因为PG算法+Pong环境的训练需要大量的时间,一次直接训练完成很耗时;当然我们支持从0开始训练
- sample: 在训练时调用的函数,带探索
- predict:在预测(测试)时调用的函数,不带探索
- learn:更新函数
- save和load:保存参数和加载参数。注意:这里我们保存了优化器的参数,但是在加载是并未加载上优化器的参数,有报错,未进行修复,但是不加载优化器参数几乎不影响我们的训练的。(这里我其实不太明白到底需不需加载优化器参数,还望大佬不吝赐教,拜谢)
class Agent():def __init__(self, algorithm):self.alg=algorithmif os.path.exists("./savemodel"):print("开始从文件加载参数....")try:self.load()print("从文件加载参数结束....")except:print("从文件加载参数失败,从0开始训练....")def sample(self, obs):""" 根据观测值 obs 采样(带探索)一个动作"""obs = paddle.to_tensor(obs, dtype='float32')prob = self.alg.predict(obs)#print("prob:",prob)prob = prob.numpy()act = np.random.choice(len(prob), 1, p=prob)[0] # 根据动作概率选取动作return actdef predict(self, obs):""" 根据观测值 obs 选择最优动作"""obs = paddle.to_tensor(obs, dtype='float32')prob = self.alg.predict(obs)act = prob.argmax().numpy()[0] # 根据动作概率选择概率最高的动作return actdef learn(self, obs, act, reward):""" 根据训练数据更新一次模型参数"""act = np.expand_dims(act, axis=-1)reward = np.expand_dims(reward, axis=-1)obs = paddle.to_tensor(obs, dtype='float32')act = paddle.to_tensor(act, dtype='int32')reward = paddle.to_tensor(reward, dtype='float32')#print("gggggggggggggg",obs.shape,act.shape,reward.shape)loss = self.alg.learn(obs, act, reward)return loss.numpy()[0]def save(self):paddle.save(self.alg.model.state_dict(),'./savemodel/PG-Pong_net.pdparams')paddle.save(self.alg.optimizer.state_dict(), "./savemodel/opt.pdopt")def load(self):# 加载网络参数model_state_dict=paddle.load('./savemodel/PG-Pong_net.pdparams')self.alg.model.set_state_dict(model_state_dict)# # 加载优化器参数# optimizer_state_dict=paddle.load("./savemodel/opt.pdopt")# self.alg.optimizer.set_state_dict(optimizer_state_dict)
8. 训练与测试
8.1 定义训练函数
# 训练一个episode
def run_train_episode(agent, env):obs_list, action_list, reward_list = [], [], []obs = env.reset()while True:obs = preprocess(obs) # from shape (210, 160, 3) to (6400,)obs_list.append(obs)action = agent.sample(obs)action_list.append(action)obs, reward, done, info = env.step(action)# if reward!=0:# print("reward: ",action)reward_list.append(reward)if done:breakreturn obs_list, action_list, reward_list
8.2 定义预测函数
# 评估 agent, 跑 5 个episode,总reward求平均
def run_evaluate_episodes(agent, env, render=False):eval_reward = []for i in range(5):obs = env.reset()episode_reward = 0while True:obs = preprocess(obs) # from shape (210, 160, 3) to (6400,)action = agent.predict(obs)obs, reward, isOver, _ = env.step(action)episode_reward += rewardif render:env.render()if isOver:breakeval_reward.append(episode_reward)return np.mean(eval_reward)
8.3 定义奖励处理函数
进行奖励衰减操作,衰减因子gamma默认为0.99
def calc_reward_to_go(reward_list, gamma=0.99):"""calculate discounted reward"""reward_arr = np.array(reward_list)for i in range(len(reward_arr) - 2, -1, -1):# G_t = r_t + γ·r_t+1 + ... = r_t + γ·G_t+1reward_arr[i] += gamma * reward_arr[i + 1]# normalize episode rewardsreward_arr -= np.mean(reward_arr)reward_arr /= np.std(reward_arr)return reward_arr
8.4 训练与预测的主函数
便于演示,我们仅进行100次的继续训练,读者可自行增加次数以获得更好的训练效果
def main():env = gym.make('Pong-v4')obs_dim = 80 * 80act_dim = env.action_space.nprint('obs_dim {}, act_dim {}'.format(obs_dim, act_dim))# 根据parl框架构建agentLEARNING_RATE = 5e-4model = Model(obs_dim=obs_dim, act_dim=act_dim)alg = PolicyGradient(model, lr=LEARNING_RATE)agent = Agent(alg)twriter=LogWriter('./logs/PG_Pong')for i in range(100): # default 3000obs_list, action_list, reward_list = run_train_episode(agent, env)twriter.add_scalar('reward',sum(reward_list),i)if i % 50 == 0:print("Episode {}, Reward Sum {}.".format(i, sum(reward_list)))batch_obs = np.array(obs_list)batch_action = np.array(action_list)batch_reward = calc_reward_to_go(reward_list)#print("ggggggggggggg",batch_obs.shape)agent.learn(batch_obs, batch_action, batch_reward)last_test_total_reward=0if (i + 1) % 100 == 0:# render=True 查看显示效果total_reward = run_evaluate_episodes(agent, env, render=False)print('Test reward: {}'.format(total_reward))# save the parametersif last_test_total_reward<total_reward:last_test_total_reward=total_rewardagent.save()# 运行整个程序
main()
obs_dim 6400, act_dim 6W1022 22:01:06.998914 174 gpu_resources.cc:61] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.2, Runtime API Version: 11.2
W1022 22:01:07.003042 174 gpu_resources.cc:91] device: 0, cuDNN Version: 8.2.开始从文件加载参数....
从文件加载参数结束....
Episode 0, Reward Sum 14.0.
Episode 50, Reward Sum 8.0.
Test reward: 12.0
9.使用训练好的网络进行测试并生成动图
9.1 gif动图生成函数
def save_frames_as_gif(frames, filename):#Mess with this to change frame sizeplt.figure(figsize=(frames[0].shape[1]/100, frames[0].shape[0]/100), dpi=300)patch = plt.imshow(frames[0])plt.axis('off')def animate(i):patch.set_data(frames[i])anim = animation.FuncAnimation(plt.gcf(), animate, frames = len(frames), interval=50)anim.save(filename, writer='pillow', fps=60)
9.2 从文件加载模型参数
model=Model(6400,6)
model_state_dict=paddle.load("./savemodel/PG-Pong_net.pdparams")
model.set_state_dict(model_state_dict)
9.4 使用训练好的模型进行测试并保存过程为动图
env=gym.make('Pong-v4')state=env.reset()
frames = []
done=0
i=0
reward_list=[]
while not done:frames.append(env.render(mode="rgb_array"))obs = preprocess(state)obs = paddle.to_tensor(obs, dtype='float32')prob = model(obs)action = prob.argmax().numpy()[0]next_state,reward,done,_=env.step(action)if reward!=0:reward_list.append(reward)print(i," ",reward,done)state=next_statei+=1reward_counter=Counter(reward_list)
print(reward_counter)
print("你的得分为:",reward_counter[1.0],'对手得分为:',reward_counter[-1.0])
if reward_counter[1.0]>reward_counter[-1.0]:print("恭喜您赢了!!!")
else:print("惜败,惜败,训练一下智能体网络再来挑战吧QWQ")save_frames_as_gif(frames, filename="Pong-v4_trained.gif")env.close()
199 1.0 False
732 1.0 False
937 1.0 False
1547 1.0 False
1676 1.0 False
1877 1.0 False
2165 1.0 False
2451 1.0 False
2575 1.0 False
2705 1.0 False
2995 1.0 False
3125 1.0 False
3331 1.0 False
3454 1.0 False
3584 1.0 False
3793 1.0 False
4885 1.0 False
5096 1.0 False
5698 1.0 False
5992 1.0 False
6202 1.0 True
Counter({1.0: 21})
你的得分为: 21 对手得分为: 0
恭喜您赢了!!!
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-dsXN1i1W-1667103194205)(main_files/main_37_1.png)]
10. 总结
本项目参考自飞桨PARL,鼓励大家给点点stars
本项目目前通过5000+回合的训练,我们的智能体已经学会通过快速抖动法
取得游戏的胜利了,但是大概率还不能完全碾压,后续有时间会继续训练或采取更加高效的算法进行改进。然后,这是我的第一个Atari游戏项目,之前都在在经典的控制游戏下进行实验,环境的转变使得学习的难度也上升,训练时间也在增加,学到的东西也在增加,挺好的…还请大佬多多指教,小黑还有很多路要走,嘿嘿!
之前的强化学习项目有:
- DQN+CartPole-v0
- A2C+CartPole-v0
- DDPG+Pendulum-v0
- TD3+Pendulum-v0
- REINFORCE+CartPole-v0
- PPO+CartPole-v0
- SAC+Pendulum-v0
欢迎大家来交流学习!!!
此文章为搬运
原项目链接