TR2 - Transformer模型的复现


目录

  • 理论知识
  • 模型结构
    • 结构分解
      • 黑盒
      • 两大模块
      • 块级结构
      • 编码器的组成
      • 解码器的组成
  • 模型实现
    • 多头自注意力块
    • 前馈网络块
    • 位置编码
    • 编码器
    • 解码器
    • 组合模型
    • 最后附上引用部分
  • 模型效果
  • 总结与心得体会


理论知识

Transformer是可以用于Seq2Seq任务的一种模型,和Seq2Seq不冲突。

模型结构

模型整体结构

结构分解

黑盒

以机器翻译任务为例
黑盒

两大模块

在Transformer内部,可以分成Encoder编码器和和Decoder解码器两部分,这也是Seq2Seq的标准结构。
两大模块

块级结构

继续拆解,可以发现模型的由许多的编码器块和解码器块组成并且每个解码器都可以获取到最后一层编码器的输出以及上一层解码器的输出(第一个当然是例外的)。
块组成

编码器的组成

继续拆解,一个编码器是由一个自注意力块和一个前馈网络组成。
编码器的组成

解码器的组成

而解码器,是在编码器的结构中间又插入了一个Encoder-Decoder Attention层。
解码器的组成

模型实现

通过前面自顶向下的拆解,已经基本掌握了模型的总体结构。接下来自底向上的复现Transformer模型。

多头自注意力块

class MultiHeadAttention(nn.Module):"""多头注意力模块"""def __init__(self, dims, n_heads):"""dims: 每个词向量维度n_heads: 注意力头数"""super().__init__()self.dims = dimsself.n_heads = n_heads# 维度必需整除注意力头数assert dims % n_heads == 0# 定义Q矩阵self.w_Q = nn.Linear(dims, dims)# 定义K矩阵self.w_K = nn.Linear(dims, dims)# 定义V矩阵self.w_V = nn.Linear(dims, dims)self.fc = nn.Linear(dims, dims)# 缩放self.scale = torch.sqrt(torch.FloatTensor([dims//n_heads])).to(device)def forward(self, query, key, value, mask=None):batch_size = query.shape[0]# 例如: [32, 1024, 300] 计算10头注意力Q = self.w_Q(query)K = self.w_K(key)V = self.w_V(value)# [32, 1024, 300] -> [32, 1024, 10, 30] 把向量重新分组Q = Q.view(batch_size, -1, self.n_heads, self.dims//self.n_heads).permute(0, 2, 1, 3)K = K.view(batch_size, -1, self.n_heads, self.dims//self.n_heads).permute(0, 2, 1, 3)V = V.view(batch_size, -1, self.n_heads, self.dims//self.n_heads).permute(0, 2, 1, 3)# 1. 计算QK/根dk# [32, 1024, 10, 30] * [32, 1024, 30, 10] -> [32, 1024, 10, 10] 交换最后两维实现乘法attention = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scaleif mask is not None:# 将需要mask掉的部分设置为很小的值attention = attention.masked_fill(mask==0, -1e10)# 2. softmaxattention = torch.softmax(attention, dim=-1)# 3. 与V相乘# [32, 1024, 10, 10] * [32, 1024, 10, 30] -> [32, 1024, 10, 30]x = torch.matmul(attention, V)# 恢复结构# 0 2 1 3 把 第2,3维交换回去x = x.permute(0, 2, 1, 3).contiguous()# [32, 1024, 10, 30] -> [32, 1024, 300]x = x.view(batch_size, -1, self.n_heads*(self.dims//self.n_heads))# 走一个全连接层x = self.fc(x)return x

前馈网络块

class FeedForward(nn.Module):"""前馈传播"""def __init__(self, d_model, d_ff, dropout=0.1):super().__init__()self.linear1 = nn.Linear(d_model, d_ff)self.dropout = nn.Dropout(dropout)self.linear2 = nn.Linear(d_ff, d_model)def forward(self, x):x = F.relu(self.linear1(x))x = self.dropout(x)x = self.linear2(x)return x

位置编码

