【实验13】使用预训练ResNet18进行CIFAR10分类

目录

1 数据处理

1.1 数据集介绍

1.2数据处理与划分

 2 模型构建- Pytorch高层API中的Resnet18

3 模型训练

4 模型评价

5 比较“使用预训练模型”和“不使用预训练模型”的效果:

6 模型预测

7 完整代码

8 参考链接


1 数据处理

1.1 数据集介绍

  •  数据规模: CIFAR10数据集共有60000个样本,每个样本都是一张32*32像素的RGB图像(彩色图像)。
  • 数据集划分:60000个样本被分成了50000个训练样本和10000个测试样本
  • 类别内容:    CIFAR10中有10类物体,标签值分别按照0~9来区分,他们分别是飞机( airplane )、汽车( automobile )、鸟( bird )、猫( cat )、鹿( deer )、狗( dog )、青蛙( frog )、马( horse )、船( ship )和卡车( truck )

           

  • 数据来源:是从一个叫做【the 80 million tiny images dataset】(“8000 万张小图” 数据集)中精炼剥离出来的一部分,由 Hinton 的学生 Alex Krizhevsky 和 Ilya Sutskever 整理 。

1.2数据处理与划分

        本实验中,对数据集进行归一化处理后,将原始训练集拆分成了train_set、dev_set两个部分,分别包括40 000条和10 000条样本。

