基于Pytorch构建DenseNet网络对cifar-10进行分类

DenseNet是指Densely connected convolutional networks(密集卷积网络)。它的优点主要包括有效缓解梯度消失、特征传递更加有效、计算量更小、参数量更小、性能比ResNet更好。它的缺点主要是较大的内存占用。

DenseNet网络与Resnet、GoogleNet类似,都是为了解决深层网络梯度消失问题的网络。

Resnet从深度方向出发,通过建立前面层与后面层之间的“短路连接”或“捷径”,从而能训练出更深的CNN网络。

GoogleNet从宽度方向出发,通过Inception(利用不同大小的卷积核实现不同尺度的感知,最后进行融合来得到图像更好的表征)。

DenseNet从特征入手,通过对前面所有层与后面层的密集连接,来极致利用训练过程中的所有特征,进而达到更好的效果和减少参数。

DenseNet网络

Dense Block:像GoogLeNet网络由Inception模块组成、ResNet网络由残差块(Residual Building Block)组成一样,DenseNet网络由Dense Block组成,论文截图如下所示:每个层从前面的所有层获得额外的输入,并将自己的特征映射传递到后续的所有层,使用级联(Concatenation)方式,每一层都在接受来自前几层的”集体知识(collective knowledge)”。增长率(growth rate)k是每个层的额外通道数。

58c8038c8e5f0cf7dea34eb09bd15c88.png

其实说了那么多我也不大明白原理和数学推理,只需要按照相关代码做就行了

class Bottleneck(nn.Module):def __init__(self, input_channel, growth_rate):super(Bottleneck, self).__init__()self.bn1 = nn.BatchNorm2d(input_channel)self.relu1 = nn.ReLU(inplace=True)self.conv1 = nn.Conv2d(input_channel, 4 * growth_rate, kernel_size=1)self.bn2 = nn.BatchNorm2d(4 * growth_rate)self.relu2 = nn.ReLU(inplace=True)self.conv2 = nn.Conv2d(4 * growth_rate, growth_rate, kernel_size=3, padding=1)def forward(self, x):out = self.conv1(self.relu1(self.bn1(x)))out = self.conv2(self.relu2(self.bn2(out)))out = torch.cat([out, x], 1)return out
class Transition(nn.Module):def __init__(self, input_channels, out_channels):super(Transition, self).__init__()self.bn = nn.BatchNorm2d(input_channels)self.relu = nn.ReLU(inplace=True)self.conv = nn.Conv2d(input_channels, out_channels, kernel_size=1)def forward(self, x):out = self.conv(self.relu(self.bn(x)))out = F.avg_pool2d(out, 2)return out
class DenseNet(nn.Module):def __init__(self, nblocks, growth_rate, reduction, num_classes):super(DenseNet, self).__init__()self.growth_rate = growth_ratenum_planes = 2 * growth_rateself.basic_conv = nn.Sequential(nn.Conv2d(3, 2 * growth_rate, kernel_size=7, stride=2, padding=3),nn.BatchNorm2d(2 * growth_rate),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2, padding=1))self.dense1 = self._make_dense_layers(num_planes, nblocks[0])num_planes += nblocks[0] * growth_rateout_planes = int(math.floor(num_planes * reduction))self.trans1 = Transition(num_planes, out_planes)num_planes = out_planesself.dense2 = self._make_dense_layers(num_planes, nblocks[1])num_planes += nblocks[1] * growth_rateout_planes = int(math.floor(num_planes * reduction))self.trans2 = Transition(num_planes, out_planes)num_planes = out_planesself.dense3 = self._make_dense_layers(num_planes, nblocks[2])num_planes += nblocks[2] * growth_rateout_planes = int(math.floor(num_planes * reduction))self.trans3 = Transition(num_planes, out_planes)num_planes = out_planesself.dense4 = self._make_dense_layers(num_planes, nblocks[3])num_planes += nblocks[3] * growth_rateself.AdaptiveAvgPool2d = nn.AdaptiveAvgPool2d(1)# 全连接层self.fc = nn.Sequential(nn.Linear(num_planes, 256),nn.ReLU(inplace=True),# 使一半的神经元不起作用,防止参数量过大导致过拟合nn.Dropout(0.5),nn.Linear(256, 128),nn.ReLU(inplace=True),nn.Dropout(0.5),nn.Linear(128, 10))def _make_dense_layers(self, in_planes, nblock):layers = []for i in range(nblock):layers.append(Bottleneck(in_planes, self.growth_rate))in_planes += self.growth_ratereturn nn.Sequential(*layers)def forward(self, x):out = self.basic_conv(x)out = self.trans1(self.dense1(out))out = self.trans2(self.dense2(out))out = self.trans3(self.dense3(out))out = self.dense4(out)out = self.AdaptiveAvgPool2d(out)out = out.view(out.size(0), -1)out = self.fc(out)return out
def DenseNet121():return DenseNet([6, 12, 24, 16], growth_rate=32, reduction=0.5, num_classes=10)
def DenseNet169():return DenseNet([6, 12, 32, 32], growth_rate=32, reduction=0.5, num_classes=10)
def DenseNet201():return DenseNet([6, 12, 48, 32], growth_rate=32, reduction=0.5, num_classes=10)
def DenseNet265():return DenseNet([6, 12, 64, 48], growth_rate=32, reduction=0.5, num_classes=10)
# 初始化模型
from torchstat import stat
# 定义模型输出模式,GPU和CPU均可
model = DenseNet121().to(DEVICE)

