【大模型学习】多模态大模型进行偏好优化

一、简介

训练模型以理解并预测人类偏好是一项复杂的任务。传统方法如SFT(监督微调)通常需要较高的成本,因为这些算法需要对数据进行特定标签的标注。偏好优化(Preference Optimization)作为一种替代方案,可以简化这一过程并提供更准确的结果。通过对候选回答的对比和排序,而不是赋予固定的标签,偏好优化能够更高效地捕捉人类偏好的细微差别。

虽然偏好优化已经在大语言模型中广泛使用,但现在它也可以应用于视觉语言模型(VLM)。得益于TRL(Transformer Reinforcement Learning)的开发,现在我们可以使用TRL对VLM进行直接偏好优化(Direct Preference Optimization)。本文将介绍使用TRL和DPO对视觉语言模型进行训练的全过程。

二、偏好数据集

进行偏好优化,首先需要有一个能体现用户偏好的数据集。在双项选择的设定下,相应的数据一般包含一个提示词(Prompt)和两个候选回答,其中一个被标记为选中(chosen),另一个被标记为淘汰(rejected)。模型需要学习选择正确的回答,而不是被淘汰的回答。下图展示了一个例子:

❔ 问题: 有多少个家庭?

  • ❌ 被淘汰的回答: 图片没有提供关于家庭的信息。
  • ✅ 选中的回答: 图片显示了一个工会组织的表格,包含18000个家庭。

尽管选中的回答也不是完全正确(应该是18000000个家庭),但比被淘汰的回答更好。

本文将使用openbmb/RLAIF-V-Dataset作为示例数据集,该数据集包含超过83000条标注数据。可以通过以下代码查看数据集:

from datasets import load_datasetdataset = load_dataset("openbmb/RLAIF-V-Dataset", split="train[:1%]")
sample = dataset[1]
sample["image"].show()
sample["question"]
'how many families?'
sample["rejected"]
'The image does not provide any information about families.'
sample["chosen"]
'The image shows a Union Organization table setup with 18,000 families.'

我们将要训练的 VLM 模型需要文本和图像同时作为输入,所以这里的第一步还是要对数据集格式进行改造。一条数据应该被结构化成能模拟人机对话的形式。用户提供一个提示语,其中包含一张图片和一个问题,然后模型需要能够给出一个回答。我们用以下代码实现格式转换:

from datasets import features
from transformers import AutoProcessorprocessor = AutoProcessor.from_pretrained("HuggingFaceM4/idefics2-8b", do_image_splitting=False)def format(example):# Prepare the input for the chat templateprompt = [{"role": "user","content": [{"type": "image"}, {"type": "text", "text": example["question"]}],},]chosen = [{"role": "assistant","content": [{"type": "text", "text": example["chosen"]}],},]rejected = [{"role": "assistant","content": [{"type": "text", "text": example["rejected"]}],},]# Apply the chat templateprompt = processor.apply_chat_template(prompt, tokenize=False)chosen = processor.apply_chat_template(chosen, tokenize=False)rejected = processor.apply_chat_template(rejected, tokenize=False)# Resize the image to ensure it fits within the maximum allowable# size of the processor to prevent OOM errors.max_size = processor.image_processor.size["longest_edge"]example["image"].thumbnail((max_size, max_size))return {"images": [example["image"]], "prompt": prompt, "chosen": chosen, "rejected": rejected}# Apply the formatting function to the dataset,
# remove columns to end up with only "images", "prompt", "chosen", "rejected" columns
dataset = dataset.map(format, remove_columns=dataset.column_names)# Make sure that the images are decoded, it prevents from storing bytes.
# More info here https://github.com/huggingface/blog/pull/2148#discussion_r1667400478
f = dataset.features
f["images"] = features.Sequence(features.Image(decode=True)) # to avoid bytes
dataset = dataset.cast(f)

完成了格式转换,我们来看看第一条数据:

