【HuggingFace Transformers】OpenAIGPTModel源码解析

OpenAIGPTModel源码解析

  • 1. GPT 介绍
  • 2. OpenAIGPTModel类 源码解析

说到ChatGPT,大家可能都使用过吧。2022年,ChatGPT的推出引发了广泛的关注和讨论。这款对话生成模型不仅具备了强大的语言理解和生成能力,还能进行非常自然的对话,给用户带来了全新的互动体验。然而,ChatGPT的成功背后离不开它的前身——GPT

1. GPT 介绍

GPT(Generative Pre-trained Transformer)是由OpenAI开发的一种基于Transformer架构的大型语言模型。它由多个堆叠的自注意力解码器层(Transformer Blocks)组成,每一层包含多头自注意力机制和前馈神经网络,并配有残差连接和层归一化以稳定训练。GPT采用自回归方式生成文本,通过在大规模互联网数据上进行预训练,具备强大的自然语言理解和生成能力,能够完成对话生成、文本补全等多种任务。其结构如下:

在这里插入图片描述

2. OpenAIGPTModel类 源码解析

源码地址:transformers/src/transformers/models/openai/modeling_openai.py

# -*- coding: utf-8 -*-
# @time: 2024/9/3 20:39
from typing import Optional, Union, Tupleimport torchfrom torch import nn
from transformers import add_start_docstrings, OpenAIGPTPreTrainedModel
from transformers.modeling_outputs import BaseModelOutput
from transformers.models.openai.modeling_openai import OPENAI_GPT_START_DOCSTRING, Block, OPENAI_GPT_INPUTS_DOCSTRING, _CHECKPOINT_FOR_DOC, _CONFIG_FOR_DOC
from transformers.utils import add_start_docstrings_to_model_forward, add_code_sample_docstrings@add_start_docstrings("The bare OpenAI GPT transformer model outputting raw hidden-states without any specific head on top.",OPENAI_GPT_START_DOCSTRING,
)
class OpenAIGPTModel(OpenAIGPTPreTrainedModel):def __init__(self, config):super().__init__(config)self.tokens_embed = nn.Embedding(config.vocab_size, config.n_embd)  # 定义 token 嵌入层self.positions_embed = nn.Embedding(config.n_positions, config.n_embd)  # 定义 position 嵌入层self.drop = nn.Dropout(config.embd_pdrop)  # 定义 drop 层self.h = nn.ModuleList([Block(config.n_positions, config, scale=True) for _ in range(config.n_layer)]) # 定义多个 Block 层# 注册一个缓冲区用于存储position_ids,初始化为从 0 到 config.n_positions 的序列self.register_buffer("position_ids", torch.arange(config.n_positions), persistent=False)# Initialize weights and apply final processingself.post_init()def get_input_embeddings(self):return self.tokens_embeddef set_input_embeddings(self, new_embeddings):self.tokens_embed = new_embeddingsdef _prune_heads(self, heads_to_prune):"""Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer}"""# 剪掉模型多头注意力机制中的一些头,heads_to_prune 是一个字典,键为layer_num,值为需要剪枝的 heads 列表。for layer, heads in heads_to_prune.items():self.h[layer].attn.prune_heads(heads)@add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING)@add_code_sample_docstrings(checkpoint=_CHECKPOINT_FOR_DOC,output_type=BaseModelOutput,config_class=_CONFIG_FOR_DOC,)def forward(self,input_ids: Optional[torch.LongTensor] = None,attention_mask: Optional[torch.FloatTensor] = None,token_type_ids: Optional[torch.LongTensor] = None,position_ids: Optional[torch.LongTensor] = None,head_mask: Optional[torch.FloatTensor] = None,inputs_embeds: Optional[torch.FloatTensor] = None,output_attentions: Optional[bool] = None,output_hidden_states: Optional[bool] = None,return_dict: Optional[bool] = None,) -> Union[Tuple[torch.Tensor], BaseModelOutput]:# 根据 config 配置设定 output_attentions, output_hidden_states, return_dict 的值output_attentions = output_attentions if output_attentions is not None else self.config.output_attentionsoutput_hidden_states = (output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states)return_dict = return_dict if return_dict is not None else self.config.use_return_dict# 获取 input_ids 或者 inputs_embeds 以及 input_shapeif input_ids is not None and inputs_embeds is not None:  # 当 input_ids 和 inputs_embeds 同时存在时,抛出错误raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")elif input_ids is not None:  # 如果存在 input_ids,将其形状调整为 (batch_size, sequence_length)self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)input_shape = input_ids.size()input_ids = input_ids.view(-1, input_shape[-1])elif inputs_embeds is not None:  # 如果存在 inputs_embeds,获取其形状input_shape = inputs_embeds.size()[:-1]else:  # 如果 input_ids 和 inputs_embeds 都不存在,抛出错误raise ValueError("You have to specify either input_ids or inputs_embeds")# 如果没有传入 position_ids,则生成默认的 position_idsif position_ids is None:# Code is different from when we had a single embedding matrix from position and token embeddingsposition_ids = self.position_ids[None, : input_shape[-1]]# ------------------------------------- 1. 获取 attention_mask -----------------------------## Attention mask.if attention_mask is not None:# We create a 3D attention mask from a 2D tensor mask.# Sizes are [batch_size, 1, 1, to_seq_length]# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]# this attention mask is more simple than the triangular masking of causal attention# used in OpenAI GPT, we just need to prepare the broadcast dimension here.attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)  # 将 2D 掩码扩展为 3D 掩码,适用于批量输入# Since attention_mask is 1.0 for positions we want to attend and 0.0 for# masked positions, this operation will create a tensor which is 0.0 for# positions we want to attend and the dtype's smallest value for masked positions.# Since we are adding it to the raw scores before the softmax, this is# effectively the same as removing these entirely.# 将注意力掩码转换为与模型参数相同的数据类型,并进行数值变换,torch.finfo(self.dtype).min 返回数据类型的最小值。attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype)  # fp16 compatibilityattention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min# ----------------------------------------------------------------------------------------## ------------------------------------- 2. 获取 head_mask ---------------------------------## Prepare head mask if neededhead_mask = self.get_head_mask(head_mask, self.config.n_layer)# ---------------------------------------------------------- -----------------------------## ------------------------------------- 3. 获取 hidden_states -----------------------------## 如果 inputs_embeds 为 None,则使用 tokens_embed 对 input_ids 计算if inputs_embeds is None:inputs_embeds = self.tokens_embed(input_ids)# 计算 position_embedsposition_embeds = self.positions_embed(position_ids)# 如果存在 token_type_ids,使用 tokens_embed 计算;否则 token_type_embeds 为 0if token_type_ids is not None:token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1))token_type_embeds = self.tokens_embed(token_type_ids)else:token_type_embeds = 0# 计算 hidden_states,即inputs_embeds、position_embeds 和 token_type_embeds 之和,并使用 dropouthidden_states = inputs_embeds + position_embeds + token_type_embedshidden_states = self.drop(hidden_states)# -------------------------------------------------------------------------------------## 获取输出形状,以及初始化输出结果 all_attentions 和 all_hidden_statesoutput_shape = input_shape + (hidden_states.size(-1),)all_attentions = () if output_attentions else Noneall_hidden_states = () if output_hidden_states else None# -----------------------------------4. Block逐层计算处理(核心部分)--------------------#for i, block in enumerate(self.h):# 如果需要输出 hidden states,将当前 hidden_states 添加到 all_hidden_statesif output_hidden_states:all_hidden_states = all_hidden_states + (hidden_states,)# 通过当前 Block 处理 hidden_states,得到新的 hidden_states 和 attentionsoutputs = block(hidden_states, attention_mask, head_mask[i], output_attentions=output_attentions)hidden_states = outputs[0]# 如果需要输出 attentions,将当前 attentions 添加到 all_attentionsif output_attentions:all_attentions = all_attentions + (outputs[1],)# ---------------------------------------------------------------------------------## 将 hidden_states 的形状调整为输出形状hidden_states = hidden_states.view(*output_shape)# 如果需要输出 hidden states,将最后的 hidden_states 添加到 all_hidden_statesif output_hidden_states:all_hidden_states = all_hidden_states + (hidden_states,)# -----------------------------------5. 根据配置的输出方式输出结果-------------------------------#if not return_dict:return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)return BaseModelOutput(last_hidden_state=hidden_states,hidden_states=all_hidden_states,attentions=all_attentions,)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/418848.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

