pytorch迁移学习训练图像分类

pytorch迁移学习训练图像分类

  • 一、环境配置
  • 二、迁移学习关键代码
  • 三、完整代码
  • 四、结果对比

代码和图片等资源均来源于哔哩哔哩up主:同济子豪兄
讲解视频:Pytorch迁移学习训练自己的图像分类模型

一、环境配置

1,安装所需的包

pip install numpy pandas matplotlib seaborn plotly requests tqdm opencv-python pillow wandb -i https://pypi.tuna.tsinghua.edu.cn/simple

2,安装Pytorch

pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113

3,创建目录

import os
# 存放训练得到的模型权重
os.mkdir('checkpoint')

4,下载数据集压缩包(下载之后需要解压数据集)

wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220716-mmclassification/dataset/fruit30/fruit30_split.zip

二、迁移学习关键代码

以下是迁移学习的三种选择,根据训练的需求选择不同的迁移方法:

  • 选择一:只微调训练模型最后一层(全连接分类层)
model = models.resnet18(pretrained=True) # 载入预训练模型
# 修改全连接层,使得全连接层的输出与 当前数据集类别数n_class 对应
model.fc = nn.Linear(model.fc.in_features, n_class)
# 只微调训练最后一层全连接层的参数,其它层冻结
optimizer = optim.Adam(model.fc.parameters())
  • 选择二:微调训练所有层。

适用于训练数据集与预训练模型相差大时,可以选择微调训练所有层,此时只使用预训练模型的部分权重和特征,例如原始模型为imageNet,而训练数据为医疗相关

model = models.resnet18(pretrained=True) # 载入预训练模型
model.fc = nn.Linear(model.fc.in_features, n_class)
optimizer = optim.Adam(model.parameters())
  • 选择三:随机初始化模型全部权重,从头训练所有层
model = models.resnet18(pretrained=False) # 只载入模型结构,不载入预训练权重参数
model.fc = nn.Linear(model.fc.in_features, n_class)
optimizer = optim.Adam(model.parameters())

三、完整代码

import time
import osimport numpy as np
from tqdm import tqdmimport torch
import torchvision
import torch.nn as nn# 忽略出现的红色提示
import warnings
warnings.filterwarnings("ignore")# 有 GPU 就用 GPU,没有就用 CPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('device', device)from torchvision import transforms# 训练集图像预处理:缩放裁剪、图像增强、转 Tensor、归一化
train_transform = transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])# 测试集图像预处理-RCTN:缩放、裁剪、转 Tensor、归一化
test_transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])# 数据集文件夹路径
dataset_dir = 'fruit30_split'
train_path = os.path.join(dataset_dir, 'train')	# 测试集路径
test_path = os.path.join(dataset_dir, 'val')	# 测试集路径from torchvision import datasets# 载入训练集
train_dataset = datasets.ImageFolder(train_path, train_transform)# 载入测试集
test_dataset = datasets.ImageFolder(test_path, test_transform)# 各类别名称
class_names = train_dataset.classes
n_class = len(class_names)# 定义数据加载器DataLoader
from torch.utils.data import DataLoaderBATCH_SIZE = 32# 训练集的数据加载器
train_loader = DataLoader(train_dataset,batch_size=BATCH_SIZE,shuffle=True,num_workers=4)# 测试集的数据加载器
test_loader = DataLoader(test_dataset,batch_size=BATCH_SIZE,shuffle=False,num_workers=4)from torchvision import models
import torch.optim as optim# 选择一:只微调训练模型最后一层(全连接分类层)
model = models.resnet18(pretrained=True) # 载入预训练模型
# 修改全连接层,使得全连接层的输出与当前数据集类别数对应
# 新建的层默认 requires_grad=True,指定张量需要梯度计算
model.fc = nn.Linear(model.fc.in_features, n_class)
model.fc	# 查看全连接层
# 只微调训练最后一层全连接层的参数,其它层冻结
optimizer = optim.Adam(model.fc.parameters())    # optim 是 PyTorch 的一个优化器模块,用于实现各种梯度下降算法的优化方法# 选择二:微调训练所有层
# 训练数据集与预训练模型相差大时,可以选择微调训练所有层,只使用预训练模型的部分权重和特征,例如原始模型为imageNet,训练数据为医疗相关
# model = models.resnet18(pretrained=True) # 载入预训练模型
# model.fc = nn.Linear(model.fc.in_features, n_class)
# optimizer = optim.Adam(model.parameters())# 选择三:随机初始化模型全部权重,从头训练所有层
# model = models.resnet18(pretrained=False) # 只载入模型结构,不载入预训练权重参数
# model.fc = nn.Linear(model.fc.in_features, n_class)
# optimizer = optim.Adam(model.parameters())# 训练配置
model = model.to(device)# 交叉熵损失函数
criterion = nn.CrossEntropyLoss()# 训练轮次 Epoch
EPOCHS = 30# 遍历每个 EPOCH
for epoch in tqdm(range(EPOCHS)):model.train()for images, labels in train_loader:  # 获取训练集的一个 batch,包含数据和标注images = images.to(device)labels = labels.to(device)outputs = model(images)           # 前向预测,获得当前 batch 的预测结果loss = criterion(outputs, labels) # 比较预测结果和标注,计算当前 batch 的交叉熵损失函数optimizer.zero_grad()loss.backward()                   # 损失函数对神经网络权重反向传播求梯度optimizer.step()                  # 优化更新神经网络权重# 测试集上初步测试
model.eval()
with torch.no_grad():correct = 0total = 0for images, labels in tqdm(test_loader): # 获取测试集的一个 batch,包含数据和标注images = images.to(device)labels = labels.to(device)outputs = model(images)              # 前向预测,获得当前 batch 的预测置信度_, preds = torch.max(outputs, 1)     # 获得最大置信度对应的类别,作为预测结果total += labels.size(0)correct += (preds == labels).sum()   # 预测正确样本个数print('测试集上的准确率为 {:.3f} %'.format(100 * correct / total))# 保存模型
torch.save(model, 'checkpoint/fruit30_pytorch_A1.pth') # 选择一:微调全连接层
# torch.save(model, 'checkpoint/fruit30_pytorch_A2.pth') # 选择二:微调所有层
# torch.save(model, 'checkpoint/fruit30_pytorch_A3.pth') # 选择三:随机权重

