transfomer中Decoder和Encoder的base_layer的源码实现

简介

Encoder和Decoder共同组成transfomer,分别对应图中左右浅绿色框内的部分.
在这里插入图片描述
Encoder:
目的:将输入的特征图转换为一系列自注意力的输出。
工作原理:首先,通过卷积神经网络(CNN)提取输入图像的特征。然后,这些特征通过一系列自注意力的变换层进行处理,每个变换层都会将特征映射进行编码并产生一个新的特征映射。这个过程旨在捕捉图像中的空间和通道依赖关系。
作用:通过处理输入特征,提取图像特征并进行自注意力操作,为后续的目标检测任务提供必要的特征信息。
Decoder:
目的:接受Encoder的输出,并生成对目标类别和边界框的预测。
工作原理:首先,它接收Encoder的输出,然后使用一系列解码器层对目标对象之间的关系和全局图像上下文进行推理。这些解码器层将最终的目标类别和边界框的预测作为输出。
作用:基于Encoder的输出和全局上下文信息,生成目标类别和边界框的预测结果。
总结:Encoder就是特征提取类似卷积;Decoder用于生成box,类似head

源码实现:

Encoder 通常是6个encoder_layer组成,Decoder 通常是6个decoder_layer组成
我实现了核心的BaseTransformerLayer层,可以用来定义encoder_layer和decoder_layer

具体源码及其注释如下,配好环境可直接运行(运行依赖于上一个博客的代码):

import torch
from torch import nn
from ZMultiheadAttention import MultiheadAttention  # 来自上一次写的attensionclass FFN(nn.Module):def __init__(self,embed_dim=256,feedforward_channels=1024,act_cfg='ReLU',ffn_drop=0.,):super(FFN, self).__init__()self.l1 = nn.Linear(in_features=embed_dim, out_features=feedforward_channels)if act_cfg == 'ReLU':self.act1 = nn.ReLU(inplace=True)else:self.act1 = nn.SiLU(inplace=True)self.d1 = nn.Dropout(p=ffn_drop)self.l2 = nn.Linear(in_features=feedforward_channels, out_features=embed_dim)self.d2 = nn.Dropout(p=ffn_drop)def forward(self, x):tmp = self.d1(self.act1(self.l1(x)))tmp = self.d2(self.l2(tmp))x = tmp + xreturn x# transfomer encode和decode的最小循环单元,用于打包self_attention或者cross_attention
class BaseTransformerLayer(nn.Module):def __init__(self,attn_cfgs=[dict(embed_dim=64, num_heads=4), dict(embed_dim=64, num_heads=4)],fnn_cfg=dict(embed_dim=64, feedforward_channels=128, act_cfg='ReLU', ffn_drop=0.),operation_order=('self_attn', 'norm', 'cross_attn', 'norm', 'ffn', 'norm')):super(BaseTransformerLayer, self).__init__()self.attentions = nn.ModuleList()# 搭建att层for attn_cfg in attn_cfgs:self.attentions.append(MultiheadAttention(**attn_cfg))self.embed_dims = self.attentions[0].embed_dim# 统计norm数量 并搭建self.norms = nn.ModuleList()num_norms = operation_order.count('norm')for _ in range(num_norms):self.norms.append(nn.LayerNorm(normalized_shape=self.embed_dims))# 统计ffn数量 并搭建self.ffns = nn.ModuleList()self.ffns.append(FFN(**fnn_cfg))self.operation_order = operation_orderdef forward(self, query, key=None, value=None, query_pos=None, key_pos=None):attn_index = 0norm_index = 0ffn_index = 0for order in self.operation_order:if order == 'self_attn':temp_key = temp_value = query  # 不用担心三个值一样,在attention里面会重映射qkvquery, attention = self.attentions[attn_index](query,temp_key,temp_value,query_pos=query_pos,key_pos=query_pos)attn_index += 1elif order == 'cross_attn':query, attention = self.attentions[attn_index](query,key,value,query_pos=query_pos,key_pos=key_pos)attn_index += 1elif order == 'norm':query = self.norms[norm_index](query)norm_index += 1elif order == 'ffn':query = self.ffns[ffn_index](query)ffn_index += 1return queryif __name__ == '__main__':query = torch.rand(size=(10, 2, 64))key = torch.rand(size=(5, 2, 64))value = torch.rand(size=(5, 2, 64))query_pos = torch.rand(size=(10, 2, 64))key_pos = torch.rand(size=(5, 2, 64))# encoder 通常是6个encoder_layer组成 每个encoder_layer['self_attn', 'norm', 'ffn', 'norm']encoder_layer = BaseTransformerLayer(attn_cfgs=[dict(embed_dim=64, num_heads=4)],fnn_cfg=dict(embed_dim=64, feedforward_channels=1024, act_cfg='ReLU',ffn_drop=0.),operation_order=('self_attn', 'norm', 'ffn', 'norm'))encoder_layer_output = encoder_layer(query=query, query_pos=query_pos, key_pos=key_pos)# decoder 通常是6个decoder_layer组成 每个decoder_layer['self_attn', 'norm', 'cross_attn', 'norm', 'ffn', 'norm']decoder_layer = BaseTransformerLayer(attn_cfgs=[dict(embed_dim=64, num_heads=4), dict(embed_dim=64, num_heads=4)],fnn_cfg=dict(embed_dim=64, feedforward_channels=1024, act_cfg='ReLU',ffn_drop=0.),operation_order=('self_attn', 'norm', 'cross_attn', 'norm', 'ffn', 'norm'))decoder_layer_output = decoder_layer(query=query, key=key, value=value, query_pos=query_pos, key_pos=key_pos)pass

