【PyTorch】现代卷积神经网络

文章目录

  • 1. 理论介绍
    • 1.1. 深度卷积神经网络(AlexNet)
      • 1.1.1. 概述
      • 1.1.2. 模型设计
    • 1.2. 使用块的网络(VGG)
    • 1.3. 网络中的网络(NiN)
  • 2. 实例解析
    • 2.1. 实例描述
    • 2.2. 代码实现
      • 2.2.1. 在FashionMNIST数据集上训练AlexNet
        • 2.2.1.1. 主要代码
        • 2.2.1.2. 完整代码
        • 2.2.1.3. 输出结果
      • 2.2.2. 在FashionMNIST数据集上训练VGG-11
        • 2.2.2.1. 主要代码
        • 2.2.2.2. 完整代码
        • 2.2.2.3. 输出结果
        • 2.2.2.1. 主要代码
        • 2.2.2.2. 完整代码
        • 2.2.2.3. 输出结果

1. 理论介绍

1.1. 深度卷积神经网络(AlexNet)

1.1.1. 概述

  • 特征本身应该被学习,而且在合理地复杂性前提下,特征应该由多个共同学习的神经网络层组成,每个层都有可学习的参数。在机器视觉中,最底层可能检测边缘、颜色和纹理,更高层建立在这些底层表示的基础上,以表示更大的特征。
  • 包含许多特征的深度模型需要大量的有标签数据,才能显著优于基于凸优化的传统方法(如线性方法和核方法)。ImageNet挑战赛为研究人员提高超大规模的数据集。
  • 深度学习对计算资源要求很高,训练可能需要数百个迭代轮数,每次迭代都需要通过代价高昂的许多线性代数层传递数据。用GPU训练神经网络大大提高了训练速度,相比于CPU,GPU由100~1000个小的处理单元(核心)组成,庞大的核心数量使GPU比CPU快几个数量级;GPU核心更简单,功耗更低;GPU拥有更高的内存带宽;卷积和矩阵乘法,都是可以在GPU上并行化的操作。
  • AlexNet首次证明了学习到的特征可以超越手工设计的特征。

1.1.2. 模型设计

模型设计

  • AlexNet使用ReLU而不是sigmoid作为其激活函数。
  • 在最后一个卷积层后有两个全连接层,分别有4096个输出。 这两个巨大的全连接层拥有将近1GB的模型参数。 由于早期GPU显存有限,原版的AlexNet采用了双数据流设计,使得每个GPU只负责存储和计算模型的一半参数。
  • AlexNet通过暂退法控制全连接层的模型复杂度。
  • 为了进一步扩充数据,AlexNet在训练时增加了大量的图像增强数据,如翻转、裁切和变色。 这使得模型更健壮,更大的样本量有效地减少了过拟合。

1.2. 使用块的网络(VGG)

  • 经典卷积神经网络的基本组成部分
    • 带填充以保持分辨率的卷积层
    • 非线性激活函数
    • 池化层
  • VGG块由一系列卷积层组成,后面再加上用于空间下采样的最大池化层。深层且窄的卷积(即 3 × 3 3\times3 3×3)比较浅层且宽的卷积更有效。不同的VGG模型可通过每个VGG块中卷积层数量和输出通道数量的差异来定义。
  • VGG网络由几个VGG块和全连接层组成。超参数变量conv_arch指定了每个VGG块里卷积层个数和输出通道数。全连接模块则与AlexNet中的相同。
  • VGG-11有5个VGG块,其中前两个块各有一个卷积层,后三个块各包含两个卷积层,共8个卷积层。 第一个模块有64个输出通道,每个后续模块将输出通道数量翻倍,直到该数字达到512。

1.3. 网络中的网络(NiN)

  • 网络中的网络(NiN)在每个像素的通道上分别使用多层感知机

2. 实例解析

2.1. 实例描述

  • 在FashionMNIST数据集上训练AlexNet。注:输入图像增加 224 × 224 224\times224 224×224,不使用跨GPU分解模型。
  • 在FashionMNIST数据集上训练VGG-11。注:可以按比例缩小通道数以减少参数量,学习率可以设置得略高。

2.2. 代码实现

2.2.1. 在FashionMNIST数据集上训练AlexNet