MapSet之二叉搜索树

系列文章: 1. 先导片--Map&Set之二叉搜索树 2. Map&Set之相关概念 目录 前言 1.二叉搜索树 1.1 定义 1.2 操作-查找 1.3 操作-新增 1.4 操作-删除(难点) 1.5 总体实现代码 1.6 性能分析 前言 TreeMap 和 TreeSet 是 Java 中基于搜索树实现的 M…

图形语言传输格式glTF和三维瓦片数据3Dtiles(b3dm、pnts)学习

文章目录 一、3DTiles二、b3dm三、glTF1.glTF 3D模型格式有两种2.glTF 场景描述结构和坐标系3.glTF的索引访问与ID4.glTF asset5.glTF的JSON结构scenesscene.nodes nodesnodes.children transformations对外部数据的引用buffers 原始二进制数据块,没有固有的结构或含…

表单项标签简单学习

目录 1. 单选框 radio​编辑​编辑​编辑​编辑 2. 复选框 checkbox ​编辑​编辑​编辑 3. 隐藏域 hidden 4. 多行文本框 textarea​编辑​编辑 5. 下拉框 select​编辑​编辑 6. 选择头像​编辑​编辑 <!DOCTYPE html> <html lang"en"> <head&…

自用NAS系列1-设备

拾光坞 拾光坞多账号绑定青龙面板SMBWebdav小雅alist下载到NASDocker安装迅雷功能利用qBittorrentEEJackett打造一站式下载工具安装jackett插件 外网访问内网拾光客户端拾光穿透公网ipv6路由器配置ipv6拾光坞公网验证拾光坞域名验证 拾光坞 多账号绑定 手机注册拾光坞账号&am…

