第62步 深度学习图像识别:多分类建模(Pytorch)

基于WIN10的64位系统演示

一、写在前面

上期我们基于TensorFlow环境做了图像识别的多分类任务建模。

本期以健康组、肺结核组、COVID-19组、细菌性(病毒性)肺炎组为数据集,基于Pytorch环境,构建SqueezeNet多分类模型,因为它建模速度快。

同样,基于GPT-4辅助编程,这次改写过程就不展示了。

二、多分类建模实战

使用胸片的数据集:肺结核病人和健康人的胸片的识别。其中,健康人900张,肺结核病人700张,COVID-19病人549张、细菌性(病毒性)肺炎组900张,分别存入单独的文件夹中。

(a)直接分享代码

######################################导入包###################################
# 导入必要的包
import copy
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision import models
from torch.utils.data import DataLoader
from torch import optim, nn
from torch.optim import lr_scheduler
import os
import matplotlib.pyplot as plt
import warnings
import numpy as npwarnings.filterwarnings("ignore")
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False# 设置GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")################################导入数据集#####################################
import torch
from torchvision import datasets, transforms
import os# 数据集路径
data_dir = "./MTB-1"# 图像的大小
img_height = 100
img_width = 100# 数据预处理
data_transforms = {'train': transforms.Compose([transforms.RandomResizedCrop(img_height),transforms.RandomHorizontalFlip(),transforms.RandomVerticalFlip(),transforms.RandomRotation(0.2),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),'val': transforms.Compose([transforms.Resize((img_height, img_width)),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
}# 加载数据集
full_dataset = datasets.ImageFolder(data_dir)# 获取数据集的大小
full_size = len(full_dataset)
train_size = int(0.7 * full_size)  # 假设训练集占70%
val_size = full_size - train_size  # 验证集的大小# 随机分割数据集
torch.manual_seed(0)  # 设置随机种子以确保结果可重复
train_dataset, val_dataset = torch.utils.data.random_split(full_dataset, [train_size, val_size])# 将数据增强应用到训练集
train_dataset.dataset.transform = data_transforms['train']# 创建数据加载器
batch_size = 32
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=4)dataloaders = {'train': train_dataloader, 'val': val_dataloader}
dataset_sizes = {'train': len(train_dataset), 'val': len(val_dataset)}
class_names = full_dataset.classes###############################定义SqueezeNet模型################################
# 定义SqueezeNet模型
model = models.squeezenet1_1(pretrained=True)  # 这里以SqueezeNet 1.1版本为例
num_ftrs = model.classifier[1].in_channels# 根据分类任务修改最后一层
model.classifier[1] = nn.Conv2d(num_ftrs, len(class_names), kernel_size=(1,1))# 修改模型最后的输出层为我们需要的类别数
model.num_classes = len(class_names)model = model.to(device)# 打印模型摘要
print(model)#############################编译模型#########################################
# 定义损失函数
criterion = nn.CrossEntropyLoss()# 定义优化器
optimizer = optim.Adam(model.parameters())# 定义学习率调度器
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)# 开始训练模型
num_epochs = 50# 初始化记录器
train_loss_history = []
train_acc_history = []
val_loss_history = []
val_acc_history = []for epoch in range(num_epochs):print('Epoch {}/{}'.format(epoch, num_epochs - 1))print('-' * 10)# 每个epoch都有一个训练和验证阶段for phase in ['train', 'val']:if phase == 'train':model.train()  # 设置模型为训练模式else:model.eval()   # 设置模型为评估模式running_loss = 0.0running_corrects = 0# 遍历数据for inputs, labels in dataloaders[phase]:inputs = inputs.to(device)labels = labels.to(device)# 零参数梯度optimizer.zero_grad()# 前向with torch.set_grad_enabled(phase == 'train'):outputs = model(inputs)_, preds = torch.max(outputs, 1)loss = criterion(outputs, labels)# 只在训练模式下进行反向和优化if phase == 'train':loss.backward()optimizer.step()# 统计running_loss += loss.item() * inputs.size(0)running_corrects += torch.sum(preds == labels.data)epoch_loss = running_loss / dataset_sizes[phase]epoch_acc = (running_corrects.double() / dataset_sizes[phase]).item()# 记录每个epoch的loss和accuracyif phase == 'train':train_loss_history.append(epoch_loss)train_acc_history.append(epoch_acc)else:val_loss_history.append(epoch_loss)val_acc_history.append(epoch_acc)print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))print()# 保存模型
torch.save(model.state_dict(), 'model.pth')# 加载最佳模型权重
#model.load_state_dict(best_model_wts)
#torch.save(model, 'shufflenet_best_model.pth')
#print("The trained model has been saved.")
###########################Accuracy和Loss可视化#################################
epoch = range(1, len(train_loss_history)+1)fig, ax = plt.subplots(1, 2, figsize=(10,4))
ax[0].plot(epoch, train_loss_history, label='Train loss')
ax[0].plot(epoch, val_loss_history, label='Validation loss')
ax[0].set_xlabel('Epochs')
ax[0].set_ylabel('Loss')
ax[0].legend()ax[1].plot(epoch, train_acc_history, label='Train acc')
ax[1].plot(epoch, val_acc_history, label='Validation acc')
ax[1].set_xlabel('Epochs')
ax[1].set_ylabel('Accuracy')
ax[1].legend()#plt.savefig("loss-acc.pdf", dpi=300,format="pdf")####################################混淆矩阵可视化#############################
from sklearn.metrics import classification_report, confusion_matrix
import math
import pandas as pd
import numpy as np
import seaborn as sns
from matplotlib.pyplot import imshow# 定义一个绘制混淆矩阵图的函数
def plot_cm(labels, predictions):# 生成混淆矩阵conf_numpy = confusion_matrix(labels, predictions)# 将矩阵转化为 DataFrameconf_df = pd.DataFrame(conf_numpy, index=class_names ,columns=class_names)  plt.figure(figsize=(8,7))sns.heatmap(conf_df, annot=True, fmt="d", cmap="BuPu")plt.title('Confusion matrix',fontsize=15)plt.ylabel('Actual value',fontsize=14)plt.xlabel('Predictive value',fontsize=14)def evaluate_model(model, dataloader, device):model.eval()   # 设置模型为评估模式true_labels = []pred_labels = []# 遍历数据for inputs, labels in dataloader:inputs = inputs.to(device)labels = labels.to(device)# 前向with torch.no_grad():outputs = model(inputs)_, preds = torch.max(outputs, 1)true_labels.extend(labels.cpu().numpy())pred_labels.extend(preds.cpu().numpy())return true_labels, pred_labels# 获取预测和真实标签
true_labels, pred_labels = evaluate_model(model, dataloaders['val'], device)# 计算混淆矩阵
cm_val = confusion_matrix(true_labels, pred_labels)
a_val = cm_val[0,0]
b_val = cm_val[0,1]
c_val = cm_val[1,0]
d_val = cm_val[1,1]# 计算各种性能指标
acc_val = (a_val+d_val)/(a_val+b_val+c_val+d_val)  # 准确率
error_rate_val = 1 - acc_val  # 错误率
sen_val = d_val/(d_val+c_val)  # 灵敏度
sep_val = a_val/(a_val+b_val)  # 特异度
precision_val = d_val/(b_val+d_val)  # 精确度
F1_val = (2*precision_val*sen_val)/(precision_val+sen_val)  # F1值
MCC_val = (d_val*a_val-b_val*c_val) / (np.sqrt((d_val+b_val)*(d_val+c_val)*(a_val+b_val)*(a_val+c_val)))  # 马修斯相关系数# 打印出性能指标
print("验证集的灵敏度为:", sen_val, "验证集的特异度为:", sep_val,"验证集的准确率为:", acc_val, "验证集的错误率为:", error_rate_val,"验证集的精确度为:", precision_val, "验证集的F1为:", F1_val,"验证集的MCC为:", MCC_val)# 绘制混淆矩阵
plot_cm(true_labels, pred_labels)# 获取预测和真实标签
train_true_labels, train_pred_labels = evaluate_model(model, dataloaders['train'], device)
# 计算混淆矩阵
cm_train = confusion_matrix(train_true_labels, train_pred_labels)  
a_train = cm_train[0,0]
b_train = cm_train[0,1]
c_train = cm_train[1,0]
d_train = cm_train[1,1]
acc_train = (a_train+d_train)/(a_train+b_train+c_train+d_train)
error_rate_train = 1 - acc_train
sen_train = d_train/(d_train+c_train)
sep_train = a_train/(a_train+b_train)
precision_train = d_train/(b_train+d_train)
F1_train = (2*precision_train*sen_train)/(precision_train+sen_train)
MCC_train = (d_train*a_train-b_train*c_train) / (math.sqrt((d_train+b_train)*(d_train+c_train)*(a_train+b_train)*(a_train+c_train))) 
print("训练集的灵敏度为:",sen_train, "训练集的特异度为:",sep_train,"训练集的准确率为:",acc_train, "训练集的错误率为:",error_rate_train,"训练集的精确度为:",precision_train, "训练集的F1为:",F1_train,"训练集的MCC为:",MCC_train)# 绘制混淆矩阵
plot_cm(train_true_labels, train_pred_labels)################################模型性能参数计算################################
from sklearn import metricsdef test_accuracy_report(model, dataloader, device):true_labels, pred_labels = evaluate_model(model, dataloader, device)print(metrics.classification_report(true_labels, pred_labels, target_names=class_names)) test_accuracy_report(model, dataloaders['val'], device)def train_accuracy_report(model, dataloader, device):true_labels, pred_labels = evaluate_model(model, dataloader, device)print(metrics.classification_report(true_labels, pred_labels, target_names=class_names)) train_accuracy_report(model, dataloaders['train'], device)################################AUC曲线绘制####################################
from sklearn import metrics
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.pyplot import imshow
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns
import pandas as pd
import math
from sklearn.metrics import roc_auc_score, auc
from sklearn.preprocessing import LabelBinarizerdef multiclass_roc_auc_score(y_test, y_pred, average="macro"):# 判断 y_test 是否需要进行标签二值化if len(np.unique(y_test)) > 2:  # 假设 y_test 是类别标签,且类别数大于 2lb = LabelBinarizer()lb.fit(y_test)y_test = lb.transform(y_test)return roc_auc_score(y_test, y_pred, average=average)def plot_roc(name, labels, predictions, **kwargs):lb = LabelBinarizer()labels = lb.fit_transform(labels)  # one-hot 编码# predictions 不需要进行标签二值化# 计算ROC曲线和AUC值fpr = dict()tpr = dict()roc_auc = dict()class_num = len(class_names)for i in range(class_num):  # class_num是类别数目fpr[i], tpr[i], _ = metrics.roc_curve(labels[:, i], predictions[:, i])roc_auc[i] = metrics.auc(fpr[i], tpr[i])for i in range(class_num):plt.plot(fpr[i], tpr[i], label='ROC curve of class {0} (area = {1:0.2f})' ''.format(i, roc_auc[i]))plt.plot([0, 1], [0, 1], 'k--')plt.xlim([0.0, 1.0])plt.ylim([0.0, 1.05])plt.xlabel('False Positive Rate')plt.ylabel('True Positive Rate')plt.title('Receiver operating characteristic example')plt.legend(loc="lower right")plt.show()# 确保模型处于评估模式
model.eval()def evaluate_model_pre(model, data_loader, device):model.eval()predictions = []labels = []with torch.no_grad():for inputs, targets in data_loader:inputs = inputs.to(device)targets = targets.to(device)outputs = model(inputs)# 使用 softmax 函数,转换成概率值prob_outputs = torch.nn.functional.softmax(outputs, dim=1)predictions.append(prob_outputs.detach().cpu().numpy())labels.append(targets.detach().cpu().numpy())return np.concatenate(predictions, axis=0), np.concatenate(labels, axis=0)val_pre_auc, val_label_auc = evaluate_model_pre(model, dataloaders['val'], device)
train_pre_auc, train_label_auc = evaluate_model_pre(model, dataloaders['train'], device)auc_score_val = multiclass_roc_auc_score(val_label_auc, val_pre_auc)
auc_score_train = multiclass_roc_auc_score(train_label_auc, train_pre_auc)plot_roc('validation AUC: {0:.4f}'.format(auc_score_val), val_label_auc, val_pre_auc, color="red", linestyle='--')
plot_roc('training AUC: {0:.4f}'.format(auc_score_train), train_label_auc, train_pre_auc, color="blue", linestyle='--')print("训练集的AUC值为:",auc_score_train, "验证集的AUC值为:",auc_score_val)

(b)输出结果:学习曲线

 (c)输出结果:混淆矩阵

 (d)输出结果:性能参数

 (e)输出结果:ROC曲线

三、数据

链接:https://pan.baidu.com/s/1rqu15KAUxjNBaWYfEmPwgQ?pwd=xfyn

提取码:xfyn

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/111413.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Android Activity启动过程一:从Intent到Activity创建

关于作者:CSDN内容合伙人、技术专家, 从零开始做日活千万级APP。 专注于分享各领域原创系列文章 ,擅长java后端、移动开发、人工智能等,希望大家多多支持。 目录 一、概览二、应用内启动源码流程 (startActivity)2.1 startActivit…

ADRV9009子卡 设计原理图:FMCJ450-基于ADRV9009的双收双发射频FMC子卡 便携测试设备

FMCJ450-基于ADRV9009的双收双发射频FMC子卡 一、板卡概述 ADRV9009是一款高集成度射频(RF)、捷变收发器,提供双通道发射器和接收器、集成式频率合成器以及数字信号处理功能。北京太速科技,这款IC具备多样化的高性能和低功耗组合,FMC子…

基于亚马逊云科技无服务器服务快速搭建电商平台——部署篇

受疫情影响消费者习惯发生改变,刺激了全球电商行业的快速发展。除了依托第三方电商平台将产品销售给消费者之外,企业通过品牌官网或者自有电商平台销售商品也是近几年电商领域快速发展的商业模式。独立站电商模式可以进行多方面、全渠道的互联网市场拓展…

Git分布式版本控制系统与github

第四阶段提升 时 间:2023年8月29日 参加人:全班人员 内 容: Git分布式版本控制系统与github 目录 一、案例概述 二、版本控制系统 (一) 本地版本控制 (二)集中化的版本控制系统 &…

DP读书:鲲鹏处理器 架构与编程(十三)操作系统内核与云基础软件

操作系统内核与云基础软件 鲲鹏软件构成硬件特定软件 鲲鹏软件构成硬件特定软件1. Boot Loader2. SBSA 与 SBBR3. UEFI4. ACPI 操作系统内核Linux系统调用Linux进程调度Linux内存管理Linux虚拟文件系统Linux网络子系统Linux进程间通信Linux可加载内核模块Linux设备驱动程序Linu…

Vue 项目性能优化 — 实践指南

前言 Vue 框架通过数据双向绑定和虚拟 DOM 技术,帮我们处理了前端开发中最脏最累的 DOM 操作部分, 我们不再需要去考虑如何操作 DOM 以及如何最高效地操作 DOM;但 Vue 项目中仍然存在项目首屏优化、Webpack 编译配置优化等问题,所…

自然语言处理(四):全局向量的词嵌入(GloVe)

全局向量的词嵌入(GloVe) 全局向量的词嵌入(Global Vectors for Word Representation),通常简称为GloVe,是一种用于将词语映射到连续向量空间的词嵌入方法。它旨在捕捉词语之间的语义关系和语法关系&#…

Python小知识 - Python中的多线程

Python中的多线程 线程是进程中的一个执行单元,是轻量级的进程。一个进程可以创建多个线程,线程之间共享进程的资源,比如内存、文件句柄等。 在Python中,使用threading模块实现线程。 下面的代码创建了两个线程,一个输…

2023.8.29 关于性能测试

目录 什么是性能测试? 性能测试常见术语及其性能测试衡量指标 并发 用户数 响应时间 事务 点击率 吞吐量 思考时间 资源利用率 性能测试分类 基准性能测试 负载性能测试 压力性能测试 可靠性性能测试 性能测试执行流程 什么是性能测试? 性…

实用的面试经验分享:程序员们谈论他们的面试历程

🌷🍁 博主猫头虎 带您 Go to New World.✨🍁 🦄 博客首页——猫头虎的博客🎐 🐳《面试题大全专栏》 文章图文并茂🦕生动形象🦖简单易学!欢迎大家来踩踩~🌺 &a…

docker 学习-- 04 实践2 (lnpmr环境)

docker 学习 系列文章目录 docker 学习-- 01 基础知识 docker 学习-- 02 常用命令 docker 学习-- 03 环境安装 docker 学习-- 04 实践 1(宝塔) docker 学习-- 04 实践 2 (lnpmr环境) 文章目录 docker 学习 系列文章目录1. 配…

财务数据分析?奥威BI数据可视化工具很擅长

BI数据可视化工具通常是可以用户各行各业,用于不同主题的数据可视化分析,但面对财务数据分析这块难啃的骨头,能够好好地完成的,还真不多。接下来要介绍的这款BI数据可视化工具不仅拥有内存行列计算模型这样的智能财务指标计算功能…

flutter 上传图片并裁剪

1.首先在pubspec.yaml文件中新增依赖pub.dev image_picker: ^0.8.75 image_cropper: ^4.0.1 2.在Android的AndroidManifest.xml文件里面添加权限 <activityandroid:name"com.yalantis.ucrop.UCropActivity"android:screenOrientation"portrait"andro…

Django基础7——用户认证系统、Session管理、CSRF安全防护机制

文章目录 一、用户认证系统二、案例&#xff1a;登陆认证2.1 平台登入2.2 平台登出2.3 login_required装饰器 三、Django Session管理3.1 Django使用Session3.1.1 Cookie用法3.1.2 Session用法 3.2 案例&#xff1a;用户登录认证 四、Django CSRF安全防护机制 一、用户认证系统…

3d激光slam建图与定位(2)_aloam代码阅读

1.常用的几种loam算法 aloam 纯激光 lego_loam 纯激光 去除了地面 lio_sam imu激光紧耦合 lvi_sam 激光视觉 2.代码思路 2.1.特征点提取scanRegistration.cpp&#xff0c;这个文件的目的是为了根据曲率提取4种特征点和对当前点云进行预处理 输入是雷达点云话题 输出是 4种特征点…

透过源码理解Flutter InheritedWidget

InheritedWidget的核心是保存值和保存使用这个值的widget&#xff0c;通过对比值的变化&#xff0c;来决定是否要通知那些使用了这个值的widget更新自身。 1 updateShouldNotify和notifyClients InheritedWidget通过updateShouldNotify函数控制依赖其的子组件是否在Inherited…

vue cli构建的项目出现 Uncaught runtime errors

使用 vue/cli 脚手架构建的项目&#xff0c;在 npm run dev 运行后&#xff0c;页面出现 Uncaught runtime errors 报错遮罩层&#xff0c;如下图所示。 报错原因 这种错误通常是运行时出的问题&#xff0c;可能是网络错误&#xff0c;可能是变量未定义等等。 这种错误默认在开…

七、Kafka-Kraft 模式

目录 7.1 Kafka-Kraft 架构7.2 Kafka-Kraft 集群部署 7.1 Kafka-Kraft 架构 左图为 Kafka 现有架构&#xff0c;元数据在 zookeeper 中&#xff0c;运行时动态选举 controller&#xff0c;由controller 进行 Kafka 集群管理 右图为 kraft 模式架构&#xff08;实验性&#xff…

华为数通方向HCIP-DataCom H12-821题库(单选题:161-180)

第161题 以下关于 URPF(Unicast Reverse Path Forwarding) 的描述&#xff0c; 正确的是哪一项 A、部署了严格模式的 URPF&#xff0c;也能够可以同时部署允许匹配缺省路由模式 B、如果部署松散模式的 URPF&#xff0c;默认情况下不需要匹配明细路由 C、如果部署松散模式的…

Java IDEA Web 项目 1、创建

环境&#xff1a; IEDA 版本&#xff1a;2023.2 JDK&#xff1a;1.8 Tomcat&#xff1a;apache-tomcat-9.0.58 maven&#xff1a;尚未研究 自行完成 IDEA、JDK、Tomcat等安装配置。 创建项目&#xff1a; IDEA -> New Project 选择 Jakarta EE Template&#xff1a;选择…