【深度学习实验】注意力机制(四):点积注意力与缩放点积注意力之比较

文章目录

  • 一、实验介绍
  • 二、实验环境
    • 1. 配置虚拟环境
    • 2. 库版本介绍
  • 三、实验内容
    • 0. 理论介绍
      • a. 认知神经学中的注意力
      • b. 注意力机制
    • 1. 注意力权重矩阵可视化(矩阵热图)
    • 2. 掩码Softmax 操作
    • 3. 打分函数——加性注意力模型
    • 3. 打分函数——点积注意力与缩放点积注意力
      • a. 缩放点积注意力模型
      • b. 点积注意力模型
      • c. 模拟实验
      • d. 模型比较与选择
      • e. 代码整合

一、实验介绍

  注意力机制作为一种模拟人脑信息处理的关键工具,在深度学习领域中得到了广泛应用。本系列实验旨在通过理论分析和代码演示,深入了解注意力机制的原理、类型及其在模型中的实际应用。

本文将介绍将介绍带有掩码的 softmax 操作

二、实验环境

  本系列实验使用了PyTorch深度学习框架,相关操作如下:

1. 配置虚拟环境

conda create -n DL python=3.7 
conda activate DL
pip install torch==1.8.1+cu102 torchvision==0.9.1+cu102 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html
conda install matplotlib
 conda install scikit-learn

2. 库版本介绍

软件包本实验版本目前最新版
matplotlib3.5.33.8.0
numpy1.21.61.26.0
python3.7.16
scikit-learn0.22.11.3.0
torch1.8.1+cu1022.0.1
torchaudio0.8.12.0.2
torchvision0.9.1+cu1020.15.2

三、实验内容

0. 理论介绍

a. 认知神经学中的注意力

  人脑每个时刻接收的外界输入信息非常多,包括来源于视
觉、听觉、触觉的各种各样的信息。单就视觉来说,眼睛每秒钟都会发送千万比特的信息给视觉神经系统。人脑通过注意力来解决信息超载问题,注意力分为两种主要类型:

  • 聚焦式注意力(Focus Attention):
    • 这是一种自上而下的有意识的注意力,通常与任务相关。
    • 在这种情况下,个体有目的地选择关注某些信息,而忽略其他信息。
    • 在深度学习中,注意力机制可以使模型有选择地聚焦于输入的特定部分,以便更有效地进行任务,例如机器翻译、文本摘要等。
  • 基于显著性的注意力(Saliency-Based Attention)
    • 这是一种自下而上的无意识的注意力,通常由外界刺激驱动而不需要主动干预。
    • 在这种情况下,注意力被自动吸引到与周围环境不同的刺激信息上。
    • 在深度学习中,这种注意力机制可以用于识别图像中的显著物体或文本中的重要关键词。

  在深度学习领域,注意力机制已被广泛应用,尤其是在自然语言处理任务中,如机器翻译、文本摘要、问答系统等。通过引入注意力机制,模型可以更灵活地处理不同位置的信息,提高对长序列的处理能力,并在处理输入时动态调整关注的重点。

