pytorch小记(十四):pytorch中 nn.Embedding 详解

pytorch小记(十四):pytorch中 nn.Embedding 详解

  • PyTorch 中的 nn.Embedding 详解
    • 1. 什么是 nn.Embedding?
    • 2. nn.Embedding 的基本使用
      • 示例 1:基础用法
      • 示例 2:处理批次输入
    • 3. nn.Embedding 与 nn.Linear 的区别
      • 3.1 nn.Embedding
      • 3.2 nn.Linear
    • 4. nn.Embedding 与 nn.Sequential 的区别
    • 5. 应用场景
    • 6. 总结


PyTorch 中的 nn.Embedding 详解

在自然语言处理、推荐系统以及其他处理离散输入的任务中,我们常常需要将离散的标识符(例如单词、字符、用户 ID 等)转换为连续的、低维的向量表示。PyTorch 提供了专门的模块——nn.Embedding,用来实现这种“嵌入”操作。本文将详细解释 nn.Embedding 的工作原理、使用方法以及与普通线性层(nn.Linear)和顺序模块(nn.Sequential)的区别,并给出清晰的代码示例。


1. 什么是 nn.Embedding?

nn.Embedding 实际上是一个查找表(lookup table),它内部维护一个矩阵,每一行对应一个离散标识符的向量表示。

  • 假设你有一个词汇表,大小为 num_embeddings,每个词将映射到一个 embedding_dim 维的向量上。
  • nn.Embedding 会创建一个形状为 [num_embeddings, embedding_dim] 的矩阵。
  • 当你输入一个包含单词索引的张量时,模块会直接从这个矩阵中查找出相应行的向量,作为单词的嵌入表示。

这种方式的好处是直接“查找”而非进行繁琐的矩阵乘法计算,既高效又直观。


2. nn.Embedding 的基本使用

示例 1:基础用法

下面的例子展示了如何使用 nn.Embedding 将一组单词索引转换为对应的嵌入向量。

import torch
import torch.nn as nn# 定义一个嵌入层
# 假设词汇表大小为 10,每个单词用 5 维向量表示
embedding = nn.Embedding(num_embeddings=10, embedding_dim=5)# 打印嵌入矩阵的形状
print("嵌入矩阵形状:", embedding.weight.shape)
# 输出: torch.Size([10, 5])# 定义一个包含单词索引的张量,例如 [3, 7, 1]。索引 embedding 表中[3],[7],[1]行
indices = torch.tensor([3, 7, 1])# 使用嵌入层查找对应的嵌入向量
embedded_vectors = embedding(indices)
print("查找到的嵌入向量:")
print(embedded_vectors)

说明:

  • 输入是一个包含索引 [3, 7, 1] 的 1D 张量,输出是一个形状为 [3, 5] 的张量。
  • 每一行就是词汇表中对应索引的嵌入向量。

示例 2:处理批次输入

在实际任务中,我们通常一次处理多个样本。例如,一个批次中包含多个句子,每个句子由若干单词索引组成。下面的例子展示了如何处理批次数据。

# 假设有一个批次,包含 2 个样本,每个样本包含 4 个单词索引
'''
对应原数据的
[[[1],[2],[3],[4]],[[5],[6],[7],[8]]]
'''
batch_indices = torch.tensor([[1, 2, 3, 4],[5, 6, 7, 8]
])# 使用嵌入层查找嵌入向量
batch_embeddings = embedding(batch_indices)print("批次嵌入向量形状:", batch_embeddings.shape)
# 输出形状: torch.Size([2, 4, 5])

说明:

  • 输入的 batch_indices 形状为 [2, 4],表示 2 个样本,每个样本 4 个单词索引。
  • 输出为 [2, 4, 5],每个单词索引转换成 5 维嵌入向量。

3. nn.Embedding 与 nn.Linear 的区别

虽然 nn.Embedding 和 nn.Linear 都涉及到矩阵的操作,但二者解决的问题大不相同。

3.1 nn.Embedding

  • 用途:专门用于将离散的索引(如单词 ID)转换为连续的向量表示,是一种查找表操作。
  • 输入:通常为整数索引。
  • 输出:直接返回查找表中对应的向量,效率高,不进行额外的计算。

3.2 nn.Linear

  • 用途:用于实现线性变换,即对输入做矩阵乘法加上偏置,计算公式为 y = x W T + b y = xW^T + b y=xWT+b
  • 输入:需要连续数值的张量。
  • 应用:若要模拟嵌入操作,需要先将整数索引转换成 one-hot 编码,再通过 nn.Linear 进行计算,这样既低效又不直观。

总结

  • 使用 nn.Embedding 更直接、更高效,因为它只进行查找操作;
  • nn.Linear 则用于对连续特征进行线性变换。

