在自定义数据集上微调Alpaca和LLaMA

本文将介绍使用LoRa在本地机器上微调Alpaca和LLaMA,我们将介绍在特定数据集上对Alpaca LoRa进行微调的整个过程,本文将涵盖数据处理、模型训练和使用流行的自然语言处理库(如Transformers和hugs Face)进行评估。此外还将介绍如何使用grado应用程序部署和测试模型。

配置

首先,alpaca-lora1 GitHub存储库提供了一个脚本(finetune.py)来训练模型。在本文中,我们将利用这些代码并使其在Google Colab环境中无缝地工作。

首先安装必要的依赖:

 !pip install -U pip!pip install accelerate==0.18.0!pip install appdirs==1.4.4!pip install bitsandbytes==0.37.2!pip install datasets==2.10.1!pip install fire==0.5.0!pip install git+https://github.com/huggingface/peft.git!pip install git+https://github.com/huggingface/transformers.git!pip install torch==2.0.0!pip install sentencepiece==0.1.97!pip install tensorboardX==2.6!pip install gradio==3.23.0

安装完依赖项后,继续导入所有必要的库,并为matplotlib绘图配置设置:

 import transformersimport textwrapfrom transformers import LlamaTokenizer, LlamaForCausalLMimport osimport sysfrom typing import Listfrom peft import (LoraConfig,get_peft_model,get_peft_model_state_dict,prepare_model_for_int8_training,)import fireimport torchfrom datasets import load_datasetimport pandas as pdimport matplotlib.pyplot as pltimport matplotlib as mplimport seaborn as snsfrom pylab import rcParams%matplotlib inlinesns.set(rc={'figure.figsize':(10, 7)})sns.set(rc={'figure.dpi':100})sns.set(style='white', palette='muted', font_scale=1.2)DEVICE = "cuda" if torch.cuda.is_available() else "cpu"DEVICE

数据

我们这里使用BTC Tweets Sentiment dataset4,该数据可在Kaggle上获得,包含大约50,000条与比特币相关的tweet。为了清理数据,删除了所有以“转发”开头或包含链接的推文。

使用Pandas来加载CSV:

 df = pd.read_csv("bitcoin-sentiment-tweets.csv")df.head()

通过清理的数据集有大约1900条推文。

情绪标签用数字表示,其中-1表示消极情绪,0表示中性情绪,1表示积极情绪。让我们看看它们的分布:

 df.sentiment.value_counts()# 0.0    860# 1.0    779# -1.0    258# Name: sentiment, dtype: int64

数据量差不多,虽然负面评论较少,但是可以简单的当成平衡数据来对待:

 df.sentiment.value_counts().plot(kind='bar');

构建JSON数据集

原始Alpaca存储库中的dataset5格式由一个JSON文件组成,该文件具有具有指令、输入和输出字符串的对象列表。

让我们将Pandas的DF转换为一个JSON文件,该文件遵循原始Alpaca存储库中的格式:

 def sentiment_score_to_name(score: float):if score > 0:return "Positive"elif score < 0:return "Negative"return "Neutral"dataset_data = [{"instruction": "Detect the sentiment of the tweet.","input": row_dict["tweet"],"output": sentiment_score_to_name(row_dict["sentiment"])}for row_dict in df.to_dict(orient="records")]dataset_data[0]

结果如下:

 {"instruction": "Detect the sentiment of the tweet.","input": "@p0nd3ea Bitcoin wasn't built to live on exchanges.","output": "Positive"}

然后就是保存生成的JSON文件,以便稍后使用它来训练模型:

 import jsonwith open("alpaca-bitcoin-sentiment-dataset.json", "w") as f:json.dump(dataset_data, f)

模型权重