具体流程说明:

Encoder 通常是6个encoder_layer组成,每个encoder_layer[‘self_attn’, ‘norm’, ‘ffn’, ‘norm’]
Decoder 通常是6个decoder_layer组成,每个decoder_layer[‘self_attn’, ‘norm’, ‘cross_attn’, ‘norm’, ‘ffn’, ‘norm’]
按照以上方式搭建网络即可
其中norm为LayerNorm,在样本内部进行归一化。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/238947.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

开发需求总结9-el-tree获取选中节点,节点全选时返回被全选子级的父节点,未全选则返回被选中的节点

目录 需求描述 代码实现: 需求描述 需要获取树组件选中的节点,假如父节点被选中(该节点全选),即只返回父节点的数据,如父节点未被全选,则正常返回被选中节点的数据。 示例一: 如上图…

Python展示 RGB立方体的二维切面视图

代码实现 import numpy as np import matplotlib.pyplot as plt# 生成 24-bit 全彩 RGB 立方体 def generate_rgb_cube():# 初始化一个 256x256x256 的三维数组rgb_cube np.zeros((256, 256, 256, 3), dtypenp.uint8)# 填充立方体for r in range(256):for g in range(256):fo…

编曲混音FL Studio21.2对电脑有什么配置要求

FL Studio 21是一款非常流行的音乐制作软件,它可以帮助音乐人和制作人创作出高质量的音乐作品。然而,为了保证软件的稳定性和流畅性,用户需要知道FL Studio 21对电脑的配置要求。本文将介绍FL Studio 21的配置要求,以帮助用户选择…

32 二叉树的定义

之前的通用树结构 采用双亲孩子表示法模型 孩子兄弟表示法模型 引出二叉树 二叉树的定义: 满二叉树和完全二叉树 对此图要有印象 满二叉树一定是完全二叉树,但是完全二叉树不一定是满二叉树 小结

RabbitMQ交换机(2)-Direct

1.Direct 直连(路由)交换机,生产者将消息发送到交换机,并指定消息的Routing Key(路由键)。交换机会将Routing Key与队列绑定进行匹配,如果匹配成功,则将该消息路由到对应的队列中。如果没有匹配成功,该消息…

小程序中使用微信同声传译插件实现语音识别、语音合成、文本翻译功能----语音识别(一)

官方文档链接:https://mp.weixin.qq.com/wxopen/plugindevdoc?appidwx069ba97219f66d99&token370941954&langzh_CN#- 要使用插件需要先在小程序管理后台的设置->第三方设置->插件管理中添加插件,目前该插件仅认证后的小程序。 语音识别…

JS | JS调用EXE

JS | JS调用EXE 网上洋洋洒洒一大堆文章提供,然我还是没找打合适的方案: 注册表方案做了如下测试(可行但是不推荐?): 先,键入文件名为 myprotocal.reg 的注册表,并键入一下信息: Windows Registry Editor Version 5.00[HKEY_CLASSES_ROOT\openExe] //协议名…

Redis相关命令详解及其原理

Redis概念 Redis,英文全称是remote dictionary service,也就是远程字典服务。这是kv存储数据库。Redis,包括所有的数据库,都是请求-回应模式,通俗来说就是数据库不会主动地要给前台推送数据,只有前台发送了…

MySQL/Oracle 的 字符串拼接