class PositionalEncoding(nn.Module):"""位置编码"""def __init__(self, d_model, dropout=0.1, max_len=5000):super().__init__()self.dropout = nn.Dropout(dropout)# 用来存位置编码的向量pe = torch.zeros(max_len, d_model).to(device)# 准备位置信息position = torch.arange(0, max_len).unsqueeze(1)div_term = torch.exp(torch.arange(0, d_model, 2)* -(math.log(10000.0) / d_model))pe[:, 0::2] = torch.sin(position * div_term)pe[:, 1::2] = torch.cos(position * div_term)pe = pe.unsqueeze(0)# 注册一个不参数梯度下降的模型参数self.register_buffer('pe', pe)def forward(self, x):x  = x + self.pe[:, :x.size(1)].requires_grad_(False)return self.dropout(x)

编码器

class EncoderLayer(nn.Module):def __init__(self, d_model, n_heads, d_ff, dropout=0.1):super().__init__()self.self_attn = MultiHeadAttention(d_model, n_heads)self.feedforward = FeedForward(d_model, d_ff, dropout)self.norm1 = nn.LayerNorm(d_model)self.norm2 = nn.LayerNorm(d_model)self.dropout = nn.Dropout(dropout)def forward(self, x, mask):attn_output = self.self_attn(x, x, x, mask)x = x + self.dropout(attn_output)x = self.norm1(x)ff_output = self.feedforward(x)x = x + self.dropout(ff_output)x = self.norm2(x)return x

解码器

class DecoderLayer(nn.Module):def __init__(self, d_model, n_heads, d_ff, dropout=0.1):super().__init__()self.self_attn = MultiHeadAttention(d_model, n_heads)self.enc_attn = MultiHeadAttention(d_model, n_heads)self.feedforward = FeedForward(d_model, d_ff, dropout)self.norm1 = nn.LayerNorm(d_model)self.norm2 = nn.LayerNorm(d_model)self.norm3 = nn.LayerNorm(d_model)self.dropout = nn.Dropout(dropout)def forward(self, x, enc_output, mask, enc_mask):# 自注意力attn_output = self.self_attn(x, x, x, mask)x = x + self.dropout(attn_output)x = self.norm1(x)# 编码器-解码器注意力attn_output = self.enc_attn(x, enc_output, enc_output, enc_mask)x = x + self.dropout(attn_output)x = self.norm2(x)# 前馈网络ff_output = self.feedforward(x)x = x + self.dropout(ff_output)x = self.norm3(x)return x

组合模型

class Transformer(nn.Module):def __init__(self, vocab_size, d_model, n_heads, n_encoder_layers, n_decoder_layers, d_ff, dropout=0.1):super().__init__()self.embedding = nn.Embedding(vocab_size, d_model)self.positional_encoding = PositionalEncoding(d_model)self.encoder_layers = nn.ModuleList([EncoderLayer(d_model, n_heads, d_ff, dropout) for _ in range(n_encoder_layers)])self.decoder_layers = nn.ModuleList([DecoderLayer(d_model, n_heads, d_ff, dropout) for _ in range(n_decoder_layers)])self.fc_out = nn.Linear(d_model, vocab_size)self.dropout = nn.Dropout(dropout)def forward(self, src, trg, src_mask, trg_mask):# 词嵌入src = self.embedding(src)src = self.positional_encoding(src)trg = self.embedding(trg)trg = self.positional_encoding(trg)# 编码器for layer in self.encoder_layers:src = layer(src, src_mask)# 解码器for layer in self.decoder_layers:trg = layer(trg, src, trg_mask, src_mask)output = self.fc_out(trg)return output

最后附上引用部分

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

模型效果

编写代码测试模型的复现是否正确(没有跑任务)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')vocab_size = 10000
d_model = 512
n_heads = 8
n_encoder_layers = 6
n_decoder_layers = 6
d_ff = 2048
dropout = 0.1transformer_model = Transformer(vocab_size, d_model, n_heads, n_encoder_layers, n_decoder_layers, d_ff, dropout).to(device)src = torch.randint(0, vocab_size, (32, 10)).to(device) # 源语言
trg = torch.randint(0, vocab_size, (32, 20)).to(device) # 目标语言src_mask = (src != 0).unsqueeze(1).unsqueeze(2).to(device)
trg_mask = (trg != 0).unsqueeze(1).unsqueeze(2).to(device)output = transformer_model(src, trg, src_mask, trg_mask)
print(output.shape)

