VIT总结

关于transformer、VIT和Swin T的总结

1.transformer

1.1.注意力机制

An attention function can be described as mapping a query and a set of key-value pairs to an output, where the query, keys, values, and output are all vectors. The output is computed as a weighted sum of the values, where the weight assigned to each value is computed by a compatibility function of the query with the corresponding key.[1]
输入是query和 key-value,注意力机制首先计算query与每个key的关联性(compatibility)每个关联性作为每个value的权重(weight),各个权重与value的乘积相加得到输出

Attention Is All You Need 中用到的attention叫做“Scaled Dot-Product Attention”,具体过程如下图所示:
在这里插入图片描述
代码实现:

import torch
import torch.nn as nnclass SelfAttention(nn.Module):def __init__(self, embed_size, heads):super(SelfAttention, self).__init__()self.embed_size = embed_sizeself.heads = headsself.head_dim = embed_size // headsassert (self.head_dim * heads == embed_size), "Embed size needs  to  be div by heads"self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)self.fc_out = nn.Linear(heads * self.head_dim, embed_size)def forward(self, values, keys, query, mask):N = query.shape[0]  # the number of training examplesvalue_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]# Split embedding into self.heads piecesvalues = values.reshape(N, value_len, self.heads, self.head_dim)keys = keys.reshape(N, key_len, self.heads, self.head_dim)queries = query.reshape(N, query_len, self.heads, self.head_dim)values = self.values(values)keys = self.keys(keys)queries = self.queries(queries)energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])# queries shape: (N, query_len, heads, heads_dim)# keys shape: (N, key_len, heads, heads_dim)# energy shape: (N, heads, query_len, key_len)if mask is not None:energy = energy.masked_fill(mask == 0, float("-1e20"))# Fills elements of self tensor with value where mask is Trueattention = torch.softmax(energy / (self.embed_size ** (1 / 2)), dim=3)out = torch.einsum("nhql, nlhd->nqhd", [attention, values]).reshape(N, query_len, self.heads * self.head_dim)# attention shape: (N, heads, query_len, key_len)# values shape: (N, value_len, heads, head_dim)# after einsum (N, query_len, heads, head_dim) then flatten last two dimensionsout = self.fc_out(out)return out

1.为什么有mask?
NLP处理不定长文本需要padding,但是padding的内容无意义,所以处理时需要mask.
2.关于qkv
qkv是相同的,需要查询的q,与每一个key相乘得到权重信息,权重与v相乘,这样结果受权重大的v影响
3.为什么除以根号dk

We suspect that for large values of dk, the dot products grow large in magnitude, pushing the softmax function into regions where it has extremely small gradients 4. To counteract this effect, we scale the dot products by 1 √dk
点积过大,经过softmax,进入饱和区,梯度很小

4.为什么需要多头
在这里插入图片描述
不同头部的output就是从不同层面(representation subspace)考虑关联性而得到的输出。

1.2.TransformerBlock

解码端的后面两部分和编码段一样,所以打包成一个类
在这里插入图片描述

class TransformerBlock(nn.Module):def __init__(self, embed_size, heads, dropout, forward_expansion):super(TransformerBlock, self).__init__()self.attention = SelfAttention(embed_size, heads)self.norm1 = nn.LayerNorm(embed_size)self.norm2 = nn.LayerNorm(embed_size)self.feed_forward = nn.Sequential(nn.Linear(embed_size, forward_expansion * embed_size),nn.ReLU(),nn.Linear(forward_expansion * embed_size, embed_size))self.dropout = nn.Dropout(dropout)def forward(self, value, key, query, mask):attention = self.attention(value, key, query, mask)x = self.dropout(self.norm1(attention + query))forward = self.feed_forward(x)out = self.dropout(self.norm2(forward + x))return out

1.3.Encoder

关键的就是位置编码

