【Datawhale AI 夏令营】第四期 基于2B源大模型 微调

定位:代码复现贴
教程:https://datawhaler.feishu.cn/wiki/PLCHwQ8pai12rEkPzDqcufWKnDd

模型加载

model = AutoModelForCausalLM.from_pretrained(path, device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True
)
  • AutoModelForCausalLM.from_pretrained(path):

    • 这是 transformers 库中的一种通用方法,用于从预训练模型路径(path)加载一个因果语言模型(Causal Language Model,CLM)。
    • 因果语言模型是一种序列到序列的模型,通常用于生成任务,例如自动完成或文本生成。
  • device_map="auto":

    • 该参数用于自动选择计算设备(如 GPU 或 CPU)来加载模型。设置为 "auto" 后,模型会根据可用资源自动映射到适当的设备。
  • torch_dtype=torch.bfloat16:

    • 这将模型的计算精度设置为 bfloat16(一种 16 位浮点格式),这通常用于加速计算和减少显存占用,同时保持数值稳定性。
  • trust_remote_code=True:

    • 这个参数表示信任远程代码,允许加载自定义模型结构。如果预训练模型所在的路径中包含自定义的模型定义文件(而不是标准的 transformers 库模型),这个选项允许这些自定义代码被执行。

输出的模型如下:
在这里插入图片描述

模型结构分析

Yuan 在 Transformer 的 Decoder 进行改进,引入了一种新的注意力机制 Localized Filtering-based Attention(LFA)

在这里插入图片描述

  • YuanForCausalLM:

    • 这是一个自定义的因果语言模型类,可能来自于远程代码定义。该模型包含了实际的 YuanModel 和一个 lm_head(语言模型的输出头)。
  • YuanModel:

    • 该模型是 YuanForCausalLM 的核心部分,包含嵌入层、多个解码器层(YuanDecoderLayer)、和一个归一化层。
  • embed_tokens:

    • 这是词嵌入层,用于将输入的标记(tokens)转换为高维向量表示。这里的词表大小为 135040,每个标记被嵌入到一个 2048 维的向量空间中。
  • layers:

    • 这是模型的主体,由 24YuanDecoderLayer 组成,每个解码器层包含自注意力机制、MLP(多层感知器)层、和归一化层。
  • YuanAttention:

    • 这是一个自注意力机制模块,包含了查询(q_proj)、键(k_proj)、值(v_proj)的线性投影,以及一个旋转嵌入(rotary_emb)和本地过滤模块(lf_gate)。
  • YuanMLP:

    • 这是一个 MLP 层,包含了向上和向下的线性投影(up_projdown_proj),以及一个激活函数 SiLU
  • YuanRMSNorm:

    • 这是一个归一化层,使用 RMSNorm(Root Mean Square Layer Normalization)来稳定训练过程。
  • lm_head:

    • 这是模型的输出层,用于将解码器层的输出转换为预测的词概率分布。它是一个线性层,输入维度为 2048,输出维度为 135040(与词表大小一致)。

配置Lora

from peft import LoraConfig, TaskType, get_peft_modelconfig = LoraConfig(task_type=TaskType.CAUSAL_LM, target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],inference_mode=False, # 训练模式r=8, # Lora 秩lora_alpha=32, # Lora alaph,具体作用参见 Lora 原理lora_dropout=0.1# Dropout 比例
)

我们输出config,可以观测到其中的完整配置选项。

