卷积神经网络实现运动鞋识别 - P5

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍦 参考文章:Pytorch实战 | 第P5周:运动鞋识别
  • 🍖 原作者:K同学啊 | 接辅导、项目定制
  • 🚀 文章来源:K同学的学习圈子

目录

  • 环境
  • 步骤
    • 环境设置
      • 包引用
      • 训练设备
    • 数据准备
      • 图像解压后的路径
      • 打印图像的参数
      • 展示图像
      • 图像的预处理
      • 创建数据集
      • 获取数据集的分类
      • 打乱数据的顺序,生成批次
    • 模型设计
    • 模型训练
      • 训练函数
      • 评估函数
      • 循环迭代部分
    • 模型效果展示
      • 训练过程图表展示
      • 载入最佳模式,随机选择图像进行预测
  • 总结与心得体会


环境

  • 系统: Linux
  • 语言: Python3.8.10
  • 深度学习框架: Pytorch2.0.0+cu118

步骤

环境设置

包引用

import torch
import torch.nn as nn 
import torch.optim as optim # 优化器
import torch.nn.functional as F # 可以静态调用的方法from torchvision import datasets, transforms # 数据集创建、数据预处理方法
from torch.utils.data import DataLoader # DataLoader可以将数据集封装成批次数据import matplotlib.pyplot as plt
import numpy as np
from PIL import Image # 加载图片预览使用的库
from torchinfo import summary # 可以打印模型实际运行时的图
import copy # 深拷贝使用的库
import pathlib, random # 文件夹遍历和随机数

训练设备

# 声明一个全局设备对象,方便后面将数据和模型拷贝到设备中
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

数据准备

图像解压后的路径

train_path = 'train'
test_path = 'test'

打印图像的参数

train_pathlib = pathlib.Path(train_path)
train_image_list = list(train_pathlib.glob('*/*'))
for _ in range(5):print(np.array(Image.open(str(random.choice(train_image_list)))).shape)

图片的参数
重复执行了多次,返回结果都是(240, 240, 3),可以确定图像的大小统一为240,240,在数据加载的过程中可以不对图像做缩放处理。

展示图像

plt.figure(figsize=(20, 4))
for i in range(20):image = random.choice(train_image_list)plt.subplot(2, 10, i+1)plt.axis('off')plt.imshow(Image.open(str(image)))plt.title(image.parts[-2])

数据集预览
至此我们对数据集中的图像有了一个初步的了解。接下来就是准备训练数据。

图像的预处理

定义一些图像的预处理方法,例如将图像读取并转为pytorch的tensor对象,然后对图像的数值做归一化处理

transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

创建数据集

train_dataset = datasets.ImageFolder(train_path, transform=transform)
test_dataset = datasets.ImageFolder(test_path, transform=transform)

获取数据集的分类

class_names = [key for key in train_dataset.class_to_idx]
print(class_names)

数据分类

打乱数据的顺序,生成批次

batch_size = 32
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

模型设计

使用3x3的卷积核,最大的通道数到256,每次卷积操作后,就紧跟一个池化层,一共使用了4个卷积层和4个池化层。最后使用了三层全连接网络来做分类器。
模型结构图

class Network(nn.Module):def __init__(self, num_classes):super().__init__()self.conv1 = nn.Conv2d(3, 64, 3)self.bn1  = nn.BatchNorm2d(64)self.conv2 = nn.Conv2d(64, 128, 3)self.bn2 = nn.BatchNorm2d(128)self.conv3 = nn.Conv2d(128, 256, 3)self.bn3 = nn.BatchNorm2d(256)self.conv4 = nn.Conv2d(256, 256, 3)self.bn4 = nn.BatchNorm2d(256)self.maxpool = nn.MaxPool2d(2)self.fc1 = nn.Linear(13*13*256, 128)self.fc2 = nn.Linear(128, 128)self.fc3 = nn.Linear(128, num_classes)self.dropout = nn.Dropout(0.5)def forward(self, x):# 240 -> 238x = F.relu(self.bn1(self.conv1(x)))# 238 -> 119x = self.maxpool(x)# 119 -> 117x = F.relu(self.bn2(self.conv2(x)))# 117 -> 58x = self.maxpool(x)# 58 -> 56x = F.relu(self.bn3(self.conv3(x)))# 56 -> 28x = self.maxpool(x)# 28 -> 26x = F.relu(self.bn4(self.conv4(x)))# 26 -> 13x = self.maxpool(x)x = x.view(x.size(0), -1)x = self.dropout(x)x = F.relu(self.dropout(self.fc1(x)))x = F.relu(self.dropout(self.fc2(x)))x = self.fc3(x)return x
model = Network(len(class_names)).to(device)summary(model, input_size=(32, 3, 240, 240))

