yolov8实战第四天——yolov8图像分类 ResNet50图像分类(保姆式教程)

yolov8实战第一天——yolov8部署并训练自己的数据集(保姆式教程)_yolov8训练自己的数据集-CSDN博客在前几天,我们使用yolov8进行了部署,并在目标检测方向上进行自己数据集的训练与测试,今天我们训练下yolov8的图像分类,看看效果如何,同时使用resnet50也训练一个分类模型,看看哪个效果好!

图像分类是指将输入的图像自动分类为不同的类别。它是计算机视觉领域的一个重要应用,可以用于人脸识别、物体识别、场景分类等任务。

通常情况下,图像分类的流程如下:

  1. 收集和准备数据集:收集与任务相关的图像数据,并将其打上标签。
  2. 定义模型:选择一种适合于你的任务的深度学习模型,例如卷积神经网络(CNN)。
  3. 训练模型:使用收集到的数据集对模型进行训练,通过反向传播算法来更新模型参数,使其可以根据输入图像进行正确的分类。
  4. 评估模型性能:使用测试集对已经训练好的模型进行评估,比较模型预测结果与真实标签之间的差异,从而评估模型的性能。
  5. 使用模型进行预测:使用已经训练好的模型对新的图像进行分类预测。

在实际应用中,可以使用各种深度学习框架(例如 TensorFlow、PyTorch、Keras 等)来构建图像分类模型,并使用各种数据增强技术(例如旋转、缩放、裁剪等)来增加数据集的多样性和数量。

如果你想学习如何使用深度学习框架来构建图像分类模型,可以参考一些在线教程、书籍或者 MOOC。

一、yolov8图像分类

1.模型选型

下载yolov8分类模型。

分别使用模型进行测试:

yolov8n-cls效果:

yolov8m-cls效果:

总结:n效果不咋地,还是得使用m进行后续训练工作。 

2.数据集准备

皮肤癌检测_数据集-飞桨AI Studio星河社区

同目标检测,还是放在datasets下。

直接改成这个,省去分数据集操作。 

 3.训练

yolo classify train data=./datasets/skin-cancer-detection model=yolov8n-cls.pt epochs=100

测试:

yolo classify predict model=runs/classify/train4/weights/best.pt source='./datasets/skin-cancer-detection/train/nevus'

  

label: 

 pred:

总结:数据集比较小,yolov8效果不太好。

、resnet50图像分类

Resnet50 网络中包含了 49 个卷积层、一个全连接层。如图下图所示,Resnet50网络结构可以分成七个部分,第一部分不包含残差块,主要对输入进行卷积、正则化、激活函数、最大池化的计算。第二、三、四、五部分结构都包含了残差块,图 中的绿色图块不会改变残差块的尺寸,只用于改变残差块的维度。在 Resnet50 网 络 结 构 中 , 残 差 块 都 有 三 层 卷 积 , 那 网 络 总 共 有1+3×(3+4+6+3)=49个卷积层,加上最后的全连接层总共是 50 层,这也是Resnet50 名称的由来。网络的输入为 224×224×3,经过前五部分的卷积计算,输出为 7×7×2048,池化层会将其转化成一个特征向量,最后分类器会对这个特征向量进行计算并输出类别概率。

运行train.py即可。

train.py

