2024年大模型面试准备(四):大模型面试必会的位置编码(绝对位置编码sinusoidal,旋转位置编码RoPE,以及相对位置编码ALiBi)

节前,我们组织了一场算法岗技术&面试讨论会,邀请了一些互联网大厂朋友、参加社招和校招面试的同学,针对大模型技术趋势、大模型落地项目经验分享、新手如何入门算法岗、该如何备战、面试常考点分享等热门话题进行了深入的讨论。


合集在这里:《大模型面试宝典》(2024版) 正式发布!


Transformer 模型在处理序列数据时,其自注意力机制使得模型能够全局地捕捉不同元素之间的依赖关系,但这样做的代价是丧失了序列中的元素顺序信息。由于自注意力机制并不考虑元素在序列中的位置,所以在输入序列的任何置换下都是不变的,这就意味着模型无法区分序列中元素的相对位置。在许多自然语言处理任务中,词语之间的顺序是至关重要的,所以需要一种方法来让模型捕获这一信息。

因此在送入编码器端建模其上下文语义之前,一个非常重要的操作是 在词嵌入中加入位置编码(Positional Encoding) 这一特征。

具体来说,序列中每一个单词所在的位置都对应一个向量。这一向量会与单词表示对应相加并送入到后续模块中做进一步处理。在训练的过程当中,模型会自动地学习到如何利用这部分位置信息。

常见的位置编码主要有绝对位置编码(sinusoidal),旋转位置编码(RoPE),以及相对位置编码ALiBi

1、绝对位置编码sinusoidal

绝对位置编码是直接将序列中每个位置的信息编码进模型的,从而使模型能够了解每个元素在序列中的具体位置。原始Transformer提出时采用了sinusoidal位置编码,通过使用不同频率的正弦和余弦的函数,使得模型捕获位置之间的复杂关系,且这些编码与序列中每个位置的绝对值有关。sinusoidal位置编码公式如下:

在这里插入图片描述
在这里插入图片描述

原始 Transformer 的位置编码虽然是基于绝对位置的,但其数学结构使其能够捕获一些相对位置信息。使用正弦和余弦函数的组合为每个位置创建编码,波长呈几何级数排列,意味着每个位置的编码都是独特的。同时,正弦和余弦函数的周期性特性确保了不同位置之间的编码关系是连续且平滑的。

比如:

  • 对于相邻位置,位置编码的差异较小,与两者之间的距离成正比。

  • 对于相隔较远的位置,位置编码的差异较大,与两者之间的距离也成正比。

这种连续和平滑的关系允许模型学习位置之间的相对关系,而不仅仅是各自的绝对位置。考虑两个位置和,由于正弦和余弦函数的性质,位置编码的差值将与和之间的差值有关。这意味着通过比较不同位置编码之间的差值,模型可以推断出它们之间的相对位置。

总结来说,通过上面这种方式计算位置编码有这样几个好处:

  • 首先,正余弦函数的范围是在 [-1,+1],导出的位置编码与原词嵌入相加,不会使得结果偏离过远而破坏原有单词的语义信息。

  • 其次,依据三角函数的基本性质,可以得知第 pos + k 个位置的编码是第 pos 个位置的编码的线性组合,这就意味着位置编码中蕴含着单词之间的距离信息。

使用 Pytorch 实现的位置编码参考代码如下:

class PositionalEncoder(nn.Module):def __init__(self, d_model, max_seq_len = 80):super().__init__()self.d_model = d_model# 根据 pos 和 i 创建一个常量 PE 矩阵pe = torch.zeros(max_seq_len, d_model)for pos in range(max_seq_len):for i in range(0, d_model, 2):pe[pos, i] = math.sin(pos / (10000 ** ((2 * i)/d_model)))pe[pos, i + 1] = math.cos(pos / (10000 ** ((2 * (i + 1))/d_model)))pe = pe.unsqueeze(0)self.register_buffer('pe', pe)def forward(self, x):# 使得单词嵌入表示相对大一些x = x * math.sqrt(self.d_model)# 增加位置常量到单词嵌入表示中seq_len = x.size(1)x = x + Variable(self.pe[:,:seq_len], requires_grad=False).cuda()return x

另一份实现和可视化代码如下:

