游戏AI的创造思路-技术基础-深度学习(3)

继续填坑,本篇介绍深度学习中的长短期记忆网络~~~~

目录

3.3. 长短期记忆网络(LSTM)

3.3.1. 什么是长短期记忆网络

3.3.2. 形成过程与运行原理

3.3.2.1. 细胞状态与门结构

3.3.2.2. 遗忘门

3.3.2.3. 输入门

3.3.2.4. 细胞状态更新

3.3.2.5. 输出门

3.3.2.6. 以上各步骤的示例代码

3.3.3. 优缺点

3.3.4. 存在的问题及解决方法

3.3.5. 示例代码


3.3. 长短期记忆网络(LSTM)

3.3.1. 什么是长短期记忆网络

长短期记忆网络(LSTM,Long Short-Term Memory)算法是一种特殊的循环神经网络(RNN),它旨在解决传统RNN在处理长序列数据时遇到的梯度消失和梯度爆炸问题,从而更有效地学习序列中的长期依赖关系。

  • 为了最小化训练误差,通常使用梯度下降法,如应用时序性倒传递算法,来依据错误修改每次的权重。此外,LSTM有多种变体,其中一个重要的版本是门控循环单元(GRU)。
  • LSTM适合于处理和预测时间序列中间隔和延迟非常长的重要事件。其表现通常比时间递归神经网络及隐马尔科夫模型(HMM)更好。例如,在不分段连续手写识别上,LSTM模型曾赢得过ICDAR手写识别比赛冠军。此外,LSTM还广泛应用于自主语音识别,并在2013年使用TIMIT自然演讲数据库达成了17.7%的错误率纪录。
  • LSTM的成功在很大程度上促进了深度学习和人工智能领域的发展。尽管近年来出现了新的模型结构,如基于注意力机制的Transformer,但LSTM仍然是许多序列建模任务的可靠选择。随着时间的推移,LSTM被广泛应用于自然语言处理、语音识别、文本生成、视频分析等多个领域

3.3.2. 形成过程与运行原理

LSTM通过引入“”结构和“细胞状态”来更好地捕捉序列中的长期依赖关系。(通过借鉴脑神经学的知识来组建序列中的长期依赖关系)

3.3.2.1. 细胞状态与门结构

LSTM的核心是细胞状态,它像一条传送带,在整个链上运行,只有一些小的线性操作作用其上,信息在上面流传保持不变会很容易。LSTM通过精心设计的门结构来去除或增加信息到细胞状态,这些门结构包括遗忘门、输入门和输出门。

3.3.2.2. 遗忘门

决定从细胞状态中丢弃什么信息。它查看当前的输入和前一个时间步的隐藏状态,并为细胞状态中的每个数字输出一个在0到1之间的数字,1表示“完全保留”,0表示“完全舍弃”。

遗忘门决定了从上一个时间步的细胞状态中丢弃哪些信息。其计算公式为:

[ I_t = \sigma(X_tW_{xi} + H_{t-1}W_{hi} + b_i) ]

其中,( I_t )表示输入门在时刻( t )的值,( X_t )是时刻 ( t ) 的输入,( H_{t-1} )是前一个时刻的隐藏状态,( W_{xi} )( W_{hi} ) 是对应的权重矩阵,而( b_i )( b_i )是偏置项。函数( \sigma )表示sigmoid激活函数。

3.3.2.3. 输入门

决定什么新信息将被存储在细胞状态中。这包括两部分,一部分是输入门决定我们将更新哪些部分,另一部分是tanh层创建一个新的候选值向量,这个向量可能会被添加到细胞状态中。

[ F_t = \sigma(X_tW_{xf} + H_{t-1}W_{hf} + b_f) ]

类似地,( F_t )表示遗忘门在时刻( t )的值,其他符号的含义与输入门公式中的相同,只是权重和偏置项是针对遗忘门的。

3.3.2.4. 细胞状态更新

首先,旧细胞状态与遗忘门相乘,丢弃掉需要丢弃的信息。然后,将输入门的输出与tanh层的输出相乘,得出新的候选细胞状态。最后,将这两个值相加,形成新的细胞状态。

  • 旧细胞状态与遗忘门相乘

[ \tilde{C}t = C{t-1} \odot F_t ]

这里,( \tilde{C}t )表示经过遗忘门处理后的旧细胞状态,( C{t-1} )是前一个时刻的细胞状态,( F_t ) 是遗忘门在时刻( t )的输出,而( \odot )表示逐元素相乘(Hadamard乘积)。这一步的目的是丢弃掉不需要的信息。

  • 计算新的候选细胞状态

