nn.embedding函数详解(pytorch)

提示:文章附有源码!!!

文章目录

  • 前言
  • 一、nn.embedding函数解释
  • 二、nn.embedding函数使用方法
  • 四、模型训练与预测的权重变化探讨


前言

最近发现prompt工程(如sam模型),也有transform的detr模型等都使用了nn.Embedding函数,对points、boxes或learn query进行编码或解码。因此,我想写一篇文章作为记录,本想简单对其 介绍,但写着写着就想把所有与它相关东西作为记录。本文章探讨了nn.Embedding参数、使用方法、模型训练与预测的变化,并附有列子源码作为支撑 ,呈现一个较为完善的理解内容。

一、nn.embedding函数解释

Embedding实际是一个索引表或查找表,它是符合随机初始化生成的正太分布的表,将输入向量化,其结构如下:

nn.Embedding(num_embeddings, embedding_dim)

第1个参数 num_embeddings 就是生成num_embeddings个嵌入向量。
第2个参数 embedding_dim 就是嵌入向量的维度,即用embedding_dim值的维数来表示一个基本单位。

当然,该函数还有很多其它参数,解释如下:

参数源码注释如下:

num_embeddings (int): size of the dictionary of embeddings
embedding_dim (int): the size of each embedding vector
padding_idx (int, optional): If specified, the entries at :attr:`padding_idx` do not contribute to the gradient;therefore, the embedding vector at :attr:`padding_idx` is not updated during training,i.e. it remains as a fixed "pad". For a newly constructed Embedding,the embedding vector at :attr:`padding_idx` will default to all zeros,but can be updated to another value to be used as the padding vector.
max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm`is renormalized to have norm :attr:`max_norm`.
norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``.
scale_grad_by_freq (boolean, optional): If given, this will scale gradients by the inverse of frequency ofthe words in the mini-batch. Default ``False``.
sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor.See Notes for more details regarding sparse gradients.

参数中文解释:

num_embeddings (python:int) – 词典的大小尺寸,比如总共出现5000个词,那就输入5000。此时index为(0-4999embedding_dim (python:int) – 嵌入向量的维度,即用多少维来表示一个符号。
padding_idx (python:int, optional) – 填充id,比如,输入长度为100,但是每次的句子长度并不一样,后面就需要用统一的数字填充,而这里就是指定这个数字,这样,网络在遇到填充id时,就不会计算其与其它符号的相关性。(初始化为0max_norm (python:float, optional) – 最大范数,如果嵌入向量的范数超过了这个界限,就要进行再归一化。
norm_type (python:float, optional) – 指定利用什么范数计算,并用于对比max_norm,默认为2范数。
scale_grad_by_freq (boolean, optional) – 根据单词在mini-batch中出现的频率,对梯度进行放缩。默认为False.
sparse (bool, optional) – 若为True,则与权重矩阵相关的梯度转变为稀疏张量

注:该函数服从正太分布,该函数可参与训练,我将在后面做解释。

二、nn.embedding函数使用方法

该函数实际是对词的编码,假如你有2句话,每句话有四个词,那么你想对每个词使用6个维度表达,其代码如下:

import torch.nn as nn
import torch
if __name__ == '__main__':embedding = nn.Embedding(100, 6)  # 我设置100个索引,每个使用6个维度表达。input = torch.LongTensor([[1, 2, 4, 5],[4, 3, 2, 3]])  # a batch of 2 samples of 4 indices eache = embedding(input)print('输出尺寸', e.shape)print('输出值:\n',e)weights=embedding.weightprint('embed权重输出值:\n', weights[:6])

输出结果:
在这里插入图片描述

从图上可看出,输入编码是通过索引查找已编号embedding的权重,并将其赋值替换表达。换句话说,nn.Embedding(100, 6)生成正太分布100行6列数据,行必须超过输入句子词语长度,而句子每个词使用整数编码成索引,该索引对应之前embedding行寻找,得到对应行
维度,即可转为表达该词的特征向量。

四、模型训练与预测的权重变化探讨

之前已说过nn.Embedding()在训练过程中会发生变化,但在预测中将不在变化,应该是被训练成最佳词的向量维度表达,也就是说每个词唯一对应索引,被Embedding特征表达训练成最佳特征表达,也可说训练词索引特征表达固定。为探讨此过程,我写了对应示列,如下:

import torch
from torch.nn import Embeddingclass Model(torch.nn.Module):def __init__(self):super(Model, self).__init__()self.emb = Embedding(5, 3)def forward(self,vec):input = torch.tensor([0, 1, 2, 3, 4])emb_vec1 = self.emb(input)# print(emb_vec1)  ### 输出对同一组词汇的编码output = torch.einsum('ik, kj -> ij', emb_vec1, vec)return output
def simple_train():model = Model()vec = torch.randn((3, 1))label = torch.Tensor(5, 1).fill_(3)loss_fun = torch.nn.MSELoss()opt = torch.optim.SGD(model.parameters(), lr=0.015)print('初始化emebding参数权重:\n',model.emb.weight)for iter_num in range(100):output = model(vec)loss = loss_fun(output, label)opt.zero_grad()loss.backward(retain_graph=True)opt.step()# print('第{}次迭代emebding参数权重{}:\n'.format(iter_num, model.emb.weight))print('训练后emebding参数权重:\n',model.emb.weight)torch.save(model.state_dict(),'./embeding.pth')return modeldef simple_test():model = Model()ckpt = torch.load('./embeding.pth')model.load_state_dict(ckpt)model=model.eval()vec = torch.randn((3, 1))print('加载emebding参数权重:\n', model.emb.weight)for iter_num in range(100):output = model(vec)print('n次预测后emebding参数权重:\n', model.emb.weight)if __name__ == '__main__':simple_train()  # 训练与保存权重simple_test()

结果如下:

在这里插入图片描述
训练代码参考博客:点击这里

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/183924.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

数据结构大体体系

逻辑结构 线性结构线性表一串珠子用线连起来,这就是典型的“线性存储结构”。每颗珠子之间的关系结构也很简单,包括头尾的话,它们最少有一个关系对象,而中间的珠子无论前后都只有一个关系对象,即 one-to-one栈队列字符…

Chatgpt人工智能对话源码系统分享 带完整搭建教程

ChatGPT的开发基于大规模预训练模型技术。预训练模型是一种在大量文本数据上进行训练的模型,可以学习到各种语言模式和知识。在ChatGPT中,预训练模型被用于学习如何生成文本,并且可以用于各种不同的任务,如对话生成、问答、摘要等…

时序预测 | MATLAB实现基于LSSVM-Adaboost最小二乘支持向量机结合AdaBoost时间序列预测

时序预测 | MATLAB实现基于LSSVM-Adaboost最小二乘支持向量机结合AdaBoost时间序列预测 目录 时序预测 | MATLAB实现基于LSSVM-Adaboost最小二乘支持向量机结合AdaBoost时间序列预测预测效果基本介绍模型描述程序设计参考资料 预测效果 基本介绍 1.MATLAB实现基于LSSVM-Adaboos…

全球10米土地覆盖产品(ESA)数据集2020和2021年

简介 全球10米土地覆盖产品(ESA)来源于欧空局,是基于哨兵一号、哨兵二号数据制作的2020年的10m分辨率的全球土地覆盖数据。土地利用数据一共分为11类,分别是:林地、灌木、草地、耕地、建筑、裸地/稀疏植被区、雪和冰、开阔水域、草本湿地、红树林、苔藓…

贰[2],QT异常处理

1,异常:QT编译警告 warning LNK4042: 对象被多次指定;已忽略多余的指定 处理办法,检查.pri文件,是否关联了多个相同的文件(头文件.h/源文件.cpp) 2,异常:C4819: 该文件包含不能在当前代码页(936…

云尘 命令执行系列

第一题 system <?php include "flag.php";if (isset($_POST[cmd])) {system($_POST[cmd]); }show_source(__FILE__);代码如上 system($_POST[cmd]); POST请求发送一个名为 cmd 的参数&#xff0c;然后将该参数的值传递给系统命令执行函数 system()&#xff0c…

C语言学习笔记之结构篇

C语言是一门结构化程序设计语言。在C语言看来&#xff0c;现实生活中的任何事情都可看作是三大结构或者三大结构的组合的抽象&#xff0c;即顺序&#xff0c;分支&#xff08;选择&#xff09;&#xff0c;循环。 所谓顺序就是一条路走到黑&#xff1b;生活中在很多事情上我们都…

Spring Boot项目中通过 Jasypt 对属性文件中的账号密码进行加密

下面是在Spring Boot项目中对属性文件中的账号密码进行加密的完整步骤&#xff0c;以MySQL的用户名为root&#xff0c;密码为123321为例&#xff1a; 步骤1&#xff1a;引入Jasypt依赖 在项目的pom.xml文件中&#xff0c;添加Jasypt依赖&#xff1a; <dependency><…

ClickHouse 学习之从高级到监控以及备份(二)

第 一 部分 高级篇 第 1 章 Explain 查看执行计划 在 clickhouse 20.6 版本之前要查看 SQL 语句的执行计划需要设置日志级别为 trace 才能可以看到&#xff0c;并且只能真正执行 sql&#xff0c;在执行日志里面查看。在 20.6 版本引入了原生的执行计划的语法。在 20.6.3 版本成…

ubuntu 20.04 server安装

ubuntu 20.04 server安装 ubuntu-20.04.6-live-server-amd64.iso 安装 安装ubuntu20.04 TLS系统后&#xff0c;开机卡在“A start job is running for wait for network to be Configured”等待连接两分多钟。 cd /etc/systemd/system/network-online.target.wants/在[Servi…

揭开堆叠式自动编码器的强大功能

一、介绍 在不断发展的人工智能和机器学习领域&#xff0c;深度学习技术因其处理复杂和高维数据的能力而广受欢迎。在各种深度学习模型中&#xff0c;堆叠式自动编码器是一种多功能且功能强大的工具&#xff0c;可用于特征学习、降维和数据表示。本文探讨了堆叠式自动编码器在深…

R语言实操记录——导出高清图片(矢量图)

R语言 R语言实操记录——导出高清图片&#xff08;矢量图&#xff09; 文章目录 R语言一、起因&#xff08;闲聊&#xff0c;可跳过&#xff09;二、如何在R中导出高清图片&#xff08;矢量图&#xff09;2.1、保存为EPS图片格式后转AI编辑2.2、保存为PDF格式&#xff08;推荐…

LabVIEW实现变风量VAV终端干预PID控制

LabVIEW实现变风量VAV终端干预PID控制 变风量&#xff08;VAV&#xff09;控制方法的研究一直是VAV空调研究的重点。单端PID控制在温差较大时&#xff0c;系统容易出现过冲。针对空调终端单端PID控制的不足&#xff0c;设计一种干预控制与PID控制耦合的控制方法。项目使用LabV…

关于Alibaba Cloud Toolkit 下载配置以及后端自动部署

idea中File-Settings-Plugins 搜索Alibaba Cloud Toolkit点击下载&#xff0c;下载完成重启 1、点击 Tools-Alibaba Cloud-Deploy to Host 部署到主机 2、配置服务器ip、jar包启动命令、服务器jar存放位置 3、设置服务器ip用户名密码&#xff0c;点击测试连接情况 4、配置脚本…

Flink SQL TopN语句详解

TopN 定义&#xff08;⽀持 Batch\Streaming&#xff09;&#xff1a; TopN 对应离线数仓的 row_number()&#xff0c;使⽤ row_number() 对某⼀个分组的数据进⾏排序。 应⽤场景&#xff1a; 根据 某个排序 条件&#xff0c;计算 某个分组 下的排⾏榜数据。 SQL 语法标准&am…

基于Java+SpringBoot+LayUI仓库管理系统

一.项目介绍 本项目是使用JavaSpringBoot开发&#xff0c;可以实现仓库的注册、登录&#xff0c;登录后可进入系统&#xff0c;进行客户管理、供应商管理、商品管理、商品退货查询管理、登录日志及退出等几大模块。系统界面采用传统的后台管理界面&#xff0c;界面简单、直观。…

【大数据】NiFi 中的处理器(一):GenerateTableFetch

NiFi 中的处理器&#xff08;一&#xff09;&#xff1a;GenerateTableFetch 1.简介2.应用场景3.示例3.1 案例一&#xff1a;无输入流文件&#xff0c;来源表含增量字段3.2 案例二&#xff1a;无输入流文件&#xff0c;不含增量字段3.3 案例三&#xff1a;无输入流文件&#xf…

Transformer的最简洁pytorch实现

目录 前言 1. 数据预处理 2. 模型参数 3. Positional Encoding 4. Pad Mask 5. Subsequence Mask 6. ScaledDotProductAttention 7. MultiHeadAttention 8. FeedForward Networks 9. Encoder Layer 10. Encoder 11. Decoder Layer 12. Decoder 13. Transformer 1…

【单片机基础小知识-如何通过指针来读写寄存器】

寄存器的本质就是内存&#xff0c;RAM&#xff0c;而指针是可以对内存进行操作的&#xff0c;因此可以通过指针来读写寄存器。 如何读取以下一片地址&#xff1a; 步骤1、首地址 结构体&#xff0c;它所占用的内存空间大小与它内部成员有关。 构造一个28字节的类型 type…

计算机服务器中了locked勒索病毒怎么办,勒索病毒解密,数据恢复

随着网络技术的不断成熟&#xff0c;网络中存在的病毒威胁也不断增多&#xff0c;近期&#xff0c;云天数据恢复中心陆续接到很多企业的求助&#xff0c;企业的计算机服务器数据库遭到了勒索病毒攻击&#xff0c;并且勒索病毒的攻击与加密形式也发生了许多变化。其中攻击次数较…