强化学习笔记之【SAC算法】

强化学习笔记之【SAC算法】


前言:

本文为强化学习笔记第三篇,第一篇讲的是Q-learning和DQN,第二篇DDPG,第三篇TD3

TD3比DDPG少了一个target_actor网络,其它地方有点小改动

CSDN主页:https://blog.csdn.net/rvdgdsva

博客园主页:https://www.cnblogs.com/hassle

博客园本文链接:https://www.cnblogs.com/hassle/p/18459320


文章目录

  • 强化学习笔记之【SAC算法】
      • 前言:
      • 一、SAC算法
      • 二、SAC算法Latex解释
      • 三、SAC五大网络和模块
        • 3.1 Actor 网络
        • 3.2 Critic1 和 Critic2 网络
        • 3.3 Target Critic1 和 Target Critic2 网络
        • 3.4 软更新模块
        • 3.5 总结

STAND ALONE COMPLEX = S . A . C

首先,我们需要明确,Q-learning算法发展成DQN算法,DQN算法发展成为DDPG算法,而DDPG算法发展成TD3算法,TD3算法发展成SAC算法

Soft Actor-Critic (SAC) 是一种基于策略梯度的深度强化学习算法,它具有最大化奖励与最大化熵(探索性)的双重目标。SAC 通过引入熵正则项,使策略在决策时具有更大的随机性,从而提高探索能力。

一、SAC算法

OK,先用伪代码让你们感受一下SAC算法

# 定义 SAC 超参数
alpha = 0.2               # 熵正则项系数
gamma = 0.99              # 折扣因子
tau = 0.005               # 目标网络软更新参数
lr = 3e-4                 # 学习率# 初始化 Actor、Critic、Target Critic 网络和优化器
actor = ActorNetwork()                      # 策略网络 π(s)
critic1 = CriticNetwork()                   # 第一个 Q 网络 Q1(s, a)
critic2 = CriticNetwork()                   # 第二个 Q 网络 Q2(s, a)
target_critic1 = CriticNetwork()            # 目标 Q 网络 1
target_critic2 = CriticNetwork()            # 目标 Q 网络 2# 将目标 Q 网络的参数设置为与 Critic 网络相同
target_critic1.load_state_dict(critic1.state_dict())
target_critic2.load_state_dict(critic2.state_dict())# 初始化优化器
actor_optimizer = torch.optim.Adam(actor.parameters(), lr=lr)
critic1_optimizer = torch.optim.Adam(critic1.parameters(), lr=lr)
critic2_optimizer = torch.optim.Adam(critic2.parameters(), lr=lr)# 经验回放池(Replay Buffer)
replay_buffer = ReplayBuffer()# SAC 训练循环
for each iteration:# Step 1: 从 Replay Buffer 中采样一个批次 (state, action, reward, next_state)batch = replay_buffer.sample()state, action, reward, next_state, done = batch# Step 2: 计算目标 Q 值 (y)with torch.no_grad():# 从 Actor 网络中获取 next_state 的下一个动作next_action, next_log_prob = actor.sample(next_state)# 目标 Q 值的计算:使用目标 Q 网络的最小值 + 熵项target_q1_value = target_critic1(next_state, next_action)target_q2_value = target_critic2(next_state, next_action)min_target_q_value = torch.min(target_q1_value, target_q2_value)# 目标 Q 值 y = r + γ * (最小目标 Q 值 - α * next_log_prob)target_q_value = reward + gamma * (1 - done) * (min_target_q_value - alpha * next_log_prob)# Step 3: 更新 Critic 网络# Critic 1 损失current_q1_value = critic1(state, action)critic1_loss = F.mse_loss(current_q1_value, target_q_value)# Critic 2 损失current_q2_value = critic2(state, action)critic2_loss = F.mse_loss(current_q2_value, target_q_value)# 反向传播并更新 Critic 网络参数critic1_optimizer.zero_grad()critic1_loss.backward()critic1_optimizer.step()critic2_optimizer.zero_grad()critic2_loss.backward()critic2_optimizer.step()# Step 4: 更新 Actor 网络# 通过 Actor 网络生成新的动作及其 log 概率new_action, log_prob = actor.sample(state)# 计算 Actor 的目标损失:L = α * log_prob - Q1(s, π(s))q1_value = critic1(state, new_action)actor_loss = (alpha * log_prob - q1_value).mean()# 反向传播并更新 Actor 网络参数actor_optimizer.zero_grad()actor_loss.backward()actor_optimizer.step()# Step 5: 软更新目标 Q 网络参数with torch.no_grad():for param, target_param in zip(critic1.parameters(), target_critic1.parameters()):target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)for param, target_param in zip(critic2.parameters(), target_critic2.parameters()):target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)