[ \hat{C}t = \tanh(X_tW{xc} + H_{t-1}W_{hc} + b_c) ]

其中,( \hat{C}t )是新的候选细胞状态,( X_t )是时刻 ( t )的输入,( H{t-1} ) 是前一个时刻的隐藏状态,( W_{xc} )( W_{hc} ) 是对应的权重矩阵,( b_c )是偏置项。函数 ( \tanh )是双曲正切激活函数,它将输入值压缩到 ( -1 ) 到 ( 1 ) 的范围内。

  • 将候选细胞状态与输入门相乘

[ i_t \odot \hat{C}_t ]

这里,( i_t )是输入门在时刻( t )的输出,( \odot )表示逐元素相乘。这一步的目的是根据输入门的选择来决定哪些新的信息被加入到细胞状态中。

  • 更新细胞状态

[ C_t = \tilde{C}_t + i_t \odot \hat{C}_t ]

最终,新的细胞状态( C_t )是经过遗忘门处理后的旧细胞状态 ( \tilde{C}_t )与经过输入门处理后的新候选细胞状态 ( i_t \odot \hat{C}_t ) 之和。这一步完成了细胞状态的更新,使得LSTM能够记住长期依赖关系。

3.3.2.5. 输出门

基于细胞状态来决定输出什么。首先,运行一个sigmoid层来确定细胞状态的哪个部分将输出,然后将细胞状态通过tanh进行处理(得到一个在-1到1之间的值),并将其与sigmoid门的输出相乘,最终得到输出。

[ O_t = \sigma(X_tW_{xo} + H_{t-1}W_{ho} + b_o) ]

在这里,( O_t )是输出门在时刻( t )的值,其他参数和符号的意义与前面公式中的一致,但针对输出门。

3.3.2.6. 以上各步骤的示例代码

Python代码示例

import numpy as np  def sigmoid(x):  return 1 / (1 + np.exp(-x))  def tanh(x):  return np.tanh(x)  # LSTM Cell 参数初始化  
input_size = 10  
hidden_size = 20  Wf = np.random.randn(hidden_size, hidden_size + input_size) # 遗忘门权重  
Wi = np.random.randn(hidden_size, hidden_size + input_size) # 输入门权重  
Wc = np.random.randn(hidden_size, hidden_size + input_size) # 候选细胞状态权重  
Wo = np.random.randn(hidden_size, hidden_size + input_size) # 输出门权重  # LSTM Cell 前向传播  
def lstm_cell_forward(xt, ht_prev, ct_prev, Wf, Wi, Wc, Wo):  # 拼接前一个隐藏状态和当前输入  concat = np.concatenate((ht_prev, xt), axis=0)  # 计算遗忘门  ft = sigmoid(np.dot(Wf, concat))  # 计算输入门  it = sigmoid(np.dot(Wi, concat))  # 计算候选细胞状态  cct = tanh(np.dot(Wc, concat))  # 细胞状态更新  ct = ft * ct_prev + it * cct  # 计算输出门  ot = sigmoid(np.dot(Wo, concat))  # 计算隐藏状态  ht = ot * tanh(ct)  return ht, ct  # 示例使用  
xt = np.random.randn(input_size) # 当前输入  
ht_prev = np.zeros(hidden_size) # 前一个隐藏状态  
ct_prev = np.zeros(hidden_size) # 前一个细胞状态  ht, ct = lstm_cell_forward(xt, ht_prev, ct_prev, Wf, Wi, Wc, Wo)

C++代码示例

