卷积神经网络——LeNet——FashionMNIST

目录

  • 一、文件结构
  • 二、model.py
  • 三、model_train.py
  • 四、model_test.py

一、文件结构

在这里插入图片描述

二、model.py

import torch
from torch import nn
from torchsummary import summaryclass LeNet(nn.Module):def __init__(self):super(LeNet,self).__init__()self.c1 = nn.Conv2d(in_channels=1,out_channels=6,kernel_size=5,padding=2)self.sig = nn.Sigmoid()self.s2 = nn.AvgPool2d(kernel_size=2,stride=2)self.c3 = nn.Conv2d(in_channels=6,out_channels=16,kernel_size=5)self.s4 = nn.AvgPool2d(kernel_size=2,stride=2)self.flatten = nn.Flatten()self.f5 = nn.Linear(in_features=5*5*16,out_features=120)self.f6 = nn.Linear(in_features=120,out_features=84)self.f7 = nn.Linear(in_features=84,out_features=10)def forward(self,x):x = self.sig(self.c1(x))x = self.s2(x)x = self.sig(self.c3(x))x = self.s4(x)x = self.flatten(x)x = self.f5(x)x = self.f6(x)x = self.f7(x)return x# if __name__ =="__main__":
#     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#
#     model = LeNet().to(device)
#
#     print(summary(model,input_size=(1,28,28)))

三、model_train.py