二、SAC算法Latex解释

1、初始化 Actor、Critic1、Critic2、TargetCritic1 、TargetCritic2 网络
2、Buffer中采样 (state, action, reward, next_state)

3、Actor 输入 next_state 对应输出 next_action 和 next_log_prob
4、Actor 输入 state 对应输出 new_action 和 log_prob
5、Critic1 和 Critic2 分别输入next_state 和 next_action 取其中较小输出经熵正则计算得 target_q_value

6、使用 MSE_loss(Critic1(state, action), target_q_value) 更新 Critic1
7、使用 MSE_loss(Critic2(state, action), target_q_value) 更新 Critic2
8、使用 (alpha * log_prob - critic1(state, new_action)).mean() 更新 Actor


三、SAC五大网络和模块

SAC 算法 中,Actor、Critic1、Critic2、Target Critic1 和 Target Critic2 网络是核心模块,它们分别用于输出动作、评估状态-动作对的价值,并通过目标网络进行稳定的更新。

3.1 Actor 网络

Actor 网络用于在给定状态下输出一个高斯分布的均值和标准差(即策略)。它是通过神经网络近似的随机策略。用于选择动作。

import torch
import torch.nn as nnclass ActorNetwork(nn.Module):def __init__(self, state_dim, action_dim):super(ActorNetwork, self).__init__()self.fc1 = nn.Linear(state_dim, 256)self.fc2 = nn.Linear(256, 256)self.mean_layer = nn.Linear(256, action_dim)  # 输出动作的均值self.log_std_layer = nn.Linear(256, action_dim)  # 输出动作的log标准差def forward(self, state):x = torch.relu(self.fc1(state))x = torch.relu(self.fc2(x))mean = self.mean_layer(x)  # 输出动作均值log_std = self.log_std_layer(x)  # 输出 log 标准差log_std = torch.clamp(log_std, min=-20, max=2)  # 限制标准差范围return mean, log_stddef sample(self, state):mean, log_std = self.forward(state)std = torch.exp(log_std)  # 将 log 标准差转为标准差normal = torch.distributions.Normal(mean, std)action = normal.rsample()  # 通过重参数化技巧进行采样log_prob = normal.log_prob(action).sum(-1)  # 计算 log 概率return action, log_prob

3.2 Critic1 和 Critic2 网络

Critic 网络用于计算状态-动作对的 Q 值,SAC 使用两个 Critic 网络(Critic1 和 Critic2)来缓解 Q 值的过估计问题。

class CriticNetwork(nn.Module):def __init__(self, state_dim, action_dim):super(CriticNetwork, self).__init__()self.fc1 = nn.Linear(state_dim + action_dim, 256)self.fc2 = nn.Linear(256, 256)self.q_value_layer = nn.Linear(256, 1)  # 输出 Q 值def forward(self, state, action):x = torch.cat([state, action], dim=-1)  # 将 state 和 action 作为输入x = torch.relu(self.fc1(x))x = torch.relu(self.fc2(x))q_value = self.q_value_layer(x)  # 输出 Q 值return q_value

3.3 Target Critic1 和 Target Critic2 网络

Target Critic 网络的结构与 Critic 网络相同,用于稳定 Q 值更新。它们通过软更新(即在每次训练后慢慢接近 Critic 网络的参数)来保持训练的稳定性。

class TargetCriticNetwork(nn.Module):def __init__(self, state_dim, action_dim):super(TargetCriticNetwork, self).__init__()self.fc1 = nn.Linear(state_dim + action_dim, 256)self.fc2 = nn.Linear(256, 256)self.q_value_layer = nn.Linear(256, 1)  # 输出 Q 值def forward(self, state, action):x = torch.cat([state, action], dim=-1)  # 将 state 和 action 作为输入x = torch.relu(self.fc1(x))x = torch.relu(self.fc2(x))q_value = self.q_value_layer(x)  # 输出 Q 值return q_value

3.4 软更新模块

在 SAC 中,目标网络会通过软更新逐渐逼近 Critic 网络的参数。每次更新后,目标网络参数会按照 ττ 的比例向 Critic 网络的参数靠拢。

def soft_update(critic, target_critic, tau=0.005):for param, target_param in zip(critic.parameters(), target_critic.parameters()):target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)

