实现pytorch注意力机制-one demo

主要组成部分:

1. 定义注意力层

定义一个Attention_Layer类,接受两个参数:hidden_dim(隐藏层维度)和is_bi_rnn(是否是双向RNN)。

2. 定义前向传播:

定义了注意力层的前向传播过程,包括计算注意力权重和输出。

3. 数据准备

生成一个随机的数据集,包含3个句子,每个句子10个词,每个词128个特征。

4. 实例化注意力层:

实例化一个注意力层,接受两个参数:hidden_dim(隐藏层维度)和is_bi_rnn(是否是双向RNN)。

5. 前向传播

将数据传递给注意力层的前向传播方法。

6. 分析结果

获取第一个句子的注意力权重。

7. 可视化注意力权重

使用matplotlib库可视化了注意力权重。

**主要函数和类:**
Attention_Layer类:定义了注意力层的结构和前向传播过程。
forward方法:定义了注意力层的前向传播过程。
torch.from_numpy函数:将numpy数组转换为PyTorch张量。
matplotlib库:用于可视化注意力权重。
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt# 定义注意力层
class Attention_Layer(nn.Module):def __init__(self, hidden_dim, is_bi_rnn):super(Attention_Layer,self).__init__()self.hidden_dim = hidden_dimself.is_bi_rnn = is_bi_rnnif is_bi_rnn:self.Q_linear = nn.Linear(hidden_dim * 2, hidden_dim * 2, bias = False)self.K_linear = nn.Linear(hidden_dim * 2, hidden_dim * 2, bias = False)self.V_linear = nn.Linear(hidden_dim * 2, hidden_dim * 2, bias = False)else:self.Q_linear = nn.Linear(hidden_dim, hidden_dim, bias = False)self.K_linear = nn.Linear(hidden_dim, hidden_dim, bias = False)self.V_linear = nn.Linear(hidden_dim, hidden_dim, bias = False)def forward(self, inputs, lens):# 获取输入的大小size = inputs.size()Q = self.Q_linear(inputs) K = self.K_linear(inputs).permute(0, 2, 1)V = self.V_linear(inputs)max_len = max(lens)sentence_lengths = torch.Tensor(lens)mask = torch.arange(sentence_lengths.max().item())[None, :] < sentence_lengths[:, None]mask = mask.unsqueeze(dim = 1)mask = mask.expand(size[0], max_len, max_len)padding_num = torch.ones_like(mask)padding_num = -2**31 * padding_num.float()alpha = torch.matmul(Q, K)alpha = torch.where(mask, alpha, padding_num)alpha = F.softmax(alpha, dim = 2)out = torch.matmul(alpha, V)return out# 准备数据
data = np.random.rand(3, 10, 128)  # 3个句子,每个句子10个词,每个词128个特征
lens = [7, 10, 4]  # 每个句子的长度# 实例化注意力层
hidden_dim = 64
is_bi_rnn = True
att_L = Attention_Layer(hidden_dim, is_bi_rnn)# 前向传播
att_out = att_L(torch.from_numpy(data).float(), lens)# 分析结果
attention_weights = att_out[0, :, :].detach().numpy()  # 获取第一个句子的注意力权重# 可视化注意力权重
plt.imshow(attention_weights, cmap='hot', interpolation='nearest')
plt.colorbar()
plt.show()

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/19454.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

SAP-ABAP:SAP的Screen Layout Designer屏幕布局设计器详解及示例

在SAP中&#xff0c;Screen Layout Designer&#xff08;屏幕布局设计器&#xff09;是用于设计和维护屏幕&#xff08;Dynpro&#xff09;布局的工具。通过Screen Layout Designer&#xff0c;您可以创建和修改屏幕元素&#xff08;如输入字段、按钮、文本、表格控件等&#x…

windows11+ubuntu20.04双系统下卸载ubuntu并重新安装