2.2.1.1. 主要代码
AlexNet = nn.Sequential(nn.Conv2d(1, 96, kernel_size=11, stride=4, padding=1), nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2),nn.Conv2d(96, 256, kernel_size=5, padding=2), nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2),nn.Conv2d(256, 384, kernel_size=3, padding=1), nn.ReLU(),nn.Conv2d(384, 384, kernel_size=3, padding=1), nn.ReLU(),nn.Conv2d(384, 256, kernel_size=3, padding=1), nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2),nn.Flatten(),nn.Linear(6400, 4096), nn.ReLU(),nn.Dropout(p=0.5),nn.Linear(4096, 4096), nn.ReLU(),nn.Dropout(p=0.5),nn.Linear(4096, 10)
).to(device, non_blocking=True)
2.2.1.2. 完整代码
import os, torch
from tensorboardX import SummaryWriter
from rich.progress import track
from torchvision.datasets import FashionMNIST
from torchvision.transforms import Compose, ToTensor, Resize
from torch.utils.data import DataLoader
from torch import nn, optimdef load_dataset():"""加载数据集"""root = "./dataset"transform = Compose([Resize(224), ToTensor()])mnist_train = FashionMNIST(root, True, transform, download=True)mnist_test = FashionMNIST(root, False, transform, download=True)dataloader_train = DataLoader(mnist_train, batch_size, shuffle=True, num_workers=num_workers, pin_memory=True,)dataloader_test = DataLoader(mnist_test, batch_size, shuffle=False,num_workers=num_workers, pin_memory=True,)return dataloader_train, dataloader_testif __name__ == "__main__":# 全局参数设置num_epochs = 10batch_size = 128num_workers = 3lr = 0.01device = torch.device('cuda')# 创建记录器def log_dir():root = "runs"if not os.path.exists(root):os.mkdir(root)order = len(os.listdir(root)) + 1return f'{root}/exp{order}'writer = SummaryWriter(log_dir=log_dir())# 数据集配置dataloader_train, dataloader_test = load_dataset()# 模型配置AlexNet = nn.Sequential(nn.Conv2d(1, 96, kernel_size=11, stride=4, padding=1), nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2),nn.Conv2d(96, 256, kernel_size=5, padding=2), nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2),nn.Conv2d(256, 384, kernel_size=3, padding=1), nn.ReLU(),nn.Conv2d(384, 384, kernel_size=3, padding=1), nn.ReLU(),nn.Conv2d(384, 256, kernel_size=3, padding=1), nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2),nn.Flatten(),nn.Linear(6400, 4096), nn.ReLU(),nn.Dropout(p=0.5),nn.Linear(4096, 4096), nn.ReLU(),nn.Dropout(p=0.5),nn.Linear(4096, 10)).to(device, non_blocking=True)criterion = nn.CrossEntropyLoss(reduction='none')optimizer = optim.SGD(AlexNet.parameters(), lr=lr)# 训练循环metrics_train = torch.zeros(3, device=device)  # 训练损失、训练准确度、样本数metrics_test = torch.zeros(2, device=device)   # 测试准确度、样本数for epoch in track(range(num_epochs), description='AlexNet'):AlexNet.train()for X, y in dataloader_train:X, y = X.to(device, non_blocking=True), y.to(device, non_blocking=True)loss = criterion(AlexNet(X), y)optimizer.zero_grad()loss.mean().backward()optimizer.step()AlexNet.eval()with torch.no_grad():metrics_train.zero_()for X, y in dataloader_train:X, y = X.to(device, non_blocking=True), y.to(device, non_blocking=True)y_hat = AlexNet(X)loss = criterion(y_hat, y)metrics_train[0] += loss.sum()metrics_train[1] += (y_hat.argmax(dim=1) == y).sum()metrics_train[2] += y.numel()metrics_train[0] /= metrics_train[2]metrics_train[1] /= metrics_train[2]metrics_test.zero_()for X, y in dataloader_test:X, y = X.to(device, non_blocking=True), y.to(device, non_blocking=True)y_hat = AlexNet(X)metrics_test[0] += (y_hat.argmax(dim=1) == y).sum()metrics_test[1] += y.numel()metrics_test[0] /= metrics_test[1]_metrics_train = metrics_train.cpu().numpy()_metrics_test = metrics_test.cpu().numpy()writer.add_scalars('metrics', {'train_loss': _metrics_train[0],'train_acc': _metrics_train[1],'test_acc': _metrics_test[0]}, epoch)writer.close()
2.2.1.3. 输出结果

