强化学习-python案例

强化学习是一种机器学习方法,旨在通过与环境的交互来学习最优策略。它的核心概念是智能体(agent)在环境中采取动作,从而获得奖励或惩罚。智能体的目标是最大化长期奖励,通过试错的方式不断改进其决策策略。

在强化学习中,智能体观察当前状态,选择动作,并根据环境反馈(奖励和下一个状态)调整其策略。常见的强化学习算法包括Q-learning、策略梯度方法和深度强化学习等。强化学习广泛应用于游戏、机器人控制、推荐系统等领域。

  1. 奖励(Reward)
    r t = R ( s t , a t ) r_t = R(s_t, a_t) rt=R(st,at)
    其中 r t r_t rt 是在时间步 t t t 时,智能体在状态 s t s_t st 下采取动作 a t a_t at 所获得的奖励。

  2. 状态价值函数(State Value Function)
    V ( s ) = E [ ∑ t = 0 ∞ γ t r t ∣ s 0 = s ] V(s) = \mathbb{E} \left[ \sum_{t=0}^{\infty} \gamma^t r_t \mid s_0 = s \right] V(s)=E[t=0γtrts0=s]
    其中 V ( s ) V(s) V(s) 是状态 s s s 的价值, γ \gamma γ 是折扣因子 ( 0 ≤ γ < 1 ( 0 \leq \gamma < 1 (0γ<1),表示未来奖励的重要性。

  3. 动作价值函数(Action Value Function)
    Q ( s , a ) = E [ ∑ t = 0 ∞ γ t r t ∣ s 0 = s , a 0 = a ] Q(s, a) = \mathbb{E} \left[ \sum_{t=0}^{\infty} \gamma^t r_t \mid s_0 = s, a_0 = a \right] Q(s,a)=E[t=0γtrts0=s,a0=a]
    其中 Q ( s , a ) Q(s, a) Q(s,a) 是在状态 s s s 下采取动作 a a a 的价值。

  4. 贝尔曼方程(Bellman Equation)

    • 状态价值函数的贝尔曼方程:
      V ( s ) = ∑ a π ( a ∣ s ) ∑ s ′ , r P ( s ′ , r ∣ s , a ) [ r + γ V ( s ′ ) ] V(s) = \sum_{a} \pi(a \mid s) \sum_{s', r} P(s', r \mid s, a) \left[ r + \gamma V(s') \right] V(s)=aπ(as)s,rP(s,rs,a)[r+γV(s)]
    • 动作价值函数的贝尔曼方程:
      Q ( s , a ) = ∑ s ′ , r P ( s ′ , r ∣ s , a ) [ r + γ max ⁡ a ′ Q ( s ′ , a ′ ) ] Q(s, a) = \sum_{s', r} P(s', r \mid s, a) \left[ r + \gamma \max_{a'} Q(s', a') \right] Q(s,a)=s,rP(s,rs,a)[r+γamaxQ(s,a)]
  5. 策略(Policy)
    π ( a ∣ s ) = P ( a ∣ s ) \pi(a \mid s) = P(a \mid s) π(as)=P(as)
    其中 π ( a ∣ s ) \pi(a \mid s) π(as) 是在状态 s s s 下选择动作 a a a 的概率。

目标函数

  1. 策略梯度目标函数
    J ( θ ) = E τ ∼ π θ [ ∑ t = 0 T r t ] J(\theta) = \mathbb{E}_{\tau \sim \pi_\theta} \left[ \sum_{t=0}^{T} r_t \right] J(θ)=Eτπθ[t=0Trt]
    • 说明 J ( θ ) J(\theta) J(θ) 是关于策略参数 θ \theta θ 的目标函数,表示在策略 π θ \pi_\theta πθ 下,执行轨迹 τ \tau τ 的预期总奖励。目标是最大化该期望值,通常通过梯度上升方法进行优化。

损失函数

  1. 策略损失函数(使用REINFORCE算法):
    L ( θ ) = − E τ ∼ π θ [ ∑ t = 0 T r t log ⁡ π θ ( a t ∣ s t ) ] L(\theta) = -\mathbb{E}_{\tau \sim \pi_\theta} \left[ \sum_{t=0}^{T} r_t \log \pi_\theta(a_t \mid s_t) \right] L(θ)=Eτπθ[t=0Trtlogπθ(atst)]

    • 说明:这个损失函数的目的是最小化负的期望总奖励。通过优化该损失函数,可以最大化目标函数 J ( θ ) J(\theta) J(θ)。这里的 log ⁡ π θ ( a t ∣ s t ) \log \pi_\theta(a_t \mid s_t) logπθ(atst) 是对策略的对数概率,表示在状态 s t s_t st 下采取动作 a t a_t at 的可能性。
  2. 价值函数损失(对于Q-learning):
    L ( θ ) = E [ ( r t + γ max ⁡ a ′ Q ( s ′ , a ′ ; θ ) − Q ( s , a ; θ ) ) 2 ] L(\theta) = \mathbb{E} \left[ \left( r_t + \gamma \max_{a'} Q(s', a'; \theta) - Q(s, a; \theta) \right)^2 \right] L(θ)=E[(rt+γamaxQ(s,a;θ)Q(s,a;θ))2]

    • 说明:该损失函数用于最小化当前动作价值函数 Q ( s , a ; θ ) Q(s, a; \theta) Q(s,a;θ) 和目标价值 r t + γ max ⁡ a ′ Q ( s ′ , a ′ ; θ ) r_t + \gamma \max_{a'} Q(s', a'; \theta) rt+γmaxaQ(s,a;θ) 之间的均方误差。通过最小化该损失,更新网络参数 θ \theta θ 以更准确地预测价值。

细节总结

  • 目标函数:用于衡量当前策略的性能,指导优化过程。强化学习的目标是通过更新策略来最大化期望奖励。
  • 损失函数:是优化过程中实际最小化的函数,直接反映模型的学习效果。损失函数的设计直接影响学习的效率和效果。

这些公式是强化学习中策略优化和价值评估的核心,理解它们有助于深入掌握强化学习的理论基础和应用。

代码

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np# 环境假设
class SimpleEnv:def reset(self):return np.random.rand(4)  # 随机状态def step(self, action):next_state = np.random.rand(4)reward = np.random.rand()  # 随机奖励done = np.random.rand() > 0.9  # 随机结束return next_state, reward, done# 策略网络
class PolicyNetwork(nn.Module):def __init__(self):super(PolicyNetwork, self).__init__()self.fc = nn.Sequential(nn.Linear(4, 128),nn.ReLU(),nn.Linear(128, 2),  # 假设有两个动作)def forward(self, x):return torch.softmax(self.fc(x), dim=-1)# 计算折扣奖励
def compute_discounted_rewards(rewards, discount_factor=0.99):discounted_rewards = []cumulative_reward = 0for r in reversed(rewards):cumulative_reward = r + cumulative_reward * discount_factordiscounted_rewards.insert(0, cumulative_reward)return discounted_rewards# 训练函数
def train(env, policy_net, optimizer, episodes=1000):for episode in range(episodes):state = env.reset()rewards = []log_probs = []while True:state_tensor = torch.FloatTensor(state)probs = policy_net(state_tensor)action = np.random.choice(len(probs), p=probs.detach().numpy())log_prob = torch.log(probs[action])next_state, reward, done = env.step(action)log_probs.append(log_prob)rewards.append(reward)state = next_stateif done:break# 计算折扣奖励discounted_rewards = compute_discounted_rewards(rewards)# 更新策略optimizer.zero_grad()loss = -sum(log_prob * reward for log_prob, reward in zip(log_probs, discounted_rewards))loss.backward()optimizer.step()# 输出每个回合的总奖励total_reward = sum(rewards)print(f"Episode {episode + 1}, Total Reward: {total_reward:.2f}")# 测试函数
def test(env, policy_net, episodes=10):for episode in range(episodes):state = env.reset()total_reward = 0while True:state_tensor = torch.FloatTensor(state)with torch.no_grad():probs = policy_net(state_tensor)action = torch.argmax(probs).item()next_state, reward, done = env.step(action)total_reward += rewardstate = next_stateif done:breakprint(f"Test Episode {episode + 1}, Total Reward: {total_reward:.2f}")# 主程序
env = SimpleEnv()
policy_net = PolicyNetwork()
optimizer = optim.Adam(policy_net.parameters(), lr=0.01)train(env, policy_net, optimizer)
test(env, policy_net)

在这里插入图片描述

训练奖励图:显示每个训练回合的总奖励变化,帮助评估模型在训练过程中的学习效果。
测试奖励图:展示在测试回合中模型的总奖励,反映训练后的表现。

代码结构

  1. 环境(Environment)

    • SimpleEnv 类:模拟一个简单的环境,包含 resetstep 方法。
      • reset():初始化并返回一个随机状态。
      • step(action):根据所采取的动作返回下一个状态、奖励和是否结束标志。
      • 奖励和结束状态是随机生成的,模拟了一个非常简化的环境。
  2. 策略网络(Policy Network)

    • PolicyNetwork 类:定义一个神经网络,用于近似策略。
      • 使用全连接层,输入状态维度为 4(环境状态的维度),输出动作概率的维度为 2(假设有两个可能的动作)。
      • forward 方法通过 softmax 函数输出每个动作的概率。
  3. 折扣奖励计算

    • compute_discounted_rewards(rewards, discount_factor=0.99):计算每个时间步的折扣奖励。
      • 从后往前遍历奖励列表,使用折扣因子更新累计奖励,生成折扣奖励列表。
  4. 训练函数(Training Function)

    • train(env, policy_net, optimizer, episodes=1000):进行训练的主函数。
      • 循环执行指定的回合数:
        • 重置环境,初始化奖励和日志概率列表。
        • 在回合中循环,使用当前状态选择动作并记录日志概率和奖励。
        • 计算并更新策略网络的损失,使用反向传播更新参数。
        • 每个回合结束后打印总奖励,帮助监控训练进度。
  5. 测试函数(Testing Function)

    • test(env, policy_net, episodes=10):用于评估训练后模型表现的函数。
      • 重置环境并执行多个测试回合,选择最大概率的动作。
      • 累计并打印每个测试回合的总奖励,评估训练的效果。
  6. 主程序

    • 创建环境和策略网络实例,定义优化器(Adam)。
    • 调用训练函数进行训练,然后调用测试函数进行评估。

整体逻辑

  1. 环境设置:定义了一个非常简单的环境,主要用于演示如何应用策略梯度方法。实际应用中,可以替换为更复杂的环境,比如OpenAI的Gym库中的环境。

  2. 策略学习:使用神经网络近似策略,通过与环境的交互收集状态、动作、奖励,并更新网络参数,以优化策略。

  3. 输出和评估:通过在训练过程中的总奖励输出和测试过程中的评估,可以观察到模型的学习进展。

小结

这段代码是一个简单的强化学习示例,展示了如何使用策略梯度方法和PyTorch进行训练和测试。虽然环境和任务是简化的,但它提供了一个良好的基础,便于理解强化学习的核心概念和实现。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/436267.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Linux操作系统中MongoDB

1、什么是MongoDB 1、非关系型数据库 NoSQL&#xff0c;泛指非关系型的数据库。随着互联网web2.0网站的兴起&#xff0c;传统的关系数据库在处理web2.0网站&#xff0c;特别是超大规模和高并发的SNS类型的web2.0纯动态网站已经显得力不从心&#xff0c;出现了很多难以克服的问…

sysbench 命令:跨平台的基准测试工具

一、命令简介 sysbench 是一个跨平台的基准测试工具&#xff0c;用于评估系统性能&#xff0c;包括 CPU、内存、文件 I/O、数据库等性能。 ‍ 比较同类测试工具 bench.sh 在上文 bench.sh&#xff1a;Linux 服务器基准测试中介绍了 bench.sh 一键测试脚本&#xff0c;它对…

曲线图异常波形检测系统源码分享

曲线图异常波形检测检测系统源码分享 [一条龙教学YOLOV8标注好的数据集一键训练_70全套改进创新点发刊_Web前端展示] 1.研究背景与意义 项目参考AAAI Association for the Advancement of Artificial Intelligence 项目来源AACV Association for the Advancement of Comput…

华为OD机试 - 最长元音子串的长度(Python/JS/C/C++ 2024 E卷 100分)

华为OD机试 2024E卷题库疯狂收录中&#xff0c;刷题点这里 专栏导读 本专栏收录于《华为OD机试真题&#xff08;Python/JS/C/C&#xff09;》。 刷的越多&#xff0c;抽中的概率越大&#xff0c;私信哪吒&#xff0c;备注华为OD&#xff0c;加入华为OD刷题交流群&#xff0c;…

Redis入门第三步:Redis事务处理

欢迎继续跟随《Redis新手指南&#xff1a;从入门到精通》专栏的步伐&#xff01;在本文中&#xff0c;我们将探讨Redis的事务处理机制。了解如何使用事务来保证一系列操作的原子性和一致性&#xff0c;这对于构建可靠的应用程序至关重要 1 什么是Redis事务&#x1f340; ​ R…

解锁数据宝藏:AI驱动搜索工具,让非结构化数据“说话

哈哈,说起这个 AI 搜索演示啊,那可真是个有意思的话题!非结构化数据,这家伙虽然难搞,但价值却是杠杠的。今天呢,咱就好好聊聊怎么借助 Fivetran 和 Milvus,快速搭建一个 AI 驱动的搜索工具,让企业能从那些乱七八糟的数据里淘到金子! 一、非结构化数据的挑战与机遇 首…

堆【数据结构C语言版】【 详解】

目录-笔记整理 一、思考二、堆概念与性质三、堆的构建、删除、添加1. 构建2. 删除3. 添加 四、复杂度分析4.1 时间复杂度4.2 空间复杂度 五、总结 一、思考 设计一种数据结构&#xff0c;来存放整数&#xff0c;要求三个接口&#xff1a; 1&#xff09;获取序列中的最值&#…

Thinkphp/Laravel旅游景区预约系统的设计与实现

目录 技术栈和环境说明具体实现截图设计思路关键技术课题的重点和难点&#xff1a;框架介绍数据访问方式PHP核心代码部分展示代码目录结构解析系统测试详细视频演示源码获取 技术栈和环境说明 采用PHP语言开发&#xff0c;开发环境为phpstudy 开发工具notepad并使用MYSQL数据库…

景联文科技入选《2024中国AI大模型产业图谱2.0版》数据集代表厂商

近日&#xff0c;大数据产业领域头部媒体数据猿携手上海大数据联盟联合发布了备受瞩目的《2024中国AI大模型产业图谱2.0版》。以大数据与AI为代表的智能技术为主要视角&#xff0c;聚焦全产业链&#xff0c;为业内提供更为专业直观的行业指导。 景联文科技凭借高质量数据集&…

基于大数据的学生体质健康信息系统

作者&#xff1a;计算机学姐 开发技术&#xff1a;SpringBoot、SSM、Vue、MySQL、JSP、ElementUI、Python、小程序等&#xff0c;“文末源码”。 专栏推荐&#xff1a;前后端分离项目源码、SpringBoot项目源码、Vue项目源码、SSM项目源码、微信小程序源码 精品专栏&#xff1a;…

Vue Mini基于 Vue 3 的小程序框架

新的小程序框架 https://vuemini.org/ Vue Mini 是一个基于 Vue 3 的小程序框架&#xff0c;它允许开发者利用 Vue 3 的强大功能来构建微信小程序。Vue Mini 的核心优势在于它的响应式系统和组合式 API&#xff0c;这些特性让开发者能够以一种更声明式、更高效的方式来编写和…

江科大笔记——新建工程

STM32的开发方式 目前STM32的开发方式主要有基于寄存器的方式、基于标准库的方式&#xff08;库函数的方式&#xff09;、基于HAL库的方式&#xff1a; 基于库函数的方式是使用ST官方提供的封装好的函数&#xff0c;通过调用这些函数来间接地配置寄存器。基于HAL库的方式可以…

【机器学习(七)】分类和回归任务-K-近邻 (KNN)算法-Sentosa_DSML社区版

文章目录 一、算法概念二、算法原理&#xff08;一&#xff09;K值选择&#xff08;二&#xff09;距离度量1、欧式距离2、曼哈顿距离3、闵可夫斯基距离 &#xff08;三&#xff09;决策规则1、分类决策规则2、回归决策规则 三、算法优缺点优点缺点 四、KNN分类任务实现对比&am…

【CKA】二、节点管理-设置节点不可用

2、节点管理-设置节点不可用 1. 考题内容&#xff1a; 2. 答题思路&#xff1a; 先设置节点不可用&#xff0c;然后驱逐节点上的pod 这道题就两条命令&#xff0c;直接背熟就行。 也可以查看帮助 kubectl cordon -h kubectl drain -h 参数详情&#xff1a; –delete-empty…

【COSMO-SkyMed系列的4颗卫星主要用途】

COSMO-SkyMed系列的4颗卫星主要用于提供一个多用途的对地观测平台&#xff0c;服务于民间、公共机构、军事和商业领域。以下是这4颗卫星的主要用途&#xff1a; 民防与环境风险管理&#xff1a; 卫星的高分辨率雷达图像可用于监测自然灾害&#xff0c;如地震、洪水、滑坡等&am…

【计算机网络】网络层详解

文章目录 一、引言二、IP 基础知识1、IP 地址2、路由3、IP报文4、IP报文的分片与重组 三、IP 属于面向无连接型四、IP协议相关技术1、DNS2、ICMP3、NAT技术4、DHCP 一、引言 TCP/IP的心脏是网络层。这一层主要由 IP 和 ICMP 两个协议组成。网络层的主要作用是“实现终端节点之…

Redis进阶篇 - 缓存穿透、缓存击穿、缓存雪崩问题及其解决方案

文章目录 1 文章概述2 缓存穿透2.1 什么是缓存穿透&#xff1f;2.2 缓存穿透的解决方法2.2.1 做好参数校验2.2.2 缓存无效Key2.2.3 使用布隆过滤器2.2.4 接口限流 3 缓存击穿3.1 什么是缓存击穿&#xff1f;3.2 缓存击穿的解决方法3.2.1 调整热点数据过期时间3.2.2 热点数据预热…

Postgresql怎么查询数据库中所有的表,odoo17数据库最依赖表整理

今天遇到了一个需求,需要梳理odoo中数据库表的分类,所以想要知道怎么查询当前数据库中所有的表,特此记录. 一个简单的SQL语句: select * from pg_tables;得到的结果如下: 显然这个有点杂乱,我们换一个SQL语句: select tablename from pg_tables where schemanamepublic不过…

软件测试学习笔记丨Mock的价值与实战

本文转自测试人社区&#xff0c;原文链接&#xff1a;https://ceshiren.com/t/topic/32331 一、Mock的价值与意义 1.1 简介 测试过程中&#xff0c;对于一些不容易构造或获取的对象&#xff0c;用一个虚拟的对象来替代它&#xff0c;达到相同的效果&#xff0c;这个虚拟的对象…

启动服务并登录MySQL9数据库

【图书推荐】《MySQL 9从入门到性能优化&#xff08;视频教学版&#xff09;》-CSDN博客 《MySQL 9从入门到性能优化&#xff08;视频教学版&#xff09;&#xff08;数据库技术丛书&#xff09;》(王英英)【摘要 书评 试读】- 京东图书 (jd.com) Windows平台下安装与配置MyS…