策略梯度 (Policy Gradient):直接优化策略的强化学习方法

策略梯度 (Policy Gradient) 是强化学习中的一种方法,用于优化智能体的策略,使其在给定环境中表现得更好。与值函数方法(如 Q-learning)不同,策略梯度方法直接对策略进行优化,而不是通过学习一个值函数来间接估计最优策略。

核心思想:

在策略梯度方法中,智能体的策略是一个参数化的函数(通常是神经网络),通过梯度上升法来优化该策略的参数,使得智能体在与环境互动时获得最大的预期奖励。该方法通过计算策略相对于策略参数的梯度来更新策略参数,从而改善智能体的行为。

实现方式:

  1. 收集经验: 智能体与环境互动,收集状态-动作对以及相应的奖励。
  2. 计算梯度: 基于当前策略和收集到的经验,计算梯度。
  3. 更新策略: 使用计算出的梯度更新策略参数。

优点:

  • 可以直接优化策略,适用于连续动作空间。
  • 不依赖于环境的价值函数,适用于部分可观测或高维的状态空间。

缺点:

  • 策略梯度的估计通常具有较高的方差,需要更多的样本来获得稳定的结果。
  • 收敛速度较慢,可能需要更多的计算资源。

简单例子:

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt# 1D迷宫环境,目标是从位置0移动到位置10
class SimpleMazeEnv:def __init__(self):self.state = 0  # 初始位置self.target = 10  # 目标位置self.max_steps = 20  # 最大步数def reset(self):self.state = 0return self.statedef step(self, action):if action == 0:  # 向左移动self.state = max(0, self.state - 1)elif action == 1:  # 向右移动self.state = min(self.target, self.state + 1)# 计算奖励,靠近目标位置时奖励更高reward = -abs(self.state - self.target)  # 离目标越远奖励越低done = (self.state == self.target)  # 到达目标时结束return self.state, reward, done# 策略网络
class PolicyNetwork(nn.Module):def __init__(self, input_dim, output_dim):super(PolicyNetwork, self).__init__()self.fc1 = nn.Linear(input_dim, 128)self.fc2 = nn.Linear(128, 128)self.fc3 = nn.Linear(128, output_dim)self.softmax = nn.Softmax(dim=-1)def forward(self, x):x = torch.relu(self.fc1(x))x = torch.relu(self.fc2(x))x = self.fc3(x)return self.softmax(x)# 策略梯度算法(REINFORCE)
def reinforce(env, policy, optimizer, episodes=1000, gamma=0.99):episode_rewards = []best_reward = -float('inf')best_path = []for episode in range(episodes):state = env.reset()state = torch.tensor([state], dtype=torch.float32)done = Falserewards = []log_probs = []path = []  # 记录当前回合的路径while not done:# 选择动作action_probs = policy(state)dist = torch.distributions.Categorical(action_probs)action = dist.sample()# 执行动作并观察结果next_state, reward, done = env.step(action.item())next_state = torch.tensor([next_state], dtype=torch.float32)# 保存奖励和动作的log概率rewards.append(reward)log_probs.append(dist.log_prob(action))path.append(state.item())  # 记录当前位置state = next_state# 计算回报returns = []G = 0for r in reversed(rewards):G = r + gamma * Greturns.insert(0, G)# 计算损失并更新模型returns = torch.tensor(returns, dtype=torch.float32)log_probs = torch.stack(log_probs)loss = -torch.sum(log_probs * returns)optimizer.zero_grad()loss.backward()optimizer.step()total_reward = sum(rewards)episode_rewards.append(total_reward)if total_reward > best_reward:best_reward = total_rewardbest_path = pathif (episode + 1) % 100 == 0:print(f"Episode {episode + 1}, Total Reward: {total_reward}, Best Reward: {best_reward}")return episode_rewards, best_path# 初始化环境和模型
env = SimpleMazeEnv()
input_dim = 1  # 状态是一个标量
output_dim = 2  # 动作是向左或向右
policy = PolicyNetwork(input_dim, output_dim)
optimizer = optim.Adam(policy.parameters(), lr=0.001)# 训练模型
episode_rewards, best_path = reinforce(env, policy, optimizer, episodes=1000)# 可视化训练结果
plt.figure(figsize=(12, 6))# 绘制奖励曲线
plt.subplot(1, 2, 1)
plt.plot(episode_rewards)
plt.xlabel('Episode')
plt.ylabel('Total Reward')
plt.title('Training Progress')# 绘制最优路径图
plt.subplot(1, 2, 2)
plt.plot(best_path, marker='o', markersize=5, label="Best Path")
for i, coord in enumerate(best_path):plt.text(i, coord, f"({i}, {coord})", fontsize=8)  # 显示坐标
plt.xlabel('Steps')
plt.ylabel('State')
plt.title('Best Path Taken')
plt.legend()plt.tight_layout()
plt.show()
  1. 环境SimpleMazeEnv是一个非常简单的1D迷宫环境,智能体的目标是从位置0移动到目标位置10。每步,智能体可以选择向左或向右移动。
  2. 策略网络PolicyNetwork是一个简单的神经网络,输出的是两个动作的概率(向左和向右)。
  3. 训练过程:采用策略梯度算法(REINFORCE),在每一轮训练中,智能体根据当前策略选择动作,通过累积奖励(回报)来更新策略网络。
  4. 奖励:智能体的奖励是与目标位置的距离成反比,离目标越近奖励越高。