4. nn.Embedding 与 nn.Sequential 的区别

  • nn.Sequential 是一个模块容器,用于按顺序组合多个层,适用于前向传播流程固定的情况。
  • nn.Embedding 则是一个具体的层,用于实现查找表功能。
  • 在模型中,我们通常将 nn.Embedding 放在最前面,将离散输入转换为连续向量,再结合 nn.Sequential 里的其他层进行进一步处理。

例如,在 NLP 模型中常常这样使用:

class TextModel(nn.Module):def __init__(self, vocab_size, embed_dim):super(TextModel, self).__init__()# 使用 nn.Embedding 将单词索引映射为嵌入向量self.embedding = nn.Embedding(vocab_size, embed_dim)# 使用 nn.Sequential 组合后续的线性层self.fc = nn.Sequential(nn.Linear(embed_dim, 10),nn.ReLU(),nn.Linear(10, 2))def forward(self, x):# x 的形状可能为 (batch_size, sequence_length)x_embed = self.embedding(x)  # 变为 (batch_size, sequence_length, embed_dim)# 对嵌入向量进行池化,变为 (batch_size, embed_dim)x_pool = x_embed.mean(dim=1)out = self.fc(x_pool)return out# 假设词汇表大小 100,嵌入维度 8
model = TextModel(vocab_size=100, embed_dim=8)

在这个例子中,nn.Embedding 将离散单词转换为连续向量,而 nn.Sequential 则定义了后续的前向传播步骤。


5. 应用场景

nn.Embedding 常用于:

  • 自然语言处理(NLP):将单词、子词、字符等离散输入转换为低维向量表示,为后续的 RNN、Transformer 模型提供输入。
  • 推荐系统:将用户 ID、商品 ID 映射为嵌入向量,用于捕捉用户和物品之间的相似性。
  • 图神经网络:将节点或边的离散标签转换为连续向量表示。

6. 总结

  • nn.Embedding 是一个查找表,用于将离散索引映射为连续向量。
  • 它的输入通常是整数张量,输出是对应的嵌入向量。
  • 与 nn.Linear 相比,nn.Embedding 不需要进行大量的计算,只是直接查找,所以更高效。
  • nn.Embedding 经常与 nn.Sequential 结合使用:先将离散数据转换为嵌入向量,再通过连续层进行处理。

通过以上详细解释和分步代码示例,希望大家能对 nn.Embedding 有一个全面的理解,并能在实际项目中正确使用它来提升模型的表现。

🚀 写在最后
利用 nn.Embedding,你可以轻松将离散数据转换为高质量的连续表示,这在各种深度学习任务中都是至关重要的!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/37883.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

MySQL0基础学习记录-下载与安装

下载 下载地址: (Windows)https://dev.mysql.com/downloads/file/?id536787 安装 直接点next,出现: 点execute 然后一直next到这页: next 然后需要给root设置一个密码: 在next。。很多页…

React基础语法速览

一、项目创建 npm create vite 这里选择react即可,如图: 二、基本文件说明 react函数式编程时,用的是JSX语法进行开发的,这里注意,return时只能有一个根标签; 三、React核心语法 1.插值功能 插值可以使用…

IT工具 | node.js 进程管理工具 PM2 大升级!支持 Bun.js

P(rocess)M(anager)2 是一个 node.js 下的进程管理器,内置负载均衡,支持应用自动重启,常用于生产环境运行 node.js 应用,非常好用👍 🌼概述 2025-03-15日,PM2发布最新版本v6.0.5,这…

teaming技术

一.介绍 在CentOS 6与RHEL 6系统中,双网卡绑定采用的是bonding技术。到了CentOS 7,不仅能继续沿用bonding,还新增了teaming技术。在此推荐使用teaming,因其在查看与监控方面更为便捷 。 二.原理 这里介绍两种最常见的双网卡绑定…

SpringSecurity配置(自定义认证过滤器)

文末有本篇文章的项目源码文件可供下载学习 在这个案例中,我们已经实现了自定义登录URI的操作,登录成功之后,我们再次访问后端中的API的时候要在请求头中携带token,此时的token是jwt字符串,我们需要将该jwt字符串进行解析,查看解析后的User对象是否处于登录状态.登录状态下,将…

【机器学习-模型评估】

“评估”已建立的模型 在进行回归和分类时,为了进行预测,定义了预测函数fθ(x) 然后根据训练数据求出了预测函数的参数θ(即对目标函数进行微分,然后求出参数更新表达式的操作) 之前求出参数更新表达式之后就结束了。但是,其实我…

区块链开发技术公司:引领数字经济的创新力量