import torch
from torchvision import datasets, models, transforms
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import timeimport numpy as np
import matplotlib.pyplot as plt
import os
from tqdm import tqdm# 一、建立数据集
# animals-6
#   --train
#       |--dog
#       |--cat
#       ...
#   --valid
#       |--dog
#       |--cat
#       ...
#   --test
#       |--dog
#       |--cat
#       ...
# 我的数据集中 train 中每个类别60张图片,valid 中每个类别 10 张图片,test 中每个类别几张到几十张不等,一共 6 个类别。# 二、数据增强
# 建好的数据集在输入网络之前先进行数据增强,包括随机 resize 裁剪到 256 x 256,随机旋转,随机水平翻转,中心裁剪到 224 x 224,转化成 Tensor,正规化等。
image_transforms = {'train': transforms.Compose([transforms.RandomResizedCrop(size=256, scale=(0.8, 1.0)),transforms.RandomRotation(degrees=15),transforms.RandomHorizontalFlip(),transforms.CenterCrop(size=224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])]),'valid': transforms.Compose([transforms.Resize(size=256),transforms.CenterCrop(size=224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])])
}# 三、加载数据
# torchvision.transforms包DataLoader是 Pytorch 重要的特性,它们使得数据增加和加载数据变得非常简单。
# 使用 DataLoader 加载数据的时候就会将之前定义的数据 transform 就会应用的数据上了。
dataset = 'skin-cancer-detection'
train_directory = './skin-cancer-detection/train'
valid_directory = './skin-cancer-detection/val'batch_size = 32
num_classes = 9 #分类种类数
print(train_directory)
data = {'train': datasets.ImageFolder(root=train_directory, transform=image_transforms['train']),'valid': datasets.ImageFolder(root=valid_directory, transform=image_transforms['valid'])
}
print("训练集图片类别及其对应编号(种类名:编号):",data['train'].class_to_idx)
print("测试集图片类别及其对应编号:",data['valid'].class_to_idx)train_data_size = len(data['train'])
valid_data_size = len(data['valid'])train_data = DataLoader(data['train'], batch_size=batch_size, shuffle=True, num_workers=0)
valid_data = DataLoader(data['valid'], batch_size=batch_size, shuffle=True, num_workers=0)print("训练集图片数量:",train_data_size, "测试集图片数量:",valid_data_size)# 四、迁移学习
# 这里使用ResNet-50的预训练模型。
#resnet50 = models.resnet50(pretrained=True)
resnet50 = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)# 在PyTorch中加载模型时,所有参数的‘requires_grad’字段默认设置为true。这意味着对参数值的每一次更改都将被存储,以便在用于训练的反向传播图中使用。
# 这增加了内存需求。由于预训练的模型中的大多数参数已经训练好了,因此将requires_grad字段重置为false。
for param in resnet50.parameters():param.requires_grad = False# 为了适应自己的数据集,将ResNet-50的最后一层替换为,将原来最后一个全连接层的输入喂给一个有256个输出单元的线性层,接着再连接ReLU层和Dropout层,然后是256 x 6的线性层,输出为6通道的softmax层。
fc_inputs = resnet50.fc.in_features
resnet50.fc = nn.Sequential(nn.Linear(fc_inputs, 256),nn.ReLU(),nn.Dropout(0.4),nn.Linear(256, num_classes),nn.LogSoftmax(dim=1)
)# 用GPU进行训练。
resnet50 = resnet50.to('cuda:0')# 定义损失函数和优化器。
loss_func = nn.NLLLoss()
optimizer = optim.Adam(resnet50.parameters())# 五、训练
def train_and_valid(model, loss_function, optimizer, epochs=25):device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")history = []best_acc = 0.0best_epoch = 0for epoch in range(epochs):epoch_start = time.time()print("Epoch: {}/{}".format(epoch+1, epochs))model.train()train_loss = 0.0train_acc = 0.0valid_loss = 0.0valid_acc = 0.0for i, (inputs, labels) in enumerate(tqdm(train_data)):inputs = inputs.to(device)labels = labels.to(device)#因为这里梯度是累加的,所以每次记得清零optimizer.zero_grad()outputs = model(inputs)loss = loss_function(outputs, labels)print("标签值:",labels)print("输出值:",outputs)loss.backward()optimizer.step()train_loss += loss.item() * inputs.size(0)ret, predictions = torch.max(outputs.data, 1)correct_counts = predictions.eq(labels.data.view_as(predictions))acc = torch.mean(correct_counts.type(torch.FloatTensor))train_acc += acc.item() * inputs.size(0)with torch.no_grad():model.eval()for j, (inputs, labels) in enumerate(tqdm(valid_data)):inputs = inputs.to(device)labels = labels.to(device)outputs = model(inputs)loss = loss_function(outputs, labels)valid_loss += loss.item() * inputs.size(0)ret, predictions = torch.max(outputs.data, 1)correct_counts = predictions.eq(labels.data.view_as(predictions))acc = torch.mean(correct_counts.type(torch.FloatTensor))valid_acc += acc.item() * inputs.size(0)avg_train_loss = train_loss/train_data_sizeavg_train_acc = train_acc/train_data_sizeavg_valid_loss = valid_loss/valid_data_sizeavg_valid_acc = valid_acc/valid_data_sizehistory.append([avg_train_loss, avg_valid_loss, avg_train_acc, avg_valid_acc])if best_acc < avg_valid_acc:best_acc = avg_valid_accbest_epoch = epoch + 1epoch_end = time.time()print("Epoch: {:03d}, Training: Loss: {:.4f}, Accuracy: {:.4f}%, \n\t\tValidation: Loss: {:.4f}, Accuracy: {:.4f}%, Time: {:.4f}s".format(epoch+1, avg_valid_loss, avg_train_acc*100, avg_valid_loss, avg_valid_acc*100, epoch_end-epoch_start))print("Best Accuracy for validation : {:.4f} at epoch {:03d}".format(best_acc, best_epoch))torch.save(model, 'models/'+dataset+'_model_'+str(epoch+1)+'.pt')return model, historynum_epochs = 100 #训练周期数
trained_model, history = train_and_valid(resnet50, loss_func, optimizer, num_epochs)
torch.save(history, 'models/'+dataset+'_history.pt')history = np.array(history)
plt.plot(history[:, 0:2])
plt.legend(['Tr Loss', 'Val Loss'])
plt.xlabel('Epoch Number')
plt.ylabel('Loss')
plt.ylim(0, 1)
plt.savefig(dataset+'_loss_curve.png')
plt.show()plt.plot(history[:, 2:4])
plt.legend(['Tr Accuracy', 'Val Accuracy'])
plt.xlabel('Epoch Number')
plt.ylabel('Accuracy')
plt.ylim(0, 1)
plt.savefig(dataset+'_accuracy_curve.png')
plt.show()

