动手学深度学习68 Transformer

1. Transformer

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

2. 多头注意力代码

通过不断地reshape,避免forloop操作。
什么样的shape进去,怎样的shape出来。

#@save
class MultiHeadAttention(nn.Module):"""多头注意力"""def __init__(self, key_size, query_size, value_size, num_hiddens,num_heads, dropout, bias=False, **kwargs):super(MultiHeadAttention, self).__init__(**kwargs)self.num_heads = num_headsself.attention = d2l.DotProductAttention(dropout)self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)def forward(self, queries, keys, values, valid_lens):# queries,keys,values的形状:# (batch_size,查询或者“键-值”对的个数,num_hiddens)# valid_lens 的形状:# (batch_size,)或(batch_size,查询的个数)# 经过变换后,输出的queries,keys,values 的形状:# (batch_size*num_heads,查询或者“键-值”对的个数,# num_hiddens/num_heads)queries = transpose_qkv(self.W_q(queries), self.num_heads)keys = transpose_qkv(self.W_k(keys), self.num_heads)values = transpose_qkv(self.W_v(values), self.num_heads)if valid_lens is not None:# 在轴0,将第一项(标量或者矢量)复制num_heads次,# 然后如此复制第二项,然后诸如此类。valid_lens = torch.repeat_interleave(valid_lens, repeats=self.num_heads, dim=0)# output的形状:(batch_size*num_heads,查询的个数,# num_hiddens/num_heads)output = self.attention(queries, keys, values, valid_lens)# output_concat的形状:(batch_size,查询的个数,num_hiddens)output_concat = transpose_output(output, self.num_heads)return self.W_o(output_concat)#@save
def transpose_qkv(X, num_heads):"""为了多注意力头的并行计算而变换形状"""# 输入X的形状:(batch_size,查询或者“键-值”对的个数,num_hiddens)# 输出X的形状:(batch_size,查询或者“键-值”对的个数,num_heads,# num_hiddens/num_heads)X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)# 输出X的形状:(batch_size,num_heads,查询或者“键-值”对的个数,# num_hiddens/num_heads)X = X.permute(0, 2, 1, 3)# 最终输出的形状:(batch_size*num_heads,查询或者“键-值”对的个数,# num_hiddens/num_heads)return X.reshape(-1, X.shape[2], X.shape[3])#@save
def transpose_output(X, num_heads):"""逆转transpose_qkv函数的操作"""X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])X = X.permute(0, 2, 1, 3)return X.reshape(X.shape[0], X.shape[1], -1)

3. Transformer代码

输出不变,模型容易叠加
重要参数:hidden-size, head

