Pytorch深度解析:Transformer嵌入层源码逐行解读

前言

本部分博客需要先阅读博客:
《Transformer实现以及Pytorch源码解读(一)-数据输入篇》
作为知识储备。

Embedding使用方式

如下面的代码中所示,embedding一般是先实例化nn.Embedding(vocab_size, embedding_dim)。实例化的过程中输入两个参数:vocab_size和embedding_dim。其中的vocab_size是指输入的数据集合中总共涉及多少个去重后的单词;embedding_dim是指,每个单词你希望用多少维度的向量表示。随后,实例化的embedding在forward中被调用self.embeddings(inputs)。

class Transformer(nn.Module):def __init__(self, vocab_size, embedding_dim, hidden_dim, num_class,dim_feedforward=512, num_head=2, num_layers=2, dropout=0.1, max_len=512, activation: str = "relu"):super(Transformer, self).__init__()# 词嵌入层self.embedding_dim = embedding_dimself.embeddings = nn.Embedding(vocab_size, embedding_dim)self.position_embedding = PositionalEncoding(embedding_dim, dropout, max_len)# 编码层:使用Transformerencoder_layer = nn.TransformerEncoderLayer(hidden_dim, num_head, dim_feedforward, dropout, activation)self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)# 输出层self.output = nn.Linear(hidden_dim, num_class)def forward(self, inputs, lengths):inputs = torch.transpose(inputs, 0, 1)hidden_states = self.embeddings(inputs)hidden_states = self.position_embedding(hidden_states)attention_mask = length_to_mask(lengths) == Falsehidden_states = self.transformer(hidden_states, src_key_padding_mask=attention_mask).transpose(0, 1)logits = self.output(hidden_states)log_probs = F.log_softmax(logits, dim=-1)return log_probs

数据被怎样变换了?

如下图所示,第一个tensor表示input,该input表示一个句子( sentence),只是该句子中的单词用整数进行了代替,相同的整数表示相同的单词。而每个1在embedding之后,变成了相同过的向量。

我们将以上的代码重新的运行一遍,发现表示1的向量改变了,这说明embedding 的过程不是确定的,而是随机的。

数据是怎样被变化的?

Embedding类在调用过程中主要涉及到以下几个核心方法:_
init
,rest_parameters,forward:

Embedding类的初始化过程如下所示。当_weight没有的情况下调用Parameter初始化一个空的向量,该向量的维度与输入数据中的去重单词个数(num_bembeddings)一样。然后调用reset_parameters方法。

 def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None,max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False,sparse: bool = False, _weight: Optional[Tensor] = None,device=None, dtype=None) -> None:factory_kwargs = {'device': device, 'dtype': dtype}super(Embedding, self).__init__()self.num_embeddings = num_embeddingsself.embedding_dim = embedding_dimif padding_idx is not None:if padding_idx > 0:assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings'elif padding_idx < 0:assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings'padding_idx = self.num_embeddings + padding_idxself.padding_idx = padding_idxself.max_norm = max_normself.norm_type = norm_typeself.scale_grad_by_freq = scale_grad_by_freqif _weight is None:self.weight = Parameter(torch.empty((num_embeddings, embedding_dim), **factory_kwargs))# print("===========================================1")# print(self.weight)#将self.weight进行nornal归一化self.reset_parameters()print("===========================================2")print(self.weight)else:assert list(_weight.shape) == [num_embeddings, embedding_dim], \'Shape of weight does not match num_embeddings and embedding_dim'self.weight = Parameter(_weight)self.sparse = sparse

reset_parameters的实现如下所示,主要是调用了init.norma_方法。

    def reset_parameters(self) -> None:init.normal_(self.weight)self._fill_padding_idx_with_zero()

init.normal_又调用了torch.nn.init中的normal方法。该方法将空的self.weight矩阵填充为一个符合 (0,1)正太分布的矩阵。

N

(

mean

,

std

2

)

.

