LLaMA长度外推高性价比trick:线性插值法及相关改进源码阅读及相关记录

前言

最近,开源了可商用的llama2,支持长度相比llama1的1024,拓展到了4096长度,然而,相比GPT-4、Claude-2等支持的长度,llama的长度外推显得尤为重要,本文记录了三种网络开源的RoPE改进方式及相关源码的阅读。

关于长度外推性:https://kexue.fm/archives/9431

关于RoPE:https://kexue.fm/archives/8265

1、线性插值法

论文:EXTENDING CONTEXT WINDOW OF LARGE LANGUAGE MODELS VIA POSITION INTERPOLATION

链接:https://arxiv.org/pdf/2306.15595.pdf

思想:不进行长度外推,而是直接缩小位置索引。即:将4096的位置编码通过线性插值法压缩到2048内,这样只需在少量的4096长度的数据上继续预训练,便可达到不错的效果。

在这里插入图片描述

源码阅读(附注释)

class LlamaLinearScaledRotaryEmbedding(torch.nn.Module):def __init__(self, dim, max_position_embeddings=2048, base=10000, scale=1, device=None):super().__init__()# 相比RoPE增加scale参数self.scale = scale# inv_freq为基值向量inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))self.register_buffer("inv_freq", inv_freq)# Build here to make `torch.jit.trace` work.self.max_seq_len_cached = max_position_embeddings# 构建max_seq_len_cached大小的张量tt = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)# 张量t归一化,RoPE没有这一步t /= self.scale# einsum计算频率矩阵# 'i, j->i j’表示分别输入尺寸为[i]、[j]的向量,做笛卡尔运算得到尺寸为[i, j]的矩阵。freqs = torch.einsum("i,j->ij", t, self.inv_freq)# Different from paper, but it uses a different permutation in order to obtain the same calculation# 在-1维做一次拷贝、拼接emb = torch.cat((freqs, freqs), dim=-1)dtype = torch.get_default_dtype()# 注册为模型的缓冲区cos_cached和sin_cachedself.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)def forward(self, x, seq_len=None):# x: [bs, num_attention_heads, seq_len, head_size]# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.# seq_len为序列长度,seq_len大于max_seq_len_cached,则重新计算频率矩阵,并更新cos_cached和sin_cached的缓冲区if seq_len > self.max_seq_len_cached:self.max_seq_len_cached = seq_lent = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)t /= self.scalefreqs = torch.einsum("i,j->ij", t, self.inv_freq)# Different from paper, but it uses a different permutation in order to obtain the same calculationemb = torch.cat((freqs, freqs), dim=-1).to(x.device)self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False)self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False)# 长度裁剪:返回cos_cached和sin_cached中与seq_len(序列长度)return (self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),)

线性插值法的相关实验效果:https://lmsys.org/blog/2023-06-29-longchat/

2、NTK插值法

NTK插值改进llama中使用的RoPE插值方法,同样,对于RoPE代码改动更小,其他地方与线性插值法实现一致。

reddit原帖:NTK-Aware Scaled RoPE allows LLaMA models to have extended (8k+) context size without any fine-tuning and minimal perplexity degradation

链接:https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/?rdt=58346

源码阅读:

class LlamaNTKScaledRotaryEmbedding(torch.nn.Module):def __init__(self, dim, max_position_embeddings=2048, base=10000, alpha=1, device=None):super().__init__()# 与线性插值法相比,实现更简单,alpha仅用来改变basebase = base * alpha ** (dim / (dim-2))inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))self.register_buffer("inv_freq", inv_freq)# Build here to make `torch.jit.trace` work.self.max_seq_len_cached = max_position_embeddingst = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)freqs = torch.einsum("i,j->ij", t, self.inv_freq)# Different from paper, but it uses a different permutation in order to obtain the same calculationemb = torch.cat((freqs, freqs), dim=-1)dtype = torch.get_default_dtype()self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)def forward(self, x, seq_len=None):# x: [bs, num_attention_heads, seq_len, head_size]# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.if seq_len > self.max_seq_len_cached:self.max_seq_len_cached = seq_lent = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)freqs = torch.einsum("i,j->ij", t, self.inv_freq)# Different from paper, but it uses a different permutation in order to obtain the same calculationemb = torch.cat((freqs, freqs), dim=-1).to(x.device)self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False)self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False)return (self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),)

3、动态插值法

动态插值法又是对NTK插值法和线性插值法的改进,可以看作是上述两者的一种结合思想,旨在减少困惑度损失并实现更大的缩放。

reddit原帖:Dynamically Scaled RoPE further increases performance of long context LLaMA with zero fine-tuning

链接:https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/

源码阅读