测试:图片名改下即可。

import torch
from torchvision import  models, transforms
import torch.nn as nn
import cv2
classes = ["1","2","3","4","5","6","7","8","9"] #识别种类名称(顺序要与训练时的数据导入编号顺序对应,可以使用datasets.ImageFolder().class_to_idx来查看)transf = transforms.ToTensor()
device = torch.device('cuda:0')
num_classes = 2
model_path = "models/skin-cancer-detection_model_3.pt"
image_input = cv2.imread("ISIC_0000019.jpg")
image_input = transf(image_input)
image_input = torch.unsqueeze(image_input,dim=0).cuda()
#搭建模型
resnet50 = models.resnet50(pretrained=True)
for param in resnet50.parameters():param.requires_grad = Falsefc_inputs = resnet50.fc.in_features
resnet50.fc = nn.Sequential(nn.Linear(fc_inputs, 256),nn.ReLU(),nn.Dropout(0.4),nn.Linear(256, num_classes),nn.LogSoftmax(dim=1)
)
resnet50 = torch.load(model_path)outputs = resnet50(image_input)
value,id =torch.max(outputs,1)
print(outputs,"\n","结果是:",classes[id])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/227196.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

很实用的ChatGPT网站——httpchat-zh.com

很实用的ChatGPT网站——http://chat-zh.com/ 今天介绍一个好兄弟开发的ChatGPT网站&#xff0c;网址[http://chat-zh.com/]。这个网站功能模块很多&#xff0c;包含生活、美食、学习、医疗、法律、经济等很多方面。下面简单介绍一些部分功能与大家一起分享。 登录和注册页面…

nvm 的安装及使用 (Node版本管理器)

目录 1、nvm 介绍 2、nvm安装 3、nvm 使用 4、node官网可以查看node和npm对应版本 5、nvm安装指定版本node 6、安装cli脚手架 1、nvm 介绍 NVM 全称 node.js version management &#xff0c;专门针对 node 版本进行管理的工具&#xff0c;通过它可以安装和切换不同版本的…

二分查找--二分查找算法(朴素二分模板)

个人主页&#xff1a;Lei宝啊 愿所有美好如期而遇 本题题目链接https://leetcode.cn/problems/binary-search/description/ 算法原理 二段性&#xff0c;我们发现这个数组可以找到某种规律将其分为两段&#xff0c;不断划分下去&#xff0c;最终可以找到target 图示 我们分…

编程笔记 html5cssjs 004 我的第一个页面

编程笔记 html5&css&js 004 我的第一个页面 一、基本结构二、HTML标签三、HTML元素四、HTML属性五、编写第一个网页六、使用VSCODE小结 开始编写网页&#xff0c;并且使用第一个网页成为一个母板&#xff0c;用于完成后续内容的学习。有一个基本要求&#xff0c;显示结…

“产品经理必懂的关键术语“

产品经理是现代企业中非常重要的一个角色&#xff0c;他们负责制定产品策略、规划产品开发流程、管理产品质量和用户反馈等等。然而&#xff0c;对于产品经理来说&#xff0c;了解并掌握相关的专业术语是非常重要的。本篇文章会介绍一些产品经理需要掌握的专业术语&#xff0c;…

系统启动流程 - 理解modules加载流程

​编辑 Hacker_Albert    202 linux 启动流程module加载 1.启动过程分为三个部分 BIOS 上电自检&#xff08;POST&#xff09;引导装载程序 (GRUB2)内核初始化启动 systemd&#xff0c;其是所有进程之父。 1.1.BIOS 上电自检&#xff08;POST&#xff09; BIOS stands for…

详解Keras3.0 Layer API: Dropout layer

Dropout layer 图1 标准的神经网络 图2 加了Dropout临时删除部分神经元 Dropout层的作用是在神经网络中引入正则化&#xff0c;以防止过拟合。它通过随机丢弃一部分神经元&#xff08;如图2&#xff09;的输出来减少模型对训练数据的依赖性。这样可以提高模型的泛化能力&#x…

VSCode中的注释标签

2023年12月30日&#xff0c;周六上午 在软件开发中&#xff0c;开发者会使用这些标签来提供关于代码功能、版本信息、作者、API使用说明等方面的额外信息。 这些标签的含义通常是&#xff1a; apiNote: 提供有关API使用的注释或说明。author: 标识代码作者的信息。category: …