\mathcal{N}(\text{mean}, \text{std}^2).

N

(

mean

,

std

2

)

.

def normal_(tensor: Tensor, mean: float = 0., std: float = 1.) -> Tensor:r"""Fills the input Tensor with values drawn from the normaldistribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`.Args:tensor: an n-dimensional `torch.Tensor`mean: the mean of the normal distributionstd: the standard deviation of the normal distributionExamples:>>> w = torch.empty(3, 5)>>> nn.init.normal_(w)"""return _no_grad_normal_(tensor, mean, std)

继续追踪_no_grad_normal_(tensor, mean, std)我们发现,该方法是通过c++实现,所在的源码文件目录为:

namespace torch {
namespace nn {
namespace init {
namespace {
struct Fan {explicit Fan(Tensor& tensor) {const auto dimensions = tensor.ndimension();TORCH_CHECK(dimensions >= 2,"Fan in and fan out can not be computed for tensor with fewer than 2 dimensions");if (dimensions == 2) {in = tensor.size(1);out = tensor.size(0);} else {in = tensor.size(1) * tensor[0][0].numel();out = tensor.size(0) * tensor[0][0].numel();}}int64_t in;int64_t out;
};
Tensor normal_(Tensor tensor, double mean, double std) {NoGradGuard guard;return tensor.normal_(mean, std);
}

forward方法的c++实现如下所示。

torch::Tensor EmbeddingImpl::forward(const Tensor& input) {return F::detail::embedding(input,weight,options.padding_idx(),options.max_norm(),options.norm_type(),options.scale_grad_by_freq(),options.sparse());
}

继续追踪,发现weight中的每个变量被下面的c++代码填充了正太分布的随机数。

void normal_kernel(const TensorBase &self, double mean, double std, c10::optional<Generator> gen) {CPUGeneratorImpl* generator = get_generator_or_default<CPUGeneratorImpl>(gen, detail::getDefaultCPUGenerator());templates::cpu::normal_kernel(self, mean, std, generator);
}

随机数的生成调用如下的代码,首先询问:目前代码是在什么设备上运行,并调用cpu或者gup上的随机数生成方法。

template <typename T>
static inline T * check_generator(c10::optional<Generator> gen) {TORCH_CHECK(gen.has_value(), "Expected Generator but received nullopt");TORCH_CHECK(gen->defined(), "Generator with undefined implementation is not allowed");TORCH_CHECK(T::device_type() == gen->device().type(), "Expected a '", T::device_type(), "' device type for generator but found '", gen->device().type(), "'");return gen->get<T>();
}/*** Utility function used in tensor implementations, which* supplies the default generator to tensors, if an input generator* is not supplied. The input Generator* is also static casted to* the backend generator type (CPU/CUDAGeneratorImpl etc.)*/
template <typename T>
static inline T* get_generator_or_default(const c10::optional<Generator>& gen, const Generator& default_gen) {return gen.has_value() && gen->defined() ? check_generator<T>(gen) : check_generator<T>(default_gen);
}

至此,embedding的每个随机数的生成过程都清楚了。

总结

Embedding的过程,其实就是为每个单词对应一个向量的过程。该向量为(0,1)正太分布,该矩阵在Embedding的实例化过程就已经被初始化完成。在调用Embedding示例的时候即forward开始工作的时候,只是做了一个匹配的过程,也就是将<字典,向量>的对应关系应用到input上。前期解读该部分源码的困惑是一只找不到forward中的对应处理过程,以为embedding的处理逻辑是在forward的阶段展开的,显然这种想法是不对的。Pytorch的架构设计的的确优雅!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/359242.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【shell脚本速成】mysql备份脚本

文章目录 案例需求脚本应用场景&#xff1a;解决问题脚本思路实现代码 &#x1f308;你好呀&#xff01;我是 山顶风景独好 &#x1f388;欢迎踏入我的博客世界&#xff0c;能与您在此邂逅&#xff0c;真是缘分使然&#xff01;&#x1f60a; &#x1f338;愿您在此停留的每一刻…

更改ip后还被封是ip质量的原因吗?

不同的代理IP的质量相同&#xff0c;一般来说可以根据以下几个因素来进行判断&#xff1a; 1.可用率 可用率就是提取的这些代理IP中可以正常使用的比率。假如我们无法使用某个代理IP请求目标网站或者请求超时&#xff0c;那么就代表这个代理不可用&#xff0c;一般来说免费代…

最强铁基超导磁体诞生!科学家基于机器学习设计新研究体系,磁场强度超过先前记录2.7倍

超导现象&#xff0c;自 1911 年被发现以来&#xff0c;始终保持着前沿性与高价值&#xff0c;吸引了大批学者投身其研究中。超导现象是指某些材料在低于特定温度时电阻突然降为零&#xff0c;这不仅是材料学的革命性突破&#xff0c;也为电力传输、磁悬浮交通和医疗成像等领域…

【CentOS7】Linux安装Docker教程(保姆篇)

文章目录 查看是否已安装卸载&#xff08;已安装过&#xff09;docker安装友情提示 更多相关内容可查看 注&#xff1a;本篇为Centos7安装Docker&#xff0c;若为其他系统请理性参考 查看是否已安装 如果已安装&#xff0c;请卸载重新安装 docker --version这里显示已安装 …

mac鼠标自动点击工具:RapidClick for Mac 激活版

RapidClick是一种简单易用的点击工具&#xff0c;它可以帮助用户快速进行连续的鼠标点击操作。该软件可用于自动点击鼠标&#xff0c;从而提高用户在电脑上的效率和速度。RapidClick还具有一些自定义设置&#xff0c;比如点击间隔和点击频率&#xff0c;可以根据用户的需求进行…

Redis-数据结构-跳表详解

Redis概述 Redis-数据结构-跳表详解 跳表&#xff08;Skip List&#xff09;是一种基于并联的链表结构&#xff0c;用于在有序元素序列中快速查找元素的数据结构。 Redis 中广泛使用跳表来实现有序集合&#xff08;Sorted Set&#xff09;这一数据结构。 1.跳表的基本概念和…

Java程序之可爱的小兔兔

题目&#xff1a; 古典问题&#xff0c;有一对兔子&#xff0c;从出生后第3个月起每个月都生一对兔子&#xff0c;小兔子长到第三个月后每个月又生一对兔子&#xff0c;假如兔子都不死&#xff0c;问每个月的兔子总数为多少? 程序分析&#xff1a; 兔子的规律为数列1,1,2,3,…

.locked勒索病毒详解 | 防御措施 | 恢复数据

引言 在数字化飞速发展的今天&#xff0c;我们享受着信息技术带来的便捷与高效&#xff0c;然而&#xff0c;网络安全问题也随之而来&#xff0c;且日益严重。其中&#xff0c;勒索病毒以其狡猾的传播方式和巨大的破坏性&#xff0c;成为了网络安全领域中的一大难题。.locked勒…

捷瑞数字业绩波动性明显:关联交易不低,募资必要性遭质疑

《港湾商业观察》施子夫 5月22日&#xff0c;山东捷瑞数字科技股份有限公司&#xff08;以下简称&#xff0c;捷瑞数字&#xff09;及保荐机构国新证券披露第三轮问询的回复&#xff0c;继续推进北交所上市进程。 从2023年6月递表开始&#xff0c;监管层已下发三轮审核问询函…

项目训练营第二天

项目训练营第二天 用户登录逻辑 1、账户名不少于4位 2、密码不少于8位 3、数据库表中能够查询到账户、密码 4、密码查询时用同样加密脱敏处理手段处理后再和数据库中取出字段进行对比&#xff0c;如果账户名未查询到&#xff0c;直接返回null 5、后端设置相应的脱敏后用户的s…

我的常见问题记录

1,maven在idea工具可以正常使用,在命令窗口执行出现问题 代码: E:\test-hello\simple-test>mvn clean compile [INFO] Scanning for projects... [WARNING] [WARNING] Some problems were encountered while building the effective model for org.consola:simple-test:jar…

一个完整的Flutter应用

本文基于以下链接进行细节补充15.2 Flutter APP代码结构 | 《Flutter实战第二版》 代码结构 我们先来创建一个全新的Flutter工程&#xff0c;命名为"github_client_app" 我们在项目根目录下分别创建imgs和fonts、jsons、l10n文件夹 工程目录如下&#xff1a; 在l…

LLC开关电源开发:LLC设计参考文档(模态分析)

电源简析和全桥LLC模型分析 1.1模拟电源、开关电源和数字电源简介 1.1.1 模拟电源 模拟电源&#xff1a;即变压器电源&#xff0c;通过铁芯、线圈来实现&#xff0c;线圈的匝数决定了两端的电压比&#xff0c;铁芯的作用是传递变化磁场&#xff0c;&#xff08;我国&#xff09…

MySQL数据库(五):事务

MySQL数据库中的事务是一种用来保证一系列操作要么全部成功&#xff0c;要么全部取消的机制。想象一下你去超市购物&#xff0c;拿了很多商品&#xff0c;如果中途发现没带钱包&#xff0c;你可以放弃这次购买&#xff0c;所有商品会回到原位。通过事务&#xff0c;可以确保数据…

dial tcp 10.96.0.1:443: connect: no route to host

1、创建Pod一直不成功&#xff0c;执行kubectl describe pod runtime-java-c8b465b98-47m82 查看报错 Warning FailedCreatePodSandBox 2m17s kubelet Failed to create pod sandbox: rpc error: code Unknown desc failed to setup network for…

WebHttpServletRequestResponse(完整知识点汇总)

额外知识点 Web核心 Web 全球广域网&#xff0c;也成为万维网&#xff08;www&#xff09;&#xff0c;可通过浏览器访问的网站 JavaWeb 使用Java技术来解决相关Web互联网领域的技术栈 JavaWeb技术栈 B/S架构&#xff1a;Browser/Server&#xff0c;即浏览器/服务器 架构模式…

Vue核心指令解析:探索MVVM与数据操作之美

文章目录 前言一、Vue.js1. MVVM模式介绍2. 单页面组件介绍及案例讲解3. 插值表达式介绍及案例讲解 二、Vue常用指令详解1. 数据绑定指令v-textv-html 2. 条件渲染指令v-ifv-show 3. 列表渲染指令v-for循环数组介绍及案例讲解循环对象介绍及案例讲解 4. 事件监听指令v-on事件修…

Python-矩阵元素定位

[题目描述] 小理得到了一个 n 行 m 列的矩阵&#xff0c;现在他想知道第 x 行第 y 列的值是多少&#xff0c;请你帮助他完成这个任务。输入格式&#xff1a; 第一行包含两个数 n 和m &#xff0c;表示这个矩阵包含 n行 m 列。从第 2 行到第 n1 行&#xff0c;每行输入 m 个整数…

【JS逆向百例】某点数据逆向分析,多方法详解

前言 最近收到粉丝的私信&#xff0c;其在逆向某个站点时遇到了些问题&#xff0c;在查阅资料未果后&#xff0c;来询问K哥&#xff0c;K哥一向会尽力满足粉丝的需求。网上大多数分析该站点的教程已经不再适用&#xff0c;本文K哥将提供 3 种解决方案&#xff0c;对于 webpack…

[个人感悟] MySQL应该考察哪些问题?

前言 数据存储一直是软件开发中必不可少的一环, 从早期的文件存储txt, Excel, Doc, Access, 以及关系数据库时代的MySQL,SQL Server, Oracle, DB2, 乃至最近的大数据时代f非关系型数据库:Hadoop, HBase, MongoDB. 此外还有顺序型数据库InfluxDB, 图数据库Neo4J, 分布式数据库T…