大模型基础——从零实现一个Transformer(2)

大模型基础——从零实现一个Transformer(1)

一、引言

上一章主要实现了一下Transformer里面的BPE算法和 Embedding模块定义
本章主要讲一下 Transformer里面的位置编码以及多头注意力

二、位置编码

2.1正弦位置编码(Sinusoidal Position Encoding)

其中:

pos:表示token在文本中的位置
: i代表词向量具体的某一维度,即位置编码的每个维度对应一个波长不同的正弦或余弦波
d : d表示位置编码的最大维度,和词嵌入的维度相同,假设是512

对于位置0的编码为:

对于位置1的编码为:

2.2 正弦位置编码特性

  • 相对位置关系:pos + k的位置编码可以被位置pos的位置编码线性表示
    三角函数公式如下:

对于pos + k的位置编码:

根据式( 3 )和( 4 )整理上式有:

  • 位置之间的相对距离

𝑃𝐸𝑝𝑜𝑠+𝑘∙𝑃𝐸𝑝𝑜𝑠 的内积:

位置之间内积的关系大小如下:

可以看到内积会随着相对位置的递增而减少,从而可以表示位置的相对距离。内积的结果是对称的,所以没有方向信息。

2.3 代码实现

import torch
from torch import nn,Tensor
import mathclass PositionalEmbedding(nn.Module):def __init__(self,d_model:int=512,dropout:float=0.1,max_positions:int=1024) -> None:''':param d_model: embedding向量的维度:param dropout::param max_positions: 最大长度'''super().__init__()self.dropout = nn.Dropout(p=dropout)# Position Embedding  (max_positions,d_model)pe = torch.zeros(max_positions,d_model)# 创建position index列表 ,形状为:(max_positions, 1)position = torch.arange(0,max_positions).unsqueeze(1)# d_model 维度 偶数位是sin ,奇数位是cos# 计算除数,这里的除数将用于计算正弦和余弦的频率div_term = torch.exp(torch.arange(0,d_model,2) * -(math.log(10000.0) /d_model))# 对矩阵的偶数列(0,2,4...)进行正弦函数编码pe[:, 0::2] = torch.sin(position * div_term)# 对矩阵的奇数列(1,3,5...)进行余弦函数编码pe[:, 1::2] = torch.cos(position * div_term)# 扩展维度,增加batch_size: pe (1, max_positions, d_model)pe = pe.unsqueeze(0)# buffers will not be trainedself.register_buffer("pe", pe)def forward(self,x:Tensor) ->Tensor:"""Args:x (Tensor): (batch_size, seq_len, d_model) embeddingsReturns:Tensor: (batch_size, seq_len, d_model)"""# x.size(1)是指当前x的最大长度x = x + self.pe[:,:x.size(1)]return self.dropout(x)if __name__ == '__main__':seq_len = 128d_model = 512pe = PositionalEmbedding(d_model)x = torch.rand((1,100,d_model))print(pe(x).shape)

三、多头注意力

3.1 自注意力

公式如下:

  • 假设一个矩阵X,分别乘上权重矩阵,,就得到了Q , K , V向量矩阵

  • 然后除以 𝑑𝑘 进行缩放,再经过Softmax,得到注意力权重矩阵,接着乘以value向量矩阵V,就一次得到了所有单词的输出矩阵Z

3.2 多头注意力

将原来n_head分割乘Nx n_sub_head.对于每个头i,都有它自己不同的key,query和value矩阵: 𝑊𝑖𝐾,𝑊𝑖𝑄,𝑊𝑖𝑉 。在多头注意力中,key和query的维度是 𝑑𝑘 ,value嵌入的维度是 𝑑𝑣 (其中key,query和value的维度可以不同,Transformer里面一般设置的是相同的),这样每个头i,权重 𝑊𝑖𝑄∈𝑅𝑑×𝑑𝑘,𝑊𝑖𝐾∈𝑅𝑑×𝑑𝑘,𝑊𝑖𝑉∈𝑅𝑑×𝑑𝑣 ,然后与压缩到X中的输入相乘,得到 𝑄∈𝑅𝑁×𝑑𝑘,𝐾∈𝑅𝑁×𝑑𝑘,𝑉∈𝑅𝑁×𝑑𝑣 .

3.3 代码实现

