QLoRA代码实战

QLoRA原理参考:
BiliBili:4bit量化与QLoRA模型训练
zhihu:QLoRA(Quantized LoRA)详解

下载llama3-8b模型

from modelscope import snapshot_download
model_dir = snapshot_download('LLM-Research/Meta-Llama-3-8B-Instruct')

设置quantization_config

from transformers import BitsAndBytesConfigquantization_config = BitsAndBytesConfig(load_in_4bit=True,bnb_4bit_quant_type="nf4",bnb_4bit_use_double_quant=True,bnb_4bit_compute_dtype=torch.bfloat16,
)

加载模型

加载量化后的llama3-8b模型,大概需要6G的GPU显存。

from transformers import AutoModelForCausalLM,AutoTokenizer,TrainingArguments,Trainer,DataCollatorForSeq2Seq
model = AutoModelForCausalLM.from_pretrained(model_dir,quantization_config=quantization_config,low_cpu_mem_usage=True)
tokenizer = AutoTokenizer.from_pretrained(model_dir)

一层的数据类型,可以看到除了layernorm,linear层都进行了量化。

model.layers.0.self_attn.q_proj.weight torch.uint8
model.layers.0.self_attn.k_proj.weight torch.uint8
model.layers.0.self_attn.v_proj.weight torch.uint8
model.layers.0.self_attn.o_proj.weight torch.uint8
model.layers.0.mlp.gate_proj.weight torch.uint8
model.layers.0.mlp.up_proj.weight torch.uint8
model.layers.0.mlp.down_proj.weight torch.uint8
model.layers.0.input_layernorm.weight torch.float16
model.layers.0.post_attention_layernorm.weight torch.float16

预处理模型

from peft import prepare_model_for_kbit_training
model = prepare_model_for_kbit_training(model)

设置LoRA参数

这里使用了默认设置,参数target_modules和modules_to_save可以设置具体训练哪些模块。
在peft/utils/constants.py中,默认定义了各种模型的LoRA target modules,llama模型对Q和V进行lora。

"llama": ["q_proj", "v_proj"],
config = LoraConfig(task_type=TaskType.CAUSAL_LM)
model = get_peft_model(model, config)
model.print_trainable_parameters()
#trainable params: 3,407,872 || all params: 8,033,669,120 || trainable%: 0.0424
print(model) #加入了LoRA后的模型结构。

加载并处理数据

数据下载:AI-ModelScope/alpaca-gpt4-data-zh
需要把下载的数据中dataset_infos.json 重命名为datasets_info.json,这样才能正确加载。

from datasets import load_datasetdataset = load_dataset("alpaca-data-zh")def process_func(example):# print(example)MAX_LENGTH = 256input_ids, attention_mask, labels = [], [], []# 将prompt进行tokenize,这里我们没有利用tokenizer进行填充和截断# 这里我们自己进行截断,在DataLoader的collate_fn函数中进行填充input = example["input"] if example["input"] is not None else ''instruction = tokenizer("\n".join(["Human: " + example["instruction"], input]).strip() + "\n\nAssistant: ")# 将output进行tokenize,注意添加eos_tokenresponse = tokenizer(example["output"] + tokenizer.eos_token)# 将instruction + output组合为inputinput_ids = instruction["input_ids"] + response["input_ids"]attention_mask = instruction["attention_mask"] + response["attention_mask"]# prompt设置为-100,不计算losslabels = [-100] * len(instruction["input_ids"]) + response["input_ids"]# 设置最大长度,进行截断if len(input_ids) > MAX_LENGTH:input_ids = input_ids[:MAX_LENGTH]attention_mask = attention_mask[:MAX_LENGTH]labels = labels[:MAX_LENGTH]return {"input_ids": input_ids,"attention_mask": attention_mask,"labels": labels}tokenized_ds = dataset['train'].map(process_func, remove_columns=dataset['train'].column_names)

设置TrainingArguments

在per_device_train_batch_size=1的情况下,大概需要9G显存。