打印结果

torch.Size([32, 20, 10000])

说明模型正常运行了

总结与心得体会

我是从CV模型学到Transfromer来的,通过对Transformer模型的复现我发现:

  • 类似于残差的连接在Transformer中也十分常见,还有先缩小再放大的Bottleneck结构。
  • 整个Transformer模型的核心处理对特征的维度没有变化,这一点和CV模型完全不同。
  • Transformer的核心是多头自注意机制。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/290909.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

伦敦银短线交易频率可以有多高?

伦敦银是很适合于短线交易的品种,至于交易的频率可以短到什么程度,取决于投资者采用的是手动交易,还是程序化的交易。高频交易(HFT)是一种利用计算机算法和高速网络进行的快速交易策略。高频交易者会利用复杂的数学模型…

提取gdip-yolo与ia-seg中的图像自适应模块进行图像去雾与亮度增强

gdip-yolo与ia-seg都是一种将图像自适应模块插入模型前面,从而提升模型在特定数据下检测能力的网络结构。gdip-yolo提出了gdip模块,可以应用到大雾数据与低亮度数据(夜晚环境),然后用于目标检测训练;ia-seg将ia-yolo中的代码修改了一下修车了ipam模块,应用到低亮度数据(…

新能源充电桩站场视频汇聚系统建设方案及技术特点分析

随着新能源汽车的普及,充电桩作为新能源汽车的基础设施,其安全性和可靠性越来越受到人们的关注。为了更好地保障充电桩的安全运行与站场管理,TSINGSEE青犀&触角云推出了一套新能源汽车充电桩视频汇聚管理与视频监控方案。 方案采用高清摄…

游戏赛道新机会:善用数据分析,把握游戏赛道广告变现良机 | TOPON变现干货

12月10日,由罗斯基联合TopOn、钛动科技共同主办的《游戏赛道新机会》主题系列沙龙在武汉举办。活动邀请了国内外多家业内知名公司的负责人到场分享,现场嘉宾分别从自己擅长的领域出发,通过数据分析,案例复盘等多个维度方向进行讲解…

国内IP代理软件电脑版:深入解析与应用指南

随着互联网技术的快速发展,网络活动日益丰富多样,IP代理软件也因其独特的功能和优势,成为许多电脑用户不可或缺的工具。在国内,由于网络环境的复杂性和特殊性,选择一款稳定、高效的IP代理软件电脑版尤为重要。虎观代理…

SpringBoot 登录认证(二)

SpringBoot 登录认证(一)-CSDN博客 SpringBoot 登录认证(二)-CSDN博客 SpringBoot登录校验(三)-CSDN博客 HTTP是无状态协议 HTTP协议是无状态协议。什么又是无状态的协议? 所谓无状态&…

刷题之动态规划

前言 大家好,我是jiantaoyab,开始刷动态规划的题目了,要特别注意初始化的时候给什么值。 动态规划5个步骤 状态表示 :dp数组中每一个下标对应值的含义是什么->dp[i]表示什么状态转移方程: dp[i] 等于什么1 和 2 是…

用grafana+prometheus+cadvisor监控容器指标数据,并查询当前容器的网速网络用量

前言 整理技术,在这篇文章中,将会搭建grafanaprometheuscadvisor监控容器,并使用一个热门数据看板,再监控容器的性能指标 dashboard效果 这个是node-exporter采集到的数据,我没装node-exporter,而且这也…

第三十二天-Django模板-DTL模板引擎

目录 1.介绍 2. 使用 1.配置jinja2 2.DTL模板变量使用 3.与jinja2区别 4.模板标签使用 1.循环 2.条件控制 3.注释 4.url解析 5.显示时间 5.模板的基础与包含 6.过滤器 内置过滤器 自定义过滤器 1.介绍 2. 使用 1.配置jinja2 2.DTL模板变量使用 与jinja2语法相似…

短视频矩阵系统--技术3年源头迭代

短视频矩阵系统核心技术算法主要包括以下几个方面: 1. 视频剪辑:通过剪辑工具或API从各大短视频平台抓取符合要求的视频。这些视频通常符合某些特定条件,如特定关键词、特定时间段发布的视频、视频点赞评论转发等数据表现良好的视频。 2. 视…

vue3+threejs新手从零开发卡牌游戏(二十一):添加战斗与生命值关联逻辑

首先将双方玩家的HP存入store中,stores/common.ts代码如下: import { ref, computed } from vue import { defineStore } from piniaexport const useCommonStore defineStore(common, () > {const _font ref() // 字体const p1HP ref(4000) // 己…

[Python GUI PyQt] PyQt5快速入门

PyQt5快速入门 PyQt5的快速入门0. 写在前面1. 思维导图2. 第一个PyQt5的应用程序3. PyQt5的常用基本控件和布局3.1 PyQt5的常用基本控件3.1.1 按钮控件 QPushButton3.1.2 文本标签控件 QLabel3.1.3 单行输入框控件 QLineEdit3.1.4 A Quick Widgets Demo 3.2 PyQt5的常用基本控件…

大模型重塑电商,淘宝、百度、京东讲出新故事

配图来自Canva可画 随着AI技术日渐成熟,大模型在各个领域的应用也越来越深入,国内互联网行业也随之进入了大模型竞赛的后半场,开始从“百模大战”转向了实际应用。大模型从通用到细分垂直领域的跨越,也让更多行业迎来了新的商机。…

c语言游戏实战(7):扫雷

前言: 扫雷是一款经典的单人益智游戏,它的目标是在一个方格矩阵中找出所有的地雷,而不触碰到任何一颗地雷。在计算机编程领域,扫雷也是一个非常受欢迎的项目,因为它涉及到许多重要的编程概念,如数组、循环…

3D人体姿态估计项目 | 从2D视频中通过检测人体关键点来估计3D人体姿态实现

项目应用场景 人体姿态估计是关于图像或视频中人体关节的 2D 或 3D 定位。一般来说,这个过程可以分为两个部分:(1) 2D 视频中的 2D 关键点检测;(2) 根据 2D 关键点进行 3D 位姿估计。这个项目使用 Detectron2 从任意的 2D 视频中检测 2D 关节…

海豚【货运系统源码】货运小程序【用户端+司机端app】源码物流系统搬家系统源码师傅接单

技术栈:前端uniapp后端vuethinkphp 主要功能: 不通车型配置不通价格参数 多城市定位服务 支持发货地 途径地 目的地智能费用计算 支持日期时间 预约下单 支持添加跟单人数选择 支持下单优惠券抵扣 支持司机收藏订单评价 支持订单状态消息通知 支…

FANUC机器人故障诊断—报警代码更新(三)

FANUC机器人故障诊断中,有些报警代码,继续更新如下。 一、报警代码(SRVO-348) SRVO-348DCS MCC关闭报警a,b [原因]向电磁接触器发出了关闭指令,而电磁接触器尚未关闭。 [对策] 1.当急停单元上连接了CRMA…

备考ICA----Istio实验14---出向流量管控Egress Gateways实验

备考ICA----Istio实验14—出向流量管控Egress Gateways实验 1. 发布测试用 pod kubectl apply -f istio/samples/sleep/sleep.yaml kubectl get pods -l appsleep2. ServiceEntry 创建一个ServiceEntry允许流量访问edition.cnn.com egressgw/edition-ServiceEntry.yaml api…

使用anime.js实现列表滚动轮播

官网&#xff1a;https://animejs.com/ html <div id"slide1"><div class"weather-item" v-for"item in weatherList"><div><img src"../../images/hdft/position.png" alt"">{{item.body.cityInf…

Topaz Gigapixel AI for Mac 图像放大软件

Topaz Gigapixel AI for Mac是一款专为Mac用户设计的智能图像放大软件。它采用了人工智能技术&#xff0c;特别是深度学习算法&#xff0c;以提高图像的分辨率和质量&#xff0c;使得图像在放大后仍能保持清晰的细节。这款软件的特点在于其能够将低分辨率的图片放大至高分辨率&…