在数字化浪潮席卷全球的今天,区块链技术作为新兴技术的代表,正以其独特的去中心化、不可篡改和透明性等特点,深刻改变着各行各业的发展格局。区块链开发技术公司,作为这一领域的先锋和推动者,正不断研发创新&#xff0…

油候插件、idea、VsCode插件推荐(自用)

开发软件: 之前的文章: 开发必装最实用工具软件与网站 推荐一下我使用的开发工具 目前在用的 油候插件 AC-baidu-重定向优化百度搜狗谷歌必应搜索_favicon_双列 让查询变成多列,而且可以流式翻页 Github 增强 - 高速下载 github下载 TimerHo…

Linux中find 命令的高级用法 组合条件 与、或、非(-a、-o、!) 以及通过 -regex 和 -iregex 选项使用正则表达式

find 命令详解 find 是 Unix 和类 Unix 操作系统(如 Linux 和 macOS)中一个非常强大的命令行工具,用于在文件系统中搜索文件和目录。find 命令可以根据多种条件(如文件名、类型、大小、修改时间等)进行搜索&#xff0c…

基于Python的垃圾短信分类

垃圾短信分类 1 垃圾短信分类问题介绍 1.1 垃圾短信 随着移动互联科技的高速发展,信息技术在不断改变着我们的生活,让我们的生活更方便,其中移动通信技术己经在我们生活起到至关重要的作用,与我们每个人人息息相关。短信作为移…

go语言中空结构体

空结构体(struct{}) 普通理解 在结构体中,可以包裹一系列与对象相关的属性,但若该对象没有属性呢?那它就是一个空结构体。 空结构体,和正常的结构体一样,可以接收方法函数。 type Lamp struct{}func (l Lamp) On()…

Transformer-GRU、Transformer、CNN-GRU、GRU、CNN五模型多变量回归预测

Transformer-GRU、Transformer、CNN-GRU、GRU、CNN五模型多变量回归预测 目录 Transformer-GRU、Transformer、CNN-GRU、GRU、CNN五模型多变量回归预测预测效果基本介绍程序设计参考资料 预测效果 基本介绍 Transformer-GRU、Transformer、CNN-GRU、GRU、CNN五模型多变量回归预…

大数据学习(80)-数仓分层

🍋🍋大数据学习🍋🍋 🔥系列专栏: 👑哲学语录: 用力所能及,改变世界。 💖如果觉得博主的文章还不错的话,请点赞👍收藏⭐️留言📝支持一…

flink 写入es的依赖导入问题(踩坑记录)

flink 写入es的依赖导入问题(踩坑记录) ps:可能只是flink低版本才会有这个问题 1. 按照官网的导入方式: 2. 你会在运行sql-client的时候完美得到一个错误: Exception in thread "main" org.apache.flink.table.client.SqlClientEx…

Python 用户账户(创建用户账户)

Web应用程序的核心是让任何用户都能够注册账户并能够使用它,不管用户身处何方。在本章中,你将创建一些表单,让用户能够添加主题和条目,以及编辑既有的 条目。你还将学习Django如何防范对基于表单的网页发起的常见攻击,…

10-BST(二叉树)-建立二叉搜索树,并进行前中后遍历

题目 来源 3540. 二叉搜索树 - AcWing题库 思路 建立二叉搜索树(注意传参时用到了引用,可以直接对root进行修改),同时进行递归遍历;遍历可以分前中后三种写,也可以用标志来代替合在一起。其余详见代码。…

无人机点对点技术要点分析!

一、技术架构 1. 网络拓扑 Ad-hoc网络:无人机动态组建自组织网络,节点自主协商路由,无需依赖地面基站。 混合架构:部分场景结合中心节点(如指挥站)与P2P网络,兼顾集中调度与分布式协同。 2.…

[极客大挑战 2019]Knife——3.20BUUCTF练习day4(2)

[极客大挑战 2019]Knife——3.20BUUCTF练习day4(2) 解题内容 在一个文件中输入以下内容,该文件是phtml文件(HTML嵌套PHP代码,可以绕过很多限制)但在本题中要先改文件名为2.gif然后抓包修改后缀名为phtml,因为只可以上传gif和jpg…

1、环境初始化--Linux安装dockerCE

主要安装环境ubuntu、centos、Windows 因某些原因,使用阿里镜像源: https://developer.aliyun.com/mirror/docker-ce?spma2c6h.13651102.0.0.4a451b11EjxMKe Ubuntu安装步骤&相应解释 sudo apt-get update 解释: 刷新软件源列表 该命…

什么是 BA ?BA怎么样?BA和BI是什么关系?

前几天有朋友在评论区提到了BA这个角色,具体是干什么的,我大概来说一下。 什么是BA BA 英文的全称是Business Analyst,从字面上意思就是商业分析师,做过商业智能BI项目的应该比较了解。实际上以我个人的经验,BA 的角…