虽然原始的Llama模型权重不可用,但它们被泄露并随后被改编用于HuggingFace Transformers库。我们将使用decapoda-research6:

 BASE_MODEL = "decapoda-research/llama-7b-hf"model = LlamaForCausalLM.from_pretrained(BASE_MODEL,load_in_8bit=True,torch_dtype=torch.float16,device_map="auto",)tokenizer = LlamaTokenizer.from_pretrained(BASE_MODEL)tokenizer.pad_token_id = (0  # unk. we want this to be different from the eos token)tokenizer.padding_side = "left"

这段代码使用来自Transformers库的LlamaForCausalLM类加载预训练的Llama 模型。load_in_8bit=True参数使用8位量化加载模型,以减少内存使用并提高推理速度。

代码还使用LlamaTokenizer类为同一个Llama模型加载标记器,并为填充标记设置一些附加属性。具体来说,它将pad_token_id设置为0以表示未知的令牌,并将padding_side设置为“left”以填充左侧的序列。

数据集加载

现在我们已经加载了模型和标记器,下一步就是加载之前保存的JSON文件,使用HuggingFace数据集库中的load_dataset()函数:

 data = load_dataset("json", data_files="alpaca-bitcoin-sentiment-dataset.json")data["train"]

结果如下:

 Dataset({features: ['instruction', 'input', 'output'],num_rows: 1897})

接下来,我们需要从加载的数据集中创建提示并标记它们:

 def generate_prompt(data_point):return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.  # noqa: E501### Instruction:{data_point["instruction"]}### Input:{data_point["input"]}### Response:{data_point["output"]}"""def tokenize(prompt, add_eos_token=True):result = tokenizer(prompt,truncation=True,max_length=CUTOFF_LEN,padding=False,return_tensors=None,)if (result["input_ids"][-1] != tokenizer.eos_token_idand len(result["input_ids"]) < CUTOFF_LENand add_eos_token):result["input_ids"].append(tokenizer.eos_token_id)result["attention_mask"].append(1)result["labels"] = result["input_ids"].copy()return resultdef generate_and_tokenize_prompt(data_point):full_prompt = generate_prompt(data_point)tokenized_full_prompt = tokenize(full_prompt)return tokenized_full_prompt

第一个函数generate_prompt从数据集中获取一个数据点,并通过组合指令、输入和输出值来生成提示。第二个函数tokenize接收生成的提示,并使用前面定义的标记器对其进行标记。它还向输入序列添加序列结束标记,并将标签设置为与输入序列相同。第三个函数generate_and_tokenize_prompt结合了前两个函数,生成并标记提示。

数据准备的最后一步是将数据集分成单独的训练集和验证集:

 train_val = data["train"].train_test_split(test_size=200, shuffle=True, seed=42)train_data = (train_val["train"].map(generate_and_tokenize_prompt))val_data = (train_val["test"].map(generate_and_tokenize_prompt))

我们还需要数据进行打乱,并且获取200个样本作为验证集。generate_and_tokenize_prompt()函数应用于训练和验证集中的每个示例,生成标记化的提示。

训练

训练过程需要几个参数,这些参数主要来自原始存储库中的微调脚本:

 LORA_R = 8LORA_ALPHA = 16LORA_DROPOUT= 0.05LORA_TARGET_MODULES = ["q_proj","v_proj",]BATCH_SIZE = 128MICRO_BATCH_SIZE = 4GRADIENT_ACCUMULATION_STEPS = BATCH_SIZE // MICRO_BATCH_SIZELEARNING_RATE = 3e-4TRAIN_STEPS = 300OUTPUT_DIR = "experiments"

下面就可以为训练准备模型了:

 model = prepare_model_for_int8_training(model)config = LoraConfig(r=LORA_R,lora_alpha=LORA_ALPHA,target_modules=LORA_TARGET_MODULES,lora_dropout=LORA_DROPOUT,bias="none",task_type="CAUSAL_LM",)model = get_peft_model(model, config)model.print_trainable_parameters()#trainable params: 4194304 || all params: 6742609920 || trainable%: 0.06220594176090199

我们使用LORA算法初始化并准备模型进行训练,通过量化可以减少模型大小和内存使用,而不会显着降低准确性。

LoraConfig7是一个为LORA算法指定超参数的类,例如正则化强度(lora_alpha)、dropout概率(lora_dropout)和要压缩的目标模块(target_modules)。

然后就可以直接使用Transformers库进行训练:

 training_arguments = transformers.TrainingArguments(per_device_train_batch_size=MICRO_BATCH_SIZE,gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,warmup_steps=100,max_steps=TRAIN_STEPS,learning_rate=LEARNING_RATE,fp16=True,logging_steps=10,optim="adamw_torch",evaluation_strategy="steps",save_strategy="steps",eval_steps=50,save_steps=50,output_dir=OUTPUT_DIR,save_total_limit=3,load_best_model_at_end=True,report_to="tensorboard")

这段代码创建了一个TrainingArguments对象,该对象指定用于训练模型的各种设置和超参数。这些包括:

  • gradient_accumulation_steps:在执行向后/更新之前累积梯度的更新步数。
  • warmup_steps:优化器的预热步数。
  • max_steps:要执行的训练总数。
  • learning_rate:学习率。
  • fp16:使用16位精度进行训练。

DataCollatorForSeq2Seq是transformer库中的一个类,它为序列到序列(seq2seq)模型创建一批输入/输出序列。在这段代码中,DataCollatorForSeq2Seq对象用以下参数实例化:

 data_collator = transformers.DataCollatorForSeq2Seq(tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True)

pad_to_multiple_of:表示最大序列长度的整数,四舍五入到最接近该值的倍数。

padding:一个布尔值,指示是否将序列填充到指定的最大长度。

以上就是训练的所有代码准备,下面就是训练了

 trainer = transformers.Trainer(model=model,train_dataset=train_data,eval_dataset=val_data,args=training_arguments,data_collator=data_collator)model.config.use_cache = Falseold_state_dict = model.state_dictmodel.state_dict = (lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())).__get__(model, type(model))model = torch.compile(model)trainer.train()model.save_pretrained(OUTPUT_DIR)