class Encoder(nn.Module):def __init__(self,src_vocab_size,embed_size,num_layers,heads,device,forward_expansion,dropout,max_length):super(Encoder, self).__init__()self.embed_size = embed_sizeself.device = deviceself.word_embedding = nn.Embedding(src_vocab_size, embed_size)self.position_embedding = nn.Embedding(max_length, embed_size)self.layers = nn.ModuleList([TransformerBlock(embed_size,heads,dropout=dropout,forward_expansion=forward_expansion)for _ in range(num_layers)])self.dropout = nn.Dropout(dropout)def forward(self, x, mask):N, seq_lengh = x.shapepositions = torch.arange(0, seq_lengh).expand(N, seq_lengh).to(self.device)out = self.dropout(self.word_embedding(x) + self.position_embedding(positions))for layer in self.layers:out = layer(out, out, out, mask)return out

2.VIT

在这里插入图片描述

Reference:

[1].Attention Is All You Need
[2].https://zhuanlan.zhihu.com/p/366592542
[3].代码实现:https://zhuanlan.zhihu.com/p/653170203
[4].An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/212777.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

MySQL数据库,函数与分组

单行函数: 操作数据对象 接受参数返回一个结果 只对一行进行变换 每行返回一个结果 可以嵌套 参数也可以是一列或一个值 数值函数 基本函数: 注:ROUND(x,y)函数的y是负数时,即往高位进行四舍五入,如-3就是按百位…

Plonky2 = Plonk + FRI

Plonky2由Polygon Zero团队开发,实现了一种快速的递归SNARK,据其团队公开的基准测试,2020年,以太坊第一笔递归证明需要60s生成,而于今Plonky2在 MacBook Pro上生成只需 170 毫秒。 下面将逐步剖析Plonky2。 整体构造 …

acwing-Linux学习笔记

acwing-Linux课上的笔记 acwing-Linux网址 文章目录 1.1常用文件管理命令homework作业测评命令 2.1 简单的介绍tmux与vimvimhomeworktmux教程vim教程homework中的一些操作 3 shell语法概论注释变量默认变量数组expr命令read命令echo命令printf命令test命令与判断符号[]逻辑运算…

Hive HWI 配置

