经典卷积神经网络-VGGNet

经典卷积神经网络-VGGNet

一、背景介绍

VGG是Oxford的Visual Geometry Group的组提出的。该网络是在ILSVRC 2014上的相关工作,主要工作是证明了增加网络的深度能够在一定程度上影响网络最终的性能。VGG有两种结构,分别是VGG16和VGG19,两者并没有本质上的区别,只是网络深度不一样。

在这里插入图片描述

二、VGG-16网络结构

在这里插入图片描述

其中VGG系列具体的网络结构如下表所示:

在这里插入图片描述

如图所示,这是论文中所有VGG网络的详细信息,D列对应的为VGG-16网络。16指的是在这个网络中包含16个卷积层和全连接层(不算池化层和Softmax)。

  • VGG-16的卷积层没有那么多的超参数,在整个网络模型中,所有卷积核的大小都是 3 × 3的,并且padding为same,stride为1。所有池化层的池化核大小都是 2 × 2 的,并且步长为2。在几次卷积之后紧跟着池化,整个网络结构很规整。

  • 总共包含约1.38亿个参数,但其结构并不复杂,结构很规整,都是几个卷积层后面跟着可以压缩图像大小的池化层,同时,卷积层的卷积核数量的变化也存在一定的规律,都是池化之后图像高度宽度减半,但在下一个卷积层中通道数翻倍,这正是这种简单网络结构的一个规则。

  • VGG16相比AlexNet的一个改进是采用连续的几个3x3的卷积核代替AlexNet中的较大卷积核(11x11,7x7,5x5)。对于给定的感受野(与输出有关的输入图片的局部大小),采用堆积的小卷积核是优于采用大的卷积核,因为多层非线性层可以增加网络深度来保证学习更复杂的模式,而且代价还比较小(参数更少)。在VGG中,使用了3个3x3卷积核来代替7x7卷积核,使用了2个3x3卷积核来代替5×5卷积核,这样做的主要目的是在保证具有相同感受野的条件下,提升了网络的深度,在一定程度上提升了神经网络的效果。

  • 它的主要缺点就是需要训练的特征数量非常大。有些文章介绍了VGG-19,但通过研究发现VGG-19和VGG-16的性能表现几乎不分高下,所以很多人还是使用VGG-16,这也说明了单纯的增加网络深度,其性能不会有太大的提升。

  • 论文中还介绍了权重初始化方法,即预训练低层模型参数为深层模型参数初始化赋值。原文:网络权重初始化是非常重要的,坏的初始化会使得深度网络的梯度的不稳定导致无法学习。为了解决这个问题,我们首先在网络A中使用随机初始化进行训练。然后到训练更深的结构时,我们将第一层卷积层和最后三层全连接层的参数用网络A中的参数初始化(中间层的参数随机初始化)。

  • 论文中揭示了,随着网络深度的增加,图像的高度和宽度都以一定规律不断缩小,每次池化之后刚好缩小一半,而通道数量在不断增加,而且刚好也是在每组卷积操作后增加一倍。也就是说,图像缩小和通道增加的比例是有规律的,从这个角度看,这篇论文很吸引人。

三、VGG-16的Pytorch实现

我们可以根据:https://dgschwend.github.io/netscope/#/preset/vgg-16,来搭建VGG-16。

在这里插入图片描述

后面要将VGG-16Net应用到CIFAR10数据集上,所以对网络做了一些修改,具体代码如下:

