深度学习_GPT2Block详解(casual attention)

一、GTP2Block 整体结构

1.1 block准备

import torch 
from torch import nn
from transformers import GPT2Model, GPT2Config
from transformers.models.gpt2.modeling_gpt2 import GPT2Blockcfg = GPT2Config()
print(cfg.add_cross_attention)
blk = GPT2Block(cfg, layer_idx=0)
hidden_states = torch.randn(10, 1024, 768)

1.2 block架构

经典的preNorm TFDecoder架构

GPT2Block((ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)(attn): GPT2Attention((c_attn): Conv1D()(c_proj): Conv1D()(attn_dropout): Dropout(p=0.1, inplace=False)(resid_dropout): Dropout(p=0.1, inplace=False))(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)(mlp): GPT2MLP((c_fc): Conv1D()(c_proj): Conv1D()(act): NewGELUActivation()(dropout): Dropout(p=0.1, inplace=False))
)

1.3 forward-preNorm

y = attn(ln_1(x)) + x
O = mlp(ln_2(y)) + y

在这里插入图片描述
在这里插入图片描述

二、GPT2Attention

  1. hidden 拆分成 q k v: query, key, value = gpt2_att.c_attn(hidden_states).split(split_size, dim=2)
  2. q k v 拆分成多头
query = gpt2_att._split_heads(query, gpt2_att.num_heads, gpt2_att.head_dim)
key = gpt2_att._split_heads(key, gpt2_att.num_heads, gpt2_att.head_dim)
value = gpt2_att._split_heads(value, gpt2_att.num_heads, gpt2_att.head_dim)
print(f'{query.shape=}') # [batch, n_head, len, head_emb] 
  1. 计算attention
    1. A ^ = Q K T K d i m \hat{A}=\frac{QK^T}{\sqrt{K_{dim}}} A^=Kdim QKT 代码中用的是 V d i m \sqrt{V_{dim}} Vdim
    2. casual attention: 对原始attn进行mask
    3. 计算mask后的attention: A = s o f t m a x ( A ^ , d i m = − 1 ) A=softmax(\hat{A}, dim=-1) A=softmax(A^,dim=1)
    4. O = A V O=AV O=AV
# 3- attention 
#  3.1 A = QK^T
attn_weights = torch.matmul(query, key.transpose(-1, -2)) / torch.full([], value.size(-1) ** 0.5)
#  3.2 mask 
max_positions = 1024
causal_mask = torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)
).view(1, 1, max_positions, max_positions)
mask_value = torch.finfo(attn_weights.dtype).min
mask_value = torch.full([], mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
# where mask
attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value)
#  3.3 A = softmax(A)
attn_weights = nn.functional.softmax(attn_weights, dim=-1) # [batch, n_head, len, len] 
#  3.4  O = AV
attn_output = torch.matmul(attn_weights, value)            # [batch, n_head, len, head_emb] 
# 4- q k v -> merge head -> attn_out # [batch, len, head_emb*n_head] 
attn_output = gpt2_att._merge_heads(attn_output, gpt2_att.num_heads, gpt2_att.head_dim)
  1. 多头合并 [batch, n_head, len, head_emb] =>> [batch, len, head_emb*n_head]
    1. attn_output = gpt2_att._merge_heads(attn_output, gpt2_att.num_heads, gpt2_att.head_dim)

pic-attn_weights mask前后

三、GPT2MLP

结构比较简单 O = d r o p O u t ( σ ( X W 1 ) W 2 ) O=dropOut(\sigma (XW_1)W_2) O=dropOut(σ(XW1)W2),主要是激活函数 NewGELU

GPT2MLP((c_fc): Conv1D()(c_proj): Conv1D()(act): NewGELUActivation()(dropout): Dropout(p=0.1, inplace=False)
)class NewGELUActivation(nn.Module):"""Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also seethe Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415"""def forward(self, input: Tensor) -> Tensor:return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))

在这里插入图片描述

NewGELUActivation 它是高斯误差线性单元(Gaussian Error Linear Unit,简称 GELU)的一种变体。GELU 激活函数在近年来的深度学习模型中越来越受欢迎,尤其是在自然语言处理(NLP)领域,如 BERT 和 GPT 等模型中。

GELU 激活函数的数学定义是输入值 x 乘以标准正态分布的累积分布函数(CDF)在该点的值。具体来说,GELU 的表达式为:
G E L U ( x ) = x Φ ( x ) GELU(x)=x \Phi(x) GELU(x)=xΦ(x)

