基于Qwen2-VL模型针对LaTeX OCR任务进行微调训练 - 原模型 多图推理
flyfish
输入
输出
[‘第一张图片是一幅中国山水画,描绘了一座山峰和周围的树木。第二张图片是一张现代照片,展示了一座山峰和周围的自然景观,包括水体和植被。’]
from PIL import Image
import requests
import torch
from torchvision import io
from typing import Dict
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
from modelscope import snapshot_download
from qwen_vl_utils import process_vision_info# 下载模型快照并指定保存目录
model_dir = snapshot_download("qwen/Qwen2-VL-7B-Instruct")# 加载模型到可用设备(CPU或GPU),并使用自动精度(根据设备自动选择)
# 使用 attn_implementation="flash_attention_2" 以利用更快的注意力机制实现
model = Qwen2VLForConditionalGeneration.from_pretrained(model_dir,torch_dtype="auto",device_map="auto",attn_implementation="flash_attention_2",
)# 打印 device_map 和模型信息
print("\nDevice Map:")
print(model.device_map)print("\nModel Information:")
print(model)# 打印模型的所有属性及其值(不包括方法)
print("\nModel Attributes and Their Values:")
attributes_and_methods = dir(model)
for attr in attributes_and_methods:try:value = getattr(model, attr)if not callable(value):print(f"{attr}: {value}")except AttributeError:continue# 加载图像处理器
processor = AutoProcessor.from_pretrained(model_dir, low_cpu_mem_usage=False)# 从本地获取文件
image_path1 = "./QueHuaQiuSe1.png"
image_path2 = "./QueHuaQiuSe2.png"# 定义对话历史,包括用户输入的文本和多个图像
messages = [{"role": "user","content": [{"type": "image", "image": image_path1},{"type": "image", "image": image_path2},{"type": "text", "text": "识别这些图像之间的不同之处。"},],}
]# 使用处理器应用聊天模板,并添加生成提示
# tokenize=False 表示不进行分词处理,add_generation_prompt=True 添加生成提示
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True
)# 处理视觉信息(图像和视频)
image_inputs, video_inputs = process_vision_info(messages)# 预处理输入数据,将文本和图像转换为模型可以接受的格式
inputs = processor(text=[text],images=image_inputs,videos=video_inputs,padding=True,return_tensors="pt",
)# 将输入数据移动到CUDA设备上(如果可用的话)
inputs = inputs.to("cuda")# 推理:生成输出文本
generated_ids = model.generate(**inputs, max_new_tokens=128) # 最大新生成token数量为128# 提取生成的token ID,去掉输入的原始token ID
generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]# 解码生成的token ID为人类可读的文本
output_text = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)# 打印生成的描述文本
print("\nGenerated Description Text:")
print(output_text)