在实例化训练器之后,代码在模型的配置中将use_cache设置为False,并使用get_peft_model_state_dict()函数为模型创建一个state_dict,该函数为使用低精度算法进行训练的模型做准备。

然后在模型上调用torch.compile()函数,该函数编译模型的计算图并准备使用PyTorch 2进行训练。

训练过程在A100上持续了大约2个小时。我们看一下Tensorboard上的结果:

训练损失和评估损失呈稳步下降趋势。看来我们的微调是有效的。

如果你想将模型上传到Hugging Face上,可以使用下面代码,

 from huggingface_hub import notebook_loginnotebook_login()model.push_to_hub("curiousily/alpaca-bitcoin-tweets-sentiment", use_auth_token=True)

推理

我们可以使用generate.py脚本来测试模型:

 !git clone https://github.com/tloen/alpaca-lora.git%cd alpaca-lora!git checkout a48d947

我们的脚本启动的gradio应用程序

 !python generate.py \--load_8bit \--base_model 'decapoda-research/llama-7b-hf' \--lora_weights 'curiousily/alpaca-bitcoin-tweets-sentiment' \--share_gradio

简单的界面如下:

总结

我们已经成功地使用LoRa方法对Llama 模型进行了微调,还演示了如何在Gradio应用程序中使用它。

如果你对本文感兴趣,请看原文:

https://avoid.overfit.cn/post/34b6eaf7097a4929b9aab7809f3cfeaa

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/73864.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【C++】开源:跨平台轻量日志库easyloggingpp

&#x1f60f;★,:.☆(&#xffe3;▽&#xffe3;)/$:.★ &#x1f60f; 这篇文章主要介绍跨平台轻量日志库easyloggingpp。 无专精则不能成&#xff0c;无涉猎则不能通。。——梁启超 欢迎来到我的博客&#xff0c;一起学习&#xff0c;共同进步。 喜欢的朋友可以关注一下&am…

