AI玩Flappy Bird || 基于Q-Learning和DQN的机器学习

一、游戏介绍

Flappy Bird 游戏需要玩家控制一只小鸟越过管道障碍物。玩家只可以进行“跳跃”或者“不操作”两种操作,即点或不点。点则让小鸟上升一段距离,不点小鸟继续下降。若小鸟碰到障碍物或地面,则游戏失败。
本项目目的是开发一个深层神经网络模型,具体地,利用图像中的不同对象训练卷积神经网络,进行基于游戏画面场景状态分析进行图像识别分类。从原始像素中学习游戏的特性,并决定采取相应行动,本质上是一个对游戏场景中特定状态的模式识别过程,在此设计了一个强化学习系统,通过自主学习来玩这款游戏。

二、问题定义 

当通过很少预定的行为进行编程不能充分解决问题时,可采用强化学习方式,这是一种通过进行场景训练,使算法在输入未知和多维数据(如彩色图片)时做出正确的决策方式。通过这种方式,算法可以学会自动对图像进行特征提取,对于训练中未出现的场景和状态也同样可以进行分类和预测

三、算法介绍

1、预处理

(1)去除背景颜色——背景使用黑色 

为了节省内存将其缩小为84x84大小的图像,每帧图像色阶都是0-255。 此外,为了提高卷积神经网络的精度,在这一步去除背景层并用纯黑色背景代替,以去除噪声。

 

(2) 对图像进行处理——灰度处理

依次对所得游戏图像进行缩放、灰度化以及调整亮度处理。在当前帧进入一个状态之前,处理几帧图像叠加组合的多维图像数据(如在模型构建部分提到的),当前帧与先前帧重叠时,灰度稍有降低,当我们远离最新帧时强度降低。因此,这样输入的图像将提供关于小鸟当前所在轨迹的良好信息,其处理过程如图2所示。

代码:

2、Q-Learning

(1)贝尔曼方程

强化学习的目标是使总回报(奖励)最大化。在Q-Learning中,它是非策略的,迭代更新使用的是贝尔曼方程,获得Q值的目标值,

其中s′和a′ 分别是下一帧的状态和动作(1或0),r是奖励(-1,0.1,1),γ是折扣因子。Qi(s,a)是为( s , a )矩阵在第i次迭代的Q值。这种更新迭代将收敛得到一个最佳的Q函数。为了防止学习僵化,这个动作值函数可以用一个函数(这里为深度学习网络)近似,以便能更好概括不可预见的状态。 

具体应用:

在训练的每个迭代中,通过以上代码应用贝尔曼方程来计算目标 Q 值(记为y_batch); 在这里,对于每个批次中的样本,目标 Q 值通过应用贝尔曼方程计算得到,考虑了当前奖励以及下一个状态的最大预测Q值。 

(2) 各个重要值的具体设定:

学习率 (lr): 1e-6  

折扣因子 (gamma): 0.99  

初始探索率 (initial_epsilon): 0.1  

最终探索率 (final_epsilon): 1e-4  

每批图像数量 (batch_size): 32  

重放记忆池大小 (replay_memory_size): 3000  

图像预处理尺寸 (image_size): 84  

最大迭代次数 (num_iters): 200000 

然后通过当前模型的预测值和动作来计算当前状态的 Q 值(记为q_value):

这里,将当前模型的预测结果与动作进行乘积,然后对结果进行求和,从而得到当前状态的Q值。 在损失计算部分,使用均方误差损失函数(MSELoss)来衡量模型的训练效果:

 

这里将当前状态的 Q 值(q_value)和目标 Q 值(y_batch)之间的均方误差作为损失值,然后通过反向传播和优化器更新模型参数。

均方误差(mean square error, MSE),是反应估计量与被估计量之间差异程度的一种度量,设t是根据子样确定的总体参数Ɵ的一个估计量,(Ɵ-t) ^2 的数学期望,称为估计量t 的均方误差。(n为样本个数) 