在NVIDIA GeForce GTX 1660 SUPER显卡上训练了100轮,大致上一轮1分钟,这是DenseNet网络训练的损失率和准确率,在验证集也是保持80%的准确率。

fef4e1c3ccec7ba873ae14960d444595.png

DenseNet也是一个系列,包括DenseNet-121、DenseNet-169等等,论文中给出了4种层数的DenseNet,论文截图如下所示:所有网络的增长率k是32,表示每个Dense Block中每层输出的feature map个数。

410bfbcfc3a6141a28efec184547aa49.png

关于图像分类的模型算法,热情也没了,到此也就告一段落了,后续再讨论一些新的话题。

最后欢迎关注公众号:python与大数据分析

47d362ba65d9cc25fac0dea80aa05dc0.jpeg

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/98679.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

机器学习深度学习——transformer(机器翻译的再实现)

👨‍🎓作者简介:一位即将上大四,正专攻机器学习的保研er 🌌上期文章:机器学习&&深度学习——自注意力和位置编码(数学推导代码实现) 📚订阅专栏:机器…

【Golang系统开发】搜索引擎(2) 压缩词典

写在前面 这篇文章我们就给出一系列的数据结构,使得词典能达到越来越高的压缩比。当然,和倒排索引记录表的大小相比,词典只占据了非常小的空间。那么为什么要对词典进行压缩呢? 这是因为决定信息检索系统的查询响应时间的一个重…

Spring Boot 如何通过jdbc+HikariDataSource 完成对Mysql 操作

😀前言 本篇博文是关于Spring Boot 如何通过jdbcHikariDataSource 完成对Mysql 操作的说明,希望你能够喜欢😊 🏠个人主页:晨犀主页 🧑个人简介:大家好,我是晨犀,希望我的…

lvs-DR

lvs-DR数据包流向分析 client向目标VIP发出请求。 DIR根据负载均衡算法一台active的RS(RIR1),将RIP1所在的网卡的mac地址作为目标的mac地址,发送到局域网里。 RIRI在局域网中的收到这个帧,拆开后发现目标&#xff08…

CSRF

CSRF CSRF,跨站域请求伪造,通常攻击者会伪造一个场景(例如一条链接),来诱使用户点击,用户一旦点击,黑客的攻击目的也就达到了,他可以盗用你的身份,以你的名义发送恶意请…

Vue-6.编译器webstorm

