经典卷积神经网络 - VGG

使用块的网络 - VGG。

使用多个 3 × 3 3\times 3 3×3的要比使用少个 5 × 5 5\times 5 5×5的效果要好。
在这里插入图片描述
在这里插入图片描述
VGG全称是Visual Geometry Group,因为是由Oxford的Visual Geometry Group提出的。AlexNet问世之后,很多学者通过改进AlexNet的网络结构来提高自己的准确率,主要有两个方向:小卷积核和多尺度。而VGG的作者们则选择了另外一个方向,即加深网络深度。

网络架构

卷积网络的输入是224 * 224RGB图像,整个网络的组成是非常格式化的,基本上都用的是3 * 3的卷积核以及 2 * 2max pooling,少部分网络加入了1 * 1的卷积核。因为想要体现出“上下左右中”的概念,3*3的卷积核已经是最小的尺寸了。

VGG16相比之前网络的改进是3个33卷积核来代替7x7卷积核,2个33卷积核来代替5*5卷积核,这样做的主要目的是在保证具有相同感知野的条件下,减少参数,提升了网络的深度。

多个VGG块后接全连接层。

不同次数的重复块得到不同的架构,如VGG-16,VGG-19等。

VGG:更大更深的AlexNet。

总结:

  • VGG使用可重复使用的卷积块来构建深度卷积神经网络
  • 不同的卷积块个数和超参数可以得到不同复杂度的变种

代码实现

使用数据集CIFAR

model.py

