深度学习笔记_6经典预训练网络LeNet-18解决FashionMNIST数据集

1、 调用模型库,定义参数,做数据预处理

import numpy as np
import torch
from torchvision.datasets import FashionMNIST
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim
from torch import nn
from sklearn.metrics import confusion_matrix, accuracy_score, precision_score, recall_score, f1_score, roc_curve, auc
import matplotlib.pyplot as plt
from torchvision import models# 检查 GPU 可用性
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)# 设置超参数
train_batch_size = 64
test_batch_size = 64
learning_rate = 0.001
num_epochs = 50# 定义数据转换操作
transform = transforms.Compose([transforms.RandomRotation(degrees=[-30, 30]),   # 随机旋转transforms.RandomHorizontalFlip(),   # 随机水平翻转transforms.Resize((224, 224)),  # 调整图像大小transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),   # 颜色抖动transforms.ToTensor(),  # 将图像转换为张量transforms.Normalize((0.5,), (0.5,))
])

2、下载FashionMNIST训练集

# 下载FashionMNIST训练集
trainset = FashionMNIST(root='data', train=True,download=True, transform=transform)# 下载FashionMNIST测试集
testset = FashionMNIST(root='data', train=False,download=True, transform=transform)# 创建 DataLoader 对象
train_loader = DataLoader(trainset, batch_size=train_batch_size, shuffle=True)
test_loader = DataLoader(testset, batch_size=test_batch_size, shuffle=False)

3、使用预训练的ResNet-18模型

