PyTorch从零开始实现Transformer

文章目录

    • 自注意力
    • Transformer块
    • 编码器
    • 解码器块
    • 解码器
    • 整个Transformer
    • 参考来源
    • 全部代码(可直接运行)

自注意力

计算公式

在这里插入图片描述

代码实现


class SelfAttention(nn.Module):def __init__(self, embed_size, heads):super(SelfAttention, self).__init__()self.embed_size = embed_sizeself.heads = headsself.head_dim = embed_size // headsassert (self.head_dim * heads == embed_size),  "Embed size needs  to  be div by heads"self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)self.fc_out = nn.Linear(heads*self.head_dim, embed_size)def forward(self, values, keys, query, mask):N = query.shape[0] # the number of training examplesvalue_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]# Split embedding into self.heads piecesvalues = values.reshape(N, value_len, self.heads, self.head_dim)keys = keys.reshape(N, key_len, self.heads, self.head_dim)queries = query.reshape(N, query_len, self.heads, self.head_dim)values = self.values(values)keys = self.keys(keys)queries = self.queries(queries)energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys]) # 矩阵乘法,使用爱因斯坦标记法# queries shape: (N, query_len, heads, heads_dim)# keys shape: (N, key_len, heads, heads_dim)# energy shape: (N, heads, query_len, key_len)if mask is not None:energy = energy.masked_fill(mask==0, float("-1e20")) #Fills elements of self tensor with value where mask is Trueattention = torch.softmax(energy / (self.embed_size ** (1/2)), dim=3)out = torch.einsum("nhql, nlhd->nqhd", [attention, values]).reshape(N, query_len, self.heads*self.head_dim) # 矩阵乘法,使用爱因斯坦标记法einsum# attention shape: (N, heads, query_len, key_len)# values shape: (N, value_len, heads, head_dim)# after einsum (N, query_len, heads, head_dim) then flatten last two dimensionsout = self.fc_out(out)return out

Transformer块

我们把Transfomer块定义为如下图所示的结构,这个Transformer块在编码器和解码器中都有出现过。
在这里插入图片描述

代码实现

class TransformerBlock(nn.Module):def __init__(self, embed_size, heads, dropout, forward_expansion):super(TransformerBlock, self).__init__()self.attention = SelfAttention(embed_size, heads)self.norm1 = nn.LayerNorm(embed_size)self.norm2 = nn.LayerNorm(embed_size)self.feed_forward = nn.Sequential(nn.Linear(embed_size, forward_expansion*embed_size),nn.ReLU(),nn.Linear(forward_expansion*embed_size, embed_size))self.dropout = nn.Dropout(dropout)def forward(self, value, key, query, mask):attention = self.attention(value, key, query, mask)x = self.dropout(self.norm1(attention + query))forward = self.feed_forward(x)out = self.dropout(self.norm2(forward + x))return out

编码器

编码器结构如下所示,Inputs经过Input Embedding 和Positional Encoding之后,通过多个Transformer块

在这里插入图片描述

代码实现

class Encoder(nn.Module):def __init__(self, src_vocab_size,embed_size,num_layers,heads,device,forward_expansion,dropout,max_length):super(Encoder, self).__init__()self.embed_size = embed_sizeself.device = deviceself.word_embedding = nn.Embedding(src_vocab_size, embed_size)self.position_embedding = nn.Embedding(max_length, embed_size)self.layers = nn.ModuleList([TransformerBlock(embed_size,heads,dropout=dropout,forward_expansion=forward_expansion)for _ in range(num_layers)])self.dropout = nn.Dropout(dropout)def forward(self, x, mask):N, seq_lengh = x.shapepositions = torch.arange(0, seq_lengh).expand(N, seq_lengh).to(self.device)out = self.dropout(self.word_embedding(x) + self.position_embedding(positions))for layer in self.layers:out = layer(out, out, out, mask)return out

解码器块

解码器块结构如下图所示

在这里插入图片描述

代码实现

