【深度学习实验】卷积神经网络(八):使用深度残差神经网络ResNet完成图片多分类任务

目录

一、实验介绍

二、实验环境

1. 配置虚拟环境

2. 库版本介绍

三、实验内容

0. 导入必要的工具包

1. 构建数据集(CIFAR10Dataset)

 a. read_csv_labels()

b. CIFAR10Dataset

 2. 构建模型(FeedForward)

3.整合训练、评估、预测过程(Runner)

4. __main__

5. 代码整合


一、实验介绍

        本实验实现了实现深度残差神经网络ResNet,并基于此完成图像分类任务。

        残差网络(ResNet)是一种深度神经网络架构,用于解决深层网络训练过程中的梯度消失和梯度爆炸问题。通过引入残差连接(residual connection)来构建网络层与层之间的跳跃连接,使得网络可以更好地优化深层结构。

        残差网络的一个重要应用是在图像识别任务中,特别是在深度卷积神经网络(CNN)中。通过使用残差模块,可以构建非常深的网络,例如ResNet,其在ILSVRC 2015图像分类挑战赛中取得了非常出色的成绩。

        在ResNet中,每个残差块由一个或多个卷积层组成,其中包含了跳跃连接。跳跃连接将输入直接添加到残差块的输出中,从而使得网络可以学习残差函数,即残差块只需学习将输入的变化部分映射到输出,而不需要学习完整的映射关系。这种设计有助于减轻梯度消失问题,使得网络可以更深地进行训练。

二、实验环境

        本系列实验使用了PyTorch深度学习框架,相关操作如下:

1. 配置虚拟环境

conda create -n DL python=3.7 
conda activate DL
pip install torch==1.8.1+cu102 torchvision==0.9.1+cu102 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html
conda install matplotlib
 conda install scikit-learn

2. 库版本介绍

软件包本实验版本目前最新版
matplotlib3.5.33.8.0
numpy1.21.61.26.0
python3.7.16
scikit-learn0.22.11.3.0
torch1.8.1+cu1022.0.1
torchaudio0.8.12.0.2
torchvision0.9.1+cu1020.15.2

三、实验内容

0. 导入必要的工具包

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision.io import read_image

1. 构建数据集(CIFAR10Dataset)

        CIFAR10数据集共有60000个样本,每个样本都是一张32*32像素的RGB图像(彩色图像),每个RGB图像又必定分为3个通道(R通道、G通道、B通道)。CIFAR10中有10类物体,标签值分别按照0~9来区分,他们分别是飞机( airplane )、汽车( automobile )、鸟( bird )、猫( cat )、鹿( deer )、狗( dog )、青蛙( frog )、马( horse )、船( ship )和卡车( truck )。为减小运行时间,本实验选取其中1000张作为训练集。

数据集链接:

CIFAR-10 and CIFAR-100 datasets (toronto.edu)icon-default.png?t=N7T8http://www.cs.toronto.edu/~kriz/cifar.html

 a. read_csv_labels()

        从CSV文件中读取标签信息并返回一个标签字典。

def read_csv_labels(fname):"""读取fname来给标签字典返回一个文件名"""with open(fname, 'r') as f:# 跳过文件头行(列名)lines = f.readlines()[1:]tokens = [l.rstrip().split(',') for l in lines]return dict(((name, label) for name, label in tokens))
  •  使用open函数打开指定文件名的CSV文件,并将文件对象赋值给变量f。这里使用'r'参数以只读模式打开文件。

  • 使用文件对象的readlines()方法读取文件的所有行,并将结果存储在名为lines的列表中。通过切片操作[1:],跳过了文件的第一行(列名),将剩余的行存储在lines列表中。

  • 列表推导式(list comprehension):对lines列表中的每一行进行处理。对于每一行,使用rstrip()方法去除行末尾的换行符,并使用split(',')方法将行按逗号分割为多个标记。最终,将所有行的标记组成的子列表存储在tokens列表中。

  • 使用字典推导式(dictionary comprehension)将tokens列表中的子列表转换为字典。对于tokens中的每个子列表,将子列表的第一个元素作为键(name),第二个元素作为值(label),最终返回一个包含这些键值对的字典。

b. CIFAR10Dataset

