Transformer详解(3)-多头自注意力机制

attention

在这里插入图片描述

在这里插入图片描述

multi-head attention

在这里插入图片描述
在这里插入图片描述

pytorch代码实现

import math
import torch
from torch import nn
import torch.nn.functional as Fclass MultiHeadAttention(nn.Module):def __init__(self, heads=8, d_model=128, droput=0.1):super().__init__()self.d_model = d_model  # 128self.d_k = d_model // heads  # 128//8=16self.h = heads  # 8self.q_linear = nn.Linear(d_model, d_model)  # (50,128)*(128,128)=(50,128),其中(128*128)属于权重,在网络训练中学习。self.k_linear = nn.Linear(d_model, d_model)self.v_linear = nn.Linear(d_model, d_model)self.dropout = nn.Dropout(droput)self.out = nn.Linear(d_model, d_model)def attention(self, q, k, v, d_k, mask=None, dropout=None):scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)  # 矩阵乘法 (32,8,50,16)*(32,8,16,50)->(32,8,50,50)if mask is not None:mask = mask.unsqueeze(1)scores = scores.masked_fill(mask == 0, -1e9)scores = F.softmax(scores, dim=-1)if dropout is not None:scores = dropout(scores)output = torch.matmul(scores, v)  # (32,8,50,50)*(32,8,50,16)->(32,8,50,16)return outputdef forward(self, q, k, v, mask=None):bs = q.size(0)  # batch_size 大小  这里的例子是32k = self.k_linear(k).view(bs, -1, self.h, self.d_k)q = self.k_linear(q).view(bs, -1, self.h, self.d_k)v = self.k_linear(v).view(bs, -1, self.h, self.d_k)# (32,50,128)->(32,50,128)->(32,50,8,16)  8*16=128 每个embedding拆成的8份,也就是8个头k = k.transpose(1, 2)  # (32,50,8,16)->(32,8,50,16)q = q.transpose(1, 2)v = v.transpose(1, 2)scores = self.attention(q, k, v, self.d_k, mask, self.dropout)  # (32,8,50,16)concat = scores.transpose(1, 2).contiguous().view(bs, -1, self.d_model)  # (32,50,128)output = self.out(concat)  # (32,50,128)return outputif __name__ == '__main__':multi_head_attention = MultiHeadAttention(8, 128)normal_tensor = torch.randn(32, 50, 128)  # 随机生成均值为0,方差为1的正态分布。batch_size=32,序列长度=50,embedding维度=128。x = torch.sigmoid(normal_tensor)  # 把每个数缩放到(0,1)output = multi_head_attention(x, x, x)print('done')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/330282.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

大模型日报|今日必读的 13 篇大模型论文

大家好,今日必读的大模型论文来啦! 1.MIT新研究:并非所有语言模型特征都是线性的 最近的研究提出了线性表征假说:语言模型通过操作激活空间中概念(“特征”)的一维表征来执行计算。与此相反,来…

【EXCEL_VBA_基础知识】15 使用ADO操作外部数据

课程来源:王佩丰老师的《王佩丰学VBA视频教程》,如有侵权,请联系删除! 目录 1. 使用ADO链接外部数据源 2. 常用SQL语句(Execute(SQL语句)) 2.1 查询数据、查询某几个字段、带条件查询、合并两表数据、插…

GPT提示词技巧,使用教程,国内版官网直达,非套壳

GPT提示词技巧,使用教程,国内版官网直达,非套壳 主站点:https://chatgpt-plus.top(江苏福建地区打不开,需要魔法) 店铺地址:https://buy.chatgpt-plus.top/ 选择plus账号进入&…

工地升降机AI人数识别系统

工地升降机人数识别系统采用了AI神经网络和深度学习算法,工地升降机AI人数识别系统通过升降机内置的摄像头实时监测轿厢内的人员数量。通过图像处理和人脸识别算法,系统能够精确地识别升降机内的人数。一旦系统识别到人数达到或者超过设定的阈值&#xf…

C++中的四种类型转换运算符

隐式类型转换是安全的,显式类型转换是有风险的,C语言之所以增加强制类型转换的语法,就是为了强调风险,让程序员意识到自己在做什么。但是,这种强调风险的方式还是比较粗放,粒度比较大,它并没有表…

虚拟化技术[1]之服务器虚拟化

文章目录 虚拟化技术简介数据中心虚拟化 服务器虚拟化服务器虚拟化层次寄居虚拟化裸机虚拟化VMM无法直接捕获特权指令解决方案 服务器虚拟化底层实现CPU虚拟化内存虚拟化I/O设备虚拟化 虚拟机迁移虚拟机动态迁移迁移内容:内存迁移迁移内容:网络资源迁移迁…

mysqldump提示Using a password on the command line interface can be insecured的解决办法

mysql数据库备份一句话执行命令 mysqldump --all-databases -h127.0.0.1 -uroot -p123456 > allbackupfile.sql 提示如下提示 [rootyfvyy5b2on3knb8q opt]# mysqldump --all-databases -h127.0.0.1 > allbackupfile.sql mysqldump: Couldnt execute SELECT COLUMN_NA…

3D模型旋转显示不全怎么办---模大狮模型网