b. 注意力机制

  1. 注意力机制(Attention Mechanism):

    • 作为资源分配方案,注意力机制允许有限的计算资源集中处理更重要的信息,以应对信息超载的问题。
    • 在神经网络中,它可以被看作一种机制,通过选择性地聚焦于输入中的某些部分,提高了神经网络的效率。
  2. 基于显著性的注意力机制的近似: 在神经网络模型中,最大汇聚(Max Pooling)和门控(Gating)机制可以被近似地看作是自下而上的基于显著性的注意力机制,这些机制允许网络自动关注输入中与周围环境不同的信息。

  3. 聚焦式注意力的应用: 自上而下的聚焦式注意力是一种有效的信息选择方式。在任务中,只选择与任务相关的信息,而忽略不相关的部分。例如,在阅读理解任务中,只有与问题相关的文章片段被选择用于后续的处理,减轻了神经网络的计算负担。

  4. 注意力的计算过程:注意力机制的计算分为两步。首先,在所有输入信息上计算注意力分布,然后根据这个分布计算输入信息的加权平均。这个计算依赖于一个查询向量(Query Vector),通过一个打分函数来计算每个输入向量和查询向量之间的相关性。

    • 注意力分布(Attention Distribution):注意力分布表示在给定查询向量和输入信息的情况下,选择每个输入向量的概率分布。Softmax 函数被用于将分数转化为概率分布,其中每个分数由一个打分函数计算得到。

    • 打分函数(Scoring Function):打分函数衡量查询向量与输入向量之间的相关性。文中介绍了几种常用的打分函数,包括加性模型、点积模型、缩放点积模型和双线性模型。这些模型通过可学习的参数来调整注意力的计算。

      • 加性模型 s ( x , q ) = v T tanh ⁡ ( W x + U q ) \mathbf{s}(\mathbf{x}, \mathbf{q}) = \mathbf{v}^T \tanh(\mathbf{W}\mathbf{x} + \mathbf{U}\mathbf{q}) s(x,q)=vTtanh(Wx+Uq)

      • 点积模型 s ( x , q ) = x T q \mathbf{s}(\mathbf{x}, \mathbf{q}) = \mathbf{x}^T \mathbf{q} s(x,q)=xTq

      • 缩放点积模型 s ( x , q ) = x T q D \mathbf{s}(\mathbf{x}, \mathbf{q}) = \frac{\mathbf{x}^T \mathbf{q}}{\sqrt{D}} s(x,q)=D xTq (缩小方差,增大softmax梯度)

      • 双线性模型 s ( x , q ) = x T W q \mathbf{s}(\mathbf{x}, \mathbf{q}) = \mathbf{x}^T \mathbf{W} \mathbf{q} s(x,q)=xTWq (非对称性)

  5. 软性注意力机制

    • 定义:软性注意力机制通过一个“软性”的信息选择机制对输入信息进行汇总,允许模型以概率形式对输入的不同部分进行关注,而不是强制性地选择一个部分。

    • 加权平均:软性注意力机制中的加权平均表示在给定任务相关的查询向量时,每个输入向量受关注的程度,通过注意力分布实现。

    • Softmax 操作:注意力分布通常通过 Softmax 操作计算,确保它们成为一个概率分布。

1. 注意力权重矩阵可视化(矩阵热图)

【深度学习实验】注意力机制(一):注意力权重矩阵可视化(矩阵热图heatmap)

在这里插入图片描述

2. 掩码Softmax 操作

【深度学习实验】注意力机制(二):掩码Softmax 操作
在这里插入图片描述

3. 打分函数——加性注意力模型

  • 加性模型 s ( x , q ) = v T tanh ⁡ ( W x + U q ) \mathbf{s}(\mathbf{x}, \mathbf{q}) = \mathbf{v}^T \tanh(\mathbf{W}\mathbf{x} + \mathbf{U}\mathbf{q}) s(x,q)=vTtanh(Wx+Uq)

【深度学习实验】注意力机制(三):打分函数——加性注意力模型

3. 打分函数——点积注意力与缩放点积注意力

  • 点积模型 s ( x , q ) = x T q \mathbf{s}(\mathbf{x}, \mathbf{q}) = \mathbf{x}^T \mathbf{q} s(x,q)=xTq

  • 缩放点积模型 s ( x , q ) = x T q D \mathbf{s}(\mathbf{x}, \mathbf{q}) = \frac{\mathbf{x}^T \mathbf{q}}{\sqrt{D}} s(x,q)=D xTq (缩小方差,增大softmax梯度)

a. 缩放点积注意力模型

class DotProductAttention(nn.Module):"""缩放点积注意力"""def __init__(self, dropout, **kwargs):super(DotProductAttention, self).__init__(**kwargs)# 使用暂退法进行模型正则化self.dropout = nn.Dropout(dropout)def forward(self, queries, keys, values, valid_lens=None):d = queries.shape[-1]self.scores = torch.bmm(queries, keys.transpose(1, 2)) / math.sqrt(d)self.attention_weights = masked_softmax(self.scores, valid_lens)return torch.bmm(self.dropout(self.attention_weights), values)
  1. 初始化方法 (__init__):

    • 参数:
      • dropout: Dropout 正则化的概率。
    • 说明: 初始化方法定义了模型的组件,仅包含了一个 Dropout 正则化层。
  2. 前向传播方法 (forward):

    • 参数:
      • queries: 查询张量,形状为 (batch_size, num_queries, d)
      • keys: 键张量,形状为 (batch_size, num_kv_pairs, d)
      • values: 值张量,形状为 (batch_size, num_kv_pairs, value_size)
      • valid_lens: 有效长度张量,形状为 (batch_size,)(batch_size, num_queries)
    • 返回值: 加权平均后的值张量,形状为 (batch_size, num_queries, value_size)
  3. 实现细节:

    • 计算缩放点积得分:通过张量乘法计算 querieskeys 的点积,然后除以 d \sqrt{d} d 进行缩放,其中 d d d 是查询或键的维度。
    • 使用 masked_softmax 函数计算注意力权重,根据有效长度对注意力进行掩码。
    • 将注意力权重应用到值上,得到最终的加权平均结果。
    • 使用 Dropout 对注意力权重进行正则化。