class DecoderBlock(nn.Module):def __init__(self, embed_size, heads, forward_expansion, dropout, device):super(DecoderBlock, self).__init__()self.attention = SelfAttention(embed_size, heads)self.norm = nn.LayerNorm(embed_size)self.transformer_block = TransformerBlock(embed_size, heads, dropout, forward_expansion)self.dropout = nn.Dropout(dropout)def forward(self, x, value, key, src_mask, trg_mask):attention = self.attention(x, x, x, trg_mask)query = self.dropout(self.norm(attention + x))out = self.transformer_block(value, key, query, src_mask)return out

解码器

解码器块加上word embedding 和 positional embedding之后构成解码器

在这里插入图片描述

代码实现

class Decoder(nn.Module):def __init__(self, trg_vocab_size, embed_size, num_layers, heads, forward_expansion, dropout, device, max_length):super(Decoder, self).__init__()self.device = deviceself.word_embedding = nn.Embedding(trg_vocab_size, embed_size)self.position_embedding = nn.Embedding(max_length, embed_size)self.layers = nn.ModuleList([DecoderBlock(embed_size, heads, forward_expansion, dropout, device)for _ in range(num_layers)])self.fc_out = nn.Linear(embed_size, trg_vocab_size)self.dropout = nn.Dropout(dropout)def forward(self, x, enc_out, src_mask, trg_mask):N, seq_length = x.shapepositions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)x = self.dropout((self.word_embedding(x) + self.position_embedding(positions)))for layer in self.layers:x = layer(x, enc_out, enc_out, src_mask, trg_mask)out = self.fc_out(x)return out

整个Transformer

在这里插入图片描述

代码实现

class Transformer(nn.Module):def __init__(self,src_vocab_size, trg_vocab_size,src_pad_idx,trg_pad_idx,embed_size=256,num_layers=6,forward_expansion=4,heads=8,dropout=0,device="cuda",max_length=100):super(Transformer, self).__init__()self.encoder = Encoder(src_vocab_size,embed_size,num_layers,heads,device,forward_expansion,dropout,max_length)self.decoder = Decoder(trg_vocab_size,embed_size,num_layers,heads,forward_expansion,dropout,device,max_length)self.src_pad_idx = src_pad_idxself.trg_pad_idx = trg_pad_idxself.device = devicedef make_src_mask(self, src):src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)#(N, 1, 1, src_len)return src_mask.to(self.device)def make_trg_mask(self, trg):N, trg_len = trg.shapetrg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand(N, 1, trg_len, trg_len)return trg_mask.to(self.device)def forward(self, src, trg):src_mask = self.make_src_mask(src)trg_mask = self.make_trg_mask(trg)enc_src = self.encoder(src, src_mask)out = self.decoder(trg, enc_src, src_mask,  trg_mask)return out

参考来源

[1] https://www.youtube.com/watch?v=U0s0f995w14
[2] https://github.com/aladdinpersson/Machine-Learning-Collection/blob/master/ML/Pytorch/more_advanced/transformer_from_scratch/transformer_from_scratch.py

[3] https://arxiv.org/abs/1706.03762
[4] https://www.youtube.com/watch?v=pkVwUVEHmfI

全部代码(可直接运行)