GEE数据集:加拿大卫星森林资源调查 (SBFI)-2020 年加拿大森林覆盖、干扰恢复、结构、物种、林分年龄以及 1985-2020 年林分替代干扰的信息

目录 简介 数据集后处理 数据下载链接 矢量属性 代码 代码链接 引用 许可 网址推荐 0代码在线构建地图应用 机器学习 加拿大卫星森林资源调查 (SBFI) 简介 卫星森林资源清查&#xff08;SBFI&#xff09;提供了 2020 年加拿大森林覆盖、干扰恢复、结构、物种、林分…

海外云手机是否适合运营TikTok?

随着科技的迅猛发展&#xff0c;海外云手机逐渐成为改变工作模式的重要工具。这种基于云端技术的虚拟手机&#xff0c;不仅提供了更加便捷、安全的使用体验&#xff0c;还在电商引流和海外社媒管理等领域展示了其巨大潜力。那么&#xff0c;海外云手机究竟能否有效用于运营TikT…

828华为云征文 | Flexus X 实例服务器网络性能深度评测

引言 随着互联网应用的快速发展&#xff0c;网络带宽和性能对云服务器的表现至关重要。在不同的云服务平台上&#xff0c;即便配置相同的带宽&#xff0c;实际的网络表现也可能有所差异。因此&#xff0c;了解并测试服务器的网络性能变得尤为重要。本文将以华为云X实例服务器为…

Open-Sora代码详细解读(1):解读DiT结构

Diffusion Models专栏文章汇总&#xff1a;入门与实战 前言&#xff1a;目前开源的DiT视频生成模型不是很多&#xff0c;Open-Sora是开发者生态最好的一个&#xff0c;涵盖了DiT、时空DiT、3D VAE、Rectified Flow、因果卷积等Diffusion视频生成的经典知识点。本篇博客从Open-S…

【MySQL】MySQL基础

目录 什么是数据库主流数据库基本使用MySQL的安装连接服务器服务器、数据库、表关系使用案例数据逻辑存储 MySQL的架构SQL分类什么是存储引擎 什么是数据库 mysql它是数据库服务的客户端mysqld它是数据库服务的服务器端mysql本质&#xff1a;基于C&#xff08;mysql&#xff09…

linux系统中,计算两个文件的相对路径

realpath --relative-to/home/itheima/smartnic/smartinc/blocks/ruby/seanet_diamond/tb/parser/test_parser_top /home/itheima/smartnic/smartinc/corundum/fpga/lib/eth/lib/axis/rtl/axis_fifo.v 检验方式就是直接在当前路径下&#xff0c;把输出的路径复制一份&#xff0…

