pytorch实现长短期记忆网络 (LSTM)

 人工智能例子汇总:AI常见的算法和例子-CSDN博客 

LSTM 通过 记忆单元(cell)三个门控机制(遗忘门、输入门、输出门)来控制信息流:

 记忆单元(Cell State)

  • 负责存储长期信息,并通过门控机制决定保留或丢弃信息。

 遗忘门(Forget Gate, ftf_tft​)

 输入门(Input Gate, iti_tit​)

 输出门(Output Gate, oto_tot​)

特性

传统 RNNLSTM
记忆能力短期记忆长短期记忆
计算复杂度
解决梯度消失
适用场景短序列数据长序列数据

LSTM 应用场景

  • 自然语言处理(NLP):文本生成、情感分析、机器翻译
  • 时间序列预测:股票预测、天气预报、传感器数据分析
  • 语音识别:自动字幕生成、语音转文字(ASR)
  • 机器人与控制系统:智能体决策、自动驾驶

例子:

下面例子实现了一个 基于 LSTM 的强化学习智能体,在 1D 网格环境 里移动,并找到最优路径。
最终,我们 绘制 5 条测试路径,并高亮显示最佳路径(红色)

import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt# ========== 1. 定义 LSTM 策略网络 ==========
class LSTMPolicy(nn.Module):def __init__(self, input_size, hidden_size, output_size, num_layers=1):super(LSTMPolicy, self).__init__()self.hidden_size = hidden_sizeself.num_layers = num_layersself.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)self.fc = nn.Linear(hidden_size, output_size)self.softmax = nn.Softmax(dim=-1)def forward(self, x, hidden_state):batch_size = x.size(0)# 确保 hidden_state 维度正确if hidden_state[0].dim() == 2:hidden_state = (hidden_state[0].unsqueeze(1).repeat(1, batch_size, 1),hidden_state[1].unsqueeze(1).repeat(1, batch_size, 1))out, hidden_state = self.lstm(x, hidden_state)out = self.fc(out[:, -1, :])  # 取最后时间步的输出action_prob = self.softmax(out)  # 归一化输出,作为策略return action_prob, hidden_statedef init_hidden(self, batch_size=1):return (torch.zeros(self.num_layers, batch_size, self.hidden_size),torch.zeros(self.num_layers, batch_size, self.hidden_size))# ========== 2. 创建网格环境 ==========
class GridWorld:def __init__(self, grid_size=10, goal_position=9):self.grid_size = grid_sizeself.goal_position = goal_positionself.reset()def reset(self):self.position = 0return self.positiondef step(self, action):if action == 0:self.position = max(0, self.position - 1)elif action == 1:self.position = min(self.grid_size - 1, self.position + 1)reward = 1 if self.position == self.goal_position else -0.1done = self.position == self.goal_positionreturn self.position, reward, done# ========== 3. 训练智能体 ==========
def train(num_episodes=500, max_steps=50):env = GridWorld()input_size = 1hidden_size = 64output_size = 2num_layers = 1policy = LSTMPolicy(input_size, hidden_size, output_size, num_layers)optimizer = optim.Adam(policy.parameters(), lr=0.01)gamma = 0.99for episode in range(num_episodes):state = torch.tensor([[env.reset()]], dtype=torch.float32).unsqueeze(0)  # (1, 1, input_size)hidden_state = policy.init_hidden(batch_size=1)log_probs = []rewards = []for step in range(max_steps):action_probs, hidden_state = policy(state, hidden_state)action = torch.multinomial(action_probs, 1).item()log_prob = torch.log(action_probs.squeeze(0)[action])log_probs.append(log_prob)next_state, reward, done = env.step(action)rewards.append(reward)if done:breakstate = torch.tensor([[next_state]], dtype=torch.float32).unsqueeze(0)# 计算回报并更新策略returns = []R = 0for r in reversed(rewards):R = r + gamma * Rreturns.insert(0, R)returns = torch.tensor(returns, dtype=torch.float32)returns = (returns - returns.mean()) / (returns.std() + 1e-9)loss = sum([-log_prob * R for log_prob, R in zip(log_probs, returns)])optimizer.zero_grad()loss.backward()optimizer.step()if (episode + 1) % 50 == 0:print(f"Episode {episode + 1}/{num_episodes}, Total Reward: {sum(rewards)}")torch.save(policy.state_dict(), "policy.pth")# 训练智能体
train(500)# ========== 4. 测试智能体并绘制最佳路径 ==========
def test(num_episodes=5):env = GridWorld()input_size = 1hidden_size = 64output_size = 2num_layers = 1policy = LSTMPolicy(input_size, hidden_size, output_size, num_layers)policy.load_state_dict(torch.load("policy.pth"))plt.figure(figsize=(10, 5))best_path = Nonebest_steps = float('inf')for episode in range(num_episodes):state = torch.tensor([[env.reset()]], dtype=torch.float32).unsqueeze(0)  # (1, 1, input_size)hidden_state = policy.init_hidden(batch_size=1)positions = [env.position]  # 记录位置变化while True:action_probs, hidden_state = policy(state, hidden_state)action = torch.argmax(action_probs, dim=-1).item()next_state, reward, done = env.step(action)positions.append(next_state)if done:breakstate = torch.tensor([[next_state]], dtype=torch.float32).unsqueeze(0)# 记录最佳路径(最短步数)if len(positions) < best_steps:best_steps = len(positions)best_path = positions# 绘制普通路径(蓝色)plt.plot(range(len(positions)), positions, marker='o', linestyle='-', color='blue', alpha=0.6,label=f'Episode {episode + 1}' if episode == 0 else "")# 绘制最佳路径(红色)if best_path:plt.plot(range(len(best_path)), best_path, marker='o', linestyle='-', color='red', linewidth=2,label="Best Path")# 打印最佳路径print(f"Best Path (steps={best_steps}): {best_path}")plt.xlabel("Time Steps")plt.ylabel("Agent Position")plt.title("Agent's Movement Path (Best Path in Red)")plt.legend()plt.grid(True)plt.show()# 测试并绘制智能体移动路径
test(5)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/12302.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