from torch import nnclass Vgg16_Net(nn.Module):def __init__(self):super(Vgg16_Net, self).__init__()self.layer1 = nn.Sequential(# input_size = (3, 32, 32)nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(64),nn.ReLU(inplace=True),# input_size = (64, 32, 32)nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(64),nn.ReLU(inplace=True),# input_size = (64, 32, 32)nn.MaxPool2d(kernel_size=2, stride=2))self.layer2 = nn.Sequential(# input_size = (64, 16, 16)nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(128),nn.ReLU(inplace=True),# input_size = (128, 16, 16)nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(128),nn.ReLU(inplace=True),# input_size = (128, 16, 16)nn.MaxPool2d(2, 2))self.layer3 = nn.Sequential(# input_size = (128, 8, 8)nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(256),nn.ReLU(inplace=True),# input_size = (256, 8, 8)nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(256),nn.ReLU(inplace=True),# input_size = (256, 8, 8)nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(256),nn.ReLU(inplace=True),# input_size = (256, 8, 8)nn.MaxPool2d(2, 2))self.layer4 = nn.Sequential(# input_size = (256, 4, 4)nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(512),nn.ReLU(inplace=True),# input_size = (512, 4, 4)nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(512),nn.ReLU(inplace=True),# input_size = (512, 4, 4)nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(512),nn.ReLU(inplace=True),# input_size = (512, 4, 4)nn.MaxPool2d(2, 2))self.layer5 = nn.Sequential(# input_size = (512, 2, 2)nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(512),nn.ReLU(inplace=True),# input_size = (512, 2, 2)nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(512),nn.ReLU(inplace=True),# input_size = (512, 2, 2)nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(512),nn.ReLU(inplace=True),# input_size = (512, 2, 2)nn.MaxPool2d(2, 2)# output_size = (512, 1, 1))self.conv = nn.Sequential(self.layer1,self.layer2,self.layer3,self.layer4,self.layer5)self.fc = nn.Sequential(# input_size = 512nn.Linear(512, 512),nn.ReLU(inplace=True),nn.Dropout(0.5),nn.Linear(512, 256),nn.ReLU(inplace=True),nn.Dropout(0.5),nn.Linear(256, 10))def forward(self, x):x = self.conv(x)# -1表示自动计算行数# -1也可以改成x.size(0) 表示batch_size的大小x = x.view(-1, 512 * 1 * 1)x = self.fc(x)return x

四、案例:CIFAR-10分类问题

import time
import torch
import torchvision
from model import *
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from matplotlib import pyplot as plt# 加载数据集 拿到dataloader
def load_dataset(batch_size):train_data = torchvision.datasets.CIFAR10("../dataset/CIFAR10", train=True, download=True, transform=transforms.ToTensor())test_data = torchvision.datasets.CIFAR10("../dataset/CIFAR10", train=False, download=True, transform=transforms.ToTensor())train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=2)test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=2)return train_dataloader, test_dataloader# 模型训练
def train(model, train_dataloader, criterion, optimizer, epochs, device, num_print, lr_scheduler=None, test_dataloader=None):# 记录train和test的acc方便绘制学习曲线record_train = list()record_test = list()# 开始训练model.train()for epoch in range(epochs):print("========== epoch: [{}/{}] ==========".format(epoch + 1, epochs))# total记录样本数 correct记录正确预测样本数total, correct, train_loss = 0, 0, 0start = time.time()# 结合enumerate函数和迭代器的unpacking 可以在获取数据的同时获取该批次数据对应的索引for i, (image, target) in enumerate(train_dataloader):image, target = image.to(device), target.to(device)output = model(image)loss = criterion(output, target)optimizer.zero_grad()loss.backward()optimizer.step()train_loss += loss.item()total += target.size(0)correct += (output.argmax(dim=1) == target).sum().item()train_acc = 100.0 * correct / totalif (i + 1) % num_print == 0:print("step: [{}/{}], train_loss: {:.3f} | train_acc: {:6.3f}% | lr: {:.6f}".format(i + 1,len(train_dataloader), train_loss / (i + 1), train_acc, get_cur_lr(optimizer)))# 更新当前优化器的学习率if lr_scheduler is not None:lr_scheduler.step()print("--- cost time: {:.4f}s ---".format(time.time() - start))if test_dataloader is not None:record_test.append(test(model, test_dataloader, criterion, device))record_train.append(train_acc)# 保存当前模型torch.save(model.state_dict(), "train_model/VGG-16Net_{}.pth".format(epoch + 1))return record_train, record_test# 模型测试
def test(model, test_dataloader, criterion, device):# total记录样本数 correct记录正确预测样本数total, correct = 0, 0# 开始测试model.eval()with torch.no_grad():print("*************** test ***************")for X, y in test_dataloader:X, y = X.to(device), y.to(device)output = model(X)loss = criterion(output, y)total += y.size(0)correct += (output.argmax(dim=1) == y).sum().item()test_acc = 100.0 * correct / totalprint("test_loss: {:.3f} | test_acc: {:6.3f}%".format(loss.item(), test_acc))print("************************************\n")# 记得重新调用model.train()model.train()return test_acc# 获取当前的学习率 这里直接返回了第一个参数分组的学习率
def get_cur_lr(optimizer):for param_group in optimizer.param_groups:return param_group['lr']# 绘制学习曲线
def learning_curve(record_train, record_test=None):# 设置 Matplotlib 图形样式# ggplot2 是一个用于数据可视化的流行 R 语言包,以其优雅和灵活的语法而闻名plt.style.use("ggplot")plt.plot(range(1, len(record_train) + 1), record_train, label="train acc")if record_test is not None:plt.plot(range(1, len(record_test) + 1), record_test, label="test acc")plt.legend(loc=4)plt.title("learning curve")plt.xticks(range(0, len(record_train) + 1, 5))plt.yticks(range(0, 101, 5))plt.xlabel("epoch")plt.ylabel("accuracy")plt.show()# 定义超参数
BATCH_SIZE = 128
NUM_EPOCHS = 20
NUM_CLASSES = 10
LEARNING_RATE = 0.02
MOMENTUM = 0.9
WEIGHT_DECAY = 0.0005
NUM_PRINT = 100
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"def main():model = Vgg16_Net()model = model.to(DEVICE)# 加载数据train_dataloader, test_dataloader = load_dataset(BATCH_SIZE)# 定义损失函数criterion = nn.CrossEntropyLoss()# 定义优化器optimizer = torch.optim.SGD(model.parameters(),lr=LEARNING_RATE,momentum=MOMENTUM,weight_decay=WEIGHT_DECAY,nesterov=True)# 定义学习率调度器lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)# 进行训练 返回训练集正确率和测试集正确率record_train, record_test = train(model, train_dataloader, criterion, optimizer, NUM_EPOCHS, DEVICE, NUM_PRINT, lr_scheduler, test_dataloader)# 绘制学习曲线learning_curve(record_train, record_test)if __name__ == '__main__':main()