class CIFAR10Dataset(Dataset):def __init__(self, folder_path, fname):self.labels = read_csv_labels(os.path.join(folder_path, fname))self.folder_path = os.path.join(folder_path, 'train')def __len__(self):return len(self.labels)def __getitem__(self, idx):img = read_image(self.folder_path + '/' + str(idx + 1) + '.png')label = self.labels[str(idx + 1)]return img, torch.tensor(int(label))
  • 构造函数:

    • 接受两个参数

      • folder_path表示数据集所在的文件夹路径

      • fname表示包含标签信息的文件名。

    • 调用read_csv_labels函数,传递folder_pathfname作为参数,以读取CSV文件中的标签信息,并将返回的标签字典存储在self.labels变量中。

    • 通过拼接folder_path和字符串'train'来构建数据集的文件夹路径,将结果存储在self.folder_path变量中。

  • def __len__(self)

    • 这是CIFAR10Dataset类的方法,用于返回数据集的长度,即样本的数量。

  • def __getitem__(self, idx): 这是CIFAR10Dataset类的方法,用于根据给定的索引idx获取数据集中的一个样本。它首先根据索引idx构建图像文件的路径,并调用read_image函数来读取图像数据,将结果存储在img变量中。然后,它通过将索引转换为字符串,并使用该字符串作为键来从self.labels字典中获取相应的标签,将结果存储在label变量中。最后,它返回一个元组,包含图像数据和经过torch.tensor转换的标签。

 2. 构建模型(FeedForward)

        参考前文:

【深度学习实验】卷积神经网络(七):实现深度残差神经网络ResNet-CSDN博客icon-default.png?t=N7T8https://blog.csdn.net/m0_63834988/article/details/133705834

3.整合训练、评估、预测过程(Runner)

        参考前文:

【深度学习实验】前馈神经网络(九):整合训练、评估、预测过程(Runner)_QomolangmaH的博客-CSDN博客icon-default.png?t=N7T8https://blog.csdn.net/m0_63834988/article/details/133219448?spm=1001.2014.3001.5501

4. __main__

if __name__ == '__main__':batch_size = 20# 构建训练集train_data = CIFAR10Dataset('cifar10_tiny', 'trainLabels.csv')train_iter = DataLoader(train_data, batch_size=batch_size)# 构建测试集test_data = CIFAR10Dataset('cifar10_tiny', 'trainLabels.csv')test_iter = DataLoader(test_data, batch_size=batch_size)# 模型训练num_classes = 10# 定义模型model = ResNet(num_classes)# 定义损失函数loss_fn = F.cross_entropy# 定义优化器optimizer = torch.optim.SGD(model.parameters(), lr=0.1)runner = Runner(model, optimizer, loss_fn, metric=None)runner.train(train_iter, num_epochs=10, save_path='chapter_5')# 模型预测runner.load_model('chapter_5.pth')x, label = next(iter(test_iter))predict = torch.argmax(runner.predict(x.float()), dim=1)print('predict:', predict)print('  label:', label)

5. 代码整合

