基于强化学习算法玩CartPole游戏

什么事CartPole游戏

CartPole(也称为倒立摆问题)是一个经典的控制理论和强化学习的基础问题,通常用于测试和验证控制算法的性能。具体来说,它是一个简单的物理模拟问题,其目标是通过在一个平衡杆(倒立摆)上安装在小车(或称为平衡车)上的水平移动,使杆子保持竖直直立的状态。

有两个动作(action):

左移(0)

右移(1)

四个状态(state): 1. 小车在轨道上的位置 2. 杆子与竖直方向的夹角 3. 小车速度 4. 角度变化率

神经网络设计

1、强化学习的训练网络cartpole_train.py

import  gym
import pygame
import time
import random
import torch
from torch.distributions import Categoricalfrom torch import nn, optim
import torch.nn.functional as Fdef compute_policy_loss(n, log_p):r = list()#构造奖励r列表for i in range(n, 0 ,-1):r.append(i *1.0)r = torch.tensor(r)r = (r - r.mean()) / r.std() #进行标准化处理loss = 0#计算损失函数for pi, ri in zip(log_p, r):loss += -pi * rireturn  lossclass CartPolePolicy(nn.Module):def __init__(self):super(CartPolePolicy, self).__init__()self.fc1 = nn.Linear(in_features = 4, out_features = 128)self.fc2 = nn.Linear(128, 2) #输出为神经元个数为2表示,向左和向向右self.drop = nn.Dropout(p=0.6)def forward(self, x):x = self.fc1(x)x = self.drop(x)x = F.relu(x)x = self.fc2(x)#使用softmax决策最终的行动,是向左还是右return F.softmax(x, dim=1)if __name__ == '__main__':env = gym.make("CartPole-v1") #启动环境env.reset(seed= 543)torch.manual_seed(543)policy = CartPolePolicy() #定义模型optimizer = optim.Adam(policy.parameters(), lr = 0.01) #优化器#我们一共最多训练1000个回合#每个回合最多行动10000次#当某一回合的游戏步数超过5000时,就认为完成训练max_episod = 1000 #最大游戏回合数max_action = 10000 #每回合最大行动数max_steps = 5000 #完成训练的步数for episod in range(1, max_episod + 1):# 对于每一轮循环,都要重新启动一次游戏环境state, _ = env.reset()step = 0log_p = list()for step in range(1, max_action + 1):state = torch.from_numpy(state).float().unsqueeze(0)probs = policy(state) #计算神经网络给出的行动概率# 基于网络给出的概率分布,随机选择行动m = Categorical(probs)# 这里并不是直接使用概率较大的行动,而是通过概率分布生成action, 这样可以进一步探索低概率行动action = m.sample()state, _, done, _, _ = env.step(action.item())if done:break #表示跳出该for循环log_p.append(m.log_prob(action)) #保存每次行动对应的概率分布if step > max_steps: #当step大于最大步数时print(f"Done! last episode {episod} Run steps {step}")break #跳出循序,结束训练#每一回合游戏,都会做一次梯度下降算法optimizer.zero_grad()loss = compute_policy_loss(step, log_p)loss.backward()optimizer.step()if episod % 10 ==0:print(f"Episode {episod} Run step {step}")#保存模型torch.save(policy.state_dict(), f"cartpole_policy.pth")

2、验证:cartpole_eval.py