b. 点积注意力模型

class DotProductAttention2(nn.Module):"""点积注意力"""def __init__(self, dropout, **kwargs):super(DotProductAttention2, self).__init__(**kwargs)# 使用暂退法进行模型正则化self.dropout = nn.Dropout(dropout)def forward(self, queries, keys, values, valid_lens=None):# P195:(8.3),(8.4)# 在计算得分时不进行缩放操作(即不再除以sqrt(d))self.scores = torch.bmm(queries, keys.transpose(1, 2))self.attention_weights = masked_softmax(self.scores, valid_lens)return torch.bmm(self.dropout(self.attention_weights), values)
  • 区别:
    • 计算点积得分:通过张量乘法计算 querieskeys 的点积。

c. 模拟实验

  • 模型数据
queries, keys = torch.normal(0, 1, (2, 1, 2)), torch.ones((2, 10, 2))
values = torch.arange(40, dtype=torch.float32).reshape(1, 10, 4).repeat(2, 1, 1)
valid_lens = torch.tensor([2, 6])
  • 模型应用
# 创建缩放点积注意力模型
attention = DotProductAttention(0.5)
attention.eval()
# 使用模型进行前向传播
attention(queries, keys, values, valid_lens)# 创建点积注意力模型
attention2 = DotProductAttention2(0.5)
attention2.eval()
# 使用模型进行前向传播
attention2(queries, keys, values, valid_lens)

在这里插入图片描述

  • 权重可视化
      为了直观地展示模型在不同输入下的注意力分布,使用前文 show_heatmaps 函数,通过热图的形式展示注意力权重。
# 可视化缩放点积注意力权重
show_heatmaps(attention.attention_weights.reshape((1, 1, 2, 10)),xlabel='Keys', ylabel='Queries')# 可视化点积注意力权重
show_heatmaps(attention2.attention_weights.reshape((1, 1, 2, 10)),xlabel='Keys', ylabel='Queries')

在这里插入图片描述

d. 模型比较与选择

  • 缩放点积注意力模型

    • 适用于处理高维度的查询和键。
    • 通过缩放操作有助于防止点积得分的方差过大。
  • 点积注意力模型

    • 适用于处理相对较低维度的查询和键。
    • 更方便地利用矩阵乘积提高计算效率。
      在这里插入图片描述

e. 代码整合

# 导入必要的库
import math
import torch
from torch import nn
import torch.nn.functional as F
from d2l import torch as d2l
from torch.utils import datadef masked_softmax(X, valid_lens):"""通过在最后一个轴上掩蔽元素来执行softmax操作"""# X:3D张量,valid_lens:1D或2D张量if valid_lens is None:return nn.functional.softmax(X, dim=-1)else:shape = X.shapeif valid_lens.dim() == 1:valid_lens = torch.repeat_interleave(valid_lens, shape[1])else:valid_lens = valid_lens.reshape(-1)# 最后一轴上被掩蔽的元素使用一个非常大的负值替换,从而其softmax输出为0X = d2l.sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6)return nn.functional.softmax(X.reshape(shape), dim=-1)def show_heatmaps(matrices, xlabel, ylabel, titles=None, figsize=(2.5, 2.5), cmap='Reds'):"""显示矩阵热图"""d2l.use_svg_display()num_rows, num_cols = matrices.shape[0], matrices.shape[1]fig, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize,sharex=True, sharey=True, squeeze=False)for i, (row_axes, row_matrices) in enumerate(zip(axes, matrices)):for j, (ax, matrix) in enumerate(zip(row_axes, row_matrices)):pcm = ax.imshow(matrix.detach().numpy(), cmap=cmap)if i == num_rows - 1:ax.set_xlabel(xlabel)if j == 0:ax.set_ylabel(ylabel)if titles:ax.set_title(titles[j])fig.colorbar(pcm, ax=axes, shrink=0.6)class DotProductAttention(nn.Module):"""缩放点积注意力"""def __init__(self, dropout, **kwargs):super(DotProductAttention, self).__init__(**kwargs)# 使用暂退法进行模型正则化self.dropout = nn.Dropout(dropout)# queries的形状:(batch_size,查询的个数,d)# keys的形状:(batch_size,“键-值”对的个数,d)# values的形状:(batch_size,“键-值”对的个数,值的维度)# valid_lens的形状:(batch_size,)或者(batch_size,查询的个数)def forward(self, queries, keys, values, valid_lens=None):print(queries)d = queries.shape[-1]print(d)self.scores = torch.bmm(queries, keys.transpose(1, 2)) / math.sqrt(d)print(self.scores)self.attention_weights = masked_softmax(self.scores, valid_lens)return torch.bmm(self.dropout(self.attention_weights), values)class DotProductAttention2(nn.Module):"""点积注意力"""def __init__(self, dropout, **kwargs):super(DotProductAttention2, self).__init__(**kwargs)# 使用暂退法进行模型正则化self.dropout = nn.Dropout(dropout)# queries的形状:(batch_size,查询的个数,d)# keys的形状:(batch_size,“键-值”对的个数,d)# values的形状:(batch_size,“键-值”对的个数,值的维度)# valid_lens的形状:(batch_size,)或者(batch_size,查询的个数)def forward(self, queries, keys, values, valid_lens=None):# P195:(8.3),(8.4)# 在计算得分时不进行缩放操作(即不再除以sqrt(d))self.scores = torch.bmm(queries, keys.transpose(1, 2))print(self.scores)self.attention_weights = masked_softmax(self.scores, valid_lens)return torch.bmm(self.dropout(self.attention_weights), values)queries, keys = torch.normal(0, 1, (2, 1, 2)), torch.ones((2, 10, 2))
values = torch.arange(40, dtype=torch.float32).reshape(1, 10, 4).repeat(2, 1, 1)
valid_lens = torch.tensor([2, 6])# 缩放点积注意力模型
attention = DotProductAttention(0.5)
attention.eval()
attention(queries, keys, values, valid_lens)# 点积注意力模型
attention2 = DotProductAttention2(0.5)
attention2.eval()
attention2(queries, keys, values, valid_lens)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/201120.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