AlexNet

2.2.2. 在FashionMNIST数据集上训练VGG-11

2.2.2.1. 主要代码
def vgg_block(num_convs, in_channels, out_channels):"""定义VGG块"""layers = []for _ in range(num_convs):layers.append(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1))layers.append(nn.ReLU())in_channels = out_channelslayers.append(nn.MaxPool2d(kernel_size=2, stride=2))return nn.Sequential(*layers)def vgg(conv_arch):"""定义VGG网络"""vgg_blocks = []in_channels = 1out_channels = 0for num_convs, out_channels in conv_arch:vgg_blocks.append(vgg_block(num_convs, in_channels, out_channels))in_channels = out_channelsreturn nn.Sequential(*vgg_blocks, nn.Flatten(),nn.Linear(out_channels * 7 * 7, 4096), nn.ReLU(),nn.Dropout(p=0.5),nn.Linear(4096, 4096), nn.ReLU(),nn.Dropout(p=0.5),nn.Linear(4096, 10),)
2.2.2.2. 完整代码
import os
from tensorboardX import SummaryWriter
from rich.progress import track
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, Resize, ToTensor
from torchvision.datasets import FashionMNISTdef load_dataset():"""加载数据集"""root = "./dataset"transform = Compose([Resize(224), ToTensor()])mnist_train = FashionMNIST(root, True, transform, download=True)mnist_test = FashionMNIST(root, False, transform, download=True)dataloader_train = DataLoader(mnist_train, batch_size, shuffle=True, num_workers=num_workers, pin_memory=True,)dataloader_test = DataLoader(mnist_test, batch_size, shuffle=False,num_workers=num_workers, pin_memory=True,)return dataloader_train, dataloader_testdef vgg_block(num_convs, in_channels, out_channels):"""定义VGG块"""layers = []for _ in range(num_convs):layers.append(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1))layers.append(nn.ReLU())in_channels = out_channelslayers.append(nn.MaxPool2d(kernel_size=2, stride=2))return nn.Sequential(*layers)def vgg(conv_arch):"""定义VGG网络"""vgg_blocks = []in_channels = 1out_channels = 0for num_convs, out_channels in conv_arch:vgg_blocks.append(vgg_block(num_convs, in_channels, out_channels))in_channels = out_channelsreturn nn.Sequential(*vgg_blocks, nn.Flatten(),nn.Linear(out_channels * 7 * 7, 4096), nn.ReLU(),nn.Dropout(p=0.5),nn.Linear(4096, 4096), nn.ReLU(),nn.Dropout(p=0.5),nn.Linear(4096, 10),)if __name__ == '__main__':# 全局参数配置num_epochs = 10batch_size = 128num_workers = 3lr = 0.05device = torch.device('cuda')# 记录器配置def log_dir():root = "runs"if not os.path.exists(root):os.mkdir(root)order = len(os.listdir(root)) + 1return f'{root}/exp{order}'writer = SummaryWriter(log_dir=log_dir())# 数据集配置dataloader_train, dataloader_test = load_dataset()# 模型配置ratio = 4conv_arch = ((1, 64), (1, 128), (2, 256), (2, 512), (2, 512))small_conv_arch =  [(pair[0], pair[1] // ratio) for pair in conv_arch]net = vgg(small_conv_arch).to(device, non_blocking=True)criterion = nn.CrossEntropyLoss(reduction='none')optimizer = optim.SGD(net.parameters(), lr=lr)# 训练循环metrics_train = torch.zeros(3, device=device)  # 训练损失、训练准确度、样本数metrics_test = torch.zeros(2, device=device)   # 测试准确度、样本数for epoch in track(range(num_epochs), description='VGG'):net.train()for X, y in dataloader_train:X, y = X.to(device, non_blocking=True), y.to(device, non_blocking=True)loss = criterion(net(X), y)optimizer.zero_grad()loss.mean().backward()optimizer.step()net.eval()with torch.no_grad():metrics_train.zero_()for X, y in dataloader_train:X, y = X.to(device, non_blocking=True), y.to(device, non_blocking=True)y_hat = net(X)loss = criterion(y_hat, y)metrics_train[0] += loss.sum()metrics_train[1] += (y_hat.argmax(dim=1) == y).sum()metrics_train[2] += y.numel()metrics_train[0] /= metrics_train[2]metrics_train[1] /= metrics_train[2]metrics_test.zero_()for X, y in dataloader_test:X, y = X.to(device, non_blocking=True), y.to(device, non_blocking=True)y_hat = net(X)metrics_test[0] += (y_hat.argmax(dim=1) == y).sum()metrics_test[1] += y.numel()metrics_test[0] /= metrics_test[1]_metrics_train = metrics_train.cpu().numpy()_metrics_test = metrics_test.cpu().numpy()writer.add_scalars('metrics', {'train_loss': _metrics_train[0],'train_acc': _metrics_train[1],'test_acc': _metrics_test[0]}, epoch)writer.close()
2.2.2.3. 输出结果

