神经网络入门实战:(二十三)使用本地数据集进行训练和验证

使用本地数据集进行训练和验证

训练过程和用官网的数据集没什么差别,只是需要从本地加载数据集。

本地的数据集需要弄好格式,分别创建 trainval 两个文件夹,用来存放训练和验证数据集,两个文件夹里面是多个以类别名命名的子文件夹,其中存放相应类别的图片。样式大概是这样:

data_directory/class1/img1.jpgimg2.jpg...class2/img1.jpgimg2.jpg...class3/img1.jpgimg2.jpg...

使用 datasets.ImageFolderDataLoader 加载本地数据集的代码如下:

import torch
import torchvision
from torchvision import transforms,datasets
from torch.utils.data import DataLoader# 加载本地训练数据集
train_data_dir = 'E:\\4_Data_sets\\respiratory waveform\\train'  # 数据集根目录路径
train_image_datasets = datasets.ImageFolder(train_data_dir, transform=dataclass_transform)
train_dataloader = DataLoader(train_image_datasets, batch_size=7, shuffle=True)
train_classes = train_image_datasets.classes  # 获取类别名称
# 加载本地验证数据集
test_data_dir = 'E:\\4_Data_sets\\respiratory waveform\\train'  # 数据集根目录路径
test_image_datasets = datasets.ImageFolder(test_data_dir, transform=dataclass_transform)
test_dataloader = DataLoader(test_image_datasets, batch_size=7, shuffle=True)
test_classes = test_image_datasets.classes  # 获取类别名称

完整训练代码如下:

这里以 resnet18 为例,不过模型保存在了 NN_models 文件里,并未给出。大家主要看看怎么使用本地数据集。

import torch
from torch import nn
import torchvision
from torchvision import transforms,datasets
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from NN_models import *
from PIL import Image
import time# 检查CUDA是否可用,并设置设备为 GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")dataclass_transform = transforms.Compose([transforms.ToTensor(),transforms.Resize(224),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 标准化
])# 加载本地训练数据集
train_data_dir = 'E:\\4_Data_sets\\respiratory waveform\\train'  # 数据集根目录路径,自定义
train_image_datasets = datasets.ImageFolder(train_data_dir, transform=dataclass_transform)
train_dataloader = DataLoader(train_image_datasets, batch_size=7, shuffle=True)
train_classes = train_image_datasets.classes  # 获取类别名称
# 加载本地验证数据集
test_data_dir = 'E:\\4_Data_sets\\respiratory waveform\\train'  # 数据集根目录路径,自定义
test_image_datasets = datasets.ImageFolder(test_data_dir, transform=dataclass_transform)
test_dataloader = DataLoader(test_image_datasets, batch_size=7, shuffle=True)
test_classes = test_image_datasets.classes  # 获取类别名称# 创建网络模型
resnet18_Instance = resnet18(3, 3).to(device)
# 定义损失函数
loss = nn.CrossEntropyLoss()
# 定义优化器
learning_rate = 0.01
optimizer = torch.optim.SGD(resnet18_Instance.parameters(), lr=learning_rate, momentum=0.9)# 开始训练
total_train_step = 0
first_train_step = 0
total_test_step = 0
epoch_sum = 15  # 迭代次数# 添加tensorboard
writer = SummaryWriter('logs')start_time = time.time()
last_epoch_time = time.time()  # 记录开始训练的时间for i in range(epoch_sum):print("----------------第 {} 轮训练开始了----------------:".format(i + 1))# 训练步骤开始for data in train_dataloader:imgs, labels = dataimgs, labels = imgs.to(device), labels.to(device)  # 将数据和目标移动到GPUoutputs = resnet18_Instance(imgs)loss_real = loss(outputs, labels)  # 这里的损失变量 loss_real,千万别和损失函数 loss 相同,否则会报错!optimizer.zero_grad()loss_real.backward()optimizer.step()  # 更新模型参数total_train_step += 1# 表示第一轮训练结束,取每一轮的第一个batch_size来看看训练效果if total_train_step % 33 == 0:first_train_step += 1print("训练次数为:{}, loss为:{}".format(total_train_step, loss_real))  # 此训练次数非训练轮次# 每轮测试结束之后,计算训练的准确率,就拿第一个batch_size来看看准确率outputs_B = torch.argmax(outputs, dim=1)outputs_C = (labels == outputs_B).sum()  # True默认为1,False默认为0accuracy = (outputs_C.item() / len(outputs_B)) * 100accuracy = round(accuracy, 2)  # 2表示保留两位小数(四舍五入)print(f"训练正确率为:{accuracy}%")writer.add_scalar('first_batch_size', loss_real.item(), first_train_step)writer.add_scalar('total_batch_size', loss_real.item(), total_train_step)# 每训练一轮,就使用测试集看看训练效果total_test_loss = 0with torch.no_grad():for data in test_dataloader:imgs, labels = dataimgs, labels = imgs.to(device), labels.to(device)outputs = resnet18_Instance(imgs)loss_fake = loss(outputs, labels)total_test_loss += loss_fake.item()print("整体测试集上的LOSS为:{}".format(total_test_loss))one_epoch_time = time.time()  # 记录训练一次的时间one_cost_time = one_epoch_time - last_epoch_timeprint(f"训练此轮需要的时间为:{one_cost_time}")last_epoch_time = one_epoch_timeone_epoch_time = 0end_time = time.time()
total_time = end_time - start_time
print(f"训练总计需要的时间为:{total_time}")writer.close()