Nginx跨域运行案例:云台控制http请求,通过 http server 代理转发功能,实现跨域运行。(基于大华摄像头WEB无插件开发包)

文章目录 引言I 跨域运行案例开发资源测试/生产环境,Nginx代理转发,实现跨域运行本机开发运行II nginx的location指令Nginx配置中, 获取自定义请求header头Nginx 配置中,获取URL参数引言 背景:全景监控 需求:感知站点由于云台相关操作为 http 请求,http 请求受浏览器…

Redis-主从集群

主从架构 单节点Redis的并发能力是有上限的&#xff0c;要进一步提高Redis的并发能力&#xff0c;就需要搭建主从集群&#xff0c;实现读写分离。 主从数据同步原理 全量同步 主从第一次建立连接时&#xff0c;会执行全量同步&#xff0c;将master节点的所有数据都拷贝给sla…

34465A-61/2 数字万用表(六位半)

34465A-61/2 数字万用表(六位半) 文章目录 34465A-61/2 数字万用表(六位半)前言一、测DC/AC电压二、测DC/AC电流四、测电阻五、测电容六、测二极管七、保存截图流程前言 1、6位半数字万用表通常具有200,000个计数器,可以显示最大为199999的数值。相比普通数字万用表,6位半…

注册安全分析报告:熊猫频道

前言 由于网站注册入口容易被黑客攻击&#xff0c;存在如下安全问题&#xff1a; 暴力破解密码&#xff0c;造成用户信息泄露短信盗刷的安全问题&#xff0c;影响业务及导致用户投诉带来经济损失&#xff0c;尤其是后付费客户&#xff0c;风险巨大&#xff0c;造成亏损无底洞…

【笔记】Java | 三目运算符和Math函数的比较

实际效果 比较两数并赋值&#xff0c;如下两种方法的耗时不会有差异。 result Math.min(result, subLen);result result < subLen ? result : subLen; 源码解析 因为源码Math.min的源码本质就算三目运算符的比较&#xff0c;所以执行结果是一样的。 三目运算符简介 概…

怎么强制撤销excel工作表保护?

经常不是用的Excel文件设置了工作表保护&#xff0c;偶尔打开文件的时候想要编辑文件&#xff0c;但是发现忘记了密码&#xff0c;那么这种情况&#xff0c;我们怎么强制撤销excel工作表保护&#xff1f;今天分享两种解决方法。 方法一、 将excel文件转换为其他文件格式&…

新品上市丨科学级新款制冷相机sM4040A/sM4040B

sM4040B科学级显微制冷相机 特性 sM4040B搭载了 GSENSE4040BSI 3.2 英寸图像传感器&#xff0c;针对传感器固有的热噪声&#xff0c;专门设计了高效制冷模块&#xff0c;使得相机传感器的工作温度比环境温度低达 35-40 度。针对制冷相机常见的低温结雾现象设计了防结雾机制&a…

二百五十九、Java——采集Kafka数据,解析成一条条数据,写入另一Kafka中(一般JSON)

一、目的 由于部分数据类型频率为1s&#xff0c;从而数据规模特别大&#xff0c;因此完整的JSON放在Hive中解析起来&#xff0c;尤其是在单机环境下&#xff0c;效率特别慢&#xff0c;无法满足业务需求。 而Flume的拦截器并不能很好的转换数据&#xff0c;因为只能采用Java方…

鸿蒙自动化发布测试版本app

创建API客户端 API客户端是AppGallery Connect用于管理用户访问AppGallery Connect API的身份凭据&#xff0c;您可以给不同角色创建不同的API客户端&#xff0c;使不同角色可以访问对应权限的AppGallery Connect API。在访问某个API前&#xff0c;必须创建有权访问该API的API…

UE5.3_跟一个插件—Socket.IO Client

网上看到这个插件,挺好! 项目目前也没有忙到不可开交,索性跟着测一下吧: 商城可见,售价72.61人民币! 但是,git上有仓库哦,免费!! 跟着链接先准备起来: Documentation: GitHub - getnamo/SocketIOClient-Unreal: Socket.IO client plugin for the Unreal Engin…