RL — 强化学习技巧

一、说明 深度学习&#xff08;DL&#xff09;很难训练&#xff0c;强化学习&#xff08;RL&#xff09;要差得多。在早期开发中&#xff0c;遵循与 DL 相同的策略&#xff1a;保持简单&#xff01;消除任何妨碍您的花里胡哨的东西&#xff0c;并将不确定性降至最低。具体到RL&…

git clone 登录 github

git clone 登录 github 目录概述需求&#xff1a; 设计思路实现思路分析1.github 设置setting2.输入passwd 参考资料和推荐阅读 Survive by day and develop by night. talk for import biz , show your perfect code,full busy&#xff0c;skip hardness,make a better result…

脑电信号处理与特征提取——6.运用机器学习技术和脑电进行大脑解码(涂毅恒)

目录 六、运用机器学习技术和脑电进行大脑解码 6.1 前言 6.2 基于脑电数据的机器学习基础分析 6.3 基于脑电数据的机器学习进阶分析 6.4 代码解读 六、运用机器学习技术和脑电进行大脑解码 6.1 前言 6.2 基于脑电数据的机器学习基础分析 6.3 基于脑电数据的机器学习进阶分…

C# 关于使用newlife包将webapi接口寄宿于一个控制台程序、winform程序、wpf程序运行

C# 关于使用newlife包将webapi接口寄宿于一个控制台程序、winform程序、wpf程序运行 安装newlife包 Program的Main()函数源码 using ConsoleApp3; using NewLife.Log;var server new NewLife.Http.HttpServer {Port 8080,Log XTrace.Log,SessionLog XTrace.Log }; serv…

hdu Perfect square number

题意&#xff1a; 有n个数&#xff08;n<300&#xff09;&#xff0c;将其中的任意的一个数改为x&#xff08;x在[1,300]&#xff09;&#xff0c;求改之后&#xff0c;区间和为完全平方数的最大区间个数是多少 思路&#xff1a; 将a[x]改之后的区间个数等于&#xff1a;改…

计算机和汇编语言

1.用电表示数字 我们已经学习过二进制来表示数字 二进制计数采用0和1组合表示数字 0和1很适合使用开关闭合&#xff0c;导线上有电流是1&#xff0c;无电流是 我们还可以加上小灯泡&#xff0c;来表示数 2.二进制加法机 上述这个加法机器是接受左边和下面的输入&#xff0c;把…

TCP三次握手

文章目录 目的场景TCP头部结构 目的 保证双方互相建立了连接。 场景 发生在客户端连接服务器的时候&#xff0c;当调用connect()&#xff1b;时&#xff0c;底层会通过TCP协议进行三次握手。 客户端发送 和 服务器接收客户端确定服务器可以收发&#xff0c;自己可以发送服务…

sqlyog导出mysql数据字典

1.打开sqlyog执行sql获取字典数据 SELECTt.COLUMN_NAME AS 字段名,t.COLUMN_TYPE AS 数据类型,CASE IFNULL(t.COLUMN_DEFAULT,Null) WHEN THEN 空字符串 WHEN Null THEN NULL ELSE t.COLUMN_DEFAULT END AS 默认值,CASE t.IS_NULLABLE WHEN YES THEN 是 ELSE 否 END AS 是否…

JSON动态生成表格

<!DOCTYPE html> <html><head><meta charset"utf-8"><title></title></head><body><script>var fromjava"{\"total\":3,\"students\":[{\"name\":\"张三\",\&q…

Spring的创建及使用

文章目录 什么是SpringSpring项目的创建存储Bean对象读取Bean对象getBean()方法 更简单的读取和存储对象的方式路径配置使用类注解存储Bean对象关于五大类注解使用方法注解Bean存储对象Bean重命名 Bean对象的读取 使用Resource注入对象Resource VS Autowired同一类型多个bean对…