VGG

2.2.2.1. 主要代码

2.2.2.2. 完整代码

2.2.2.3. 输出结果

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/214887.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

UEFI下Windows10和Ubuntu22.04双系统安装图解

目录 简介制作U盘启动盘并从U盘启动电脑安装系统安装Windows系统安装Ubuntu 附录双系统时间不一致 简介 传统 Legacy BIOS主板下的操作系统安装可参考本人博客 U盘系统盘制作与系统安装(详细图解) ,本文介绍UEFI主板下的双系统安装&#xff…

2023.12.9 关于 Spring Boot 事务传播机制详解

目录 事务传播机制 七大事务传播机制 支持当前调用链上的事务 Propagation.REQUIRED Propagation.SUPPORTS Propagation.MANDATORY 不支持当前调用链上的事务 Propagation.REQUIRES_NEW Propagation.NOT_SUPPORTED Propagation.NEVER 嵌套事务 Propagation.NESTED…

一个音乐能够做成二维码吗?音乐的活码制作技巧

一个音乐能够做成二维码后展示吗?现在以二维码为载体来储存内容的方式越来越常见,比如图片、文件、视频、音频都可以做成二维码展示,人们也更习惯去扫码获取内容。音频作为日常工作生活中常用的一种内容,可以用音频二维码生成器来…

Unity_ET-TimerComponent

Unity_ET-TimerComponent 源码&#xff1a; namespace ETModel {public struct Timer{public long Id { get; set; }public long Time { get; set; }public TaskCompletionSource<bool> tcs;}[ObjectSystem]public class TimerComponentUpdateSystem : UpdateSystem<…

phpstudy小皮(PHP集成环境)下载及使用

下载 https://www.xp.cn/download.html直接官网下载即可&#xff0c;下载完解压是个.exe程序&#xff0c;直接点击安装就可以&#xff0c;它会自动在D盘目录为D:\phpstudy_pro 使用 phpMyAdmin是集成的数据库可视化&#xff0c;这里需要下载一下&#xff0c;在软件管理-》网站程…

three.js 入门三:buffergeometry贴图属性(position、index和uvs)

环境&#xff1a; three.js 0.159.0 一、基础知识 geometry&#xff1a;决定物体的几何形状、轮廓&#xff1b;material&#xff1a;决定物体呈现的色彩、光影特性、贴图皮肤&#xff1b;mesh&#xff1a;场景中的物体&#xff0c;由geometry和materia组成&#xff1b;textu…

数字系统设计(EDA)实验报告【出租车计价器】

一、问题描述 题目九&#xff1a;出租车计价器设计&#xff08;平台实现&#xff09;★★ 完成简易出租车计价器设计&#xff0c;选做停车等待计价功能。 1、基本功能&#xff1a; &#xff08;1&#xff09;起步8元/3km&#xff0c;此后2元/km&#xff1b; &#xff08;2…

Unity中实现ShaderToy卡通火(一)

文章目录 前言一、准备好我们的后处理基础脚本1、C#&#xff1a;2、Shader&#xff1a; 二、开始逐语句对ShaderToy进行转化1、首先&#xff0c;找到我们的主函数 mainImage2、其余的方法全部都是在 mainImage 函数中调用的方法3、替换后的代码(已经没报错了&#xff0c;都是效…

微服务学习:Nacos配置中心

先打开Nacos&#xff08;详见微服务学习&#xff1a;Nacos微服务架构中的服务注册、服务发现和动态配置&Nacos下载&#xff09; 1.环境隔离&#xff1a; 新建命名空间&#xff1a; 记住命名空间ID&#xff1a; c82496fb-237f-47f7-91ed-288a53a63324 再配置 就可达成环…

