MM-LLM:使用Llava类构建图文多模态大模型实践

在这里插入图片描述
多模态大模型的结构如上,llava是用两层MLP作为连接器。该模式也是后续很多工作的基础。

本文主要参考了https://github.com/yuanzhoulvpi2017/zero_nlp/tree/main/train_llava的工作,最初是在b站看到的,讲解的很细致。

基础模型

大语言模型:Qwen2-1.5B-Instruct
视觉模型:clip-vit-large-patch14-336
连接器:MLP
框架:llava模型

1.LLM的处理

下载模型权重到本地后,修改Qwen2-1.5B-Instruct/tokenizer_config.json的added_tokens_decoder的值,添加

"151646": {"content": "<image>","lstrip": false,"normalized": false,"rstrip": false,"single_word": false,"special": true}

additional_special_tokens添加 "<image>"

2.初始化llava模型

# 模型权重路径
modify_qwen_tokenizer_dir = "autodl-tmp/Qwen2-1.5B-Instruct"
clip_model_name_or_path = ("autodl-tmp/clip-vit-large-patch14-336"
)# 加载qwen2
qwen_tokenizer = AutoTokenizer.from_pretrained(modify_qwen_tokenizer_dir)
qwen_model = AutoModelForCausalLM.from_pretrained(modify_qwen_tokenizer_dir, device_map='cuda:0', torch_dtype=torch.bfloat16)# 加载clip
clip_model = AutoModel.from_pretrained(clip_model_name_or_path, device_map="cuda:0")
processor = AutoProcessor.from_pretrained(clip_model_name_or_path)# 将clip模型和llm_model模型的config拿出来,初始化一个llava model
# Initializing a CLIP-vision config
vision_config = clip_model.vision_model.config
# Initializing a Llama config
text_config = qwen_model.config
# Initializing a Llava llava-1.5-7b style configuration
configuration = LlavaConfig(vision_config, text_config)
# Initializing a model from the llava-1.5-7b style configuration
model = LlavaForConditionalGeneration(configuration)

输出:

LlavaForConditionalGeneration((vision_tower): CLIPVisionModel((vision_model): CLIPVisionTransformer((embeddings): CLIPVisionEmbeddings((patch_embedding): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14), bias=False)(position_embedding): Embedding(577, 1024))(pre_layrnorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)(encoder): CLIPEncoder((layers): ModuleList((0-23): 24 x CLIPEncoderLayer((self_attn): CLIPAttention((k_proj): Linear(in_features=1024, out_features=1024, bias=True)(v_proj): Linear(in_features=1024, out_features=1024, bias=True)(q_proj): Linear(in_features=1024, out_features=1024, bias=True)(out_proj): Linear(in_features=1024, out_features=1024, bias=True))(layer_norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)(mlp): CLIPMLP((activation_fn): QuickGELUActivation()(fc1): Linear(in_features=1024, out_features=4096, bias=True)(fc2): Linear(in_features=4096, out_features=1024, bias=True))(layer_norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True))))(post_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)))(multi_modal_projector): LlavaMultiModalProjector((linear_1): Linear(in_features=1024, out_features=1536, bias=True)(act): GELUActivation()(linear_2): Linear(in_features=1536, out_features=1536, bias=True))(language_model): Qwen2ForCausalLM((model): Qwen2Model((embed_tokens): Embedding(151936, 1536)(layers): ModuleList((0-27): 28 x Qwen2DecoderLayer((self_attn): Qwen2SdpaAttention((q_proj): Linear(in_features=1536, out_features=1536, bias=True)(k_proj): Linear(in_features=1536, out_features=256, bias=True)(v_proj): Linear(in_features=1536, out_features=256, bias=True)(o_proj): Linear(in_features=1536, out_features=1536, bias=False)(rotary_emb): Qwen2RotaryEmbedding())(mlp): Qwen2MLP((gate_proj): Linear(in_features=1536, out_features=8960, bias=False)(up_proj): Linear(in_features=1536, out_features=8960, bias=False)(down_proj): Linear(in_features=8960, out_features=1536, bias=False)(act_fn): SiLU())(input_layernorm): Qwen2RMSNorm()(post_attention_layernorm): Qwen2RMSNorm()))(norm): Qwen2RMSNorm())(lm_head): Linear(in_features=1536, out_features=151936, bias=False))
)