#@save
class EncoderBlock(nn.Module):"""Transformer编码器块"""def __init__(self, key_size, query_size, value_size, num_hiddens,norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,dropout, use_bias=False, **kwargs):super(EncoderBlock, self).__init__(**kwargs)self.attention = d2l.MultiHeadAttention(key_size, query_size, value_size, num_hiddens, num_heads, dropout,use_bias)self.addnorm1 = AddNorm(norm_shape, dropout)self.ffn = PositionWiseFFN(ffn_num_input, ffn_num_hiddens, num_hiddens)self.addnorm2 = AddNorm(norm_shape, dropout)def forward(self, X, valid_lens):Y = self.addnorm1(X, self.attention(X, X, X, valid_lens))return self.addnorm2(Y, self.ffn(Y))#@save
class TransformerEncoder(d2l.Encoder):"""Transformer编码器"""def __init__(self, vocab_size, key_size, query_size, value_size,num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens,num_heads, num_layers, dropout, use_bias=False, **kwargs):super(TransformerEncoder, self).__init__(**kwargs)self.num_hiddens = num_hiddensself.embedding = nn.Embedding(vocab_size, num_hiddens)self.pos_encoding = d2l.PositionalEncoding(num_hiddens, dropout)self.blks = nn.Sequential()for i in range(num_layers):self.blks.add_module("block"+str(i),EncoderBlock(key_size, query_size, value_size, num_hiddens,norm_shape, ffn_num_input, ffn_num_hiddens,num_heads, dropout, use_bias))def forward(self, X, valid_lens, *args):# 因为位置编码值在-1和1之间,# 因此嵌入值乘以嵌入维度的平方根进行缩放,# 然后再与位置编码相加。X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))self.attention_weights = [None] * len(self.blks)for i, blk in enumerate(self.blks):X = blk(X, valid_lens)self.attention_weights[i] = blk.attention.attention.attention_weightsreturn Xclass DecoderBlock(nn.Module):"""解码器中第i个块"""def __init__(self, key_size, query_size, value_size, num_hiddens,norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,dropout, i, **kwargs):super(DecoderBlock, self).__init__(**kwargs)self.i = iself.attention1 = d2l.MultiHeadAttention(key_size, query_size, value_size, num_hiddens, num_heads, dropout)self.addnorm1 = AddNorm(norm_shape, dropout)self.attention2 = d2l.MultiHeadAttention(key_size, query_size, value_size, num_hiddens, num_heads, dropout)self.addnorm2 = AddNorm(norm_shape, dropout)self.ffn = PositionWiseFFN(ffn_num_input, ffn_num_hiddens,num_hiddens)self.addnorm3 = AddNorm(norm_shape, dropout)def forward(self, X, state):enc_outputs, enc_valid_lens = state[0], state[1]# 训练阶段,输出序列的所有词元都在同一时间处理,# 因此state[2][self.i]初始化为None。# 预测阶段,输出序列是通过词元一个接着一个解码的,# 因此state[2][self.i]包含着直到当前时间步第i个块解码的输出表示if state[2][self.i] is None:key_values = Xelse:key_values = torch.cat((state[2][self.i], X), axis=1)state[2][self.i] = key_valuesif self.training:batch_size, num_steps, _ = X.shape# dec_valid_lens的开头:(batch_size,num_steps),# 其中每一行是[1,2,...,num_steps]dec_valid_lens = torch.arange(1, num_steps + 1, device=X.device).repeat(batch_size, 1)else:dec_valid_lens = None# 自注意力X2 = self.attention1(X, key_values, key_values, dec_valid_lens)Y = self.addnorm1(X, X2)# 编码器-解码器注意力。# enc_outputs的开头:(batch_size,num_steps,num_hiddens)Y2 = self.attention2(Y, enc_outputs, enc_outputs, enc_valid_lens)Z = self.addnorm2(Y, Y2)return self.addnorm3(Z, self.ffn(Z)), stateclass TransformerDecoder(d2l.AttentionDecoder):def __init__(self, vocab_size, key_size, query_size, value_size,num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens,num_heads, num_layers, dropout, **kwargs):super(TransformerDecoder, self).__init__(**kwargs)self.num_hiddens = num_hiddensself.num_layers = num_layersself.embedding = nn.Embedding(vocab_size, num_hiddens)self.pos_encoding = d2l.PositionalEncoding(num_hiddens, dropout)self.blks = nn.Sequential()for i in range(num_layers):self.blks.add_module("block"+str(i),DecoderBlock(key_size, query_size, value_size, num_hiddens,norm_shape, ffn_num_input, ffn_num_hiddens,num_heads, dropout, i))self.dense = nn.Linear(num_hiddens, vocab_size)def init_state(self, enc_outputs, enc_valid_lens, *args):return [enc_outputs, enc_valid_lens, [None] * self.num_layers]def forward(self, X, state):X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))self._attention_weights = [[None] * len(self.blks) for _ in range (2)]for i, blk in enumerate(self.blks):X, state = blk(X, state)# 解码器自注意力权重self._attention_weights[0][i] = blk.attention1.attention.attention_weights# “编码器-解码器”自注意力权重self._attention_weights[1][i] = blk.attention2.attention.attention_weightsreturn self.dense(X), state@propertydef attention_weights(self):return self._attention_weights