3.5 总结
  1. 初始化网络和参数:
    • Actor 网络:用于选择动作。
    • Critic 1 和 Critic 2 网络:用于估计 Q 值。
    • Target Critic 1 和 Target Critic 2:与 Critic 网络架构相同,用于生成更稳定的目标 Q 值。
  2. 目标 Q 值计算:
    • 使用目标网络计算下一状态下的 Q 值。
    • 取两个 Q 网络输出的最小值,防止 Q 值的过估计。
    • 引入熵正则项,计算公式: y = r + γ ⋅ min ⁡ ( Q 1 , Q 2 ) − α ⋅ log ⁡ π ( a ∣ s ) y=r+\gamma\cdot\min(Q_1,Q_2)-\alpha\cdot\log\pi(a|s) y=r+γmin(Q1,Q2)αlogπ(as)
  3. 更新 Critic 网络:
    • 最小化目标 Q 值与当前 Q 值的均方误差 (MSE)。
  4. 更新 Actor 网络:
    • 最大化目标损失: L = α ⋅ log ⁡ π ( a ∣ s ) − Q 1 ( s , π ( s ) ) L=\alpha\cdot\log\pi(a|s)-Q_1(s,\pi(s)) L=αlogπ(as)Q1(s,π(s)),即在保证探索的情况下选择高价值动作。
  5. 软更新目标网络:
    • 软更新目标 Q 网络参数,使得目标网络参数缓慢向当前网络靠近,避免振荡。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/444815.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

思迈特:在AI时代韧性增长的流量密码

作者 | 曾响铃 文 | 响铃说 “超级人工智能将在‘几千天内’降临。” 最近,OpenAI 公司 CEO 山姆奥特曼在社交媒体罕见发表长文,预言了这一点。之前,很多专家预测超级人工智能将在五年内到来,奥特曼的预期,可能让这…

图论day57|建造最大岛屿(卡码网)【截至目前,图论的最高难度】

