分类神经网络1:VGGNet模型复现

目录

分类网络的常见形式

VGG网络架构

VGG网络部分实现代码


分类网络的常见形式

常见的分类网络通常由特征提取部分分类部分组成。

特征提取部分实质就是各种神经网络,如VGG、ResNet、DenseNet、MobileNet等。其负责捕获数据的有用信息,一般是通过堆叠多个卷积层和池化层来实现的,这些层有助于检测图像中的边缘、纹理和特征。

分类部分通常是一个全连接层,负责将特征提取部分输出的信息映射到最终的类别或标签。这些全连接层通常包括一个或多个隐藏层,以及一个输出层,其中输出层的节点数量等于任务中的类别数量。

VGG网络架构

论文原址:https://arxiv.org/pdf/1409.1556v6.pdf

VGG 网络是由牛津大学的Visual Geometry Group 开发的,其结构特点在于使用了多个 3x3 的小卷积核,并通过这些小卷积层的重复堆叠来构建网络,从而能够捕捉到更加复杂和抽象的特征表示。VGG 网络的模型结构如下:

VGG网络的核心架构可以分为以下几个部分:

  1. 输入层:VGG网络接受224x224像素的RGB图像作为输入。
  2. 卷积层:网络的前几层由多个卷积层组成,每个卷积层都使用3x3的卷积核来提取图像的特征。这些卷积层后面通常跟着一个2x2 最大池化,用于逐步减小特征图的空间尺寸,同时增加特征深度。
  3. 池化层:在卷积层之后,网络使用最大池化层来降低特征图的空间分辨率,这有助于减少计算量并提取更加抽象的特征。
  4. 全连接层:经过多个卷积和池化层之后,网络的特征图被展平并通过几个全连接层进行处理。全连接层的作用是将学习到的特征映射到最终的分类结果。
  5. 输出层:VGG网络的最后是一个softmax层,它将网络的输出转换为概率分布,以便进行多类别的分类任务。

VGG网络的一个显著特点是其深度,其相关配置信息如下:

VGG系列不同变体内容如下:

  • VGG A:这是一个基础的配置,没有特别独特的设计。
  • VGG A-LRN:在这个版本中,加入了局部响应归一化(LRN),这是一种在AlexNet中首次使用的技术。不过,LRN在当前的深度学习实践中已经较少被采用。
  • VGG B:相较于A版本,B版本增加了两个卷积层,以增强网络的学习能力。
  • VGG C:在B的基础上,C版本进一步增加了三个卷积层,但这些层使用的是1x1的卷积核。1x1卷积核可以看作是对输入特征图进行线性变换,有助于减少参数数量并增加非线性。
  • VGG D:D版本在C版本的基础上做了调整,将1x1的卷积核替换为3x3的卷积核,这个配置后来被称为VGG16,因为它总共有16层。
  • VGG E:在D版本的基础上,E版本进一步增加了三个3x3的卷积层,形成了VGG19,总共有19层。

从图中可以看出,随着网络深度的加深,模型变得更为复杂。通常来说,增加网络的深度可以增加模型的表示能力,使其能够学习到更复杂的特征和模式,从而在某些任务上取得更好的性能。然而,随着网络深度的增加,模型的参数数量也会增加,导致模型的复杂度增加,训练和推理的计算成本也会增加,同时可能会增加过拟合的风险。

VGG网络部分实现代码

废话不多说,直接上干货