args = TrainingArguments(output_dir="./llama3_4bit",per_device_train_batch_size=4,gradient_accumulation_steps=32,logging_steps=10,num_train_epochs=1,save_strategy='epoch',learning_rate=1e-4,# gradient_checkpointing=True,# optim="paged_adamw_32bit")

训练

trainer = Trainer(model=model,args=args,tokenizer=tokenizer,train_dataset=tokenized_ds,data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True),
)
trainer.train(resume_from_checkpoint=False)

加载qlora

from transformers import AutoModelForCausalLM,AutoTokenizer
model_path = model_dir #llama3-8b的路径
model = AutoModelForCausalLM.from_pretrained(model_path,quantization_config=config,low_cpu_mem_usage=True)
tokenizer = AutoTokenizer.from_pretrained(model_path)
model_qlora = PeftModel.from_pretrained(model=model,model_id="llama3_4bit/checkpoint-7") #qlora路径
#预测
ipt = tokenizer("Human: {}\n{}".format("怎么学习llm", "").strip() + "\n\nAssistant: ", return_tensors="pt").to(model.device)
tokenizer.decode(model_qlora.generate(**ipt, max_length=128, do_sample=True)[0], skip_special_tokens=True)

合并LoRA

合并后的模型大概5.4G。

merge_model = model_qlora.merge_and_unload()
merge_model.save_pretrained("llama3")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/443303.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Linux】man手册安装使用

目录 man(manual,手册) 手册安装: 章节区分: 指令参数: 使用场景: 手册内容列表: 手册查看快捷键: 实例: 仍致谢:Linux常用命令大全(手册) – 真正好用的Linux命令在线查询网站 提供的命令查询 在开头先提醒一下:在 man 手册中退出的方法很简单…

给Windows系统设置代理的操作方法

一、什么是代理 网络代理是一种特殊的网络服务,允许一个网络终端通过这个服务与另一个网络终端进行非直接的连接,而提供代理服务的电脑系统或其它类型的网络终端被称为代理服务器。 代理服务器是网络信息的中转站,代理服务器就像是一个很大的…

map和set(c++)

前言 在前面我们在介绍二叉搜索树时我们分别实现了一个key结构和key-val结构,如果我们再进一步完善这棵树,将二叉搜索树升级为红黑树去存储key和key-val那么我们就可以得到我们今天要介绍的主角map和set。当然了标准库的实现还是有很多需要注意的地方&a…

玩机搞机基本常识-----如何在 Android 中实现默认开启某个功能 修改方法列举

我们有时候需要对安卓系统进行修改。实现其中的某些功能。让用户使用得心应手。节约时间。那么如果要实现系统中的有些功能选项开启或者关闭。就需要对系统有一定的了解。那么在 Android 中实现默认开启某个功能可以通过以下几种方式: 一、在应用的设置中添加选项 …

C语言练习

题目: 1.如果在int型变量的声明中为变量赋一个实数值(如3.12或4.6)的初始值会怎样呢?请打一段代码来看看 分析:……不用分析,开个玩笑,虽然很简单但是还是按照惯例水上一波数字 1.首先按照题目要求用函数类型int整型给变量赋值…

鸿蒙网络管理模块05——数据流量统计

如果你也对鸿蒙开发感兴趣,加入“Harmony自习室”吧!扫描下方名片,关注公众号,公众号更新更快,同时也有更多学习资料和技术讨论群。 1、概述 HarmonyOS供了基于物理网络的数据流量统计能力,支持基于网卡/U…

【PS2020】Adobe Photoshop 2020 中文免费版

photoshop 2020是全球最大的图像处理软件,为用户提供了广泛的专业级润饰工具套件,集成了专为激发灵感而设计的强大编辑功能,帮助用户制作出满意的图片效果,是很多摄影师、广告师等专业人员必备的一款图像及照片后期处理大型专业软…

网络受限情况下安装openpyxl模块提示缺少Jdcal,et_xmlfile

1.工作需要处理关于Excel文件内容的东西 2.用公司提供的openpyxl模块总是提示缺少jdcal文件,因为网络管控,又没办法直接使用命令下载,所以网上找了资源,下载好后上传到个人资源里了 资源路径 openpyxl jdcal et_xmlfile 以上模块来源于:Py…