4. QA

在这里插入图片描述

14 可能区别不大
15 看芯片运算力多钱
16 transformer架构
17 concat在保留信息方面比加权平均好
18
19 无
20 n个小的attention,类似卷积多通道
21 no https://zhuanlan.zhihu.com/p/674797180
22 bert gpt 有硬件要求

在这里插入图片描述
在这里插入图片描述
23 没有实际意义
24 一般选hidden-size 256 1024等
25 变种 不是标准正态
26 bert只有encoder
27 可以 图片扣出很多batch就可以认为是一个序列
29 可以先用bert做预处理,提取特征。
在这里插入图片描述
30 没什么联系 位置编码 和自注意力有点关系
31 也可以 用的不多 attention比cnn要好,视野广,bert一层可以看到很长的信息–整张图片的信息。
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/469579.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

晨控RFID技术助力半导体制造业革新之路

晨控RFID技术助力半导体制造业革新之路 应用背景 随着信息技术的快速发展,无线射频识别技术逐渐成为连接物理世界与数字世界的桥梁。尤其在半导体产业中,RFID技术的应用不仅提升了生产效率,还加强了产品追踪与管理能力,对推动产…

ReactPress:构建高效、灵活、可扩展的开源发布平台

ReactPress Github项目地址:https://github.com/fecommunity/reactpress 欢迎Star。 在当今数字化时代,内容管理系统(CMS)已成为各类网站和应用的核心组成部分。ReactPress,作为一款融合了现代Web开发多项先进技术的开…

PyTorch版本的3D网络Grad-CAM可视化实验记录

前言 最近在跑代码的时候需要可视化一些网络中间层特征来诊断网络,但是我的backbone是一个3D网络,一般的Grad-CAM都是在2D网络中应用更广泛,查了一下也只有几篇博文是关于3D Grad-CAM的介绍的。自己参照他们的代码试了一下,但是可…

基于STM32的智能充电桩:集成RTOS、MQTT与SQLite的先进管理系统设计思路

一、项目概述 随着电动车的普及,充电桩作为关键基础设施,其智能化、网络化管理显得尤为重要。本项目旨在基于STM32微控制器开发一款智能充电桩,能够实现高效的充电监控与管理。项目通过物联网技术,提供实时数据监测、远程管理、用…

.NET中通过C#实现Excel与DataTable的数据互转

在.NET框架中,使用C#进行Excel数据与DataTable之间的转换是数据分析、报表生成、数据迁移等操作中的常见需求。这一过程涉及到将Excel文件中的数据读取并加载至DataTable中,以便于利用.NET提供的丰富数据处理功能进行操作,同时也包括将DataTa…

域名服务系统DNS (Domain Name System)

域名的介绍 熟悉了域名之后,不仅仅是应对考试,生活中看到一个常规的网址,我们也能快速想到这个网址对应的含义是什么,并且在记忆网址的时候也更加得心应手,快速了解域名中各个层级的含义,这是 非常有趣呢…

Kettle——CSV文件转换成excel文件输出

1.点击—文件—新建—转换 拖入两个组件: 按shift+鼠标左击建立连接,并点击主输出步骤, 点击CSV文件输入,选择浏览的csv文件,然后点击确定 同样,Excel也同上,只是要删除这个xls 并…

Select,poll,epoll和IO多路复用和NIO

Select,poll,epoll和IO多路复用和NIO IO 多路复用:是一种 I/O 处理机制,它允许单个线程同时处理多个 I/O 流(如多个文件描述符对应的网络连接、文件操作等)的输入输出操作,通过一种机制来监听这…

希尔排序(C语言)

