微调 Florence-2 - 微软的尖端视觉语言模型

Florence-2 是微软于 2024 年 6 月发布的一个基础视觉语言模型。该模型极具吸引力,因为它尺寸很小 (0.2B 及 0.7B) 且在各种计算机视觉和视觉语言任务上表现出色。

Florence 开箱即用支持多种类型的任务,包括: 看图说话、目标检测、OCR 等等。虽然覆盖面很广,但仍有可能你的任务或领域不在此列,也有可能你希望针对自己的任务更好地控制模型输出。此时,你就需要微调了!

本文,我们展示了一个在 DocVQA 上微调 Florence 的示例。尽管原文宣称 Florence 2 支持视觉问答 (VQA) 任务,但最终发布的模型并未包含 VQA 功能。因此,我们正好拿这个任务练练手,看看我们能做点什么!

预训练细节与模型架构

98ccecd0587f449e3674afd0908280f1.png
Florence-2 架构

无论执行什么样的计算机视觉任务,Florence-2 都会将其建模为序列到序列的任务。Florence-2 以图像和文本作为输入,并输出文本。模型结构比较简单: 用 DaViT 视觉编码器将图像转换为视觉嵌入,并用 BERT 将文本提示转换为文本和位置嵌入; 然后,生成的嵌入由标准编码器 - 解码器 transformer 架构进行处理,最终生成文本和位置词元。Florence-2 的优势并非源自其架构,而是源自海量的预训练数据集。作者指出,市面上领先的计算机视觉数据集通常所含信息有限 - WIT 仅有图文对,SA-1B仅有图像及相关分割掩码。因此,他们决定构建一个新的 FLD-5B 数据集,其中的每个图像都包含最广泛的信息 - 目标框、掩码、描述文本及标签。在创建数据集时,很大程度采用了自动化的过程,作者使用现成的专门任务模型,并用一组启发式规则及质检过程来清理所获得的结果。最终生成的用于预训练 Florence-2 模型的新数据集中包含了 1.26 亿张图像、超过 50 亿个标注。

SA-1Bhttps://ai.meta.com/datasets/segment-anything/

VQA 上的原始性能

我们尝试了各种方法来微调模型以使其适配 VQA (视觉问答) 任务的响应方式。迄今为止,我们发现最有效方法将其建模为图像区域描述任务,尽管其并不完全等同于 VQA 任务。看图说话任务虽然可以输出图像的描述性信息,但其不允许直接输入问题。

我们还测试了几个“不支持”的提示,例如 “<VQA>”、“<vqa>” 以及 “<Visual question answering>”。不幸的是,这些尝试的产生的结果都不可用。

微调后在 DocVQA 上的性能

我们使用 DocVQA 数据集的标准指标Levenshtein 相似度来测量性能。微调前,模型在验证集上的输出与标注的相似度为 0,因为模型输出与标注差异不小。对训练集进行 7 个 epoch 的微调后,验证集上的相似度得分提高到了 57.0。

Levenshtein 相似度https://en.wikipedia.org/wiki/Levenshtein_distance

我们创建了一个🤗 空间以演示微调后的模型。虽然该模型在 DocVQA 上表现良好,但在一般文档理解方面还有改进的空间。但我们仍然认为,它成功地完成了任务,展示了 Florence-2 对下游任务进行微调的潜力。我们建议大家使用The Cauldron数据集对 Florence-2 进行微调,大家可以在我们的 GitHub 页面上找到必要的代码。

  • 🤗 空间https://hf.co/spaces/andito/Florence-2-DocVQA

  • The Cauldronhttps://hf.co/datasets/HuggingFaceM4/the_cauldron

  • 我们的 GitHub 页面https://github.com/andimarafioti/florence2-finetuning

下图给出了微调前后的推理结果对比。你还可以至此处亲自试用模型。

模型试用地址https://hf.co/spaces/andito/Florence-2-DocVQA

266edf3c8245404862cb97421a8aa420.png
微调前后的结果

微调细节

由原文我们可以知道,基础模型在预训练时使用的 batch size 为 2048,大模型在预训练时使用的 batch size 为 3072。另外原文还说: 与冻结图像编码器相比,使用未冻结的图像编码器进行微调能带来性能改进。

我们在低资源的情况下进行了多组实验,以探索模型如何在更受限的条件下进行微调。我们冻结了视觉编码器,并在Colab的分别使用单张 A100 GPU (batch size 6) 、单张 T4 (batch size 1) 顺利完成微调。

Colab 链接https://colab.research.google.com/drive/1hKDrJ5AH_o7I95PtZ9__VlCTNAo1Gjpf?usp=sharing

与此同时,我们还对更多资源的情况进行了实验,以 batch size 64 对整个模型进行了微调。在配备 8 张 H100 GPU 的集群上该训练过程花费了 70 分钟。你可以在这里找到我们训得的模型。