目录 MySQL、Oracle 的 字符串拼接1、MySQL 的字符串拼接1.1 CONCAT(str1,str2,...) : 可以拼接多个字符串1.2 CONCAT_WS(separator,str1,str2,...) : 指定分隔符拼接多个字符串1.3 GROUP_CONCAT(expr) : 聚合函数,用于将多行的值连接成一个字符串。 2、Oracle 的字…

广州市生物医药及高端医疗器械产业链大会暨联盟会员大会召开,天空卫士数据安全备受关注

12月20日,广州市生物医药及高端医疗器械产业链大会暨联盟会员大会在广州举办。在本次会议上,作为大会唯一受邀参加主题分享的技术供应商,天空卫士南区技术总监黄军发表《生物制药企业如何保护数据安全》的主题演讲。 做好承上启下“连心桥”…

C++设计模式-- 2.代理模式 和 外观模式

文章目录 代理模式外观模式角色和职责代码演示一:代码演示二:外观模式适用场景 代理模式 代理模式的定义:为其他对象提供一种代理以控制对这个对象的访问。在某些情况下,一个对象不适合 或不能直接引用另一个对象,而代…

【实战记录】 vagrant+virtualbox+docker 轻松用虚拟机集成组件

用途 最近要学一大堆组件,不想直接安装本机上,然后gpt说:你可以用vagrant起个虚拟机(然后docker拉取各种组件的镜像);或者k8s 实战的整体思路 首先安装virtualbox和vagrant。然后cmd依次键入三条命令 安…

无需编程,简单易上手的家具小程序搭建方法分享

想要开设一家家具店的小程序吗?现在,我将为大家介绍如何使用乔拓云平台搭建一个家具小程序,帮助您方便快捷地开展线上家具销售业务。 第一步,登录乔拓云平台进入商城后台管理页面。 第二步,在乔拓云平台的后台管理页面…

云畅科技技术中心被认定为湖南省省级企业技术中心

近日,湖南省工业和信息化厅公布《2023年第二批湖南省省级企业技术中心(第29批)》,云畅科技技术中心作为研发设计型代表入选。 省级企业技术中心是强化企业技术创新主体地位,增强企业自主创新能力,推动工业企业高质量发展的一个重要…

深圳三维扫描分析/偏差检测模具型腔三维尺寸及形位偏差测量公司

CASAIM中科广电三维扫描模具型腔深圳案例: 模具型腔的三维扫描分析/偏差检测是一项重要的质量控制过程,旨在确保模具制造过程中的精确度和一致性。 CASAIM中科广电通过使用高精度的三维扫描设备,可以获取模具型腔的实际形状和尺寸数据&…

使用vue快速开发一个带弹窗的Chrome插件

vue-chrome-extension-quickstart 说在前面 🎈平时我们使用Chrome插件通常都只是用来编写简单的js注入脚本,大家有没有遇到过需要插件在页面上注入一个弹窗呢?比如我们希望可以通过快捷键快速唤起ChatGPT面板或者快速唤起一个翻译面板&#x…

案例:应用内字体大小调节

文章目录 介绍相关概念完整实例 代码结构解读保存默认大小获取字体大小修改字体大小 介绍 本篇Codelab将介绍如何使用基础组件Slider,通过拖动滑块调节应用内字体大小。要求完成以下功能: 实现两个页面的UX:主页面和字体大小调节页面。拖动…

compose 实验

cd /opt mkdir compose_nginx cd compose_nginx mkdir nginx cd nginx/ 此时顺便将nginx安装包拖进来 vim Dockerfile mkdir /opt/compose_nginx/wwwroot echo "<h1>this is test web</h1>" > /opt/compose_nginx/wwwroot/index.html docker netw…

漏洞复现-金和OA jc6/servlet/Upload接口任意文件上传漏洞(附漏洞检测脚本)

免责声明 文章中涉及的漏洞均已修复&#xff0c;敏感信息均已做打码处理&#xff0c;文章仅做经验分享用途&#xff0c;切勿当真&#xff0c;未授权的攻击属于非法行为&#xff01;文章中敏感信息均已做多层打马处理。传播、利用本文章所提供的信息而造成的任何直接或者间接的…

力扣刷题(无重复字符的最长子串)

3. 无重复字符的最长子串https://leetcode.cn/problems/longest-substring-without-repeating-characters/ 给定一个字符串 s &#xff0c;请你找出其中不含有重复字符的 最长子串 的长度。 示例 1: 输入: s "abcabcbb" 输出: 3 解释: 因为无重复字符的最长子串是…