import torch
import torch.nn as nnclass SelfAttention(nn.Module):def __init__(self, embed_size, heads):super(SelfAttention, self).__init__()self.embed_size = embed_sizeself.heads = headsself.head_dim = embed_size // headsassert (self.head_dim * heads == embed_size),  "Embed size needs  to  be div by heads"self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)self.fc_out = nn.Linear(heads*self.head_dim, embed_size)def forward(self, values, keys, query, mask):N = query.shape[0] # the number of training examplesvalue_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]# Split embedding into self.heads piecesvalues = values.reshape(N, value_len, self.heads, self.head_dim)keys = keys.reshape(N, key_len, self.heads, self.head_dim)queries = query.reshape(N, query_len, self.heads, self.head_dim)values = self.values(values)keys = self.keys(keys)queries = self.queries(queries)energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])# queries shape: (N, query_len, heads, heads_dim)# keys shape: (N, key_len, heads, heads_dim)# energy shape: (N, heads, query_len, key_len)if mask is not None:energy = energy.masked_fill(mask==0, float("-1e20")) #Fills elements of self tensor with value where mask is Trueattention = torch.softmax(energy / (self.embed_size ** (1/2)), dim=3)out = torch.einsum("nhql, nlhd->nqhd", [attention, values]).reshape(N, query_len, self.heads*self.head_dim)# attention shape: (N, heads, query_len, key_len)# values shape: (N, value_len, heads, head_dim)# after einsum (N, query_len, heads, head_dim) then flatten last two dimensionsout = self.fc_out(out)return outclass TransformerBlock(nn.Module):def __init__(self, embed_size, heads, dropout, forward_expansion):super(TransformerBlock, self).__init__()self.attention = SelfAttention(embed_size, heads)self.norm1 = nn.LayerNorm(embed_size)self.norm2 = nn.LayerNorm(embed_size)self.feed_forward = nn.Sequential(nn.Linear(embed_size, forward_expansion*embed_size),nn.ReLU(),nn.Linear(forward_expansion*embed_size, embed_size))self.dropout = nn.Dropout(dropout)def forward(self, value, key, query, mask):attention = self.attention(value, key, query, mask)x = self.dropout(self.norm1(attention + query))forward = self.feed_forward(x)out = self.dropout(self.norm2(forward + x))return outclass Encoder(nn.Module):def __init__(self, src_vocab_size,embed_size,num_layers,heads,device,forward_expansion,dropout,max_length):super(Encoder, self).__init__()self.embed_size = embed_sizeself.device = deviceself.word_embedding = nn.Embedding(src_vocab_size, embed_size)self.position_embedding = nn.Embedding(max_length, embed_size)self.layers = nn.ModuleList([TransformerBlock(embed_size,heads,dropout=dropout,forward_expansion=forward_expansion)for _ in range(num_layers)])self.dropout = nn.Dropout(dropout)def forward(self, x, mask):N, seq_lengh = x.shapepositions = torch.arange(0, seq_lengh).expand(N, seq_lengh).to(self.device)out = self.dropout(self.word_embedding(x) + self.position_embedding(positions))for layer in self.layers:out = layer(out, out, out, mask)return outclass DecoderBlock(nn.Module):def __init__(self, embed_size, heads, forward_expansion, dropout, device):super(DecoderBlock, self).__init__()self.attention = SelfAttention(embed_size, heads)self.norm = nn.LayerNorm(embed_size)self.transformer_block = TransformerBlock(embed_size, heads, dropout, forward_expansion)self.dropout = nn.Dropout(dropout)def forward(self, x, value, key, src_mask, trg_mask):attention = self.attention(x, x, x, trg_mask)query = self.dropout(self.norm(attention + x))out = self.transformer_block(value, key, query, src_mask)return outclass Decoder(nn.Module):def __init__(self, trg_vocab_size, embed_size, num_layers, heads, forward_expansion, dropout, device, max_length):super(Decoder, self).__init__()self.device = deviceself.word_embedding = nn.Embedding(trg_vocab_size, embed_size)self.position_embedding = nn.Embedding(max_length, embed_size)self.layers = nn.ModuleList([DecoderBlock(embed_size, heads, forward_expansion, dropout, device)for _ in range(num_layers)])self.fc_out = nn.Linear(embed_size, trg_vocab_size)self.dropout = nn.Dropout(dropout)def forward(self, x, enc_out, src_mask, trg_mask):N, seq_length = x.shapepositions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)x = self.dropout((self.word_embedding(x) + self.position_embedding(positions)))for layer in self.layers:x = layer(x, enc_out, enc_out, src_mask, trg_mask)out = self.fc_out(x)return outclass Transformer(nn.Module):def __init__(self,src_vocab_size, trg_vocab_size,src_pad_idx,trg_pad_idx,embed_size=256,num_layers=6,forward_expansion=4,heads=8,dropout=0,device="cuda",max_length=100):super(Transformer, self).__init__()self.encoder = Encoder(src_vocab_size,embed_size,num_layers,heads,device,forward_expansion,dropout,max_length)self.decoder = Decoder(trg_vocab_size,embed_size,num_layers,heads,forward_expansion,dropout,device,max_length)self.src_pad_idx = src_pad_idxself.trg_pad_idx = trg_pad_idxself.device = devicedef make_src_mask(self, src):src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)#(N, 1, 1, src_len)return src_mask.to(self.device)def make_trg_mask(self, trg):N, trg_len = trg.shapetrg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand(N, 1, trg_len, trg_len)return trg_mask.to(self.device)def forward(self, src, trg):src_mask = self.make_src_mask(src)trg_mask = self.make_trg_mask(trg)enc_src = self.encoder(src, src_mask)out = self.decoder(trg, enc_src, src_mask,  trg_mask)return outif __name__ == "__main__":device = torch.device("cuda" if torch.cuda.is_available() else "cpu")print(device)x = torch.tensor([[1, 5, 6, 4, 3, 9, 5, 2, 0], [1, 8, 7, 3, 4, 5, 6, 7, 2]]).to(device)trg = torch.tensor([[1, 7, 4, 3, 5, 9, 2, 0], [1, 5, 6, 2, 4, 7, 6, 2]]).to(device)src_pad_idx = 0trg_pad_idx = 0src_vocab_size = 10trg_vocab_size = 10model = Transformer(src_vocab_size, trg_vocab_size, src_pad_idx, trg_pad_idx, device=device).to(device)out = model(x, trg[:, :-1])print(out.shape)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/73721.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Prometheus + Grafana安装