import mathimport torch
from torch import nn,Tensor
from typing import *class MultiHeadAttention(nn.Module):def __init__(self,d_model: int = 512,n_heads: int=8,dropout: float = 0.1):''':param d_model: embedding大小:param n_heads: 多头个数:param dropout:'''super().__init__()assert d_model % n_heads == 0self.d_model = d_modelself.n_heads = n_headsself.d_key = d_model // n_headsself.q = nn.Linear(d_model,d_model)self.k = nn.Linear(d_model,d_model)self.k = nn.Linear(d_model,d_model)self.concat = nn.Linear(d_model,d_model)self.dropout = nn.Dropout(dropout)def split_heads(self,x:Tensor,is_key : bool = False) -> Tensor:'''分割向量为N个头,如果是key的话,softmax时候,key需要转置一下:param x::param is_key::return:'''batch_size = x.size(0)# x (batch_size,seq_len,n_heads,d_key)x = x.view(batch_size,-1,self.n_heads,self.d_key)if is_key:# (batch_size,n_heads,d_key,seq_len)return x.permute(0,2,3,1)# (batch_size,n_heads,seq_len,d_keyreturn x.transpose(1,2)def merge_heads(self,x: Tensor) -> Tensor:x = x.transpose(1,2).contigouse().view(x.size(0),-1,self.d_model)return xdef attention(self,query:Tensor,key:Tensor,value:Tensor,mask:Tensor = None,keep_attentions:bool = False):scores = torch.matmul(query,key) / math.sqrt(self.d_key)if mask is not None:scores = scores.masked_fill(mask == 0, -1e9)# weights (batch_size,n_heads,q_length,k_length)weights = self.dropout(torch.softmax(scores,dim=-1))# (batch_size,n_heads,q_length,k_length) x (batch_size,n_heads,v_length,d_key)# -> (batch_size,n_heads,q_length,d_key)# assert k_length == v_length# attn_output (batch_size, n_heads, q_length, d_key)atten_output = torch.matmul(weights,value)if keep_attentions:self.weights = weightselse:del weightsreturn atten_outputdef forward(self,query: Tensor,key: Tensor,value: Tensor,mask: Tensor = None,keep_attentions: bool = False)-> Tuple[Tensor,Tensor]:''':param query:(batch_size, q_length, d_model):param key:(batch_size, k_length, d_model):param value:(batch_size, v_length, d_model):param mask: mask for padding or decoder. Defaults to None.:param keep_attentions: whether keep attention weigths or not. Defaults to False.:return: (batch_size, q_length, d_model) attention output'''query = self.q(query)key = self.k(key)value = self.v(value)query,key,value = (self.split_heads(query),self.split_heads(key,is_key=True),self.split_heads(value))atten_output = self.attention(query,key,value,mask,keep_attentions)del querydel keydel value# concatconcat_output = self.merge_heads(atten_output)# the final liear# output (batch_size, q_length, d_model)output = self.concat(concat_output)return output

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/346090.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

jmeter性能优化之mysql配置

一、连接数据库和grafana 准备:连接好数据库和启动grafana并导入mysql模板 大批量注册、登录、下单等,还有过节像618,双11和数据库交互非常庞大,都会存在数据库的某一张表里面,当用户在登录或者查询某一个界面时&…

第十二届蓝桥杯单片机国赛练习代码

文章目录 前言一、问题重述二、主函数总结 前言 第十五蓝桥杯国赛落幕已有十天,是时候总结一下,这个专栏也将结束。虽然并没有取得预期的结果,但故事结尾并不总是美满的。下面是赛前练习的第十二届国赛的代码。 一、问题重述 二、主函数 完整…

渗透测试模拟实战-tomexam网络考试系统

渗透测试,也称为“pentest”或“道德黑客”,是一种模拟攻击的网络安全评估方法,旨在识别和利用系统中的安全漏洞。这种测试通常由专业的安全专家执行,他们使用各种技术和工具来尝试突破系统的防御,如网络、应用程序、主…

高德地图简单实现点标,和区域绘制

高德地图开发文档:https://lbs.amap.com/api/javascript-api/guide/abc/quickstart 百度搜索高德地图开发平台 注册高德地图开发账号 在应用管理中 我的应用中 添加一个Key 点击提交 进入高德地图开发文档:https://lbs.amap.com/api/javascript-api/guide/abc/quickstart …

定个小目标之刷LeetCode热题(15)