# 导入必要的工具包
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision.io import read_image#  残差连接, 输入和输出的维度有时是相同的, 有时是不同的, 所以需要 use_1x1conv来判断是否需要
class Residual(nn.Module):def __init__(self, input_channels, num_channels, use_1x1conv=False, strides=1):super().__init__()self.conv1 = nn.Conv2d(input_channels, num_channels, kernel_size=3, padding=1, stride=strides)self.conv2 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1)if use_1x1conv:self.conv3 = nn.Conv2d(input_channels, num_channels, kernel_size=1, stride=strides)else:self.conv3 = None# 批量归一化层,将会在第7章讲到self.bn1 = nn.BatchNorm2d(num_channels)self.bn2 = nn.BatchNorm2d(num_channels)def forward(self, X):Y = F.relu(self.bn1(self.conv1(X)))Y = self.bn2(self.conv2(Y))if self.conv3:X = self.conv3(X)Y += Xreturn F.relu(Y)# 残差网络是由几个不同的残差块组成的
def resnet_block(input_channels, num_channels, num_residuals, first_block=False):blk = []for i in range(num_residuals):if i == 0 and not first_block:blk.append(Residual(input_channels, num_channels,use_1x1conv=True, strides=2))else:blk.append(Residual(num_channels, num_channels))return blkclass ResNet(nn.Module):def __init__(self, num_classes):super().__init__()self.b1 = nn.Sequential(nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3),nn.BatchNorm2d(64), nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2, padding=1))self.b2 = nn.Sequential(*resnet_block(64, 64, 2, first_block=True))self.b3 = nn.Sequential(*resnet_block(64, 128, 2))self.b4 = nn.Sequential(*resnet_block(128, 256, 2))self.b5 = nn.Sequential(*resnet_block(256, 512, 2))self.head = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), nn.Flatten(), nn.Linear(512, num_classes))def forward(self, x):net = nn.Sequential(self.b1, self.b2, self.b3, self.b4, self.b5, self.head)return net(x)import osdef read_csv_labels(fname):"""读取fname来给标签字典返回一个文件名"""with open(fname, 'r') as f:# 跳过文件头行(列名)lines = f.readlines()[1:]tokens = [l.rstrip().split(',') for l in lines]return dict(((name, label) for name, label in tokens))class CIFAR10Dataset(Dataset):def __init__(self, folder_path, fname):self.labels = read_csv_labels(os.path.join(folder_path, fname))self.folder_path = os.path.join(folder_path, 'train')def __len__(self):return len(self.labels)def __getitem__(self, idx):img = read_image(self.folder_path + '/' + str(idx + 1) + '.png')label = self.labels[str(idx + 1)]return img, torch.tensor(int(label))class Runner(object):def __init__(self, model, optimizer, loss_fn, metric=None):self.model = modelself.optimizer = optimizerself.loss_fn = loss_fn# 用于计算评价指标self.metric = metric# 记录训练过程中的评价指标变化self.dev_scores = []# 记录训练过程中的损失变化self.train_epoch_losses = []self.dev_losses = []# 记录全局最优评价指标self.best_score = 0# 模型训练阶段def train(self, train_loader, dev_loader=None, **kwargs):# 将模型设置为训练模式,此时模型的参数会被更新self.model.train()num_epochs = kwargs.get('num_epochs', 0)log_steps = kwargs.get('log_steps', 100)save_path = kwargs.get('save_path', 'best_model.pth')eval_steps = kwargs.get('eval_steps', 0)# 运行的step数,不等于epoch数global_step = 0if eval_steps:if dev_loader is None:raise RuntimeError('Error: dev_loader can not be None!')if self.metric is None:raise RuntimeError('Error: Metric can not be None')# 遍历训练的轮数for epoch in range(num_epochs):total_loss = 0# 遍历数据集for step, data in enumerate(train_loader):x, y = datalogits = self.model(x.float())loss = self.loss_fn(logits, y.long())total_loss += lossif step % log_steps == 0:print(f'loss:{loss.item():.5f}')loss.backward()self.optimizer.step()self.optimizer.zero_grad()# 每隔一定轮次进行一次验证,由eval_steps参数控制,可以采用不同的验证判断条件if eval_steps != 0:if (epoch + 1) % eval_steps == 0:dev_score, dev_loss = self.evaluate(dev_loader, global_step=global_step)print(f'[Evalute] dev score:{dev_score:.5f}, dev loss:{dev_loss:.5f}')if dev_score > self.best_score:self.save_model(f'model_{epoch + 1}.pth')print(f'[Evaluate]best accuracy performance has been updated: {self.best_score:.5f}-->{dev_score:.5f}')self.best_score = dev_score# 验证过程结束后,请记住将模型调回训练模式self.model.train()global_step += 1# 保存当前轮次训练损失的累计值train_loss = (total_loss / len(train_loader)).item()self.train_epoch_losses.append((global_step, train_loss))self.save_model(f'{save_path}.pth')print('[Train] Train done')# 模型评价阶段def evaluate(self, dev_loader, **kwargs):assert self.metric is not None# 将模型设置为验证模式,此模式下,模型的参数不会更新self.model.eval()global_step = kwargs.get('global_step', -1)total_loss = 0self.metric.reset()for batch_id, data in enumerate(dev_loader):x, y = datalogits = self.model(x.float())loss = self.loss_fn(logits, y.long()).item()total_loss += lossself.metric.update(logits, y)dev_loss = (total_loss / len(dev_loader))self.dev_losses.append((global_step, dev_loss))dev_score = self.metric.accumulate()self.dev_scores.append(dev_score)return dev_score, dev_loss# 模型预测阶段,def predict(self, x, **kwargs):self.model.eval()logits = self.model(x)return logits# 保存模型的参数def save_model(self, save_path):torch.save(self.model.state_dict(), save_path)# 读取模型的参数def load_model(self, model_path):self.model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))if __name__ == '__main__':batch_size = 20# 构建训练集train_data = CIFAR10Dataset('cifar10_tiny', 'trainLabels.csv')train_iter = DataLoader(train_data, batch_size=batch_size)# 构建测试集test_data = CIFAR10Dataset('cifar10_tiny', 'trainLabels.csv')test_iter = DataLoader(test_data, batch_size=batch_size)# 模型训练num_classes = 10# 定义模型model = ResNet(num_classes)# 定义损失函数loss_fn = F.cross_entropy# 定义优化器optimizer = torch.optim.SGD(model.parameters(), lr=0.1)runner = Runner(model, optimizer, loss_fn, metric=None)runner.train(train_iter, num_epochs=15, save_path='chapter_5')# 模型预测runner.load_model('chapter_5.pth')x, label = next(iter(test_iter))predict = torch.argmax(runner.predict(x.float()), dim=1)print('predict:', predict)print('  label:', label)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/157206.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

