Transformer 代码剖析6 - 位置编码 (pytorch实现)

一、位置编码的数学原理与设计思想

1.1 核心公式解析

位置编码采用正弦余弦交替编码方案:
P E ( p o s , 2 i ) = sin ⁡ ( p o s 1000 0 2 i / d m o d e l ) P E ( p o s , 2 i + 1 ) = cos ⁡ ( p o s 1000 0 2 i / d m o d e l ) PE_{(pos,2i)} = \sin\left(\frac{pos}{10000^{2i/d_{model}}}\right) \\ PE_{(pos,2i+1)} = \cos\left(\frac{pos}{10000^{2i/d_{model}}}\right) PE(pos,2i)=sin(100002i/dmodelpos)PE(pos,2i+1)=cos(100002i/dmodelpos)

式中:

  • p o s pos pos:当前词在序列中的绝对位置
  • i i i:特征维度的索引( 0 ≤ i < d m o d e l / 2 0 \leq i < d_{model}/2 0i<dmodel/2
  • 1000 0 2 i / d m o d e l 10000^{2i/d_{model}} 100002i/dmodel:频率控制项,形成指数衰减的频率分布

1.2 设计优势分析

1. 绝对位置感知: 每个位置生成唯一编码模式
2. 相对位置建模: 通过三角函数加法公式可推导任意两个位置的关联度
3. 多频特征捕捉: 不同频率的正余弦波组合形成丰富的表征空间
4. 值域归一化: 所有编码值分布在[-1,1]区间,与词嵌入维度保持数值一致性
(图示:不同维度上的位置编码波形,高频维度对应快速变化,低频维度对应缓慢变化)
(图示:不同维度上的位置编码波形,高频维度对应快速变化,低频维度对应缓慢变化)

二、代码架构与执行流程

2.1 类结构设计

PositionalEncoding
__init__构造函数
创建零矩阵
配置梯度策略
构建位置索引
生成维度索引
计算正弦编码
计算余弦编码
forward前向传播
获取输入尺寸
返回截断编码

2.2 核心代码模块

class PositionalEncoding(nn.Module):def __init__(self, d_model, max_len, device):super().__init__()# 编码矩阵初始化(关键参数说明)self.encoding = torch.zeros(max_len, d_model, device=device)self.encoding.requires_grad = False  # 冻结梯度计算# 位置索引构建(维度变换演示)pos = torch.arange(0, max_len, device=device).float().unsqueeze(dim=1)# 维度索引生成(步长控制逻辑)_2i = torch.arange(0, d_model, step=2, device=device).float()# 编码计算过程(数学实现)self.encoding[:, 0::2] = torch.sin(pos / (10000  (_2i / d_model)))self.encoding[:, 1::2] = torch.cos(pos / (10000  (_2i / d_model)))def forward(self, x):batch_size, seq_len = x.size()return self.encoding[:seq_len, :]

三、逐行代码深度解析

3.1 构造函数解析

def __init__(self, d_model, max_len, device):super(PositionalEncoding, self).__init__()
  • 功能说明:继承PyTorch模块基类,初始化可训练参数
  • 参数详解:
    • d_model:编码维度(需与词嵌入维度一致)
    • max_len:预计算的最大序列长度(如512对应BERT标准配置)
    • device:硬件加速配置(实现跨平台兼容)
    self.encoding = torch.zeros(max_len, d_model, device=device)self.encoding.requires_grad = False
  • 设计意图:创建静态编码矩阵,避免反向传播计算
  • 内存优化:通过requires_grad=False节省显存占用
  • 维度说明:矩阵形状为[max_len, d_model],例如max_len=512时生成512x512矩阵
    pos = torch.arange(0, max_len, device=device)pos = pos.float().unsqueeze(dim=1)
  • 位置索引构建:生成[0,1,…,max_len-1]的连续位置序列
  • 维度变换:通过unsqueeze将1D张量转换为2D(max_len,1),便于广播计算
    _2i = torch.arange(0, d_model, step=2, device=device).float()
  • 步长控制:step=2确保交替访问奇偶索引
  • 数值范围:当d_model=512时,生成[0,2,4,…,510]的索引序列
    self.encoding[:, 0::2] = torch.sin(pos / (10000  (_2i / d_model)))self.encoding[:, 1::2] = torch.cos(pos / (10000  (_2i / d_model)))
  • 分片赋值:通过0::21::2实现奇偶列交替填充
  • 频率控制:10000 (_2i/d_model)生成指数衰减的频率系数

3.2 前向传播解析

def forward(self, x):batch_size, seq_len = x.size()return self.encoding[:seq_len, :]
  • 动态适配:根据实际输入序列长度截取编码
  • 广播机制:自动扩展编码矩阵到批次维度(无需显式复制)
  • 数值叠加:后续与词嵌入进行element-wise相加操作

四、张量运算可视化演示

4.1 示例参数配置

假设:

  • d_model = 4
  • max_len = 3
  • device = 'cpu'

4.2 计算过程推演

步骤1:生成位置索引

pos = [[0],[1],[2]]  # shape (3,1)

步骤2:创建维度索引

_2i = [0, 2]  # d_model=4时step=2生成

步骤3:计算频率项

频率项 = 10000^( (0/4), (2/4) ) = [1, 10000^0.5] ≈ [1, 100]

步骤4:计算位置编码

sin项:
pos / [1, 100] = [[0/1, 0/100],[1/1, 1/100],[2/1, 2/100]]= [[0, 0],[1, 0.01],[2, 0.02]]
sin值:
[[0, 0],[0.8415, 0.00999983],[0.9093, 0.01999867]]cos项计算同理...

最终编码矩阵:

PE = [[sin(0), cos(0), sin(0), cos(0)],      # 位置0[sin(1), cos(0.01), sin(1), cos(0.01)],# 位置1[sin(2), cos(0.02), sin(2), cos(0.02)] # 位置2
]

五、工程实践与优化策略

5.1 配置参数建议

  1. max_len设定:应大于训练数据最大序列长度20%
  2. 设备兼容性:通过device参数统一管理计算设备
  3. 混合精度训练:可将编码矩阵转为half精度

5.2 性能优化技巧

  1. 预计算缓存:提前生成编码矩阵避免运行时计算
  2. 内存映射:对超长序列使用内存映射文件
  3. 稀疏矩阵:对长文本场景采用分块加载策略

六、与其他模块的协同工作

6.1 与词嵌入的集成

class TransformerEmbedding(nn.Module):def __init__(self, vocab_size, d_model, max_len, device, dropout):super().__init__()self.tok_emb = nn.Embedding(vocab_size, d_model)self.pos_emb = PositionalEncoding(d_model, max_len, device)self.dropout = nn.Dropout(dropout)def forward(self, x):tok_emb = self.tok_emb(x)pos_emb = self.pos_emb(x)return self.dropout(tok_emb + pos_emb)
  • 加法融合:通过element-wise相加实现信息融合
  • 梯度隔离:位置编码不参与梯度更新
  • 维度验证:确保tok_embpos_emb维度严格一致

七、典型应用场景分析

7.1 文本生成任务

  • 长序列处理:通过位置编码捕获远距离依赖
  • 解码器优化:在自回归生成时动态调整位置编码

7.2 语音识别系统

  • 时序建模:精确捕捉语音信号的时序特征
  • 多尺度编码:结合不同频率分量处理语音信号

八、扩展研究方向

  1. 相对位置编码:改进绝对位置编码的局限性
  2. 动态频率调整:根据输入数据自动调节频率参数
  3. 混合编码方案:结合可学习参数与固定编码
  4. 量子化压缩:对编码矩阵进行低比特量化

原项目代码(附)

"""
@author : Hyunwoong
@when : 2019-10-22
@homepage : https://github.com/gusdnd852
"""import torch
from torch import nn# 定义一个名为PositionalEncoding的类,它继承自nn.Module,用于计算正弦位置编码。
class PositionalEncoding(nn.Module):"""计算正弦位置编码的类。"""def __init__(self, d_model, max_len, device):"""PositionalEncoding类的构造函数。:param d_model: 模型的维度(即嵌入向量的大小)。:param max_len: 序列的最大长度。:param device: 硬件设备设置(CPU或GPU)。"""super(PositionalEncoding, self).__init__()  # 调用父类nn.Module的构造函数。# 初始化一个与输入矩阵大小相同的零矩阵,用于存储位置编码,以便后续与输入矩阵相加。self.encoding = torch.zeros(max_len, d_model, device=device)self.encoding.requires_grad = False  # 我们不需要计算位置编码的梯度。# 创建一个从0到max_len-1的一维张量,表示序列中的位置索引。pos = torch.arange(0, max_len, device=device)# 将位置索引张量转换为浮点数,并增加一个维度,从1D变为2D,以表示每个位置的索引。pos = pos.float().unsqueeze(dim=1)# 1D => 2D,增加维度以表示单词的位置。# 创建一个从0到d_model-1,步长为2的一维浮点数张量,用于计算正弦和余弦函数的指数部分。_2i = torch.arange(0, d_model, step=2, device=device).float()# 'i'表示d_model的索引(例如,嵌入大小=50时,'i'的范围为[0,50])。# "step=2"意味着'i'每次增加2(相当于2*i)。# 使用正弦函数计算位置编码的偶数索引位置的值。self.encoding[:, 0::2] = torch.sin(pos / (10000 ** (_2i / d_model)))# 使用余弦函数计算位置编码的奇数索引位置的值。self.encoding[:, 1::2] = torch.cos(pos / (10000 ** (_2i / d_model)))# 计算位置编码,以考虑单词的位置信息。def forward(self, x):# self.encoding是预先计算好的位置编码矩阵。# [max_len = 512, d_model = 512],表示最大长度为512,维度为512的位置编码。# 获取输入x的批次大小和序列长度。batch_size, seq_len = x.size()# [batch_size = 128, seq_len = 30],表示批次大小为128,序列长度为30。# 返回与输入序列长度相匹配的位置编码。return self.encoding[:seq_len, :]# [seq_len = 30, d_model = 512],返回的形状为序列长度乘以维度。# 它将与输入嵌入(tok_emb)相加,tok_emb的形状通常为[128, 30, 512]。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/27874.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

CF 452A.Eevee(Java实现)

题目分析 输入一个数字-长度&#xff0c;输入一个字符串。判断这个字符串是具体的哪一个单词 思路分析 首先给了长度&#xff0c;那我先判断长度相同的单词&#xff0c;然后再一一对比&#xff0c;如果都能通过&#xff0c;那就输出这个单词 代码 import java.util.*;public …

【监控】使用Prometheus+Grafana搭建服务器运维监控面板(含带BearerToken的Exporter配置)

【监控】使用PrometheusGrafana搭建服务器运维监控面板&#xff08;含带BearerToken的Exporter配置&#xff09; 文章目录 1、Grafana 数据可视化面板2、Prometheus - 收集和存储指标数据3、Exporter - 采集和上报指标数据 1、Grafana 数据可视化面板 Grafana 是一个开源的可视…

ADC采集模块与MCU内置ADC性能对比

2.5V基准电压源&#xff1a; 1. 精度更高&#xff0c;误差更小 ADR03B 具有 0.1% 或更小的初始精度&#xff0c;而 电阻分压方式的误差主要来自电阻的容差&#xff08;通常 1% 或 0.5%&#xff09;。长期稳定性更好&#xff0c;分压电阻容易受到温度、老化的影响&#xff0c;长…

UDP协议(20250303)

1. UDP UDP:用户数据报协议&#xff08;User Datagram Protocol&#xff09;&#xff0c;传输层协议之一&#xff08;UDP&#xff0c;TCP&#xff09; 2. 特性 发送数据时不需要建立链接&#xff0c;节省资源开销不安全不可靠的协议 //一般用在实时性比较高…

基于https虚拟主机配置

一、https介绍 http 明文&#xff0c;80/tcp https 密文&#xff0c;443/tcp 二、安全性保障 1、数据安全性 数据加密 2、数据完整性 3、验证身份的真实性、有效性 三、数据安全性 手段&#xff1a;加密 发送方加密数据&#xff0c;接收方解密数据 对称加密算法 加密、解密数据…

机器学习(五)

一&#xff0c;多类&#xff08;Multiclass&#xff09; 多类是指输出不止有两个输出标签&#xff0c;想要对多个种类进行分类。 Softmax回归算法&#xff1a; Softmax回归算法是Logistic回归在多类问题上的推广&#xff0c;和线性回归一样&#xff0c;将输入的特征与权重进行…

概率论基础概念

前言 本文隶属于专栏《机器学习数学通关指南》&#xff0c;该专栏为笔者原创&#xff0c;引用请注明来源&#xff0c;不足和错误之处请在评论区帮忙指出&#xff0c;谢谢&#xff01; 本专栏目录结构和参考文献请见《机器学习数学通关指南》 正文 &#x1f3b2; 1. 随机事件 …

动漫短剧开发公司,短剧小程序搭建快速上线

在当今快节奏的生活里&#xff0c;人们的娱乐方式愈发多元&#xff0c;而动漫短剧作为新兴娱乐形式&#xff0c;正以独特魅力迅速崛起&#xff0c;成为娱乐市场的耀眼新星。近年来&#xff0c;动漫短剧市场呈爆发式增长&#xff0c;吸引众多创作者与观众目光。 从市场规模来看…

MySQL零基础教程15—简单的表连接(join)

在学习子查询的时候&#xff0c;我们已经感受到了&#xff0c;在一个语句中&#xff0c;通过访问不同表的数据最终获取我们想要的结果这种操作方式&#xff0c;实际上在mysql中&#xff0c;还有更加有趣的一个功能&#xff0c;就是表连接&#xff0c;同样是在查询数据的时候连接…

【AVRCP】深入剖析 AVRCP 命令体系:从单元到特定命令的全面解读

在蓝牙音频 / 视频远程控制规范&#xff08;AVRCP&#xff09;中&#xff0c;丰富的命令体系是实现设备间高效交互的关键。这些命令涵盖了单元命令、通用单元与子单元命令、特定命令等多个层面&#xff0c; 一、支持的单元命令 1.1 单元命令概述 AVRCP中支持的单元命令在设备…

物业管理系统源码 物业小程序源码

物业管理系统源码 物业小程序源码 一、基础信息管理 1. 房产信息管理 记录楼栋、单元、房间的详细信息&#xff08;面积、户型、产权等&#xff09;。 管理业主/租户的档案&#xff0c;包括联系方式、合同信息等。 2. 公共资源管理 管理停车场、电梯、绿化带、公…

专题二最大连续1的个数|||

1.题目 题目分析&#xff1a; 给一个数字k&#xff0c;可以把数组里的0改成1&#xff0c;但是只能改k次&#xff0c;然后该变得到的数组能找到最长的子串且都是1。 2.算法原理 这里不用真的把0变成1&#xff0c;因为改了比较麻烦&#xff0c;下次用就要改回成1&#xff0c;这…

【计算机网络入门】初学计算机网络(十一)重要

目录 1. CIDR无分类编址 1.1 CIDR的子网划分 1.1.1 定长子网划分 1.1.2 变长子网划分 2. 路由聚合 2.1 最长前缀匹配原则 3. 网络地址转换NAT 3.1 端口号 3.2 IP地址不够用&#xff1f; 3.3 公网IP和内网IP 3.4 NAT作用 4. ARP协议 4.1 如何利用IP地址找到MAC地址…

精讲坐标轴系统(Axis)

续前文&#xff1a; 保姆级matplotlib教程&#xff1a;详细目录 保姆级seaborn教程&#xff1a;详细目录 seaborn和matplotlib怎么选&#xff0c;还是两个都要学&#xff1f; 详解Python matplotlib深度美化&#xff08;第一期&#xff09; 详解Python matplotlib深度美化&…

Metal学习笔记十:光照基础

光和阴影是使场景流行的重要要求。通过一些着色器艺术&#xff0c;您可以突出重要的对象、描述天气和一天中的时间并设置场景的气氛。即使您的场景由卡通对象组成&#xff0c;如果您没有正确地照亮它们&#xff0c;场景也会变得平淡无奇。 最简单的光照方法之一是 Phong 反射模…

动态规划_路径问题(典型算法思想)—— OJ例题算法解析思路

目录 一、62. 不同路径 - 力扣&#xff08;LeetCode&#xff09; 算法代码&#xff1a; 代码思路分析 问题定义&#xff1a; 动态规划定义&#xff1a; 边界条件&#xff1a; 填表过程&#xff1a; 返回结果&#xff1a; 代码优化思路 空间优化&#xff1a; 滚动数组…

【AI论文】ViDoRAG:通过动态迭代推理代理实现视觉文档检索增强生成

摘要&#xff1a;理解富含视觉信息的文档中的信息&#xff0c;对于传统的检索增强生成&#xff08;Retrieval-Augmented Generation&#xff0c;RAG&#xff09;方法来说&#xff0c;仍然是一个重大挑战。现有的基准测试主要集中在基于图像的问答&#xff08;Question Answerin…

【赵渝强老师】监控Redis

对运行状态的Redis实例进行监控是运维管理中非常重要的内容&#xff0c;包括&#xff1a;监控Redis的内存、监控Redis的吞吐量、监控Redis的运行时信息和监控Redis的延时。通过Redis提供的监控命令便能非常方便地实现对各项指标的监控。 一、监控Redis的内存 视频讲解如下 【…

HTML前端手册

HTML前端手册 记录前端框架在使用过程中遇到的各种问题和解决方案&#xff0c;供后续快速进行手册翻阅使用 文章目录 HTML前端手册1-前端框架1-TypeScript框架2-CSS框架 2-前端Demo1-Html常用代码 2-知云接力3-Live2D平面动画 3-前端运维1-NPM版本管理 1-前端框架 1-TypeScrip…

C++:类和对象(下篇)

1. 再谈构造函数 1.1 构造函数体赋值 在创建对象时&#xff0c;编译器通过调用构造函数&#xff0c;给对象中各个成员变量一个合适的初始值。 class Date { public:Date(int year, int month, int day){_year year;_month month;_day day;} private:int _year;int _mont…