大模型增量预训练新技巧:解决灾难性遗忘

大家好,目前不少开源模型在通用领域具有不错的效果,但由于缺乏领域数据,往往在一些垂直领域中表现不理想,这时就需要增量预训练和微调等方法来提高模型的领域能力。

但在领域数据增量预训练或微调时,很容易出现灾难性遗忘现象,也就是学会了垂直领域知识,但忘记了通用领域知识,之前介绍过增量预训练以及领域大模型训练技巧。

今天给大家带来一篇增量预训练方法-Llama-Pro,对LLMs进行Transformer块扩展后,增量预训练过程中仅对新增块进行训练,有效地进行模型知识注入,并且极大程度地避免灾难性遗忘。

图片

LLaMA Pro: Progressive LLaMA with Block Expansion

LLaMA Pro: Progressive LLaMA with Block Expansion
Paper: https://arxiv.org/abs/2401.02415
Github: https://github.com/TencentARC/LLaMA-Pro

文章目录

    • 技术交流群
    • 用通俗易懂方式讲解系列
    • 块扩展方法
    • 实验细节
    • 讨论分析
    • 写在最后

技术交流群

前沿技术资讯、算法交流、求职内推、算法竞赛、面试交流(校招、社招、实习)等、与 10000+来自港科大、北大、清华、中科院、CMU、腾讯、百度等名校名企开发者互动交流~

我们建了大模型面试与技术交流群, 想要进交流群、需要源码&资料、提升技术的同学,可以直接加微信号:mlc2060。加的时候备注一下:研究方向 +学校/公司+CSDN,即可。然后就可以拉你进群了。

方式①、微信搜索公众号:机器学习社区,后台回复:加群
方式②、添加微信号:mlc2060,备注:技术交流

资料1
在这里插入图片描述

用通俗易懂方式讲解系列

  • 用通俗易懂的方式讲解:自然语言处理初学者指南(附1000页的PPT讲解)
  • 用通俗易懂的方式讲解:1.6万字全面掌握 BERT
  • 用通俗易懂的方式讲解:NLP 这样学习才是正确路线
  • 用通俗易懂的方式讲解:28张图全解深度学习知识!
  • 用通俗易懂的方式讲解:不用再找了,这就是 NLP 方向最全面试题库
  • 用通俗易懂的方式讲解:实体关系抽取入门教程
  • 用通俗易懂的方式讲解:灵魂 20 问帮你彻底搞定Transformer
  • 用通俗易懂的方式讲解:图解 Transformer 架构
  • 用通俗易懂的方式讲解:大模型算法面经指南(附答案)
  • 用通俗易懂的方式讲解:十分钟部署清华 ChatGLM-6B,实测效果超预期
  • 用通俗易懂的方式讲解:内容讲解+代码案例,轻松掌握大模型应用框架 LangChain
  • 用通俗易懂的方式讲解:如何用大语言模型构建一个知识问答系统
  • 用通俗易懂的方式讲解:最全的大模型 RAG 技术概览
  • 用通俗易懂的方式讲解:利用 LangChain 和 Neo4j 向量索引,构建一个RAG应用程序
  • 用通俗易懂的方式讲解:使用 Neo4j 和 LangChain 集成非结构化知识图增强 QA
  • 用通俗易懂的方式讲解:面了 5 家知名企业的NLP算法岗(大模型方向),被考倒了。。。。。
  • 用通俗易懂的方式讲解:NLP 算法实习岗,对我后续找工作太重要了!。
  • 用通俗易懂的方式讲解:理想汽车大模型算法工程师面试,被问的瑟瑟发抖。。。。
  • 用通俗易懂的方式讲解:基于 Langchain-Chatchat,我搭建了一个本地知识库问答系统
  • 面试了字节大模型算法岗(实习),快被问哭了。。。。

块扩展方法

块扩展,顾名思义,就是在原始模型中每个Transformer块或者某几个Transformer块后增加一个Transformer块,但为了保持扩展后的模型输出保持不变,需要增加的块为恒等块(输入输出相同),如下图所示。

图片

在构建恒等块过程中,主要是将多头注意力层和FFN层中的最后一个线性层(Linear)权重置为0变成Zero-Linear,即可保持经过该块的输入输出一致。

PS:论文附录A中写了大段的推导公式来证明,在此不做过多介绍。

块的增加方式是,对原始模型的 个Transformer块分成 组,每组中包含 个Transformer块,对于每组后添加 个恒等块。代码实现具体如下:

model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16)
ckpt = model.state_dict()# original_layers是模型原始层数,layers是模型最后达到层数
split = int(original_layers / (layers - original_layers))layer_cnt = 0output = {}
for i in range(original_layers):for k in ckpt:if ('layers.' + str(i) + '.') in k:output[k.replace(('layers.' + str(i) + '.'), ('layers.' + str(layer_cnt) + '.'))] = ckpt[k]layer_cnt += 1if (i+1) % split == 0:for k in ckpt:if ('layers.' + str(i) + '.') in k:if 'down_proj' in k or 'o_proj' in k:output[k.replace(('layers.' + str(i) + '.'), ('layers.' + str(layer_cnt) + '.'))] = torch.zeros_like(ckpt[k])else:output[k.replace(('layers.' + str(i) + '.'), ('layers.' + str(layer_cnt) + '.'))] = ckpt[k]layer_cnt += 1assert layer_cnt==layers
for k in ckpt:if not 'layers' in k:output[k] = ckpt[k]torch.save(output, output_path)