这样得到了llava模型的结构,但是旧有的权重参数还没迁移过来,要将其移动到新model里。

# 权重复制
model.vision_tower.vision_model = clip_model.vision_model
model.language_model = qwen_model

然后保存到本地,注意要将autodl-tmp/processor的preprocessor_config.json复制到autodl-tmp/vlm_1

# 保存模型
model.save_pretrained("autodl-tmp/vlm_1")
qwen_tokenizer.save_pretrained("autodl-tmp/vlm_1")
processor.save_pretrained("autodl-tmp/processor")

3.数据集加载代码

采用该数据集:https://huggingface.co/datasets/OpenGVLab/ShareGPT-4o

主要代码:

class LlavaDataset(Dataset):def __init__(self, dataset_dir: str) -> None:super().__init__()self.chat_data, self.image_dir = self.build_dataset(dataset_dir)def build_dataset(self, data_dir: str) -> Tuple[List[Dict], Path]:# 得到对话文件和图像文件的路径data_dir = Path(data_dir) # 父文件夹路径chat_file = data_dir.joinpath("final_data.jsonl") # 对话文件image_dir = data_dir.joinpath("image") # 图像文件夹# 读取为记录,转为dictchat_data = pd.read_json(chat_file, lines=True).to_dict(orient="records")return chat_data, image_dirdef __len__(self):return len(self.chat_data)def __getitem__(self, index) -> Tuple[str, str, Path]:# 根据索引定位到记录cur_data = self.chat_data[index] # 定位conversations = cur_data.get("conversations") # 字典格式获取到对话记录human_input = conversations[0].get("value") # 查询chatbot_output = conversations[1].get("value") # 回复image_path = self.image_dir.joinpath(cur_data.get("image")) # 图片的路径,由图片文件夹+图片名构成return human_input, chatbot_output, image_path

4.训练

使用deepseed训练,主要代码

def train():parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))model_args, data_args, training_args = parser.parse_args_into_dataclasses()model, processor = load_model_processor(model_args)data_collator = TrainLLavaModelCollator(processor, -100)train_dataset = load_dataset(data_args)trainer = Trainer(model=model,args=training_args,train_dataset=train_dataset,eval_dataset=None,data_collator=data_collator,)trainer.train()trainer.save_state()trainer.save_model(output_dir=training_args.output_dir)

5.推理

没有训练的模型进行推理的结果:

很抱歉,我无法看到或描述图片,因为我是一个文本生成模型,无法处理图像。如果您需要帮助,可以提供文字描述,我会尽力帮助您。

训练后的模型推理:

The image depicts a scene of a person sitting on a chair with their
legs crossed. The person is wearing a white shirt and dark blue jeans.
The person’s hair is styled in a messy, tousled manner, which adds to
the casual and relaxed atmosphere of the image. The person’s eyes are
closed, and they appear to be in a state of deep thought or
contemplation.

In the background, there is a small, white, rectangular object that
appears to be a piece of paper or a piece of writing. The object is
positioned in a manner that suggests it might be part of a document or
a note. The background is a light beige color, which contrasts with
the person’s clothing and the white object.

The chair is a wooden chair with a simple design, featuring a single
armrest and a backrest. The chair is positioned on a dark wooden
floor, which adds to the overall casual and comfortable feel of the
scene. The floor is also light beige, which complements the background
and the person’s clothing.

The lighting in the image is soft and diffused, giving the scene a
warm and inviting atmosphere. The person’s posture suggests they are
in a relaxed position, possibly after a long day or a moment of
reflection.

In summary, the image captures a person sitting on a chair with their
legs crossed, wearing casual clothing, and in a relaxed position. The
background includes a small white object, and the lighting is soft and
diffused, creating a warm and inviting atmosphere.

我仅仅训练了三轮,使用了不到300条数据。虽然结果不是很好,但是可以看出来是有成效的。
在这里插入图片描述

在我查找的多模态大模型实现中性价比是最高的,不用重写LLM的forward函数什么的。

相关代码放在https://github.com/stay-leave/enhance_llm。

