仅仅使用pytorch来手撕transformer架构(2):多头注意力MultiHeadAttention类的实现和向前传播

手撕MultiHeadAttention 类的代码,结合具体的例子来说明每一步的作用和计算过程。

往期文章:
仅仅使用pytorch来手撕transformer架构(1):位置编码的类的实现和向前传播

最适合小白入门的Transformer介绍

1. 初始化方法 __init__

def __init__(self, embed_size, heads):super(MultiHeadAttention, self).__init__()self.embed_size = embed_sizeself.heads = headsself.head_dim = embed_size // headsassert self.head_dim * heads == embed_size, "Embed size needs to be divisible by heads"self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)self.fc_out = nn.Linear(heads * self.head_dim, embed_size)

1.1参数解释

  • embed_size:嵌入向量的维度,表示每个输入向量的大小。
  • heads:注意力头的数量。多头注意力机制将输入分割成多个“头”,每个头学习不同的特征。
  • head_dim:每个注意力头的维度大小,计算公式为 embed_size // heads。这意味着每个头处理的特征子集的大小。

1.2线性变换层

  • self.valuesself.keysself.queries

    • 这些是线性变换层,用于将输入的嵌入向量分别转换为值(Values)、键(Keys)和查询(Queries)。
    • 每个线性层的输入和输出维度都是 self.head_dim,因为每个头处理的特征子集大小为 self.head_dim
    • 使用 bias=False 是为了简化计算,避免引入额外的偏置项。
  • self.fc_out

    • 在多头注意力计算完成后,将所有头的输出拼接起来,并通过一个线性层将维度转换回原始的嵌入维度 embed_size

2. 前向传播方法 forward

def forward(self, values, keys, query, mask):N = query.shape[0]  # Batch sizevalue_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

2.1输入参数

  • valueskeysquery
    • 这三个输入张量的形状通常为 (batch_size, seq_len, embed_size)
    • 它们分别对应于值(Values)、键(Keys)和查询(Queries)。
  • mask
    • 用于遮蔽某些位置的注意力权重,避免模型关注到不应该关注的部分(例如,解码器中的未来信息)。

2.2多头注意力计算过程

2.2.1 将输入嵌入分割为多个头:
values = values.reshape(N, value_len, self.heads, self.head_dim)
keys = keys.reshape(N, key_len, self.heads, self.head_dim)
queries = query.reshape(N, query_len, self.heads, self.head_dim)
  • 将输入的嵌入向量分割成 heads 个头,每个头的维度为 self.head_dim
  • 例如,如果 embed_size = 256heads = 8,则 self.head_dim = 32,每个头处理 32 维的特征。
  • 重塑后的形状为 (N, seq_len, heads, head_dim)
2.2.2 线性变换:
values = self.values(values)
keys = self.keys(keys)
queries = self.queries(queries)
  • 对每个头的值、键和查询分别进行线性变换。
  • 这一步将输入特征投影到不同的子空间中,使得每个头可以学习不同的特征。
2.2.3计算注意力分数(Attention Scores):
energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
  • 使用 torch.einsum 计算查询和键之间的点积,得到注意力分数矩阵。
  • 公式 nqhd,nkhd->nhqk 表示:
    • n:批量大小(Batch Size)。
    • q:查询序列的长度。
    • k:键序列的长度。
    • h:头的数量。
    • d:每个头的维度。
  • 输出的 energy 形状为 (N, heads, query_len, key_len)
2.2.4应用掩码(Masking):
if mask is not None:energy = energy.masked_fill(mask == 0, float("-1e20"))
  • 如果提供了掩码,将掩码为 0 的位置的注意力分数设置为一个非常小的值(如 -1e20),这样在后续的 softmax 计算中,这些位置的权重会趋近于 0。
2.2.5计算注意力权重:
attention = torch.softmax(energy / (self.embed_size ** (0.5)), dim=3)
  • 对注意力分数进行 softmax 归一化,得到注意力权重。
  • 除以 sqrt(embed_size) 是为了缩放点积结果,避免梯度消失或爆炸。
2.2.6应用注意力权重:
out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(N, query_len, self.heads * self.head_dim
)
  • 使用 torch.einsum 将注意力权重与值相乘,得到加权的值。
  • 公式 nhql,nlhd->nqhd 表示:
    • n:批量大小。
    • h:头的数量。
    • q:查询序列的长度。
    • l:值序列的长度。
    • d:每个头的维度。
  • 输出的 out 形状为 (N, query_len, heads * self.head_dim)
2.2.7线性变换输出:
out = self.fc_out(out)
  • 将所有头的输出拼接起来,并通过一个线性层将维度转换回原始的嵌入维度 embed_size

3. 示例矩阵计算

假设:

  • embed_size = 4
  • heads = 2
  • head_dim = embed_size // heads = 2
  • 输入序列长度为 3,批量大小为 1。