>>> dataset[1]
{'images': [<PIL.JpegImagePlugin.JpegImageFile image mode=L size=980x812 at 0x154505570>],'prompt': 'User:<image>how many families?<end_of_utterance>\n','rejected': 'Assistant: The image does not provide any information about families.<end_of_utterance>\n','chosen': 'Assistant: The image shows a Union Organization table setup with 18,000 families.<end_of_utterance>\n'}

三、训练

3.1 训练需要多大的 GPU 显存?

以微调1B的模型为例子,假设模型的的每个参数用32bit存储,32bit=4byte。

每个参数通常以浮点数形式存储。FP32(32位浮点数)每个参数占用4字节的存储空间,而BF16(16位浮点数)每个参数占用2字节的存储空间。

需要用到GPU的部分:模型权重(需要加载进去)、梯度(更新参数)、优化器(状态量,SGD和Adam占用的显存空间不一样)、激活值等等

1 Byte = 1 \times 10^{-9} GB

  • 模型权重1B = 1b x 4 byte = 4GB;
  • 梯度的显存需求与模型权重相同 4GB;
  • 以Adam优化器(LLM用的多)为例,Adam需要维护模型的参数、每个参数的动量和平方梯度信息,因此占用的显存大约是模型权重的3倍 [一阶动量估计(类似动量)、二阶动量估计(平方梯度)];

注意,优化器都是用FP32进行存储的,因为大量的小值累加(sum、mean)操作,如果用FP16进行会损失精度,太小的值用FP16会表示为0。

  • 激活值(中间结果),反向传播和前向传播会用到,这边只是简单起见,bs=1,和模型参数一样是4GB,实际上这个计算推导很复杂,后面有机会再写~,同时Transformer中激活值和序列长度以平方次数增长;
  • 输入数据:跟Batch size、样本I大小有关系,就是B x I x 4 字节,这边暂时忽略;
参数来源计算公式显存需求
要训练的模型8 \times 10^9 \times 432 GB
参考模型(这个任务额外要的,防止模型发生偏移,和要训练的模型一样大)8 \times 10^9 \times 432 GB
梯度8 \times 10^9 \times 432 GB
优化器状态量3 \times 8 \times 10^9 \times 472 GB
合计168 GB

可以使用量化、LoRA 等技术来大幅度地减少显存需求,让训练可以进行。

3.2 使用 bfloat16 和 LoRA 后的显存需求

参数来源计算公式显存需求
要训练的模型8 \mathrm{G} \times 216 GB
参考模型8 \mathrm{G} \times 216 GB
梯度55 \mathrm{M} \times 20.1 GB
优化器状态量3 \times 55 \mathrm{M} \times 20.3 GB
合计32.4 GB

四、微调Llava 1.5和PaliGemma等模型

TRL的DPO实现已支持Idefics2、Llava 1.5和PaliGemma,同时TRL也在努力支持更多的模型。最简单的调用方法是使用TRL提供的示例脚本。例如,如果你想微调PaliGemma,可以使用以下命令:

accelerate launch examples/scripts/dpo_visual.py \--dataset_name HuggingFaceH4/rlaif-v_formatted \--model_name_or_path google/paligemma-3b-pt-224 \--per_device_train_batch_size 2 \--gradient_accumulation_steps 32 \--dataset_num_proc 32 \--output_dir dpo_paligemma_rlaif-v \--bf16 \--torch_dtype bfloat16 \--gradient_checkpointing \--use_peft \--lora_target_modules=all-linear

五、可视化结果

下表展示了一些可视化的结果:

ImageQuestionIdefics2Idefics2+DPO
Are there two ships in this image?YesNo
Is the ground uneven in this image?NoYes
Is there one shovel in this image?YesNo

六、参考链接

[1] https://huggingface.co/docs/peft/en/index

[2] https://cloud.google.com/vertex-ai/generative-ai/docs/model-garden/lora-qlora?hl=zh-cn

[3] https://huggingface.co/blog/zh/dpo_vlm

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/395807.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【多线程-从零开始-捌】阻塞队列,消费者生产者模型