查看训练结果可以发现,测试集正确率基本保持在87.3%左右,训练集正确率接近100%:

在这里插入图片描述

学习曲线如下:

在这里插入图片描述

参考链接:

  • https://cloud.tencent.com/developer/article/1638597

  • https://blog.csdn.net/m0_50127633/article/details/117047057?spm=1001.2014.3001.5502

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/229485.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Cuk、Zeta和Sepic开关电源拓扑结构

Cuk、Zeta和Sepic变换器,三种拓扑结构大致类似。不同点在于电感和二极管,MOS管的位置关系的变化。 Cuk电源是一种非隔离的直流电源转换器,其基本结构包括输入滤波电容、开关管、输入电感、输出电感和输出电容等元件。Cuk电路可以看作是Boost和Buck电路的…

适用于各种危险区域的火焰识别摄像机,实时监测、火灾预防、安全监控,为安全保驾护航

火灾是一种极具破坏力的灾难,对人们的生命和财产造成了严重的威胁。为了更好地预防和防范火灾,火焰识别摄像机作为一种先进的监控设备,正逐渐受到人们的重视和应用。本文将介绍火焰识别摄像机在安全监控和火灾预防方面的全面应用方案。 一、火…

自动驾驶论文

文章目录 一、Convolutional Social Pooling for Vehicle Trajectory Prediction二、QCNet:Query-Centric Trajectory Prediction三、VectorNet: Encoding HD Maps and Agent Dynamics from Vectorized Representation 一、Convolutional Social Pooling for Vehicl…

【Spark精讲】一文讲透Spark宽窄依赖的区别

宽依赖窄依赖的区别 窄依赖:RDD 之间分区是一一对应的宽依赖:发生shuffle,多对多的关系 宽依赖是子RDD的一个分区依赖了父RDD的多个分区父RDD的一个分区的数据,分别流入到子RDD的不同分区特例:cartesian算子对应的Car…

vue +elementui 项目登录通过不同账号切换侧边栏菜单的颜色

前景提要:要求不同权限账号登录侧边栏颜色不一样。分为 theme:1代表默认样式,theme:2代表深色主题样式。 1.首先定义一个主题文件 theme.js,定义两个主题样式 // 主要是切换菜单栏和菜单头部主题的设计,整体主题样式切…

JavaScript系列——正则表达式

文章目录 需求场景正则表达式的定义创建正则表达式通过 / 表示式/ 创建通过构造函数创建 编写一个正则表达式的模式使用简单模式使用特殊字符常用特殊字符列表特殊字符组和范围 正则表达式使用代码演示 常用示例验证手机号码合法性 小结 需求场景 在前端开发领域,在…

MVCC 并发控制原理-源码解析(非常详细)

基础概念 并发事务带来的问题 1)脏读:一个事务读取到另一个事务更新但还未提交的数据,如果另一个事务出现回滚或者进一步更新,则会出现问题。 2)不可重复读:在一个事务中两次次读取同一个数据时&#xff0c…

数字化制造安全防线:迅软DSE助力通用设备企业终端安全卫士

客户简要介绍 某公司是一家主要生产新型激光打印机、喷墨打印机、其它打印机、精密多功能机、传真机等办公自动化用品的企业。公司与顾客建立长期的信赖忠诚关系”的方针,逐步完善公司的各项运营,不断扩充市场前景。产品除国内销售外,还销往…

