大模型面试准备(五):图解 Transformer 最关键模块 MHA

节前,我们组织了一场算法岗技术&面试讨论会,邀请了一些互联网大厂朋友、参加社招和校招面试的同学,针对大模型技术趋势、大模型落地项目经验分享、新手如何入门算法岗、该如何备战、面试常考点分享等热门话题进行了深入的讨论。


合集在这里:《大模型面试宝典》(2024版) 正式发布!


Transformer 原始论文中的模型结构如下图所示:
图片

上一篇文章讲解了 Transformer 的关键模块 Positional Encoding(大家可以自行翻阅),本篇文章讲解一下 Transformer 的最重要模块 Multi-Head Attention(MHA),毕竟 Transformer 的论文名称就叫 《Attention Is All You Need》。

Transformer 中的 Multi-Head Attention 可以细分为3种,Multi-Head Self-Attention(对应上图左侧Multi-Head Attention模块),Multi-Head Cross-Attention(对应上图右上Multi-Head Attention模块),Masked Multi-Head Self-Attention(对应上图右下Masked Multi-Head Attention模块)。

其中 Self 和 Cross 的区分是对应的 Q和 K、 V是否来自相同的输入。是否Mask的区分是是否需要看见全部输入和预测的输出,Encoder需要看见全部的输入问题,所以不能Mask;而Decoder是预测输出,当前预测只能看见之前的全部预测,不能看见之后的预测,所以需要Mask。

本篇文章主要通过图解的方式对 Multi-Head Attention 的核心思想和计算过程做讲解,喜欢本文记得收藏、点赞、关注。技术和面试交流,文末加入我们

MHA核心思想

在这里插入图片描述

MHA过程图解

注意力计算公式如下:

在这里插入图片描述

图示过程图下:

图片

多头注意力

MHA通过多个头的方式,可以增强自注意力机制聚合上下文信息的能力,以关注上下文的不同侧面,作用类似于CNN的多个卷积核。下面我们就通过一张图来完成MHA的解析:

图片

在这里插入图片描述

单头注意力

知道了多头注意力的实现方式后,那如果是通过单头注意力完成同样的计算,矩阵形式是什么样的呢?下面我还是以一图胜千言的方式来回答这个问题:

图片通过单头注意力的比较,相信大家对多头注意力(MHA)应该有了更好的理解。我们可以发现多头注意力就是将一个单头进行了切分计算,最后又将结果进行了合并,整个过程中的整体维度和计算量基本是不变的,但提升了模型的学习能力。

最后附上一份MHA的实现和Transformer的构建代码:

import torch
import torch.nn as nn# 定义多头自注意力层
class MultiHeadAttention(nn.Module):def __init__(self, d_model, n_heads):super(MultiHeadAttention, self).__init__()self.n_heads = n_heads  # 多头注意力的头数self.d_model = d_model  # 输入维度(模型的总维度)self.head_dim = d_model // n_heads  # 每个注意力头的维度assert self.head_dim * n_heads == d_model, "d_model必须能够被n_heads整除"  # 断言,确保d_model可以被n_heads整除# 线性变换矩阵,用于将输入向量映射到查询、键和值空间self.wq = nn.Linear(d_model, d_model)  # 查询(Query)的线性变换self.wk = nn.Linear(d_model, d_model)  # 键(Key)的线性变换self.wv = nn.Linear(d_model, d_model)  # 值(Value)的线性变换# 最终输出的线性变换,将多头注意力结果合并回原始维度self.fc_out = nn.Linear(d_model, d_model)  # 输出的线性变换def forward(self, query, key, value, mask):# 将嵌入向量分成不同的头query = query.view(query.shape[0], -1, self.n_heads, self.head_dim)key = key.view(key.shape[0], -1, self.n_heads, self.head_dim)value = value.view(value.shape[0], -1, self.n_heads, self.head_dim)# 转置以获得维度 batch_size, self.n_heads, seq_len, self.head_dimquery = query.transpose(1, 2)key = key.transpose(1, 2)value = value.transpose(1, 2)# 计算注意力得分scores = torch.matmul(query, key.transpose(-2, -1)) / self.head_dimif mask is not None:scores = scores.masked_fill(mask == 0, -1e9)attention = torch.nn.functional.softmax(scores, dim=-1)out = torch.matmul(attention, value)# 重塑以恢复原始输入形状out = out.transpose(1, 2).contiguous().view(query.shape[0], -1, self.d_model)out = self.fc_out(out)return out# 定义Transformer编码器层
class TransformerEncoderLayer(nn.Module):def __init__(self, d_model, n_heads, dim_feedforward, dropout):super(TransformerEncoderLayer, self).__init__()# 多头自注意力层,接收d_model维度输入,使用n_heads个注意力头self.self_attn = MultiHeadAttention(d_model, n_heads)# 第一个全连接层,将d_model维度映射到dim_feedforward维度self.linear1 = nn.Linear(d_model, dim_feedforward)# 第二个全连接层,将dim_feedforward维度映射回d_model维度self.linear2 = nn.Linear(dim_feedforward, d_model)# 用于随机丢弃部分神经元,以减少过拟合self.dropout = nn.Dropout(dropout)# 第一个层归一化层,用于归一化第一个全连接层的输出self.norm1 = nn.LayerNorm(d_model)# 第二个层归一化层,用于归一化第二个全连接层的输出self.norm2 = nn.LayerNorm(d_model)def forward(self, src, src_mask):# 使用多头自注意力层处理输入src,同时提供src_mask以屏蔽不需要考虑的位置src2 = self.self_attn(src, src, src, src_mask)# 残差连接和丢弃:将自注意力层的输出与原始输入相加,并应用丢弃src = src + self.dropout(src2)# 应用第一个层归一化src = self.norm1(src)# 经过第一个全连接层,再经过激活函数ReLU,然后进行丢弃src2 = self.linear2(self.dropout(torch.nn.functional.relu(self.linear1(src))))# 残差连接和丢弃:将全连接层的输出与之前的输出相加,并再次应用丢弃src = src + self.dropout(src2)# 应用第二个层归一化src = self.norm2(src)# 返回编码器层的输出return src# 实例化模型
vocab_size = 10000  # 词汇表大小(根据实际情况调整)
d_model = 512  # 模型的维度
n_heads = 8  # 多头自注意力的头数
num_encoder_layers = 6  # 编码器层的数量
dim_feedforward = 2048  # 全连接层的隐藏层维度
max_seq_length = 100  # 最大序列长度
dropout = 0.1  # 丢弃率# 创建Transformer模型实例
model = Transformer(vocab_size, d_model, n_heads, num_encoder_layers, dim_feedforward, max_seq_length, dropout)

最后的最后再贴上一张非常不错的 Transformer 手绘吧!

在这里插入图片描述

技术交流群

前沿技术资讯、算法交流、求职内推、算法竞赛、面试交流(校招、社招、实习)等、与 10000+来自港科大、北大、清华、中科院、CMU、腾讯、百度等名校名企开发者互动交流~

我们建了算法岗技术与面试交流群, 想要进交流群、需要源码&资料、提升技术的同学,可以直接加微信号:mlc2040。加的时候备注一下:研究方向 +学校/公司+CSDN,即可。然后就可以拉你进群了。

方式①、微信搜索公众号:机器学习社区,后台回复:加群
方式②、添加微信号:mlc2040,备注:技术交流

用通俗易懂方式讲解系列

  • 《大模型面试宝典》(2024版) 正式发布!
  • 《大模型实战宝典》(2024版)正式发布!
  • 大模型面试准备(一):LLM主流结构和训练目标、构建流程
  • 大模型面试准备(二):LLM容易被忽略的Tokenizer与Embedding
  • 大模型面试准备(三):聊一聊大模型的幻觉问题
  • 大模型面试准备(四):大模型面试必会的位置编码(绝对位置编码sinusoidal,旋转位置编码RoPE,以及相对位置编码ALiBi)

