[大模型]LLaMA3-8B-Instruct Lora 微调

本节我们简要介绍如何基于 transformers、peft 等框架,对 LLaMA3-8B-Instruct 模型进行 Lora 微调。Lora 是一种高效微调方法,深入了解其原理可参见博客:知乎|深入浅出 Lora。

这个教程会在同目录下给大家提供一个 nodebook 文件,来让大家更好的学习。

环境准备

在 Autodl 平台中租赁一个 3090 等 24G 显存的显卡机器,如下图所示镜像选择 PyTorch-->2.1.0-->3.10(ubuntu22.04)-->12.1
接下来打开刚刚租用服务器的 JupyterLab,并且打开其中的终端开始环境配置、模型下载和运行演示。

在这里插入图片描述

环境配置

在完成基本环境配置和本地模型部署的情况下,你还需要安装一些第三方库,可以使用以下命令:

python -m pip install --upgrade pip
# 更换 pypi 源加速库的安装
pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simplepip install modelscope==1.9.5
pip install "transformers>=4.40.0"
pip install streamlit==1.24.0
pip install sentencepiece==0.1.99
pip install accelerate==0.29.3
pip install datasets==2.19.0
pip install peft==0.10.0MAX_JOBS=8 pip install flash-attn --no-build-isolation

注意:flash-attn 安装会比较慢,大概需要十几分钟。

考虑到部分同学配置环境可能会遇到一些问题,我们在 AutoDL 平台准备了 LLaMA3 的环境镜像,该镜像适用于该仓库的所有部署环境。点击下方链接并直接创建 Autodl 示例即可。
https://www.codewithgpu.com/i/datawhalechina/self-llm/self-llm-LLaMA3

在本节教程里,我们将微调数据集放置在根目录 /dataset。

模型下载

使用 modelscope 中的 snapshot_download 函数下载模型,第一个参数为模型名称,参数 cache_dir 为模型的下载路径。

在 /root/autodl-tmp 路径下新建 model_download.py 文件并在其中输入以下内容,粘贴代码后记得保存文件,如下图所示。并运行 python /root/autodl-tmp/model_download.py 执行下载,模型大小为 15 GB,下载模型大概需要 2 分钟。

import torch
from modelscope import snapshot_download, AutoModel, AutoTokenizer
import osmodel_dir = snapshot_download('LLM-Research/Meta-Llama-3-8B-Instruct', cache_dir='/root/autodl-tmp', revision='master')

指令集构建

LLM 的微调一般指指令微调过程。所谓指令微调,是说我们使用的微调数据形如:

{"instruction": "回答以下用户问题,仅输出答案。","input": "1+1等于几?","output": "2"
}

其中,instruction 是用户指令,告知模型其需要完成的任务;input 是用户输入,是完成用户指令所必须的输入内容;output 是模型应该给出的输出。

即我们的核心训练目标是让模型具有理解并遵循用户指令的能力。因此,在指令集构建时,我们应针对我们的目标任务,针对性构建任务指令集。例如,在本节我们使用由笔者合作开源的 Chat-甄嬛 项目作为示例,我们的目标是构建一个能够模拟甄嬛对话风格的个性化 LLM,因此我们构造的指令形如:

{"instruction": "你是谁?","input": "","output": "家父是大理寺少卿甄远道。"
}

我们所构造的全部指令数据集在根目录下。

数据格式化

Lora 训练的数据是需要经过格式化、编码之后再输入给模型进行训练的,如果是熟悉 Pytorch 模型训练流程的同学会知道,我们一般需要将输入文本编码为 input_ids,将输出文本编码为 labels,编码之后的结果都是多维的向量。我们首先定义一个预处理函数,这个函数用于对每一个样本,编码其输入、输出文本并返回一个编码后的字典:

def process_func(example):MAX_LENGTH = 384    # Llama分词器会将一个中文字切分为多个token,因此需要放开一些最大长度,保证数据的完整性input_ids, attention_mask, labels = [], [], []instruction = tokenizer(f"<|start_header_id|>user<|end_header_id|>\n\n{example['instruction'] + example['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", add_special_tokens=False)  # add_special_tokens 不在开头加 special_tokensresponse = tokenizer(f"{example['output']}<|eot_id|>", add_special_tokens=False)input_ids = instruction["input_ids"] + response["input_ids"] + [tokenizer.pad_token_id]attention_mask = instruction["attention_mask"] + response["attention_mask"] + [1]  # 因为eos token咱们也是要关注的所以 补充为1labels = [-100] * len(instruction["input_ids"]) + response["input_ids"] + [tokenizer.pad_token_id]if len(input_ids) > MAX_LENGTH:  # 做一个截断input_ids = input_ids[:MAX_LENGTH]attention_mask = attention_mask[:MAX_LENGTH]labels = labels[:MAX_LENGTH]return {"input_ids": input_ids,"attention_mask": attention_mask,"labels": labels}

