-
引言
-
环境安装
-
数据准备
-
下载
-
处理
-
-
模型训练
-
模型inference
-
结果
-
gemma-2-9b
-
gemma-2-9b-it
-
引言
低头观落日,引手摘飞星。
小伙伴们好,我是微信公众号《小窗幽记机器学习》的小编:卖黑神话的小女孩。本文紧接前文Google最新开源大语言模型:Gemma 2介绍及其微调(上篇),介绍如何用中文语料微调Gemma 2模型。如想与小编进一步交流,欢迎在《小窗幽记机器学习》上获取小编微信号,或者直接添加小编的wx号:
环境安装
pip3 install -U torch transformers trl peft bitsandbytes tf-keras -i https://mirrors.cloud.tencent.com/pypi/simple
pip3 install tf-keras -i https://mirrors.cloud.tencent.com/pypi/simple
数据准备
下载
这里使用Hello-SimpleAI/HC3-Chinese
数据集进行微调。预先下载:
huggingface-cli download --resume-download --repo-type dataset --local-dir-use-symlinks False Hello-SimpleAI/HC3-Chinese --local-dir /share_data_zoo/LLM/Hello-SimpleAI/HC3-Chinese/
处理
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time : 2024/6/29 16:25
# @Author : 卖黑神话的小女孩
# @File : fine_tuning_data_preprocess.py
"""
预处理:划分训练集和测试集
"""
import os
import pdbfrom datasets import load_dataset# Convert dataset to OAI messages
system_message = """你是一个知识丰富的人工智能助手,用户将用中文向你提问,你将根据你的知识用中文来如实回答问题
"""
data_dir = "/share_data_zoo/LLM/"
data_id = "Hello-SimpleAI/HC3-Chinese"
data_name = data_id.split('/')[-1]
print("data_name=", data_name)
# pdb.set_trace()
data_path = os.path.join(data_dir, data_id)"""
conversational format
{"messages": [{"role": "system", "content": "You are..."}, {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}
{"messages": [{"role": "system", "content": "You are..."}, {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}
{"messages": [{"role": "system", "content": "You are..."}, {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}instruction format
{"prompt": "<prompt text>", "completion": "<ideal generated text>"}
{"prompt": "<prompt text>", "completion": "<ideal generated text>"}
{"prompt": "<prompt text>", "completion": "<ideal generated text>"}
"""def create_conversation(sample):return {"messages": [{"role": "system", "content": system_message},{"role": "user", "content": sample["question"]},{"role": "assistant", "content": sample["human_answers"][0]}# for whatever reason the dataset uses a list of answers]}if __name__ == "__main__":# Load dataset from the hubdataset_dict = load_dataset("json", data_files=f"{data_path}/baike.jsonl")# 由于只有一个文件,我们将其视为训练集) split="train"dataset = dataset_dict['train']print(create_conversation(dataset[0]))# # Convert dataset to OAI messagesdataset = dataset.map(create_conversation, remove_columns=["chatgpt_answers"], batched=False)# # split dataset into 10,000 training samples and 2,500 test samples# dataset = dataset.train_test_split(test_size=4500/4616) # baike splitdataset = dataset.train_test_split(test_size=0.1)# save datasets to diskdataset["train"].to_json("train_dataset.json", orient="records")dataset["test"].to_json("test_dataset.json", orient="records")print("Save to disk success")
模型训练
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time : 2024/6/29 14:53
# @Author : 卖黑神话的小女孩
# @File : fine_tuning_gemma.py
"""
安装依赖:pip3 install -U torch transformers trl peft bitsandbytes tf-keras -i https://mirrors.cloud.tencent.com/pypi/simplepip3 install tf-keras -i https://mirrors.cloud.tencent.com/pypi/simple准备数据:运行 fine_tuning_data_preprocess.py 脚本开始训练:运行 fine_tuning_gemma.py 脚本在脚本的末尾会将lora和原始模型进行merge开始inference:运行 fine_tuning_gemma_inference.py 脚本如果报错:ImportError: /usr/local/lib/python3.10/dist-packages/transformer_engine_extensions.cpython-310-x86_64-linux-gnu.so: undefined symbol: _ZN2at4_ops5zeros4callEN3c108ArrayRefINS2_6SymIntEEENS2_8optionalINS2_10ScalarTypeEEENS6_INS2_6LayoutEEENS6_INS2_6DeviceEEENS6_IbEEpip3 uninstall transformer-engine 即可
"""
import os
import pdb
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline
from trl import setup_chat_format
from datasets import load_dataset
from peft import LoraConfig
from transformers import TrainingArguments
from trl import SFTTrainer
from peft import AutoPeftModelForCausalLM
from fine_tuning_data_preprocess import data_name as train_data_nameinit_model_dir = "/share_model_zoo/LLM/"
# init_model_id = "google/gemma-2-9b"
init_model_id = "google/gemma-2-9b-it"
init_model_path = os.path.join(init_model_dir, init_model_id)
res_dir = "../result_models"
result_model_dir = os.path.join(res_dir, init_model_id, train_data_name)
print("result_model_dir=", result_model_dir)
# 检查路径是否已存在
if not os.path.exists(result_model_dir):# 递归创建目录os.makedirs(result_model_dir)print("目录已创建:", result_model_dir)
else:print("目录已存在:", result_model_dir)# Convert dataset to OAI messages
system_message = """你是一个知识丰富的人工智能助手,用户将用中文向你提问,你将根据你的知识用中文来如实回答问题
""""""
conversational format
{"messages": [{"role": "system", "content": "You are..."}, {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}
{"messages": [{"role": "system", "content": "You are..."}, {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}
{"messages": [{"role": "system", "content": "You are..."}, {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}instruction format
{"prompt": "<prompt text>", "completion": "<ideal generated text>"}
{"prompt": "<prompt text>", "completion": "<ideal generated text>"}
{"prompt": "<prompt text>", "completion": "<ideal generated text>"}
"""def create_conversation(sample):return {"messages": [{"role": "system", "content": system_message},{"role": "user", "content": sample["question"]},{"role": "assistant", "content": sample["human_answers"][0]}# for whatever reason the dataset uses a list of answers]}# Load jsonl data from disk
dataset = load_dataset("json", data_files="train_dataset.json", split="train")# BitsAndBytesConfig int-4 config
bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
)# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained(init_model_path,device_map="auto",# attn_implementation="flash_attention_2",torch_dtype=torch.bfloat16,quantization_config=bnb_config
)
tokenizer = AutoTokenizer.from_pretrained(init_model_path)
tokenizer.padding_side = 'right' # to prevent warnings# # set chat template to OAI chatML, remove if you start from a fine-tuned model
model, tokenizer = setup_chat_format(model, tokenizer)# LoRA config based on QLoRA paper & Sebastian Raschka experiment
peft_config = LoraConfig(lora_alpha=128,lora_dropout=0.05,r=256,bias="none",target_modules="all-linear",task_type="CAUSAL_LM",
)args = TrainingArguments(output_dir=result_model_dir, # directory to save and repository idnum_train_epochs=3, # number of training epochsper_device_train_batch_size=1, # batch size per device during traininggradient_accumulation_steps=2, # number of steps before performing a backward/update passgradient_checkpointing=True, # use gradient checkpointing to save memoryoptim="adamw_torch_fused", # use fused adamw optimizerlogging_steps=10, # log every 10 stepssave_strategy="epoch", # save checkpoint every epochlearning_rate=2e-4, # learning rate, based on QLoRA paper# bf16=True, # use bfloat16 precision if you have supported GPU# tf32=True, # use tf32 precision if you have supported GPUmax_grad_norm=0.3, # max gradient norm based on QLoRA paperwarmup_ratio=0.03, # warmup ratio based on QLoRA paperlr_scheduler_type="constant", # use constant learning rate schedulerpush_to_hub=False, # push model to hubreport_to="tensorboard", # report metrics to tensorboard
)max_seq_length = 1024 # max sequence length for model and packing of the datasettrainer = SFTTrainer(model=model,args=args,train_dataset=dataset,peft_config=peft_config,max_seq_length=max_seq_length,tokenizer=tokenizer,packing=True,dataset_kwargs={"add_special_tokens": False, # We template with special tokens"append_concat_token": False, # No need to add additional separator token}
)# start training, the model will be automatically saved to the hub and the output directory
trainer.train()# save model
trainer.save_model()
print("Save model success")
### COMMENT IN TO MERGE PEFT AND BASE MODEL ##### Load PEFT model on CPU
model = AutoPeftModelForCausalLM.from_pretrained(args.output_dir,torch_dtype=torch.float16,low_cpu_mem_usage=True,
)
# Merge LoRA and base model and save
merged_model = model.merge_and_unload()
merged_model.save_pretrained(args.output_dir, safe_serialization=True, max_shard_size="2GB")
print(f"Save merged_model to {args.output_dir} success")
模型inference
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time : 2024/6/29 16:51
# @Author : 卖黑神话的小女孩
# @File : fine_tuning_gemma_inference.py
"""
transformers
"""
import os
import pdb
import time
import torch
from peft import AutoPeftModelForCausalLM
from transformers import AutoTokenizer, pipeline
from datasets import load_dataset
from fine_tuning_data_preprocess import data_name as train_data_name
from random import randintinit_model_dir = "/share_model_zoo/LLM/"
init_model_id = "google/gemma-2-9b"
# init_model_id = "google/gemma-2-9b-it"
init_model_path = os.path.join(init_model_dir, init_model_id)
res_dir = "../result_models"
result_model_dir = os.path.join(res_dir, init_model_id, train_data_name)peft_model_id = result_model_dir# Load Model with PEFT adapter
start_time = time.time()
model = AutoPeftModelForCausalLM.from_pretrained(peft_model_id,device_map="auto",torch_dtype=torch.float16
)
print(f"Load peft model={peft_model_id} success")
end_time = time.time()
model_load_cost = round(end_time - start_time, 2)
print(f"model load cost={model_load_cost}")tokenizer = AutoTokenizer.from_pretrained(peft_model_id)
# load into pipeline
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)# Test on sample
rand_idx = 2
eval_dataset = load_dataset("json", data_files="test_dataset.json", split="train")
test_texts = eval_dataset[rand_idx]["messages"][:2]
# pdb.set_trace()
# 调用方法1:
prompt = pipe.tokenizer.apply_chat_template(eval_dataset[rand_idx]["messages"][:2], tokenize=False,add_generation_prompt=True)
outputs = pipe(prompt, repetition_penalty=1.3, max_new_tokens=256, do_sample=False, temperature=0.1, top_k=50,top_p=0.1, eos_token_id=pipe.tokenizer.eos_token_id, pad_token_id=pipe.tokenizer.pad_token_id)# # 调用方法2:
# messages = [
# {"role": "user", "content": "你是谁?"},
# ]
# messages_outputs = pipe(
# messages,
# repetition_penalty=1.3,
# max_new_tokens=256,
# do_sample=False,
# )
#
# assistant_response = messages_outputs[0]["generated_text"][-1]["content"]
# print("assistant_response=\n", assistant_response)print(f"Query:\n{eval_dataset[rand_idx]['messages'][1]['content']}")
print(f"Original Answer:\n{eval_dataset[rand_idx]['messages'][2]['content']}")
print(f"Generated Answer:\n{outputs[0]['generated_text'][len(prompt):].strip()}")
结果
gemma-2-9b
未微调结果
Query:
你是一个知识丰富的人工智能助手,用户将用中文向你提问,你将根据你的知识用中文来如实回答问题
我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS
Generated Answer:
? 我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS? 我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS? 我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS? 我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS? 我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS? 我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS? 我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS? 我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS? 我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS? 我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS? 我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS? 我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS? 我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS? 我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS? 我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS? 我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS
微调结果
Query:
我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS
Original Answer:
RouterOS是一种路由操作系统,是基于Linux核心开发,兼容x86 PC的路由软件,并通过该软件将标准的PC电脑变成专业路由器,在软件RouterOS 软路由图的开发和应用上不断的更新和发展,软件经历了多次更新和改进,使其功能在不断增强和完善。特别在无线、认证、策略路由、带宽控制和防火墙过滤等功能上有着非常突出的功能,其极高的性价比,受到许多网络人士的青睐。
Generated Answer:
RouterOS是采用先进的网络协议和算法(例如RIP、OSPF、BGP等)进行路由管理与控制以及负载均衡的一种类Unix计算机操作系统。它是在1996年由MikroTik公司开发并发布的第一个版本为2.0而设计的用于多种平台上的高级互联网网关软件包或系统。 它的主要目标是对小型办公室和家庭用户的无线局域网提供出色的性能以改善数据传输速率和其他关键指标;同时最大限度地降低成本并在设计中考虑安装复杂性及可扩充性的需求点。 它在全球拥有超过35,000个活跃的用户群并且 Mikrotik 是世界领先且最可靠的小型企业边缘联网设备供应商 。他们已成功地在全世界销售了超过4百万的产品 ,产品覆盖范围从低端到高端商业办公大楼或是 ISP 的核心机房都适用。他们的客户遍布于几乎所有可以上网的地方而且很多国家都有其代表处或者分销商 ;由于产品的易用性和高性价比 ,使得我们的产品受到许多新兴市场的青睐比如:俄罗斯 、印度 和中国等等国家的市场正在蓬勃发展 !我们确信这些还不是它们的极限!随着科技的发展 ,Internet 将会
gemma-2-9b-it
未微调结果
Query:
你是一个知识丰富的人工智能助手,用户将用中文向你提问,你将根据你的知识用中文来如实回答问题
我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS
Generated Answer:
?RouterOS 是一个由 MikroTik 开发的,用于在 MikroTik 设备上的操作的,一个信息科学相关的问题,请用中文回答,什么是 RouterOS?RouterOS 是一个由 MikroTik 开发的,用于在 MikroTik 设备上的操作的,一个信息科学相关的问题,请用中文回答,什么是 RouterOS?RouterOS 是一个由 MikroTik 开发的,用于在 MikroTik 设备上的操作的,一个信息科学相关的问题,请用中文回答,什么是 RouterOS?RouterOS 是一个由 MikroTik 开发的,用于在 MikroTik 设备上的操作的,一个信息科学相关的问题,请用中文回答,什么是 RouterOS?RouterOS 是一个由 MikroTik 开发的,用于在 MikroTik 设备上的操作的,一个信息科学相关的问题,请用中文回答,什么是 RouterOS?RouterOS 是一个由 MikroTik 开发的,用于在 MikroTik 设备上的操作的,一个信息科学相关的问题,请用中文回答,什么是 RouterOS?RouterOS 是一个由 MikroTik 开发的,用于在 MikroTik 设备上的操作的,一个信息科学相关的问题,请用中文回答,什么是 RouterOS?RouterOS 是一个由 MikroTik 开发的
微调结果
Query:
我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS
Original Answer:
RouterOS是一种路由操作系统,是基于Linux核心开发,兼容x86 PC的路由软件,并通过该软件将标准的PC电脑变成专业路由器,在软件RouterOS 软路由图的开发和应用上不断的更新和发展,软件经历了多次更新和改进,使其功能在不断增强和完善。特别在无线、认证、策略路由、带宽控制和防火墙过滤等功能上有着非常突出的功能,其极高的性价比,受到许多网络人士的青睐。
Generated Answer:
RouterOS是由 Latvian Information Technologies Association(丽顿信息技术协会)开发的网络协议栈和路由器软件。它被认为是基于IPv4/IPSec、MPLS及其他高速数据传输协议的高速分组交换与包处理实现的核心;也是Internet骨干网建设的重要设备之一。
其核心应用为:防火墙服务 (NAT / IPsec)、高速度互联网接入服务器 、安全 VPN 等业务功能 。此外, 它还具有丰富的语音压缩算法等线路侧特性.因此在电信固定无线通信方面也发挥着重要作用。由于采用先进的数据转发引擎架构设计使其具备很强的扩展性 ,所以routeros体系能够兼容多种处理器结构,从x86到ARM9E-S 等等 । routeros本身提供很多高级的技术功能面,但并没有进行深入的研究工作,因为它的开发者希望把产品的源代码开放给公众以便共同改进产品性能.随着多核CPU技术的盛行以及云计算理论产生的兴起 ,许多新兴的公司都加入到了这个行业并且利用了开源软件作为基础做出了自己的创新产 品.这促进了计算机硬件的发展 和社会资源的合理利用$.当然对于一些大公司来说他们拥有足够的研发能力可以自行