在3D建模和渲染过程中,我们有时会遇到旋转模型时显示不全的问题。这种情况可能由多种原因造成,包括模型本身的问题、软件设置不当、硬件配置不足等。本文将为您详细介绍几种可能的解决方法,帮助您解决3D模型旋转显示不全的问题。 一、检查模型…

防火墙技术基础篇:基于IP地址的转发策略

防火墙技术基础篇:基于IP地址的转发策略的应用场景及实现 什么是基于IP地址的转发策略? 基于IP地址的转发策略是一种网络管理方法,它允许根据目标IP地址来选择数据包的转发路径。这种策略比传统的基于目的地地址的路由更灵活,因…

K8s-yaml文件

一.Yaml文件详解: Kubernetes 支持 YAML 和 JSON 格式管理资源对象 JSON 格式:主要用于 api 接口之间消息的传递YAML 格式:用于配置和管理,YAML 是一种简洁的非标记性语言,内容格式人性化,较易读 YAML 语…

寻找峰值 ---- 二分查找

题目链接 题目: 分析: 因为题目中要找的是任意一个峰值即可, 所以和<山脉数组的峰值索引>这道题差不多因为峰值左右都小于峰值, 所以具有"二段性", 可以使用二分查找算法如果nums[mid] < nums[mid 1], mid一定不是峰值, 所以left mid 1如果nums[mid] &…

【数据库】MySQL

文章目录 概述DDL数据库操作查询使用创建删除 表操作创建约束MySqL数据类型数值类型字符串类型日期类型 查询修改删除 DMLinsertupdatedelete DQL基本查询条件查询分组查询分组查询排序查询分页查询 多表设计一对多一对一多对多设计步骤 多表查询概述内连接外连接 子查询标量子…

【LeetCode 随笔】面试经典 150 题【中等+困难】持续更新中。。。

文章目录 12.【中等】整数转罗马数字151.【中等】反转字符串中的单词6.【中等】Z 字形变换68.【困难】文本左右对齐167.【中等】两数之和 II - 输入有序数组 &#x1f308;你好呀&#xff01;我是 山顶风景独好 &#x1f49d;欢迎来到我的博客&#xff0c;很高兴能够在这里和您…

基于YOLO系列算法(YOLOv5、YOLOv6、YOLOv8以及YOLOv9)和Streamlit框架的行人头盔检测系统

摘要 本文基于最新的基于深度学习的目标检测算法 (YOLOv5、YOLOv6、YOLOv8)以及YOLOv9) 对头盔数据集进行训练与验证&#xff0c;得到了最好的模型权重文件。使用Streamlit框架来搭建交互式Web应用界面&#xff0c;可以在网页端实现模型对图像、视频和实时摄像头的目标检测功能…

网络考试系统的设计与实现参考论文(论文 + 源码)

网络考试系统的设计与实现 摘 要 科技在进步&#xff0c;人们生活和工作的方式正发生着改变&#xff0c;不仅体现在人们的衣食住行&#xff0c;也体现在与时俱进的考试形式上。以前的考试需要组织者投入大量的时间和精力&#xff0c;需要对考试的试题进行筛选&#xff0c;对后期…

变量命名的艺术:从蛇形到驼峰

新书上架~&#x1f447;全国包邮奥~ python实用小工具开发教程http://pythontoolsteach.com/3 欢迎关注我&#x1f446;&#xff0c;收藏下次不迷路┗|&#xff40;O′|┛ 嗷~~ 目录 一、蛇形命名法的魅力 二、类名和模块名的特殊规则 三、驼峰命名法的魅力与挑战 四、保持…

DLRover:蚂蚁集团开源的AI训练革命

在当前的深度学习领域&#xff0c;大规模训练作业面临着一系列挑战。首先&#xff0c;硬件故障或软件错误导致的停机时间会严重影响训练效率和进度。其次&#xff0c;传统的检查点机制在大规模训练中效率低下&#xff0c;耗时长且容易降低训练的有效时间。资源管理的复杂性也给…

新建项目上传gitee

1.在项目根目录下打开黑窗口执行初始化 git init2.复制码云上新建仓库地址 3.本地仓库和远程仓库建立连接 远程仓库地址是之前复制的仓库地址&#xff0c;复制后直接在命令窗口中鼠标右键Paste即可在命令窗口粘贴出来 git remote add origin 远程仓库地址4.每次上传之前先更…

【嵌入式软件工程师面经】Socket,TCP,HTTP之间的区别

目录&#xff1a; 目录 目录&#xff1a; 一、Socket原理与TCP/IP协议 1.1 Socket概念&#xff1a; 1.2 建立Socket连接&#xff1a; 1.3 SOCKET连接与TCP/IP连接 二、HTTP连接&#xff1a; 2.1 HTTP原理 三、三者的区别和联系 前些天发现了一个巨牛的人工智能学习网站&#xf…

头歌openGauss-存储过程第1关:创建存储过程

编程要求 1、创建第1个存储过程&#xff0c;并调用&#xff1b; 1&#xff09;创建存储过程&#xff0c;查询emp表数据&#xff1b; 2&#xff09;调用存储过程&#xff1b; --创建存储过程&#xff0c;获得计算机&#xff08;cs&#xff09;系学生选课情况并将结果写入临时表t…