Llama-3-8B-Instruct 采用的Prompt Template格式如下:

<|start_header_id|>system<|end_header_id|>
You are a helpful assistant<|eot_id|>'
<|start_header_id|>user<|end_header_id|>
你是谁?<|eot_id|>'
<|start_header_id|>assistant<|end_header_id|>
我是一个有用的助手。<|eot_id|>"

加载 tokenizer 和半精度模型

模型以半精度形式加载,如果你的显卡比较新的话,可以用torch.bfolat形式加载。对于自定义的模型一定要指定trust_remote_code参数为True

tokenizer = AutoTokenizer.from_pretrained('/root/autodl-tmp/LLM-Research/Meta-Llama-3-8B-Instruct', use_fast=False, trust_remote_code=True)model = AutoModelForCausalLM.from_pretrained('/root/autodl-tmp/LLM-Research/Meta-Llama-3-8B-Instruct', device_map="auto",torch_dtype=torch.bfloat16)

定义 LoraConfig

LoraConfig这个类中可以设置很多参数,但主要的参数没多少,简单讲一讲,感兴趣的同学可以直接看源码。

  • task_type:模型类型
  • target_modules:需要训练的模型层的名字,主要就是attention部分的层,不同的模型对应的层的名字不同,可以传入数组,也可以字符串,也可以正则表达式。
  • rlora的秩,具体可以看Lora原理
  • lora_alphaLora alaph,具体作用参见 Lora 原理

Lora的缩放是啥嘞?当然不是r(秩),这个缩放就是lora_alpha/r, 在这个LoraConfig中缩放就是 4 倍。

config = LoraConfig(task_type=TaskType.CAUSAL_LM,target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],inference_mode=False, # 训练模式r=8, # Lora 秩lora_alpha=32, # Lora alaph,具体作用参见 Lora 原理lora_dropout=0.1# Dropout 比例
)

自定义 TrainingArguments 参数

TrainingArguments这个类的源码也介绍了每个参数的具体作用,当然大家可以来自行探索,这里就简单说几个常用的。

  • output_dir:模型的输出路径
  • per_device_train_batch_size:顾名思义 batch_size
  • gradient_accumulation_steps: 梯度累加,如果你的显存比较小,那可以把 batch_size 设置小一点,梯度累加增大一些。
  • logging_steps:多少步,输出一次log
  • num_train_epochs:顾名思义 epoch
  • gradient_checkpointing:梯度检查,这个一旦开启,模型就必须执行model.enable_input_require_grads(),这个原理大家可以自行探索,这里就不细说了。
args = TrainingArguments(output_dir="./output/llama3",per_device_train_batch_size=4,gradient_accumulation_steps=4,logging_steps=10,num_train_epochs=3,save_steps=100,learning_rate=1e-4,save_on_each_node=True,gradient_checkpointing=True
)

使用 Trainer 训练

trainer = Trainer(model=model,args=args,train_dataset=tokenized_id,data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True),
)
trainer.train()

保存 lora 权重

lora_path='./llama3_lora'
trainer.model.save_pretrained(lora_path)
tokenizer.save_pretrained(lora_path)

加载 lora 权重推理

训练好了之后可以使用如下方式加载lora权重进行推理:

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from peft import PeftModelmode_path = '/root/autodl-tmp/LLM-Research/Meta-Llama-3-8B-Instruct'
lora_path = './llama3_lora' # lora权重路径# 加载tokenizer
tokenizer = AutoTokenizer.from_pretrained(mode_path)# 加载模型
model = AutoModelForCausalLM.from_pretrained(mode_path, device_map="auto",torch_dtype=torch.bfloat16)# 加载lora权重
model = PeftModel.from_pretrained(model, model_id=lora_path, config=config)prompt = "你是谁?"
messages = [# {"role": "system", "content": "现在你要扮演皇帝身边的女人--甄嬛"},{"role": "user", "content": prompt}
]text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)model_inputs = tokenizer([text], return_tensors="pt").to('cuda')generated_ids = model.generate(model_inputs.input_ids,max_new_tokens=512,do_sample=True,top_p=0.9, temperature=0.5, repetition_penalty=1.1,eos_token_id=tokenizer.encode('<|eot_id|>')[0],
)
generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]print(response)

