11. DPO 微调示例:根据人类偏好优化LLM大语言模型

在部署大模型之后,我们必然要和微调打交道。现在大模型的微调有非常多的方法,过去的文章中提到的微调方法通常依赖于问题和答案对,标注成本较高。

2023 年所提出的 Direct Preference Optimization(DPO)为我们提供了一种无需标准标注答案的高效微调方法。DPO 依赖于人类对文本的偏好对(preference pairs),也就是说,数据集中只包含人类对两段文本中哪段更好的判断,而不是具体的正确答案。

在本文中,我们将利用 DPO 来微调一个模型让其按照偏好进行输出。这篇文章也为生成式人工智能导论课程中 HW6: LLM Values Alignment 提供中文引导。

代码文件下载 | 作业PDF

安装和导入一些必要的库

pip install bitsandbytes==0.43.1 datasets==2.19.0 peft==0.10.0 trl==0.8.6 accelerate==0.29.3
import os
import re
import jsonimport torch
import pandas as pd
from tqdm.auto import tqdmfrom datasets import Dataset
from peft import LoraConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, BitsAndBytesConfig, GenerationConfig
from trl import DPOTrainer

可能的问题:Keras 3 与 Transformers 不兼容

在导入时,你可能会看到以下报错:

RuntimeError: Failed to import trl.trainer.dpo_trainer because of the following error (look up to see its traceback):
Failed to import transformers.trainer because of the following error (look up to see its traceback):
Failed to import transformers.integrations.integration_utils because of the following error (look up to see its traceback):
Failed to import transformers.modeling_tf_utils because of the following error (look up to see its traceback):
Your currently installed version of Keras is Keras 3, but this is not yet supported in Transformers. Please install the backwards-compatible tf-keras package with pip install tf-keras.

transformers 库建议安装兼容的 tf-keras 包来解决这个兼容性问题。你可以通过以下命令安装:

pip install tf-keras

现在问题应该得到了解决。

加载数据集

我们将使用预先提供的数据集,包括带标签的偏好数据和测试提示数据。

这个数据集来自于生成式人工智能导论的HW6,处理的问题是:是否应该将动漫真人化?两个回答分别对应支持和不支持(由GPT生成),在后面的代码中你将选择支持的占比。

git clone https://github.com/Baiiiiiiiiii/GenAI_hw6_dataset.git
with open("./GenAI_hw6_dataset/labelled_data.json", 'r') as jsonfile:full_data = json.load(jsonfile)with open("./GenAI_hw6_dataset/test_prompt.json", 'r') as jsonfile:test_data = json.load(jsonfile)

直观理解数据集:

full_data

image-20240919114655048

使用 HFD 下载模型

我们这里使用多线程的方法进行快速下载。

如果直接运行以下命令报错,根据 a. 使用 HFD 加快 Hugging Face 模型和数据集的下载 进行前置安装。

当然,你也可以取消我注释的部分,使用官方的命令进行安装,但是会很慢。

安装工具

sudo apt-get update
sudo apt-get install git wget curl aria2 git-lfs
git lfs install

下载 hfd 并修改权限

wget https://hf-mirror.com/hfd/hfd.sh
chmod a+x hfd.sh

多线程下载模型

export HF_ENDPOINT=https://hf-mirror.com
./hfd.sh 'MediaTek-Research/Breeze-7B-Instruct-v0_1' --tool aria2c -x 16

下载

加载模型

将使用MediaTek-Research/Breeze-7B-Instruct-v0_1模型进行微调。

model = AutoModelForCausalLM.from_pretrained('MediaTek-Research/Breeze-7B-Instruct-v0_1',device_map='auto',trust_remote_code=True,quantization_config=BitsAndBytesConfig(load_in_4bit=True,bnb_4bit_compute_dtype=torch.bfloat16,bnb_4bit_use_double_quant=True,bnb_4bit_quant_type='nf4')
)

这里,我们采用了4位量化(4-bit quantization)来减少模型的内存占用,加快推理速度。

查看未经过微调的模型原始输出

在进行微调之前,我们首先查看一下原始模型的输出效果。首先,加载分词器:

tokenizer = AutoTokenizer.from_pretrained('MediaTek-Research/Breeze-7B-Instruct-v0_1')
tokenizer.padding_side = "right"
tokenizer.pad_token = tokenizer.eos_token