模型结构图

模型训练

模型训练过程中,每个epoch都会对全部的训练集进行一次完整的遍历,所以可以封装一些训练和评估方法,将业务逻辑和循环分开

训练函数

def train(train_loader, model, loss_fn, optimizer):size = len(train_loader.dataset)num_batches = len(train_loader)train_loss, train_acc = 0, 0for x, y in train_loader:x, y = x.to(device), y.to(device)pred = model(x)loss = loss_fn(pred, y)optimizer.zero_grad()loss.backward()optimizer.step()train_loss += loss.item()train_acc += (pred.argmax(1) == y).type(torch.float).sum().item()train_loss /= num_batchestrain_acc /= sizereturn train_loss, train_acc

评估函数

def test(test_loader, model, loss_fn):size = len(test_loader.dataset)num_batches = len(test_loader)test_loss, test_acc = 0, 0for x, y in test_loader:x, y = x.to(device), y.to(device)pred = model(x)loss = loss_fn(pred, y)test_loss += loss.item()test_acc += (pred.argmax(1) == y).type(torch.float).sum().item()test_loss /= num_batchestest_acc /= sizereturn test_loss, test_acc

循环迭代部分

loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
scheduler = optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=lambda epoch:0.92**(epoch //2)) 
# 创建学习率的衰减
epochs = 50train_loss, train_acc = [], []
test_loss, test_acc = [], []
best_acc = 0
for epoch in range(epochs):model.train()epoch_train_loss, epoch_train_acc = train(train_loader, model, loss_fn, optimizer)model.eval()with torch.no_grad():epoch_test_loss, epoch_test_acc = test(test_loader, model, loss_fn)scheduler.step() # 每次迭代调用一次,自动做学习率衰减# 如果当前评估的学习率更好,就保存当前模型if best_acc < epoch_test_acc:best_acc = epoch_test_accbest_model = copy.deepcopy(model)# 记录历史记录train_loss.append(epoch_train_loss)train_acc.append(epoch_train_acc)test_loss.append(epoch_test_loss)test_acc.append(epoch_test_acc)# 打印每个迭代的数据print(f"Epoch:{epoch+1}, TrainLoss: {epoch_train_loss:.3f}, TrainAcc: {epoch_train_acc*100:.1f}, TestLoss: {epoch_test_loss:.3f}, TestAcc: {epoch_test_acc*100:.1f}")# 打印本次训练的最佳正确率
print(f'training finished, best_acc is {best_acc*100:.1f}')# 将最佳模型保存到文件中
torch.save(model.state_dict(), 'best_model.pth')

模型训练过程

模型效果展示

训练过程图表展示

画一个拆线图,观察训练过程中损失函数和正确率的变化趋势

plt.figure(figsize=(20,5))epoch_range = range(epochs)plt.subplot(1,2, 1)
plt.plot(epoch_range, train_loss, label='train loss')
plt.plot(epoch_range, test_loss, label='validation loss')
plt.legend(loc='upper right')
plt.title('Loss')plt.subplot(1,2,2)
plt.plot(epoch_range, train_acc, label='train accuracy')
plt.plot(epoch_range, test_acc, label='validation accuracy')
plt.legend(loc='lower right')
plt.title('Accuracy')

训练过程图示
可以看出模型在最后基本已经收敛,最佳准确率是88.2%,满足了挑战任务。

载入最佳模式,随机选择图像进行预测

model.load_state_dict(torch.load('best_model.pth'))
model = model.to(device)test_pathlib = pathlib.Path(test_path)image_list = list(test_pathlib.glob('*/*'))image_path = random.choice(image_list)
image = transform(Image.open(str(image_path)))
image = image.unsqueeze(0)
image = image.to(device)pred = model(image)plt.figure(figsize=(5,5))
plt.axis('off')
plt.imshow(Image.open(str(image_path)))
plt.title(f'real: {image_path.parts[-2]}, predict: {class_names[pred.argmax(1).item()]}')

预测结果
上次运行上面的预测任务,发现正确率还不错。

总结与心得体会

  1. 整个模型设计的思路其实是模仿了vgg16模型,在卷积层的数量和通道上做了简化。轻量级的任务可以首先试着减少池化层间的卷积次数,减少模型中最大的特征图的通道数
  2. 对图像的归一化操作很重要。在没有归一化前,模型的最佳正确率只能达到80%,推测可能是因为未做归一化的图像值域范围太大,不方便收敛,归一化后,原始图像中的输入特征值范围变成0~1,模型的权重变化更易作用到特征上。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/115952.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

沐风老师3DMAX厨房橱柜生成器KitchenCabinetGenerator教程

3DMAX厨房橱柜生成器插件使用方法 3DMAX橱柜生成器KitchenCabinetGenerator是一个在3dMax中自动创建三维橱柜模型的高效脚本。它有多种风格的台面、门和橱柜&#xff0c;可以灵活地应用于Archviz项目&#xff0c;同时为3D艺术家节省大量时间。 【适用版本】 1.3dMax2018 – 20…

YOLO数据集划分(训练集、验证集、测试集)

1.将训练集、验证集、测试集按照7:2:1随机划分 1.项目准备 1.在项目下新建一个py文件&#xff0c;名字就叫做splitDataset1.py 2.将自己需要划分的原数据集就放在项目文件夹下面 以我的为例&#xff0c;我的原数据集名字叫做hatDataXml 里面的JPEGImages装的是图片 Annota…

设计模式-适配器

文章目录 一、简介二、适配器模式基础1. 适配器模式定义与分类2. 适配器模式的作用与优势3.UML图 三、适配器模式实现方式1. 类适配器模式2. 对象适配器模式3.类适配器模式和对象适配器模式对比 四、适配器模式应用场景1. 继承与接口的适配2. 跨平台适配 五、适配器模式与其他设…

C++之std::distance应用实例(一百八十八)

简介&#xff1a; CSDN博客专家&#xff0c;专注Android/Linux系统&#xff0c;分享多mic语音方案、音视频、编解码等技术&#xff0c;与大家一起成长&#xff01; 优质专栏&#xff1a;Audio工程师进阶系列【原创干货持续更新中……】&#x1f680; 人生格言&#xff1a; 人生…

论文解读 | 三维点云深度学习的综述

原创 | 文 BFT机器人 KITTI 是作为基准测试是自动驾驶中最具影响力的数据集之一&#xff0c;在学术界和工业界都被广泛使用。现有的三维对象检测器存在着两个限制。第一是现有方法的远程检测能力相对较差。其次&#xff0c;如何充分利用图像中的纹理信息仍然是一个开放性的问题…

uniapp授权小程序隐私弹窗效果demo(整理)

9月15号前要配置这句话 "__usePrivacyCheck__": true,官方“小程序隐私协议开发指南”文档 <template> <view class"dealBox"><view class"txtBox padding10"><!-- 查看协议 -->在您使用施工现场五星计划小程序之前&am…

解决D盘的类型不是基本,而是动态的问题

一、正确的图片 1.1图片 1.2本人遇到的问题 二、将动态磁盘 转为基本盘 2.1 基本概念&#xff0c;动态无法转化为基本&#xff0c;不是双向的&#xff0c;借助软件 网址&#xff1a;转换动态磁盘到普通磁盘_检测到计算机本地磁盘为动态分区_卫水金波的博客-CSDN博客 2.2分区…

我开课了!《机器学习》公益课9月4日开课

我是黄海广&#xff0c;大学老师&#xff0c;我上的一门课叫《机器学习》&#xff0c;本科生学机器学习有点难&#xff0c;但也不是没有可能&#xff0c;我在摸索中&#xff0c;设计适合本科生的机器学习课程&#xff0c;写了教材&#xff0c;录了视频&#xff0c;做了课件。我…

安装使用electron

一、安装node和npm 运行cmd查看是否安装及版本号 npm -v node -v 二、安装electron npm直接安装会报错缺少什么文件&#xff0c;使用cnpm进行安装 直接安装cnmp后&#xff0c;再用cnmp命令安装可能会报错Error: Cannot find module ‘node:util’ 原因是npm版本与cnpm版本…

MySQL官网下载安装包

MySQL官网&#xff1a; MySQL MySQL 8.0官网下载地址&#xff1a; MySQL :: Download MySQL Community Server 2023-07-18 MySQL 8.1.0 发布&#xff0c;这是 MySQL 变更发版模型后的第一个创新版本 (Innovation Release) 。 如果在官网中找不到下载位置&#xff0c;点击第二个…

在Visual Studio 2017上配置并使用OpenGL

1 在Visual Studio 2017上配置并使用OpenGL 在GLUT - The OpenGL Utility Toolkit&#xff1a;GLUT - The OpenGL Utility Toolkit中点击“GLUT for Microsoft Windows 95 & NT users”&#xff0c;选择“If you want just the GLUT header file, the .LIB, and .DLL file…

elementplus实现左侧菜单栏收缩与展开

1.页面结构 Home.vue下包含aside.vue和menu.vue 2.TAside.vue el-menu左侧菜单栏显示 注意&#xff1a; 要使用收缩与展开&#xff0c;el-aside必须设置width"collapse"&#xff0c;否则收缩展开会出现收缩后&#xff0c;el-aside宽度不变窄需要使用动态改变展开收…

使用boost::geometry::union_ 合并边界(内、外)- 方案一

使用boost::geometry::union_ 合并边界&#xff08;内、外&#xff09;&#xff1a;方案一 结合 boost::geometry::read_wkt() 函数 #include <iostream> #include <vector>#include <boost/geometry.hpp> #include <boost/geometry/geometries/point_x…

C++ 文件和流

iostream 标准库提供了 cin 和 cout 方法&#xff0c;用于从标准输入读取流和向标准输出写入流。而从文件中读取流或向文件写入流&#xff0c;需要用到fstream标准库。在 C 中进行文件处理时&#xff0c;须在源代码文件中包含头文件 <iostream> 和 <fstream>。fstr…

Python小知识 - 一致性哈希算法

一致性哈希算法 一致性哈希算法&#xff08;Consistent Hashing Algorithm&#xff09;是用于解决分布式系统中节点增减比较频繁的问题。它的思想是&#xff0c;将数据映射到0~2^64-1的哈希空间中&#xff0c;并通过哈希函数对数据进行映射&#xff0c;计算出数据所在的节点。当…

Hadoop依赖环境配置与安装部署

目录 什么是Hadoop&#xff1f;一、Hadoop依赖环境配置1.1 设置静态IP地址1.2 重启网络1.3 再克隆两台服务器1.4 修改主机名1.5 安装JDK1.6 配置环境变量1.7 关闭防火墙1.8 服务器之间互传资料1.9 做一个host印射1.10 免密传输 二、Hadoop安装部署2.1 解压hadoop的tar包2.2 切换…

【笔记】常用 js 函数

数组去重 Array.from(new Set()) 对象合并 Object.assign . 这里有个细节&#xff1a;当两个对象中含有key相同value不同时&#xff0c;会以 后面对象的key&#xff1a;value为准 保留小数点后几位 toFixed 注意&#xff1a; Number型&#xff0c;用该方法处理完&#xff0c;会…

4、DVWA——文件包含

文章目录 一、文件包含概述二、low2.1 源码分析2.2 通关分析 三、medium3.1 源码分析3.2 通关思路 四、high4.1 源码分析4.2 通关思路 五、impossible 一、文件包含概述 文件包含是指当服务器开启allow_url_include选项时&#xff0c;就可以通过php的某些特性函数&#xff08;i…

【Vue3 知识第二讲】Vue3新特性、vue-devtools 调试工具、脚手架搭建

文章目录 一、Vue3 新特性1.1 重写双向数据绑定1.1.1 Vue2 基于Object.defineProperty() 实现1.1.2 Vue3 基于Proxy 实现 1.2 优化 虚拟DOM1.3 Fragments1.4 Tree shaking1.5 Composition API 二、 vue-devtools 调试工具三、环境配置四、脚手架目录介绍五、SFC 语法规范解析附…

dvwa xss通关

反射型XSS通关 low难度 选择难度&#xff1a; 直接用下面JS代码尝试&#xff1a; <script>alert(/xss/)</script>通关成功&#xff1a; medium难度 直接下面代码尝试后失败 <script>alert(/xss/)</script>发现这段代码直接被输出&#xff1a; 尝试…