class LlamaDynamicScaledRotaryEmbedding(torch.nn.Module):def __init__(self, dim, max_position_embeddings=2048, base=10000, ntk=False, device=None):super().__init__()# 是否开启NTK(Neural Tangent Kernel)self.ntk = ntkself.base = baseself.dim = dimself.max_position_embeddings = max_position_embeddings# inv_freq为基值向量inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))self.register_buffer("inv_freq", inv_freq)# Build here to make `torch.jit.trace` work.self.max_seq_len_cached = max_position_embeddingst = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)freqs = torch.einsum("i,j->ij", t, self.inv_freq)# Different from paper, but it uses a different permutation in order to obtain the same calculation# emb:[max_seq_len_cached, dim]emb = torch.cat((freqs, freqs), dim=-1)dtype = torch.get_default_dtype()self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)def forward(self, x, seq_len=None):# x: [bs, num_attention_heads, seq_len, head_size]# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.if seq_len > self.max_seq_len_cached:self.max_seq_len_cached = seq_lenif self.ntk:base = self.base * ((self.ntk * seq_len / self.max_position_embeddings) - (self.ntk - 1)) ** (self.dim / (self.dim-2))# 计算新的inv_freqinv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(x.device) / self.dim))self.register_buffer("inv_freq", inv_freq)t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)if not self.ntk:# 缩放t *= self.max_position_embeddings / seq_len# 得到新的频率矩阵freqsfreqs = torch.einsum("i,j->ij", t, self.inv_freq)# Different from paper, but it uses a different permutation in order to obtain the same calculation# freqs与自身拼接得到新的embemb = torch.cat((freqs, freqs), dim=-1).to(x.device)# 注册为模型的缓冲区cos_cached和sin_cachedself.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False)self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False)# 长度裁剪return (self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),)

网友对于困惑度的实验并取得了一定的效果:https://github.com/turboderp/exllama/pull/118

总结

本文介绍了llama通过线性插值法及相关改进方案进行长度外推的trcik,并对相关源码阅读及网络资源进行记录,个人粗浅认为,相比LongLLaMA,基于线性插值法+Finetune的方式,是一种高性价比的长度外推方案。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/93646.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于Mysqlrouter+MHA+keepalived实现高可用半同步 MySQL Cluster项目

目录 项目名称: 基于Mysqlrouter MHA keepalived实现半同步主从复制MySQL Cluster MySQL Cluster: 项目架构图: 项目环境: 项目环境安装包: 项目描述: 项目IP地址规划: 项目步骤: 一…

利用Python隧道爬虫ip轻松构建全局爬虫网络

嘿,爬虫程序员们!你们有没有碰到过需要大规模数据爬取的情况?也许你们之前遇到过网站的反爬措施,卡住你们的进度。别担心,今天我来分享一个利用Python隧道爬虫ip实现的方法,帮助你们轻松搭建全局爬虫ip网络…

Springboot 实践(4)swagger-ui 测试controller

前文项目操作,完成了项目的创建、数据源的配置以及数据库DAO程序的生成与配置。此文讲解利用swagger-ui界面,测试生成的数据库DAO程序。目前,项目swagger-ui界面如下: 以”用户管理”为例,简单讲述swagger-ui测试数据库…

案例15 Spring Boot入门案例

1. 选择Spring Initializr快速构建项目 ​ 2. 设置项目信息 ​ 3. 选择依赖 ​ 4. 设置项目名称 ​ 5. 项目结构 ​ 6. 项目依赖 自动配置了Spring MVC、内置了Tomcat、配置了Logback(日志)、配置了JSON。 ​ 7. 创建HelloController类 com.wfit.boot.hello目录下创建HelloCo…

内网搭建电影网站的实现和进行公网访问

如何实现内网搭建电影网站并进行公网访问 文章目录 如何实现内网搭建电影网站并进行公网访问前言1. 把软件分别安装到本地电脑上1.1 打开PHPStudy软件,安装一系列电影网站所需的支持软件1.2 设置MacCNS10的运行环境1.3 进入电影网页的安装程序1.4 对运行环境进行检测…

ArcGIS Pro基础入门、制图、空间分析、影像分析、三维建模、空间统计分析与建模、python融合、案例全流程科研能力提升

目录 第一章 入门篇 GIS理论及ArcGIS Pro基础 第二章 基础篇 ArcGIS数据管理与转换 第三章 数据编辑与查询、拓扑检查 第四章 制图篇 地图符号与版面设计 第五章 空间分析篇 ArcGIS矢量空间分析及应用 第六章 ArcGIS栅格空间分析及应用 第七章 影像篇 遥感影像处理 第八…

回归预测 | MATLAB实现基于SSA-KELM-Adaboost麻雀算法优化核极限学习机结合AdaBoost多输入单输出回归预测

回归预测 | MATLAB实现基于SSA-KELM-Adaboost麻雀算法优化核极限学习机结合AdaBoost多输入单输出回归预测 目录 回归预测 | MATLAB实现基于SSA-KELM-Adaboost麻雀算法优化核极限学习机结合AdaBoost多输入单输出回归预测预测效果基本介绍模型描述程序设计参考资料 预测效果 基本…