uni-app模版(扩展插件)

锋哥原创的uni-app视频教程: 2023版uniapp从入门到上天视频教程(Java后端无废话版),火爆更新中..._哔哩哔哩_bilibili2023版uniapp从入门到上天视频教程(Java后端无废话版),火爆更新中...共计23条视频,包括:第1讲 uni…

python统计分析——直方图(plt.hist)

使用matplotlib.pyplot.hist()函数绘制直方图 from matplotlib.pyplot as pltdata_setnp.array([2,3,3,4,4,4,4,5,5,6]) plt.hist(fish_data) 下面介绍plt.hist()函数中常用的几个重要参数(参数等号后为默认设置): (1&#xff0…

WebStorm 创建一个Vue项目(1)

一、下载并安装WebStorm 步骤一 步骤二 选择激活方式 激活码: I2A0QUY8VU-eyJsaWNlbnNlSWQiOiJJMkEwUVVZOFZVIiwibGljZW5zZWVOYW1lIjoiVU5JVkVSU0lEQURFIEVTVEFEVUFMIERFIENBTVBJTkFTIiwiYXNzaWduZWVOYW1lIjoiVGFvYmFv77yaSkVU5YWo5a625qG25rAIOa0uW3peS9nOWup…

[Angular] 笔记 16:模板驱动表单 - 选择框与选项

油管视频: Select & Option (Template Driven Forms) Select & Option 在 pokemon.ts 中新增 interface: export interface Pokemon {id: number;name: string;type: string;isCool: boolean;isStylish: boolean;acceptTerms: boolean; }// new interface…

从0搭建github.io网页

点击跳转到🔗我的博客文章目录 从0搭建github.io网页 文章目录 从0搭建github.io网页1.成果展示1.1 网址和源码1.2 页面展示 2.new对象2.1 创建仓库 3.github.io仓库的初始化3.1 千里之行,始于足下3.2 _config.yml3.3 一点杂活 4.PerCheung.github.io.p…

工业4G 物联网网关——机房动环监控系统应用方案介绍

机房动环监控系统是什么?机房动环监控系统的全称为机房动力环境监控系统,是一套安装在机房内的监控系统,可以对分散在机房各处的独立动力设备、环境和安防进行实时监测,统计和分析处理相关数据,第一时间侦测到故障发生…

MyBatis学习一:快速入门

前言 公司要求没办法,前端也要了解一下后端知识,这里记录一下自己的学习 学习教程:黑马mybatis教程全套视频教程,2天Mybatis框架从入门到精通 文档: https://mybatis.net.cn/index.html MyBatis 快速入门&#xf…

回顾 2023,展望 2024

by zhengkai.blog.csdn.net 项目与心得 今年最大的项目和心得,非GCP莫属,作为全球顶尖的云平台, GCP有他的优势,也有很多难用的地方。但是作为当时的一个strategic solution,我们的印度本地化项目必须使用GCP&#xf…

非线性最小二乘问题的数值方法 —— 从牛顿迭代法到高斯-牛顿法 (实例篇 V)

Title: 非线性最小二乘问题的数值方法 —— 从牛顿迭代法到高斯-牛顿法 (实例篇 V) 姊妹博文 非线性最小二乘问题的数值方法 —— 从牛顿迭代法到高斯-牛顿法 (I) 非线性最小二乘问题的数值方法 —— 从牛顿迭代法到高斯-牛顿法 (II) 非线性最小二乘问题的数值方法 —— 从牛顿…

Docker 从入门到实践:Docker介绍

前言 在当今的软件开发和部署领域,Docker已经成为了一个不可或缺的工具。Docker以其轻量级、可移植性和标准化等特点,使得应用程序的部署和管理变得前所未有的简单。无论您是一名开发者、系统管理员,还是IT架构师,理解并掌握Dock…

CSS 纵向底部往上动画

<template><div class"container" mouseenter"startAnimation" mouseleave"stopAnimation"><!-- 旋方块 --><div class"box" :class"{ scale-up-ver-bottom: isAnimating }"><!-- 元素内容 --&g…

c++_STL容器总结

STL容器总结 1.STL的基本概念1.2STL的六大组件 2.string类2.1string的基本概念2.2string容器常用操作 3.vector容器3.1vector容器基本概述 4.deque容器4.1deque容器的基本概念4.2deque容器的实现原理4.3deque常用API 5. stack容器5.2stack常用API 6.queue容器6.1 queue 容器基本…