[CLIP-VIT-L + Qwen] 多模态大模型学习笔记 - 4

[CLIP-VIT-L + Qwen] 多模态大模型学习笔记 - 4

  • 前情提要
  • 源码解读(MQwenLMHeadModel类)
    • init函数
      • 总体含义
      • 逐行解读
    • prepare_inputs_for_generation函数
      • 整体含义
      • 逐行解读
    • forward函数
      • 整体含义
      • 逐行解读
    • main函数
      • 逐行解读

参考repo:WatchTower-Liu/VLM-learning; url: VLLM-BASE

前情提要

有关MQwenModel的代码请看(多模态大模型学习笔记 - 1、 多模态大模型学习笔记 - 2, 多模态大模型学习笔记 - 3)
本节中将接着看MQwen.py中的剩余源码,即MQwenLMHeadModel和main函数源码,MQwen.py重构了Qwen大模型中QwenModel的前向传播代码和QwenLMHeadModel的部分代码,以适配视觉编码器CLIP-VIT-L和语言模型Qwen的多模态架构,QwenModel类作为基座模型,QwenLMHeadModel 是基于 QwenModel 的一个扩展,加入了针对特定下游任务的头,在后续我们主要使用重写后的MQwenLMHeadModel作为多模态架构中的语言模型。

源码解读(MQwenLMHeadModel类)

init函数

  class MQWenLMHeadModel(QWenLMHeadModel):  def __init__(self, config, otherConfig):super().__init__(config)self.transformer = MQWenModel(config, otherConfig)if config.bf16:self.transformer.bfloat16()if config.fp16:self.transformer.half()

总体含义

初始化一个使用MQwenModel的类的实例,以便后续使用MQwenModel进行前向传播。

逐行解读

config 和 otherconfig一个作为初始化模型的通用配置参数,一个是用户自定义的额外参数传入。

    def __init__(self, config, otherConfig):

使用通用配置参数初始化父类。

 super().__init__(config)

初始化基座模型MQwenModel用于前向传播,传递入通用配置参数和自定义参数,赋值给成员变量self.transformer

self.transformer = MQWenModel(config, otherConfig)

确定将模型的权重转换为bf16(brain float 16)还是单精度浮点数(float16)数据格式,bf16和双精度浮点数(float32)有相同的动态范围,但是只需要16位的存储空间,可以看做单精度浮点数(float16)的变体。这两种精度都支持自动混合精度训练,可以减少内存占用,提高性能。

        if config.bf16:self.transformer.bfloat16()if config.fp16:self.transformer.half()

prepare_inputs_for_generation函数

    def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):if past_key_values:input_ids = input_ids[:, -1].unsqueeze(-1)if input_ids.size(0) == 1:attention_mask = Noneelse:attention_mask = kwargs.get("attention_mask", None)if inputs_embeds is not None and past_key_values is None:model_inputs = {"inputs_embeds": inputs_embeds}else:model_inputs = {"input_ids": input_ids}model_inputs.update({"past_key_values": past_key_values,"use_cache": kwargs.get("use_cache"),"attention_mask": attention_mask,"images": kwargs.get("images")})return model_inputs

整体含义

这段代码主要用于准备模型输入部分,对Input_ids,past_key_values等参数进行预处理,并返回一个包含这些值预处理结果的字典

逐行解读

input_ids: 由分词后的token被映射为词汇表中的唯一数字索引。例如’hello, world’分词后被映射为{111,222}(这里仅举例,不代表真实索引)
past_key_values:过去时间步中处理的序列输入数据的键值对缓存结果,通常为一个元组。
inputs_embeds:input_ids经过word_embedding层处理为的词嵌入向量。
一般来说,只需要传入input_ids和input_embeds中的其中一个即可、

    def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):