import torch
import torch.nn as nn__all__ = ["VGG", "vgg11_bn", "vgg13_bn", "vgg16_bn", "vgg19_bn"]cfg = {'A': [64,     'M', 128,      'M', 256, 256,           'M', 512, 512,           'M', 512, 512,           'M'],'B': [64, 64, 'M', 128, 128, 'M', 256, 256,           'M', 512, 512,           'M', 512, 512,           'M'],'C': [64, 64, 'M', 128, 128, 'M', 256, 256, 256,      'M', 512, 512, 512,      'M', 512, 512, 512,      'M'],'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}class ConvBNReLU(nn.Module):def __init__(self, in_channels, out_channels, stride=1,  kernel_size=3, padding=1):super(ConvBNReLU, self).__init__()self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding)self.bn = nn.BatchNorm2d(num_features=out_channels)self.relu = nn.ReLU(inplace=True)def forward(self, x):x = self.conv(x)x = self.bn(x)x = self.relu(x)return xclass VGG(nn.Module):def __init__(self, features, num_classes=1000, init_weights=True):super(VGG, self).__init__()self.features = featuresself.avgpool = nn.AdaptiveAvgPool2d((7, 7))self.classifier = nn.Sequential(nn.Linear(512 * 7 * 7, 4096),nn.ReLU(True),nn.Dropout(),nn.Linear(4096, 4096),nn.ReLU(True),nn.Dropout(),nn.Linear(4096, num_classes),)if init_weights:self._initialize_weights()def forward(self, x):for layer in self.features:x = layer(x)x = self.avgpool(x)x = torch.flatten(x, 1)x = self.classifier(x)return xdef _initialize_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')if m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.BatchNorm2d):nn.init.constant_(m.weight, 1)nn.init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):nn.init.normal_(m.weight, 0, 0.01)nn.init.constant_(m.bias, 0)def make_layers(cfg):layers = nn.ModuleList()in_channels = 3for i in cfg:if i == 'M':layers.append(nn.MaxPool2d(kernel_size=2, stride=2))else:layers.append(ConvBNReLU(in_channels=in_channels, out_channels=i))in_channels = ireturn layersdef vgg11_bn(num_classes):model = VGG(make_layers(cfg['A']), num_classes=num_classes)return modeldef vgg13_bn(num_classes):model = VGG(make_layers(cfg['B']), num_classes=num_classes)return modeldef vgg16_bn(num_classes):model = VGG(make_layers(cfg['C']), num_classes=num_classes)return modeldef vgg19_bn(num_classes):model = VGG(make_layers(cfg['D']), num_classes=num_classes)return modelif __name__=='__main__':import torchsummarydevice = 'cuda' if torch.cuda.is_available() else 'cpu'input = torch.ones(2, 3, 224, 224).to(device)net = vgg16_bn(num_classes=4)net = net.to(device)out = net(input)print(out)print(out.shape)torchsummary.summary(net, input_size=(3, 224, 224))# Total params: 134,285,380

这只是一个网络架构部分实现代码,其中 cfg 列表是 VGG 卷积和池化后的通道数,大家可以结合 VGG 的配置信息图一起对比理解。希望对大家有所帮助呀!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/315430.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

ASP.NET基于WEB的选课系统

摘要 设计本系统的目的是对选课信息进行管理。学生选课系统维护模块主要完成的是系统管理与维护功能。课题研究过程中,首先对系统管理模块进行了详尽的需求分析,经分析得到系统管理模块主要完成如下的功能:用户基本信息、选课信息的录入,查看…

Spring Boot 如何实现缓存预热

Spring Boot 实现缓存预热 1、使用启动监听事件实现缓存预热。2、使用 PostConstruct 注解实现缓存预热。3、使用 CommandLineRunner 或 ApplicationRunner 实现缓存预热。4、通过实现 InitializingBean 接口,并重写 afterPropertiesSet 方法实现缓存预热。 1、使用…

华为先进芯片麒麟9010效能再升级,挑战新高度 | 百能云芯

根据最新的彭博资讯报道,华为再次引领了智能手机行业的先进技术,其最新发布的Pura 70系列智能手机搭载了由中芯国际生产的麒麟9010高阶处理器。这一消息再次证明了华为在芯片设计和生产领域的持续创新能力,并且表明华为对于提升智能手机性能和…

【机器学习】集成学习---Bagging之随机森林(RF)

【机器学习】集成学习---Bagging之随机森林(RF) 一、引言1. 简要介绍集成学习的概念及其在机器学习领域的重要性。2. 引出随机森林作为Bagging算法的一个典型应用。 二、随机森林原理1. Bagging算法的基本思想2. 随机森林的构造3. 随机森林的工作机制 三…

开源文本嵌入模型M3E

进入正文前,先扯点题外话 这两天遇到一个棘手的问题,在用 docker pull 拉取镜像时,会报错: x509: certificate has expired or is not yet valid 具体是下面👇这样的 rootDS918:/volume2/docker/xiaoya# docker pul…

一款神奇的地理数据可视化python库

在地理信息系统(GIS)和地理数据可视化领域,Python的易用性和强大的库支持使其成为处理地理数据的理想选择之一。今天我们介绍Cartopy库,它为地理数据可视化提供了强大的支持。无论是对于GIS专业人士还是对地理数据可视化感兴趣的初…

同事上班这样摸鱼,我坐边上咋看他都在专心写代码啊

