PixelSNAIL论文代码学习(2)——门控残差网络的实现

文章目录

    • 引言
    • 正文
      • 门控残差网络介绍
      • 门控残差网络具体实现代码
      • 使用pytorch实现
    • 总结

引言

  • 阅读了pixelSNAIL,很简短,就用了几页,介绍了网络结构,介绍了试验效果就没有了,具体论文学习链接
  • 这段时间看他的代码,还是挺痛苦的,因为我对于深度学习的框架尚且不是很熟练 ,而且这个作者很厉害,很多东西都是自己实现的,所以看起来十分费力,本来想逐行分析,结果发现逐行分析不现实,所以这里按照模块进行分析。
  • 今天就专门来学习一下他门门控控残差模块如何实现。

正文

门控残差网络介绍

  • 介绍

    • 通过门来控制每一个残差模块,门通常是由sigmoid函数组成
    • 作用:有效建模复杂函数,有助于缓解梯度消失和爆炸的问题
  • 基本步骤

    • 卷积操作:对输入矩阵执行卷积操作
    • 非线性激活:应用非线性激活函数,激活卷积操作的输出
    • 第二次卷积操作:对上一个层的输出进行二次卷积
    • 门控操作:将二次卷积的输出分为a和b两个部分,并且通过sigmoid函数进行门控 a , b = S p l i t ( c 2 ) G a t e : g = a × s i g m o i d ( b ) a,b = Split(c_2) \\ Gate:g = a \times sigmoid(b) a,b=Split(c2)Gate:g=a×sigmoid(b)
      • 这里一般是沿着最后一个通道,将原来的矩阵拆解成a和b,然后在相乘,确保每一个矩阵有一个门控参数
    • 将门控输出 g g g和原始输入 x x x相加
  • 具体流程图如下

    • x: 输入
    • c1: 第一次卷积操作(Conv1)
    • a1: 非线性激活函数(例如 ReLU)
    • c2: 第二次卷积操作(Conv2),输出通道数是输入通道数的两倍
    • split: 将c2 分为两部分 a 和 b
    • a, b: 由 c2 分割得到的两部分
    • sigmoid: 对b 应用 sigmoid 函数
    • gated: 执行门控操作 a×sigmoid(b)
    • y: 输出,由原始输入 x 和门控输出相加得到

在这里插入图片描述

  • 这里参考一下论文中的图片,可以看到和基本的门控神经网络是近似的,只不过增加了一些辅助输入还有条件矩阵

在这里插入图片描述

门控残差网络具体实现代码

  • 具体和上面描述的差不多,这里增加了两个额外的参数,分别是辅助输入a和条件矩阵b

  • 注意,这里的二维卷积就是加上了简单的权重归一化的普通二维卷积。

  • 辅助输入a

    • 用途:提供额外的信息,帮助网络更好地执行任务,比如说在多模态场景或者多任务学习中,会通过a提供主输入x相关联的信息
    • 操作:如果提供了a,那么在第一次卷积之后,会经过全连接层与c1相加
  • 条件矩阵h

    • 用途:主要用于条件生成任务,因为条件生成任务的网络行为会受到某些条件和上下文影响。比如,在文本生成图像中,h会是一个文本描述的嵌入
    • 操作:如果提供了 h,那么 h 会被投影到一个与 c2 具有相同维度的空间中,并与 c2 相加。这是通过一个全连接层实现的,该层的权重是 hw。