CDDIS从2025年2月开始数据迁移

CDDIS 将从 2025 年 2 月开始将我们的网站从 cddis.nasa.gov 迁移到 earthdata.nasa.gov&#xff0c;并于 2025 年 6 月结束。 期间可能对GAMIT联网数据下载造成影响。

【Redis】主从模式,哨兵,集群

主从复制 单点问题&#xff1a; 在分布式系统中&#xff0c;如果某个服务器程序&#xff0c;只有一个节点&#xff08;也就是一个物理服务器&#xff09;来部署这个服务器程序的话&#xff0c;那么可能会出现以下问题&#xff1a; 1.可用性问题&#xff1a;如果这个机器挂了…

华为云kubernetes部署deepseek r1、ollama和open-webui(已踩过坑)

1 概述 ollama是一个管理大模型的一个中间层&#xff0c;通过它你可以下载并管理deepseek R1、llama3等大模型。 open-webui是一个web界面&#xff08;界面设计受到chatgpt启发&#xff09;&#xff0c;可以集成ollama API、 OpenAI的 API。 用常见的web应用架构来类比&#x…

在Mac mini M4上部署DeepSeek R1本地大模型

在Mac mini M4上部署DeepSeek R1本地大模型 安装ollama 本地部署&#xff0c;我们可以通过Ollama来进行安装 Ollama 官方版&#xff1a;【点击前往】 Web UI 控制端【点击安装】 如何在MacOS上更换Ollama的模型位置 默认安装时&#xff0c;OLLAMA_MODELS 位置在"~/.o…

CSS 背景与边框:从基础到高级应用

CSS 背景与边框&#xff1a;从基础到高级应用 1. CSS 背景样式1.1 背景颜色示例代码&#xff1a;设置背景颜色 1.2 背景图像示例代码&#xff1a;设置背景图像 1.3 控制背景平铺行为示例代码&#xff1a;控制背景平铺 1.4 调整背景图像大小示例代码&#xff1a;调整背景图像大小…

数据思维错题知识点整理(复习)

小的知识点整理 目前常见的数据采集方案有什么。 埋点、可视化埋点、无埋点&#xff08;无埋点并不是字面意思不埋点&#xff0c;其实也是一种埋点&#xff0c;只是让开发人员完全无感知&#xff0c;直接嵌入sdk&#xff0c;然后每个元素都能查看他们的情况&#xff0c;后续开…

PyQt4学习笔记2】QMainWindow

目录 一、创建 QMainWindow 组件 1. 创建工具栏 2. 创建停靠窗口 3. 设置状态栏 4. 设置中央窗口部件 二、QMainWindow 的主要方法 1. addToolBar() 2. addDockWidget() 3. setStatusBar() 4. setCentralWidget() 5. menuBar() 6. saveState() 和 restoreState() 三、QMainWind…

Linux:文件系统(软硬链接)

目录 inode ext2文件系统 Block Group 超级块&#xff08;Super Block&#xff09; GDT&#xff08;Group Descriptor Table&#xff09; 块位图&#xff08;Block Bitmap&#xff09; inode位图&#xff08;Inode Bitmap&#xff09; i节点表&#xff08;inode Tabl…

ubuntu22.40安装及配置静态ip解决重启后配置失效

遇到这种错误&#xff0c;断网安装即可&#xff01; 在Ubuntu中配置静态IP地址的步骤如下。根据你使用的Ubuntu版本&#xff08;如 Netplan 或传统的 ifupdown&#xff09;&#xff0c;配置方法有所不同。以下是基于 Netplan 的配置方法&#xff08;适用于Ubuntu 17.10及更高版…