jvm实战之-常用jvm命令的使用

各命令的使用 JMAP 1、查看内存信息&#xff0c;对象实例数、对象占有大小 jmap -histo 进程号>./log.txt2、查看堆的配置信息和使用情况 jmap - heap 进程号3、将堆的快照信息dump下来&#xff0c;使用java自带的jvisualvm.exe打开分析 jmap -dump:formatb,filedump.h…

启封涂料行业ERP需求分析和方案分享

涂料制造业是一个庞大而繁荣的行业 它广泛用于建筑、汽车、电子、基础设施和消费品。涂料行业生产不同的涂料&#xff0c;如装饰涂料、工业涂料、汽车涂料和防护涂料。除此之外&#xff0c;对涂料出口的需求不断增长&#xff0c;这增加了增长和扩张的机会。近年来&#xff0c;…

前端的 js

js 点击按钮修改文字 <!DOCTYPE html> <html> <head></head><body><h2>Head 中的 JavaScript</h2><p id"demo">一个段落。</p><button type"button" onclick"myFunction()">试一…

HUAWEI华为笔记本电脑MateBook D 14 2022款 i5 集显 非触屏(NbDE-WFH9)原装出厂Windows11系统21H2

链接&#xff1a;https://pan.baidu.com/s/1-tCCFwZ0RggXtbWYBVyhFg?pwdmcgv 提取码&#xff1a;mcgv 华为MageBookD14原厂WIN11系统自带所有驱动、出厂状态主题壁纸、Office办公软件、华为电脑管家、华为应用市场等预装软件程序 文件格式&#xff1a;esd/wim/swm 安装方式…

Solidworks学习笔记

本内容为solidworks的学习笔记&#xff0c;根据自己的理解进行记录&#xff0c;部分可能不正确&#xff0c;请自行判断。 学习视频参考&#xff1a;【SolidWorks2018视频教程 SW2018中文版软件基础教学知识 SolidWorks自学教程软件操作教程 sw视频教程 零基础教程 视频教程】 h…

JY901S 9轴姿态角度传感器模块

JY901S 9轴姿态角度传感器模块 JY901S 简介模块特性引脚说明IIC通讯IIC读写寄存器代码示例 JY901S 简介 模块集成高精度的陀螺仪、加速度计、地磁场传感器&#xff0c;采用高性能的微处理器和先进的动力学解算与卡尔曼动态滤波算法&#xff0c;能够快速求解出模块当前的实时运…

Nacos2.1.2改造适配达梦数据库7.0

出于业务需求&#xff0c;现将Nacos改造适配达梦数据库7.0&#xff0c;记录本次改造过程。 文章目录 一、前期准备二、适配流程1、项目初始化2、引入驱动3、源码修改 三、启动测试四、打包测试 一、前期准备 Nacos源码&#xff0c;版本&#xff1a;2.1.2&#xff1a;源码下载…

Python圣诞树代码

Python圣诞树代码 # 小黄 2023/12/25import turtle as t # as就是取个别名&#xff0c;后续调用的t都是turtle from turtle import * import random as rn 100.0speed(20) # 定义速度 pensize(5) # 画笔宽度 screensize(800, 800, bgblack) # 定义背景颜色&#xff0c;可…

Selenium在vue框架下求生存

vue框架下面&#xff0c;没有id、没有name&#xff0c;vue帮开发做了很多脏活累活&#xff0c;却委屈了写页面自动化测试的人&#xff08;当然&#xff0c;也给爬信息的也带来了一定的难处&#xff09;。这里只能靠总结&#xff0c;用一些歪门邪道&#xff1a; 一、跟开发商量…

RabbitMQ之快速入门、上手

前言 学习一样新技术、新框架&#xff0c;最重要的是学习其思想、原理。即原理性思维。 如果是因为工作原因&#xff0c;需要快速上手RabbitMQ&#xff0c;本篇或许适合你。 核心概念 Connection&#xff1a;publisher&#xff0f;consumer 和 broker 之间的 TCP 连接Channel…

Zookeeper的使用场景

统一命名服务 利用ZooKeeper节点的树形分层结构和子节点的顺序维护能力&#xff0c;来为分布式系统中的资源命名。 例&#xff1a;分布式节点命名 分布式消息队列 1.在Zookeeper中创建一个持久节点&#xff0c;用作队列的根节点。队列元素的节点放在这个根节点下。 2.入队:…

信息网络协议基础_IP移动网络管理

文章目录 概述移动IPv6待解决的问题关键词基本过程分组拦截技术移动检测和转交地址自动配置到家乡代理绑定注册通信对端不支持IPv6通信对端支持移动IPv6对IP以上层屏蔽移动性移动IPv6存在的问题移动IPv6优化代理移动IP概述原理基本过程初始接入切换概述