实验细节

数据由代码和数学组成,其中代码数据采用The-Stack-Dedup数据集中Python语言部分共22B Token,数学数据采用Proof-Pile-2数据集中AlgebraicStack、OpenWebMath和ArXiv部分共55B,详细如下表所示。

图片

数据分布

基础模型为LLaMA2-7B模型,通过块扩展方法将32层模型扩展到40层,其中 、 、 ,每个组从4个Transformer块扩展到5个Transformer块。

对于代码和数学数据进行增量预训练,批量大小为1024,序列最大长度为4096,预热比率为6%,学习率为2e-4,采用余弦学习率调度器,BF16混合精度训练,权重衰减为0.1。使用16个NVIDIA H800 GPU进行了15900个步骤的训练,大约耗费2830个GPU/小时。

在ARC、HellaSwag、MMLU、TruthfulQA、Winogrande、GSM8K、GSM8K-PoT、HumanEval、MBPP等多个评测数据集中进行评测,可以看出,在保持通用任务能力不下降的情况下,数学和代码能力较原始LLaMA2-7B模型有很大提升。

图片

图片

讨论分析

对比块扩展方法与正常训练和Lora方法之间的区别,采用TRACE基准利用总体性能(OP)和逆向转移(BWT)指标进行评估。,如下表所示,块扩展方法整体提升较大。

图片

对比块个数对块扩展方法的影响,进行了不同个数块的实验,并且对比了MoE的方法,训练损失如下,MoE方法的损失下降程度与添加四个块相当。

图片

在代码和法律(16.7B)领域数据下进行增量预训练,在通用任务以及领域任务上比较不同个数块之间的差异,同时比较扩展块全部添加到模型底部或顶部之间的差别,如下所示。可以发现块个数为8时效果最佳,并且不能直接将扩展块全部堆积在头部或尾部,需要分开插入。

图片

写在最后

该方法主要通过增加恒定块扩展模型层数,使模型在增量训练过程中仅训练新增层、冻结原始层,保持模型原有能力,防止模型出现灾难性遗忘现象。

但有两点存疑:

  • 目前来说mistral要好于llama,为啥不用mistral进行实验

  • 不用恒定块,性能会差多少

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/251026.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【开源】基于JAVA+Vue+SpringBoot的河南软件客服系统

目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块2.1 系统管理人员2.2 业务操作人员 三、系统展示四、核心代码4.1 查询客户4.2 新增客户跟进情况4.3 查询客户历史4.4 新增服务派单4.5 新增客户服务费 五、免责说明 一、摘要 1.1 项目介绍 基于JAVAVueSpringBootMySQL的河…

css新手教程

css新手教程 课程:14、盒子模型及边框使用_哔哩哔哩_bilibili 一.什么是CSS 1.什么是CSS Cascading Style Sheet 层叠样式表。 CSS:表现(美化网页) 字体,颜色,边距,高度,宽度&am…

Node.js-1