定义一个数据处理函数,将数据格式化为模型可以接受的输入,我们这里的 prompt 延续原来的繁体(因为Breeze-7B-Instruct-v0_1更多使用繁体中文进行训练,你并不需要修改它):

def data_formulate(data):messages = [{"role": "system", "content": '回覆請少於20字'},{"role": "user", "content": data['prompt']},]prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)return prompt

接下来,生成原始模型的响应:

original_model_response = []
for data in tqdm(test_data):id = data['id']print(f'Question {id}:\n'+data['prompt'])inputs = tokenizer(data_formulate(data), return_tensors="pt").to('cuda')generation_config=GenerationConfig(do_sample=False,max_new_tokens = 200,pad_token_id = tokenizer.pad_token_id)output = model.generate(**inputs, generation_config=generation_config)output = tokenizer.batch_decode(output, skip_special_tokens=True)[0].split('[/INST] ')[1]original_model_response.append(output)print('Response from original model:\n'+output+'\n')

这段代码将遍历测试数据集,生成并打印每个问题的原始模型响应。

image-20240919113918391

设置参数

你只需要修改这个模块,不需要改变其他的,除非你真的知道自己在做什么。

support_ratio 将反映你的偏好:

  • 0 表示完全不支持(反对)真人化
  • 1 表示完全支持真人化
  • 0.1 表示 10% 支持真人化
num_epoch = 1
data_size = 50
support_ratio = 0.1

准备训练数据

这里,我们将数据集分为支持(support)和反对(oppose)两部分,构建一个包含偏好对的训练数据集(是的,这里就是 DPO)。

# 选择部分数据用于训练
training_data = full_data[:data_size]# 定义 support 数据集的大小
support_data_size = int(data_size * support_ratio)# 为训练数据集准备数据
prompt_list = [data_formulate(data) for data in training_data]
chosen_list = [data['support'] for data in training_data[:support_data_size]] + [data['oppose'] for data in training_data[support_data_size:]]
rejected_list = [data['oppose'] for data in training_data[:support_data_size]] + [data['support'] for data in training_data[support_data_size:]]
position_list = ['support' for _ in range(support_data_size)] + ['oppose' for _ in range(data_size - support_data_size)]# 创建训练数据集
train_dataset = Dataset.from_dict({'prompt': prompt_list, 'position': position_list, 'chosen': chosen_list, 'rejected': rejected_list})
pd.DataFrame(train_dataset).rename(columns={"chosen": "preferred", "rejected": "non-preferred"})

总共有 50 笔训练数据,当 support 设置为 0.1 时,前 50*0.1=5 笔训练资料的偏好将倾向于支持真人化,后 50-5=45 笔资料反对真人化。

image-20240919114949791

训练

现在,我们进入训练阶段。首先,设置训练参数:

training_args = TrainingArguments(output_dir='./',per_device_train_batch_size=1,num_train_epochs=num_epoch,gradient_accumulation_steps=8,gradient_checkpointing=False,learning_rate=2e-4,optim="paged_adamw_8bit",logging_steps = 1,warmup_ratio = 0.1,report_to = 'none'
)

接下来,配置PEFT(Parameter-Efficient Fine-Tuning):

peft_config = LoraConfig(lora_alpha=16,lora_dropout=0.1,r=64,bias="none",task_type="CAUSAL_LM",
)

然后,初始化DPO训练器:

dpo_trainer = DPOTrainer(model,args=training_args,beta=0.1,train_dataset=train_dataset,tokenizer=tokenizer,peft_config=peft_config,
)

开始训练:

dpo_trainer.train()

image-20240919115410184

查看微调后的模型输出

训练完成后,我们需要查看微调后的模型效果。以下是生成训练后模型响应的代码:

trained_model_response = []
for data in tqdm(test_data):id = data['id']print(f'Question {id}:\n'+data['prompt'])inputs = tokenizer(data_formulate(data), return_tensors="pt").to('cuda')generation_config=GenerationConfig(do_sample=False,max_new_tokens = 200,pad_token_id = tokenizer.pad_token_id)output = model.generate(**inputs, generation_config=generation_config)output = tokenizer.batch_decode(output, skip_special_tokens=True)[0].split('[/INST] ')[1]trained_model_response.append(output)print('Response from trained model:\n'+output+'\n')