参考文献:

参考资料:
[1] https://jalammar.github.io/illustrated-transformer/
[2] https://zhuanlan.zhihu.com/p/264468193
[3] https://zhuanlan.zhihu.com/p/662777298

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/287148.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Godot4自学手册】第二十九节使用Shader来实现敌人受伤的闪白效果

在Godot 4中,Shader是用来为材质提供自定义渲染效果的程序。材质可以应用于MeshInstance、CanvasItem和ParticleEmitter等节点。Shader可以影响顶点的变换、片段(像素)的颜色,以及光照与物体的交互。 在Godot中,Shader…

C#_事件_多线程(基础)

文章目录 事件通过事件使用委托 多线程(基础)进程:线程: 多线程线程生命周期主线程Thread 类中的属性和方法创建线程管理线程销毁线程 昨天习题答案 事件 事件(Event)本质上来讲是一种特殊的多播委托,只能从声明它的类中进行调用,基本上说是…

【小沐学AI】智谱AI大模型的一点点学习(Python)

文章目录 1、简介1.1 大模型排行榜 2、智谱AI2.1 GLM2.1.1 模型简介2.1.2 开源代码2.1.2.1 GLM-130B 2.2 ChatGLM2.2.1 模型简介2.2.2 开源代码2.2.2.1 ChatGLM2.2.2.2 ChatGLM22.2.2.3 ChatGLM3 2.3 CodeGeeX2.3.1 模型简介2.3.2 开源代码 2.4 CogView2.4.1 模型简介2.4.2 开源…

在存在代理的主机上,为docker容器配置代理