怎么让表格中的一行数据 转置 为一列数据 (WPS )

例如 我现在有一列数据 我想要 变成一行 数据 1.首先选中想要转置的数据&#xff0c;然后control C 2.接着 点击你想放置数据的位置 右键 其实 关键是 找到 选择性复制 3. 找到转置&#xff0c;勾选 最后 确定 反之亦然

PoseiSwap:基于 Nautilus Chain ,构建全新价值体系

在 DeFi Summer 后&#xff0c;以太坊自身的弊端不断凸显&#xff0c;而以 Layer2 的方式为其扩容成为了行业很长一段时间的叙事方向之一。虽然以太坊已经顺利的从 PoW 的 1.0 迈向了 PoS 的 2.0 时代&#xff0c;但以太坊创始人 Vitalik Buterin 表示&#xff0c; Layer2 未来…

TWS真无线蓝牙耳机哪家好?六款口碑好的TWS真无线蓝牙耳机分享

为了帮助大家在这个充满选择的世界中找到最理想的蓝牙耳机&#xff0c;我们特别为您精心挑选了几款备受赞誉的产品&#xff0c;它们在音质、舒适度、功能和性价比等方面都有出色的表现。在本文中&#xff0c;我们将深入探讨这些蓝牙耳机的特点和优势&#xff0c;帮助您更好地了…

智能化RFID耳机装配系统:提升效率、精准追踪与优化管理

智能化RFID耳机装配系统&#xff1a;提升效率、精准追踪与优化管理 在当今的智能化时代&#xff0c;无线射频识别技术&#xff08;RFID&#xff09;被广泛应用于各个行业。本文将介绍一种基于RFID技术的智能耳机装配案例&#xff0c;通过RFID技术实现耳机装配过程的自动化控制…

Github Pages自定义域名

Github Pages自定义域名 当你想在网上发布内容时&#xff0c;配置Github Pages是一个很好的选择。如果你想要在自己的域名上发布&#xff0c;你可以使用Github Pages来创建自己的网站。本文将介绍如何使用Github Pages自定义域名。 这里呢先列出前置条件&#xff1a; 您的Gi…

在使用Python爬虫时遇到503 Service Unavailable错误解决办法汇总

在进行Python爬虫的过程中&#xff0c;有时会遇到503 Service Unavailable错误&#xff0c;这意味着所请求的服务不可用&#xff0c;无法获取所需的数据。为了解决这个常见的问题&#xff0c;本文将提供一些解决办法&#xff0c;希望能提供实战价值&#xff0c;让爬虫任务顺利完…

mysql的json处理

需要注意&#xff0c;5.7以上版本才支持&#xff0c;但如果是生产环境需要使用的话&#xff0c;尽量使用8.0版本&#xff0c;因为8.0版本对json处理做了比较大的性能优化。你你可以使用select version();来查看版本信息。 本文看下MySQL的json处理。在正式开始让我们先来准备一…

DTCC2023第十四届中国数据库大会分享:MySQL性能诊断平台:利用eBPF技术实现高效的根因诊断

主题 8月16-18日 DTCC2023第十四届中国数据库大会在北京国际会议中心召开&#xff0c;17日下午在云原生数据库开发与实践分论坛&#xff0c;我将带来分享&#xff1a;《MySQL性能诊断平台&#xff1a;利用eBPF技术实现高效的根因诊断》敬请期待&#xff01; 欢迎大家提前试用我…

zsh中安装ros-<ros2-distro>-turtlebot3*失败 || 以humble为例

在zsh中尝试使用 sudo apt install ros-<ros2-distro>-turtlebot3* 安装turtlebot3相关仿真包失败&#xff0c;报错E: 无法定位软件包。 但是在bash中尝试使用同样的命令却可以安装。 原因是zsh中如果要使用通配符&#xff0c;那么一定要放在字符串里&#xff0c;以上…