这段代码与之前生成原始模型响应的代码类似,但这次生成的是经过微调后的模型响应:

image-20240919115643310

观察输出结果

最后,我们对比微调前后的模型响应,观察DPO方法带来的效果提升:

model_response = []
print(f'num_epoch: {num_epoch}\ndata_size: {data_size}\nsupport_ratio: {support_ratio}')
print()
for data in test_data:id = data['id']ref_output = original_model_response[id-1]output = trained_model_response[id-1]print(f'Question {id}:\n'+data['prompt'])print('Response from original model:\n'+ref_output)print('Response from trained model:\n'+output)print()model_response.append({'id':data['id'], 'prompt':data['prompt'], 'response_from_original_model':ref_output, 'response_from_trained_model':output})

image-20240919115708299

拓展

在使用 GPT 的时候你应该也见到过其同时生成两个回答让我们选择更倾向于哪个,这个和 Google 验证码有着异曲同工之妙。

推荐阅读

Direct Preference Optimization: Your Language Model is Secretly a Reward Model

下一章

12. Inseq 特征归因:可视化解释 LLM 的输出

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/430742.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

C语言----指针

基本知识点:指针的定义、指针运算符和指针运算等基本概念。重 点:字符指针、指针数组和多级指针。难 点:利用指针类型解决复杂的应用问题。 指针的概念 要点归纳 1.指针变量 在计算机中,所有数据都通过变量存放在内存中,每个变量都…

【matlab】将程序打包为exe文件(matlab r2023a为例)

文章目录 一、安装运行时环境1.1 安装1.2 简介 二、打包三、打包文件为什么很大 一、安装运行时环境 使用 Application Compiler 来将程序打包为exe,相当于你使用C编译器把C语言编译成可执行程序。 在matlab菜单栏–App下面可以看到Application Compiler。 或者在…

啤酒过滤——关于过滤助剂的介绍

在啤酒的酿造过程中,过滤是一个关键步骤,在啤酒厂中最常用的过滤助剂主要有两种:硅藻土和珍珠岩。它们能够帮助去除杂质,确保啤酒的清澈和口感。过滤助剂通常以粉状形式存在,它们被涂抹在过滤机的支撑材料上&#xff0…

深度合成算法备案和大模型备案的区别是什么

以下是关于大语言模型上线备案和深度合成算法备案区别的文档内容: 一、大语言模型上线备案与深度合成算法备案的区别 备案对象 大语言模型上线备案:主要针对生成式人工智能(AIGC)产品中的大型语言模型,能够生成文本、图…

MT6765/MT6762(R/D/M)/MT6761(MT8766)安卓核心板参数比较_MTK联发科4G智能模块

联发科Helio P35 MT6765安卓核心板 MediaTek Helio P35 MT6765是智能手机的主流ARM SoC,于2018年末推出。它在两个集群中集成了8个ARM Cortex-A53内核(big.LITTLE)。四个性能内核的频率高达2.3GHz。集成显卡为PowerVR GE8320,频率…

MATLAB系列09:图形句柄

MATLAB系列09:图形句柄 9. 图形句柄9.1 MATLAB图形系统9.2 对象句柄9.3 对象属性的检测和更改9.3.1 在创建对象时改变对象的属性9.3.2 对象创建后改变对象的属性 9.4 用 set 函数列出可能属性值9.5 自定义数据9.6 对象查找9.7 用鼠标选择对象9.8 位置和单位9.8.1 图…

Leetcode面试经典150题-39.组合总数进阶:40.组合总和II

本题是扩展题,真实考过,看这个题之前先看一下39题 Leetcode面试经典150题-39.组合总数-CSDN博客 给定一个候选人编号的集合 candidates 和一个目标数 target ,找出 candidates 中所有可以使数字和为 target 的组合。 candidates 中的每个数…

E2VPT: An Effective and Efficient Approach for Visual Prompt Tuning

论文汇总 存在的问题 1.以前的提示微调方法那样只关注修改输入,而应该明确地研究在微调过程中改进自注意机制的潜力,并探索参数效率的极限。 2.探索参数效率的极值来减少可调参数的数量? 解决办法 提示嵌入进行transformer中 提示剪枝 Token-wise …

004_动手实现MLP(pytorch)