Vue专栏(帮助你搭建一个优秀的Vue架子) Vue-1.零基础学习Vue Vue-2.Nodejs的介绍和安装 Vue-3.Vue简介 Vue-4.编译器VsCode Vue-5.编译器Idea Vue-6.编译器webstorm Vue-7.命令创建Vue项目 Vue-8.Vue项目配置详解 Vue-9.集成(.editorconfig、…

Docker搭建LNMP运行Wordpress平台

一、项目1.1 项目环境1.2 服务器环境1.3 任务需求 二、Linux 系统基础镜像三、Nginx1、建立工作目录2、编写 Dockerfile 脚本3、准备 nginx.conf 配置文件4、生成镜像5、创建自定义网络6、启动镜像容器7、验证 nginx 四、Mysql1、建立工作目录2、编写 Dockerfile3、准备 my.cnf…

Java自学到什么程度就可以去找工作了?

引言 Java作为一门广泛应用于软件开发领域的编程语言,对于初学者来说,了解到什么程度才能开始寻找实习和入职机会是一个常见的问题。 本文将从实习和入职这两个方面,分点详细介绍Java学习到什么程度才能够开始进入职场。并在文章末尾给大家安…

设计模式之迭代器模式(Iterator)的C++实现

1、迭代器模式的提出 在软件开发过程中,操作的集合对象内部结构常常变化,在访问这些对象元素的同时,也要保证对象内部的封装性。迭代器模式提供了一种利用面向对象的遍历方法来遍历对象元素。迭代器模式通过抽象一个迭代器类,不同…

【Leetcode】98. 验证二叉搜索树

一、题目 1、题目描述 给你一个二叉树的根节点 root ,判断其是否是一个有效的二叉搜索树。 有效 二叉搜索树定义如下: 节点的左子树只包含 小于 当前节点的数。节点的右子树只包含 大于 当前节点的数。所有左子树和右子树自身必须也是二叉搜索树。示例1: 输入:root = …

FT2000+低温情况下RTC守时问题

1、背景介绍 飞腾2000芯片通过I2C连接一块RTC时钟芯片(BellingBL5372)来实现麒麟信安系统下后的守时功能。目前BIOS支持UEFI功能,BIOS上电后能获取RTC时间,并将时间写入相应的UEFI变量或内存区域,操作系统上电后使用U…

antd5源码调试环境启动(MacOS)

将源码下载至本地 这里antd5 版本是5.8.3 $ git clone gitgithub.com:ant-design/ant-design.git $ cd ant-design $ npm install $ npm start前提:安装python3、node版本18.14.0(这是本人当前下载的版本) python3安装教程可参考:https://…

Vue3 用父子组件通信实现页面页签功能

一、大概流程 二、用到的Vue3知识 1、组件通信 (1)父给子 在vue3中父组件给子组件传值用到绑定和props 因为页签的数组要放在父页面中, data(){return {tabs: []}}, 所以顶部栏需要向父页面获取页签数组 先在页签页面中定义props用来接…

Docker 常规软件安装

1. 总体安装步骤 1. 搜索镜像 search 2. 拉取镜像 pull 3. 查看镜像 images 4. 启动镜像 - 端口映射 run 5. 停止容器 stop 6. 移除容器 rm 2. 安装tomcat 1. 搜索 docker search tomcat 2. 拉取 docker pull tomcat 3. 查看本地镜像 docker images tomcat 4. 创建容器实…

搭载KaihongOS的工业平板、机器人、无人机等产品通过3.2版本兼容性测评,持续繁荣OpenHarmony生态

近日,搭载深圳开鸿数字产业发展有限公司(简称“深开鸿”)KaihongOS软件发行版的工业平板、机器人、无人机等商用产品均通过OpenAtom OpenHarmony(以下简称“OpenHarmony”)3.2 Release版本兼容性测评,获颁O…

漏洞指北-VulFocus靶场专栏-中级01

漏洞指北-VulFocus靶场专栏-中级01 中级001 🌸dcrcms 文件上传 (CNVD-2020-27175)🌸step1:输入账号 密码burp suite 拦截 修改类型为 jpeg 中级002 🌸thinkphp3.2.x 代码执行🌸step1:burpsuite …

《HeadFirst设计模式(第二版)》第十一章代码——代理模式

代码文件目录: RMI: MyRemote package Chapter11_ProxyPattern.RMI;import java.rmi.Remote; import java.rmi.RemoteException;public interface MyRemote extends Remote {public String sayHello() throws RemoteException; }MyRemoteClient packa…

Java之优雅处理 NullPointerException空指针异常

前言 NPE问题就是,我们在开发中经常碰到的NullPointerException。假设我们有两个类,他们的UML类图如下图所示 在这种情况下,有如下代码 user.getAddress().getProvince(); 这种写法,在user为null时,是有可能报Nul…

网络编程面试笔试题

一、OSI 7层模型,TCP/IP 4层模型 5层模型。 以及每一层的功能(重点:第三层 第四层) 答: 7层模型(①物理层:二进制比特流传输,②数据链路层:相邻结点的可靠传输&#xf…

奇舞周刊第503期:图解串一串 webpack 的历史和核心功能

记得点击文章末尾的“ 阅读原文 ”查看哟~ 下面先一起看下本期周刊 摘要 吧~ 奇舞推荐 ■ ■ ■ 图解串一串 webpack 的历史和核心功能 提到打包工具,可能你会首先想到 webpack。那没有 webpack 之前,都是怎么打包的呢?webpack 都有哪些功能&…