一、步骤: 希尔排序的基本思想:先选定一个整数,把待排序文件中所有记录分成个组,所有距离为gap的记录分在同一组内,并对每一组内的记录进行排序。然后取重复上述分组和排序的工作。当到gap 1时,所有记录在…

自动驾驶为什么需要时间同步?高精度时间同步如何实现?

自动驾驶作为汽车与物联网技术、人工智能等高新技术融合的产物,具有新颖性、复杂性和巨大的挑战性。自动驾驶需要实时传输、处理海量数据,并实时做出决策,这不仅要求有通畅网络通信环境、强大的数据算力,更要求时间同步的超低时延…

【信号处理】基于联合图像表示的深度学习卷积神经网络

Combined Signal Representations for Modulation Classification Using Deep Learning: Ambiguity Function, Constellation Diagram, and Eye Diagram 信号表示 Ambiguity Function(AF) 模糊函数描述了信号的两个维度(dimensions):延迟(delay)和多普勒(Doppler)。 …

C++20 概念与约束(1)—— SFINAE

●《C20 概念与约束(1)—— SFINAE》 《C20 概念与约束(2)—— 初识概念与约束》 《C20 概念与约束(3)—— 约束的进阶用法》 1、从模板说起 众所周知,C在使用模板时,如果有多…

[极客大挑战 2019]PHP 1

[极客大挑战 2019]PHP 1 审题 猜测备份在www.zip中,输入下载文件。 知识点 反序列化 解题 查看代码 看到index.php中包含了class.php,直接看class.php中的代码 查看条件 当usernameadmin,password100时输出flag 构造反序列化 输入select中&#…

Kubebot:一款Google云平台下的Slackbot安全测试工具

Kubebot 今天给大家介绍的是一款名叫Kubebot的安全测试Slackbot,该工具基于Google 云平台搭建,并且提供了Kubernetes后端。 项目架构 数据流 1.API请求由Slackbot发起,发送至API服务器,API服务器以Kubernetes(K8s)集群中的Docke…

Kubernetes的基本构建块和最小可调度单元pod-0

文章目录 一,什么是pod1.1pod在k8s中使用方法(1)使用方法一(2)使用方法二 1.2pod中容器的进程1.3pod的网络隔离管理(1)pause容器的作用 1.4 Pod分类:(1)自主式…

Redis安装(Windows环境)

目录 1.下载2.双击安装后配置环境变量3.启动服务4.设置Windows服务5.启动客户端6.常用的Redis服务命令7.使用图形化界面工具查看Redis内部数据情况类似Navicat 连接数据库 1.下载 1.点击github下载地址 2.上方资源链接下载安装 2.双击安装后配置环境变量 3.启动服务 上图虽然…

windows下qt5.12.11使用ODBC远程连接mysql数据库

1、下载并安装mysql驱动,下载地址:https://dev.mysql.com/downloads/ 2、配置ODBC数据源,打开64位的ODBC数据源配置工具:

【AI写作宝-注册安全分析报告-无验证方式导致安全隐患】

前言 由于网站注册入口容易被黑客攻击,存在如下安全问题: 1. 暴力破解密码,造成用户信息泄露 2. 短信盗刷的安全问题,影响业务及导致用户投诉 3. 带来经济损失,尤其是后付费客户,风险巨大,造…

GFPS技术原理(四)GATT特征值

Fast Pair服务 fast pair 服务的UUID 是0xFE2C,然后它又包含多个特征值,下面一一分析: Model ID UUID是0x1233,设备在谷歌云注册的时候会分配一个24 bit的ID。 Key-based Pairing UUID是0x1234,这个是用来做DH密钥…

3.2 软件需求:面对过程分析模型

面对过程分析模型 1. 需求分析的模型概述1.1 面对过程分析模型-结构化分析方法1.2 结构化分析的过程 2. 功能模型:数据流图初步2.1 加工2.2 外部实体(数据源点/终点)2.3 数据流2.4 数据存储2.5 注意事项 3. 功能模型:数据流图进阶…