3、神经网络

在当前模型结构中,首先有三个卷积层,然后是两个完全连接层,最终完全连接层的输出是两个动作的得分,结果由损失函数得出。损失函数自动进行Q学习参数设置。遵循空间批量规范,在每个卷积层后都添加ReLu。输入图像的大小84×84,每个时刻有两种可能的输出操作,每次动作将会获得一个得分值,以此决定最佳动作。 

图像resize成84x84大小!

 

 

可以看到,这里的网络使用了连续三个卷积层+两个全连接层的形式。最后输出为2个值,即动作选择。 

4、DQN结构

(1)增加样本池

在Q-Learning中,以连续方式记录的经验数据是高度相关的。若使用相同的顺序更新DQN参数,训练过程就会受到干扰。与从一个标记的数据集中采样小批量训练分类模型类似,这里同样应该在抽取出的获得更新的DQN经验中引入一定的随机性。为此设置一个经验回放存储器,用来存储每帧游戏画面的经验数据,直到达到其最大存储容量。(DQN的一大特点就是设置了数据库,后续的每次训练从数据库中抽取数据。这样可以使得训练更加有效。) 

程序中,使用了一个队列replay_memory来当作经验池,经验池大小replay_memory_size设置为30000,如果数据库容量达到上限,将会把最先进入的数据抛出,即队列的先入先出。 

过大的经验池会占用更多的内存资源,可能导致计算效率低下或资源不足。
过小的经验池可能导致模型无法充分利用历史经验数据,从而影响模型的训练效果和性能。
经验池大小与批次大小(batch size)密切相关,它们共同决定了模型每次更新时能够处理的经验数据数量。批次大小也是一个需要仔细调整的超参数。

(2) 利用神经网络计算Q值

输入状态值,输出为Q值,根据大量的数据去训练神经网络的参数,最终得到Q-Learning的计算模型。

 

这里有三个卷积层(conv1、conv2、conv3),它们依次对输入数据进行处理。每个卷积层都会提取输入数据的某些特征,并将结果传递给下一个层。每次卷积操作后,output变量都会更新为新的特征映射。

在将数据传递给全连接层之前,通常需要将卷积层输出的多维张量(通常是一个四维张量:批量大小 x 通道数 x 高度 x 宽度)展平为一个二维张量(批量大小 x 特征数量)。.view()函数就是用来进行这种展平操作的。在这里,output.size(0)保留了批量大小不变,而-1则让PyTorch自动计算第二个维度的大小,以确保数据的总元素数量不变。

经过展平后,数据被传递给两个全连接层(fc1和fc2)。这些层通常用于对提取的特征进行分类或回归。

四、算法设计

1、 Train.py算法

(1)开启游戏模拟器,会打开一个窗口,实时显示游戏的信息,获取游戏的状态

(2)创建样本池

(3)当训练次数小于设置的迭代次数(300万)时,进入训练 获得的第一个数值, 也就是从神经网络当中的q数值

(4)执行一个随机动作或者神经网络计算的Q(s,a)值选择对应的动作

(5)样本池使用一个大小确定的队列来进行维护,其中存放的是游戏过程中的数据state, action, reward, next_state, terminal

(6)得到下一帧图像进行数据预处理

(7)每执行一次动作,游戏会返回执行该动作之后的一帧图像,把样本池更新, 

2、 DQN设计

(1)初始化Q函数Q,目标Q函数Q ̂= Q对于每一个回合

a)对于每一个时间步iter

探索与利用(随着训练的次数越来越多,Q值函数越来越精确,比较能确定较好的动作,把epsilon的值变小,减少探索,即较少随机决定动作)

b)对于给定的状态state ,基于Q (epsilon - 贪心)执行动作action

c)获得反馈reward,并获得新的状态next_state

d)将(state, action , reward , next_state)存储到缓冲区中(更新经验池)