什么是阻塞队列 阻塞队里是在普通的队列&#xff08;先进先出队列&#xff09;基础上&#xff0c;做出了扩充 线程安全 标准库中原有的队列 Queue 和其子类&#xff0c;默认都是线程不安全的 具有阻塞特性 如果队列为空&#xff0c;进行出队列操作&#xff0c;此时就会出现阻…

C++ 重要特性探究

shared_from_this 使用分析 场景 类的成员函数需要获取指向自身的shared_ptr的时候类成员函数传递shared_ptr给其他函数或者对象的时候&#xff0c;目的是为了管理对象生命周期使用方法 首先类必须继承 std::enable_shared_from_this<T>必须使用 shared_from_this 获取指…

智慧交通:将物联网与人工智能完美融合

智慧交通是当今社会面临的一个重要挑战&#xff0c;也是人们生活质量提高的一个重要方面。通过将物联网技术与人工智能相结合&#xff0c;我们能够实现智慧交通系统的全面升级和优化&#xff0c;为人们带来更加便捷、高效和安全的出行体验。 在智慧交通领域&#xff0c;物联网…

电脑图片损坏打不开怎么办?能修复吗?

照片和视频是记录和保存现实生活中的事件的最好方式。由于手机储存空间有限&#xff0c;一般我们会把有纪念意义的照片放到电脑上进行保存&#xff0c;但有时难免会遇到照片被损坏打不开的情况&#xff0c;一旦遇到这种情况&#xff0c;先不要急&#xff0c;也不要因为照片打不…

【智能控制】第7章 神经网络理论基础,神经网络的分类,原理,发展,神经网络学习算法(北京航天航空大学)

目录 第7章 神经网络理论基础 1. 神经网络的发展 2. 神经网络原理 3. 神经网络的分类 (1) 前向网络 (2) 反馈网络 (3) 自组织网络 4. 神经网络学习算法 (1) 智能Hebb学习规则 (2) Delta&#xff08;δ&#xff09;学习规则 5. 神经网络的特征及…

【Mind+】 掌控板入门教程09 魔法之光

光是地球生命的来源&#xff0c;是人类生活的依据&#xff0c;更是人类认识外部世界的工具。在科技发达的今天&#xff0c;我们可以通过传感器来检测光&#xff0c;利用光帮助我们更好的生活。 今天就让我们一起通过几个小项目来感受光的魔法吧。 项目示例 掌控板…

经验是负债,学习是资产

经验是负债&#xff0c;学习是资产 经验是负债&#xff0c;学习是资产。这是李嘉诚先生的一句名言。他一语道出了学习在企业发展中的推动作用。 企业家经营的目的&#xff0c;无非就是将利润最大化。企业能够产生利润&#xff0c;靠的是提升自身业绩、降低运营成本&#xff0c;…

使用 Java Swing 创建一个最大公约数计算器 GUI 应用

使用Java语言,设计一个程序,实现求取两个正整数的最大公约数。 比较基础的一个Java小程序。 1、效果展示 2、程序代码 package demo; import javax.swing.*; import java.awt.*;

Kafka基本讲解

Kafka基本讲解 一&#xff1a;Kafka介绍 Kafka是分布式消息队列&#xff0c;主要设计用于高吞吐量的数据处理和消息传输&#xff0c;适用于日志处理、实时数据管道等场景。Kafka作为实时数仓架构的核心组件&#xff0c;用于收集、缓存和分发实时数据流&#xff0c;支持复杂的…

【博客搭建 第二篇章】项目中怎么引入其他的 icon

一、注册账号并将图标添加到自己的项目中 1、网站地址&#xff1a;https://www.iconfont.cn/ 2、注册 iconfont 账号 3、登录 iconfont 网站中 4、添加图标到购物车中 5、添加图标到项目中 6、生成在线连接 7、复制连接 二、项目中配置连接地址 找到项目中的 them…

R语言医疗数据分析笔记