LoraConfig(peft_type=<PeftType.LORA: 'LORA'>, auto_mapping=None, base_model_name_or_path=None, revision=None, task_type=<TaskType.CAUSAL_LM: 'CAUSAL_LM'>, inference_mode=False, r=8, target_modules={'k_proj', 'down_proj', 'o_proj', 'up_proj', 'gate_proj', 'v_proj', 'q_proj'},lora_alpha=32, lora_dropout=0.1, fan_in_fan_out=False, bias='none', use_rslora=False, modules_to_save=None, init_lora_weights=True, layers_to_transform=None, layers_pattern=None, rank_pattern={}, alpha_pattern={}, megatron_config=None, megatron_core='megatron.core', loftq_config={}, use_dora=False, # <=== doralayer_replication=None, runtime_config=LoraRuntimeConfig(ephemeral_gpu_offload=False)) 

没想到后面还有一个use_dora的选项,碰巧之前浏览过这块,可以分享一下:

DoRA

首先对预训练模型的权重进行分解,将每个权重矩阵分解为幅度(magnitude)向量和方向(direction)矩阵

在微调过程中,DoRA使用LoRA进行方向性更新,只调整方向部分的参数,而保持幅度部分不变。这种方式可以减少需要调整的参数数量,提高微调的效率。

在这里插入图片描述

后面,我们构建一个 PeftModel并且查看对应的训练参数量占比:

# 构建PeftModel
model = get_peft_model(model, config)
model.print_trainable_parameters()

输出如下:

trainable params: 9,043,968 || all params: 2,097,768,448 || trainable%: 0.4311

总参数量为 2,097,768,448(~ 21亿参数),使用LoRA后只需要微调的参数量为 9,043,968(~904万参数),约占总参数量的0.4311%

但是后面微调还是爆了,所以稍微去除一点不太重要的微调目标模块(个人观点),但是肯定会损耗微调性能的。

config = LoraConfig(task_type=TaskType.CAUSAL_LM, target_modules=["q_proj", "k_proj", "v_proj"],inference_mode=False, # 训练模式r=4, # Lora 秩lora_alpha=32, # Lora alaph,具体作用参见 Lora 原理lora_dropout=0.1# Dropout 比例
)

后续输出微调的参数占比为:

trainable params: 2,359,296 || all params: 2,091,083,776 || trainable%: 0.1128

当然,也降低了批处理大小 (牺牲速度):

# 设置训练参数
args = TrainingArguments(output_dir="./output/Yuan2.0-2B_lora_bf16",per_device_train_batch_size=6, # <===== 12gradient_accumulation_steps=1,logging_steps=1,save_strategy="epoch",num_train_epochs=3,learning_rate=5e-5,save_on_each_node=True,gradient_checkpointing=True,bf16=True
)

微调成功之后效果如下,即便增加了一些其他信息,也能保持相关的抽取。

在这里插入图片描述
(但是多次几次依旧容易翻车,会输出极其符合数据集分布的答案。)

数据集中的组织名和姓名是互斥的,且中国难识别归类到国籍。

在这里插入图片描述

关于更多的微调知识,感觉可以参考这篇知乎大佬的笔记:https://zhuanlan.zhihu.com/p/696837567

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/406306.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

AI学习记录 - 如何快速构造一个简单的token词汇表

创作不易&#xff0c;有用的话点个赞 先直接贴代码&#xff0c;我们再慢慢分析&#xff0c;代码来自openai的图像分类模型的一小段 def bytes_to_unicode():"""Returns list of utf-8 byte and a corresponding list of unicode strings.The reversible bpe c…

压测工具哪个好?LoadRunner、Jmeter、Locust、Wrk 全方位对比....

当你想做性能测试的时候&#xff0c;你会选择什么样的测试工具呢&#xff1f;是会选择wrk&#xff1f;jmeter&#xff1f;locust&#xff1f;还是loadrunner呢&#xff1f;今天&#xff0c;笔者将根据自己使用经验&#xff0c;针对jmeter、locust、wrk和loadrunner常用的性能测…

前后端部署-服务器linux中部署Node.js环境

一.安装分布式版本管理系统Git (Alibaba Cloud Linux 3/2、CentOS 7.x) sudo yum install git -y 二.使用Git将NVM的源码克隆到本地的~/.nvm目录下&#xff0c;并检查最新版本。 git clone https://gitee.com/mirrors/nvm.git ~/.nvm && cd ~/.nvm && gi…

RVG29;狂犬病病毒肽;狂犬病病毒糖蛋白;115136-25-9;YTIWMPENPRPGTPCDIFTNSRGKRASNG

【RVG29 简介】 RVG29&#xff08;狂犬病病毒肽&#xff09;是一种由29个氨基酸组成的细胞穿透肽&#xff0c;它来源于狂犬病病毒糖蛋白。RVG肽能够特异性识别并结合中枢神经系统中普遍存在的烟碱型乙酰胆碱受体&#xff08;nAChR&#xff09;&#xff0c;并通过受体介导的转胞…

AR 眼镜之-系统应用音效-实现方案

目录 &#x1f4c2; 前言 AR 眼镜系统版本 系统应用音效 1. &#x1f531; 技术方案 1.1 技术方案概述 1.2 实现方案 1&#xff09;初始化 2&#xff09;播放音效 3&#xff09;释放资源 2. &#x1f4a0; 播放音效 2.1 静音不播放 2.2 获取音效默认音量 3. ⚛️ …

2.初识springcloud

文章目录 1.什么是SpringCloud1.1版本的介绍 2.Spring Cloud实现方案3.环境搭建4.服务拆分原则5.数据准备5.1订单服务5.2商品服务 大家好&#xff0c;我是晓星航。今天为大家带来的是 初识springcloud 相关的讲解&#xff01;&#x1f600; 1.什么是SpringCloud 简单来说&…

【算法基础实验】图论-最小生成树-Prim的即时实现

理论知识 Prim算法是一种用于计算加权无向图的最小生成树&#xff08;MST, Minimum Spanning Tree&#xff09;的贪心算法。最小生成树是一个连通的无向图的子图&#xff0c;它包含所有的顶点且总权重最小。Prim算法从一个起始顶点开始&#xff0c;不断将权重最小的边加入生成…

Excel表格添加趋势线_数据拟合

一个曲线通过补偿算法拟合为另一个曲线&#xff0c;通常可以通过多种数学和计算技术实现。这里也可以通过Excel表格添加趋势线&#xff0c;然后对趋势线进行拟合&#xff0c;得到趋势预测公式来达到数据补偿。 通过把你需要的数据导入到Excel表格中。 通过 “ 插入 ” --> “…

谷歌云AI新作:CROME,跨模态适配器高效多模态大语言模型

CROME: Cross-Modal Adapters for Efficient Multimodal LLM https://arxiv.org/pdf/2408.06610 Abstract 研究对象&#xff1a;Multimodal Large Language Models (MLLMs) demonstrate remarkable imagelanguage capabilities, but their widespread use faces challenges in…

论坛 推荐

畅议论坛&#xff1a;http://udbbs.top/http://udbbs.top/

查看U盘的具体信息,分区表格式、实际容量和分区状态

查看U盘的具体信息&#xff0c;分区表格式、实际容量和分区状态 前言&#xff1a; 利用windows自带的命令行窗口就可以 1、使用命令提示符查看MBR和GPT分区类型 &#xff08;1&#xff09;按“Windows R”键&#xff0c;在弹出的运行对话框中输入“diskpart”&#xff0c;并按…

游戏开发设计模式之工厂模式

目录 简单工厂模式&#xff08;Simple Factory Pattern&#xff09; 应用场景&#xff1a; 优缺点&#xff1a; 工厂方法模式&#xff08;Factory Method Pattern&#xff09; 应用场景&#xff1a; 优缺点&#xff1a; 抽象工厂模式&#xff08;Abstract Factory Patte…

碰撞检测 | 基于ROS Rviz插件的多边形碰撞检测仿真平台

目录 0 专栏介绍1 基于多边形的碰撞检测2 碰撞检测仿真平台搭建2.1 多边形实例2.2 外部服务接口2.3 Rviz插件化 3 案例演示3.1 功能介绍3.2 绘制多边形 0 专栏介绍 &#x1f525;课设、毕设、创新竞赛必备&#xff01;&#x1f525;本专栏涉及更高阶的运动规划算法轨迹优化实战…

Debian12安装tomcat8

jdk安装 安装Tomcat前需要先安装JDK&#xff0c;JDK安装参见&#xff1a; https://zhengshaoshaolin.blog.csdn.net/article/details/141407600 tomcat安装 1、下载安装 Apache Tomcat 访问官方 Apache Tomcat 下载页面&#xff0c;获取最新的二进制文件 或者使用如下的 wg…

Spring DI 数据类型—— set 方法注入

首先新建项目&#xff0c;可参考 初识IDEA、模拟三层--控制层、业务层和数据访问层 一、spring 环境搭建 &#xff08;一&#xff09;pom.xml 导相关坐标 <?xml version"1.0" encoding"UTF-8"?> <project xmlns"http://maven.apache.or…

【Win开发环境搭建】Redis与可视化工具详细安装与配置过程

&#x1f3af;导读&#xff1a;本文档提供了Redis的简介、安装指南、配置教程及常见操作方法。包括了安装包的选择与配置环境变量的过程&#xff0c;详细说明了如何通过修改配置文件来设置密码和端口等内容。同时&#xff0c;文档还介绍了如何使用命令行工具连接Redis&#xff…

ArcGIS如何将投影坐标系转回为地理坐标系

有时候两个数据&#xff0c;一个为投影坐标系&#xff0c;另一个为地理坐标系时&#xff0c;在GIS软件中位置无法叠加到一起&#xff0c;这需要将两个或多个数据的坐标系统一&#xff0c;可以直接将地理坐标系的数据进行投影&#xff0c;或将投影坐标系转为地理坐标系。下面介绍…

在使用Simulink进行FOC(Field-Oriented Control,场向量控制)仿真时,如果遇到波形丢失精度的问题,该这么解决

&#x1f3c6;本文收录于《CSDN问答解惑-专业版》专栏&#xff0c;主要记录项目实战过程中的Bug之前因后果及提供真实有效的解决方案&#xff0c;希望能够助你一臂之力&#xff0c;帮你早日登顶实现财富自由&#x1f680;&#xff1b;同时&#xff0c;欢迎大家关注&&收…

【Qt】输入类控件QLineEdit

目录 输入类控件QLineEdit 例子&#xff1a;录入个人信息 例子&#xff1a;使用正则表达式验证输入框的数据 例子&#xff1a;验证俩次输入密码一致 例子&#xff1a;切换显示代码 输入类控件QLineEdit QLineEdit 用来表示单行输入框&#xff0c;可以输入一段文本&#xf…

网络安全入门教程(非常详细)从零基础入门到精通_网路安全 教程

前言 1.入行网络安全这是一条坚持的道路&#xff0c;三分钟的热情可以放弃往下看了。2.多练多想&#xff0c;不要离开了教程什么都不会了&#xff0c;最好看完教程自己独立完成技术方面的开发。3.有时多百度&#xff0c;我们往往都遇不到好心的大神&#xff0c;谁会无聊天天给…