windows11ubuntu20.04双系统下卸载ubuntu并重新安装 背景&#xff1a;昨晚我电脑ubuntu20.04系统突然崩溃了&#xff0c;无奈只能重装系统了&#xff08;好在没有什么重要数据&#xff09;。刚好趁着这次换个ubuntu24.04系统玩一下&#xff0c;学习一下ROS2。 现系统&#xff…

SpringBoot速成(11)更新用户头像,密码P13-P14

更新头像&#xff1a; 1.代码展示: 1.RequestParam 是 Spring MVC 中非常实用的注解&#xff0c;用于从 HTTP 请求中提取参数并绑定到控制器方法的参数上。 2.PatchMapping 是 Spring MVC 中的一个注解&#xff0c;用于处理 HTTP 的 PATCH 请求。PATCH 请求通常用于对资源的部…

DeepSeek R1 与 OpenAI O1:机器学习模型的巅峰对决

我的个人主页 我的专栏&#xff1a;人工智能领域、java-数据结构、Javase、C语言&#xff0c;希望能帮助到大家&#xff01;&#xff01;&#xff01;点赞&#x1f44d;收藏❤ 一、引言 在机器学习的广袤天地中&#xff0c;大型语言模型&#xff08;LLM&#xff09;无疑是最…

Datawhale 数学建模导论二 笔记1

第6章 数据处理与拟合模型 本章主要涉及到的知识点有&#xff1a; 数据与大数据Python数据预处理常见的统计分析模型随机过程与随机模拟数据可视化 本章内容涉及到基础的概率论与数理统计理论&#xff0c;如果对这部分内容不熟悉&#xff0c;可以参考相关概率论与数理统计的…

【个人开发】deepspeed+Llama-factory 本地数据多卡Lora微调

文章目录 1.背景2.微调方式2.1 关键环境版本信息2.2 步骤2.2.1 下载llama-factory2.2.2 准备数据集2.2.3 微调模式2.2.3.1 zero-3微调2.2.3.2 zero-2微调2.2.3.3 单卡Lora微调 2.3 踩坑经验2.3.1 问题一&#xff1a;ValueError: Undefined dataset xxxx in dataset_info.json.2…

STM32 如何使用DMA和获取ADC

目录 背景 ‌摇杆的原理 程序 端口配置 ADC 配置 DMA配置 背景 DMA是一种计算机技术&#xff0c;允许某些硬件子系统直接访问系统内存&#xff0c;而不需要中央处理器&#xff08;CPU&#xff09;的介入&#xff0c;从而减轻CPU的负担。我们可以通过DMA来从外设&#xf…

Jvascript网页设计案例:通过js实现一款密码强度检测,适用于等保测评整改

本文目录 前言功能预览样式特点总结&#xff1a;1. 整体视觉风格2. 密码输入框设计3. 强度指示条4. 结果文本与原因说明 功能特点总结&#xff1a;1. 密码强度检测2. 实时反馈机制3. 详细原因说明4. 视觉提示5. 交互体验优化 密码强度检测逻辑Html代码Javascript代码 前言 能满…

Mybatis高级(动态SQL)

目录 一、动态SQL 1.1 数据准备&#xff1a; 1.2 <if>标签 1.3<trim> 标签 1.4<where>标签 1.5<set>标签 1.6 <foreach>标签 1.7<include> 标签 一、动态SQL 动态SQL是Mybatis的强⼤特性之⼀&#xff0c;能够完成不同条件下不同…

mac 意外退出移动硬盘后再次插入移动硬盘不显示怎么办

第一步&#xff1a;sudo ps aux | grep fsck 打开mac控制台输入如下指令&#xff0c;我们看到会出现两个进程&#xff0c;看进程是root的这个 sudo ps aux|grep fsck 第二步&#xff1a;杀死进程 在第一步基础上我们知道不显示u盘的进程是&#xff1a;62319&#xff0c;我们…