其中 Φ ( x ) \Phi(x) Φ(x) 是标准正态分布的 CDF,可以通过误差函数(error function,记为 erf)来计算:
Φ ( x ) = 1 2 ( 1 + e r f ( x 2 ) ) \Phi(x)=\frac{1}{2}(1+erf(\frac{x}{\sqrt 2})) Φ(x)=21(1+erf(2 x))
GPT2中用了近似公式:
σ ( x ) = 0.5 x [ 1 + t a n h ( 2 π ( x + 0.044715 x 3 ) ) ] \sigma(x) = 0.5x [1+ tanh(\sqrt{\frac{2}{\pi}} (x + 0.044715 x^3))] σ(x)=0.5x[1+tanh(π2 (x+0.044715x3))]

GELU 激活函数的优点包括:

  • 平滑性:GELU 在整个实数域上都是平滑的,这有助于梯度的传播,减少了梯度消失或爆炸的问题
  • 非单调性:GELU 函数是非单调的,这意味着它能够捕捉数据中的更复杂模式
  • 改善性能:在某些任务中,使用 GELU 激活函数的模型性能优于使用传统的 ReLU 或其他激活函数的模型。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/422598.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

《ECMAScript 与 JavaScript:差异与共通》

一、概念辨析 《ECMAScript 与 JavaScript:差异与共通》 ECMAScript(简称 ES)是一种由 Ecma International 标准化的脚本语言规范。它定义了脚本语言的核心特性,包括语法、类型、语句、关键字等。例如,ECMAScript 规定…

被要求撤回Blackwell?一家初创企业称英伟达侵权自家技术,忍无可忍!英伟达和伙伴微软被齐齐告上法庭,赔偿或高达数十亿!

刚刚,一家初创公司居然把巨头英伟达和微软一起告了! 名为Xockets的初创公司在诉讼中称,英伟达和微软公司窃取了其DPU技术,用以开发AI产品,并相互串通以压低其技术的价格,是名副其实的垄断行为!…

智汇创想pytest接口自动化测试框架

本测试框架是基于pytest搭建的接口自动化框架,对象为深圳智汇创想官方网站。深圳智汇创想科技有限责任公司(深圳智汇创想科技有限责任公司),是一家专注于跨境电子商务的集团公司,全球电商平台多品类多品牌的零售商&…

MATLAB | R2024b更新了哪些好玩的东西?

Hey, 又到了一年两度的MATLAB更新时刻,MATLAB R2024b正式版发布啦!,直接来看看有哪些我认为比较有意思的更新吧! 1 小提琴图 天塌了,我这两天才写了个半小提琴图咋画,MATLAB 官方就出了小提琴图绘制方法。 小提琴图…

客户端负载均衡Ribbon实例

文章目录 一,概述二,实现过程三,项目源码1. 源码放送:2. 部署方式 四,功能演示五,其他 一,概述 一般来说,提到负载均衡,大家一般很容易想到浏览器 -> NGINX -> 反…

加密与安全_ sm-crypto 国密算法sm2、sm3和sm4的Java库

文章目录 Presm-crypto如何使用如何引入依赖 sm2获取密钥对加密解密签名验签获取椭圆曲线点 sm3sm4加密解密 Pre 加密与安全_三种方式实现基于国密非对称加密算法的加解密和签名验签 sm-crypto https://github.com/antherd/sm-crypto 国密算法sm2、sm3和sm4的java版。基于js…

PMP--一模--解题--21-30

文章目录 9.资源管理21、 [单选] 项目经理发现一个不可预料的高影响风险已经成为项目的一个因素,团队成员之间的自身利益导致问题得不到解决,项目经理必须快速行动,让团队重新集中精力,以便项目恢复进度,项目经理应该使…

vue3项目实现全局国际化

本文主要梳理vue3项目实现全项目格式化,例如在我前面文章使用若依创建vue3的项目中,地址:若依搭建vue3项目在导航栏中切换,页面中所有的组件的默认语言随之切换,使用的组件库依旧是element-plus,搭配vue-i1…

09-排序1 排序(C)