判断当前是否为推理、训练的第一步(如果启用use_cache的话)。input_ids的size通常为(batch_size,seq_len),这里我们只取每一批次input_ids的最后一个索引,这是因为如果我们有缓存的键值对,那么就无需重复计算先前缓存的键值对。unqueeze(-1)是因为当我们只取一个索引的时候,Input_ids会降维成(batch_size,),因此我们需要重新将其扩充为二维张量,size为(batch_size,1)。

        if past_key_values:input_ids = input_ids[:, -1].unsqueeze(-1)

如果当前批次为1,说明输入只有单个样本,因此不需要使用注意力掩码。否则我们就从关键词参数中获取,如果提供的话,否则置为None。

if input_ids.size(0) == 1:attention_mask = Noneelse:attention_mask = kwargs.get("attention_mask", None)

如果我们传入了inputs_embeds并且有缓存的键值对,这代表我们处于推理或训练的中间步骤,我们初始化一个包含input_embeds的字典,否则初始化一个包含input_ids的字典。

        if inputs_embeds is not None and past_key_values is None:model_inputs = {"inputs_embeds": inputs_embeds}else:model_inputs = {"input_ids": input_ids}

我们用update函数更新字典,update函数接受一个字典参数,并且覆盖原字典中键相同元素的值。最后将字典作为返回值返回。

        model_inputs.update({"past_key_values": past_key_values,"use_cache": kwargs.get("use_cache"),"attention_mask": attention_mask,"images": kwargs.get("images")})return model_inputs

forward函数

    def forward(self,input_ids: Optional[torch.LongTensor] = None,images: Optional[torch.Tensor] = None,past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,attention_mask: Optional[torch.FloatTensor] = None,token_type_ids: Optional[torch.LongTensor] = None,position_ids: Optional[torch.LongTensor] = None,head_mask: Optional[torch.FloatTensor] = None,inputs_embeds: Optional[torch.FloatTensor] = None,encoder_hidden_states: Optional[torch.Tensor] = None,encoder_attention_mask: Optional[torch.FloatTensor] = None,labels: Optional[torch.LongTensor] = None,use_cache: Optional[bool] = None,output_attentions: Optional[bool] = None,output_hidden_states: Optional[bool] = None,return_dict: Optional[bool] = None,) -> Union[Tuple, CausalLMOutputWithPast]:return_dict = (return_dict if return_dict is not None else self.config.use_return_dict)transformer_outputs = self.transformer(input_ids,images=images,past_key_values=past_key_values,attention_mask=attention_mask,token_type_ids=token_type_ids,position_ids=position_ids,head_mask=head_mask,inputs_embeds=inputs_embeds,encoder_hidden_states=encoder_hidden_states,encoder_attention_mask=encoder_attention_mask,use_cache=use_cache,output_attentions=output_attentions,output_hidden_states=output_hidden_states,return_dict=return_dict,)hidden_states = transformer_outputs[0]lm_logits = self.lm_head(hidden_states)loss = Noneif labels is not None:labels = labels.to(lm_logits.device)shift_logits = lm_logits[..., :-1, :].contiguous()shift_labels = labels[..., 1:].contiguous()loss_fct = CrossEntropyLoss()loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))if not return_dict:output = (lm_logits,) + transformer_outputs[1:]return ((loss,) + output) if loss is not None else outputreturn CausalLMOutputWithPast(loss=loss,logits=lm_logits,past_key_values=transformer_outputs.past_key_values,hidden_states=transformer_outputs.hidden_states,attentions=transformer_outputs.attentions,)

整体含义

这段代码是MQwenLMHeadModel的前向传播函数,包括输入处理、损失计算和输出格式化。

逐行解读

对于传递的参数不在赘述,在前几期的笔记中都有详细记载,除了labels参数,这一参数用于有监督训练,作为标签计算损失。

    def forward(self,input_ids: Optional[torch.LongTensor] = None,images: Optional[torch.Tensor] = None,past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,attention_mask: Optional[torch.FloatTensor] = None,token_type_ids: Optional[torch.LongTensor] = None,position_ids: Optional[torch.LongTensor] = None,head_mask: Optional[torch.FloatTensor] = None,inputs_embeds: Optional[torch.FloatTensor] = None,encoder_hidden_states: Optional[torch.Tensor] = None,encoder_attention_mask: Optional[torch.FloatTensor] = None,labels: Optional[torch.LongTensor] = None,use_cache: Optional[bool] = None,output_attentions: Optional[bool] = None,output_hidden_states: Optional[bool] = None,return_dict: Optional[bool] = None,) -> Union[Tuple, CausalLMOutputWithPast]:

确定返回值的输出格式,如果return_dict不为None,则以字典形式返回,否则获取通用配置参数中的默认值。

        return_dict = (return_dict if return_dict is not None else self.config.use_return_dict)

使用MQwenModel的前向传播函数获取输出,输出通常包括最后一层的隐藏状态,缓存的键值对,注意力分数,每一层的隐藏状态等等,具体可以参考repo中的MQwenModel的前向传播函数,或者前几期的学习笔记。

        transformer_outputs = self.transformer(input_ids,images=images,past_key_values=past_key_values,attention_mask=attention_mask,token_type_ids=token_type_ids,position_ids=position_ids,head_mask=head_mask,inputs_embeds=inputs_embeds,encoder_hidden_states=encoder_hidden_states,encoder_attention_mask=encoder_attention_mask,use_cache=use_cache,output_attentions=output_attentions,output_hidden_states=output_hidden_states,return_dict=return_dict,)

获取last_hidden_state,使用语言头来处理并生成预测用的对数概率。这里的self.lm_head继承自QwenLMHeadModel,具体可以参考Qwen模型的源码。
初始化loss为None,以便后续更新loss值。

        hidden_states = transformer_outputs[0]lm_logits = self.lm_head(hidden_states)loss = None

如果提供了训练用的标签,首先将它转移到对数概率挂载的设备(gpu或者cpu)上,取对数概率除了最后一个时间步的所有元素,labels取除了第一个标签外的所有标签,这样做是为了让预测结果和实际标签错开,让每个时间步的输出都有对应的下一个词的标签。例如我们的输出’ABCD’,标签也为’ABCD’,经过处理后输出为’ABC’,标签为‘BCD’,这样A对应B,C对应D,每个时间步的输出结果都有对应的下一个词作为标签。contiguous()函数目的是让变量在内存中连续。
损失函数设定为交叉熵损失函数,shift_logits原本的size为(batch_size, seq_len - 1,vocab_size),shift_labels的size为(batc_size,seq_len - 1),将这两个变量除了最后一个维度,其余维度展平,以便进行损失计算。

        if labels is not None:labels = labels.to(lm_logits.device)shift_logits = lm_logits[..., :-1, :].contiguous()shift_labels = labels[..., 1:].contiguous()loss_fct = CrossEntropyLoss()loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

如果不设定返回值类型为字典。则初始化output为对数概率加上前向传播输出除了last_hidden_state外的所有输出。
如果损失值不为空,将其与output一起返回,否则只返回output

        if not return_dict:output = (lm_logits,) + transformer_outputs[1:]return ((loss,) + output) if loss is not None else output

反之,我们返回一个CausalLMOutputWithPast类型的输出结果,这个类用于封装因果模型的前向传播输出结果

        return CausalLMOutputWithPast(loss=loss,logits=lm_logits,past_key_values=transformer_outputs.past_key_values,hidden_states=transformer_outputs.hidden_states,attentions=transformer_outputs.attentions,)

main函数

def main():MQ = MQWenLMHeadModel.from_pretrained("huggingface_model/qwen/Qwen-1_8B/", torch_dtype = torch.bfloat16, trust_remote_code = True)if __name__ == "__main__":main()

逐行解读

使用MQWenLMHeadModel从huggingface加载预训练模型的权重配置,在保留原有模型的基础功能和预训练权重的同时,添加新的功能或改进现有功能。frompretrained方法通常继承自huggingface的pretrainedmodel类