import  gym
import pygame
import torch.nn as nn
import torch.nn.functional as F
import time
import torch
class CartPolePolicy(nn.Module):def __init__(self):super(CartPolePolicy, self).__init__()self.fc1 = nn.Linear(4, 128)self.fc2 = nn.Linear(128, 2)self.drop = nn.Dropout(p=0.6)def forward(self, x):x = self.fc1(x)x = self.drop(x)x = F.relu(x)x = self.fc2(x)return F.softmax(x, dim=1)if __name__ == '__main__':pygame.init() #初始化pygame#使用gym, 创建一个artPole游戏的运行环境,这个环境是提供给人类玩家使用的env = gym.make('CartPole-v1', render_mode = "human")state, _ =env.reset()#使用env.reset重置环境后,会得到CartPole游戏中关键参数statecart_position = state[0] #小车位置cart_speed = state[1] #小车速度pole_angle = state[2] #杆的角度pole_speed = state[3] #杆的尖端速度#加载网络policy = CartPolePolicy()policy.load_state_dict(torch.load("cartpole_policy.pth"))policy.eval()start_time =time.time()max_action =1000 #设置游戏最大执行次数#最多执行1000次方向键,游戏就可以通关结束step = 0fail = Falsefor step in range(1, max_action + 1):#首先使用time.sleep,使游戏暂停0.3s,用于人的反应,觉得自己反应慢可以设置更长时间# time.sleep(0.3)#小车的控制方式,通过神经网络,来决定小车的运动方向#将环境参数state转为张量state = torch.from_numpy(state).float().unsqueeze(0)#输入至网络模型,计算行动概率probsprobs = policy(state)#选取行动概率最大的行动action =torch.argmax(probs, dim = 1).item()state, _, done, _, _ = env.step(action) #done为True,表示杆倒了if done:fail = Truebreakprint(f"step = {step} action = {action} angle = {state[2]:.2f}  position = {state[0]:.2f}")end_time = time.time()game_time = end_time - start_timeif fail:print(f"Game over ,you play {game_time:.2f} seconds, {step} steps.")else:print(f"Congratulations! you play  {game_time:.2f} seconds, {step} steps.")env.close()

视频讲解:

什么是reinforce强化学习算法,基于强化学习玩CartPole游戏_哔哩哔哩_bilibili

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/393125.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

PCIe学习笔记(15)