# ==================数据处理================
transforms = transforms.Compose([transforms.Resize((32,32)), # 重新设置图片尺寸transforms.ToTensor(), # 转换为tensor格式transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])] # 归一化
)trainset = torchvision.datasets.CIFAR10(root='./cifar10', train=True, download=False, transform=transforms)
testset = torchvision.datasets.CIFAR10(root='./cifar10', train=False, download=False, transform=transforms)classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')# 数据划分
train_size = int(0.8 * len(trainset))  # 80%的数据作为训练集
dev_size = len(trainset) - train_size  # 剩余的20%作为验证集
test_size =len(testset)
print(f"train_size: {train_size},dev_size :{dev_size},test_szie :{test_size}")
# 随机划分训练集和验证集
train_data, dev_data = random_split(trainset, [train_size, dev_size])

运行结果: 

torch.Size([3, 32, 32])
train_size: 40000,dev_size :10000,test_szie :10000

可视化观察其中的一张样本图像和对应的标签:

# 可视化第一张图像
image, label = trainset[0]
print(image.size())
image, label = np.array(image), int(label)
plt.imshow(image.transpose(1, 2, 0))
plt.show()
print(classes[label])

frog

 2 模型构建- Pytorch高层API中的Resnet18

        Pytorch 提供 torchvision.models 接口,里面包含了一些常用用的网络结构,并提供了预训练模型。预训练模型可以通过设置pretrained=True来构建。只需要网络结构,不加载参数来初始化,可以将pretrained = False

什么是“预训练模型”?

        预训练模型是指在大规模数据集上预先进行训练好的神经网络模型,通常在通用任务上学习到的特征可以被迁移到其他特定任务中。预训练模型的思想是利用大规模数据的信息来初始化模型参数,然后通过微调或迁移学习,将模型适应在特定的目标任务上。即在训练结束时结果比较好的一组权重值,研究人员分享出来供其他人使用。

        “预训练模型”可以理解为一个“已经学过一部分知识”的模型。举个例子,如果学习英语,先会通过一段时间学习基础的语法和词汇,这段时间就像是模型的“预训练”。然后,在这个基础上,你可能会学习更具体的内容,比如写作文、翻译等。这时,你可以用预训练的知识来加速你的学习过程。(在“GPT”中,P代表的是“Pre-trained”(预训练)的意思)

  • 预训练模型期望的输入是RGB图像的mini-batch:(batch_size, 3, H, W),并且H和W不能低于224。
  • 图像的像素值必须在范围[0,1]间,并且用均值mean=[0.485, 0.456, 0.406]和方差std=[0.229, 0.224, 0.225]进行归一化。

什么是“迁移学习”?

        迁移学习(Transfer Learning)通俗来讲就是学会举一反三的能力,通过运用已有的知识来学习新的知识,其核心是找到已有知识和新知识之间的相似性,通过这种相似性的迁移达到迁移学习的目的。

        迁移学习就像是学会了弹钢琴后,去学电子琴,不需要从头开始学,因为已经掌握了很多钢琴的技巧和知识,这些可以直接迁移到电子琴的学习上。同样,模型从某个任务中学到的知识,可以直接用来帮助解决另一个任务。

        过程:选择预训练模型--->冻结预训练模型参数--->在新数据集上训练新增加的层--->微调预训练模型的层--->评估和测试。

        【冻结参数:设置参数的requires_grad属性为False,表示在训练过程中这些参数不需要计算梯度】

使用预训练模型

# ======================模型构建=====================
resnet18_model = resnet18(pretrained=True)
#resnet18_model = resnet18(pretrained=False)

不使用预训练模型

# ======================模型构建=====================
#resnet18_model = resnet18(pretrained=True)
resnet18_model = resnet18(pretrained=False)

3 模型训练

        Adam优化器综合了Momentum的更新方向策略和RMSProp 的计算衰减系数策略。

# ======================模型训练======================
import torch.nn.functional as F
import torch.optim as opt
from Runner import RunnerV3,Accuracy,plot# 指定运行设备
torch.cuda.set_device('cuda:0')
# 学习率大小
lr = 0.001
# 批次大小
batch_size = 64
# 创建 DataLoader
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(testset, batch_size=batch_size)
dev_loader = DataLoader(dev_data, batch_size=batch_size)
# 定义网络
model = resnet18_model
# 定义优化器,这里使用Adam优化器以及l2正则化策略,相关内容在7.3.3.2和7.6.2中会进行详细介绍
optimizer = opt.Adam(lr=lr, params=model.parameters(), weight_decay=0.005)
# 定义损失函数
loss_fn = F.cross_entropy
# 定义评价指标
metric = Accuracy()
# 实例化RunnerV3
runner = RunnerV3(model, optimizer, loss_fn, metric)
# 启动训练
log_steps = 3000
eval_steps = 3000
runner.train(train_loader, dev_loader, num_epochs=30, log_steps=log_steps,eval_steps=eval_steps, save_path="best_model.pdparams")
# 加载最优模型
runner.load_model('best_model.pdparams')
plot(runner, fig_name='cnn-loss4.pdf')

使用预训练模型

[Train] epoch: 0/30, step: 0/18750, loss: 13.31647
[Train] epoch: 4/30, step: 3000/18750, loss: 0.84583
[Evaluate]  dev score: 0.66950, dev loss: 0.95410
[Evaluate] best accuracy performance has been updated: 0.00000 --> 0.66950
[Train] epoch: 9/30, step: 6000/18750, loss: 0.61820
[Evaluate]  dev score: 0.73190, dev loss: 0.79586
[Evaluate] best accuracy performance has been updated: 0.66950 --> 0.73190
[Train] epoch: 14/30, step: 9000/18750, loss: 0.64871
[Evaluate]  dev score: 0.74370, dev loss: 0.77339
[Evaluate] best accuracy performance has been updated: 0.73190 --> 0.74370
[Train] epoch: 19/30, step: 12000/18750, loss: 0.59597
[Evaluate]  dev score: 0.72890, dev loss: 0.83410
[Train] epoch: 24/30, step: 15000/18750, loss: 0.29033
[Evaluate]  dev score: 0.75560, dev loss: 0.74253
[Evaluate] best accuracy performance has been updated: 0.74370 --> 0.75560
[Train] epoch: 28/30, step: 18000/18750, loss: 0.80589
[Evaluate]  dev score: 0.74720, dev loss: 0.77720
[Evaluate]  dev score: 0.75970, dev loss: 0.73110
[Evaluate] best accuracy performance has been updated: 0.75560 --> 0.75970
[Train] Training done!

        验证集上的准确率(dev score)随着训练轮次的推进呈现出不断上升的趋势,从最初的0.00000逐渐提升到最后的0.75970 。

 不使用预训练模型 

[Train] epoch: 0/30, step: 0/18750, loss: 7.30914
[Train] epoch: 4/30, step: 3000/18750, loss: 0.87599
[Evaluate]  dev score: 0.66180, dev loss: 0.99899
[Evaluate] best accuracy performance has been updated: 0.00000 --> 0.66180
[Train] epoch: 9/30, step: 6000/18750, loss: 0.83667
[Evaluate]  dev score: 0.68640, dev loss: 0.93156
[Evaluate] best accuracy performance has been updated: 0.66180 --> 0.68640
[Train] epoch: 14/30, step: 9000/18750, loss: 0.73321
[Evaluate]  dev score: 0.71090, dev loss: 0.87719
[Evaluate] best accuracy performance has been updated: 0.68640 --> 0.71090
[Train] epoch: 19/30, step: 12000/18750, loss: 0.81280
[Evaluate]  dev score: 0.71270, dev loss: 0.89412
[Evaluate] best accuracy performance has been updated: 0.71090 --> 0.71270
[Train] epoch: 24/30, step: 15000/18750, loss: 0.36561
[Evaluate]  dev score: 0.71520, dev loss: 0.86391
[Evaluate] best accuracy performance has been updated: 0.71270 --> 0.71520
[Train] epoch: 28/30, step: 18000/18750, loss: 0.31370
[Evaluate]  dev score: 0.72560, dev loss: 0.84174
[Evaluate] best accuracy performance has been updated: 0.71520 --> 0.72560
[Evaluate]  dev score: 0.72780, dev loss: 0.82727
[Evaluate] best accuracy performance has been updated: 0.72560 --> 0.72780
[Train] Training done!

         验证集上的准确率(dev score)随着训练轮次的推进呈现出逐步上升的趋势,从最初的0.00000逐渐提升到最后的0.72780

4 模型评价

# ======================模型评价=====================
score, loss = runner.evaluate(test_loader)
print("[Test] accuracy/loss: {:.4f}/{:.4f}".format(score, loss))

使用预训练模型

[Test] accuracy/loss: 0.7488/0.7621

不使用预训练模型 

[Test] accuracy/loss: 0.7320/0.8259

5 比较“使用预训练模型”和“不使用预训练模型”的效果:

        (1)训练初期,由于预训练模型已经学习到了一些通用特征,模型能较快地适应新任务的数据分布,损失函数往往下降得相对较快。而不使用预训练模型,需要从零开始学习数据中的所有特征,所以损失函数下降速度通常相对较慢。【在未使用预训练模型的训练结果中,epoch: 0/30, step: 0/18750, loss: 7.30914到epoch: 9/30, step: 6000/18750, loss: 0.83667,而t同样轮数使用预训练模型的从epoch: 0/30, step: 0/18750, loss: 13.31647到epoch: 9/30, step: 6000/18750, loss: 0.61820。观察前期损失变化图像也能发现使用预训练模型的斜率比未使用的斜率要大】

        (2)由最终在测试集上的评价可以看出,使用预训练模型的泛化能力要好一点。(但这个实验中好像没有太大的差别)

        (3)使用预训练模型的训练时间相对较短,对计算资源的需求也相对较少。因为主要是进行微调操作,只需在预训练模型的基础上,根据新任务的数据对部分参数进行调整,不需要从头开始训练整个模型架构。微调过程相比于从头训练一个复杂的模型,计算量大幅降低,不需要长时间占用大量的 GPU 等计算资源。

  

6 模型预测

#=========================模型预测===================# 获取测试集中的一个batch的数据
for X, label in test_loader:logits = runner.predict(X)# 多分类,使用softmax计算预测概率pred = F.softmax(logits)# 获取概率最大的类别pred_class = torch.argmax(pred[2]).cpu().numpy()label = label[2].data.numpy()# 输出真实类别与预测类别print("The true category is {} and the predicted category is {}".format(classes[label], classes[pred_class]))# 可视化图片X = np.array(X)X = X[1]plt.imshow(X.transpose(1, 2, 0))plt.show()break

The true category is ship and the predicted category is ship

可见模型成功预测了。 

7 完整代码

'''
@Function: 使用预训练resnet18(调用API)实现CIFAR-10分类
@Author: lxy
@date: 2024/11/27
'''
import torch
from torchvision.transforms import transforms
import torchvision
from torch.utils.data import DataLoader,random_split
import numpy as np
from torchvision.models import resnet18
import matplotlib.pyplot as plt# ==================数据处理================
transforms = transforms.Compose([transforms.Resize((32,32)), # 重新设置图片尺寸transforms.ToTensor(), # 转换为tensor格式transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])] # 归一化
)trainset = torchvision.datasets.CIFAR10(root='./cifar10', train=True, download=False, transform=transforms)
testset = torchvision.datasets.CIFAR10(root='./cifar10', train=False, download=False, transform=transforms)classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')
# 可视化第一张图像
image, label = trainset[0]
print(image.size())
image, label = np.array(image), int(label)
plt.imshow(image.transpose(1, 2, 0))
plt.show()
print(classes[label])# ======================模型构建=====================
resnet18_model = resnet18(pretrained=True)
#resnet18_model = resnet18(pretrained=False)# ======================模型训练======================
import torch.nn.functional as F
import torch.optim as opt
from Runner import RunnerV3,Accuracy,plot# 指定运行设备
torch.cuda.set_device('cuda:0')
# 学习率大小
lr = 0.001
# 批次大小
batch_size = 64
# 加载数据
train_size = int(0.8 * len(trainset))  # 80%的数据作为训练集
dev_size = len(trainset) - train_size  # 剩余的20%作为验证集
test_size =len(testset)
print(f"train_size: {train_size},dev_size :{dev_size},test_szie :{test_size}")
# 随机划分训练集和验证集
train_data, dev_data = random_split(trainset, [train_size, dev_size])
# 创建 DataLoader
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(testset, batch_size=batch_size)
dev_loader = DataLoader(dev_data, batch_size=batch_size)
# 定义网络
model = resnet18_model
# 定义优化器,这里使用Adam优化器以及l2正则化策略,相关内容在7.3.3.2和7.6.2中会进行详细介绍
optimizer = opt.Adam(lr=lr, params=model.parameters(), weight_decay=0.005)
# 定义损失函数
loss_fn = F.cross_entropy
# 定义评价指标
metric = Accuracy()
# 实例化RunnerV3
runner = RunnerV3(model, optimizer, loss_fn, metric)
# 启动训练
log_steps = 3000
eval_steps = 3000
runner.train(train_loader, dev_loader, num_epochs=30, log_steps=log_steps,eval_steps=eval_steps, save_path="best_model.pdparams")
# 加载最优模型
runner.load_model('best_model.pdparams')
plot(runner, fig_name='cnn-loss4.pdf')# ======================模型评价=====================
score, loss = runner.evaluate(test_loader)
print("[Test] accuracy/loss: {:.4f}/{:.4f}".format(score, loss))
#=========================模型预测===================# 获取测试集中的一个batch的数据
for X, label in test_loader:logits = runner.predict(X)# 多分类,使用softmax计算预测概率pred = F.softmax(logits)# 获取概率最大的类别pred_class = torch.argmax(pred[2]).cpu().numpy()label = label[2].data.numpy()# 输出真实类别与预测类别print("The true category is {} and the predicted category is {}".format(classes[label], classes[pred_class]))# 可视化图片X = np.array(X)X = X[1]plt.imshow(X.transpose(1, 2, 0))plt.show()break

8 参考链接

参考链接:

CIFAR-10数据集简介_cifar10-CSDN博客

测试集可用作验证集

迁移学习之——什么是迁移学习(Transfer Learning)

torchvision.models_torchvision models

一文读懂预训练模型(非常详细)

Adam优化器(理论、公式、代码)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/482318.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Java之链表1

文章目录 1. 链表1.11.2 链表的概念及其结构1.3 自己实现一个链表 1. 链表 1.1 之前我们学习了 顺序表ArrayList,并自己实现了 ArrayList ,发现它在删除元素和添加元素时很麻烦,最坏的情况时,需要将所有的元素移动,因…

二分搜索(三)x的平方根

69. x 的平方根 给定一个排序数组和一个目标值,在数组中找到目标值,并返回其索引。如果目标值不存在于数组中,返回它将会被按顺序插入的位置。 请必须使用时间复杂度为 O(log n) 的算法。 示例 1: 输入: nums [1,3,5,6], target 5 输出: 2…

AI开发-数据可视化库-Seaborn

1 需求 概述 Seaborn 是一个基于 Python 的数据可视化库,它建立在 Matplotlib 之上。其主要目的是使数据可视化更加美观、方便和高效。它提供了高层次的接口和各种美观的默认主题,能够帮助用户快速创建出具有吸引力的统计图表,用于数据分析和…

使用docker-compose部署搜索引擎ElasticSearch6.8.10

背景 Elasticsearch 是一个开源的分布式搜索和分析引擎,基于 Apache Lucene 构建。它被广泛用于实时数据搜索、日志分析、全文检索等应用场景。 Elasticsearch 支持高效的全文搜索,并提供了强大的聚合功能,可以处理大规模的数据集并进行快速…

LeetCode—74. 搜索二维矩阵(中等)

仅供个人学习使用 题目描述: 给你一个满足下述两条属性的 m x n 整数矩阵: 每行中的整数从左到右按非严格递增顺序排列。 每行的第一个整数大于前一行的最后一个整数。 给你一个整数 target ,如果 target 在矩阵中,返回 true…

Cento7 紧急模式无法正常启动,修复home挂载问题

Centos 7 开机失败进入紧急模式[emergency mode],解决方案。 通过journalctl -xb查看启动日志,定位发现/home目录无法正常挂载。 退出启动日志检查,进行修复。 进行问题修复 # 修复挂载问题 mkdir /home mount /dev/mapper/centos-home /ho…

Matlab mex- setup报错—错误使用 mex,未检测到支持的编译器...

错误日志: 在使用mex编译时报错提示:错误使用 mex,未检测到支持的编译器。您可以安装免费提供的 MinGW-w64 C/C 编译器;请参阅安装 MinGW-w64 编译器。有关更多选项,请访问https://www.mathworks.com/support/compile…

【C语言】二叉树(BinaryTree)的创建、3种递归遍历、3种非递归遍历、结点度的实现

代码主要实现了以下功能: 二叉树相关数据结构定义 定义了二叉树节点结构体 BiTNode,包含节点数据值(字符类型)以及指向左右子树的指针。 定义了顺序栈结构体 SqStack,用于存储二叉树节点指针,实现非递归遍历…

Android -- 简易音乐播放器

Android – 简易音乐播放器 播放器功能:* 1. 播放模式:单曲、列表循环、列表随机;* 2. 后台播放(单例模式);* 3. 多位置同步状态回调;处理模块:* 1. 提取文件信息:音频文…

Python语法基础(四)

🌈个人主页:羽晨同学 💫个人格言:“成为自己未来的主人~” 高阶函数之map 高阶函数就是说,A函数作为B函数的参数,B函数就是高阶函数 map:映射 map(func,iterable) 这个是map的基本语法,…

大模型时代的人工智能基础与实践——基于OmniForce的应用开发教程

《大模型时代的人工智能基础与实践——基于 OmniForce 的应用开发教程》由京东探索研究院及京东教育联袂撰写,图文并茂地介绍传统人工智能和新一代人工智能(基于大模型的通用人工智能技术),展示人工智能广阔的应用场景。同时&…

ESP8266 (ESP-01S)烧录固件 和 了解与单片机通信必需的AT指令

ESP8266(ESP-01s)烧录固件 工具: 需要安装的原装出厂固件库: ESP8266 --接线-- VCC 3.3(外接开发板) GND GND(外接开发板) IO0 GND(外接…

【操作文档】mysql分区操作步骤.docx

1、建立分区表 执行 tb_intercept_notice表-重建-添加分区.sql 文件; DROP TABLE IF EXISTS tb_intercept_notice_20241101_new; CREATE TABLE tb_intercept_notice_20241101_new (id char(32) NOT NULL COMMENT id,number varchar(30) NOT NULL COMMENT 号码,cre…

S4 UPA of AA :新资产会计概览

通用并行会计(Universal Parallel Accounting)可以支持每个独立的分类账与其他模块集成,UPA主要是为了支持平行评估、多货币类型、财务合并、多准则财务报告的复杂业务需求 在ML层面UPA允许根据不同的分类账规则对物料进行评估,并…

ScribblePrompt 医学图像分割工具,三种标注方式助力图像处理

ScribblePrompt 的主要目标是简化医学图像的分割过程,这在肿瘤检测、器官轮廓描绘等应用中至关重要。相比依赖大量人工标注数据,该工具允许用户通过少量输入(例如简单的涂鸦或点位)来引导模型优化分割结果。这种方式减少了医学专家…

jdk各个版本介绍

Java Development Kit(JDK)是Java平台的核心组件,它包含了Java编程语言、Java虚拟机(JVM)、Java类库以及用于编译、调试和运行Java应用程序的工具。 JDK 1.0-1.4(经典时代) • JDK 1.0&#xff…

【Python爬虫五十个小案例】爬取猫眼电影Top100

博客主页:小馒头学python 本文专栏: Python爬虫五十个小案例 专栏简介:分享五十个Python爬虫小案例 🐍引言 猫眼电影是国内知名的电影票务与资讯平台,其中Top100榜单是影迷和电影产业观察者关注的重点。通过爬取猫眼电影Top10…

Doge东哥wordpress主题

Doge东哥wordpress主题是一款专为中小型企业设计的WordPress外贸网站模板,它以其现代、专业且用户友好的界面,为企业提供了一个展示产品和服务的理想平台。以下是对该模板的详细描述: 首页设计概览 首页的设计简洁而不失大气,顶…

【力扣】541.反转字符串2

问题描述 思路解析 每当字符达到2*k的时候,判断,同时若剩余字符>k,只对前k个进行判断(这是重点)因为字符串是不可变变量,所以将其转化为字符串数组,最后才将结果重新转变为字符串 字符串->字符数组 …

C++练级计划-> 《IO流》iostream fstream sstream详解

如果是想全部过一遍就看完,如果想具体的了解某一个请点目录。因为有三种流的使用可能内容多 目录 流是什么? CIO流(iostream) io流的注意事项 cin和cout为什么能直接识别出类型和数据 fstream fstream的使用方法&#xff…