import math
import torch
import torch.nn as nn
import numpy as npclass PositionalEncoding(nn.Module):def __init__(self, d_model: int, dropout_prob: float, max_len: int = 5000):super().__init__()self.dropout = nn.Dropout(dropout_prob)encodings = self.get_positional_encoding(d_model, max_len)self.register_buffer('positional_encodings', encodings, False)@staticmethoddef get_positional_encoding(d_model: int, max_len: int):position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)two_i = torch.arange(0, d_model, 2, dtype=torch.float32)div_term = torch.exp(two_i * -(math.log(10000.0) / d_model))encodings = torch.zeros(max_len, d_model)encodings[:, 0::2] = torch.sin(position * div_term)encodings[:, 1::2] = torch.cos(position * div_term)return encodings.unsqueeze(0).requires_grad_(False)def forward(self, x: torch.Tensor):pe = self.positional_encodings[:x.shape[1]].detach().requires_grad_(False)return self.dropout(x + pe)def _test_positional_encoding():import matplotlib.pyplot as pltplt.figure(figsize=(15, 5))pe = PositionalEncoding.get_positional_encoding(20, 100)print(pe.shape)plt.plot(np.arange(100), pe[:, 0, 4:8].numpy())plt.legend(["dim %d" % p for p in [4, 5, 6, 7]])plt.title("Positional encoding")plt.show()if __name__ == '__main__':_test_positional_encoding()

可以看到不同维度沿着序列方向的位置编码变化如下图所示:

图片

2、旋转位置编码RoPE

sinusoidal位置编码对相对位置关系的表示还是比较间接的,那有没有办法更直接的表示相对位置关系呢?那肯定是有的,而且有许多不同的方法,旋转位置编码(Rotary Position Embedding,RoPE)是一种用绝对位置编码来表征相对位置编码的方法,并被用在了很多大语言模型的设计中,很多成功的LLM,例如LLAMA系列、GLM、百川、通义千问等,都使用了RoPE。

RoPE 借助了复数的思想,出发点是通过绝对位置编码的方式实现相对位置编码。

在这里插入图片描述

图片

本文主要介绍思想,具体推导感兴趣的可以看苏剑林老师的博客或者论文:RoFormer: Enhanced Transformer with Rotary Position Embedding

有了这一形式后,具体实现有两种方式:

  • 转到复数域,对两个向量进行旋转,再转回实数域

  • 由于上述矩阵 Rn 具有稀疏性,因此可以使用逐位相乘 ⊗ 操作进一步加快计算速度,直接在实数域通过向量和正余弦函数的乘法进行运算,也就是下面这个公式:
    图片

LLaMA的代码实现就是采用了第一种形式,如下:

def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))t = torch.arange(end, device=freqs.device)  # type: ignorefreqs = torch.outer(t, freqs).float()  # type: ignorefreqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64return freqs_cisdef reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):ndim = x.ndimassert 0 <= 1 < ndimassert freqs_cis.shape == (x.shape[1], x.shape[-1])shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]return freqs_cis.view(*shape)def apply_rotary_emb(xq: torch.Tensor,xk: torch.Tensor,freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))freqs_cis = reshape_for_broadcast(freqs_cis, xq_)xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)return xq_out.type_as(xq), xk_out.type_as(xk)

第二种方式的实现会更加易懂一点,就是把前面那个公式中的四个向量凑出来,然后照着公式算一下就可以了,示例代码如下:

import torch
import mathdef rotary_position_embedding(q, k):"""Rotary Position Embedding (RoPE) for queries and keys.Args:q: tensor for queries of shape (batch_size, num_heads, seq_len, dim)k: tensor for keys of shape (batch_size, num_heads, seq_len, dim)Returns:Rotated queries and keys"""batch_size, num_heads, seq_len, dim = q.size()# Begin of sinusoidal_position_embedding contentposition = torch.arange(seq_len, dtype=torch.float).unsqueeze(-1).to(q.device)div_term = torch.exp(torch.arange(0, dim, 2, dtype=torch.float) * -(math.log(10000.0) / dim)).to(q.device)pos_emb = position * div_termpos_emb = torch.stack([torch.sin(pos_emb), torch.cos(pos_emb)], dim=-1).flatten(-2, -1)pos_emb = pos_emb.unsqueeze(0).unsqueeze(1)pos_emb = pos_emb.expand(batch_size, num_heads, -1, -1)# End of sinusoidal_position_embedding content# Extract and duplicate cosine and sine embeddingscos_emb = pos_emb[..., 1::2].repeat_interleave(2, dim=-1)sin_emb = pos_emb[..., ::2].repeat_interleave(2, dim=-1)# Create alternate versions of q and kq_alternate = torch.stack([-q[..., 1::2], q[..., ::2]], dim=-1).reshape(q.size())k_alternate = torch.stack([-k[..., 1::2], k[..., ::2]], dim=-1).reshape(k.size())# Rotate queries and keysq_rotated = q * cos_emb + q_alternate * sin_embk_rotated = k * cos_emb + k_alternate * sin_embreturn q_rotated, k_rotated