Node.js 简介 定义:Node.js 是一个跨平台 JavaScript 运行环境,使开发者可以搭建服务器端的 JavaScript 应用程序 为什么 Node.js 能执行 JS 代码: Chrome 浏览器能执行 JS 代码,依靠的是内核中的 V8引擎(即&#x…

【Iot】什么是串口?什么是串口通信?串口通信(串口通讯)原理,常见的串口通信方式有哪些?

串口通信原理 1. 串口2. 串口通信4. 波特率与比特率5. 帧格式3. 串口通讯的通讯协议3.1. RS2323.2. RS485 总结 1. 串口 串行接口简称串口,也称串行通信接口或串行通讯接口(通常指COM接口),是采用串行通信方式的扩展接口。 串口可…

python二维高斯热力图绘制简单的思路代码

import numpy as np import matplotlib.pyplot as plt from scipy.ndimage import gaussian_filter import cv2# 生成一个示例图像 image_size 100 image np.zeros((image_size, image_size))# 在图像中心创建一个高亮区域 center_x, center_y image_size // 2, image_size …

[工具探索]Safari 和 Google Chrome 浏览器内核差异

最近有些Vue3的项目,使用了safari进行测试环境搞开发,发现页面存在不同程序的页面乱码情况,反而google浏览器没问题,下面我们就对比下他们之间的差异点: 日常开发google chrome占多数;现在主流浏览器 Goog…

【LeetCode: 292. Nim 游戏+ 博弈问题】

🚀 算法题 🚀 🌲 算法刷题专栏 | 面试必备算法 | 面试高频算法 🍀 🌲 越难的东西,越要努力坚持,因为它具有很高的价值,算法就是这样✨ 🌲 作者简介:硕风和炜,…

MagicVideo-V2:多阶段高保真视频生成框架

本项工作介绍了MagicVideo-V2,将文本到图像模型、视频运动生成器、参考图像embedding模块和帧内插模块集成到端到端的视频生成流程中。由于这些架构设计的好处,MagicVideo-V2能够生成具有极高保真度和流畅度的美观高分辨率视频。通过大规模用户评估&…

(bean配置类的注解开发)学习Spring的第十三天

bean配置类的注解开发 问题提出 用类充当配置文件 applicationcontext.xml : Configuration注解标识此类为配置类,替代原有xml文件 看原配置文件applicationcontext.xml代码 <?xml version"1.0" encoding"UTF-8"?> <beans xmlns"http:/…

cesium-水平测距

cesium测量两点间的距离 <template><div id"cesiumContainer" style"height: 100vh;"></div><div id"toolbar" style"position: fixed;top:20px;left:220px;"><el-breadcrumb><el-breadcrumb-item&…

Qt实现类似ToDesk顶层窗口 不规则按钮

先看效果&#xff1a; 在进行多进程开发时&#xff0c;可能会遇到需要进行全局弹窗的需求。 因为平时会使用ToDesk进行远程桌面控制&#xff0c;在电脑被控时&#xff0c;ToDesk会在右下角进行一个顶层窗口的提示&#xff0c;效果如下&#xff1a; 其实要实现顶层窗口&#xf…

【ELK】logstash快速入门

1.概述 1.1.什么是logstash&#xff1f; 之前我们聊了es&#xff0c;并且用docker搭建了一个eskibana的环境。es目前最普遍的用法是用来存储日志的&#xff0c;然后结合kibana对日志做一些可视化的工作。既然要收集日志&#xff0c;就面临着一个问题&#xff1a; 各个系统的…

如何保证MySQL和Redis中的数据一致性?

文章目录 前言一、缓存案例1.1 缓存常见用法1.2 缓存不一致产生的原因 二、解决方案2.1 先删除缓存&#xff0c;再更新数据库2.2 先更新数据库&#xff0c;删除缓存2.3 只更新缓存&#xff0c;由缓存自己同步更新数据库2.4 只更新缓存&#xff0c;由缓存自己异步更新数据库2.5 …

ElementUI 组件:Layout布局(el-row、el-col)

ElementUI安装与使用指南 Layout布局 点击下载learnelementuispringboot项目源码 效果图 el-row_el-col.vue页面效果图 项目里el-row_el-col.vue代码 <script> export default {name:el-row_el-col 布局 }</script><template><div class"roo…

C/C++ (stdio.h)标准库详解

cstdio,在C语言中称为stdio.h。该库使用所谓的流与物理设备&#xff08;如键盘、打印机、终端&#xff09;或系统支持的任何其他类型的文件一起操作。 在本文将会通过介绍函数参数&#xff0c;举出实际的简单例子来帮助大家快速上手使用函数。 目录 一、流 二、库函数 1、F…

离线数仓-数据治理

目录 一、前言 1.1 数据治理概念 1.2 数据治理目标 1.3 数据治理要解决的问题 1.3.1 合规性 元数据合规性 数据质量合规性 数据安全合规性 1.3.2 成本 存储资源成本 计算资源成本 二、数据仓库发展阶段 2.1 初始期 2.2 扩张期 2.3 缓慢发展期 2.4 变革期 三、…

LabVIEW电能质量监测系统

LabVIEW电能质量监测系统 随着全球能源需求的增加以及能源危机的加剧&#xff0c;对电能的有效利用和质量监控变得越来越重要。特别是在电力系统中&#xff0c;电能质量的监测对于保证电力设备的稳定运行和提高能源利用效率具有重要意义。采用LabVIEW软件开发了一套高效的电能…

查看自己电脑是arm还是x64(x86);linux操作系统识别

1、查看自己电脑是arm还是x64&#xff08;x86&#xff09; linux 参考&#xff1a; https://liuweiqing.blog.csdn.net/article/details/131783851 uname -a如果输出是 x86_64&#xff0c;那么你的系统是 64 位的 x86 架构&#xff08;通常我们称之为 x64&#xff09;。如果…

MySQL原理(五)事务

一、介绍&#xff1a; 1、介绍&#xff1a; 在计算机术语中&#xff0c;事务(Transaction)是访问并可能更新数据库中各种数据项的一个程序执行单元(unit)。事务是恢复和并发控制的基本单位。 2、事务的4大特性 原子性、一致性、隔离性、持久性。这四个属性通常称为ACID特性…

Zoho Projects与Jira:中国市场的理想替代品之争?

在软件开发生命周期中&#xff0c;项目管理一直是一个非常重要的环节。为了更好地协作、追踪项目的进程和管理任务&#xff0c;许多公司选择了Jira这款著名的项目管理工具&#xff0c;它是个非常强大的工具&#xff0c;但是作为一款纯国外产品&#xff0c;他可能不适合中国市场…