# 使用预训练的ResNet-18模型
model = models.resnet18(pretrained=True)
# 修改最后一层,使其适应FashionMNIST的输出类别数
model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
model.fc = nn.Linear(model.fc.in_features, 10)
model = model.to(device)# 冻结预训练模型的参数
for param in model.parameters():param.requires_grad = False# 只训练模型的最后一层
for param in model.fc.parameters():param.requires_grad = True
# 初始化优化器和损失函数
optimizer = optim.Adam(model.fc.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()

4、 训练循环

# 记录训练和测试过程中的损失和准确率
train_losses = []
test_losses = []
train_accuracies = []
test_accuracies = []
conf_matrix_list = []
accuracy_list = []
error_rate_list = []
precision_list = []
recall_list = []
f1_score_list = []
roc_auc_list = []# 训练循环
for epoch in range(num_epochs):model.train()train_loss = 0.0correct = 0total = 0for batch_idx, (data, target) in enumerate(train_loader):optimizer.zero_grad()data, target = data.to(device), target.to(device)  # 将数据移到 GPU 上output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()train_loss += loss.item()# 计算训练准确率_, predicted = output.max(1)total += target.size(0)correct += predicted.eq(target).sum().item()# 计算平均训练损失和训练准确率train_loss /= len(train_loader)train_accuracy = 100. * correct / totaltrain_losses.append(train_loss)train_accuracies.append(train_accuracy)# 测试模型model.eval()test_loss = 0.0correct = 0all_labels = []all_preds = []with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)  # 将数据移到 GPU 上output = model(data)test_loss += criterion(output, target).item()pred = output.argmax(dim=1, keepdim=True)correct += pred.eq(target.view_as(pred)).sum().item()all_labels.extend(target.cpu().numpy())  # 将结果移到 CPU 上all_preds.extend(pred.cpu().numpy())  # 将结果移到 CPU 上# 计算平均测试损失和测试准确率test_loss /= len(test_loader)test_accuracy = 100. * correct / len(test_loader.dataset)test_losses.append(test_loss)test_accuracies.append(test_accuracy)# 计算额外的指标conf_matrix = confusion_matrix(all_labels, all_preds)conf_matrix_list.append(conf_matrix)accuracy = accuracy_score(all_labels, all_preds)accuracy_list.append(accuracy)error_rate = 1 - accuracyerror_rate_list.append(error_rate)precision = precision_score(all_labels, all_preds, average='weighted')recall = recall_score(all_labels, all_preds, average='weighted')f1 = f1_score(all_labels, all_preds, average='weighted')precision_list.append(precision)recall_list.append(recall)f1_score_list.append(f1)fpr, tpr, thresholds = roc_curve(all_labels, all_preds, pos_label=1)roc_auc = auc(fpr, tpr)roc_auc_list.append(roc_auc)# 打印每个 epoch 的指标print(f'Epoch [{epoch + 1}/{num_epochs}] -> Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%, Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%')

5、绘制Loss、Accuracy曲线图, 计算混淆矩阵

import seaborn as sns
# 绘制Loss曲线图
plt.figure()
plt.plot(train_losses, label='Train Loss', color='blue')
plt.plot(test_losses, label='Test Loss', color='red')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title('Loss Curve')
plt.grid(True)
plt.show()# 绘制Accuracy曲线图
plt.figure()
plt.plot(train_accuracies, label='Train Accuracy', color='red')  # 绘制训练准确率曲线
plt.plot(test_accuracies, label='Test Accuracy', color='green')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.title('Accuracy Curve')
plt.grid(True)
plt.show()# 计算混淆矩阵
confusion_mat = confusion_matrix(all_labels, all_preds)
class_labels = [str(i) for i in range(10)]
plt.figure()
sns.heatmap(confusion_mat, annot=True, fmt='d', cmap='Blues', cbar=False)
plt.xlabel('Predicted Labels')
plt.ylabel('True Labels')
plt.title('Confusion Matrix')
plt.savefig('confusion_matrix.png')
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/219670.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Vue学习笔记-Vue3中的customRef

作用 创建一个自定义的ref&#xff0c;并对其依赖项的更新和触发进行显式控制 案例 描述&#xff1a;向输入框中输入内容&#xff0c;在下方延迟1秒展示输入内容 代码&#xff1a; <template><input type"text" v-model"keyword"><h3&…

关于“Python”的核心知识点整理大全24

目录 ​编辑 10.1.6 包含一百万位的大型文件 pi_string.py 10.1.7 圆周率值中包含你的生日吗 10.2 写入文件 10.2.1 写入空文件 write_message.py programming.txt 10.2.2 写入多行 10.2.3 附加到文件 write_message.py programming.txt 10.3 异常 10.3.1 处理 Ze…

数据科学知识库

​ 我的博客是一个技术分享平台&#xff0c;涵盖了机器学习、数据可视化、大数据分析、数学统计学、推荐算法、Linux命令及环境搭建&#xff0c;以及Kafka、Flask、FastAPI、Docker等组件的使用教程。 在这个信息时代&#xff0c;数据已经成为了一种新的资源&#xff0c;而机…

环境搭建及源码运行_java环境搭建_mysql安装

书到用时方恨少、觉知此时要躬行&#xff1b;拥有技术&#xff0c;成就未来&#xff0c;抖音视频教学地址&#xff1a;​​​​​​​​​​​​​​ 1、介绍 MySQL是一个关系型数据库管理系统&#xff0c;由瑞典MySQL AB 公司开发&#xff0c;属于 Oracle旗下产品。MySQL是最…

go原生http开发简易blog(一)项目简介与搭建

文章目录 一、项目简介二、项目搭建前置知识三、首页- - -前端文件与后端结构体定义四、配置文件加载五、构造假数据- - -显示首页内容 代码地址&#xff1a;https://gitee.com/lymgoforIT/goblog 一、项目简介 使用Go原生http开发一个简易的博客系统&#xff0c;包含一下功能…

UE5:Lumen 框架

1.Lumen渲染流程框架 2.Lumen基本概念 2.1 LumenCard & LumenMeshCards LumenMeshCards&#xff1a;一组带有方向性的模型简化代理&#xff0c;视模型复杂度不同可能包含6个及以上数量的LumenCard&#xff1b;用来提供光照采样的位置和方向。 2.2 LumenCardPage & Lu…

Pycharm 如何更改成中文版| Python循环语句| for 和 else 的搭配使用

&#x1f308;write in front&#x1f308; &#x1f9f8;大家好&#xff0c;我是Aileen&#x1f9f8;.希望你看完之后&#xff0c;能对你有所帮助&#xff0c;不足请指正&#xff01;共同学习交流. &#x1f194;本文由Aileen_0v0&#x1f9f8; 原创 CSDN首发&#x1f412; 如…

【Linux】在vim中批量注释与批量取消注释

在vim编辑器中&#xff0c;批量注释和取消注释的操作可以通过进入V-BLOCK模式、选择要注释或取消注释的内容、输入注释符号或选中已有的注释符号和按键完成。这些操作可以大大提高代码或文本的编写和修改效率&#xff0c;是vim编辑器中常用的操作之一。 1.在vim中批量注释的步…

常用网安渗透工具及命令(扫目录、解密爆破、漏洞信息搜索)

目录 dirsearch&#xff1a; dirmap&#xff1a; 输入目标 文件读取 ciphey&#xff08;很强的一个自动解密工具&#xff09;&#xff1a; john(破解密码)&#xff1a; whatweb指纹识别&#xff1a; searchsploit&#xff1a; 例1&#xff1a; 例2&#xff1a; 例3&…

<JavaEE> 网络编程 -- 网络通信基础(协议和协议分层、数据封装和分用)

目录 一、IP地址 1&#xff09;IP地址的概念 2&#xff09;IP地址的格式 二、端口号 1&#xff09;端口号的概念 2&#xff09;端口号的格式 3&#xff09;什么是知名端口号&#xff1f; 三、协议 1&#xff09;协议的概念 2&#xff09;协议的作用 3&#xff09;TC…

【idea】解决sprintboot项目创建遇到的问题

目录 一、报错Plugin ‘org.springframework.boot:spring-boot-maven-plugin:‘ not found 二、报错java: 错误: 无效的源发行版&#xff1a;17 三、java: 无法访问org.springframework.web.bind.annotation.CrossOrigin 四、整合mybatis的时候&#xff0c;报java.lang.Ill…

人工智能中的核心概念

1 概述 人工智能英文缩写为AI&#xff0c;是一种由人制造出来的机器&#xff0c;该机器可以模仿人的思想和行为&#xff0c;从而体现出一种智能的反应。 人工智能的产业链分为基础层、技术层、应用层三个层次。 基础层包括&#xff1a;芯片、大数据、算法系统、网络等多项基础…

基于Tkinter和OpenCV的目标检测程序源码+权重文件,实现摄像头和视频文件的实时目标检测采用YOLOv8模型进行目标检测

基于Tkinter和OpenCV的目标检测程序源码权重文件&#xff0c;实现摄像头和视频文件的实时目标检测采用YOLOv8模型进行目标检测 项目描述 本项目是一个基于Tkinter和OpenCV的目标检测应用程序&#xff0c;实现了摄像头和视频文件的实时目标检测。通过YOLOv8模型进行目标检测&a…

ShuffleNet V1+V2(pytorch)

V1 V1根本思想&#xff1a; 1.GConv替换resnet的普通1*1Conv 2.GConv后加channel shuffle模块 对GConv的不同组进行重新组合。channel_shuffle a是resnet模块&#xff0c;b&#xff0c;c是ShuffleNetV1的block&#xff0c;在V1版中&#xff0c;两模块branch2的第一个1*1卷积…

JavaScript 事件冒泡与捕获机制 --- 带动态图理解

&#xff08;1&#xff09;.事件捕获 从根元素往上传递 --- ---&#xff08;由外到内&#xff09; &#xff08;2&#xff09;.事件冒泡 从元素传递到它的根源素 --- --- &#xff08;由内到外&#xff09; 代码&#xff1a; <!DOCTYPE html> <html lang"en&q…

在 Windows 11/10 上恢复已删除文本文件的 4 种方法

您是否不小心从桌面上删除了文本文件&#xff1f;不用担心。您可以在 Windows 上恢复已删除的文本文件&#xff01; 对于那些有大量工作要做的人来说&#xff0c;便利贴无疑是一种福音。便利贴能够立即记下任何内容并使其可见&#xff0c;而不会占用太多屏幕空间&#xff0c;因…

vue写了这么久了您是否知道:为什么data属性是一个函数而不是一个对象?

一、实例和组件定义data的区别 vue实例的时候定义data属性既可以是一个对象&#xff0c;也可以是一个函数 const app new Vue({el:"#app",// 对象格式data:{foo:"foo"},// 函数格式data(){return {foo:"foo"}} })组件中定义data属性&#xff…

Vue 项目关于在生产环境下调试

前言 开发项目时&#xff0c;在本地调试没问题&#xff0c;但是部署到生产会遇到一些很奇怪的问题&#xff0c;本地又没法调&#xff0c;就需要在生产环境/域名下进行调试。 在这里介绍一个插件Vue force dev ,浏览器扩展里下载 即便是设置了Vue.config.devtoolsfalse 只要安…

【Spark面试】Spark面试题答案

目录 1、spark的有几种部署模式&#xff0c;每种模式特点&#xff1f;&#xff08;☆☆☆☆☆&#xff09; 2、Spark为什么比MapReduce块&#xff1f;&#xff08;☆☆☆☆☆&#xff09; 3、简单说一下hadoop和spark的shuffle相同和差异&#xff1f;&#xff08;☆☆☆☆☆…

音视频技术开发周刊 | 324

每周一期&#xff0c;纵览音视频技术领域的干货。 新闻投稿&#xff1a;contributelivevideostack.com。 467亿参数MoE追平GPT-3.5&#xff01;爆火开源Mixtral模型细节首公开&#xff0c;中杯逼近GPT-4 今天&#xff0c;Mistral AI公布了Mixtral 8x7B的技术细节&#xff0c;不…