# 导入所需的Python库
from torchvision.datasets import FashionMNIST
from torchvision import transforms
import torch.utils.data as Data
import torch
from torch import nn
import time
import copy
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from model import LeNet  # model.py中定义了LeNet模型
from tqdm import tqdm  # 导入tqdm库,用于显示进度条# 定义数据加载和处理函数
def train_val_data_process():# 加载FashionMNIST数据集,Resize到28x28尺寸,并转换为Tensortrain_data = FashionMNIST(root="./data",train=True,transform=transforms.Compose([transforms.Resize(size=28), transforms.ToTensor()]),download=True)# 将加载的数据集分为80%的训练数据和20%的验证数据train_data, val_data = Data.random_split(train_data, lengths=[round(0.8 * len(train_data)), round(0.2 * len(train_data))])# 为训练数据和验证数据创建DataLoader,设置批量大小为32,洗牌,2个进程加载数据train_dataloader = Data.DataLoader(dataset=train_data,batch_size=32,shuffle=True,num_workers=2)val_dataloader = Data.DataLoader(dataset=val_data,batch_size=32,shuffle=True,num_workers=2)# 返回训练和验证的DataLoaderreturn train_dataloader, val_dataloader# 定义模型训练和验证过程的函数
def train_model_process(model, train_dataloader, val_dataloader, num_epochs):# 设置使用CUDA如果可用device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 打印使用的设备dev = "cuda" if torch.cuda.is_available() else "cpu"print(f'当前模型训练设备为: {dev}')# 初始化Adam优化器和交叉熵损失函数optimizer = torch.optim.Adam(model.parameters(), lr=0.001)criterion = nn.CrossEntropyLoss()# 将模型移动到选定的设备上model = model.to(device)# 复制模型权重用于后续更新最佳模型best_model_wts = copy.deepcopy(model.state_dict())best_acc = 0.0  # 初始化最佳准确度# 初始化用于记录训练和验证过程中损失和准确度的列表train_loss_all = []val_loss_all = []train_acc_all = []val_acc_all = []# 记录训练开始时间start_time = time.time()# 迭代指定的训练轮数for epoch in range(1, num_epochs + 1):# 记录每个epoch开始的时间since = time.time()# 打印分隔符和当前epoch信息print("-" * 10)print(f"Epoch: {epoch}/{num_epochs}")# 初始化训练和验证过程中的损失和正确预测数量train_loss = 0.0train_corrects = 0val_loss = 0.0val_corrects = 0# 初始化批次计数器train_num = 0val_num = 0# 创建训练进度条progress_train_bar = tqdm(total=len(train_dataloader), desc=f'Training {epoch}', unit='batch')# 训练数据集的遍历for step, (b_x, b_y) in enumerate(train_dataloader):# 将数据移动到相应的设备上b_x = b_x.to(device)b_y = b_y.to(device)# 训练模型model.train()# 前向传播output = model(b_x)# 计算预测标签pre_label = torch.argmax(output, dim=1)# 计算损失loss = criterion(output, b_y)# 清空梯度optimizer.zero_grad()# 反向传播loss.backward()# 更新权重optimizer.step()# 累加损失和正确预测数量train_loss += loss.item() * b_x.size(0)train_corrects += torch.sum(pre_label == b_y.data)# 更新批次计数器train_num += b_x.size(0)# 更新训练进度条progress_train_bar.update(1)# 关闭训练进度条progress_train_bar.close()# 创建验证进度条progress_val_bar = tqdm(total=len(val_dataloader), desc=f'Validation {epoch}', unit='batch')# 验证数据集的遍历for step, (b_x, b_y) in enumerate(val_dataloader):# 将数据移动到相应的设备上b_x = b_x.to(device)b_y = b_y.to(device)# 评估模型model.eval()# 前向传播output = model(b_x)# 计算预测标签pre_label = torch.argmax(output, dim=1)# 计算损失loss = criterion(output, b_y)# 累加损失和正确预测数量val_loss += loss.item() * b_x.size(0)val_corrects += torch.sum(pre_label == b_y.data)# 更新批次计数器val_num += b_x.size(0)# 更新验证进度条progress_val_bar.update(1)# 关闭验证进度条progress_val_bar.close()# 计算并记录epoch的平均损失和准确度train_loss_all.append(train_loss / train_num)train_acc_all.append(train_corrects.double().item() / train_num)val_loss_all.append(val_loss / val_num)val_acc_all.append(val_corrects.double().item() / val_num)# 打印训练和验证的损失与准确度print(f'{epoch} Train Loss: {train_loss_all[-1]:.4f} Train Acc: {train_acc_all[-1]:.4f}')print(f'{epoch} Val Loss: {val_loss_all[-1]:.4f} Val Acc: {val_acc_all[-1]:.4f}')# 计算并打印epoch训练耗费的时间time_use = time.time() - sinceprint(f'第 {epoch} 个 epoch 训练耗费时间: {time_use // 60:.0f}m {time_use % 60:.0f}s')# 若当前epoch的验证准确度为最佳,则更新最佳模型权重if val_acc_all[-1] > best_acc:best_acc = val_acc_all[-1]best_model_wts = copy.deepcopy(model.state_dict())# 训练结束,保存最佳模型权重torch.save(best_model_wts, 'D:/Pycharm/deepl/LeNet/weight/best_model.pth')# 如果当前epoch为总epoch数,则保存最终模型权重if epoch == num_epochs:torch.save(model.state_dict(), f'D:/Pycharm/deepl/LeNet/weight/{num_epochs}_model.pth')# 将训练过程中的统计数据整理成DataFrametrain_process = pd.DataFrame(data={"epoch": range(1, num_epochs + 1),"train_loss_all": train_loss_all,"val_loss_all": val_loss_all,"train_acc_all": train_acc_all,"val_acc_all": val_acc_all})# 打印总训练时间consume_time = time.time() - start_timeprint(f'总耗时:{consume_time // 60:.0f}m {consume_time % 60:.0f}s')# 返回包含训练过程统计数据的DataFramereturn train_process# 定义绘制训练和验证过程中损失与准确度的函数
def matplot_acc_loss(train_process):# 创建图形和子图plt.figure(figsize=(12, 4))# 绘制训练和验证损失plt.subplot(1, 2, 1)plt.plot(train_process["epoch"], train_process["train_loss_all"], 'ro-', label="train_loss")plt.plot(train_process["epoch"], train_process["val_loss_all"], 'bs-', label="val_loss")plt.legend()plt.xlabel("epoch")plt.ylabel("loss")# 保存损失图像plt.savefig('./result_picture/training_loss_accuracy.png', bbox_inches='tight')# 绘制训练和验证准确度plt.subplot(1, 2, 2)plt.plot(train_process["epoch"], train_process["train_acc_all"], 'ro-', label="train_acc")plt.plot(train_process["epoch"], train_process["val_acc_all"], 'bs-', label="val_acc")plt.legend()plt.xlabel("epoch")plt.ylabel("accuracy")# 保存准确率曲线图plt.savefig('./result_picture/training_accuracy.png', bbox_inches='tight')plt.show()if __name__ == "__main__":model = LeNet()train_dataloader, val_dataloader = train_val_data_process()train_process = train_model_process(model, train_dataloader, val_dataloader, num_epochs=20)matplot_acc_loss(train_process)

四、model_test.py