#include <Eigen/Dense>  
#include <cmath>  using namespace Eigen;  // 激活函数  
double sigmoid(double x) {  return 1.0 / (1.0 + std::exp(-x));  
}  double tanh(double x) {  return std::tanh(x);  
}  // LSTM单元前向传播  
void LSTMCellForward(const VectorXd& xt, const VectorXd& ht_prev, const VectorXd& ct_prev,   const MatrixXd& Wf, const MatrixXd& Wi, const MatrixXd& Wc, const MatrixXd& Wo,  VectorXd& ht, VectorXd& ct) {  int input_size = xt.size();  int hidden_size = ht_prev.size();  VectorXd concat(input_size + hidden_size);  concat << ht_prev, xt;  // 计算遗忘门  VectorXd ft = concat.unaryExpr([](double elem) { return sigmoid(elem); }) * Wf.transpose();  // 计算输入门  VectorXd it = concat.unaryExpr([](double elem) { return sigmoid(elem); }) * Wi.transpose();  // 计算候选细胞状态  VectorXd cct = concat.unaryExpr([](double elem) { return tanh(elem); }) * Wc.transpose();  // 细胞状态更新  ct = ft.array() * ct_prev.array() + it.array() * cct.array();  // 计算输出门  VectorXd ot = concat.unaryExpr([](double elem) { return sigmoid(elem); }) * Wo.transpose();  // 计算隐藏状态  ht = ot.array() * ct.array().unaryExpr([](double elem) { return tanh(elem); });  
}  int main() {  int input_size = 10;  int hidden_size = 20;  MatrixXd Wf = MatrixXd::Random(hidden_size, hidden_size + input_size); // 遗忘门权重  MatrixXd Wi = MatrixXd::Random(hidden_size, hidden_size + input_size); // 输入门权重  MatrixXd Wc = MatrixXd::Random(hidden_size, hidden_size + input_size); // 候选细胞状态权重  MatrixXd Wo = MatrixXd::Random(hidden_size, hidden_size + input_size); // 输出门权重  VectorXd xt = VectorXd::Random(input_size); // 当前输入  VectorXd ht_prev = VectorXd::Zero(hidden_size); // 前一个隐藏状态  VectorXd ct_prev = VectorXd::Zero(hidden_size); // 前一个细胞状态  VectorXd ht(hidden_size), ct(hidden_size);  LSTMCellForward(xt, ht_prev, ct_prev, Wf, Wi, Wc, Wo, ht, ct);  // Do something with ht and ct...  return 0;  
}

这些代码是简化示例,实际应用中LSTM的实现会更加复杂,包括多个时间步的迭代、批处理支持、梯度计算和权重更新等。

在生产环境中,建议使用成熟的深度学习框架如TensorFlow或PyTorch来实现LSTM哦。

3.3.3. 优缺点

优点

  1. 能够有效地解决传统RNN中的梯度消失和梯度爆炸问题。
  2. 能够更好地捕捉序列中的长期依赖关系。
  3. 在处理长序列数据时具有优势。

缺点

  1. LSTM模型相对复杂,计算成本较高。
  2. 对于输入序列长度较长时,可能会出现过拟合现象,导致泛化能力下降。

3.3.4. 存在的问题及解决方法

过拟合问题:可以通过正则化、dropout等技术来减轻过拟合现象。

无法有效捕捉时间上下文关系:可以引入双向LSTM(Bidirectional LSTM)结构来提高对于时间上下文之间关系的建模能力。

对输入数据序列顺序敏感:在实际应用中,可以通过数据增强、序列颠倒等方法来减轻模型对输入数据序列顺序的敏感性。

3.3.5. 示例代码

Python代码

由于篇幅限制,这里提供一个简化的Python示例,使用PyTorch库实现LSTM:

import torch  
import torch.nn as nn  # 定义一个简单的LSTM模型  
class SimpleLSTM(nn.Module):  def __init__(self, input_size, hidden_size, output_size):  super(SimpleLSTM, self).__init__()  self.hidden_size = hidden_size  self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)  self.fc = nn.Linear(hidden_size, output_size)  def forward(self, x, hidden):  lstm_out, hidden = self.lstm(x, hidden)  output = self.fc(lstm_out[:, -1, :])  # 取最后一个时间步的输出进行分类  return output, hidden  def init_hidden(self, batch_size):  return (torch.zeros(1, batch_size, self.hidden_size),  torch.zeros(1, batch_size, self.hidden_size))  # 模型参数  
input_size = 10  
hidden_size = 20  
output_size = 2  
batch_size = 1  
sequence_length = 5  # 创建模型实例  
model = SimpleLSTM(input_size, hidden_size, output_size)  # 创建虚拟输入数据和初始隐藏状态  
x = torch.randn(batch_size, sequence_length, input_size)  
hidden = model.init_hidden(batch_size)  # 前向传播  
output, hidden = model(x, hidden)  
print(output)

C++代码

在C++中使用LSTM,我们通常会借助PyTorch的C++ API,也称为LibTorch。以下是一个简单的示例:

#include <torch/script.h> // 包含TorchScript的头文件  
#include <iostream>  int main() {  // 加载一个预先训练好的LSTM模型(这里假设你已经有一个用PyTorch训练的模型并导出了TorchScript)  torch::jit::script::Module module;  try {  module = torch::jit::load("lstm_model.pt"); // 加载模型  } catch (const c10::Error& e) {  std::cerr << "模型加载错误\n";  return -1;  }  // 创建一个输入张量,假设输入大小为[1, 5, 10](batch_size, sequence_length, input_size)  torch::Tensor input = torch::randn({1, 5, 10});  // 执行模型前向传播  std::vector<torch::jit::IValue> inputs;  inputs.push_back(input);  torch::Tensor output = module.forward(inputs).toTensor();  std::cout << output << std::endl;  return 0;  
}