模型地址https://hf.co/HuggingFaceM4/Florence-2-DocVQA

我们都发现 1e-6 的小学习率适合上述所有训练情形。如果学习率变大,模型将很快过拟合。

遛代码

如果你想复现我们的结果,可以在此处找到我们的 Colab 微调笔记本。下面,我们遛一遍在DocVQA上微调Florence-2-base-ft模型。

  • Colab 地址https://colab.research.google.com/drive/1hKDrJ5AH_o7I95PtZ9__VlCTNAo1Gjpf?usp=sharing

  • DocVQAhttps://hf.co/datasets/HuggingFaceM4/DocumentVQA

  • Florence-2-base-fthttps://hf.co/microsoft/Florence-2-base-ft

我们从安装依赖项开始。

!pip install -q datasets flash_attn timm einops

接着,从 Hugging Face Hub 加载 DocVQA 数据集。

import torch
from datasets import load_datasetdata = load_dataset("HuggingFaceM4/DocumentVQA")

我们可以使用 transformers 库中的 AutoModelForCausalLMAutoProcessor 类来加载模型和处理器,并设 trust_remote_code=True ,因为该模型尚未原生集成到 transformers 中,因此需要使用自定义代码。我们还会冻结视觉编码器,以降低微调成本。

from transformers import AutoModelForCausalLM, AutoProcessor
import torchdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")model = AutoModelForCausalLM.from_pretrained("microsoft/Florence-2-base-ft",trust_remote_code=True,revision='refs/pr/6'
).to(device)
processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base-ft",trust_remote_code=True, revision='refs/pr/6')for param in model.vision_tower.parameters():param.is_trainable = False

现在开始微调模型!我们构建一个训练 PyTorch 数据集,并为数据集中的每个问题添加 <DocVQA> 前缀。

import torch from torch.utils.data import Datasetclass DocVQADataset(Dataset):def __init__(self, data):self.data = datadef __len__(self):return len(self.data)def __getitem__(self, idx):example = self.data[idx]question = "<DocVQA>" + example['question']first_answer = example['answers'][0]image = example['image'].convert("RGB")return question, first_answer, image

接着,构建数据整理器,从数据集样本构建训练 batch,以用于训练。在 40GB 内存的 A100 中,batch size 可设至 6。如果你在 T4 上进行训练,batch size 就只能是 1。

import os
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AdamW, get_schedulerdef collate_fn(batch):questions, answers, images = zip(*batch)inputs = processor(text=list(questions), images=list(images), return_tensors="pt", padding=True).to(device)return inputs, answerstrain_dataset = DocVQADataset(data['train'])
val_dataset = DocVQADataset(data['validation'])
batch_size = 6
num_workers = 0train_loader = DataLoader(train_dataset, batch_size=batch_size,collate_fn=collate_fn, num_workers=num_workers, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size,collate_fn=collate_fn, num_workers=num_workers)

开始训练模型:

epochs = 7
optimizer = AdamW(model.parameters(), lr=1e-6)
num_training_steps = epochs * len(train_loader)lr_scheduler = get_scheduler(name="linear", optimizer=optimizer,num_warmup_steps=0, num_training_steps=num_training_steps,)for epoch in range(epochs):model.train()train_loss = 0i = -1for inputs, answers in tqdm(train_loader, desc=f"Training Epoch {epoch + 1}/{epochs}"):i += 1input_ids = inputs["input_ids"]pixel_values = inputs["pixel_values"]labels = processor.tokenizer(text=answers, return_tensors="pt", padding=True, return_token_type_ids=False).input_ids.to(device)outputs = model(input_ids=input_ids, pixel_values=pixel_values, labels=labels)loss = outputs.lossloss.backward()optimizer.step()lr_scheduler.step()optimizer.zero_grad()train_loss += loss.item()avg_train_loss = train_loss / len(train_loader)print(f"Average Training Loss: {avg_train_loss}")model.eval()val_loss = 0with torch.no_grad():for batch in tqdm(val_loader, desc=f"Validation Epoch {epoch + 1}/{epochs}"):inputs, answers = batchinput_ids = inputs["input_ids"]pixel_values = inputs["pixel_values"]labels = processor.tokenizer(text=answers, return_tensors="pt", padding=True, return_token_type_ids=False).input_ids.to(device)outputs = model(input_ids=input_ids, pixel_values=pixel_values, labels=labels)loss = outputs.lossval_loss += loss.item()print(val_loss / len(val_loader))

