经典卷积神经网络 - NIN

网络中的网络,NIN。

AlexNet和VGG都是先由卷积层构成的模块充分抽取空间特征,再由全连接层构成的模块来输出分类结果。但是其中的全连接层的参数量过于巨大,因此NiN提出用1*1卷积代替全连接层,串联多个由卷积层和“全连接”层构成的小网络来构建⼀个深层网络。

AlexNet和VGG对LeNet的改进主要在于如何扩大和加深这两个模块。
或者,可以想象在这个过程的早期使用全连接层。然而,如果使用了全连接层,可能会完全放弃表征的空间结构。

网络中的网络NiN)提供了一个非常简单的解决方案:在每个像素的通道上分别使用多层感知机。也就是使用了多个1*1的卷积核。同时他认为全连接层占据了大量的内存,所以整个网络结构中没有使用全连接层。

NIN块

image-20231023194711049

一个卷积层后跟两个全连接层。

  • 步幅为1,无填充,输出形状跟卷积层输出一样。
  • 起到全连接层的作用。

NIN网络结构
在这里插入图片描述
image-20231024090401226

  • 无全连接层

  • 交替使用NIN块和步幅为2的最大池化层

    逐步减小高宽和增大通道数

  • 最后使用全局平均池化层得到输出

    其输入通道数是类别数

此网络结构总计4层: 3mlpconv + 1global_average_pooling

优点:

  1. 提供了网络层间映射的一种新可能;
  2. 增加了网络卷积层的非线性能力。

总结:

  • NIN块使用卷积层加上个 1 × 1 1\times 1 1×1卷积,后者对每个像素增加了非线性性
  • NIN使用全局平均池化层来替代VGG和AlexNet中的全连接层,不容易过拟合,更少的参数个数

代码实现

使用CIFAR-10数据集。

maxpooling不改变通道数,只改变长和宽

model.py

import torch
from torch import nn# nin块
def nin_block(in_channels,out_channels,kernel_size,strides,padding):return nn.Sequential(nn.Conv2d(in_channels,out_channels,kernel_size,strides,padding),nn.ReLU(),nn.Conv2d(out_channels,out_channels,kernel_size=1),nn.ReLU(),nn.Conv2d(out_channels,out_channels,kernel_size=1),nn.ReLU(),)# 构建网络
class NIN(nn.Module):def __init__(self, *args, **kwargs) -> None:super().__init__(*args, **kwargs)self.model = nn.Sequential(nin_block(3,96,kernel_size=11,strides=4,padding=0),nn.MaxPool2d(3,stride=2),nin_block(96,256,kernel_size=5,strides=1,padding=2),nn.MaxPool2d(3,stride=2),nin_block(256,384,kernel_size=3,strides=1,padding=1),nn.MaxPool2d(3,stride=2),nn.Dropout(0.5),nin_block(384,10,kernel_size=3,strides=1,padding=1),nn.AdaptiveAvgPool2d((1,1)),nn.Flatten())def forward(self,x):return self.model(x)# 验证模型正确性
if __name__ == '__main__':nin = NIN()x = torch.ones((64,3,244,244))output = nin(x)print(output)

train.py