训练和验证结果如下:

----------------1 轮训练开始了----------------:
训练次数为:33, loss为:1.8306987285614014
训练正确率为:57.14%
整体测试集上的LOSS为:50.43653497334162
训练此轮需要的时间为:3.3512418270111084
----------------2 轮训练开始了----------------:
训练次数为:66, loss为:1.7029897492193413e-08
训练正确率为:100.0%
整体测试集上的LOSS为:36.150528442618
训练此轮需要的时间为:2.1926515102386475
----------------3 轮训练开始了----------------......
----------------15 轮训练开始了----------------:
训练次数为:495, loss为:1.5667459365431569e-06
训练正确率为:100.0%
整体测试集上的LOSS为:1.8008886913487672
训练此轮需要的时间为:2.1712865829467773
训练总计需要的时间为:32.333725929260254

上一篇下一篇
神经网络入门实战(二十二)待发布

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/501019.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

试题转excel;word转excel;大风车excel(1.1更新)

最近更新了大风车excel1.1版本 主要优化在算法层面: 1.0版本试题解析的成功率为95%,现在1.1版本已经优化到解析成功率为99% 一、问题描述 一名教师朋友,偶尔会需要整理一些高质量的题目到excel中 以往都是手动复制搬运,几百道…

python实现自动登录12306抢票 -- selenium

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 python实现自动登录12306抢票 -- selenium 前言其实网上也出现了很多12306的代码,但是都不是最新的,我也是从网上找别人的帖子,看B站视频&…

机器学习之正则化惩罚和K折交叉验证调整逻辑回归模型

机器学习之正则化惩罚和K折交叉验证调整逻辑回归模型 目录 机器学习之正则化惩罚和K折交叉验证调整逻辑回归模型1 过拟合和欠拟合1.1 过拟合1.2 欠拟合 2 正则化惩罚2.1 概念2.2 函数2.3 正则化种类 3 K折交叉验证3.1 概念3.2 图片理解3.3 函数导入3.4 参数理解 4 训练模型K折交…

文件本地和OSS上传

这里写目录标题 前端传出文件后端本地存储阿里云OSS存储上传Demo实现上传ConfigurationProperties 前端传出文件 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><title>上传文件</title> </head&g…

《Vue3实战教程》37:Vue3生产部署

如果您有疑问&#xff0c;请观看视频教程《Vue3实战教程》 生产部署​ 开发环境 vs. 生产环境​ 在开发过程中&#xff0c;Vue 提供了许多功能来提升开发体验&#xff1a; 对常见错误和隐患的警告对组件 props / 自定义事件的校验响应性调试钩子开发工具集成 然而&#xff…

python制作打字小游戏

import pygame # 导入游戏模块 安装pygame import sys # 导入系统指令模块 import random # 导入随机数模块 pygame.init() #初始化游戏环境 wndpygame.display.set_mode((800,565)) #指定窗口大小 pygame.mixer.music.load(素材/SurvivalGame.mp3) #素…

抖音短视频矩阵系统源码开发全流程解析

在项目开发过程中&#xff0c;调整配置文件至关重要&#xff0c;这些文件包括数据库连接、API密钥及全局参数等。通过正确配置这些信息&#xff0c;可确保应用程序的稳定性和安全性。灵活调整配置以适应具体需求有助于短视频矩阵系统项目的顺利推进。 在开发环境中&#xff0c…

前端路由layout布局处理以及菜单交互(三)

上篇介绍了前端项目部署以及基本依赖的应用&#xff0c;这次主要对于路由以及布局进行模块化处理 一、 创建layout模块 1、新建src/layout/index.vue <template><el-container class"common-layout"><!-- <el-aside class"aside">&l…

戴尔/Dell 电脑按什么快捷键可以进入 Bios 设置界面?

BIOS&#xff08;基本输入输出系统&#xff09;是计算机硬件与操作系统之间的桥梁&#xff0c;它负责初始化和测试系统硬件组件&#xff0c;并加载启动操作系统。在某些情况下&#xff0c;如调整启动顺序、更改系统时间或日期、修改硬件配置等&#xff0c;您可能需要进入BIOS进…