完整微调源码:
https://github.com/datawhalechina/self-llm/blob/master/LLaMA3/LLaMA3-8B-Instruct%20Lora.ipynb

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/347311.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

服务器如何远程桌面连接不上,服务器远程桌面连接不上解决办法

服务器远程桌面连接不上&#xff0c;是IT运维中常见的挑战之一。针对这一问题&#xff0c;专业的解决方法通常涉及以下几个方面的排查与操作&#xff1a; 首先&#xff0c;我们需要检查网络连接是否正常。远程桌面连接依赖于稳定的网络连接&#xff0c;因此&#xff0c;确认服务…

【云原生】创建harbor私有仓库及使用aliyun个人仓库

1.安装docker #删除已有dockersystemctl stop docker yum remove docker \docker-client \docker-client-latest \docker-common \docker-latest \docker-latest-logrotate \docker-logrotate \docker-engine #安装docker yum install -y docker-ce-20.10.1…

企业级开源项目,云缓存解决方案:CacheCloud

CacheCloud&#xff1a;简化缓存管理&#xff0c;释放数据潜力- 精选真开源&#xff0c;释放新价值。 概览 CacheCloud是由搜狐视频团队开发的一款开源的Redis缓存云平台&#xff0c;支持Redis多种架构(Standalone、Sentinel、Cluster)高效管理、有效降低大规模redis运维成本&…

8.transformers量化

Transformers 核心设计Auto Classes Transformers Auto Classes 设计:统一接口、自动检索 AutoClasses 旨在通过全局统一的接口 from_pretrained() ,实现基于名称(路径)自动检索预训练权重(模 型)、配置文件、词汇表等所有与模型相关的抽象。 灵活扩展的配置AutoConfig…

webshell三巨头 综合分析(蚁剑,冰蝎,哥斯拉)

考点: 蚁剑,冰蝎,哥斯拉流量解密 存在3个shell 过滤器 http.request.full_uri contains "shell1.php" or http.response_for.uri contains "shell1.php" POST请求存在明文传输 ant 一般蚁剑执行命令 用垃圾字符在最开头填充 去掉垃圾字符直到可以正常bas…

Bankless:为什么 AI 需要 Crypto 的技术?

原文标题&#xff1a;《Why AI Needs Crypto’s Values》 撰文&#xff1a;Arjun Chand&#xff0c;Bankless 编译&#xff1a;Chris&#xff0c;Techub News 原文来自香港Web3媒体&#xff1a;Techub News 人工智能革命的梦想一直是一把双刃剑。 释放人工智能的潜力可以解…

世优科技AI数字人多模态交互系统“世优波塔”正式发布

2024年6月6日&#xff0c;世优科技“波塔发布会”在北京举办&#xff0c;本次发布会上&#xff0c;世优科技以全新的“波塔”产品诠释了更高效、更智能、更全面的AI数字人产品及软硬件全场景解决方案&#xff0c;实现了世优品牌、产品和价值的全面跃迁。来自行业协会、数字产业…

AIGC简介

目录 1.概述 2.诞生背景 3.作用 4.优缺点 4.1.优点 4.2.缺点 5.应用场景 5.1.十个应用场景 5.2.社交媒体内容 6.如何使用 7.未来展望 8.总结 1.概述 AIGC 是“人工智能生成内容”&#xff08;Artificial Intelligence Generated Content&#xff09;的缩写&#x…

【计算机网络】P3 计算机网络协议、接口、服务的概念、区别以及计算机网络提供的三种服务方式

目录 协议什么是协议协议是水平存活的协议的组成 接口服务服务是什么服务原语 协议与服务的区别计算机网络提供的服务的三种方式面向连接服务与无连接服务可靠服务与不可靠服务有应答服务与无应答服务 协议 什么是协议 协议&#xff0c;就是规则的集合。 在计算机网络中&…

33 _ 跨站脚本攻击(XSS):为什么Cookie中有HttpOnly属性?