(2025)深度分析DeepSeek-R1开源的6种蒸馏模型之间的逻辑处理和编写代码能力区别以及配置要求,并与ChatGPT进行对比(附本地部署教程)

(2025)通过Ollama光速部署本地DeepSeek-R1模型(支持Windows10/11)_deepseek猫娘咒语-CSDN博客文章浏览阅读1k次&#xff0c;点赞19次&#xff0c;收藏9次。通过Ollama光速部署本地DeepSeek-R1(支持Windows10/11)_deepseek猫娘咒语https://blog.csdn.net/m0_70478643/article/de…

qt + opengl 给立方体增加阴影

在前几篇文章里面学会了通过opengl实现一个立方体&#xff0c;那么这篇我们来学习光照。 风氏光照模型的主要结构由3个分量组成&#xff1a;环境(Ambient)、漫反射(Diffuse)和镜面(Specular)光照。下面这张图展示了这些光照分量看起来的样子&#xff1a; 1 环境光照(Ambient …

机器学习-监督学习

1. 定义与原理 监督学习依赖于标记数据&#xff08;即每个输入样本都对应已知的输出标签&#xff09;&#xff0c;模型通过分析这些数据中的规律&#xff0c;建立从输入特征到目标标签的映射函数。例如&#xff0c;在垃圾邮件检测中&#xff0c;输入是邮件内容&#xff0c;输出…

使用grafana v11 建立k线(蜡烛图)仪表板

先看实现的结果 沪铜主力合约 2025-02-12 的1分钟k线图 功能介绍: 左上角支持切换主力合约,日期,实现动态加载数据. 项目背景: 我想通过前端展示期货指定品种某1天的1分钟k线,类似tqsdk 的web_gui 生成图形化界面— TianQin Python SDK 3.7.8 文档 项目架构: 后端: fastap…

我们来学HTTP/TCP -- 另辟蹊径从响应入手

从响应入手 题记响应结语 题记 很多“废话”&#xff0c;在很多文章中出奇的一致那种感觉是&#xff0c;说了好像又没说一样&#xff0c;可以称之为“电子技术垃圾”当然&#xff0c;是从个人主观的感受&#xff0c;这该死的回旋镖估计也会打在自己头上但咱也学学哪吒精神“我…

Golang官方编程指南

文章目录 1. Golang 官方编程指南2. Golang 标准库API文档 1. Golang 官方编程指南 Golang 官方网站&#xff1a;https://go.dev/ 点击下一步&#xff0c;查看官方手册怎么用 https://tour.go-zh.org/welcome/1 手册中的内容比较简单 go语言是以包的形式化管理函数的 搜索包名…

开源语音克隆项目 OpenVoice V2 本地部署

#本机环境 WIN11 I5 GPU 4060ti 16G 内存 32G #开始 git clone https://github.com/myshell-ai/OpenVoice.git conda create -n opvenv python3.9 -y conda activate opvenv pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/…

Java8适配的markdown转换html工具(FlexMark)

坐标地址&#xff1a; <dependency><groupId>com.vladsch.flexmark</groupId><artifactId>flexmark-all</artifactId><version>0.60.0</version> </dependency> 工具类代码&#xff1a; import com.vladsch.flexmark.ext.tab…

Linux-文件IO

1.open函数 【1】基本概念和使用 #include <fcntl.h> int open(const char *pathname&#xff0c;int flags); int open(const char *pathname&#xff0c;int flags&#xff0c;mode_t mode); 功能: 打开或创建文件 参数: pathname //打开的文件名 f…

flutter 专题四十八 Google发布Flutter 2.0正式版,支持全平台程序构建

今天&#xff0c;Google发布了 Flutter 2.0的正式版本&#xff0c;至2018年Flutter 1.0版本发布以来&#xff0c;在最近的3年的时间礼&#xff0c;Flutter进行了大量的升级以支持更多平台的开发需求。作为 Flutter 的重大升级&#xff0c;Flutter 2.0 增加了对桌面和 Web 应用程…