e)从缓冲区中采样(通常以批量形式)( state, action , reward , next_state)

f)目标值是y = reward + 〖max〗_a Q ̂ (state , action)

(2)更新Q的参数使得Q(state , action)尽可能接近于回归

(3)每C步重置Q ̂=Q

3、Test.py算法

使用train.py每隔50000次训练产生保存的模型,产生游戏对应画面的下一个动作,累计计算得分,直到小鸟掉落或撞管道换下一个模型测试,最后根据每个模型的得分,产生得分曲线图。 

通过get_args()函数获取用户输入的参数。循环测试不同迭代次数下的游戏表现。将每次迭代的得分记录下来,并输出到控制台。使用matplotlib绘制图表,展示迭代次数与游戏得分的关系。最后保存图表为图片文件。 

五、训练结果及分析

200万次

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/34266.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Linux内核系列】:文件系统收尾以及软硬链接详解

🔥 本文专栏:Linux 🌸作者主页:努力努力再努力wz 💪 今日博客励志语录: 世界上只有一种个人英雄主义,那么就是面对生活的种种失败却依然热爱着生活 内容回顾 那么在之前的学习中,我们…

【eNSP实战】三层交换机使用ACL实现网络安全

拓图 要求: vlan1可以访问Internetvlan2和vlan3不能访问Internet和vlan1vlan2和vlan3之间可以互相访问PC配置如图所示,这里不展示 LSW1接口vlan配置 vlan batch 10 20 30 # interface Vlanif1ip address 192.168.40.2 255.255.255.0 # interface Vla…

Trae与Builder模式初体验

说明 下载的国际版:https://www.trae.ai/ 建议 要选新模型 效果 还是挺不错的,遇到问题反馈一下,AI就帮忙解决了,真是动动嘴(打打字就行了),做些小的原型效果或演示Demo很方便呀&#xff…

Canoe Panel常用控件

文章目录 一、Panel 中控件分类1. 指示类控件2. 功能类控件3. 信号值交互类控件4. 其他类控件 二、控件使用方法1. Group Box 控件2. Input/Output Box控件3. Static Text控件4. Button控件5. Switch/Indicator 控件 提示:Button 和 Switch 的区别参考 一、Panel 中…

睡不着运动锻炼贴士

在快节奏的现代生活中,失眠似乎已成为许多人的“夜间伴侣”。夜晚辗转反侧,白天精神不振,这样的恶性循环让许多人苦不堪言。其实,除了调整作息和饮食习惯,适当的运动也是改善睡眠的一剂良药。今天,就让我们…

java数据结构(复杂度)

一.时间复杂度和空间复杂度 1.时间复杂度 衡量一个程序好坏的标准,除了能处理各种异常,还有就是时间效率,当然,对于一些配置好的电脑数据处理起来就是比配置低的高,但从后期发展来看,当数据量足够庞大时&…

NAT和NAPT的介绍

一、NAT的介绍以及作用 二、NAPT的介绍以及作用 三、NAT vs NAPT 一、NAT的介绍以及作用 1.1 NAT的介绍 NAT(Network Address Translation)是一种广泛应用于互联网的技术,主要用于解决IPv4地址耗尽问题,同时提供网络安全和网络…

VSCode通过SSH免密远程登录Windows服务器

系列 1.1 VSCode通过SSH远程登录Windows服务器 1.2 VSCode通过SSH免密远程登录Windows服务器 文章目录 系列1 准备工作2 本地电脑配置2.1 生成密钥2.2 VS Code配置密钥 3. 服务端配置3.1 配置SSH服务器sshd_config3.2 复制公钥3.3 配置权限(常见问题)3.…

大模型训练全流程深度解析

前些天发现了一个巨牛的人工智能学习网站,通俗易懂,风趣幽默,忍不住分享一下给大家。点击跳转到网站。https://www.captainbed.cn/north 文章目录 1. 大模型训练概览1.1 训练流程总览1.2 关键技术指标 2. 数据准备2.1 数据收集与清洗2.2 数据…