import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets
from torchvision.transforms import transforms
from model import NIN# 扫描数据次数
epochs = 3
# 分组大小
batch = 64
# 学习率
learning_rate = 0.01
# 训练次数
train_step = 0
# 测试次数
test_step = 0# 定义图像转换
transform = transforms.Compose([transforms.Resize(224),transforms.ToTensor()
])
# 读取数据
train_dataset = datasets.CIFAR10(root="./dataset",train=True,transform=transform,download=True)
test_dataset = datasets.CIFAR10(root="./dataset",train=False,transform=transform,download=True)
# 加载数据
train_dataloader = DataLoader(train_dataset,batch_size=batch,shuffle=True,num_workers=0)
test_dataloader = DataLoader(test_dataset,batch_size=batch,shuffle=True,num_workers=0)
# 数据大小
train_size = len(train_dataset)
test_size = len(test_dataset)
print("训练集大小:{}".format(train_size))
print("验证集大小:{}".format(test_size))# GPU
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(device)
# 创建网络
net = NIN()
net = net.to(device)
# 定义损失函数
loss = nn.CrossEntropyLoss()
loss = loss.to(device)
# 定义优化器
optimizer = torch.optim.SGD(net.parameters(),lr=learning_rate)writer = SummaryWriter("logs")
# 训练
for epoch in range(epochs):print("-------------------第 {} 轮训练开始-------------------".format(epoch))net.train()for data in train_dataloader:train_step = train_step + 1images,targets = dataimages = images.to(device)targets = targets.to(device)outputs = net(images)loss_out = loss(outputs,targets)optimizer.zero_grad()loss_out.backward()optimizer.step()if train_step%100==0:writer.add_scalar("Train Loss",scalar_value=loss_out.item(),global_step=train_step)print("训练次数:{},Loss:{}".format(train_step,loss_out.item()))# 测试net.eval()total_loss = 0total_accuracy = 0with torch.no_grad():for data in test_dataloader:test_step = test_step + 1images, targets = dataimages = images.to(device)targets = targets.to(device)outputs = net(images)loss_out = loss(outputs, targets)total_loss = total_loss + loss_outaccuracy = (targets == torch.argmax(outputs,dim=1)).sum()total_accuracy = total_accuracy + accuracy# 计算精确率print(total_accuracy)accuracy_rate = total_accuracy / test_sizeprint("第 {} 轮,验证集总损失为:{}".format(epoch+1,total_loss))print("第 {} 轮,精确率为:{}".format(epoch+1,accuracy_rate))writer.add_scalar("Test Total Loss",scalar_value=total_loss,global_step=epoch+1)writer.add_scalar("Accuracy Rate",scalar_value=accuracy_rate,global_step=epoch+1)torch.save(net,"./model/net_{}.pth".format(epoch+1))print("模型net_{}.pth已保存".format(epoch+1))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/171340.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

电脑定时关机

电脑定时关机 1.右键 管理 2. 3. 4. 5. shutdown.exe/s /f /t 06.点击完成就好了 7.这里面可以 看到定时任务和启动 右键有运行 结束 禁用

Flask 上传文件,requests通过接口上传文件

这是一个使用 Flask 框架实现文件上传功能的示例代码。该代码定义了两个路由: /upload:处理文件上传请求。在该路由中,我们首先从请求中获取上传的文件,然后将文件保存到本地磁盘上,并返回一个字符串表示上传成功。 /…

LLM系列 | 22 : Code Llama实战(下篇):本地部署、量化及GPT-4对比

引言 模型简介 依赖安装 模型inference 代码补全 4-bit版模型 代码填充 指令编码 Code Llama vs ChatGPT vs GPT4 小结 引言 青山隐隐水迢迢,秋尽江南草未凋。 小伙伴们好,我是《小窗幽记机器学习》的小编:卖热干面的小女孩。紧接…

记录--vue3实现excel文件预览和打印

这里给大家分享我在网上总结出来的一些知识,希望对大家有所帮助 前言 在前端开发中,有时候一些业务场景中,我们有需求要去实现excel的预览和打印功能,本文在vue3中如何实现Excel文件的预览和打印。 预览excel 关于实现excel文档在…

android——自定义控件(编辑框)、悬浮窗