前言 1、下载安装好hive后,发现hive有hwi界面功能,研究下是否可以运行,于是使用hive –service hwi命令启动hwi界面报错。 启动hwi功能 2、访问192.168.126.110:9999/hwi,发现访问错误 一、HWI介绍 HWI(Hive Web Int…

借助乔拓云,轻松驾驭小程序开发

在当今数字化时代,微信小程序已经成为企业和个人开展业务、提升用户体验的重要工具。然而,要想成功地开发并运营一个小程序,却不是一件容易的事情。在这个过程中,第三方开发平台的出现为开发者提供了一个便捷、高效、可靠的解决方…

【android开发-22】android中音频和视频用法详解

1,播放音频 MediaPlayer是Android中用于播放音频和视频的类。它提供了许多方法来控制播放,例如播放、暂停、停止、释放等。下面是一个简单的MediaPlayer用法详解和参考代码例子。 首先,确保在布局文件中添加了一个MediaPlayer控件&#xff…

如何进行代码混淆?方法与常见工具介绍

​ 如何进行代码混淆?方法与常见工具介绍 目录 什么是代码混淆? 代码混淆的方法 常见代码混淆工具 什么是代码混淆? 代码混淆是指将计算机程序的代码转换成一种功能上等价,但难于阅读和理解的形式的行为。混淆后的代码很难被…

探究Logistic回归:用数学解释分类问题

文章目录 前言回归和分类Logistic回归线性回归Sigmoid函数把回归变成分类Logistic回归算法的数学推导Sigmoid函数与其他激活函数的比较 Logistic回归实例1. 数据预处理2. 模型定义3. 训练模型4. 结果可视化 结语 前言 当谈论当论及机器学习中的回归和分类问题时,很…

cpu 300% 爆满 内存占用不高 排查

top查询 cpu最高的PID ps -ef | grep PID 查看具体哪一个jar服务 jstack -l PID > ./jstack.log 下载/打印进程的线程栈信息 可以加信息简单分析 或进一步 查看堆内存使用情况 jmap -heap Java进程id jstack.log 信息示例 Full thread dump Java HotSpot(TM) 64-Bit Se…

MQTT 协议入门:轻松上手,快速掌握核心要点

文章目录 什么是 MQTT?MQTT 的工作原理MQTT 客户端MQTT Broker发布-订阅模式主题QoS MQTT 的工作流程开始使用 MQTT:快速教程准备 MQTT Broker准备 MQTT 客户端创建 MQTT 连接通过通配符订阅主题发布 MQTT 消息MQTT 功能演示保留消息Clean Session遗嘱消…

【广州华锐互动VRAR】VR戒毒科普宣传系统有效提高戒毒成功率

随着科技的不断发展,虚拟现实(VR)技术已经逐渐渗透到各个领域,为人们的生活带来了前所未有的便利。在教育科普领域,VR技术的应用也日益广泛,本文将详细介绍广州华锐互动开发的VR戒毒科普宣传系统&#xff0…

长沙电信大楼火灾调查报告发布:系烟头引发。FIS来护航安全

近日,长沙电信大楼的火灾调查报告引起广泛关注。调查发现,火灾是由未熄灭的烟头引发,烟头点燃了室外平台的易燃物,迅速蔓延至整个建筑。这起悲剧再次提醒我们,小小的疏忽可能酿成大灾难。但如果我们能及时发现并处理这…

Linus:我休假的时候也会带着电脑,否则会感觉很无聊

目录 Linux 内核最新版本动态 关于成为内核维护者 代码好写,人际关系难处理 内核维护者老龄化 内核中 Rust 的使用 关于 AI 的看法 参考 12.5-12.6 日,Linux 基金会组织的开源峰会(OSS,Open Source Summit)在日…

java设计模式学习之【适配器模式】

文章目录 引言适配器模式简介定义与用途:实现方式:类型 使用场景优势与劣势适配器模式在Spring中的应用多媒体播放器示例代码地址 引言 在我们的日常生活中,适配器无处不在:无论是将不同国家的插头转换成本地标准,还是…

JavaSE基础50题:11. 输出一个整数的每一位

概述 输出一个整数的每一位。 如:1234的每一位是4,3,2,1 。 个位:1234 % 10 4 十位:1234 / 10 123 123 % 10 3 百位:123 / 10 12 12 % 10 2 千位: 12 / 10 1 代码 ublic sta…

yolo目标检测+目标跟踪+车辆计数+车辆分割+车道线变更检测+速度估计

这个项目使用YOLO进行车辆检测,使用SORT(简单在线实时跟踪器)进行车辆跟踪。该项目实现了以下任务: 车辆计数车道分割车道变更检测速度估计将所有这些详细信息转储到CSV文件中 车辆计数是指在道路上安装相应设备,通过…

Linux文件部分知识

目录 认识inode 如何理解创建一个空文件? 如何理解对文件写入信息? 如何理解删除一个文件? 为什么拷贝文件的时候很慢,而删除文件的时候很快? 如何理解目录 ​编辑 文件的三个时间 ​编辑 Access: …

【文件上传系列】No.1 大文件分片、进度图展示(原生前端 + Node 后端 Koa)

分片(500MB)进度效果展示 效果展示,一个分片是 500MB 的 分片(10MB)进度效果展示 大文件分片上传效果展示 前端 思路 前端的思路:将大文件切分成多个小文件,然后并发给后端。 页面构建 先在页…

超大规模集成电路设计----FPGA时序模型及FSM的设计(八)

本文仅供学习,不作任何商业用途,严禁转载。绝大部分资料来自----数字集成电路——电路、系统与设计(第二版)及中国科学院段成华教授PPT 超大规模集成电路设计----RTL级设计之FSM(八) 7.1 CPLD的时序模型7.1.1 XPLA3 时序模型7.1.…

电话卡Giffgaff激活

Giffgaff是一家总部位于英国的移动电话公司。作为一家移动虚拟网络电信运营商,Giffgaff使用O2的网络,是O2的全资子公司,成立于2009年11月25日。 Giffgaff与传统的移动电话运营商不同,区别在于其用户也可以参与公司的部分运营&…