export、export default 和 module.exports 深度解析

文章目录 1. 模块系统概述1.1 模块系统对比1.2 模块加载流程 2. ES Modules2.1 export 使用2.2 export default 使用2.3 混合使用 3. CommonJS3.1 module.exports 使用3.2 exports 使用 4. 对比分析4.1 语法对比4.2 使用场景 5. 互操作性5.1 ES Modules 中使用 CommonJS5.2 Com…

AI芯片设计

目的:未来的时代,一定会是AI的时代,那么,AI时代的三个重要组成部分,我要参与其中之一! 参考视频:AI芯片设计第一讲_哔哩哔哩_bilibili 端处理 云端

动手学深度学习:CNN和LeNet

前言 该篇文章记述从零如何实现CNN,以及LeNet对于之前数据集分类的提升效果。 从零实现卷积核 import torch def conv2d(X,k):h,wk.shapeYtorch.zeros((X.shape[0]-h1,X.shape[1]-w1))for i in range(Y.shape[0]):for j in range(Y.shape[1]):Y[i,j](X[i:ih,j:jw…

【开源代码解读】AI检索系统R1-Searcher通过强化学习RL激励大模型LLM的搜索能力

关于R1-Searcher的报告: 第一章:引言 - AI检索系统的技术演进与R1-Searcher的创新定位 1.1 信息检索技术的范式转移 在数字化时代爆发式增长的数据洪流中,信息检索系统正经历从传统关键词匹配到语义理解驱动的根本性变革。根据IDC的统计…

使用Node的http模块创建web服务,给客户端返回html页面时,css失效的根本原因(有助于理解http)

最近正在尝试使用node写后端,使用node创建http服务的时候,碰到了这样的一个问题: 这是我的源代码: import { createServer } from http import { join, dirname, extname } from path import { fileURLToPath } from url import…

JVM 2015/3/15

定义:Java Virtual Machine -java程序的运行环境(java二进制字节码的运行环境) 好处: 一次编写,到处运行 自动内存管理,垃圾回收 数组下标越界检测 多态 比较:jvm/jre/jdk 常见的JVM&…

IP风险度自检,互联网的安全“指南针”

IP地址就像我们的网络“身份证”,而IP风险度则是衡量这个“身份证”安全性的重要指标。它关乎着我们的隐私保护、账号安全以及网络体验,今天就让我们一起深入了解一下IP风险度。 什么是IP风险度 IP风险度是指一个IP地址可能暴露用户真实身份或被网络平台…

【鸿蒙】封装日志工具类 ohos.hilog打印日志

封装一个ohos.hilog打印日志 首先要了解hilog四大日志类型: info、debug、warm、error 方法中四个参数的作用 domain: number tag: string format: string ...args: any[ ] 实例: //普通的info日志,使用info方法来打印 //第一个参数 : 0x0…

走路碎步营养补充贴士

走路碎步,这种步伐不稳的现象,在日常生活中并不罕见,特别是对于一些老年人或身体较为虚弱的人来说,更是一种常见的行走状态。然而,这种现象可能不仅仅是肌肉或骨骼的问题,它还可能是身体在向我们发出营养缺…

Python软件和搭建运行环境

目录 一、Python安装全流程(Windows/Mac/Linux) 1. 下载官方安装包 2. 详细安装步骤(以Windows为例) 3. 环境变量配置(Mac/Linux) 二、虚拟环境管理(关键!) 为什么需…

【蓝桥杯】省赛:神奇闹钟

思路 python做这题很简单,灵活用datetime库即可 code import os import sys# 请在此输入您的代码 import datetimestart datetime.datetime(1970,1,1,0,0,0) for _ in range(int(input())):ls input().split()end datetime.datetime.strptime(ls[0]ls[1],&quo…