你可以分别对模型和处理器调用 save_pretrained() 以保存它们。微调后的模型在此处,你还可以在此处找到其演示。

  • 模型链接https://hf.co/HuggingFaceM4/Florence-2-DocVQA

  • 示例地址https://hf.co/spaces/andito/Florence-2-DocVQA

    2cc9b35547030333924ea6481df35f9a.png  

    演示示例

总结

本文,我们展示了如何有效地针对自定义数据集微调 Florence-2,以在短时间内在全新任务上取得令人眼前一亮的性能。对于那些希望在设备上或在生产环境中经济高效地部署小模型的人来说,该做法特别有价值。我们鼓励开源社区利用这个微调教程,探索 Florence-2 在各种新任务中的巨大潜力!我们迫不及待地想在 🤗 Hub 上看到你的模型!

有用资源

  • 视觉语言模型详解https://hf.co/blog/zh/vlms

  • 微调 Colabhttps://colab.research.google.com/drive/1hKDrJ5AH_o7I95PtZ9__VlCTNAo1Gjpf?usp=sharing

  • 微调 Github 代码库https://github.com/andimarafioti/florence2-finetuning

  • Florence-2 推理 Notebookhttps://hf.co/microsoft/Florence-2-large/blob/main/sample_inference.ipynb

  • Florence-2 DocVQA 演示https://hf.co/spaces/andito/Florence-2-DocVQA

  • Florence-2 演示https://hf.co/spaces/gokaygo

感谢 Pedro Cuenca 对本文的审阅。


英文原文: https://hf.co/blog/finetune-florence2

原文作者: Andres Marafioti,Merve Noyan,Piotr Skalski

译者: Matrix Yao (姚伟峰),英特尔深度学习工程师,工作方向为 transformer-family 模型在各模态数据上的应用及大规模模型的训练推理。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/378219.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

气膜的制作工艺—轻空间

气膜结构是现代建筑领域中的一项创新技术&#xff0c;以其独特的优点在多种应用场景中得到广泛推广和应用。轻空间将详细介绍气膜的制作工艺&#xff0c;帮助大家更好地理解气膜结构的制造过程及其技术优势。 1. 选材与设计 气膜的制作首先从材料的选择和设计开始。气膜材料主要…

SAP ABAP性能优化

1.前言 ABAP作为SAP的专用的开发语言&#xff0c;衡量其性能的指标主要有以下两个方面&#xff1a; 响应时间&#xff1a;对于某项特定的业务请求&#xff0c;系统在收到请求后需要多久返回结果 吞吐量&#xff1a;在给定的时间能&#xff0c;系统能够处理的数据量 2. ABAP语…

看番工具 -- oneAnime v1.2.5绿色版

软件简介 OneAnime是一款专为动漫爱好者设计的应用程序&#xff0c;它提供了一个庞大的动漫资源库&#xff0c;用户可以在这里找到各种类型的动漫&#xff0c;包括热门的、经典的、新番的等等。OneAnime的界面设计简洁明了&#xff0c;操作方便&#xff0c;用户可以轻松地搜索…

产品经理-一份标准需求文档的8个模块(14)

一份标准优秀的产品需求文档包括&#xff1a; ❑ 封面&#xff1b; ❑ 文档修订记录表&#xff1b; ❑ 目录&#xff1b; ❑ 引言&#xff1b; ❑ 产品概述&#xff1a;产品结构图 ❑ 详细需求说明&#xff1a;产品逻辑图、功能与特性简述列表、交互/视觉设计、需求详细描述&am…

MySQL执行状态查看与分析

当mysql出现性能问题时&#xff0c;一般会查看mysql的执行状态&#xff0c;执行命令&#xff1a; show processlist 各列的含义 列名含义id一个标识&#xff0c;你要kill一个语句的时候使用&#xff0c;例如 mysql> kill 207user显示当前用户&#xff0c;如果不是root&…

MySQL中IF()、IFNULL()、NULLIF()、ISNULL()函数的奇妙之旅

在MySQL这片浩瀚的数据海洋中&#xff0c;函数如同航海家的罗盘&#xff0c;指引着数据处理的航向。今天&#xff0c;就让我们踏上一场探索之旅&#xff0c;深入了解MySQL中几位不可或缺的“航海家”——IF()、IFNULL()、NULLIF()、ISNULL()函数&#xff0c;看它们如何在数据处…

Redis 数据类型

Redis 数据类型 文章目录 Redis 数据类型1. String类型2. key的层级结构3. Hash类型4. List类型5. Set类型6. SortedSet类型 1. String类型 String类型是redis中最常用的存储类型&#xff0c;即字符串类型&#xff0c;同时根据字符串的格式不同&#xff0c;可以将value分为三类…

shell脚本-linux如何在脚本中远程到一台linux机器并执行命令