import torch
import torch.utils.data as Data
from torchvision import transforms
from torchvision.datasets import FashionMNIST
from model import LeNet
from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
# t代表testdef t_data_process():test_data = FashionMNIST(root="./data",train=False,transform=transforms.Compose([transforms.Resize(size=28), transforms.ToTensor()]),download=True)test_dataloader = Data.DataLoader(dataset=test_data,batch_size=1,shuffle=True,num_workers=0)return test_dataloaderdef t_model_process(model, test_dataloader):if model is not None:print('Successfully loaded the model.')device = "cuda" if torch.cuda.is_available() else "cpu"model = model.to(device)# 初始化参数test_corrects = 0.0test_num = 0all_preds = []  # 存储所有预测标签all_labels = []  # 存储所有实际标签# 只进行前向传播,不计算梯度with torch.no_grad():for test_x, test_y in test_dataloader:test_x = test_x.to(device)test_y = test_y.to(device)# 设置模型为验证模式model.eval()# 前向传播得到一个batch的结果output = model(test_x)# 查找最大值对应的行标pre_lab = torch.argmax(output, dim=1)# 收集预测和实际标签all_preds.extend(pre_lab.tolist())all_labels.extend(test_y.tolist())# 计算准确率test_corrects += torch.sum(pre_lab == test_y.data)# 将所有的测试样本进行累加test_num += test_x.size(0)# 计算准确率test_acc = test_corrects.double().item() / test_numprint(f'测试的准确率:{test_acc}')# 绘制混淆矩阵conf_matrix = confusion_matrix(all_labels, all_preds)sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues')plt.xlabel('Predicted labels')plt.ylabel('True labels')plt.title('Confusion Matrix')plt.show()plt.savefig('./result_picture/Confusion_Matrix.png', bbox_inches='tight')if __name__=="__main__":# 加载模型model = LeNet()print('loading model')# 加载权重model.load_state_dict(torch.load('D:/Pycharm/deepl/LeNet/weight/best_model.pth'))# 加载测试数据test_dataloader = t_data_process()# 加载模型测试的函数t_model_process(model,test_dataloader)device = "cuda" if torch.cuda.is_available() else "cpu"model = model.to(device)classes = ['T-shirt/top','Trouser','Pullover','Dress','coat','Sandal','Shirt','Sneaker','Bag','Ankle boot']with torch.no_grad():for b_x,b_y in test_dataloader:b_x = b_x.to(device)b_y = b_y.to(device)model.eval()output = model(b_x)pre_lab = torch.argmax(output,dim=1)result = pre_lab.item()label = b_y.item()print(f'预测值:{classes[result]}',"-----------",f'真实值:{classes[label]}')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/376206.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

一文了解MySQL的表级锁

文章目录 ☃️概述☃️表级锁❄️❄️介绍❄️❄️表锁❄️❄️元数据锁❄️❄️意向锁⛷️⛷️⛷️ 介绍 ☃️概述 锁是计算机协调多个进程或线程并发访问某一资源的机制。在数据库中,除传统的计算资源(CPU、RAM、I/O)的争用以外&#xff0…

【深度学习基础】MacOS PyCharm连接远程服务器

目录 一、需求描述二、建立与服务器的远程连接1. 新版Pycharm的界面有什么不同?2. 创建远程连接3. 建立本地项目与远程服务器项目之间的路径映射4.设置保存自动上传文件 三、设置解释器总结 写在前面,本人用的是Macbook Pro, M3 MAX处理器&am…

开发个人Ollama-Chat--6 OpenUI

开发个人Ollama-Chat–6 OpenUI Open-webui Open WebUI 是一种可扩展、功能丰富且用户友好的自托管 WebUI,旨在完全离线运行。它支持各种 LLM 运行器,包括 Ollama 和 OpenAI 兼容的 API。 功能 由于总所周知的原由,OpenAI 的接口需要密钥才…

创建地形——笔记