参考:
https://github.com/yuanzhoulvpi2017/zero_nlp/tree/main/train_llava
https://github.com/OpenGVLab/InternVL/blob/main/internvl_chat
https://github.com/AviSoori1x/seemore
https://github.com/alexander-moore/vlm
https://github.com/WatchTower-Liu/VLM-learning

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/366709.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Xilinx FPGA:vivado利用单端RAM/串口传输数据实现自定义私有协议

一、项目要求 实现自定义私有协议&#xff0c;如&#xff1a;pc端产生数据&#xff1a;02 56 38 &#xff0c;“02”代表要发送数据的个数&#xff0c;“56”“38”需要写进RAM中。当按键信号到来时&#xff0c;将“56”“38”读出返回给PC端。 二、信号流向图 三、状态…

真的假不了,假的真不了

大家好&#xff0c;我是瑶琴呀&#xff0c;拥有一头黑长直秀发的女程序员。 最近&#xff0c;17岁的中专生姜萍参加阿里巴巴 2024 年的全球数学竞赛&#xff0c;取得了 12 名的好成绩&#xff0c;一时间在网上沸腾不止。 从最开始的“数学天才”&#xff0c;到被质疑&#xff…

Linux CentOS 环境 MySQL 主从复制集群搭建

环境说明 MySQL版本8.4.0 操作系统 Linux CentOS 7.9 官网文档 https://dev.mysql.com/doc/refman/8.4/en/replication-configuration.html 以下代码片段中带分号都是在MySQL命令行( mysql -uroot -p)中执行 1. 首先在两个节点上安装数据库 参考 Linux CentOS安装MySQL8.0 …

C++编程(五)单例模式 友元

文章目录 一、单例模式&#xff08;一&#xff09;概念&#xff08;二&#xff09;实现方式1. 饿汉式2. 懒汉式 二、友元&#xff08;一&#xff09;概念&#xff08;二&#xff09;友元函数1.概念2.语法格式3. 使用示例访问静态成员变量访问非静态成员变量 &#xff08;三&…

Victor CMS v1.0 SQL 注入漏洞(CVE-2022-28060)

前言 CVE-2022-28060 是 Victor CMS v1.0 中的一个SQL注入漏洞。该漏洞存在于 /includes/login.php 文件中的 user_name 参数。攻击者可以通过发送特制的 SQL 语句&#xff0c;利用这个漏洞执行未授权的数据库操作&#xff0c;从而访问或修改数据库中的敏感信息。 漏洞详细信…

78.Vue 3 重用性模态框组件

模态框是大多数 Web 应用程序中的基本构建块。虽然最初实现起来可能看起来有点棘手&#xff0c;但实际上&#xff0c;使用 Vue 和一些 Flexbox 技巧&#xff0c;这不仅可行&#xff0c;而且非常简单。 让我们一起实现一个基础的模态框组件。 架构如下&#xff1a; AppModal.vue…

Ubuntu22 更新内核后终端输入卡顿,最简单的解决方案

在系统升级后相信很多人都遇到了这个问题&#xff0c;系统终端输入卡顿&#xff0c;但是ssh远程进来不卡&#xff0c;使用第三方终端也不卡,…&#xff0c;今天终于忍不了&#xff0c;解决了 现象&#xff1a; 更新Nvidia驱动后,内核进行了自动编译升级。 之后的一段时间使用…

用MySQL+node+vue做一个学生信息管理系统(五):学生信息增删改的实现

先实现增加信息&#xff1a; post参数的获取&#xff1a;express中接受post请求参数需要借助第三方包 body-parser 下载npm install body-parser //引入body-parser模块 const bodyParser require(body-parser); //拦截所有请求,配置body-parser模块 //extended:false 方法…

Linux多线程【线程互斥】

文章目录 Linux线程互斥进程线程间的互斥相关背景概念互斥量mutex模拟抢票代码 互斥量的接口初始化互斥量销毁互斥量互斥量加锁和解锁改进模拟抢票代码&#xff08;加锁&#xff09;小结对锁封装 lockGuard.hpp 互斥量实现原理探究可重入VS线程安全概念常见的线程不安全的情况常…

Python基础001

Python输出语句 print输出字符串 print("中国四大名著&#xff1a;","西游记|","三国演义|","红楼梦|","水浒传") print(6) print(1 1)Python输入语句 input函数 input() input("我的名字是&#xff1a;") p…