import torch
from torch import nnclass Vgg16(nn.Module):def __init__(self, *args, **kwargs) -> None:super().__init__(*args, **kwargs)self.model = nn.Sequential(nn.Conv2d(3,64,kernel_size=3,padding=1),nn.ReLU(),nn.Conv2d(64,64,kernel_size=3,padding=1),nn.ReLU(),nn.MaxPool2d(2,2),nn.Conv2d(64,128,kernel_size=3,padding=1),nn.ReLU(),nn.Conv2d(128,128,kernel_size=3,padding=1),nn.ReLU(),nn.MaxPool2d(2,2),nn.Conv2d(128,256,kernel_size=3,padding=1),nn.ReLU(),nn.Conv2d(256,256,kernel_size=3,padding=1),nn.ReLU(),nn.Conv2d(256,256,kernel_size=3,padding=1),nn.ReLU(),nn.MaxPool2d(2,2),nn.Conv2d(256,512,kernel_size=3,padding=1),nn.ReLU(),nn.Conv2d(512,512,kernel_size=3,padding=1),nn.ReLU(),nn.Conv2d(512,512,kernel_size=3,padding=1),nn.ReLU(),nn.MaxPool2d(2,2),nn.Conv2d(512, 512, kernel_size=3, padding=1),nn.ReLU(),nn.Conv2d(512, 512, kernel_size=3, padding=1),nn.ReLU(),nn.Conv2d(512, 512, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(2,2),nn.Flatten(),nn.Linear(7*7*512,4096),nn.Dropout(0.5),nn.Linear(4096,4096),nn.Dropout(0.5),nn.Linear(4096,10))def forward(self,x):return self.model(x)# 验证模型正确性
if __name__ == '__main__':net = Vgg16()x = torch.ones((64,3,244,244))output = net(x)print(output)

train.py

import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets
from torchvision.transforms import transforms
from model import Vgg16# 扫描数据次数
epochs = 3
# 分组大小
batch = 64
# 学习率
learning_rate = 0.01
# 训练次数
train_step = 0
# 测试次数
test_step = 0# 定义图像转换
transform = transforms.Compose([transforms.Resize(224),transforms.ToTensor()
])
# 读取数据
train_dataset = datasets.CIFAR10(root="./dataset",train=True,transform=transform,download=True)
test_dataset = datasets.CIFAR10(root="./dataset",train=False,transform=transform,download=True)
# 加载数据
train_dataloader = DataLoader(train_dataset,batch_size=batch,shuffle=True,num_workers=0)
test_dataloader = DataLoader(test_dataset,batch_size=batch,shuffle=True,num_workers=0)
# 数据大小
train_size = len(train_dataset)
test_size = len(test_dataset)
print("训练集大小:{}".format(train_size))
print("验证集大小:{}".format(test_size))# GPU
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(device)
# 创建网络
net = Vgg16()
net = net.to(device)
# 定义损失函数
loss = nn.CrossEntropyLoss()
loss = loss.to(device)
# 定义优化器
optimizer = torch.optim.SGD(net.parameters(),lr=learning_rate)writer = SummaryWriter("logs")
# 训练
for epoch in range(epochs):print("-------------------第 {} 轮训练开始-------------------".format(epoch))net.train()for data in train_dataloader:train_step = train_step + 1images,targets = dataimages = images.to(device)targets = targets.to(device)outputs = net(images)loss_out = loss(outputs,targets)optimizer.zero_grad()loss_out.backward()optimizer.step()if train_step%100==0:writer.add_scalar("Train Loss",scalar_value=loss_out.item(),global_step=train_step)print("训练次数:{},Loss:{}".format(train_step,loss_out.item()))# 测试net.eval()total_loss = 0total_accuracy = 0with torch.no_grad():for data in test_dataloader:test_step = test_step + 1images, targets = dataimages = images.to(device)targets = targets.to(device)outputs = net(images)loss_out = loss(outputs, targets)total_loss = total_loss + loss_outaccuracy = (targets == torch.argmax(outputs,dim=1)).sum()total_accuracy = total_accuracy + accuracy# 计算精确率print(total_accuracy)accuracy_rate = total_accuracy / test_sizeprint("第 {} 轮,验证集总损失为:{}".format(epoch+1,total_loss))print("第 {} 轮,精确率为:{}".format(epoch+1,accuracy_rate))writer.add_scalar("Test Total Loss",scalar_value=total_loss,global_step=epoch+1)writer.add_scalar("Accuracy Rate",scalar_value=accuracy_rate,global_step=epoch+1)torch.save(net,"./model/net_{}.pth".format(epoch+1))print("模型net_{}.pth已保存".format(epoch+1))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/169104.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

R语言生物群落(生态)数据统计分析与绘图

R 语言作的开源、自由、免费等特点使其广泛应用于生物群落数据统计分析。生物群落数据多样而复杂,涉及众多统计分析方法。以生物群落数据分析中的最常用的统计方法回归和混合效应模型、多元统计分析技术及结构方程等数量分析方法为主线,通过多个来自经典…

React 框架

1、React 框架简介 1.1、介绍 CS 与 BS结合:像 React,Vue 此类框架,转移了部分服务器的功能到客户端。将CS 和 BS 加以结合。客户端只用请求一次服务器,服务器就将所有js代码返回给客户端,所有交互类操作都不再依赖服…

muduo源码学习base——TimeStamp(UTC时间戳)

TimeStamp(UTC时间戳) 前置copyable和noncopyableTimeStampnow() 此博客跟着b站上的:大并发服务器开发(实战)学习做的笔记 前置 值语义、对象语义: 值语义:所谓值语义是一个对象被系统标准的复制方式复制后…

Spring Cloud之服务注册与发现(Eureka)

目录 Eureka 介绍 角色 实现流程 单机构建 注册中心 服务提供者 服务消费者 集群搭建 注册中心 服务提供者 自我保护机制 原理分析 Eureka 介绍 Eureka是spring cloud中的一个负责服务注册与发现的组件,本身是基于REST的服务,同时还提供了…

(完全解决)如何输入一个图的邻接矩阵(每两个点的亲密度矩阵affinity),然后使用sklearn进行谱聚类

文章目录 背景输入点直接输入邻接矩阵 背景 网上倒是有一些关于使用sklearn进行谱聚类的教程,但是这些教程的输入都是一些点的集合,然后根据谱聚类的原理,其会每两个点计算一次亲密度(可以认为两个点距离越大,亲密度越…

【学习草稿】bert文本分类

https://github.com/google-research/bert https://github.com/CyberZHG/keras-bert 在 BERT 中,每个单词的嵌入向量由三部分组成: Token 嵌入向量:该向量是 WordPiece 分词算法得到的子单词 ID 对应的嵌入向量。 Segment 嵌入向量&#x…

安科瑞带防逆流功能的数据通讯网关-安科瑞黄安南

AWT200 数据通讯网关应用于各种终端设备的数据采集与数据分析。用于实现设备的监测、控制、计算,为系统与设备之间建立通讯纽带,实现双向的数据通讯。实时监测并及时发现异常,同时自身根据用户规则进行逻辑判断,可以节省人力和通讯…

【机器学习】模型平移不变性/等变性归纳偏置Attention机制

Alphafold2具有旋转不变性吗——从图像识别到蛋白结构预测的旋转对称性实现 通过Alphafold2如何预测蛋白质结构,看有哪些机制或tricks可以利用? 一、等变Transformer 等变Transformer是Transformer众多变体的其中一种,其强调等变性。不变性…

CentOS7 安装 nodejs

获取安装文件 node历史版本地址 安装 上传到服务器安装位置cd 到压缩包位置,执行解压安装操作 [rootps-fdcnops-01 /]# cd usr/local/nodejs/ [rootps-fdcnops-01 nodejs]# tar -xzvf node-v16.16.0-linux-x64 配置环境变量 [rootps-fdcnops-01 nodejs]# vim /…

5G RedCap工业智能网关

5G RedCap工业智能网关是当前工业智能化发展领域的重要技术之一。随着物联网和工业互联网的迅速发展,企业对于实时数据传输和高速通信需求越来越迫切。在这种背景下,5G RedCap工业智能网关以其卓越的性能和功能,成为众多企业的首选。 5G RedC…

Qt扫盲-QPen 理论使用总结

QPen 理论使用总结 一、概述二、Pen Style 画笔风格三、Cap Style 帽风格四、Join Style 连接处样式 一、概述 QPen 是Qt绘图控件里面的一个重要的组件,和QColor 一样也是类似的一个属性类。这个类就是描述一个画笔具有的属性。 一个画笔 Pen 有style()&#xff0…

ExoPlayer架构详解与源码分析(6)——MediaPeriod

系列文章目录 ExoPlayer架构详解与源码分析(1)——前言 ExoPlayer架构详解与源码分析(2)——Player ExoPlayer架构详解与源码分析(3)——Timeline ExoPlayer架构详解与源码分析(4)—…

【C语言】关于char的取值范围的讨论

前提知识: 计算机内存中存的是整数的补码。 正数的原反补相同! 负数的补码 (除符号位以外)原码取反 1 负数的源码 (除符号位以外)补码取反 1 有符号的char,最高位二进制位表示符号位 …

蓝桥杯每日一题2023.10.19

题目描述 完全二叉树的权值 - 蓝桥云课 (lanqiao.cn) 题目分析 我们以每一个节点的坐标来将这一深度的权值之和相加从而算出权值和 要清楚每一个深度的其实节点和末尾节点,使用双指针将这个深度节点的权值和计算出来,记录所 需要的深度即可 #includ…

uni-app:引用文件的方法

绝对定位 ①import common from "/utils/common.js" ②import common from "utils/common.js" <template><view></view> </template> <script>import common from "/utils/common.js"export default {data() {ret…

通过内网穿透快速搭建公网可访问的Spring Boot接口调试环境

&#x1f525;博客主页&#xff1a; 小羊失眠啦 &#x1f516;系列专栏&#xff1a; C语言 、Cpolar、Linux ❤️感谢大家点赞&#x1f44d;收藏⭐评论✍️ 文章目录 前言1. 本地环境搭建1.1 环境参数1.2 搭建springboot服务项目 2. 内网穿透2.1 安装配置cpolar内网穿透2.1.1 …

vue3 源码解析(1)— reactive 响应式实现

前言 本文是 vue3 源码解析系列的第一篇文章&#xff0c;项目代码的整体实现是参考了 v3.2.10 版本&#xff0c;项目整体架构可以参考之前我写过的文章 rollup 实现多模块打包。话不多说&#xff0c;让我们通过一个简单例子开始这个系列的文章。 举个例子 <!DOCTYPE html…

微信小程序OA会议系统数据交互

前言 经过我们所写的上一文章&#xff1a;微信小程序会议OA系统其他页面-CSDN博客 在我们的是基础面板上面&#xff0c;可以看到出来我们的数据是死数据&#xff0c;今天我们就完善我们的是数据 后台 在我们去完成项目之前我们要把我们的项目后台准备好资源我放在我资源中&…

怎么进行设备维护与保养?智能巡检系统有什么用?

设备维护与保养需要遵循三个原则&#xff1a;故障设备全面分析的原则、故障设备深入检查的原则以及故障设备分析排查的原则。 一、故障设备全面分析的原则   检修人员在对设备维护与保养时&#xff0c;如果看到设备在运行中出现了异常的现象&#xff0c;要立刻停止设备的工作…

【C语言】指针错题(类型分析)

题目&#xff1a; #include <stdio.h> int main () {int*p NULL;int arr[10] {0}; return 0; } 选项&#xff1a; A、p arr ; B、 int (* ptr )[10]& arr ; C、 p & arr [ 0 ]; D、 p & arr ; 解析&#xff1a; 1、 p 是一个指针变量&#xff0c;指…