计算机视觉(Computer Vision, CV)是什么?

什么是计算机视觉 近年来,计算机视觉 (Computer Vision,简称CV) 不断普及,已成为人工智能 (AI) 增长最快的领域之一。计算机视觉致力于使计算机能够识别和理解图像和视频中的物体和人。 计算机视觉应用程序使用来自传感设备、人工智能、机器…

kafka 开启认证授权

前言 1、前面自己写了一篇关于各个环境各个模式的安装的文章,大家可以去看看 kafka各种环境安装(window,linux,docker,k8s),包含KRaft模式 2、使用版本 kafka_2.13-3.4.1 3、kafka验证方式,有两大类如下,文档内容在 kafka官方文档的 第七节…

WAF绕过-信息收集之反爬虫延时代理池 46

老师用的阿里云的服务器,装了宝塔和安全狗, 演示案例 Safedog-默认拦截机制分析绕过-未开CC 没有打开防止流量攻击的安全狗, 而这里,get请求可以直接看到返回结果,而head就不行。 我们就给工具换成get请求 在没有c…

c++视觉处理----绘制直方图,H—S直方图,二维H—S直方图,RGB三色直方图

直方图:cv::calcHist() cv::calcHist() 是 OpenCV 中用于计算直方图的函数。直方图是一种用于可视化图像亮度或颜色分布的工具。这函数通常应用于灰度图像或彩色图像的各个通道。以下是 cv::calcHist() 函数的基本语法和参数: void cv::calcHist(const…

Docker在边缘计算中的崭露头角:探索容器技术如何驱动边缘计算的新浪潮

文章目录 第一部分:边缘计算和Docker容器边缘计算的定义Docker容器的崭露头角1. 可移植性2. 资源隔离3. 自动化部署和伸缩 第二部分:应用案例1. 边缘分析2. 工业自动化3. 远程办公 第三部分:挑战和解决方案1. 网络延迟2. 安全性3. 管理和部署…

树莓派部署.net core控制台程序

1、在自己的电脑上使用VS写一个Net6.0的控制台程序,我假设我就写个Helloworld。 发布项目 使用mobaxterm上传程序 就传三个文件就行 回到在mobaxterm中,进入目录输入:cd consolepublish,运行程序: dotnet ConsoleApp1.dll 输出h…

入行CSGO游戏搬砖项目前,这些问题一定要了解

最近咨询的人也不少,针对大家平时问到的问题,在这里做一个统一汇总和解答。 1、什么是国外steam游戏装备汇率差项目? 通俗易懂的理解就是,从国外steam游戏平台购买装备,再挂到国内网易buff平台上进行售卖。充值汇率差…

MyCat管理及监控

MyCat原理 在 MyCat 中,当执行一条 SQL 语句时, MyCat 需要进行 SQL 解析、分片分析、路由分析、读写分离分析等操作,最终经过一系列的分析决定将当前的SQL 语句到底路由到那几个 ( 或哪一个 ) 节点数据库,数据库将数据执行完毕后…

WSL+vscode配置miniob环境

1.配置WSL Windows Subsystem for Linux入门:安装配置图形界面中文环境vscode wu-kan 2.获取源码 找个位置Git Bash然后拉取代码 git clone https://github.com/oceanbase/miniob.git 3.安装相关依赖 https://gitee.com/liangcha-xyy/source/blob/master/how…

软件设计之工厂方法模式

工厂方法模式指定义一个创建对象的接口,让子类决定实例化哪一个类。 结构关系如下: 可以看到,客户端创建了两个接口,一个AbstractFactory,负责创建产品,一个Product,负责产品的实现。ConcreteF…

如何生成SSH服务器的ed25519公钥SHA256指纹

最近搭建ubuntu服务器,远程登录让确认指纹,研究一番搞懂了,记录一下。 1、putty 第一次登录服务器,出现提示: 让确认服务器指纹是否正确。 其中:箭头指向的 ed25519 :是一种非对称加密的签名方法&#xf…

【智能家居项目】裸机版本——网卡设备接入输入子系统 | 业务子系统 | 整体效果展示

🐱作者:一只大喵咪1201 🐱专栏:《智能家居项目》 🔥格言:你只管努力,剩下的交给时间! 目录 🥞网卡设备接入输入子系统🍔测试 🥞业务子系统&#…

PostGIS导入shp文件报错:dbf file (.dbf) can not be opened.

一、报错 刚开始以为是SRID输入错误,反复尝试SRID的输入,还是报错! 后来看到了这篇博客,解决了!https://blog.csdn.net/Fama_Q/article/details/117381378 二、导致报错的原因 导入的shp文件路径太深,换…

Jackson+Feign反序列化问题排查

概述 本文记录在使用Spring Cloud微服务开发时遇到的一个反序列化问题,RPC/HTTP框架使用的是Feign,JSON序列化反序列化工具是Jackson。 问题 测试环境的ELK告警日志如下: - [43f42bf7] 500 Server Error for HTTP POST "/api/open/d…

洗地机哪款最好用?口碑最好的家用洗地机推荐

洗地机方便快捷的清洁方式,如今融入到我们的日常生活需求中来了,然而,在市面上琳琅满目的洗地机品牌中,究竟哪款洗地机比较好用呢?今天,笔者将向大家推荐四款口碑最好的家用洗地机,让你在挑选时…

Java实现防重复提交,使用自定义注解的方式

目录 1.背景 2.思路 3.实现 创建自定义注解 编写拦截器 4.使用 5.验证 6.总结 1.背景 在进行添加操作时,防止恶意点击,后端进行请求接口的防重复提交 2.思路 通过拦截器搭配自定义注解的方式进行实现,拦截器拦截请求,使…

JS加密/解密之webpack打包代码逆向

Webpack 是一个强大的打包工具,能够将多个文件打包成一个或多个最终的文件。然而,将已经经过打包的代码还原回原始源代码并不是一件直接的事情,因为 webpack 打包的过程通常会对代码进行压缩、混淆和优化,丢失了部分变量名和代码结…

Gralloc ION DMABUF in Camera Display

目录 Background knowledge Introduction ia pa va and memory addressing Memory Addressing Page Frame Management Memory area management DMA IOVA and IOMMU Introduce DMABUF What is DMABUF DMABUF 关键概念 DMABUF APIS –The Exporter DMABUF APIS –The…

Flutter - 波浪动画和lottie动画的使用

demo 地址: https://github.com/iotjin/jh_flutter_demo 代码不定时更新,请前往github查看最新代码 波浪动画三方库wave lottie动画 Lottie 是 Airbnb 开发的一款能够为原生应用添加动画效果的开源工具。具有丰富的动画效果和交互功能。 # 波浪动画 https://pub-web…

高并发下的服务容错

在微服务架构中,我们将业务拆分成一个个的服务,服务与服务之间可以相互调用,但是由于网络 原因或者自身的原因,服务并不能保证服务的100%可用,如果单个服务出现问题,调用这个服务就会 出现网络延迟&#xf…