五、Spring IoCDI ★ ✔

5. Spring IoC&DI 1. IoC & DI ⼊⻔1.1 Spring 是什么&#xff1f;★ &#xff08;Spring 是包含了众多⼯具⽅法的 IoC 容器&#xff09;1.1.1 什么是容器&#xff1f;1.1.2 什么是 IoC&#xff1f;★ &#xff08;IoC: Inversion of Control (控制反转)&#xff09;总…

Excel 宏录制与VBA编程 —— 14、使用VBA处理Excel事件

简介 若希望特定事件处理程序在触发特定事件时运行&#xff0c;可以为 Application 对象编写事件处理程序。 Application 对象的事件处理程序是全局的&#xff0c;这意味着只要 Microsoft Excel 处于打开状态&#xff0c;事件处理程序将在发生相应的事件时运行&#xff0c;而不…

数据结构与算法笔记:高级篇 - 搜索:如何用 A* 搜索算法实现游戏中的寻路功能?

概述 魔兽世界、仙剑奇侠传这类 MMRPG 游戏&#xff0c;不知道你玩过没有&#xff1f;在这些游戏中&#xff0c;有一个非常重要的功能&#xff0c;那就是任务角色自动寻路。当任务处于游戏地图中的某个位置时&#xff0c;我们用鼠标点击另外一个相对较远的位置&#xff0c;任务…

简单分享 for循环,从基础到高级

1. 基础篇&#xff1a;Hello, For Loop! 想象一下&#xff0c;你想给班上的每位同学发送“Hello!”&#xff0c;怎么办&#xff1f;那就是for循环啦&#xff0c; eg&#xff1a;首先有个名字的列表&#xff0c;for循环取出&#xff0c;分别打印 names ["Alice", …

LabVIEW与PLC通讯方式及比较

LabVIEW与PLC之间的通讯方式多样&#xff0c;包括使用MODBUS协议、OPC&#xff08;OLE for Process Control&#xff09;、Ethernet/IP以及串口通讯等。这些通讯方式各有特点&#xff0c;选择合适的通讯方式可以提高系统的效率和稳定性。以下将详细介绍每种通讯方式的特点、优点…

Ubuntu24.04 Isaacgym的安装

教程1 教程2 教程3 1.下载压缩包 link 2. 解压 tar -xvf IsaacGym_Preview_4_Package.tar.gz3. 从源码安装 Ubuntu24.04还需首先进入虚拟环境 python -m venv myenv # 创建虚拟环境&#xff0c;已有可跳过 source myenv/bin/activate # 激活虚拟环境python编译 cd isaa…

Redis---保证主从节点一致性问题 +与数据库数据保持一致性问题

保证主从节点一致性问题 Redis的同步方式默认是异步的&#xff0c;这种异步的同步方式导致了主从之间的数据存在一定的延迟&#xff0c;因此Redis默认是弱一致性的。 解决&#xff1a; 1.使用Redisson这样的工具&#xff0c;它提供了分布式锁的实现&#xff0c;确保在分布式环…

React 中 useEffect

React 中 useEffect 是副作用函数&#xff0c;副作用函数通常是处理外围系统交互的逻辑。那么 useEffect 是怎处理的呢&#xff1f;React 组件都是纯函数&#xff0c;需要将副作用的逻辑通过副作用函数抽离出去&#xff0c;也就是副作用函数是不影响函数组件的返回值的。例如&a…

Codeforces Round 954 (Div. 3)(A~E)

目录 A. X Axis B. Matrix Stabilization C. Update Queries D. Mathematical Problem A. X Axis Problem - A - Codeforces 直接找到第二大的数&#xff0c;答案就是这个数与其他两个数的差值的和。 void solve() {vector<ll>a;for (int i 1; i < 3; i){int x;…

【实战】EasyExcel实现百万级数据导入导出

文章目录 前言技术积累实战演示实现思路模拟代码测试结果 前言 最近接到一个百万级excel数据导入导出的需求&#xff0c;大概就是我们在进行公众号API群发的时候&#xff0c;需要支持500w以上的openid进行群发&#xff0c;并且可以提供发送openid数据的导出功能。可能有的同学…