一、自定义编辑框 效果图: 主要的代码为: class EditLayout JvmOverloads constructor(context: Context, attrs: AttributeSet? null, defStyleAttr: Int 0 ) : ConstraintLayout(context, attrs, defStyleAttr) {private var editTitle: Stringpr…

五、W5100S/W5500+RP2040树莓派Pico<UDP Client数据回环测试>

文章目录 1. 前言2. 协议简介2.1 简述2.2 优点2.3 应用 3. WIZnet以太网芯片4. UDP Client回环测试4.1 程序流程图4.2 测试准备4.3 连接方式4.4 相关代码4.5 测试现象 5. 注意事项6. 相关链接 1. 前言 UDP是一种无连接的网络协议,它提供了一种简单的、不可靠的方式来…

线框图软件:Balsamiq Wireframes mac中文介绍

Balsamiq Wireframes mac是一款用于创建线框图的软件工具。它旨在帮助用户快速制作出清晰、简洁的界面原型,以便在设计和开发过程中进行协作和沟通。 Balsamiq Wireframes具有简单直观的用户界面,使用户能够快速添加和编辑各种用户界面元素,如…

Java采集传感器数据,亲测有效!

背景 先说背景, 最近公司项目需要用到传感器,采集设备温湿度,倾斜角,电流…,公司采购采购了一个温湿度传感器给我们开发测试使用,如下图: 看着还挺精致有没有。 进入正题 有了这个温湿度传感器…

Spring | Spring Cache 缓存框架

Spring Cache 缓存框架: Spring Cache功能介绍Spring Cache的Maven依赖Spring Cache的常用注解EnableCaching注解CachePut注解Cacheable注解CacheEvict注解 Spring Cache功能介绍 Spring Cache是Spring的一个框架,实现了基于注解的缓存功能。只需简单加一…

ubuntu 22.04安装百度网盘

百度网盘 客户端下载 (baidu.com) 下载地址 sudo dpkg -i baidunetdisk_4.17.7_amd64.deb

高防CDN:保卫您的网站免受攻击之利与弊

在当今数字化时代,网络安全对于网站经营者至关重要。高防CDN(Content Delivery Network)技术旨在提供强大的安全性,以保护网站免受恶意攻击。本文将探讨高防CDN为普通网站带来的优势与不足之处,并分析国内外高防CDN的发…

谷歌真的不喜欢 Node.js ?

有人在 Quora 上提问,为什么谷歌不喜欢 Node.js 呢,Google 的 UX 工程师和来自 Node.js 团队的开发者分别回答了他们对这个问题的看法,对于编程语言来说,每一门语言都有它自己的优势,重要的是如何用它去解决问题。 谷…

SpringBoot集成Redis主从架构实现读写分离(哨兵模式)

一、前言 这里会使用到spring-boot-starter-data-redis包,spring boot 2的spring-boot-starter-data-redis中,默认使用的是lettuce作为redis客户端,也推荐使用lettuce,Redis使用哨兵集群,这里会通过lettuce连接到哨兵…

Flume基本使用--mysql数据输出

MySQL数据输出 在MySQL中建立数据库school,在数据库中建立表student。SQL语句如下: create database school; use school; create table student(id int not null,name varchar(40),age int,grade int,primary key(id) ); 请使用Flume实时捕…

1. 两数之和、Leetcode的Python实现

博客主页:🏆看看是李XX还是李歘歘 🏆 🌺每天分享一些包括但不限于计算机基础、算法等相关的知识点🌺 💗点关注不迷路,总有一些📖知识点📖是你想要的💗 ⛽️今…

【Elasticsearch】es脚本编程使用详解

目录 一、es脚本语言介绍 1.1 什么是es脚本 1.2 es脚本支持的语言 1.3 es脚本语言特点 1.4 es脚本使用场景 二、环境准备 2.1 docker搭建es过程 2.1.1 拉取es镜像 2.1.2 启动容器 2.1.3 配置es参数 2.1.4 重启es容器并访问 2.2 docker搭建kibana过程 2.2.1 拉取ki…

Kafka - 深入了解Kafka基础架构:Kafka的基本概念

文章目录 Kafka的基本概念 Kafka的基本概念 我们首先了解一些Kafka的基本概念。 1)Producer :消息生产者,就是向kafka broker发消息的客户端2)Consumer :消息消费者,向kafka broker获取消息的客户端3&…

【代码随想录】算法训练计划03

1、203. 移除链表元素 题目: 给你一个链表的头节点 head 和一个整数 val ,请你删除链表中所有满足 Node.val val 的节点,并返回 新的头节点 。 输入:head [1,2,6,3,4,5,6], val 6 输出:[1,2,3,4,5] 思路&#xf…

hdlbits系列verilog解答(模块按名字)-22

文章目录 一、问题描述二、verilog源码三、仿真结果 一、问题描述 此问题类似于模块。您将获得一个名为的 mod_a 模块,该模块按某种顺序具有 2 个输出和 4 个输入。您必须按名称将 6 个端口连接到顶级模块的端口: module mod_a ( output out1, output …

DVWA-Cross Site Request Forgery (CSRF)

大部分网站都会要求用户登录后,使用相应的权限在网页中进行操作,比如发邮件、购物或者转账等都是基于特定用户权限的操作。浏览器会短期或长期地记住用户的登录信息,但是,如果这个登录信息被恶意利用呢?就有可能发生CSRF CSRF的英文全称为Cross Site Request Forgery,中文…