预期效果:

  • 训练过程:每个回合的奖励会逐渐增加,智能体会逐步学习到正确的动作。
  • 可视化:我们会看到训练过程中每个回合的奖励曲线,以及最优路径(即智能体最终到达目标位置时的移动轨迹)。

运行后的图:

  • 左图:训练过程中的奖励变化。
  • 右图:最优路径的轨迹图,标记了每一步的位置。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/10007.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

新版231普通阿里滑块 自动化和逆向实现 分析

声明: 本文章中所有内容仅供学习交流使用,不用于其他任何目的,抓包内容、敏感网址、数据接口等均已做脱敏处理,严禁用于商业用途和非法用途,否则由此产生的一切后果均与作者无关! 逆向过程 补环境逆向 部分补环境 …

探索AI(chatgpt、文心一言、kimi等)提示词的奥秘

大家好,我是老六哥,我正在共享使用AI提高工作效率的技巧。欢迎关注我,共同提高使用AI的技能,让AI成功你的个人助理。 "AI提示词究竟是什么?" 这是许多初学者在接触AI时的共同疑问。 "我阅读了大量关于…

商密测评题库详解:商用密码应用安全性评估从业人员考核题库详细解析(9)

1. 申请商用密码测评机构需提交材料考点 根据《商用密码应用安全性测评机构管理办法(试行)》,申请成为商用密码应用安全性测评机构的单位应当提交的材料不包括( )。 A. 从事与普通密码相关工作情况的说明 B. 开展测评工作所需的软硬件及其他服务保障设施配备情况 C. 管…

Flink中的时间和窗口

在批处理统计中,我们可以等待一批数据都到齐后,统一处理。但是在实时处理统计中,我们是来一条就得处理一条,那么我们怎么统计最近一段时间内的数据呢?引入“窗口”。 所谓的“窗口”,一般就是划定的一段时…

Linux 进程概念

目录 一、前言 二、概念实例,正在执行的程序等 三、描述进程-PCB 四、组织进程 五、查看进程 ​编辑六、通过系统调用获取进程标示符 七、进程切换和上下文数据 1.进程切换 2.上下文数据 一、前言 在Linux中,每个执行的程序叫做进程&#xff…

allegro修改封闭图形线宽

说在前面 我们先把最优解说在前面,然后后面再说如果当时不熟悉软件的时候为了挖孔是用了shapes该怎么修改回来。 挖空最方便的方式是在cutout层画一个圆弧,下面开始图解,先add一个圆弧 z 最好是在画的时候就选择好层,如果忘记了后续再换回去也行,但好像软件有bug,此处并…

使用openwrt搭建ipsec隧道

背景:最近同事遇到了个ipsec问题,做的ipsec特性,ftp下载ipv6性能只有100kb, 正面定位该问题也蛮久了,项目没有用openwrt, 不过用了开源组件strongswan, 加密算法这些也是内核自带的,想着开源的不太可能有问题&#xff…

Day29(补)-【AI思考】-精准突围策略——从“时间贫困“到“效率自由“的逆袭方案

文章目录 精准突围策略——从"时间贫困"到"效率自由"的逆袭方案**第一步:目标熵减工程(建立四维坐标)** 与其他学习方法的结合**第二步:清华方法本土化移植** 与其他工具对比**~~第三步:游戏化改造…