相对位置编码AliBi

受到 T5 Bias 的启发,Press 等人提出了 ALiBi 算法,是一种预定义的相对位置编码。与传统方法不同,ALiBi 不向单词embedding中添加位置embedding,而是根据token之间的距离给 attention score 加上一个预设好的偏置矩阵,比如 和 相对位置差 1 就加上一个 -1 的偏置,两个 token 距离越远这个负数就越大,代表他们的相互贡献越低。由于注意力机制一般会有多个head,这里针对每一个head会乘上一个预设好的斜率项(Slope)。

在这里插入图片描述

图片

ALiBi 对最近性具有归纳偏差,它对远程查询-键对之间的注意力分数进行惩罚,随着键和查询之间的距离增加,惩罚增加。不同的注意头以不同的速率增加其惩罚,这取决于斜率幅度。实验证明这组斜率参数适用于各种文本领域和模型尺寸,不需要在新的数据和架构上调整斜率值。

因此ALiBi方法不需要对原始网络进行改动,允许在较短的输入序列上训练模型,同时在推理时能够有效地外推到较长的序列,从而实现了更高的效率和性能。

下面给出一份实验代码,大家可以自己跑一跑感觉一下:

import math
import torch
from torch import nndef get_slopes(n_heads: int):n = 2 ** math.floor(math.log2(n_heads))m_0 = 2.0 ** (-8.0 / n)m = torch.pow(m_0, torch.arange(1, 1 + n))if n < n_heads:m_hat_0 = 2.0 ** (-4.0 / n)m_hat = torch.pow(m_hat_0, torch.arange(1, 1 + 2 * (n_heads - n), 2))m = torch.cat([m, m_hat])return m@torch.no_grad()
def get_alibi_biases(n_heads: int, mask: torch.Tensor):m = get_slopes(n_heads).to(mask.device)seq_len = mask.size(0)distance = torch.tril(torch.arange(0, -seq_len, -1).view(-1, 1).expand(seq_len, seq_len))print(distance)return distance[:, :, None] * m[None, None, :]seq_len = 10
n_heads = 8m = get_slopes(n_heads)
print(m)alibi_biases = torch.zeros(seq_len,seq_len)
for j in range(1,seq_len):for i in range(j, seq_len):alibi_biases[i, i - j] = -j
print(alibi_biases)print(alibi_biases[:, :, None].shape, m[None, None, :].shape)alibi_biases[:, :, None] * m[None, None, :]

代码打印出的结果如下:

图片

可以看到上面打印出来的第一行就是给不同head的系数 ,第二行矩阵就是基于两个 token 之间距离计算的偏置矩阵。

技术交流群

前沿技术资讯、算法交流、求职内推、算法竞赛、面试交流(校招、社招、实习)等、与 10000+来自港科大、北大、清华、中科院、CMU、腾讯、百度等名校名企开发者互动交流~

我们建了算法岗技术与面试交流群, 想要进交流群、需要源码&资料、提升技术的同学,可以直接加微信号:mlc2040。加的时候备注一下:研究方向 +学校/公司+CSDN,即可。然后就可以拉你进群了。

方式①、微信搜索公众号:机器学习社区,后台回复:加群
方式②、添加微信号:mlc2040,备注:技术交流

用通俗易懂方式讲解系列

  • 《大模型面试宝典》(2024版) 正式发布!
  • 《大模型实战宝典》(2024版)正式发布!
  • 2024年大模型面试准备(一):LLM主流结构和训练目标、构建流程
  • 2024年大模型面试准备(二):LLM容易被忽略的Tokenizer与Embedding
  • 2024年大模型面试准备(三):聊一聊大模型的幻觉问题

参考文献:

[1] Touvron H, Lavril T, Izacard G, et al. Llama: Open and efficient foundation language models[J]. arXiv preprint arXiv:2302.13971, 2023.
[2] Raffel C, Shazeer N, Roberts A, et al. Exploring the limits of transfer learning with a unified text-to-text transformer[J/OL]. Journal of Machine Learning Research, 2020, 21(140):1-67. http://jmlr.org/papers/v21/20-074.html.
[3] Press O, Smith N A, Lewis M. Train short, test long: Attention with linear biases enables input length extrapolation[J]. arXiv preprint arXiv:2108.12409, 2021.
[4] Sun Y, Dong L, Patra B, et al. A length-extrapolatable transformer[J]. arXiv preprint arXiv:2212.10554, 2022.
[5] Chen S, Wong S, Chen L, et al. Extending context window of large language models via positional interpolation[J]. arXiv preprint arXiv:2306.15595, 2023.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/287193.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

生成可读取配置文件的独立运行jar程序

前言: 周五刚躺下,前线打来语音要个下载文件的小程序,下载路径和下载码需要根据配置获取,程序需要在服务器执行。当然配置的设计是个人设计的,不然每次更新下载码都要重新出具jar包,太麻烦。多年没写独立运行的jar包了,翻阅了相关资料,最终还是功夫不负有心人。想着这种…

鸿蒙 HarmonyOS应用开发之API:Context

Context 是应用中对象的上下文&#xff0c;其提供了应用的一些基础信息&#xff0c;例如resourceManager&#xff08;资源管理&#xff09;、applicationInfo&#xff08;当前应用信息&#xff09;、dir&#xff08;应用文件路径&#xff09;、area&#xff08;文件分区&#x…

uni-app攻略:如何对接驰腾打印机

一.引言 在当前的移动开发生态中&#xff0c;跨平台框架如uni-app因其高效、灵活的特点受到了开发者们的青睐。同时&#xff0c;随着物联网技术的飞速发展&#xff0c;智能打印设备已成为许多业务场景中不可或缺的一环。今天&#xff0c;我们就来探讨如何使用uni-app轻松对接驰…

阿赵UE学习笔记——22、动画合成

阿赵UE学习笔记目录 大家好&#xff0c;我是阿赵。   继续学习虚幻引擎的使用。这次来看看动画合成功能。   所谓的动画合成&#xff0c;意思就是把多段已经存在的动画拼接在一起&#xff0c;成为一段新的动画。比如之前做的钢铁侠例子里面&#xff0c;钢铁侠的待机动作感觉…

零基础机器学习(3)之机器学习的一般过程

文章目录 一、机器学习一般过程1.数据获取2.特征提取3.数据预处理①去除唯一属性②缺失值处理A. 均值插补法B. 同类均值插补法 ③重复值处理④异常值⑤数据定量化 4.数据标准化①min-max标准化&#xff08;归一化&#xff09;②z-score标准化&#xff08;规范化&#xff09; 5.…

AI入侵游戏业:是颠覆者还是创新助手?揭秘未来游戏新趋势!

在科技日新月异的今天&#xff0c;人工智能&#xff08;AI&#xff09;已经成为各行各业的关注焦点。而在娱乐产业中&#xff0c;AI技术的引入也让人们对电子游戏的未来发展产生了无限遐想。那么&#xff0c;AI究竟会给电子游戏行业带来怎样的变革&#xff1f;它会成为行业的颠…

【嵌入式学习】Qtday03.26

一、思维导图 二、练习 头文件 #ifndef WIDGET_H #define WIDGET_H#include <QWidget> #include <QTime> #include <QTextToSpeech>QT_BEGIN_NAMESPACE namespace Ui { class Widget; } QT_END_NAMESPACEclass Widget : public QWidget {Q_OBJECTpublic:Wi…

128天创作纪念日

分享的力量 hellohello~大家好✨✨我是大耳朵土土垚&#x1f973;&#x1f973;&#x1f973;今天是我创作128天的纪念日&#x1f389;&#x1f389;&#x1f389;&#xff0c;今天听到一句话——分享自己开心的事情就像有丝分裂一样会将快乐一直扩散&#x1f496;&#x1f496…

基于前端技术实现的全面预算编制系统

前言 在现代商业环境中&#xff0c;预测销售数据和实际成本是每个公司CEO和领导都极为重视的关键指标。然而&#xff0c;由于市场的不断变化&#xff0c;准确地预测和管理这些数据变得愈发具有挑战性。为了应对这一挑战&#xff0c;建立一个高效的系统来管理和审查销售数据的重…

机器人路径规划:基于斑翠鸟优化算法(Pied Kingfisher Optimizer ,PKO)的机器人路径规划(提供MATLAB代码)