1、配置Firefox的代理 (只配置域名或者ip,前面不加http://) 2、为容器中的Git配置代理 git config --global http.proxy http://qingteng:8080 3、Git下载时忽略证书校验 env GIT_SSL_NO_VERIFYtrue git clone https://github.com/nginx/nginx.git 4、docker的…

《剑指 Offer》专项突破版 - 面试题 93 : 最长斐波那契数列(C++ 实现)

题目链接:最长斐波那契数列 题目: 输入一个没有重复数字的单调递增的数组,数组中至少有 3 个数字,请问数组中最长的斐波那契数列的长度是多少?例如,如果输入的数组是 [1, 2, 3, 4, 5, 6, 7, 8]&#xff0…

Redission 分布式锁原理分析

一、前言 我们先来说说分布式锁,为啥要有分布式锁呢? 像 JDK 提供的 synchronized、Lock 等实现锁不香吗?这是因为在单进程情况下,多个线程访问同一资源,可以使用 synchronized 和 Lock 实现;在多进程情况下&#xff…

MATLAB 公共区域的点云合并(46)

MATLAB 公共区域的点云合并(46) 一、算法介绍二、算法实现1.代码2.效果一、算法介绍 点云配准后,或者公共区域存在多片点云对场景进行冗余过量表达时,我们需要将点云进行合并,Matlab点云工具中提供了这样的合并函数,通过指定网格步长,对初始点云进行过滤。 函数主要实…

ReactNative项目构建分析与思考之RN组件化

传统RN项目对比 ReactNative项目构建分析与思考之react-native-gradle-plugin ReactNative项目构建分析与思考之native_modules.gradle ReactNative项目构建分析与思考之 cli-config 在之前的文章中,已经对RN的默认项目有了一个详细的分析,下面我们来…

Linux之文件系统与软硬链接

前言 我们之前阐述的内容都是在文件打开的前提下, 但是事实上不是所有文件都是被打开的, 且大部分文件都不是被打开的(也就是文件当前并不需要被访问), 都在磁盘中进行保存. 那这些没有被(进程)打开的文件, 也是需要被管理的! 对于这部分文件核心工作之一是能够快速定位文件…

P1135 奇怪的电梯 (双向bfs)

输入输出样例 输入 5 1 5 3 3 1 2 5输出 3说明/提示 对于 100%100% 的数据,1≤N≤200,1≤A,B≤N,0≤Ki​≤N。 本题共 1616 个测试点,前 1515 个每个测试点 66 分,最后一个测试点 10 分。 重写AC代码&#xff1…

UVa1483/LA5075 Intersection of Two Prisms

题目链接 本题是2010年ICPC亚洲区域赛东京赛区的I题 题意 求两个无限高棱柱的交。其中一个棱柱是把xy平面上的凸多边形沿z轴无限拉长得到,另外一个棱柱是把xz平面上的凸多边形沿y轴无限拉长得到。输入给出第一个棱柱在xy平面的凸多边形坐标和另外一个棱柱在xz平面的…

voxelize_cuda安装教程 python+windows环境

import voxelize_cuda报错 安装步骤: 克隆voxelize项目 官网:https://github.com/YuliangXiu/neural_voxelization_layer.git git clone https://github.com/YuliangXiu/neural_voxelization_layer.git下载一些必备的解析c文件的依赖 官网&#xff1a…

鸿蒙应用开发-录音保存并播放音频

功能介绍: 录音并保存为m4a格式的音频,然后播放该音频,参考文档使用AVRecorder开发音频录制功能(ArkTS),更详细接口信息请查看接口文档:ohos.multimedia.media (媒体服务)。 知识点: 熟悉使用AVRecorder…

007 日期类型相关工具类

推荐一篇文章 http://t.csdnimg.cn/72F7Jhttp://t.csdnimg.cn/72F7J

agent利用知识来做规划:《KnowAgent: Knowledge-Augmented Planning for LLM-Based Agents》笔记

文章目录 简介KnowAgent思路准备知识Action Knowledge的定义Planning Path Generation with Action KnowledgePlanning Path Refinement via Knowledgeable Self-LearningKnowAgent的实验结果 总结参考资料 简介 《KnowAgent: Knowledge-Augmented Planning for LLM-Based Age…

盛⽔最多的容器【双指针】

首先我们设该容器的两边为左右两边界。 这道题中的:盛⽔最大容量 底 * 高 左右两边界距离 * 左右两边界的较短板。 这道题如果用暴力求解,是个人都能想到怎么做,遍历所有的情况即可。 有没有更好的办法呢?我是搜了资料了解的。我…

Covalent Network(CQT)的以太坊时光机:在 Rollup 时代确保长期数据可用性

以太坊正在经历一场向 “Rollup 时代” 的转型之旅,这一转型由以太坊改进提案 EIP-4844 推动。这标志着区块链技术的一个关键转折,采用了一种被称为“数据块(blobs)”的新型数据结构。为了与以太坊的扩容努力保持一致,…

MATLAB 自定义生成平面点云(可指定方向,添加噪声)(48)

MATLAB 自定义生成平面点云(可指定方向,添加噪声)(48) 一、算法介绍二、算法步骤三、算法实现1.代码2.效果一、算法介绍 通过这里的平面生成方法,可以生成模拟平面的点云数据,并可以人为设置平面方向,平面大小,并添加噪声来探索不同类型的平面数据。这种方法可以用于…

mysql刨根问底

索引:排好序的数据结构 二叉树: 红黑树 hash表: b-tree: 叶子相同深度,叶节点指针空,索引元素不重复,从左到右递增排序 节点带data btree: 非叶子节点只存储索引,可…

Java_15 删除排序数组中的重复项

删除排序数组中的重复项 给你一个 非严格递增排列 的数组 nums ,请你 原地 删除重复出现的元素,使每个元素 只出现一次 ,返回删除后数组的新长度。元素的 相对顺序 应该保持 一致 。然后返回 nums 中唯一元素的个数。 考虑 nums 的唯一元素的…