手撕Diffusion系列 - 第十一期 - lora微调 - 基于Stable Diffusion(代码)

手撕Diffusion系列 - 第十一期 - lora微调 - 基于Stable Diffusion(代码) 目录 手撕Diffusion系列 - 第十一期 - lora微调 - 基于Stable Diffusion(代码)Stable Diffusion 原理图Stable Diffusion的原理解释Stable Diffusion 和Di…

vscode+WSL2(ubuntu22.04)+pytorch+conda+cuda+cudnn安装系列

最近在家过年闲的没事,于是研究起深度学习开发工具链的配置和安装,之前欲与天公试比高,尝试在win上用vscodecuda11.6vs2019的cl编译器搭建cuda c编程环境,最后惨败,沦为笑柄,痛定思痛,这次直接和…

【ESP32】ESP-IDF开发 | WiFi开发 | TCP传输控制协议 + TCP服务器和客户端例程

1. 简介 TCP(Transmission Control Protocol),全称传输控制协议。它的特点有以下几点:面向连接,每一个TCP连接只能是点对点的(一对一);提供可靠交付服务;提供全双工通信&…

AI时序预测: iTransformer算法代码深度解析

在之前的文章中,我对iTransformer的Paper进行了详细解析,具体文章如下: 文章链接:深度解析iTransformer:维度倒置与高效注意力机制的结合 今天,我将对iTransformer代码进行解析。回顾Paper,我…

某盾Blackbox参数参数逆向

以前叫同盾,现在改名了,叫小盾安全,好像不做验证码了

docker中运行的MySQL怎么修改密码

1,进入MySQL容器 docker exec -it 容器名 bash 我运行了 docker ps命令查看。正在运行的容器名称。可以看到MySQL的我起名为db docker exec -it db bash 这样就成功的进入到容器中了。 2,登录MySQL中 mysql -u 用户名 -p 回车 密码 mysql -u root -p roo…

春节期间,景区和酒店如何合理用工?

春节期间,景区和酒店如何合理用工? 春节期间,旅游市场将迎来高峰期。景区与酒店,作为旅游产业链中的两大核心环节,承载着无数游客的欢乐与期待。然而,也隐藏着用工管理的巨大挑战。如何合理安排人力资源&a…

初始化mysql报错cannot open shared object file: No such file or directory

报错展示 我在初始化msyql的时候报错:mysqld: error while loading shared libraries: libaio.so.1: cannot open shared object file: No such file or directory 解读: libaio包的作用是为了支持同步I/O。对于数据库之类的系统特别重要,因此…

C语言------数组从入门到精通

1.一维数组 目标:通过思维导图了解学习一维数组的核心知识点: 1.1定义 使用 类型名 数组名[数组长度]; 定义数组。 // 示例: int arr[5]; 1.2一维数组初始化 数组的初始化可以分为静态初始化和动态初始化两种方式。 它们的主要区别在于初始化的时机和内存分配的方…

Docker/K8S

文章目录 项目地址一、Docker1.1 创建一个Node服务image1.2 volume1.3 网络1.4 docker compose 二、K8S2.1 集群组成2.2 Pod1. 如何使用Pod(1) 运行一个pod(2) 运行多个pod 2.3 pod的生命周期2.4 pod中的容器1. 容器的生命周期2. 生命周期的回调3. 容器重启策略4. 自定义容器启…

【开源免费】基于SpringBoot+Vue.JS公交线路查询系统(JAVA毕业设计)

本文项目编号 T 164 ,文末自助获取源码 \color{red}{T164,文末自助获取源码} T164,文末自助获取源码 目录 一、系统介绍二、数据库设计三、配套教程3.1 启动教程3.2 讲解视频3.3 二次开发教程 四、功能截图五、文案资料5.1 选题背景5.2 国内…

< OS 有关 > Android 手机 SSH 客户端 app: connectBot

connectBot 开源且功能齐全的SSH客户端,界面简洁,支持证书密钥。 下载量超 500万 方便在 Android 手机上,连接 SSH 服务器,去运行命令。 Fail2ban 12小时内抓获的 IP ~ ~ ~ ~ rootjpn:~# sudo fail2ban-client status sshd Status for the jail: sshd …