jenkins一键部署github项目

个人目前理解jenkins部署分为两步: 构建项目,如生成jar自动执行sh脚本 如果没有jenkins,我们可能需要将jar移动到服务器,然后执行java -jar跑程序,jenkins可以替代我们执行这些东西,下面从0开始&#xff0…

Nginx反向代理技巧

跨域 作为一个前端开发者来说不可避免的问题就是跨域,那什么是跨域呢? 跨域:指的是浏览器不能执行其他网站的脚本。它是由浏览器的同源策略造成的,是浏览器对javascript施加的安全限制。浏览器的同源策略是指协议,域名…

web实现酷炫的canvas粒子动画背景

文章目录 前言一、particle-bg1. git地址:2. 安装3. 使用4. 完整demo 二、tsParticles1. 源码地址:2. 安装3. 引入4. 使用5. 几个例子5.1 ts粒子五彩纸屑烟花5.2 多粒子产卵器-用tsParticles制作5.3 ts粒子鼠标吸引力5.4 粒子烟花 源码地址完结 前言 粒…

【制作npm包5】npm包制作完整教程,我的第一个npm包

制作npm包目录 本文是系列文章, 作者一个橙子pro,本系列文章大纲如下。转载或者商业修改必须注明文章出处 一、申请npm账号、个人包和组织包区别 二、了解 package.json 相关配置 三、 了解 tsconfig.json 相关配置 四、 api-extractor 学习 五、npm包…

macOS(m1/m2)破解Sublime Text和Navicat16

破解Sublime Text 说明:全程使用的是终端操作 1. 下载Sublime Text,建议使用brew下载 2. 进入到下载的app的文件夹 cd "/Applications/Sublime Text.app/Contents/MacOS/"3. 执行以下操作以确认版本是否匹配 md5 -q sublime_text | grep -i…

Git 目录详解

一、Git目录详解 在使用Git时,有几个目录和文件在Git项目中扮演着重要的角色,下面详细介绍一下这些目录和文件的作用 1、.git目录 .git目录是Git项目的核心,包含了Git的版本库和元数据等重要信息。在该目录中,有一些重要的子目录和…

Maven之mirrorof范围

mirrorOf 是 central 还是 * 的问题 在配置阿里对官方中央仓库的镜像服务器时&#xff0c;我们使用到了 <mirror> 元素。 <mirror><id>aliyunmaven</id><mirrorOf>central</mirrorOf><name>阿里云公共仓库</name><url>…

idea格式化日志打印

Live Template 需要在Live Templates里面创建一个模板组为MyTemplate 触发时机选择java 1、创建一个loge log.error($content$,$params$); content groovyScript("def params _3.collect {【it {}】}.join(, ); return \" _1 . _2 () exception > (params…

CentOS8安装Git

错误1. 执行yum命令报错 【错误&#xff1a;Invalid configuration value: failovermethodpriority in /etc/yum.repos.d/CentOS-epel.repo; 配置&#xff1a;ID 为 "failovermethod" 的 OptionBinding 不存在】 1.cd /etc/yum.repos.d 2.vim CentOS-epel.repo //…

【uniapp】中 微信小程序实现echarts图表组件的封装

插件地址&#xff1a;echarts-for-uniapp - DCloud 插件市场 图例&#xff1a; 一、uniapp 安装 npm i uniapp-echarts --save 二、文件夹操作 将 node_modules 下的 uniapp-echarts 文件夹复制到 components 文件夹下 当前不操作此步骤的话&#xff0c;运行 -> 运行到小…

vue实现可缩放拖拽盒子(亲测可用)

特征 没有依赖 使用可拖动&#xff0c;可调整大小或两者兼备定义用于调整大小的句柄限制大小和移动到父元素或自定义选择器将元素捕捉到自定义网格将拖动限制为垂直或水平轴保持纵横比启用触控功能使用自己的样式为句柄提供自己的样式 安装和基本用法 npm install --save vue-d…

【QT】 Word模板编辑、转PDF格式

很高兴在雪易的CSDN遇见你 ,给你糖糖 欢迎大家加入雪易社区-CSDN社区云 前言 本文分享基于QT进行Word模板编辑以及Word转PDF的技术,希望对各位小伙伴有所帮助! 感谢各位小伙伴的点赞+关注,小易会继续努力分享,一起进步! 你的点赞就是我的动力(^U^)ノ~YO 目录 …

排污口水质的在线监测,实时掌握排口水质助力生态治理

水是生命之源&#xff0c;良好的水生态环境是社会发展的必然要求。然而随着工业化和城市化的发展&#xff0c;人类面临空气和水环境污染等严峻挑战&#xff0c;其中水环境问题尤为突出。排污成为城市和工业生产过程中不可避免的环保问题。 为加快解决生态环境突出问题&#xff…