《从入门到精通:蓝桥杯编程大赛知识点全攻略》(一)-递归实现指数型枚举、递归实现排列型枚举

本篇博客将聚焦于通过递归来实现两种经典的枚举方法&#xff1a;指数型枚举和排列型枚举。这两种枚举方式在计算机科学和算法竞赛中都有广泛应用&#xff0c;无论是在解题中&#xff0c;还是在实际工作中都极具价值。 目录 前言 斐波那契数列递归 递归实现指数型枚举 算法思…

idea 的 springboot项目spring-boot-devtools 自动编译 配置热部署

1&#xff0c;设置一 2&#xff0c;设置二 设置二&#xff08;旧版本&#xff09; CtrlShiftAlt/ 点击弹出框中Registry... 引入&#xff08;如果报错&#xff0c;换不同的版本&#xff09; <dependency><groupId>org.springframework.boot</groupId><a…

低代码开发:开启企业数智化转型“快捷键”

一、低代码开发浪潮来袭&#xff0c;企业转型正当时 在当今数字化飞速发展的时代&#xff0c;低代码开发已如汹涌浪潮&#xff0c;席卷全球。从国际市场来看&#xff0c;诸多企业巨头纷纷布局低代码领域&#xff0c;像微软的 PowerApps、OutSystems 等平台&#xff0c;凭借强大…

C#二维数组详解

目录 1&#xff0c;什么是二维数组&#xff1f; 2&#xff0c;创建二维数组的几种方式 &#xff08;1&#xff09;使用[,]声明数组&#xff08;常见方式&#xff09; &#xff08;2&#xff09;声明数组时指定元素 &#xff08;3&#xff09;使用new创建数组 &#xff08;…

STM32--超声波模块(HC—SR04)(标准库+HAL库)

一、HC-SR04工作原理 1&#xff09;采用IO触发测距&#xff0c;给至少10us的高电平信号。 2&#xff09;模块自动发送8个40KHz的方波&#xff0c;自动检测是否有信号返回。 3&#xff09;有信号返回&#xff0c;通过IO输出一高电平&#xff0c;高电平持续时间就是超声波从发…

DDD(一)—— Authentication with JWT

文章目录 项目地址一、项目结构梳理1.1 Domain层1.1.1 Entities文件夹1.2 Contracts层1.2.1 Authentication文件夹1.3 Appliaction层1.3.1Common文件夹1. Interfaces文件夹Authentication 权限接口Persistence 数据库接口Services 常用服务接口1.3.2 Services文件夹1. Authenti…

GPU 进阶笔记(一):高性能 GPU 服务器硬件拓扑与集群组网

记录一些平时接触到的 GPU 知识。由于是笔记而非教程&#xff0c;因此内容不求连贯&#xff0c;有基础的同学可作查漏补缺之用 1 术语与基础 1.1 PCIe 交换芯片1.2 NVLink 定义演进&#xff1a;1/2/3/4 代监控1.3 NVSwitch1.4 NVLink Switch1.5 HBM (High Bandwidth Memory) 由…

自由学习记录(31)

Java连接MySQL 找到那个关键jar包然后导入选中&#xff0c;就配置好MySQL的JDBC&#xff08;Java Database Connectivity&#xff09;了 菜单--文件--项目结构 项目设置--模块--选择要附着的项目--选择依赖--选中模块源--选中加号添加jar包 解压之后在里面可以看到这个最关键…

第十四届蓝桥杯Scratch省赛中级组—智能计价器

智能计价器 背景信息&#xff1a; A城市的出租车计价&#xff1a;3公里以内13元&#xff0c;基本单价每公里2.3元(超过3公里的部分&#xff0c;不满1公里按照1公里收费&#xff09;&#xff0c;燃油附加费每运次1元。 例如&#xff1a; 3.2公里的打车费用&#xff1a;132.3…

游戏引擎学习第69天

回顾碰撞响应时我们停留的位置 从昨天的讨论开始&#xff0c;我们正准备处理碰撞响应的复杂性。具体来说&#xff0c;我们讨论的是&#xff0c;当两个实体在屏幕上发生碰撞时&#xff0c;如何回应这种情况。碰撞本身并不复杂&#xff0c;但要处理其后的反应和规则则更具挑战性…

全新免押租赁系统助力商品流通高效安全

内容概要 全新免押租赁系统的推出&#xff0c;可以说是一场商品流通领域的小革命。想象一下&#xff0c;不再为押金烦恼&#xff0c;用户只需通过一个简单的信用评估&#xff0c;就能快速租到所需商品&#xff0c;这种体验简直令人惊喜&#xff01;这个系统利用代扣支付技术&a…