四、结果对比

调用不同迁移学习得到的模型对比测试集准确率

# 测试集导入和图像预处理等代码和上述完整代码中一致,此处省略……# 调用自己训练的模型
model = torch.load('checkpoint/fruit30_pytorch_A1.pth')# 测试集上进行测试
model.eval()
with torch.no_grad():correct = 0total = 0for images, labels in tqdm(test_loader): # 获取测试集的一个 batch,包含数据和标注images = images.to(device)labels = labels.to(device)outputs = model(images)              # 前向预测,获得当前 batch 的预测置信度_, preds = torch.max(outputs, 1)     # 获得最大置信度对应的类别,作为预测结果total += labels.size(0)correct += (preds == labels).sum()   # 预测正确样本个数print('测试集上的准确率为 {:.3f} %'.format(100 * correct / total))

结果如下:
对于微调全连接层的选择一,测试集准确率为 72.078%
在这里插入图片描述
而所有权重随机的选择三测试集准确率为 43.228%
43.228

总体而言,迁移学习能够利用已有的知识和经验,加速模型的训练过程,提高模型的性能。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/141634.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【深度学习实验】卷积神经网络(一):卷积运算及其Pytorch实现(一维卷积:窄卷积、宽卷积、等宽卷积;二维卷积)

目录 一、实验介绍 二、实验环境 1. 配置虚拟环境 2. 库版本介绍 三、实验内容 1. 一维卷积 a. 概念 b. 示例 c. 分类 窄卷积(Narrow Convolution) 宽卷积(Wide Convolution) 等宽卷积(Same Convolution&am…

Python开发与应用实验2 | Python基础语法应用

*本文是博主对学校专业课Python各种实验的再整理与详解,除了代码部分和解析部分,一些题目还增加了拓展部分(⭐)。拓展部分不是实验报告中原有的内容,而是博主本人自己的补充,以方便大家额外学习、参考。 &a…

Python3 如何实现 websocket 服务?

Python 实现 websocket 服务很简单,有很多的三方包可以用,我从网上大概找到三种常用的包:websocket、websockets、Flask-Sockets。 但这些包很多都“年久失修”, 比如 websocket 在 2010 年就不维护了。 而 Flask-Sockets 也在 2…

通信协议:Uart的Verilog实现(上)

1、前言 调制解调器是主机/设备与串行数据通路之间的接口,以串行单比特格式发送和接收数据。它也被称为通用异步收发器(Uart, Universal Asynchronous Receiver/Transmitter),这表明该设备能够接收和发送数据,并且发送和接收单元不同步。 本节…

递归算法讲解,深度理解递归

首先最重要的就是要说明递归思想的作用,在后面学习的高级数据接口,树和图中,都需要用到递归,即深度优先搜索,如果递归掌握的不好,后面的数据结构将举步为艰。 加油 首先看下如何下面两个方法有什么区别&a…

git revert 撤销之前的提交

git revert 用来撤销之前的提交,它会生成一个新的 commit id 。 输入 git revert --help 可以看到帮忙信息。 git revert commitID 不编辑新的 commit 说明 git log 找到需要撤销的 commitID , 然后执行 git revert commitID ,会提示如下…

Allegro如何将丝印文字Change到任意层面操作指导

Allegro如何将丝印文字Change到任意层面操作指导 在用Allegro进行PCB设计的时候,有时需要将丝印文字change到其它层面,如下图 可以看到丝印文字是属于REFDES这个Class的 如果需要把丝印文字change层面,只支持REFDES中以下的层面中来change

入门级制作电子期刊的网站推荐

随着数字化时代的到来,越来越多的人开始尝试制作自己的电子期刊。如果你也是其中的一员,那么这篇文章可以帮助你制作电子期刊。无论是初学者还是有一定经验的制作者,都能快速完成高质量的电子期刊制作 小编经常使用的工具是-----FLBOOK在线制…

95 # express 二级路由的实现

上一节实现了兼容老的路由写法,这一节来实现二级路由 二级路由实现核心: 进入中间件后,让对应的路由系统去进行匹配操作中间件进去匹配需要删除 path,存起来出去时在加上 示意图: 代码实现如下: rout…

ASCII码-对照表

ASCII 1> ASCII 控制字符2> ASCII 显示字符3> 常用ASCII码3.1> 【CR】\r 回车符3.2> 【LF】\n 换行符3.3> 不同操作系统,文件中换行 1> ASCII 控制字符 2> ASCII 显示字符 3> 常用ASCII码 3.1> 【CR】‘\r’ 回车符 CR Carriage Re…

安装OpenSearch

title: “安装opensearch” createTime: 2021-11-30T19:13:4508:00 updateTime: 2021-11-30T19:13:4508:00 draft: false author: “name” tags: [“es”,“安装”] categories: [“OpenSearch”] description: “测试的” 说明 基于Elasticsearch7.10.2 的 opensearch-1.1.…

arcgis搭建离线地图服务WMTS

Arcgis搭建离线地图服务WMTS 发布时间:2021-03-04 版权: ARCGIS搭建离线地图服务器,进行离线地图二次开发 2. 离线地图服务发布(WMTS服务) (详细教程:卫星地图_高清卫星地图_地图编辑_离线地…

商品秒杀系统整理

1、使用redis缓存商品信息 2、互斥锁解决缓存击穿问题,用缓存空值解决缓存穿透问题。 3、CAS乐观锁解决秒杀超卖的问题 4、使用redission实现一人一单。(分布式锁lua)脚本。 5、使用lua脚本进行秒杀资格判断(将库存和用户下单…

jupyter notebook进不去指定目录怎么办?

首先激活你要使用的虚拟环境 刚开始是现在 (base) C:\Users\lenovo>目录下 直接输入你想进入的盘 (base) C:\Users\lenovo>e:此时再cd (base) C:\Users\lenovo>cd E:\tim\learn_pytorch 就可以进入了 安装3.4.1.15问题 已经有了最新python版本的虚拟环境&#…

【实战项目之个人博客】

目录 项目背景 项目技术栈 项目介绍 项目亮点 项目启动 1.创建SSM(省略) 2.配置项目信息 3.将前端页面加入到项目中 4.初始化数据库 5.创建标准分层的目录 6.创建和编写项目中的公共代码以及常用配置 7.创建和编写业务的Entity、Mapper、…

GE IS220PVIBH1A 336A4940CSP16 电源模块

GE IS220PVIBH1A 336A4940CSP16 电源模块是通用电气(GE)的一种电源模块,用于工业控制和电力系统中,提供电源供应和保护功能。以下是这种类型电源模块的一般特点和功能: 电源供应:GE IS220PVIBH1A 336A4940C…

电脑WIFI突然消失

文章目录 1. 现象2. 解决办法1:重新启用无线网卡设置3. 解决办法2:更新无线网卡驱动4. 解决办法3:释放静电5. 解决办法4:拆机并重新插拔无线网卡 1. 现象 如下图:电脑在使用过程中WIFI消失 设备管理器中的无线网卡驱…

数据通信——应用层(域名系统)

引言 TCP到此就告一段落,这也意味着传输层结束了,紧随其后的就是TCP/IP五层架构的应用层。操作系统、编程语言、用户的可视化界面等等都要通过应用层来体现。应用层和我们息息相关,我们使用电子设备娱乐或办公时,接触到的就是应用…

19.组合模式(Composite)

意图:将对象组成树状结构以表示“部分-整体”的层次结构,使得Client对单个对象和组合对象的使用具有一致性。 上下文:在树型结构的问题中,Client必须以不同的方式处理单个对象和组合对象。能否提供一种封装&#xff0c…

基于jquery开发的Windows 12网页版

预览 https://win12.gitapp.cn 首页代码 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta http-equiv"refresh" content"0;urldesktop.html" /> <meta name"viewport&…