通常情况下,大多数视频数据并不附带相应的描述性文本,因此有必要将视频数据转换为文本描述,为文本到视频模型提供必要的训练数据。 CogVLM2-Caption 是一个视频字幕模型,用于为 CogVideoX 模型生成训练数据。
文件
使用
import ioimport argparse
import numpy as np
import torch
from decord import cpu, VideoReader, bridge
from transformers import AutoModelForCausalLM, AutoTokenizerMODEL_PATH = "THUDM/cogvlm2-llama3-caption"DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
TORCH_TYPE = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 else torch.float16parser = argparse.ArgumentParser(description="CogVLM2-Video CLI Demo")
parser.add_argument('--quant', type=int, choices=[4, 8], help='Enable 4-bit or 8-bit precision loading', default=0)
args = parser.parse_args([])def load_video(video_data, strategy='chat'):bridge.set_bridge('torch')mp4_stream = video_datanum_frames = 24decord_vr = VideoReader(io.BytesIO(mp4_stream), ctx=cpu(0))frame_id_list = Nonetotal_frames = len(decord_vr)if strategy == 'base':clip_end_sec = 60clip_start_sec = 0start_frame = int(clip_start_sec * decord_vr.get_avg_fps())end_frame = min(total_frames,int(clip_end_sec * decord_vr.get_avg_fps())) if clip_end_sec is not None else total_framesframe_id_list = np.linspace(start_frame, end_frame - 1, num_frames, dtype=int)elif strategy == 'chat':timestamps = decord_vr.get_frame_timestamp(np.arange(total_frames))timestamps = [i[0] for i in timestamps]max_second = round(max(timestamps)) + 1frame_id_list = []for second in range(max_second):closest_num = min(timestamps, key=lambda x: abs(x - second))index = timestamps.index(closest_num)frame_id_list.append(index)if len(frame_id_list) >= num_frames:breakvideo_data = decord_vr.get_batch(frame_id_list)video_data = video_data.permute(3, 0, 1, 2)return video_datatokenizer = AutoTokenizer.from_pretrained(MODEL_PATH,trust_remote_code=True,
)model = AutoModelForCausalLM.from_pretrained(MODEL_PATH,torch_dtype=TORCH_TYPE,trust_remote_code=True
).eval().to(DEVICE)def predict(prompt, video_data, temperature):strategy = 'chat'video = load_video(video_data, strategy=strategy)history = []query = promptinputs = model.build_conversation_input_ids(tokenizer=tokenizer,query=query,images=[video],history=history,template_version=strategy)inputs = {'input_ids': inputs['input_ids'].unsqueeze(0).to('cuda'),'token_type_ids': inputs['token_type_ids'].unsqueeze(0).to('cuda'),'attention_mask': inputs['attention_mask'].unsqueeze(0).to('cuda'),'images': [[inputs['images'][0].to('cuda').to(TORCH_TYPE)]],}gen_kwargs = {"max_new_tokens": 2048,"pad_token_id": 128002,"top_k": 1,"do_sample": False,"top_p": 0.1,"temperature": temperature,}with torch.no_grad():outputs = model.generate(**inputs, **gen_kwargs)outputs = outputs[:, inputs['input_ids'].shape[1]:]response = tokenizer.decode(outputs[0], skip_special_tokens=True)return responsedef test():prompt = "Please describe this video in detail."temperature = 0.1video_data = open('test.mp4', 'rb').read()response = predict(prompt, video_data, temperature)print(response)if __name__ == '__main__':test()
感谢大家花时间阅读我的文章,你们的支持是我不断前进的动力。期望未来能为大家带来更多有价值的内容,请多多关注我的动态!