Prometheus是一款基于时序数据库的开源监控告警系统,非常适合Kubernetes集群的监控。Prometheus的基本原理是通过HTTP协议周期性抓取被监控组件的状态,任意组件只要提供对应的HTTP接口就可以接入监控。不需要任何SDK或者其他的集成过程。这样做非常适合做…

7、Kubernetes核心技术 - Secret

目录 一、Secret概述 二、Secret 三种类型 2.1、Opaque 2..2、kubernetes.io/dockerconfigjson 2.3、kubernetes.io/service-account-token 三、Secret创建 3.1、命令行方式创建 Secret 3.2、yaml方式创建 Secret 四、Secret解码 五、Secret使用 5.1、将 Secret 挂载…

银河麒麟V10 SP1安装网络调试助手

文章目录 系统环境文件准备软件配置过程系统环境 系统镜像:Kylin-Desktop-V10-SP1-General-Release-2203-ARM64.iso 内核:5.4.18-53-generic 文件准备 网络调试助手可执行文件压缩包下载m-net-assist-arm64-main.zip 链接:https://pan.baidu.com/s/10Vu8Z6wOzCImXZWAW0Y…

SpringBoot+Vue开发笔记

参考:https://www.bilibili.com/video/BV1nV4y1s7ZN?p1 ----------------------------------------------------------概要总结---------------------------------------------------------- 1、MVC架构: View:与用户交互 Controller&…

解密外接显卡:笔记本能否接外置显卡?如何连接外接显卡?

伴随着电脑游戏和图形处理的需求不断增加,很多笔记本电脑使用者开始考虑是否能够通过外接显卡来提升性能。然而,外接显卡对于笔记本电脑是否可行,以及如何连接外接显卡,对于很多人来说仍然是一个迷。本文将为您揭秘外接显卡的奥秘…

小研究 - 微服务系统服务依赖发现技术综述(一)

微服务架构得到了广泛的部署与应用, 提升了软件系统开发的效率, 降低了系统更新与维护的成本, 提高了系统的可扩展性. 但微服务变更频繁、异构融合等特点使得微服务故障频发、其故障传播快且影响大, 同时微服务间复杂的调用依赖关系或逻辑依赖关系又使得其故障难以被及时、准确…

无涯教程-jQuery - css( properties )方法函数

css(properties)方法将键/值对象设置为所有匹配元素的样式属性。 css( properties ) - 语法 selector.css( properties ) 上面的语法可以写成如下- selector.css( {key1:val1, key2:val2....keyN:valN}) 这是此方法使用的所有参数的描述- key:value - 设置为样式属…

【MySQL】复合查询

复合查询目录 一、基本查询二、多表查询三、自连接四、子查询4.1 单行子查询4.2 多行子查询4.3 多列子查询4.4 在from子句中使用子查询4.5 合并查询4.5.1 union4.5.2 union all 五、实战OJ 一、基本查询 --查询工资高于500或岗位为MANAGER的雇员,同时还要满足他们的…

【数据结构与算法——TypeScript】数组、栈、队列、链表