设计模式(二)-创建者模式(3)-抽象工厂模式

一、为什么需要抽象工厂模式? 在工厂模式中,我们需要定义多个继承于共同工厂抽象基类的工厂子类,这些子类负责创建一个对应的对象。工厂模式存在一个缺点就是:每次扩展新的工厂子类,就会增加系统的复杂度。 如果我们…

[论文笔记] Scaling Laws for Neural Language Models

概览: 一、总结 计算量、数据集大小、模型参数量大小的幂律 与 训练损失呈现 线性关系。 三个参数同时放大时,如何得到最佳的性能? 更大的模型 需要 更少的样本 就能达到相同的效果。 </

如何使用 WPF 应用程序连接 FastReport报表

随着期待已久的FastReport WPF的发布&#xff0c;您不再需要使用 FastReport .NET 来处理基于 WPF 的项目。 不久前&#xff0c;在 FastReport .NET 中使用 WPF 还相当不方便。并非一切都进展顺利&#xff1b;连接 FastReport.dll 和许多其他问题存在问题。我们重新思考了该方…

No such module ‘FacebookCore‘

在下面的地方添加这个库

2023亚太杯数学建模思路 - 案例:最短时间生产计划安排

文章目录 0 赛题思路1 模型描述2 实例2.1 问题描述2.2 数学模型2.2.1 模型流程2.2.2 符号约定2.2.3 求解模型 2.3 相关代码2.4 模型求解结果 建模资料 0 赛题思路 &#xff08;赛题出来以后第一时间在CSDN分享&#xff09; https://blog.csdn.net/dc_sinor?typeblog 最短时…

Jmeter之压力测试总结!

一、基本概念 1.线程组N&#xff1a;代表一定数量的并发用户&#xff0c;所谓并发就是指同一时刻访问发送请求的用户。线程组就是模拟并发用户访问。 2.Ramp-Up Period(in seconds)&#xff1a;建立所有线程的周期&#xff0c;就是告诉jmeter要在多久没启动所有线程&#xff…

【论文阅读笔记】Deep learning for time series classification: a review

【论文阅读笔记】Deep learning for time series classification: a review 摘要 在这篇文章中&#xff0c;作者通过对TSC的最新DNN架构进行实证研究&#xff0c;探讨了深度学习算法在TSC中的当前最新性能。文章提供了对DNNs在TSC的统一分类体系下在各种时间序列领域中的最成功…

和鲸 × 暨大经管:高效 SAAS 服务持续赋能交叉学科应用型数据人才培养

随着新一轮科技革命与产业变革的加速演进&#xff0c;拥有学科背景的应用型数据科学人才逐渐成为我国政产学研各界的人力资源需求重点。为响应需求&#xff0c;国家愈发重视新生力量数据思维与意识的培养&#xff0c;各高校也纷纷探索如何以新兴信息技术赋能传统主流学科。 在…

[Linux] shell脚本相关知识