图论day57|建造最大岛屿(卡码网)【截至目前所做的题中,图论的最高难度】 思维导图分析 104.建造最大岛屿(卡码网)【截至目前所做的题中,图论的最高难度】 思维导图分析 104.建造最大岛屿(卡码网…

i18n多语言项目批量翻译工具(支持84种语言)

这里写自定义目录标题 打开‘i18n翻译助手’小程序快捷访问 打开‘i18n翻译助手’小程序 1.将需要翻译的json文件复制到输入框(建议一次不要翻译过多,测试1000条以内没什么问题) 2.等待翻译 3.翻译完成,复制结果 快捷访问

从容应对DDoS攻击:小网站的防守之战

前几天收到云服务商短信,服务器正在遭受DDoS攻击 说实话,我的网站只是一个小型站点,平时访问量并不高,没想到会成为攻击的目标。当我看到这次DDoS攻击的通知时,我其实既惊讶又有点小小的“荣幸”,毕竟我的小…

火山引擎边缘智能×扣子,拓展AI Agent物理边界

9月21日, 火山引擎边缘智能扣子技术沙龙在上海圆满落地,沙龙以“探索端智能,加速大模型应用”为主题,边缘智能、扣子、地瓜机器人以及上海交大等多位重磅嘉宾出席,分享 AI 最新趋势及端侧大模型最新探索与应用实践。 …

Java项目-----图形验证码登陆实现

原理: 验证码在前端显示,但是是在后端生成, 将生成的验证码存入redis,待登录时,前端提交验证码,与后端生成的验证码比较. 详细解释: 图形验证码的原理(如下图代码).前端发起获取验证码的请求后, 1 后端接收请求,生成一个键key(随机的键) 然后生成一个验证码作为map的valu…

JAVA接入GPT开发

Spring AI Alibaba:Java开发者的GPT集成新标准 目前,像OpenAI等GPT服务提供商主要提供HTTP接口,这导致大部分Java开发者在接入GPT时缺乏标准化的方法。为解决这一问题,Spring团队推出了Spring AI Alibaba,它作为一套标…

基于Java的可携宠物酒店管理系统的设计与实现(论文+源码)_kaic

摘 要 随着社会经济的不断发‎‏展,现如今出行并住酒店的人越来越多,与之而来的是酒店行业的工作量日益增加,酒店的管理效率亟待提升。此外很多人出门旅游时会有携带宠物的情况,但是现如今酒店对宠物的限制,导致许多…

Java学习-JVM

目录 1. 基本常识 1.1 JVM是什么 1.2 JVM架构图 1.3 Java技术体系 1.4 Java与JVM的关系 2. 类加载系统 2.1 类加载器种类 2.2 执行顺序 2.3 类加载四个时机 2.4 生命周期 2.5 类加载途径 2.6 双亲委派模型 3. 运行时数据区 3.1 运行时数据区构成 3.2 堆 3.3 栈…

【RabbitMQ高级——过期时间TTL+死信队列】

1. 过期时间TTL概述 过期时间TTL表示可以对消息设置预期的时间,在这个时间内都可以被消费者接收获取;过了之后消息将自动被删除。RabbitMQ可以对消息和队列设置TTL。 目前有两种方法可以设置。 第一种方法是通过队列属性设置,队列中所有消…

基于Springboot的宠物咖啡馆平台的设计与实现(源码+定制+参考)

博主介绍: ✌我是阿龙,一名专注于Java技术领域的程序员,全网拥有10W粉丝。作为CSDN特邀作者、博客专家、新星计划导师,我在计算机毕业设计开发方面积累了丰富的经验。同时,我也是掘金、华为云、阿里云、InfoQ等平台…

【操作系统】四、文件管理:1.文件系统基础(文件属性、文件逻辑结构、文件物理结构、文件存储管理、文件目录、基本操作、文件共享、文件保护)

文件管理 文章目录 文件管理八、文件系统基础1.文件的属性2.文件的逻辑结构2.1顺序文件2.2索引文件2.3索引顺序文件2.4多级索引顺序文件 3.目录文件❗3.1文件控制块FCB3.1.1对目录进行的操作 3.2目录结构3.2.1单级目录结构3.2.2两级目录结构3.2.3多级目录结构(树形目…

【大模型部署】本地运行自己的大模型--ollama

ollama简介 ollama是一款开源的、轻量级的框架,它可以快速在本地构建及运行大模型,尤其是一些目前最新开源的模型,如 Llama 3, Mistral, Gemma等。 官网上有大量已经开源的模型,部分针对性微调过的模型也可以选择到,…

Qt源码-Qt多媒体音频框架

Qt 多媒体音频框架 一、概述二、音频设计1. ALSA 基础2. Qt 音频类1. 接口实现2. alsa 插件实现 一、概述 环境详细Qt版本Qt 5.15操作系统Deepin v23代码工具Visual Code源码https://github.com/qt/qtmultimedia/tree/5.15 这里记录一下在Linux下Qt 的 Qt Multimedia 模块的设…

Javascript笔试题目(一)

1.JS查找文章中出现频率最高的单词? 要在JavaScript中查找文章中出现频率最高的单词,你可以按照以下步骤进行操作: 将文章转换为小写:这可以确保单词的比较是大小写不敏感的。移除标点符号:标点符号会干扰单词的计数。将文章拆…

基于Web的停车场管理系统(论文+源码)_kaic

摘要 我国经济的发展愈发迅速,车辆也随之增加的难以想象,因此车位的治理也越来越繁杂,为了方便停车位相关信息的管理,设计开发一个合理的停车位管理系统尤为重要。因而,具有信息方便读取和操作简便的停车位管理系统的设…

在启智AI平台实践ChatGLM4-9B聊天机器人@MindSpore

前段时间在昇思训练营发现一个好东西,就是昇思AI实验室:昇思大模型平台 在官方提供的jupyter AI编程实践样例中,发现了这个项目:ChatGLM4-9B实践样例 GLM-4-9B是智谱 AI 推出的最新一代预训练模型 GLM-4 系列中的开源版本。 在语…

两个数相加(c语言)

1./给定一个整数数组 nums 和一个整数目标值 target,请你在该数组中找出 和为目标值 target // 的那 两个 整数,并返回它们的数组下标。 //你可以假设每种输入只会对应一个答案,并且你不能使用两次相同的元素。你可以按任意顺序返回答案。 /…

Windows电脑本地安装AI文生音乐软件结合内网穿透远程访问制作

文章目录 前言1. 本地部署2. 使用方法介绍3. 内网穿透工具下载安装4. 配置公网地址5. 配置固定公网地址 前言 今天和大家分享一下在Windows系统电脑上本地快速部署一个文字生成音乐的AI创作服务MusicGPT,并结合cpolar内网穿透工具实现随时随地远程访问使用进行AI音…

TextView把其它控件挤出屏幕的处理办法

1.如果TextView后面的控件是紧挨着TextView的&#xff0c;可以给TextView添加maxWidth限制其最大长度 上有问题的布局代码 <?xml version"1.0" encoding"utf-8"?> <layout xmlns:android"http://schemas.android.com/apk/res/android&qu…