分组因子又是什么意思&#xff0c;分组因子和数组的区别是什么 举个实际的例子 分组因子 分组因子是分类变量&#xff0c;用于将数据分成不同组以便于比较或分析。例如&#xff0c;在一项研究中&#xff0c;研究对象的性别&#xff08;男性和女性&#xff09;可以视为一个分组…

OBC充电机电力系统的安全保障

OBC&#xff08;On-Board Charger&#xff09;充电机是电动汽车的关键部件&#xff0c;它负责将外部交流电转换为直流电&#xff0c;为电动汽车的动力电池充电。因此&#xff0c;OBC充电机的电力系统安全保障至关重要。 首先&#xff0c;OBC充电机需要有良好的电气隔离和保护功…

【mysql 第三篇章】一条 update语句是怎么持久化到磁盘上的?

首先看一下这个 SQL 语句你会不会写? 下面是说明执行这个 SQL 语句&#xff0c;数据库底层做了什么操作。 update users set namexxx where id10;在引擎要执行更新语句的时候&#xff0c;比如更新 id10 这行数据时&#xff0c;他会先查看数据在缓冲池中是否存在&#xff0c;如…

C语言指针详解-包过系列(二)目录版

C语言指针详解-包过系列&#xff08;二&#xff09;目录版 1、数组名的深入理解1.1、数组名的本质1.2、数组名本质的两个例外1.2.1、sizeof&#xff08;数组名&#xff09;1.2.2、&数组名 2、使用指针访问数组3、一维数组传参本质4、二级指针4.1、二级指针介绍4.2、二级指针…

8.9 C++

1.思维导图 2. 搭建一个货币的场景&#xff0c;创建一个名为 RMB 的类&#xff0c;该类具有整型私有成员变量 yuan&#xff08;元&#xff09;、jiao&#xff08;角&#xff09;和 fen&#xff08;分&#xff09;&#xff0c;并且具有以下功能&#xff1a; (1)重载算术运算符…

PCL 曲线4点细分算法

文章目录 一、简介二、实现代码三、实现效果参考资料一、简介 四点细分算法(Four-Point Subdivision Scheme)是一种用于生成平滑曲线的细分算法。与 Chaikin 逼近型细分算法不同,四点细分算法通过插入新的控制点来细化曲线,并生成一条逐步逼近的平滑曲线。该算法通常用于生…

高效管理:如何快速查询并跟踪批量快递物流信息

在现代快节奏的生活中&#xff0c;我们经常需要处理大量的快递单号&#xff0c;以跟踪货物的物流轨迹。无论是电商卖家、物流公司还是个人用户&#xff0c;都希望能够快速、准确地获取到快递的实时信息。为了解决这个问题&#xff0c;我们可以借助一款名为“固乔快递查询助手”…

八、MyBatis

一、MyBatis介绍 MyBatis 是持久层框架&#xff0c;它支持自定义 SQL、存储过程以及⾼级映射。MyBatis 去除了几乎所有的 JDBC 代码以及设置参数和获取结果集的工作。MyBatis 可以通过简单的 XML 或注解来配置 和映射原始类型、接口和 Java POJO&#xff08;Plain Old Java Obj…

最新版的AutoGPT,我搭建好了

最近AutoGPT不是更新了嘛 安装 我按照官方的教程 在本地搭建好了 改动 可见的改动&#xff0c;主要是把原来的纯命令行改成前后端的形式 看下前端界面 界面比较简单&#xff0c;主要分3个大块 监控 第一个是监控 主要是看你在 build 里构建的Agents的运行情况 build 第一个是Ag…

前端项目中的Server-sent Events(SSE)项目实践及其与websocket的区别

前端项目中的Server-sent Events(SSE)项目实践 前言 在前端开发中&#xff0c;实时数据更新是提升用户体验的重要因素之一。Server-SentEvents(SSE)是一种高效的技术&#xff0c;允许服务器通过单向连接将实时数据推送到客户端。下面将从SSE的基本改变&#xff0c;使用场景展…