手写MVVM框架-实现简单的数据代理

MVVM框架最显著的特点就是虚拟dom和响应式的数据、我们以Vue为例&#xff0c;分别实现data、computed、created、methods以及虚拟dom。 这一章我们先实现简单的响应式&#xff0c;修改数据之后在控制台打印。 我们将该框架命名为MiniVue。 首先我们需要创建MiniVue的类(src/co…

ESLint

ESLint ESLint 是一个针对 JS 的代码风格检查工具&#xff0c;当不满足其要求的风格时&#xff0c;会给予警告或错误。 官网&#xff1a;https://eslint.org/ 中文网&#xff1a;https://eslint.nodejs.cn/ 安装使用 在你的项目中安装 ESLint 包&#xff1a; npm install -…

kaggle视频行为分析1st and Future - Player Contact Detection

这次比赛的目标是检测美式橄榄球NFL比赛中球员经历的外部接触。您将使用视频和球员追踪数据来识别发生接触的时刻&#xff0c;以帮助提高球员的安全。两种接触&#xff0c;一种是人与人的&#xff0c;另一种是人与地面&#xff0c;不包括脚底和地面的&#xff0c;跟我之前做的这…

Chapter 6 -Fine-tuning for classification

Chapter 6 -Fine-tuning for classification 本章内容涵盖 引入不同的LLM微调方法准备用于文本分类的数据集修改预训练的 LLM 进行微调微调 LLM 以识别垃圾邮件评估微调LLM分类器的准确性使用微调的 LLM 对新数据进行分类 现在&#xff0c;我们将通过在大语言模型上对特定目标任…

【从零开始的LeetCode-算法】922. 按奇偶排序数组 II

给定一个非负整数数组 nums&#xff0c; nums 中一半整数是 奇数 &#xff0c;一半整数是 偶数 。 对数组进行排序&#xff0c;以便当 nums[i] 为奇数时&#xff0c;i 也是 奇数 &#xff1b;当 nums[i] 为偶数时&#xff0c; i 也是 偶数 。 你可以返回 任何满足上述条件的…

python 小游戏:扫雷

目录 1. 前言 2. 准备工作 3. 生成雷区 4. 鼠标点击扫雷 5. 胜利 or 失败 6. 游戏效果展示 7. 完整代码 1. 前言 本文使用 Pygame 实现的简化版扫雷游戏。 如上图所示&#xff0c;游戏包括基本的扫雷功能&#xff1a;生成雷区、左键点击扫雷、右键标记地雷、显示数字提示…

安全策略实验报告

1.实验拓扑图 2.实验需求 vlan2属于办公区&#xff0c;vlan3生产区 办公区pc在工作日时间可以正常访问OAserver&#xff0c;i其他时间不允许 办公区pc可以在任意时间访问Web server 生产区pc可以在任意时间访问OA server但不能访问web server 特例&#xff1a;生产区pc可以…

力扣73矩阵置零

给定一个 m x n 的矩阵&#xff0c;如果一个元素为 0 &#xff0c;则将其所在行和列的所有元素都设为 0 。请使用 原地 算法。 输入&#xff1a;matrix [[1,1,1],[1,0,1],[1,1,1]] 输出&#xff1a;[[1,0,1],[0,0,0],[1,0,1]] 输入&#xff1a;matrix [[0,1,2,0],[3,4,5,2],[…

蓝桥杯C语言组:暴力破解

基于C语言的暴力破解方法详解 暴力破解是一种通过穷举所有可能的解来找到正确答案的算法思想。在C语言中&#xff0c;暴力破解通常用于解决那些问题规模较小、解的范围有限的问题。虽然暴力破解的效率通常较低&#xff0c;但它是一种简单直接的方法&#xff0c;适用于一些简单…

【自然语言处理(NLP)】生成词向量:GloVe(Global Vectors for Word Representation)原理及应用

文章目录 介绍GloVe 介绍核心思想共现矩阵1. 共现矩阵的定义2. 共现概率矩阵的定义3. 共现概率矩阵的意义4. 共现概率矩阵的构建步骤5. 共现概率矩阵的应用6. 示例7. 优缺点优点缺点 **总结** 目标函数训练过程使用预训练的GloVe词向量 优点应用总结 个人主页&#xff1a;道友老…

介绍一下Mybatis的Executor执行器

Executor执行器是用来执行我们的具体的SQL操作的 有三种基本的Executor执行器&#xff1a; SimpleExecutor简单执行器 每执行一次update或select&#xff0c;就创建一个Statement对象&#xff0c;用完立刻关闭Statement对象 ReuseExecutor可重用执行器 可重复利用Statement…