def gated_resnet(x, a=None, h=None, nonlinearity=concat_elu, conv=conv2d, init=False, counters={}, ema=None, dropout_p=0., **kwargs):xs = int_shape(x)num_filters = xs[-1]# 执行第一次卷积c1 = conv(nonlinearity(x), num_filters)# 查看是否有辅助输入aif a is not None:  # add short-cut connection if auxiliary input 'a' is givenc1 += nin(nonlinearity(a), num_filters)# 执行非线性单元c1 = nonlinearity(c1)if dropout_p > 0:c1 = tf.nn.dropout(c1, keep_prob=1. - dropout_p)# 执行第二次卷积c2 = conv(c1, num_filters * 2, init_scale=0.1)# add projection of h vector if included: conditional generation# 如果有辅助输入h,那么就将h投影到c2的维度上if h is not None:with tf.variable_scope(get_name('conditional_weights', counters)):hw = get_var_maybe_avg('hw', ema, shape=[int_shape(h)[-1], 2 * num_filters], dtype=tf.float32,initializer=tf.random_normal_initializer(0, 0.05), trainable=True)if init:hw = hw.initialized_value()c2 += tf.reshape(tf.matmul(h, hw), [xs[0], 1, 1, 2 * num_filters])# Is this 3,2 or 2,3 ?a, b = tf.split(c2, 2, 3)c3 = a * tf.nn.sigmoid(b)return x + c3

使用pytorch实现

  • tensorflow的模型定义过程和pytorch的定义过程就是不一样,tensorflow中的conv2d只需要给出输出的channel,直接输入需要卷积的部分即可。但是使用pytorch,需要进行给定输入的 channel,然后在给出输出的filter_size,很麻烦。
  • 除此之外,在定义模型的层的过程中,我们不能在forward中定义层,只能在init函数中定义层。
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import weight_normclass GatedResNet(nn.Module):def __init__(self, num_filters, nonlinearity=F.elu, dropout_p=0.0):super(GatedResNet, self).__init__()self.num_filters = num_filtersself.nonlinearity = nonlinearityself.dropout_p = dropout_p# 第一卷积层self.conv1 = nn.Conv2d(num_filters, num_filters, kernel_size=3, padding=1)
#         self.conv1 = weight_norm(self.conv1)# 第二卷积层,输出通道是 2 * num_filters,用于门控机制self.conv2 = nn.Conv2d(num_filters, 2 * num_filters, kernel_size=3, padding=1)
#         self.conv2 = weight_norm(self.conv2)# 条件权重用于 h,初始化在前向传播过程中self.hw = Nonedef forward(self, x, a=None, h=None):c1 = self.conv1(self.nonlinearity(x))# 检查是否有辅助输入 'a'if a is not None:c1 += a  # 或使用 NIN 使维度兼容c1 = self.nonlinearity(c1)if self.dropout_p > 0:c1 = F.dropout(c1, p=self.dropout_p, training=self.training)c2 = self.conv2(c1)print('the shape of c2',c2.shape)# 如果有辅助输入 h,则加入 h 的投影if h is not None:if self.hw is None:self.hw = nn.Parameter(torch.randn(h.size(1),  self.num_filters) * 0.05)print(self.hw.shape)c2 +=  (h @ self.hw).view(h.size(0), 1, 1, self.num_filters)# 将通道分为两组:'a' 和 'b'a, b = c2.chunk(2, dim=1)c3 = a * torch.sigmoid(b)return x + c3# 测试
x = torch.randn(16, 32, 32, 32)  # [批次大小,通道数,高度,宽度]
a = torch.randn(16, 32, 32, 32)  # 和 x 维度相同的辅助输入
h = torch.randn(16, 64)  # 可选的条件变量
model = GatedResNet(32)
out = model(x, a , h)

在这里插入图片描述

总结

  • 遇到了很多问题,是因为经验不够,而且很多东西都不了解,然后改的很痛苦,而且现在完全还没有跑起来,完整的组件都没有搭建完成,这里还需要继续努力。
  • 关于门控残差网络这里,这里学到了很多,知道了具体的运作流程,也知道他是专门针对序列数据,防止出现梯度爆炸的。以后可以多用用看。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/118558.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

时序预测 | MATLAB实现CNN-GRU卷积门控循环单元时间序列预测(风电功率预测)

时序预测 | MATLAB实现CNN-GRU卷积门控循环单元时间序列预测(风电功率预测) 目录 时序预测 | MATLAB实现CNN-GRU卷积门控循环单元时间序列预测(风电功率预测)预测效果基本介绍程序设计参考资料 预测效果 基本介绍 1.时序预测 | MA…