通过上篇文章的介绍&#xff0c;我们知道了同源策略可以隔离各个站点之间的DOM交互、页面数据和网络通信&#xff0c;虽然严格的同源策略会带来更多的安全&#xff0c;但是也束缚了Web。这就需要在安全和自由之间找到一个平衡点&#xff0c;所以我们默认页面中可以引用任意第三…

ARM32开发--PWM高级定时器

目录 文章目录 前言 目标 学习内容 需求 高级定时器通道互补输出 开发流程 通道配置 打开互补保护电路 完整代码 练习题 总结 前言 在嵌入式软件开发中&#xff0c;PWM&#xff08;脉冲宽度调制&#xff09;技术被广泛应用于控制各种电子设备的亮度、速度等参数。…

大疆智图_空三二维重建成果传输

一、软件环境 1.1 所需软件 1、 大疆智图&#xff1a;点击下载&#xff1b;   2、 ArcGIS Pro 3.1.5&#xff1a;点击下载&#xff0c;建议使用IDM或Aria2等多线程下载器&#xff1b;   3、 IDM下载器&#xff1a;点击下载&#xff0c;或自行搜索&#xff1b;   4、 Fas…

攻防演练之-网络安全产品大巡礼二

书接上文&#xff0c;《网络安全攻防演练风云》专栏之攻防演练之-网络安全产品大巡礼一&#xff0c;这里。 “咱们中场休息一会&#xff0c;我去接杯水哈”&#xff0c;看着认真听讲的众人&#xff0c;王工很是满意&#xff0c;经常夹在甲乙两方受气的他&#xff0c;这次终于表…

VBA即用型代码手册:删除空列Delete Empty Columns

我给VBA下的定义&#xff1a;VBA是个人小型自动化处理的有效工具。可以大大提高自己的劳动效率&#xff0c;而且可以提高数据的准确性。我这里专注VBA,将我多年的经验汇集在VBA系列九套教程中。 作为我的学员要利用我的积木编程思想&#xff0c;积木编程最重要的是积木如何搭建…

面试题:ArrayList和LinkedList的区别

ArrayList和LinkedList都是Java中实现List接口的集合类&#xff0c;用于存储和操作对象列表&#xff0c;但它们在内部数据结构、性能特性和适用场景上有所不同&#xff1a; 1.内部数据结构&#xff1a; ArrayList&#xff1a;基于动态数组实现。这意味着它在内存中是连续存储…

鸿蒙元服务未来是能一“通”多端的前端形态?

2024年&#xff0c;华为鸿蒙的热度只增不减。 在2023年底就有业内人士透露&#xff0c;华为明年将推出不兼容安卓的鸿蒙版本&#xff0c;未来IOS、鸿蒙、安卓将成为三个各自独立的系统。 果不其然&#xff0c;执行力超强的华为&#xff0c;与2024年1月18日的开发者&#xff0…

web刷题记录(5)

[羊城杯 2020]easycon 进来以后就是一个默认测试页面&#xff0c; 在这种默认界面里&#xff0c;我觉得一般不会有什么注入点之类的&#xff0c;所以这里先选择用御剑扫扫目录看看有没有什么存在关键信息的页面 扫了一半发现&#xff0c;很多都是和index.php文件有关&#xff0…

C# Winform内嵌窗体(在主窗体上显示子窗体)

在开发Winform项目中&#xff0c;经常会要切换不同的窗体。通常程序都有一个主窗体&#xff0c;在切换窗体时往往需要关闭其他子窗体&#xff0c;这个实例就来介绍MDI主窗体内嵌子窗体的实现方法。 MDI主窗体要设置一个比较重要的属性&#xff0c;IsMdiContainertrue。子窗体的…

VRRP多备份组(华为)

#交换设备 VRRP多备份组 当 VRRP 配置为单备份组时&#xff0c;业务全部由 Master 设备承担&#xff0c;而 Backup 设备完全处于空闲状态&#xff0c;没有得到充分利用。VRRP 可以通过配置多备份组来实现负载分担&#xff0c;有效地解决了这一问题。 VRRP 允许同一台设备的…

ClickHouse内幕(1)数据存储与过滤机制

本文主要讲述ClickHouse中的数据存储结构&#xff0c;包括文件组织结构和索引结构&#xff0c;以及建立在其基础上的数据过滤机制&#xff0c;从Part裁剪到Mark裁剪&#xff0c;最后到基于SIMD的行过滤机制。 数据过滤机制实质上是构建在数据存储格式之上的算法&#xff0c;所…