3.1输入张量

values = torch.tensor([[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]], dtype=torch.float32)
keys = torch.tensor([[[13, 14, 15, 16], [17, 18, 19, 20], [21, 22, 23, 24]]], dtype=torch.float32)
query = torch.tensor([[[25, 26, 27, 28], [29, 30, 31, 32], [33, 34, 35, 36]]], dtype=torch.float32)
mask = None

3.2重塑为多头

values = values.reshape(1, 3, 2, 2)  # (N, value_len, heads, head_dim)
keys = keys.reshape(1, 3, 2, 2)
queries = query.reshape(1, 3, 2, 2)

3.3线性变换

假设线性变换层的权重为单位矩阵(简化计算),则:

values = self.values(values)  # 不改变值
keys = self.keys(keys)
queries = self.queries(queries)

3.4计算注意力分数

energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])

假设:

  • queries = [[[1, 2], [3, 4]], [[5, 6], [7, 8]], [[9, 10], [11, 12]]]
  • keys = [[[13, 14], [15, 16]], [[17, 18], [19, 20]], [[21, 22], [23, 24]]]

计算点积:

energy = [[[[1*13 + 2*14, 1*15 + 2*16], [1*17 + 2*18, 1*19 + 2*20]],

完整代码:

class MultiHeadAttention(nn.Module):def __init__(self, embed_size, heads):super(MultiHeadAttention, self).__init__()self.embed_size = embed_sizeself.heads = headsself.head_dim = embed_size // headsassert self.head_dim * heads == embed_size, "嵌入尺寸需要被头部整除"self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)self.fc_out = nn.Linear(heads * self.head_dim, embed_size)def forward(self, values, keys, query, mask):N = query.shape[0]value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]# Split the embedding into self.heads different piecesvalues = values.reshape(N, value_len, self.heads, self.head_dim)keys = keys.reshape(N, key_len, self.heads, self.head_dim)queries = query.reshape(N, query_len, self.heads, self.head_dim)values = self.values(values)keys = self.keys(keys)queries = self.queries(queries)energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])if mask is not None:energy = energy.masked_fill(mask == 0, float("-1e20"))attention = torch.softmax(energy / (self.embed_size ** (1 / 2)), dim=3)out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(N, query_len, self.heads * self.head_dim)out = self.fc_out(out)return out

作者码字不易,觉得有用的话不妨点个赞吧,关注我,持续为您更新AI的优质内容。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/31120.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

MCP极简入门:超快速上手运行简单的MCP服务和MCP客户端

MCP是什么? 首先我们快速过一下MCP的基本概念,接着我们会通过一个简单的天气服务的教程,来上手学会使用MCP服务和在主机运行服务。本文根据官方教程改编。 1. MCP的基本概念 MCP(Model Context Protocol,模型上下文…

DeepSeek进阶应用(一):结合Mermaid绘图(流程图、时序图、类图、状态图、甘特图、饼图)

🌟前言: 在软件开发、项目管理和系统设计等领域,图表是表达复杂信息的有效工具。随着AI助手如DeepSeek的普及,我们现在可以更轻松地创建各种专业图表。 名人说:博观而约取,厚积而薄发。——苏轼《稼说送张琥》 创作者&…

海康线扫相机平场矫正教程

0、平场矫正前的准备确认 1、白纸准备 确保视野中有一张平整且无折痕的白纸,使其完全铺满相机的整个视野。 2、行高设置 将行高参数设定为 2048。 3、灰度值控制 相机端图像的灰度值应维持在 120 - 180 这个区间内。同时,最亮像素点与最暗像素点的灰度…

数智读书笔记系列015 探索思维黑箱:《心智社会:从细胞到人工智能,人类思维的优雅解读》读书笔记

引言 《The Society of Mind》(《心智社会》)的作者马文・明斯基(Marvin Minsky),是人工智能领域的先驱和奠基者之一 ,1969 年获得图灵奖,被广泛认为是对人工智能领域影响最大的科学家之一。他…

游戏引擎学习第148天

回顾并规划今天的工作 没有使用引擎,也没有任何库支持,只有我们自己,编写游戏的所有代码,不仅仅是小小的部分,而是从头到尾。现在,我们正处于一个我一直想做的任务中,虽然一切都需要按部就班&a…

bug-Ant中a-select的placeholder不生效(绑定默认值为undefined)

1.问题 Ant中使用a-select下拉框时,placeholder设置输入框显示默认值提示,vue2ant null与undefined在js中明确的区别: null:一个值被定义,定义为“空值” undefined:根本不存在定义 2.解决 2.1 a-select使…

DeepSeek教我写词典爬虫获取单词的音标和拼写

Python在爬虫领域展现出了卓越的功能性,不仅能够高效地抓取目标数据,还能便捷地将数据存储至本地。在众多Python爬虫应用中,词典数据的爬取尤为常见。接下来,我们将以dict.cn为例,详细演示如何编写一个用于爬取词典数据…

springboot-自定义注解

1.注解的概念 注解是一种能被添加到java代码中的【元数据,类、方法、变量、参数和包】都可以用注解来修饰。用来定义一个类、属性或一些方法,以便程序能被捕译处理。 相当于一个说明文件,告诉应用程序某个被注解的类或属性是什么&#xff0c…

低代码开发直聘管理系统

低代码 DeepSeek 组合的方式开发直聘管理系统,兼职是开挂的存在。整个管理后台系统 小程序端接口的输出,只花了两个星期不到。 一、技术栈 后端:SpringBoot mybatis MySQL Redis 前端:Vue elementui 二、整体效果 三、表结…

【面试】Kafka

Kafka 1、为什么要使用 kafka2、Kafka 的架构是怎么样的3、什么是 Kafka 的重平衡机制4、Kafka 几种选举过程5、Kafka 高水位了解过吗6、Kafka 如何保证消息不丢失7、Kafka 如何保证消息不重复消费8、Kafka 为什么这么快 1、为什么要使用 kafka 1. 解耦:在一个复杂…

文件操作详解(万字长文)

C语言文件操作 一、为什么使用文件?二、文件分类三、文件的打开和关闭四、文件的顺序读写4.1fputc4.2fgetc4.3fputs4.4fgets4.5 fprintf4.6 fscanf4.7 fwrite4.8 fread 五、文件的随机读写5.1 fseek5.2 ftell和rewind六、文件读取结束的判定七、文件缓冲区 一、为什…

突破极限!蓝耘通义万相2.1引爆AI多模态新纪元——性能与应用全方位革新

云边有个稻草人-CSDN博客 目录 一、 引言 二、 蓝耘通义万相2.1版本概述 三、 蓝耘通义万相2.1的核心技术改进 【多模态数据处理】 【语音识别与文本转化】 【自然语言处理(NLP)改进】 【跨平台兼容性】 四、 蓝耘注册 部署流程—新手也能轻松…

力扣-股票买入问题

dp dp元素代表最大利润 f[j][1] 代表第 j 次交易后持有股票的最大利润。在初始状态,持有股票意味着你花钱买入了股票,此时的利润应该是负数(扣除了买入股票的成本),而不是 0。所以,把 f[j][1] 初始化为负…

ubuntu22.04本地部署OpenWebUI

一、简介 Open WebUI 是一个可扩展、功能丰富且用户友好的自托管 AI 平台,旨在完全离线运行。它支持各种 LLM 运行器,如 Ollama 和 OpenAI 兼容的 API,并内置了 RAG 推理引擎,使其成为强大的 AI 部署解决方案。 二、安装 方法 …

Unity开发——CanvasGroup组件介绍和应用

CanvasGroup是Unity中用于控制UI的透明度、交互性和渲染顺序的组件。 一、常用属性的解释 1、alpha:控制UI的透明度 类型:float,0.0 ~1.0, 其中 0.0 完全透明,1.0 完全不透明。 通过调整alpha值可以实现UI的淡入淡…

LVGL直接解码png图片的方法

通过把png文件解码为.C文件,再放到工程中的供使用,这种方式随时速度快(应为已经解码,代码中只要直接加载图片数据显示出来即可),但是不够灵活,适用于哪些简单又不经常需要更换UI的场景下使用。如…

【算法day5】最长回文子串——马拉车算法

最长回文子串 给你一个字符串 s,找到 s 中最长的 回文 子串。 https://leetcode.cn/problems/longest-palindromic-substring/description/ 算法思路: class Solution { public:string longestPalindrome(string s) {int s_len s.size();string tmp …

JavaWeb-HttpServletRequest请求域接口

文章目录 HttpServletRequest请求域接口HttpServletRequest请求域接口简介关于请求域和应用域的区别 请求域接口中的相关方法获取前端请求参数(getParameter系列方法)存储请求域名参数(Attribute系列方法)获取客户端的相关地址信息获取项目的根路径 关于转发和重定向的细致剖析…

Dify 本地部署教程

目录 一、下载安装包 二、修改配置 三、启动容器 四、访问 Dify 五、总结 本篇文章主要记录 Dify 本地部署过程,有问题欢迎交流~ 一、下载安装包 从 Github 仓库下载最新稳定版软件包,点击下载~,当然也可以克隆仓库或者从仓库里直接下载zip源码包。 目前最新版本是V…

css错峰布局/瀑布流样式(类似于快手样式)

当样式一侧比较高的时候会自动换行,尽量保持高度大概一致, 例: 一侧元素为5,另一侧元素为6 当为5的一侧过于高的时候,可能会变为4/7分部dom节点 如果不需要这样的话删除样式 flex-flow:column wrap; 设置父级dom样…