【数据结构与算法——TypeScript】 算法(Algorithm)的认识 解决问题的过程中,不仅仅 数据的存储方式会影响效率,算法的优劣也会影响效率 什么是算法? 定义: 🟢 一个有限指令集,每条指令的描述不依赖于言语…

【音视频SDK测评】线上K歌软件开发技术选型

摘要 在线K歌软件的开发有许多技术难点,需考虑到音频录制和处理、实时音频传输和同步、音频压缩和解压缩、设备兼容性问题等技术难点外,此外,开发者还应关注音乐版权问题,确保开发的应用合规合法。 前言 前面写了几期关于直播 …

[STL]详解list模拟实现

[STL]list模拟实现 文章目录 [STL]list模拟实现1. 整体结构总览2. 成员变量解析3. 默认成员函数构造函数1迭代器区间构造函数拷贝构造函数赋值运算符重载析构函数 4. 迭代器及相关函数迭代器整体结构总览迭代器的模拟实现begin函数和end函数begin函数和end函数const版本 5. 数据…

C语言指针详解

C语言指针详解 字符指针1.如何定义2.类型和指向的内容3.代码例子 指针数组1.如何定义2.类型和内容 数组指针1.如何定义2.类型和指向类型3.数组名vs&数组名数组指针运用 数组参数&指针参数一维数组传参二维数组传参一级指针传参二级指针传参 函数指针1.如何定义2.类型和…

【前端知识】React 基础巩固(三十九)——React-Router的基本使用

React 基础巩固(三十九)——React-Router的基本使用 一、Router的基本使用 Router中包含了对路径改变的监听,并且会将相应的路径传递给子组件。 Router包括两个API: BrowserRouter使用history模式 HashRouter使用hash模式(路径后面带有#号…

APP自动化测试-Python+Appium+Pytest+Allure框架实战封装(详细)

目录:导读 前言一、Python编程入门到精通二、接口自动化项目实战三、Web自动化项目实战四、App自动化项目实战五、一线大厂简历六、测试开发DevOps体系七、常用自动化测试工具八、JMeter性能测试九、总结(尾部小惊喜) 前言 pytest只是单独的…

无人驾驶实战-第五课(动态环境感知与3D检测算法)

激光雷达的分类: 机械式Lidar:TOF、N个独立激光单元、旋转产生360度视场 MEMS式Lidar:不旋转 激光雷达的输出是点云,点云数据特点: 简单:x y z i (i为信号强度) 稀疏:7%&…

【工具】自动搜索Research网站的学术会议排名

转载请注明出处:小锋学长生活大爆炸[xfxuezhang.cn] Research.com是一个可以搜索学术会议网站的影响因子的网站。 好用是好用,但有一个缺点:得手动选择类目。有这么多类目,一个个手动选也太累了。 所以做了一个自动搜索的小工具&a…

HTTP杂谈之Referer和Origin请求头再探

一 关于Referer和Origin的汇总 1) 知识是凌乱的,各位看官看个热闹即可2) 内容不断更新1、理解有盲区,需要及时纠正2、内容交叉有重复,需要适当删减3、扩展视野3) 以下内容都与Referer和Origin请求头有关联 nginx防盗链 HTTP杂谈之Referrer-Policy响应头 iframe标签referre…

新手入门Jenkins自动化部署入门详细教程

1. 背景 在实际开发中,我们经常要一边开发一边测试,当然这里说的测试并不是程序员对自己代码的单元测试,而是同组程序员将代码提交后,由测试人员测试; 或者前后端分离后,经常会修改接口,然后重新…

vue element el-upload附件上传、在线预览、下载当前预览文件

上传 在线预览&#xff08;iframe&#xff09;&#xff1a; payload&#xff1a; response&#xff1a; 全部代码&#xff1a; <template><div><el-table :data"tableData" border style"width: 100%"><el-table-column prop"d…

.Net6 Core Web API 配置 log4net + MySQL

目录 一、导入NuGet 包 二、添加配置文件 log4net.config 三、创建MySQL表格 四、Program全局配置 五、帮助类编写 六、效果展示 小编没有使用依赖注入的方式。 一、导入NuGet 包 ---- log4net 基础包 ---- Microsoft.Extensions.Logging.Log4Net…