需求&#xff1a;我们需要从11.0.1.17远程到11.0.1.16上执行命令 实现&#xff1a; 1.让11.0.1.17 可以免密登录到11.0.1.16 [rootlocalhost ~]# ssh-keygen Generating public/private rsa key pair. Enter file in which to save the key (/root/.ssh/id_rsa): Created d…

前端基础之JavaScript学习——变量、数据类型、类型转换

大家好&#xff0c;我是来自CSDN的博主PleaSure乐事&#xff0c;今天我们开始有关JS的学习&#xff0c;希望有所帮助并巩固有关前端的知识。 我使用的编译器为vscode&#xff0c;浏览器使用为谷歌浏览器&#xff0c;使用webstorm或其他环境效果几乎一样&#xff0c;使用系统自…

数电基础 - 硬件描述语言

目录 一. 简介 二. Verilog简介和基本程序结构 三. 应用场景 四. Verilog的学习方法 五.调式方法 一. 简介 硬件描述语言&#xff08;Hardware Description Language&#xff0c;HDL&#xff09;是用于描述数字电路和系统的形式化语言。 常见的硬件描述语言包括 VHDL&…

zephyr设置BLE广播数据实例

目录 实例1&#xff1a;静态开启广播数据实例2&#xff1a;动态更改广播数据实例3&#xff1a;创建可连接的广播 实例1&#xff1a;静态开启广播数据 新建一个hello world的工程模板。 在prj.conf中开启蓝牙 CONFIG_BTy这个宏&#xff0c;默认会开启广播支持 ( BT_BROADCAS…

组网升级,双击热备和宽带管理

拓扑 要求&#xff1a; 要求12&#xff1a; 要求13&#xff1a; 要求14&#xff1a; 要求15&#xff1a; 要求16&#xff1a;

解决 Vscode不支持c++11的语法

问题&#xff1a; 解决方案&#xff1a; 1、按 CtrlShiftP 调出命令面板&#xff0c;输入 C/C: Edit Configurations (UI) 并选择它。这将打开 C/C 配置界面 2、打开 c_cpp_properties.json 文件 3、编辑 c_cpp_properties.json 4、保存 c_cpp_properties.json 文件。 关闭并…

使用JS和CSS制作的小案例(day二)

一、写在开头 本项目是从github上摘取&#xff0c;自己练习使用后分享&#xff0c;方便登录github的小伙伴可以看本篇文章 50项目50天​编辑https://github.com/bradtraversy/50projects50dayshttps://github.com/bradtraversy/50projects50days有兴趣的小伙伴可以自己去gith…

SpringBoot详细解析

1.什么是springboot springboot也是spring公司开发的一款框架。为了简化spring项目的初始化搭建的。那么spring对应springboot有什么缺点呢&#xff1f; spring项目搭建的缺点: 配置麻烦依赖tomcat启动慢 2.springboot的特点 自动配置 Spring Boot的自动配置是一个运行时&…

JVM垃圾回收-----垃圾分类

一、垃圾分类定义 垃圾分类是JVM垃圾分类中的第一步&#xff0c;这一步将堆中的对象分为存活对象和垃圾对象两类。 在垃圾分类阶段&#xff0c;JVM会从一组根对象开始&#xff0c;通过对象之间的引用关系&#xff0c;遍历所有的对象&#xff0c;并将所有存活的对象进行标记。…

flutter 手写 TabBar

前言&#xff1a; 这几天在使用 flutter TabBar 的时候 我们的设计给我提了一个需求&#xff1a; 如下 Tabbar 第一个元素 左对齐&#xff0c;试了下TabBar 的配置&#xff0c;无法实现这个需求&#xff0c;他的 配置是针对所有元素的。而且 这个 TabBar 下面的 滑块在移动的时…

idea中使用maven

默认情况下&#xff0c;idea会自动下载并安装maven&#xff0c;这不便于我们管理。 最好是自行下载maven&#xff0c;然后在idea中指定maven的文件夹路径

解析 Mira :基于 Web3,让先进的 AI 技术易于访问和使用

“Mira 平台正在以 Web3 的方式解决当前 AI 开发面临的复杂性问题&#xff0c;同时保护 AI 贡献者的权益&#xff0c;让他们可以自主拥有并货币化自己的模型、数据和应用&#xff0c;以使先进的 AI 技术更加易于访问和使用。” AI 代表着一种先进的生产力&#xff0c;它通过深…

【UE5.1】NPC人工智能——02 NPC移动到指定位置

效果 步骤 1. 新建一个蓝图&#xff0c;父类选择“AI控制器” 这里命名为“BP_NPC_AIController”&#xff0c;表示专门用于控制NPC的AI控制器 2. 找到我们之前创建的所有NPC的父类“BP_NPC” 打开“BP_NPC”&#xff0c;在类默认值中&#xff0c;将“AI控制器类”一项设置为“…