《爵士乐史》乔德.泰亚 笔记

第一章 【美国音乐的非洲化】 【乡村布鲁斯和经典布鲁斯】 布鲁斯:不止包括忧愁、哀痛 十二小节布鲁斯特征: 1.乐型(A:主、B:属、C/D:下属):A→A→B→A→C→D→A→A 2.旋律:大三、小三、降七、降五 盲人…

漏洞修复:在应用程序中发现不必要的 Http 响应头

描述 blablabla描述,一般是在返回的响应表头中出现了Server键值对,那我们要做的就是移除它,解决方案中提供了nginx的解决方案 解决方案 第一种解决方案 当前解决方案会隐藏nginx的版本号,但还是会返回nginx字样,如…

Gateway的服务网关

Gateway服务网关 Gateway网关是我们服务的守门神&#xff0c;所有微服务的统一入口。 网关的核心功能特性&#xff1a; 请求路由 权限控制 限流 架构如下&#xff1a; gateway使用 引入依赖 创建gateway服务&#xff0c;引入依赖 <!--网关--> <dependency>…

Spring Boot中通过maven进行多环境配置

上文 java Spring Boot将不同配置拆分入不同文件管理 中 我们说到了&#xff0c;多环境的多文件区分管理 说到多环境 其实不止我们 Spring Boot有 很多的东西都有 那么 这就有一个问题 如果 spring 和 maven 都配置了环境 而且他们配的不一样 那么 会用谁的呢&#xff1f; 此…

镜之Json Compare Diff

前言 “镜” 寓意是凡事都有两面性,Json 对比也不例外! 因公司业务功能当中有一个履历的功能,它有多个版本的 JSON 数据需要对比出每个版本的不同差异节点并且将差异放置在一个新的 JSON 当中原有结构不能变动,差异节点使用数组对象的形式存储,前端点击标红即可显示多个版本的节…

lenovo联想笔记本小新 潮7000-14IKBR 2018款(81GA)原装出厂Windows10系统镜像

自带所有驱动、出厂主题壁纸LOGO、Office办公软件、联想电脑管家等预装程序 链接&#xff1a;https://pan.baidu.com/s/1ynP4d5z7MPF9l5U5lCjDzQ?pwdhjvj 提取码&#xff1a;hjvj 所需要工具&#xff1a;16G或以上的U盘 文件格式&#xff1a;ISO 文件大小&#x…

JUC并发编程--------CAS、Atomic原子操作

什么是原子操作&#xff1f;如何实现原子操作&#xff1f; 什么是原子性&#xff1f; 事务的一大特性就是原子性&#xff08;事务具有ACID四大特性&#xff09;&#xff0c;一个事务包含多个操作&#xff0c;这些操作要么全部执行&#xff0c;要么全都不执行 并发里的原子性…

eureka服务注册和服务发现

文章目录 问题实现以orderservice为例orderservice服务注册orderservice服务拉取 总结 问题 我们要在orderservice中根据查询到的userId来查询user&#xff0c;将user信息封装到查询到的order中。 一个微服务&#xff0c;既可以是服务提供者&#xff0c;又可以是服务消费者&a…

0基础学习VR全景平台篇 第94篇:智慧景区浏览界面介绍

一、景区详细信息介绍 点击左上角的图标就可以看到景区详细信息例如景区简介&#xff0c;地址&#xff0c;开放信息&#xff0c;联系电话等 二、问题反馈中心 点击左下角的【问题反馈】按钮向作者进行问题反馈 三、开场地图 1、直接点击开场地图页面上的图标浏览该场景 2、通…

【JS】—闭包—双例对比法学习总结

一、选定知识点&#xff1a;闭包 二、指令学习 1. 闭包MDN的定义 闭包&#xff08;closure&#xff09;是一个函数以及其捆绑的周边环境状态&#xff08;lexical environment&#xff0c;词法环境&#xff09;的引用的组合。换而言之&#xff0c;闭包让开发者可以从内部函数…