import torch from torch import nn from torch.nn import init import numpy as np import sys import d2lzh_pytorch as d2l # 1.数据预处理 mnist_train torchvision.datasets.FashionMNIST(root/Users/w/PycharmProjects/DeepLearning_with_LiMu/datasets/FashionMnist, t…

DevExpress WPF中文教程:如何解决行焦点、选择的常见问题?

DevExpress WPF拥有120个控件和库,将帮助您交付满足甚至超出企业需求的高性能业务应用程序。通过DevExpress WPF能创建有着强大互动功能的XAML基础应用程序,这些应用程序专注于当代客户的需求和构建未来新一代支持触摸的解决方案。 无论是Office办公软件…

0-1开发自己的obsidian plugin DAY 2

今天上午解决了三个问题 1. typescript长得丑/一片飘红/格式检查太严格 在vscode的settings里搜索下面这个然后false掉: "typescript.validate.enable": false 就不会一片飘红了(其他下载第三方插件如TSLint和typescript hero的方法都不好使&…

虚幻引擎的三种输入模式和将控件显示到屏幕上

首先要知道一个概念 , HUD 和 Input 都是由 PlayerController 来控制的 而虚幻的Input控制模式有三种 Set Input Mode Game Only (设置输入模式仅限游戏): 视角会跟着鼠标旋转 , 就是正常游戏的模式 , 这也是游戏默认输入模式 Set Input Mode UI Only (设置输入模式仅限UI): …

DHCP协议原理(网络协议)

DHCP简介 定义 DHCP(动态主机配置协议)是一种网络管理协议,能够自动为局域网中的每台计算机分配IP地址及其他网络配置参数,包括子网掩码、默认网关和DNS服务器等。这一机制极大简化了网络管理,尤其在大型局域网中&am…

sheng的学习笔记-AI-K-摇臂赌博机(K-armed bandit)

AI目录:sheng的学习笔记-AI目录-CSDN博客 强化学习 sheng的学习笔记-AI-强化学习(Reinforcement Learning, RL)-CSDN博客 基础知识 单步强化学习任务 先考虑比较简单的情形:最大化单步奖赏,即仅考虑一步操作。需注意…

使用API有效率地管理Dynadot域名,注册域名服务器(NS)信息

前言 Dynadot是通过ICANN认证的域名注册商,自2002年成立以来,服务于全球108个国家和地区的客户,为数以万计的客户提供简洁,优惠,安全的域名注册以及管理服务。 Dynadot平台操作教程索引(包括域名邮箱&…

GPU共享技术深度剖析与总结

在人工智能和深度学习领域,GPU(图形处理器)已成为不可或缺的计算工具。随着深度学习模型的规模和复杂性的增加,单个GPU已经难以满足所有训练需求,GPU共享技术应运而生,成为提高训练效率的重要手段。本文将深…

聊聊AUTOSAR:基于Vector MICROSAR的TC8测试开发方案

技术背景 车载以太网技术作为汽车智能化和网联化的重要组成部分,正逐步成为现代汽车网络架构的核心,已广泛应用于汽车诊断(如OBD)、ECU软件更新、智能座舱系统、高清摄像头环视泊车系统等多个领域。 在这个过程中,ET…

oklink爬虫逆向分析

目标网站 aHR0cHM6Ly93d3cub2tsaW5rLmNvbS96aC1oYW5zL2tsYXl0bi9ibG9jay1saXN0L3BhZ2UvMg 一、抓包分析 请求头有很多加密参数,不过经过观察,发现只有X-Apikey是检测的 二、逆向分析 发包类型不是XMLHttpRequest,不能下xhr断点 打开启动器…

【项目案例】物联网比较好的10+练手项目推荐,附项目文档/源码/视频

练手项目推荐 1 智能小车 项目功能介绍: 本项目由三部分组成:应用端(微信小程序)、设备端(Hi3861)、驱动端(UPS)。 1. 应用端,采用微信小程序作为应用端控制界面。在开…

spring里面内置的非常实用的工具

一 、请求数据记录 Spring Boot提供了一个内置的日志记录解决方案,通过 AbstractRequestLoggingFilter 可以记录请求的详细信息。 AbstractRequestLoggingFilter 有两个不同的实现类,我们常用的是 CommonsRequestLoggingFilter。 通过 CommonsRequestL…