设备就绪状态 (Device Readiness Status,DRS)消息 (Device Readiness Status (DRS) 是PCIe规范中引入的一种机制,旨在改进设备初始化和就绪状态的检测与报告。 在以往的PCIe版本中,系统通常依赖于固定的超时机制来判断设备是否已…

【ML】transform 之 encode 及其实现细节

transform 之 encode 及其实现细节 1. transform (seq2seq) 是什么2. transform (seq2seq) 具体如何实现3. transform (seq2seq) 可以解决哪些类型的问题4. 补充问题4.1 残差连接(Residual Connection)是什么如何实现4.1.2 残差连接的具体实现&#xff1…

2024年武汉东湖高新区职称第二批次开始了

众所周知,武汉市东湖高新区职称一年两批次,今年下半年第二批水平能力测试报名也已经开始了,请注意报名时间,别错过!! 2024年武汉东湖高新区第二批次水测报名时间:(一)网上…

第十一章 数据仓库和商务智能 10分

11.1.0语境关系图 11.1 Q 建立数据仓库,有哪些步骤?如何建设?【6 个步骤非常重要!必须知道】 1. 理解需求(P)(目的明确,ETL) (1) 考虑业务目标和业务战略。 (2) 确定业…

FFMPEG 序列帧图片合成视频

需求: 将多张.png图片合成为视频 注意: 1需要Windows电脑 2将图片重命名 下载EXE 官网 https://ffmpeg.org/download.html#build-windows 解压后长这样 将图片和exe放在同一目录下 文件中找个空白地,Shift右键 进入PowerShell 输入命令: ./ffm…

Python 画 等高线图

Python 画 等高线图 flyfish 通过三维图形与投影等高线相结合的方式,能够直观地看到三维函数的形状以及在平面上等值线的分布。 等高线是一种用来表示三维表面在二维平面上的方法。它们是通过在固定高度(或深度)处切割三维表面来创建的平面…

Java零基础之多线程篇:不得不学的并发工具类!

哈喽,各位小伙伴们,你们好呀,我是喵手。运营社区:C站/掘金/腾讯云;欢迎大家常来逛逛 今天我要给大家分享一些自己日常学习到的一些知识点,并以文字的形式跟大家一起交流,互相学习,一…

数学建模--智能算法之鱼群算法

目录 核心原理 应用与实现 实现步骤 性能分析与改进 鱼群算法在解决哪些具体优化问题方面表现最佳? 如何根据不同的应用场景调整鱼群算法的参数设置以提高其性能? 鱼群算法与其他群体智能优化算法(如遗传算法、粒子群优化)…

C++ | Leetcode C++题解之第316题去除重复字母

题目&#xff1a; 题解&#xff1a; class Solution { public:string removeDuplicateLetters(string s) {vector<int> vis(26), num(26);for (char ch : s) {num[ch - a];}string stk;for (char ch : s) {if (!vis[ch - a]) {while (!stk.empty() && stk.back(…

html+css前端作业和平精英2个页面(无js)

htmlcss前端作业和平精英2个页面&#xff08;无js&#xff09;有视频播放器等功能效果 网页作品代码简单&#xff0c;可使用任意HTML编辑软件&#xff08;如&#xff1a;Dreamweaver、HBuilder、Vscode 、Sublime 、Webstorm、Text 、Notepad 等任意html编辑软件进行运行及修改…

lvs的dr模式综合实践

目录 ​编辑虚拟机准备工作 ​编辑​编辑​编辑 配置过程 配置client主机 配置router主机 配置lvs主机&#xff08;vip使用环回来创建&#xff09; 配置server1主机&#xff08;vip使用环回来创建&#xff09; 配置server2主机&#xff08;vip使用环回来创建&#xff0…

《数据结构》(C语言版)第1章 绪论(下)

第1章 绪论 1.3 抽象数据类型的表示与实现1.4 算法与算法分析 1.3 抽象数据类型的表示与实现 数据类型 数据类型是一组性质相同的值的集合, 以及定义于这个集合上的一组运算的总称。 抽象数据类型(ADTs: Abstract Data Types) 更高层次的数据抽象。由用户定义&#xff0c;用…

3DM游戏运行库合集离线安装包2024最新版

3DM游戏运行库合集离线安装包是一款由国内最大的游戏玩家论坛社区3DM推出的集成式游戏运行库合集软件&#xff0c;旨在解决玩家在玩游戏时遇到的运行库缺失或错误问题。该软件包含多种常用的系统运行库组件&#xff0c;支持32位和64位操作系统&#xff0c;能够自动识别系统版本…

快速上手AWS cloudfront产品

AWS CloudFront&#xff0c;亚马逊推出的卓越全球内容分发网络服务&#xff0c;专为加速网站内容的极速传输而设计&#xff0c;旨在大幅度削减加载延迟&#xff0c;同时确保内容传递过程中的高度安全性和无懈可击的可靠性。借助CloudFront的强大功能&#xff0c;用户能够轻松实…

腾讯云服务器windows系统如何转linux系统

本人购买了腾讯云服务&#xff0c;进去后发现是windows系统的&#xff0c;有点郁闷&#xff08;使用不习惯&#xff09;&#xff0c;于是就去查查看看能不能将Windows系统转成linux系统&#xff0c;网上也有解决办法&#xff0c;但是貌似跟现在的腾讯云后台不一致&#xff0c;下…

Flink学习之Flink SQL(补)

Flink SQL 1、SQL客户端 1.1 基本使用 启动yarn-session yarn-session.sh -d启动Flink SQL客户端 sql-client.sh--退出客户端 exit;测试 重启SQL客户端之后&#xff0c;需要重新建表 -- 构建Kafka Source -- 无界流 drop table if exists students_kafka_source; CREATE TABL…

软件生命周期(二)

1. 软件生命周期定义 软件生命周期&#xff08;SDLC&#xff09;是软件开始研制到最终废弃不用所经历的各个阶段 – 软件开发模型 2. 瀑布型生命周期模型 瀑布模型规定自上而下&#xff0c;相互衔接的固定次序&#xff0c;如同瀑布流水&#xff0c;逐级下落&#xff0c;具有…

sqli-labs(超详解)——Lass32~Lass38

Lass32&#xff08;宽字节注入&#xff09; 源码 function check_addslashes($string) {$string preg_replace(/. preg_quote(\\) ./, "\\\\\\", $string); //escape any backslash$string preg_replace(/\/i, \\\, $string); …

double类型 精度丢失的问题

前言 精度丢失的问题是在其他计算机语言中也都会出现&#xff0c;float和double类型的数据在执行二进制浮点运算的时候&#xff0c;并没有提供完全精确的结果。产生误差不在于数的大小&#xff0c;而是因为数的精度。 一、double进行运算时,经常出现精度丢失 0.10.2使用计算…

记录一次网关无响应的排查

1. 使用jstack pid > thread.txt 打印进 thread.txt 文件里 去观察线程的状态。 我发现&#xff0c;一个线程在经过 rateliter的prefilter后, 先是调用 consume方法&#xff0c;获取到锁。 接着在执行 jedis的 evalsha命令时 一直卡在socket.read()的状态。 发现jedis官…