请注意,C++ 示例中的模型需要是预先训练好并导出为TorchScript的模型。TorchScript是PyTorch的一个子集,允许模型在没有Python运行时的环境中执行。

在C++中直接使用LSTM而不依赖预先训练的模型会更复杂,因为你需要手动实现LSTM的所有细节。这通常不是推荐的做法,除非你有特定的性能要求或需要深度定制LSTM的行为。

在大多数情况下,使用PyTorch等高级库会更加方便和高效。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/363499.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Younger 数据集:人工智能生成神经网络

设计和优化神经网络架构通常需要广泛的专业知识&#xff0c;从手工设计开始&#xff0c;然后进行手动或自动化的精细化改进。这种依赖性成为快速创新的重要障碍。认识到从头开始自动生成神经网络架构的复杂性&#xff0c;本文引入了Younger&#xff0c;这是一个开创性的数据集&…

机器学习python实践——关于管道模型Pipeline和网格搜索GridSearchCV的一些个人思考

最近在利用python跟着指导书进行机器学习的实践&#xff0c;在实践中使用到了Pipeline类方法和GridSearchCV类方法&#xff0c;并且使用过程中发现了一些问题&#xff0c;所以本文主要想记录并分享一下个人对于这两种类方法的思考&#xff0c;如果有误&#xff0c;请见谅&#…

Kubernetes 容器编排技术

Kubernetes 容器编排 前言 知识扩展 早在 2015 年 5 月&#xff0c;Kubernetes 在 Google 上的搜索热度就已经超过了 Mesos 和 Docker Swarm&#xff0c;从那儿之后更是一路飙升&#xff0c;将对手甩开了十几条街,容器编排引擎领域的三足鼎立时代结束。 目前&#xff0c;AWS…

蚂蚁- 定存

一&#xff1a;收益变动&&收益重算 1.1: 场景组合 1: 澳门元个人活期&#xff0c;日终余额大于0&#xff0c;当日首次、本周本月非首次系统结息&#xff0c;结息后FCDEPCORE_ASYN_CMD_JOB捞起进行收益计算 【depc_account_revenue_detail】收益日 > 【depc_accoun…

Linux驱动开发笔记(十一)tty子系统及其驱动

文章目录 前言一、串口驱动框架1.1 核心数据结构1.2 数据处理流程 二、驱动编写1. 设备树的修改2. 相关API函数3. 驱动框架4. 具体功能的实现4.1 出入口函数的编写4.2 读写函数 前言 之前已经讲过应用层的应用&#xff0c;接下来我们继续进行驱动的学习。其实实际上我们很少主动…

【Redis四】主从复制、哨兵以及Cluster集群

目录 一.主从复制、哨兵、集群的区别 二.Redis主从复制 1.作用 2.原理 3.流程 三.搭建Redis 主从复制 1.源码编译安装以及配置文件修改 1.1.修改 Redis 配置文件&#xff08;Slave节点操作&#xff09; 2.验证主从复制 2.1.在Master节点上看日志 2.2.在Master节点上…

学习感悟丨在誉天学习数通HCIP怎么样

大家好&#xff0c;我是誉天学员的徐同学&#xff0c;学习的数通HCIP课程。 在学校的时候&#xff0c;听说下半年就要出去实习了&#xff0c;心中坎坷不安&#xff0c;现在我学到的知识远远不够的。然后就想着学点东西充实一下自己的知识面和专业能力&#xff0c;有一次和同学谈…

有没有能用蓝牙的游泳耳机,性能超凡的4大游泳耳机力荐

在现代科技的推动下&#xff0c;越来越多具备蓝牙功能的游泳耳机正在改变游泳爱好者的体验方式。这些创新产品不仅在防水性能上有了显著提升&#xff0c;还能让您在水中享受到高质量的音乐。然而&#xff0c;选择一款优秀的蓝牙游泳耳机并不简单&#xff0c;需要考虑到防水等级…

vite vue3使用axios解决跨域问题

引入依赖 npm install axios 在main.js中全局引入 import { createApp } from vue import App from ./App.vue import axios from axiosconst app createApp(App)// 全局引入axios app.config.globalProperties.$axios axiosapp.mount(#app) 修改vite.config.js的代理配置…