这一节,测试各类排序算法的运行速度(没有基数排序(桶) 其实在实际学习中,还是有意义的 给定 n 个(长整型范围内的)整数,要求输出从小到大排序后的结果。 本题旨在测试各种不同的排序…

Windows与Linux下 SDL2的第一个窗口程序

Windows效果和Linux效果如下&#xff1a; 下面是代码&#xff1a; #include <stdio.h> #include "SDL.h"int main(int argc, char* argv[]) { // 初始化SDL视频子系统if (SDL_Init(SDL_INIT_VIDEO) ! 0){// 如果初始化失败&#xff0c;打印错误信息printf(&…

proteus+51单片机+实验(LCD1620、定时器)

目录 1.LCD1602液晶显示屏 1.1基本概念 1.1.1LCD的简介 1.1.2LCD的显示原理 ​​​1.1.3LCD的硬件电路 1.1.4LCD的常见指令 1.1.5LCD的时序 ​​​​​​​1.2代码 1.2.1写命令和写数据操作 1.2.2初始化和测试代码 1. 3.3功能函数 1.3proteus代码 1.3.1器件代码 1.…

探索Python世界的隐藏宝石:Pika库的神秘力量

文章目录 探索Python世界的隐藏宝石&#xff1a;Pika库的神秘力量背景&#xff1a;为何选择Pika&#xff1f;Pik库简介如何安装Pika&#xff1f;简单库函数使用方法场景应用常见Bug及解决方案总结 探索Python世界的隐藏宝石&#xff1a;Pika库的神秘力量 背景&#xff1a;为何…

ELK预警方案:API+XXLJob

目录 步骤一&#xff1a;出一个接口&#xff0c;接口内查询出10分钟内是否有异常信息 步骤二&#xff1a;XXLJob中设置预警的频率 步骤三&#xff1a;在重要的业务处输出指定格式日志即可 步骤一&#xff1a;出一个接口&#xff0c;接口内查询出10分钟内是否有异常信息 {&qu…

Java | Leetcode Java题解之第402题移掉K位数字

题目&#xff1a; 题解&#xff1a; class Solution {public String removeKdigits(String num, int k) {Deque<Character> deque new LinkedList<Character>();int length num.length();for (int i 0; i < length; i) {char digit num.charAt(i);while (!…

C语言字符函数和字符串函数(20)

文章目录 前言一、字符分类函数小练习 二、字符转换函数三、strlen的使用和模拟实现四、strcpy的使用和模拟实现五、strcat的使用和模拟实现六、strcmp的使用和模拟实现七、strncpy函数的使用八、strncat函数的使用九、strncmp函数的使用十、strstr函数的使用和模拟实现十一、s…

OpenGL3.3_C++_Windows(37)

调试&#xff1a; 视觉错误与CPU调试不同&#xff0c;在GLSL代码中也不能设置断点&#xff0c;出现错误的时候寻找错误的源头可能会非常困难。 glGetError&#xff08;&#xff09; GLenum glGetError();返回整形数字&#xff0c;查询错误标记&#xff0c;但是当一个错误标记…

C#开发基础之使用四种流行的数据库访问技术ADO.NET、Dapper、EF Core 和 SqlSugar 连接 SQL Server

前言 在这篇文章中&#xff0c;我们将介绍四种流行的数据库访问技术&#xff1a;ADO.NET、Dapper、Entity Framework Core (EF Core) 和 SqlSugar。每种技术都提供了与 SQL Server 进行交互的不同方法&#xff0c;我们将以 TestDB 数据库中的 User 表为例&#xff0c;展示如何…

关于malloc/free的一些知识点

序 关于malloc/free&#xff0c;我们都不陌生&#xff0c;在最开始学习c语言时就相当了解&#xff0c;包括c中的new也是封装的malloc。下边我以glibc实现的malloc来讲述一些关于malloc/free的知识点。 malloc/free malloc和free并不是系统调用&#xff0c;而是运行时库&…

C语言的结构体类型

在我们使用C语言进行编写代码时&#xff0c;常常会使用已经给定的类型来创建变量&#xff0c;比如int型&#xff0c;char型&#xff0c;double型等&#xff0c;而当我们想创建一些较为复杂的东西时&#xff0c;单单用一个类型变量是没办法做到的&#xff0c;比如我们想创建一个…

shader 案例学习笔记之fract函数

fract函数 可以理解为模1取余&#xff0c;获取一个数的小数部分&#xff0c;如果参数是向量&#xff0c;那就是获取每个向量分量上的小数 案例一 #ifdef GL_ES precision mediump float; #endif// 渲染分辨率 uniform vec2 u_resolution; // 程序运行时间 uniform float u_ti…