1、创建地面 (1) 3D Object-Terrain (2) 导入资源 (3) 选中Terrain,绘制贴图 (4) 新建一个沙土层 (5) 编辑沙土层——选中Inspector中的新建沙土层,出现编辑面板 依次点击Nomal Map和Mask Map右侧的Slect,增加法线贴图(紫&…

Run LoongArch64 Alpine VM on x86_64

一、Build from source(build on x86_64) Obtain the latest libvirt, virt-manager, and qemu source code, compile and install them. 1.1 Build libvirt from source sudo apt-get update sudo apt-get install augeas-tools bash-completion debhelper-compat dh-apparm…

深入理解FFmpeg--libavformat接口使用(一)

libavformat(lavf)是一个用于处理各种媒体容器格式的库。它的主要两个目的是去复用(即将媒体文件拆分为组件流)和复用的反向过程(以指定的容器格式写入提供的数据)。它还有一个I/O模块,支持多种…

加密与安全_密钥体系的三个核心目标之完整性解决方案

文章目录 Pre机密性完整性1. 哈希函数(Hash Function)定义特征常见算法应用散列函数常用场景散列函数无法解决的问题 2. 消息认证码(MAC)概述定义常见算法工作原理如何使用 MACMAC 的问题 不可否认性数字签名(Digital …

Objective-C 自定义渐变色Slider

文章目录 一、前情概要二、具体实现 一、前情概要 系统提供UISlider,但在开发过程中经常需要自定义,本次需求内容是实现一个拥有渐变色的滑动条,且渐变色随着手指touch的位置不同改变区域,类似如下 可以使用CAGradientLayer实现渐…

SpringBoot相关

SpringBoot 1. what springboot也是spring公司开发的一款框架。为了简化spring项目的初始化搭建的。 spring项目搭建的缺点: 配置麻烦依赖繁多tomcat启动慢 2 .springboot的特点(why) 自动配置 springboot的自动配置是一个运行时(更准确地说,是应用程…

基于JAVA+SpringBoot+Vue+uniApp小程序的心理健康测试平台

✌全网粉丝20W,csdn特邀作者、博客专家、CSDN新星计划导师、java领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ 🍅文末获取项目下载方式🍅 一、项目背景介绍: 该系统由三个核心角色…

数据库MySQL下载安装

MySQL下载安装地址如下: MySQL :: Download MySQL Community Server 1、下载界面 2、点击下载 3、解压记住目录 4、配置my.ini文件 未完..

走进linux

1、为什么要使用linux 稳定性和可靠性: Linux内核以其稳定性而闻名,能够持续运行数月甚至数年而不需要重新启动。这对于服务器来说至关重要,因为它们需要保持长时间的稳定运行,以提供持续的服务 安全性: Linux系统…

【python算法学习1】用递归和循环分别写下 fibonacci 斐波拉契数列,比较差异

问题: fibonacci 斐波拉契数列,用递归和循环的方法分别写,比较递归和循环的思路和写法的差别 最直接的思路,是写递归方法 循环方法的稍微有点绕,我觉得问题主要是出在,总结循环的通项公式更麻烦,难在数学…

《Linux系统编程篇》vim的使用 ——基础篇

引言 上节课我们讲了,如何将虚拟机的用户目录映射到自己windows的z盘,虽然这样之后我们可以用自己的编译器比如说Visual Studio Code,或者其他方式去操作里面的文件,但是这是可搭建的情况下,在一些特殊情况下&#xf…

【深度学习基础】MAC pycharm 专业版安装与激活

文章目录 一、pycharm专业版安装二、激活 一、pycharm专业版安装 PyCharm是一款专为Python开发者设计的集成开发环境(IDE),旨在帮助用户在使用Python语言开发时提高效率。以下是对PyCharm软件的详细介绍,包括其作用和主要功能&…

力扣-排序算法

排序算法,一般都可以使用std::sort()来快速排序。 这里介绍一些相关的算法,巩固记忆。 快速排序 跟二分查找有一丢丢像。 首先选择一个基准元素,一般就直接选择第一个。然后两个指针&#xff0c…

使用python获取城市经纬度以及城市间的距离、火车时间、所需成本等

这里写自定义目录标题 1 获取城市地理坐标2 获取交通数据3 数据存储4 代码整合 本案例研究选择了中国的五个中心城市(上海市、深圳市、北京市、广州市、杭州市)和25个边境城市(如巴彦淖尔市、白山市等)作为研究对象。通过调用高德…

Go泛型详解

引子 如果我们要写一个函数分别比较2个整数和浮点数的大小&#xff0c;我们就要写2个函数。如下&#xff1a; func Min(x, y float64) float64 {if x < y {return x}return y }func MinInt(x, y int) int {if x < y {return x}return y }2个函数&#xff0c;除了数据类…

vue实现a-model弹窗拖拽移动

通过自定义拖拽指令实现 实现效果 拖动顶部&#xff0c;可对整个弹窗实施拖拽&#xff08;如果需要拖动底部、中间内容实现拖拽&#xff0c;把下面的ant-modal-header对应改掉就行&#xff09; 代码实现 编写自定义指令 新建一个ts / js文件&#xff0c;用ts举例 import V…

前端的页面代码

根据老师教的前端页面的知识&#xff0c;加上我也是借鉴了老师上课所说的代码&#xff0c;马马虎虎的写出了页面。如下代码&#xff1a; <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><title>Title</ti…