这道题直接就采用两数相加的规则,维护一个进阶值(n)即可,代码如下 class Solution {public ListNode addTwoNumbers(ListNode l1, ListNode l2) {// 新建一个值为0的头结点ListNode newHead new ListNode(0);// 创建几个指针用于…

从零开始,手把手教你文旅产业策划全攻略

如果你想深入了解文旅策划的世界,那么有很多途径可以获取知识和灵感。 首先,阅读一些专业书籍也是一个不错的选择。书店或图书馆里有许多关于文旅策划的书籍,它们通常涵盖了策划的基本理论、方法和实践案例。通过阅读这些书籍,你…

对接专有钉钉(浙政钉)登陆步骤

背景 因为项目需要对接浙政钉,我想应该和之前对接阿里云的钉钉登陆钉钉登陆类似,就上网搜索看看,出现了个专有钉钉的概念,就一时间搞不清楚,钉钉,专有钉钉,浙政钉的区别,后续稍微理…

20240606更新Toybrick的TB-RK3588开发板在Android12下的内核

20240606更新Toybrick的TB-RK3588开发板在Android12下的内核 2024/6/6 10:51 0、整体编译: 1、cat android12-rk-outside.tar.gz* | tar -xzv 2、cd android12 3、. build/envsetup.sh 4、lunch rk3588_s-userdebug 5、./build.sh -AUCKu -d rk3588-toybrick-x0-a…

Vue3中的常见组件通信之mitt

Vue3中的常见组件通信之mitt 概述 ​ 在vue3中常见的组件通信有props、mitt、v-model、 r e f s 、 refs、 refs、parent、provide、inject、pinia、slot等。不同的组件关系用不同的传递方式。常见的撘配形式如下表所示。 组件关系传递方式父传子1. props2. v-model3. $refs…

Spring Boot框架基础

文章目录 1 Spring Boot概述2 Spring Boot入门2.1 项目搭建2.2 入门程序 3 数据请求与响应3.1 数据请求3.2 数据响应 4 分层解耦4.1 三层架构4.2 控制反转4.3 依赖注入 5 参考资料 1 Spring Boot概述 Spring是Java EE编程领域的一个轻量级开源框架,是为了解决企业级…

欧美北美南美国外媒体投稿和东南亚中东亚洲媒体海外新闻发稿软文推广营销策略有哪些?

在当今全球化的浪潮中,中国品牌正积极拓展海外市场,寻求更广阔的发展空间。面对国际竞争,有效的海外媒体发稿营销策略对于品牌国际化至关重要。以下是一些关键点和建议,以帮助品牌在海外市场取得成功。 深入了解目标市场&#xf…

11. MySQL 备份、恢复

文章目录 【 1. MySQL 备份类型 】【 2. 备份数据库 mysqldump 】2.1 备份单个数据表2.2 备份多个数据库2.3 备份所有数据库2.4 备份文件解析 【 3. 恢复数据库 mysql 】【 4. 导出表数据 OUTFILE 】【 5. 恢复表数据 INFILE 】 问题背景 尽管采取了一些管理措施来保证数据库的…

【python】unindent does not match any outer indentation level错误的解决办法

【Python】"unindent does not match any outer indentation level"错误的解决办法 在Python编程中,缩进是定义代码块的关键。与其它编程语言使用花括号或特定关键字不同,Python完全依赖缩进来区分代码结构。如果你在编码时遇到了错误信息unin…

stm32编写Modbus步骤

1. modbus协议简介: modbus协议基于rs485总线,采取一主多从的形式,主设备轮询各从设备信息,从设备不主动上报。 日常使用都是RTU模式,协议帧格式如下所示: 地址 功能码 寄存器地址 读取寄存器…

MySQL基础---库的操作和表的操作(配着自己的实操图,简单易上手)

绪论​ 勿问成功的秘诀为何,且尽全力做您应该做的事吧。–美华纳;本章是MySQL的第二章,本章主要写道MySQL中库和表的增删查改以及对库和表的备份处理,本章是基于上一章所写若没安装mysql可以查看Linux下搭建mysql软件及登录和基本…

ubuntu18.04离线安装Mysql

查看系统位数 首先看自己Ubuntu是32还是64位的 sudo uname --m 我是64位。 下载mysql MySQL :: Download MySQL Community Server (Archived Versions) 上传到服务器上 解压 mkdir mysql8 sudo tar -xvf mysql-server_8.0.11-1ubuntu18.04_amd64.deb-bundle.tar -C ./mysql8…

1 c++多线程创建和传参

什么是进程? 系统资源分配的最小单位。 什么是线程? 操作系统调度的最小单位,即程序执行的最小单位。 为什么需要多线程? (1)加快程序执行速度和响应速度, 使得程序充分利用CPU资源。 (2&…

【MySQL】(基础篇二) —— MySQL初始用

MySQL初始用 目录 MySQL初始用基本语法约定选择数据库查看数据库和表其它的SHOW 在Navicat中,大部分数据库管理相关的操作都可以通过图形界面完成,这个很简单,大家可以自行探索。虽然Navicat等图形化数据库管理工具为操作和管理数据库提供了非…

upload-labs-第五关

目录 第五关 1、构造.user.ini文件 2、构造一个一句话木马文件,后缀名为jpg 3、上传.user.ini文件后上传flag.jpg 4、上传成功后访问上传路径 第五关 原理: 这一关采用黑名单的方式进行过滤,不允许上传php、php3、.htaccess等这几类文件…

AB测试学习(附有相关代码)

目录 一、基本概念1. 定义2. 作用3. 原理 二、实验基本原则三、实验步骤四、实验步骤详解1. 确定实验目的2. 确定实验变量3. 实验指标设计3.1 实验指标类型(按作用区分)3.1.1 核心指标3.1.2 驱动指标(跟踪指标)3.1.3 护栏指标 3.2…