我边上有个同事,我坐他边上,但是每天看着他都眉头紧锁,忙的不亦乐乎,但终于有一天,我发现了他上班摸鱼的秘诀。 我劝你千万不要学会这4招,要不就该不好好上班了。 目录 1 上班看电影? 2 上班…

<计算机网络自顶向下> Internet Protocol(未完成)

互联网中的网络层 IP数据报格式 ver: 四个比特的版本号(IPV4 0100, IPV6 0110) headlen:head的长度(头部长度字段(IHL)指定了头部的长度,以32位字(4字节)为单位计算。这…

pytest测试基础

assert 验证关键字 需要pahton版本大于3.6,因为有个工具pip3;因为做了映射,所以下面命令pip3即pip pip install -U pytest -U参数可选,是如果已安装可更新。 如果上述demo变化 通过验证代码,测试环境没问题。…

接口测试-笔记

Date 2024年4月23日21:19:51 Author KarrySmile 1. 前言 因为想更加规范地开发接口,同时让自己测试接口的时候更加高效,更好地写好接口文档。所以学习黑马的《接口自动化测试》课程。链接:黑马程序员软件测试接口自动化测试全套视频教程&a…

MATLAB 运算符

MATLAB 运算符 运算符是一个符号,告诉编译器执行特定的数学或逻辑操作。MATLAB设计为主要在整个矩阵和数组上运行。因此,MATLAB中的运算符既可以处理标量数据,也可以处理非标量数据。MATLAB允许以下类型的基本运算- 算术运算符 关系运算符…

【linux】Linux第一个小程序-进度条

1. 预备知识:回车和换行 回车(Carriage Return,CR): 在早期的机械打字机中,回车指的是将打字机的打印头移回到行首的操作,这样打印头就可以开始新的一行的打印。在ASCII编码中,回车用…

数据库介绍(Mysql安装)

前言 工程师再在存储数据用文件就可以了,为什么还要弄个数据库? 一、什么是数据库? 文件保存数据有以下几个缺点: 文件的安全性问题文件不利于数据查询和管理文件不利于存储海量数据文件在程序中控制不方便 数据库存储介质: 磁…

编译支持播放H265的cef控件

接着在上次编译的基础上增加h265支持编译支持视频播放的cef控件(h264) 测试页面,直接使用cef_enhancement,里边带着的那个html即可,h265视频去这个网站下载elecard,我修改的这个版本参考了里边的修改方式,不过我的这个…

Blender面操作

1.细分Subdivide -选择一个面 -右键,细分 -微调,设置切割次数 2.删除 -选择一个或多个面,按X键 -选择要删除的是面,线还是点 3.挤出面Extrude -选择一个面 -Extrude工具 -拖拽手柄,向外挤出 -微调&#xff…

Opencv | 边缘提取

目录 一. 边缘检测1. 边缘的定义2. Sobel算子 边缘提取3. Scharr算子 边缘提取4. Laplacian算子 边缘提取5. Canny 边缘检测算法5.1 计算梯度的强度及方向5.2 非极大值抑制5.3 双阈值检测5.4 抑制孤立弱边缘 二. 轮廓信息1. 获取轮廓信息2. 画轮廓 一. 边缘检测 1. 边缘的定义…

自动化爬虫工具:you-get安装与使用

Windows下的安装命令: pip install you-get linux下的安装命令: pip3 install you-get 下载完成后,我们可以看到如下的警告,意思就是这个工具并未被添加到环境变量中,如果我们想在命令行中直接调用,需要…

sql今天学习总结

排序order by(默认升序) order by id desc(降序排序) order by id,number(先按id排再按name排序) in,not in and or 通配符 where name like "Aa%";选取所有以Aa开头的名字 like "%r" 以r结…

Matlab 使用subplot绘制多个子图,一元拟合

实现效果: clc; clear;filename sri.xlsx; % 确认文件路径data readtable(filename); datavalue data{:,2:end}; datavalue datavalue;fig figure(Position, [0, 0, 1500, 900]); indexString ["(a)","(b)","(c)","(d)&qu…

python自动生成SQL语句自动化

👽发现宝藏 前些天发现了一个巨牛的人工智能学习网站,通俗易懂,风趣幽默,忍不住分享一下给大家。【点击进入巨牛的人工智能学习网站】。 Python自动生成SQL语句自动化 在数据处理和管理中,SQL(Structured …