MQ = MQWenLMHeadModel.from_pretrained("huggingface_model/qwen/Qwen-1_8B/", torch_dtype = torch.bfloat16, trust_remote_code = True)

至此,MQwen.py讲解完毕,后续会讲解repo中的其余部分,并动手训练实现一个多模态大模型

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/405524.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Go】实现字符切片零拷贝开销转为字符串

package mainimport ("fmt""unsafe" )func main() {bytes : []byte("hello world")s : *(*string)(unsafe.Pointer(&bytes))fmt.Println(s)bytes[0] Hfmt.Println(s) }slice的底层结构是底层数组、len字段、cap字段。string的底层结构是底层…

产品帮助中心如何搭建?五步让客户满意度提升100%

一、引言 创建帮助文章的好处是节省了招募大量客户联系代理的昂贵成本。它们现在通过解决客户的早期问题而无需支持干预,并为自助提供逐步指导,从而取代了支持代理。 当您创建帮助文章时,您会构建知识库并为将来保留它。这些帮助文章充当新…

案例:ZooKeeper + Kafka消息队列集群部署

目录 消息队列 概念 使用场景 不适宜 适宜 消息队列的特征 存储 异步 异步的优点 同步 为什么需要消息队列 解耦 作用 冗余 扩展性 灵活性 峰值处理能力 可恢复性 顺序保证 Kafka 概念 Kafka技术名词 (1)Broker (2&a…

C语言一笔画迷宫

目录 开头程序程序的流程图程序游玩的效果结尾 开头 大家好&#xff0c;我叫这是我58。 程序 #define _CRT_SECURE_NO_WARNINGS 1 #include <stdio.h> #include <string.h> #include <Windows.h> void printmaze(const char strmaze[11][11]) {int ia 0;…

BUG——imx6u开发_结构体导致的死机问题(未解决)

简介&#xff1a; 最近在做imx6u的linux下裸机驱动开发&#xff0c;由于是学习的初级阶段&#xff0c;既没有现成的IDE可以使用&#xff0c;也没有GDB等在线调试工具&#xff0c;只能把代码烧写在SD卡上再反复插拔&#xff0c;仅靠卑微的亮灯来判断程序死在哪一步。 至于没有使…

41-设计规则:线宽规则

1.设置电源线规则和信号线规则 2.设置信号线规则 3.设置电源线规则 如果未生效&#xff1a; ① 提升优先级即可。 ②查看使能选项有没有勾选

20:【stm32】定时器一:时基单元

时基单元 1、什么是定时器2、时基单元的基本结构2.1&#xff1a;脉冲的来源2.2&#xff1a;预分频器PSC2.3&#xff1a;计数器CNT2.4&#xff1a;update事件与预加载 3、标准库编程3.1&#xff1a;通过定时器中断来设置延迟函数 1、什么是定时器 定时器是一种专门负责定时功能…

Vue 满屏纵向轮播图

目录 前言轮播图效果展示具体实现实现思路具体代码前言 今天汇总一个需求,还是之前写的,要求写一个满屏的轮播图,准确的说,是鼠标滑动到轮播图的时候,轮播图固定在屏幕上,随着其中的轮播子项遍历结束后,解除固定的效果。原本我最开始想直接修改Element-UI的组件的,但是…

CISAW认证考试的时间是多久

CISAW&#xff0c;即中国信息安全保障人员&#xff0c;是中国信息安全认证与审查中心进行权威认证的缩写。它是全国范围内最为权威、最高端的信息安全认证之一。作为信息安全领域的重要认证&#xff0c;对于从事网络安全工作的人员来说具有极其重要的意义。因此&#xff0c;备考…

【容器安全系列Ⅲ】- 深入了解Capabilities的作用

在本系列的上一部分中&#xff0c;我们提到 Docker 容器尚未使用 time 命名空间。我们还探讨了容器在许多情况下如何以 root 用户身份运行。考虑到这两点&#xff0c;如果我们尝试更改容器内的日期和时间会发生什么&#xff1f; 为了测试这一点&#xff0c;我们先运行 docker r…

入门网络安全工程师要学习哪些内容

大家都知道网络安全行业很火&#xff0c;这个行业因为国家政策趋势正在大力发展&#xff0c;大有可为!但很多人对网络安全工程师还是不了解&#xff0c;不知道网络安全工程师需要学什么?知了堂小编总结出以下要点。 网络安全工程师是一个概称&#xff0c;学习的东西很多&…

2000-2023年逐年最大NDVI数据集(500m)

植被指数&#xff08;NDVI, Normalized Difference Vegetation Index&#xff09;可以准确反映地表植被覆盖状况。目前&#xff0c;NDVI时序数据已经在各尺度区域的植被动态变化监测、土地利用/覆被变化检测、宏观植被覆盖分类和净初级生产力估算等研究中得到了广泛的应用。 中…

【java】RuoYi-Vue前后端分离版本-请求被拦截,怎么修改拦截过滤器,解决方案

【java】RuoYi-Vue前后端分离版本-请求被拦截&#xff0c;怎么修改拦截过滤器 它用到了一个安全管理框架Spring Security 你可以通过这篇文章《Spring Security 详解》 去了解它&#xff0c;怎么使用或者使用原理。 所有业务都受SecurityConfig配置所过滤 SecurityConfig配置…

【功能自动化】使用Excel文档获取参数数据

环境搭建&#xff1a; 1.需要配置WebTours网站 2.安装pandas pip install -i https://pypi.tuna.tsinghua.edu.cn/simple numpy pip install -i https://pypi.tuna.tsinghua.edu.cn/simple pandas pip install -i https://pypi.tuna.tsinghua.edu.cn/simple python_dateutil…

设计模式(3)结构型模式

结构型模式 结构型模式1. Adapter&#xff08;适配器模式&#xff09;2. Bridge&#xff08;桥接模式&#xff09;3.Composite&#xff08;组合模式&#xff09;4.Decorator&#xff08;装饰模式&#xff09;5.Facade&#xff08;外观模式&#xff09;6.Flyweight&#xff08;享…

14、Ripper

难度 低->中 目标 一个root 两个flag kali 192.168.135.58 靶机 192.168.135.104 netdiscover -i eth0 -r 192.168.135.0/24 端口扫描 先访问一下80端口和10000端口&#xff0c;这两个都是web服务的样子 80端口是初始化界面&#xff0c;可以尝试扫扫目录 访问10000端口…

Linux升级lib64中的libc.so.6导致所有命令失效

ls: relocation error: libpthread.so.0: symbol __libc_dl_error_tsd, version GLIBC_PRIVATE not defined in file libc.so.6 with link time reference 升级Glibc后出现所有shell命令都不可用 # systemctl status systemctl: relocation error: /lib64/libpthread.so.0: sy…

Ollama 企业私有化部署大模型最佳解决方案

为什么要私有化部署大模型&#xff1f; 很多企业为了控制成本和减少核心数据外泄的风险&#xff0c;会通过私有化部署大模型&#xff0c;来控制成本和保障企业的数据安全。 说到本地化部署&#xff0c;这时就需要说到Ollama框架了。 Ollama 是什么&#xff1f; Ollama 是一个开…

张宇1000题vs武忠祥严选题,哪本更接近真题?

张宇1000题强化篇难度还是挺大的 首先是综合度比较高&#xff0c;如果你基础复习的不好&#xff0c;不建议做&#xff0c;张宇1000题强化篇的难度还是比较大的&#xff0c;适合基础已经比较扎实的同学来做&#xff01; 张宇1000题与张宇的高数18讲等课程紧密结合&#xff0c;…

BEV世界:通过统一的BEV潜在空间实现自动驾驶的多模态世界模型

BEVWorld: A Multimodal World Model for Autonomous Driving via Unified BEV Latent Space BEV世界&#xff1a;通过统一的BEV潜在空间实现自动驾驶的多模态世界模型 Abstract World models are receiving increasing attention in autonomous driving for their ability t…