Java | Leetcode Java题解之第189题轮转数组

题目&#xff1a; 题解&#xff1a; class Solution {public void rotate(int[] nums, int k) {k % nums.length;reverse(nums, 0, nums.length - 1);reverse(nums, 0, k - 1);reverse(nums, k, nums.length - 1);}public void reverse(int[] nums, int start, int end) {whil…

搭建企业内网pypi镜像库,让python在内网也能像互联网一样安装pip库

目录 知识点实验1.服务器安装python2.新建一个目录/mirror/pip&#xff0c;用于存储pypi文件&#xff0c;作为仓库目录3.下载python中的所需包放至仓库文件夹/mirror/pip3.1. 新建requirement.py脚本&#xff08;将清华pypi镜像库文件列表粘贴到requirement.txt文件中&#xff…

代码随想录算法训练营第三十七天|01背包问题、分割等和子集

01背包问题 题目链接&#xff1a;46. 携带研究材料 文档讲解&#xff1a;代码随想录 状态&#xff1a;忘了 二维dp 问题1&#xff1a;为啥会想到i代表第几个物品&#xff0c;j代表容量变化&#xff1f; 动态规划中&#xff0c;每次决策都依赖于前一个状态的结果&#xff0c;在…

Radxa 学习摘录

文章目录 1、参考资料2、硬件知识3、shell4、交叉编译工具链5、问题6、DTS 1、参考资料 技术论坛&#xff08;推荐&#xff09; 官方资料下载 wiki资料 u-boot 文档 u-boot 源码 内核文档 内核源码 原理图 radxa-repo radxa-build radxa-pkg radxa-docs 2、硬件知识 Rad…

RabbitMQ(七)Shovel插件对比Federation插件

文章目录 Shovel和Federation的主要区别&#xff08;重点&#xff09;一、启用Shovel插件二、配置Shovel三、测试1、测试计划2、测试效果发布消息源节点目标节点 Shovel和Federation的主要区别&#xff08;重点&#xff09; • Shovel更简洁一些 • Federation更倾向于跨集群使…

国外的Claude3.5 Sonnet Artifacts和国内的CodeFlying孰强孰弱?

在Claude 3.5 Sonnet发布后&#xff0c;最受大家关注的问题应该就是它在编写代码能力上的变化。 要知道在Claude3.0发布以来的这几个月就因为它的编写代码能力而一直受到人们的诟病。 那Anthropic这次终于是不负众望&#xff0c;在Claude 3.5 Sonnet中更新了一个叫做Artifact…

mysql是什么

mysql是什么 是DBMS软件系统&#xff0c;并不是一个数据库&#xff0c;管理数据库 DBMS相当于用户和数据库之间的桥梁&#xff0c;有超过300种不同的dbms系统 mysql是关系型数据库&#xff0c;关系型数据库存储模型很想excel&#xff0c;用行和列组织数据 sql是一门编程语言…

致敬经典:在国产开源操作系统 RT-Thread 重温 UNIX 彩色终端

引言 上篇文章里我们向大家介绍了 RT-Thread v5.1.0 的一些新特性。其中包括了终端环境的进一步完善。终端是人机交互的重要接口。实用的终端工具可以显著地提升系统使用者的幸福指数。举例来说&#xff0c;当我们想要修改一些系统配置&#xff0c;或是编写脚本时&#xff0c;一…

鸿蒙开发设备管理:【@ohos.distributedHardware.deviceManager (设备管理)】

设备管理 本模块提供分布式设备管理能力。 系统应用可调用接口实现如下功能&#xff1a; 注册和解除注册设备上下线变化监听发现周边不可信设备认证和取消认证设备查询可信设备列表查询本地设备信息&#xff0c;包括设备名称&#xff0c;设备类型和设备标识 说明&#xff1a…

什么是ArchiMate?有优缺点和运用场景?

一、什么是ArchiMate? ArchiMate是一种由The Open Group发布的企业级标准&#xff0c;它是一种整合多种架构的可视化业务分析模型语言&#xff0c;也属于架构描述语言&#xff08;ADL&#xff09;。ArchiMate主要从业务、应用和技术三个层次&#xff08;Layer&#xff09;&…

Web渗透:php反序列化漏洞

反序列化漏洞&#xff08;Deserialization Vulnerability&#xff09;是一种在应用程序处理数据的过程中&#xff0c;因不安全的反序列化操作引发的安全漏洞&#xff1b;反序列化是指将序列化的数据&#xff08;通常是字节流或字符串&#xff09;转换回对象的过程&#xff0c;如…