一、shell脚本基础 1.1 shell脚本的作用 shell将人类使用的高级语言翻译成二进制&#xff0c;再将二进制翻译成高级语言。换句话就是人类写了一个命令集合&#xff0c;然后用bash去翻译给硬件执行。 linux中常见的shell&#xff1a; bash:基于gun的框架下发展的shell csh:类…

竞赛选题 车道线检测(自动驾驶 机器视觉)

0 前言 无人驾驶技术是机器学习为主的一门前沿领域&#xff0c;在无人驾驶领域中机器学习的各种算法随处可见&#xff0c;今天学长给大家介绍无人驾驶技术中的车道线检测。 1 车道线检测 在无人驾驶领域每一个任务都是相当复杂&#xff0c;看上去无从下手。那么面对这样极其…

计算机网络学习笔记(六):应用层(待更新)

目录​​​​​​​ 6.2 文件传送协议FTP(File Transfer Protocol) 6.2.1 FTP概述 6.2.2 FTP的基本工作原理 6.5 电子邮件&#xff1a;SMTP、POP3、IMAP 6.5.1 电子邮件概述 6.5.2 发邮件&#xff1a;简单邮件传送协议SMTP 6.5.3 电子邮件的信息格式、地址格式 6.5.4 收…

Linux编辑器-gcc/g++使用

> 作者简介&#xff1a;დ旧言~&#xff0c;目前大二&#xff0c;现在学习Java&#xff0c;c&#xff0c;c&#xff0c;Python等 > 座右铭&#xff1a;松树千年终是朽&#xff0c;槿花一日自为荣。 > 目标&#xff1a;熟练使用gcc/g编译器 > 毒鸡汤&#xff1a;真正…

Leaflet结合Echarts实现迁徙图

效果图如下&#xff1a; <!DOCTYPE html> <html><head><title>Leaflet结合Echarts4实现迁徙图</title><meta charset"utf-8" /><meta name"viewport" content"widthdevice-width, initial-scale1.0">…

2023年中国离心制冷机产量及产业链分析[图]

离心制冷机是一种常用的空调制冷设备&#xff0c;目前主要应用于酒店、写字楼、商场、学校等众多大型场所的集中制冷场景。离心制冷机由离心式制冷压缩机、蒸发器、冷凝器、主电动机、抽气回收装置、润滑系统、控制柜和起动柜等零部件组成。这些零部件的组成有的采用分散型组装…

Rust语言精讲:数据类型全解析

大家好&#xff01;我是lincyang。 今天&#xff0c;我们将深入探讨Rust语言中的数据类型&#xff0c;这是理解和掌握Rust的基础。 Rust语言数据类型概览 Rust是静态类型语言&#xff0c;所有变量类型在编译时确定。Rust的数据类型分为两类&#xff1a;标量类型和复合类型。…

内容输入.type

内容输入.type 查看完整说明 语法 .type(text) .type(text, options)正确用法 cy.get(input).type(Hello, World) // Type Hello, World into the input错误用法 cy.type(Welcome) // Errors, cannot be chained off cy cy.clock().type(www.cypress.io) // Errors, clock…

Windows安装MongoDB

1、下载MongoDB的zip&#xff0c;解压 2、创建目录 mkdir D:\JavaSoftware\Database\MongoDB\mongodb-win32-x86_64-windows-5.0.8\data\db mkdir D:\JavaSoftware\Database\MongoDB\mongodb-win32-x86_64-windows-5.0.8\data\log 3、创建一个配置文件mongod.cfg&#xff0c…

YOLOv8改进 | DAttention (DAT)注意力机制实现极限涨点

论文地址&#xff1a; DAT论文地址 官方地址&#xff1a;官方代码的地址 代码地址&#xff1a;文末有修改了官方代码BUG的代码块复制粘贴即可 一、本文介绍 本文给大家带来的是YOLOv8改进DAT(Vision Transformer with Deformable Attention)的教程&#xff0c;其发布于2022…

数字IC基础:有符号数和无符号数加、减法的Verilog设计

相关阅读 数字IC基础https://blog.csdn.net/weixin_45791458/category_12365795.html?spm1001.2014.3001.5482 本文是对数字IC基础&#xff1a;有符号数和无符号数的加减运算一文中的谈到的有符号数加减法的算法进行Verilog实现&#xff0c;有关算法细节请阅读原文&#xff0…

git merge 和 git rebase

一、是什么 在使用 git 进行版本管理的项目中&#xff0c;当完成一个特性的开发并将其合并到 master 分支时&#xff0c;会有两种方式&#xff1a; git merge git rebasegit rebase 与 git merge都有相同的作用&#xff0c;都是将一个分支的提交合并到另一分支上&#xff0c;…