常用的msvcp140.dll丢失的解决方法,msvcp140.dll丢失的原因

自从电脑出现故障&#xff0c;我的生活变得一团糟。他每天都需要使用电脑处理工作&#xff0c;可是突然有一天&#xff0c;他发现许多软件和游戏都无法正常运行。错误提示显示“找不到msvcp140.dll”&#xff0c;这让他感到非常困扰。今天想和大家分享一个在计算机使用过程中经…

使用爬虫代码获得深度学习目标检测或者语义分割中的图片。

问题描述&#xff1a;目标检测或者图像分割需要大量的数据&#xff0c;如果手动从网上找的话会比较慢&#xff0c;这时候&#xff0c;我们可以从网上爬虫下来&#xff0c;然后自己筛选即可。 代码如下&#xff08;不要忘记安装代码依赖的库&#xff09;&#xff1a; # -*- co…

linux 内存一致性

linux 出现内存一致性的场景 1、编译器优化 &#xff0c;代码上下没有关联的时候&#xff0c;因为编译优化&#xff0c;会有执行执行顺序不一致的问题&#xff08;多核单核都会出现&#xff09; 2、多核cpu乱序执行&#xff0c;cpu的乱序执行导致内存不一致&#xff08;多核出…

elasticSearch+kibana+logstash+filebeat集群改成https认证

文章目录 一、生成相关证书二、配置elasticSearh三、配置kibana四、配置logstash五、配置filebeat六、连接https es的java api 一、生成相关证书 ps&#xff1a;主节点操作 切换用户&#xff1a;su es 进入目录&#xff1a;cd /home/es/elasticsearch-7.6.2 创建文件&#x…

OpenCV(十三):图像中绘制直线、圆形、椭圆形、矩形、多边形和文字

目录 1.绘制直线line() 2.绘制圆形circle() 3.绘制椭圆形ellipse() 4.绘制矩形rectangle() 5.绘制多边形 fillPoly() 6.绘制文字putText() 7.例子 1.绘制直线line() CV_EXPORTS_W void line(InputOutputArray img,Point pt1, Point pt2,const Scalar& color,int t…

【数据结构与算法 三】常见数据结构与算法组合应用方式

一般的数据结构和对应的 很抱歉,作为一个文本AI模型,我无法直接绘制图表,但我可以为您列出常见的算法和数据结构分类,并为每个分类提供简要说明。您可以根据这些信息自行绘制图表。 算法分类: 搜索算法:用于在数据集中查找特定元素的算法,如线性搜索、二分搜索等。 排…

MTK6761/MT6761安卓核心板4G安卓智能模块详细参数性能介绍

MTK6761 安卓核心板采用12nm制程四核Cortex-A53、最高主频2.0GHZ 处理器&#xff0c;板载内存为 1GB8GB(2GB16GB、3GB32GB、4GB64GB)&#xff0c;搭载Android 9.0操作系统。 MTK6761&#xff08;曦力 A22&#xff09;安卓核心板基本概述 MTK6761安卓核心板 是一款高性能低功耗…

HikariCP源码修改,使其连接池支持Kerberos认证

HikariCP-4.0.3 修改HikariCP源码,使其连接池支持Kerberos认证 修改后的Hikari源码地址:https://github.com/Raray-chuan/HikariCP-4.0.3 Springboot使用hikari连接池并进行Kerberos认证访问Impala的demo地址:https://github.com/Raray-chuan/springboot-kerberos-hikari-im…

2023 AZ900备考

文章目录 如何学习最近准备考AZ900考试&#xff0c;找了一圈文档&#xff0c;结果发现看那么多文档&#xff0c;不如直接看官方的教程https://learn.microsoft.com/zh-cn/certifications/exams/az-900/ &#xff0c;简单直接&#xff0c;突然想到纳瓦尔宝典中提到多花时间进行思…