第64步 深度学习图像识别:多分类建模误判病例分析(Pytorch)

基于WIN10的64位系统演示

一、写在前面

上期我们基于TensorFlow环境介绍了多分类建模的误判病例分析。

本期以健康组、肺结核组、COVID-19组、细菌性(病毒性)肺炎组为数据集,基于Pytorch环境,构建SqueezeNet多分类模型,分析误判病例,因为它建模速度快。

同样,基于GPT-4辅助编程。

二、误判病例分析实战

使用胸片的数据集:肺结核病人和健康人的胸片的识别。其中,健康人900张,肺结核病人700张,COVID-19病人549张、细菌性(病毒性)肺炎组900张,分别存入单独的文件夹中。

直接分享代码:

######################################导入包###################################
import copy
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision import models
from torch.utils.data import DataLoader
from torch import optim, nn
from torch.optim import lr_scheduler
import os
import matplotlib.pyplot as plt
import warnings
import numpy as npwarnings.filterwarnings("ignore")
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False# 设置GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")################################导入数据集#####################################
from torchvision import datasets, transforms
from torch.nn.functional import softmax
from PIL import Image
import pandas as pd
import torch.nn as nn
import timm
from torch.optim import lr_scheduler# 自定义的数据集类
class ImageFolderWithPaths(datasets.ImageFolder):def __getitem__(self, index):original_tuple = super(ImageFolderWithPaths, self).__getitem__(index)path = self.imgs[index][0]tuple_with_path = (original_tuple + (path,))return tuple_with_path# 数据集路径
data_dir = "./MTB-1"# 图像的大小
img_height = 256
img_width = 256# 数据预处理
data_transforms = {'train': transforms.Compose([transforms.RandomResizedCrop(img_height),transforms.RandomHorizontalFlip(),transforms.RandomVerticalFlip(),transforms.RandomRotation(0.2),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),'val': transforms.Compose([transforms.Resize((img_height, img_width)),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
}# 加载数据集
full_dataset = ImageFolderWithPaths(data_dir, transform=data_transforms['train'])# 获取数据集的大小
full_size = len(full_dataset)
train_size = int(0.8 * full_size)  # 假设训练集占70%
val_size = full_size - train_size  # 验证集的大小# 随机分割数据集
torch.manual_seed(0)  # 设置随机种子以确保结果可重复
train_dataset, val_dataset = torch.utils.data.random_split(full_dataset, [train_size, val_size])# 应用数据增强到训练集和验证集
train_dataset.dataset.transform = data_transforms['train']
val_dataset.dataset.transform = data_transforms['val']# 创建数据加载器
batch_size = 8
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=0)dataloaders = {'train': train_dataloader, 'val': val_dataloader}
dataset_sizes = {'train': len(train_dataset), 'val': len(val_dataset)}
class_names = full_dataset.classes# 获取数据集的类别
class_names = full_dataset.classes# 保存预测结果的列表
results = []###############################定义SqueezeNet模型################################
# 定义SqueezeNet模型
model = models.squeezenet1_1(pretrained=True)  # 这里以SqueezeNet 1.1版本为例
num_ftrs = model.classifier[1].in_channels# 根据分类任务修改最后一层
# 这里我们改变模型的输出层为4,因为我们做的是四分类
model.classifier[1] = nn.Conv2d(num_ftrs, 4, kernel_size=(1,1))# 修改模型最后的输出层为我们需要的类别数
model.num_classes = 4model = model.to(device)# 打印模型摘要
print(model)#############################编译模型#########################################
# 定义损失函数
criterion = nn.CrossEntropyLoss()# 定义优化器
optimizer = torch.optim.Adam(model.parameters())# 定义学习率调度器
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)# 开始训练模型
num_epochs = 20# 初始化记录器
train_loss_history = []
train_acc_history = []
val_loss_history = []
val_acc_history = []for epoch in range(num_epochs):print('Epoch {}/{}'.format(epoch, num_epochs - 1))print('-' * 10)# 每个epoch都有一个训练和验证阶段for phase in ['train', 'val']:if phase == 'train':model.train()  # 设置模型为训练模式else:model.eval()   # 设置模型为评估模式running_loss = 0.0running_corrects = 0# 遍历数据for inputs, labels, paths in dataloaders[phase]:inputs = inputs.to(device)labels = labels.to(device)# 零参数梯度optimizer.zero_grad()# 前向with torch.set_grad_enabled(phase == 'train'):outputs = model(inputs)_, preds = torch.max(outputs, 1)loss = criterion(outputs, labels)# 只在训练模式下进行反向和优化if phase == 'train':loss.backward()optimizer.step()# 统计running_loss += loss.item() * inputs.size(0)running_corrects += torch.sum(preds == labels.data)epoch_loss = running_loss / dataset_sizes[phase]epoch_acc = (running_corrects.double() / dataset_sizes[phase]).item()# 记录每个epoch的loss和accuracyif phase == 'train':train_loss_history.append(epoch_loss)train_acc_history.append(epoch_acc)else:val_loss_history.append(epoch_loss)val_acc_history.append(epoch_acc)print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))print()# 保存模型
torch.save(model.state_dict(), 'SqueezeNet_model-m-s.pth')# 加载最佳模型权重
#model.load_state_dict(best_model_wts)
#torch.save(model, 'shufflenet_best_model.pth')
#print("The trained model has been saved.")
###########################误判病例分析#################################
import os
import pandas as pd
from collections import defaultdict# 判定组别的字典
group_dict = {("COVID-19", "Normal"): "B",("COVID-19", "Pneumonia"): "C",("COVID-19", "Tuberculosis"): "D",("Normal", "COVID-19"): "E",("Normal", "Pneumonia"): "F",("Normal", "Tuberculosis"): "G",("Pneumonia", "COVID-19"): "H",("Pneumonia", "Normal"): "I",("Pneumonia", "Tuberculosis"): "J",("Tuberculosis", "COVID-19"): "K",("Tuberculosis", "Normal"): "L",("Tuberculosis", "Pneumonia"): "M",
}# 创建一个字典来保存所有的图片信息
image_predictions = {}# 循环遍历所有数据集(训练集和验证集)
for phase in ['train', 'val']:# 设置模型的状态model.eval()# 遍历数据for inputs, labels, paths in dataloaders[phase]:inputs = inputs.to(device)labels = labels.to(device)# 计算模型的输出with torch.no_grad():outputs = model(inputs)_, preds = torch.max(outputs, 1)# 循环遍历每一个批次的结果for path, pred in zip(paths, preds):# 提取图片的类别actual_class = os.path.split(os.path.dirname(path))[-1] # 提取图片的名称image_name = os.path.basename(path)# 获取预测的类别predicted_class = class_names[pred]# 判断预测的分组类型if actual_class == predicted_class:group_type = 'A'elif (actual_class, predicted_class) in group_dict:group_type = group_dict[(actual_class, predicted_class)]else:group_type = 'Other'  # 如果没有匹配的条件,可以归类为其他# 保存到字典中image_predictions[image_name] = [phase, actual_class, predicted_class, group_type]# 将字典转换为DataFrame
df = pd.DataFrame.from_dict(image_predictions, orient='index', columns=['Dataset Type', 'Actual Class', 'Predicted Class', 'Group Type'])# 保存到CSV文件中
df.to_csv('result-m-s.csv')

四、改写过程

先说策略:首先,先把二分类的误判病例分析代码改成四分类的;其次,用咒语让GPT-4帮我们续写代码已达到误判病例分析。

提供咒语如下:

①改写{代码1},改变成4分类的建模。代码1为:{XXX};

在{代码1}的基础上改写代码,达到下面要求:

(1)首先,提取出所有图片的“原始图片的名称”、“属于训练集还是验证集”、“预测为分组类型”;文件的路劲格式为:例如,“MTB-1\Normal\XXX.png”属于Normal,“MTB-1\COVID-19\XXX.jpg”属于COVID-19,“MTB-1\Pneumonia\XXX.jpeg”属于Pneumonia,“MTB-1\Tuberculosis\XXX.png”属于Tuberculosis;

(2)其次,根据样本预测结果,把样本分为以下若干组:(a)预测正确的图片,全部判定为A组;(b)本来就是COVID-19的图片,预测为Normal,判定为B组;(c)本来就是COVID-19的图片,预测为Pneumonia,判定为C组;(d)本来就是COVID-19的图片,预测为Tuberculosis,判定为D组;(e)本来就是Normal的图片,预测为COVID-19,判定为E组;(f)本来就是Normal的图片,预测为Pneumonia,判定为F组;(g)本来就是Normal的图片,预测为Tuberculosis,判定为G组;(h)本来就是Pneumonia的图片,预测为COVID-19,判定为H组;(i)本来就是Pneumonia的图片,预测为Normal,判定为I组;(j)本来就是Pneumonia的图片,预测为Tuberculosis,判定为J组;(k)本来就是Tuberculosis的图片,预测为COVID-19,判定为H组;(l)本来就是Tuberculosis的图片,预测为Normal,判定为I组;(m)本来就是Tuberculosis的图片,预测为Pneumonia,判定为J组;

(3)居于以上计算的结果,生成一个名为result-m.csv表格文件。列名分别为:“原始图片的名称”、“属于训练集还是验证集”、“预测为分组类型”、“判定的组别”。其中,“原始图片的名称”为所有图片的图片名称;“属于训练集还是验证集”为这个图片属于训练集还是验证集;“预测为分组类型”为模型预测该样本是哪一个分组;“判定的组别”为根据步骤(2)判定的组别,从A到J一共十组选择一个。

(4)需要把所有的图片都进行上面操作,注意是所有图片,而不只是一个批次的图片。

代码1为:{XXX}

③还需要根据报错做一些调整即可,自行调整。

最后,看看结果:

模型只运行了2次,所以效果很差哈,全部是预测成了COVID-19。

四、数据

链接:https://pan.baidu.com/s/1rqu15KAUxjNBaWYfEmPwgQ?pwd=xfyn

提取码:xfyn

五、结语

深度学习图像分类的教程到此结束,洋洋洒洒29篇,涉及到的算法和技巧也够发一篇SCI了。当然,图像识别还有图像分割和目标识别两块内容,就放到最后再说了。下一趴,我们来介绍时间序列建模!!!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/126487.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【漏洞复现】EnjoySCM存在文件上传漏洞

漏洞描述 EnjoySCM是一款适应于零售企业的供应链管理软件,主要为零售企业的供应商提供服务。EnjoySCM的目的是通过信息技术,实现供应商和零售企业的快速、高效、准确的信息沟通、管理信息交流。。 该系统存在任意文件上传漏洞,攻击者通过漏洞可以获取服务器的敏感信息。 …

【监控系统】Promethus整合Alertmanager监控告警邮件通知

【监控系统】Promethus整合Alertmanager监控告警邮件通知 Alertmanager是一种开源软件,用于管理和报警监视警报。它与Prometheus紧密集成,后者是一种流行的开源监视和警报系统。Alertmanager从多个源接收警报和通知,并根据一组配置规则来决定…

【Python】环境的搭建

前言 要想能够进行 Python 开发, 就需要搭建好 Python 的环境. 需要安装的环境主要是两个部分: 运行环境: Python开发环境: PyCharm 一、安装 Python 1.找到官方网站 官网:Welcome to Python.org 2.找到下载页面 点击download中的Windows 3.选择稳定版中的Win…

2023年MySQL实战核心技术第一篇

目录 四 . 基础架构:一条SQl查询语句是如何执行的? 4.1 MySQL逻辑架构图: 4.2 MySQL的Server层和存储引擎层 4.2.1 连接器 4.2.1.1 解释 4.2.1.2 MySQL 异常重启 解决方案: 4.2.1.2.1. 定期断开长连接: 4.2.1.2.2. 初始…

滑动谜题 -- BFS

滑动谜题 输入:board [[4,1,2],[5,0,3]] 输出:5 解释: 最少完成谜板的最少移动次数是 5 , 一种移动路径: 尚未移动: [[4,1,2],[5,0,3]] 移动 1 次: [[4,1,2],[0,5,3]] 移动 2 次: [[0,1,2],[4,5,3]] 移动 3 次: [[1,0,2],[4,5,3]…

计算机网络的故事——了解Web及网络基础

了解Web及网络基础 文章目录 了解Web及网络基础一、使用 HTTP 协议访问 Web二、HTTP 的诞生三、网络基础 TCP/IP四、与 HTTP 关系密切的协议 : IP、TCP 和 DNS 一、使用 HTTP 协议访问 Web 根据Web浏览器指定的URL,从对应的服务器中获取文件资源,从而显…

Java发送(QQ)邮箱、验证码发送

前言 使用Java应用程序发送 E-mail 十分简单,但是首先需要在项目中导入 JavaMail API 和Java Activation Framework (JAF) 的jar包。 菜鸟教程提供的下载链接: JavaMail mail.jar 1.4.5JAF(版本 1.1.1) activation.jar 1、准备…

Mojo-SDK详细安装教程

Mojo-SDK安装 运行环境:windows11wsl2(ubuntu1804) 截至20230909,windows,mac系统暂时不支持 step1: Install VS Code, the WSL extension, and the Mojo extension. step2: Install Ubuntu 22.04 for WSL and open it. step…

openpnp - 底部相机高级矫正后,底部相机看不清吸嘴的解决方法

文章目录 openpnp - 底部相机高级矫正后,底部相机看不清吸嘴的解决方法概述解决思路备注补充 - 新问题 - N1吸嘴到底部相机十字中心的位置差了很多END openpnp - 底部相机高级矫正后,底部相机看不清吸嘴的解决方法 概述 自从用openpnp后, 无论版本(dev/test), 都发现一个大概…

C++多态案例-设计计算器类

1.前置知识点 多态是面向对象的三大特性之一 多态分为两类 静态多态:函数重载和运算符重载都属于静态多态,复用函数名动态多态:派生类和虚函数实现运行时多态 静态多态和动态多态的区别 静态多态的函数地址早绑定-----编译阶段确定函数地…

【JVM】垃圾收集算法

文章目录 分代收集理论标记-清除算法标记-复制算法标记-整理算法 分代收集理论 当前商业虚拟机的垃圾收集器,大多数都遵循了“分代收集”(Generational Collection)[1]的理论进 行设计,分代收集名为理论,实质是一套符…

学会用命令行创建uni-app项目并用vscode开放项目

(创作不易,感谢有你,你的支持,就是我前行的最大动力,如果看完对你有帮助,请留下您的足迹) 目录 创建 uni-app 项目 命令行创建 uni-app 项目 编译和运行 uni-app 项目: 用 VS Code 开发 uni…

pytest笔记2: fixture

1. fixture 通常是对测试方法和测试函数,测试类整个测试文件进行初始化或是还原测试环境 # 功能函数 def multiply(a, b):return a * b # ------------ fixture---------------def setup_module(module):print("setup_module 在当前文件中所有测试用例之前&q…

项目(智慧教室)第三部分,人机交互在stm32上的实现

一。使用软件 1.stm32cubemx中针对汉字提供的软件 2.对数据进行处理 2.上面点击ok--》这里选择确定 3.这里选择保存即可由字符库,但是需要占用内存太大,需35M,但是stm32只有几百k,所以需要自己删减。 生成中文字符(用…

微服务-OpenFeign基本使用

一、前言 二、OpenFeign基本使用 1、OpenFeign简介 OpenFeign是一种声明式、模板化的HTTP客户端,它使得调用RESTful网络服务变得简单。在Spring Cloud中使用OpenFeign,可以做到像调用本地方法一样使用HTTP请求访问远程服务,开发者无需关注…

调教 文心一言 生成 AI绘画 提示词(Midjourney)

文章目录 第一步第二步第三步第四步第五步第六步第七步第八步 文心一言支持连续对话 我瞎玩的非专业哈哈 第一步 你好,今天我们要用扩散模型创建图像。我会给你提供一些信息。行吗? 第二步 这是Midjourney的工作原理:Midjourney是另一个基于ai的工具,能…

微服务-sentinel详解

文章目录 一、前言二、知识点主要构成1、sentinel基本概念1.1、资源1.2、规则 2、sentinel的基本功能2.1、流量控制2.2、熔断降级 3、控制台安装3.1、官网下载jar包3.2、启动控制台 4、项目集成 sentinel4.1、依赖配置4.2、配置文件中配置sentinel控制台地址信息4.3、配置流控4…

JVM:JIT实时编译器

一、相关 ⾼级编程语⾔按照程序的执⾏⽅式分为两种 编译型:一次性将代码编译为机器码解释型:通过解释器一句一句的将代码解释为机器码之后,再运行。每个语句都是执行的时候才翻译。 JAVA代码执行过程 (编译阶段)首先将…

必须收藏 | 如何完全卸载ArcGIS

好多小伙伴在卸载ArcGIS过程都遇到了卸载不彻底无法重新安装新版本,卸载残留的注册表找不到等一系列问题,今天小编为大家整理了几个如何完全卸载ArcGIS的方法,希望能够帮到大家! #1快捷版 1、开始>控制面板>添加删除程序&…

Flink实时计算中台Kubernates功能改造点

背景 平台为数据开发人员提供基本的实时作业的管理功能,其中包括jar、sql等作业的在线开发;因此中台需要提供一个统一的SDK支持平台能够实现flink jar作业的发布;绝大多数情况下企业可能会考虑Flink On Yarn的这个发布模式,但是伴随云原生的呼声越来越大,一些企业不希望部…