张正友相机标定法原理与实现

张正友相机标定法是张正友教授1998年提出的单平面棋盘格的相机标定方法。传统标定法的标定板是需要三维的,需要非常精确,这很难制作,而张正友教授提出的方法介于传统标定法和自标定法之间,但克服了传统标定法需要的高精度标定物的缺点,而仅需使用一个打印出来的棋盘格就可…

iOS ------ UICollectionView

一&#xff0c;UICollectionView的简介 UICollectionView是iOS6之后引入的一个新的UI控件&#xff0c;它和UITableView有着诸多的相似之处&#xff0c;其中许多代理方法都十分类似。简单来说&#xff0c;UICollectionView是比UITbleView更加强大的一个UI控件&#xff0c;有如下…

VMALL 商城系统

SpringBoot MySQL Vue等技术实现 技术栈 核心框架&#xff1a;SpringBoot 持久层框架&#xff1a;MyBatis 模板框架&#xff1a;Vue 数据库&#xff1a;MySQL 阿里云短信&#xff0c;对象存储OSS 项目包含源码和数据库文件。 效果图如下&#xff1a;

如果将视频转化为gif格式图

1.选择视频转换GIF&#xff1a; 2.添加视频文件&#xff1a; 3.点击“开始”&#xff1a; 4.选择设置&#xff0c;将格式选择为1080P更加清晰&#xff1a; 5.输出后的效果图&#xff1a;

5键键盘的输出 - 华为OD统一考试

OD统一考试 题解&#xff1a; Java / Python / C 题目描述 有一个特殊的 5键键盘&#xff0c;上面有 a,ctrl-c,ctrl-x,ctrl-v,ctrl-a五个键。 a 键在屏幕上输出一个字母 a; ctrl-c 将当前选择的字母复制到剪贴板; ctrl-x 将当前选择的 字母复制到剪贴板&#xff0c;并清空选择…

14-1、IO流

14-1、IO流 lO流打开和关闭lO流打开模式lO流对象的状态 非格式化IO二进制IO读取二进制数据获取读长度写入二进制数据 读写指针 和 随机访问设置读/写指针位置获取读/写指针位置 字符串流 lO流打开和关闭 通过构造函数打开I/O流 其中filename表示文件路径&#xff0c;mode表示打…

Linux下c开发

编程环境 Linux 下的 C 语言程序设计与在其他环境中的 C 程序设计一样&#xff0c; 主要涉及到编辑器、编译链接器、调试器及项目管理工具。编译流程 编辑器 Linux 中最常用的编辑器有 Vi。编译连接器 编译是指源代码转化生成可执行代码的过程。在 Linux 中&#xff0c;最常用…

【PWN】学习笔记(二)【栈溢出基础】

目录 课程教学C语言函数调用栈ret2textPWN工具 课程教学 课程链接&#xff1a;https://www.bilibili.com/video/BV1854y1y7Ro/?vd_source7b06bd7a9dd90c45c5c9c44d12e7b4e6 课程附件&#xff1a; https://pan.baidu.com/s/1vRCd4bMkqnqqY1nT2uhSYw 提取码: 5rx6 C语言函数调…

六级高频词汇3

目录 单词 参考链接 单词 400. nonsense n. 胡说&#xff0c;冒失的行动 401. nuclear a. 核子的&#xff0c;核能的 402. nucleus n. 核 403. retail n. /v. /ad. 零售 404. retain vt. 保留&#xff0c;保持 405. restrict vt. 限制&#xff0c;约束 406. sponsor n. …

骁龙8 Gen 3 vs A17 Pro

骁龙8 Gen 3 vs A17 Pro——谁会更胜一筹&#xff1f; Geekbench、AnTuTu 和 3DMark 等基准测试在智能手机领域发挥着至关重要的作用。它们为制造商和手机爱好者提供了设备性能的客观衡量标准。这些测试有助于评估难以测量的无形方面。然而&#xff0c;值得注意的是&#xff0c…

有病但合理的 ChatGPT 提示语

ChatGPT 面世一年多了&#xff0c;如何让大模型输出高质量内容&#xff0c;让提示词工程成了一门重要的学科。以下是一些有病但合理的提示词技巧&#xff0c;大部分经过论文证明&#xff0c;有效提高 ChatGPT 输出质量&#xff1a; ​1️⃣ Take a deep breath. 深呼吸 ✨ 作用…