一、机器人路径规划介绍 移动机器人&#xff08;Mobile robot&#xff0c;MR&#xff09;的路径规划是 移动机器人研究的重要分支之&#xff0c;是对其进行控制的基础。根据环境信息的已知程度不同&#xff0c;路径规划分为基于环境信息已知的全局路径规划和基于环境信息未知或…

【go从入门到精通】if else 条件控制

作者简介&#xff1a; 高科&#xff0c;先后在 IBM PlatformComputing从事网格计算&#xff0c;淘米网&#xff0c;网易从事游戏服务器开发&#xff0c;拥有丰富的C&#xff0c;go等语言开发经验&#xff0c;mysql&#xff0c;mongo&#xff0c;redis等数据库&#xff0c;设计模…

行业研究数据/报告网站 - 好用免费

前言 转眼进互联网 12 年了&#xff0c;先后在百度、汽车之家、某独角兽做业务和商业分析。与CEO、VP、业务owner等都对接过&#xff0c;1 个明显共性&#xff0c;就是大家都很关注外部行业数据&#xff0c;为决策提供参考。接下来&#xff0c;就和大家分享一下我收藏的 34 个…

迷宫(蓝桥杯)——DFS和BFS

迷宫 题目描述 下图给出了一个迷宫的平面图&#xff0c;其中标记为 1 的为障碍&#xff0c;标记为 0 的为可以通行的地方。 010000 000100 001001 110000迷宫的入口为左上角&#xff0c;出口为右下角&#xff0c;在迷宫中&#xff0c;只能从一个位置走到这 个它的上、下、左…

【数据结构】双向奔赴的爱恋 --- 双向链表

关注小庄 顿顿解馋๑ᵒᯅᵒ๑ 引言&#xff1a;上回我们讲解了单链表(单向不循环不带头链表)&#xff0c;我们可以发现他是存在一定缺陷的&#xff0c;比如尾删的时候需要遍历一遍链表&#xff0c;这会大大降低我们的性能&#xff0c;再比如对于链表中的一个结点我们是无法直接…

思通舆情 是一款开源免费的舆情系统 介绍

思通舆情 是一款开源免费的舆情系统。 支持本地化部署&#xff0c;支持在线体验。 支持对海量舆情数据分析和挖掘。 无论你是使用者还是共同完善的开发者&#xff0c;欢迎 pull request 或者 留言对我们提出建议。 您的支持和参与就是我们坚持开源的动力&#xff01;请 sta…

超越Sora!StreamingT2V AI视频模型,轻松打造120秒视觉盛宴

近日&#xff0c;来自美国德克萨斯大学奥斯汀分校&#xff08;UT奥斯丁&#xff09;等机构的研究人员提出了一项名为StreamingT2V的AI视频生成技术&#xff0c;引起了业界的广泛关注。这项技术打破了传统视频生成的局限&#xff0c;实现了高度一致且长度可扩展的视频生成&#…

C语言(结构体,联合体,枚举的讲解)

这期我们来讲解结构体&#xff0c;联合体&#xff0c;以及枚举的讲解&#xff0c;首先我们从概念开始一步一步的了解。 1&#xff0c;结构体 1.1概念 C 语言中的结构体是一种用户自定义的数据类型&#xff0c;它允许你将不同类型的变量组合在一起&#xff0c;从而形成一个新…

1978-2022年全国31省社会消费品零售总额数据

1978-2022年全国31省社会消费品零售总额数据 1、时间&#xff1a;1978-2022年 2、指标&#xff1a;社会消费品零售总额 3、范围&#xff1a;31省市 4、来源&#xff1a;整理自国家统计J和各省年鉴 5、缺失情况说明&#xff1a;1997-2022年31省市均无缺失&#xff0c; 199…

海外媒体发稿:9种高效的媒体套餐内容发稿策略分析-华媒舍

海外媒体发稿&#xff1a;9种高效的媒体套餐内容发稿策略分析高效的媒体发布和营销推广策略对公司、本人的成就尤为重要。下面我们就对于媒体套餐内容发稿营销推广策略开展全面解析&#xff0c;帮助读者掌握并应用这9种合理的思路&#xff0c;进而获得更好的媒体营销效果。 1.媒…

【Python系列】Python 中 YAML 文件与字典合并的实用技巧

&#x1f49d;&#x1f49d;&#x1f49d;欢迎来到我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里可以感受到一份轻松愉快的氛围&#xff0c;不仅可以获得有趣的内容和知识&#xff0c;也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学…