Java-进阶二

单列集合: ----------List ArrayList的源代码分析(扩容原理) 1 使用空参构造的集合,在底层创建一个容量为0的数组。2 添加第一个元素时,底层会扩容创建一个容量为10的数组。3 存满时会扩容1.5倍。4 如果一次添加多个…

大模型基础:基本概念、Prompt、RAG、Agent及多模态

随着大模型的迅猛发展,LLM 作为人工智能的核心力量,正以前所未有的方式重塑着我们的生活、学习和工作。无论是智能语音助手、自动驾驶汽车,还是智能决策系统,大模型都是幕后英雄,让这些看似不可思议的事情变为可能。本…

Redis SpringBoot项目学习

Redis 是一个高性能的key-value内存数据库。它支持常用的5种数据结构:String字符串、Hash哈希表、List列表、Set集合、Zset有序集合 等数据类型。 Redis它解决了2个问题: 第一个是:性能 通常数据库的读操作,一般都要几十毫秒&…

虚拟机没有网络怎么解决

CentOS7为例 进入虚拟网络编辑器 1.更改设置 2.选中NAT模式点击3点击移除网络 4添加网络,随便选一个 5.点开NAT设置,记住网关 6.DHCP设置,注意虚拟机设置ip必须在起始ip和结束ip范围内 进入虚拟机网络适配器,自定义选中第4步操作…

【Kubernetes】常见面试题汇总(五十二)

目录 116. K8S 集群服务暴露失败? 117.外网无法访问 K8S 集群提供的服务? 特别说明: 题目 1-68 属于【Kubernetes】的常规概念题,即 “ 汇总(一)~(二十二)” 。 题目 69-…

torchvision.transforms.Resize()的用法

今天我在使用torchvision.transforms.Resize()的时候发现,一般Resize中放的是size或者是(size,size)这样的二元数。 这两个里面,torchvision.transforms.Resize((size,size)),大家都很清楚,会将图像的h和w大小都变成size。 但是…

大学生课程设计报告--基于JavaGUI的贪吃蛇

前言 ​ 贪吃蛇游戏是一个基础且经典的视频游戏,它适合作为学习编程的人进行一些更深入的学习,可以更加了解关于循环,函数的使用,以及面向对象是如何应用到实际项目中的; ​ 不仅如此,贪吃蛇游戏的规则在思考后可以拆分,有利于学生将更多精力去设计游戏的核心逻辑,而…

TM1618控制共阳极数码管的数据传送问题

数据传送中的问题 首先每个字节是按照一个地址写入的,而共阳极数码管的公共端是SEG引脚连接的。这使得数码管显示的编码是按照竖向的字节。如下图所示中,横向是公共端,竖向是实际编码字符字节。 数据转换方式 这样可以一次写入所有需要显示…

腾讯云SDK项目管理

音视频终端 SDK(腾讯云视立方)控制台提供项目管理功能,您可参照以下步骤为您的应用快速添加音视频通话能力和多人音视频互动能力。 若需正式开发并上线音视频应用,请在完成创建后,参照 集成指南 进行开发包下载、集成…

yolov11人物背景扣除

有时候我们需要对图片进行背景扣除和替换,本文将基于yolov11对一张图片进行背景扣除,对视频的处理同理。 安装 pip install ultralytics 2 、获取测试图片 3、代码 from ultralytics import YOLO import cv2 import nu

【概率论】泊松分布

泊松分布 若 ,则 归一性 例子 泊松分布多出现在当X表示一定时间或一定空间内出现的事件的个数这种场合,如在一定时间内某交通路口所发生的事故的个数。 将泊松分布假设为二项分布 假设条件: (1)泊松分布一般为一段时间或一…

ChatGPT:引领人工智能新潮流!

一、ChatGPT 是什么? 1. ChatGPT 的强大功能和广泛应用。 ChatGPT 作为一款先进的 AI 语言模型,拥有众多强大功能。它可以进行文本生成、文本分类、情感分析、机器翻译等多种自然语言处理任务。同时,ChatGPT 还能进行对话式交互,…