LeNet实验 四分类 与 四分类变为多个二分类

 

目录

1. 划分二分类

2. 训练独立的二分类模型

3. 二分类预测结果代码

4. 二分类预测结果

5 改进训练模型

6 优化后 预测结果代码

 7 优化后预测结果

8 训练四分类模型 

9 预测结果代码

10 四分类结果识别


1. 划分二分类

可以根据不同的类别进行多个划分,以实现NonDemented为例,划分为NonDemented和Demented两类,不属于NonDemented的全都属于Demented

2. 训练独立的二分类模型

import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.preprocessing.image import ImageDataGeneratorfrom 文件准备 import data_dir# 数据生成器
train_datagen = ImageDataGenerator(rescale=1./255,shear_range=0.2,zoom_range=0.2,horizontal_flip=True,validation_split=0.2  # 20%用于验证
)train_generator = train_datagen.flow_from_directory(data_dir,target_size=(28, 28),batch_size=32,class_mode='binary',subset='training'
)validation_generator = train_datagen.flow_from_directory(data_dir,target_size=(28, 28),batch_size=32,class_mode='binary',subset='validation'
)# 构建LeNet-5模型
model = models.Sequential()
model.add(layers.Conv2D(6, (5, 5), activation='relu', input_shape=(28, 28, 3), padding='same'))
model.add(layers.AveragePooling2D((2, 2)))
model.add(layers.Conv2D(16, (5, 5), activation='relu', padding='same'))
model.add(layers.AveragePooling2D((2, 2)))
model.add(layers.Conv2D(120, (5, 5), activation='relu', padding='same'))
model.add(layers.Flatten())
model.add(layers.Dense(84, activation='relu'))
model.add(layers.Dense(1, activation='sigmoid'))# 编译模型
model.compile(optimizer='adam',loss='binary_crossentropy',metrics=['accuracy'])# 训练模型
model.fit(train_generator,steps_per_epoch=train_generator.samples // train_generator.batch_size,epochs=10,validation_data=validation_generator,validation_steps=validation_generator.samples // validation_generator.batch_size
)# 保存模型
model.save('lenet_binary_classification_model.h5')

3. 预测结果代码

import tensorflow as tf
from tensorflow.keras.preprocessing import image
import numpy as np
import matplotlib.pyplot as plt# 加载模型
model = tf.keras.models.load_model('lenet_binary_classification_model.h5')# 预处理图像
def preprocess_image(img_path):img = image.load_img(img_path, target_size=(28, 28))img_array = image.img_to_array(img) / 255.0img_array = np.expand_dims(img_array, axis=0)return img_array# 预测图像
img_path = 'D:\Pycharm_workspace\LeNet实验_二分类\Demented\moderateDem24.jpg'  # 测试图像路径
img_array = preprocess_image(img_path)
prediction = model.predict(img_array)
predicted_class = 'Demented' if prediction[0][0] > 0.5 else 'NonDemented'print(f'The predicted class is: {predicted_class}')# 显示图像
img = image.load_img(img_path, target_size=(28, 28))
plt.imshow(img)
plt.title(f'Predicted: {predicted_class}')
plt.show()

4. 预测结果

Demented结果

 NonDemented结果没有。。。。。。

竟然全都没有。。。。因为预测的全部都是Demented

疯狂找原因中

猜测是像素太低使得训练的模型准确率太低

于是重新训练

5 改进训练模型

进行重新训练

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt# 定义LeNet模型
def create_lenet_model(input_shape):model = Sequential([Conv2D(6, (5, 5), activation='relu', input_shape=input_shape, padding='same'),MaxPooling2D((2, 2), strides=2),Conv2D(16, (5, 5), activation='relu'),MaxPooling2D((2, 2), strides=2),Flatten(),Dense(120, activation='relu'),Dense(84, activation='relu'),Dense(1, activation='sigmoid')])model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])return model# 数据增强和数据生成器
train_datagen = ImageDataGenerator(rescale=1./255, validation_split=0.2)# 训练数据生成器
train_generator = train_datagen.flow_from_directory('D:\Pycharm_workspace\LeNet实验_二分类\image',target_size=(176, 208),batch_size=32,class_mode='binary',subset='training'
)# 验证数据生成器
validation_generator = train_datagen.flow_from_directory('D:\Pycharm_workspace\LeNet实验_二分类\image',target_size=(176, 208),batch_size=32,class_mode='binary',subset='validation'
)# 创建并训练模型
input_shape = (176, 208, 3)
model = create_lenet_model(input_shape)
history = model.fit(train_generator, epochs=10, validation_data=validation_generator)# 保存模型
model.save('dementia_classification_model.h5')# 绘制训练和验证损失
plt.plot(history.history['loss'], label='训练损失')
plt.plot(history.history['val_loss'], label='验证损失')
plt.title('训练和验证损失')
plt.xlabel('时期')
plt.ylabel('损失')
plt.legend()
plt.show()# 绘制训练和验证准确率
plt.plot(history.history['accuracy'], label='训练准确率')
plt.plot(history.history['val_accuracy'], label='验证准确率')
plt.title('训练和验证准确率')
plt.xlabel('时期')
plt.ylabel('准确率')
plt.legend()
plt.show()

 这里还有图形画loss与准确率但是我忘记保存了,就用控制台的输出

 可以看到loss值非常小而且准确率是100

6 优化后 预测结果代码

import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing import image
import numpy as np
import matplotlib.pyplot as plt
import os# 加载模型
model = load_model('dementia_classification_model.h5')# 定义类别标签
class_labels = ['Demented', 'NonDemented']# 预测函数
def predict_image(img_path):img = image.load_img(img_path, target_size=(176, 208))img_array = image.img_to_array(img)img_array = np.expand_dims(img_array, axis=0)img_array /= 255.0prediction = model.predict(img_array)predicted_class = class_labels[int(prediction[0] > 0.5)]# 显示图像和预测结果plt.imshow(image.load_img(img_path))plt.title(f'Predicted: {predicted_class}')plt.axis('off')plt.show()# 预测并展示结果
img_path = r'D:\Pycharm_workspace\LeNet实验_二分类\image\NonDemented\nonDem1.jpg'  # 替换为你的图片路径
predict_image(img_path)

 7 优化后预测结果

 图片与预测结果对应上了(右侧是图片链接可以看到是Dem的类型)

 NonDem的也是对应上了

 就此训练完成

 

8 训练四分类模型 

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt# 定义LeNet模型
def create_lenet_model(input_shape):model = Sequential([Conv2D(6, (5, 5), activation='relu', input_shape=input_shape, padding='same'),MaxPooling2D((2, 2), strides=2),Conv2D(16, (5, 5), activation='relu'),MaxPooling2D((2, 2), strides=2),Flatten(),Dense(120, activation='relu'),Dense(84, activation='relu'),Dense(4, activation='softmax')])model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])return model# 数据增强和数据生成器
train_datagen = ImageDataGenerator(rescale=1./255, validation_split=0.2)# 训练数据生成器
train_generator = train_datagen.flow_from_directory('D:\Pycharm_workspace\LeNet实验_四分类\image',target_size=(176, 208),batch_size=32,class_mode='categorical',subset='training'
)# 验证数据生成器
validation_generator = train_datagen.flow_from_directory('D:\Pycharm_workspace\LeNet实验_四分类\image',target_size=(176, 208),batch_size=32,class_mode='categorical',subset='validation'
)# 创建并训练模型
input_shape = (176, 208, 3)
model = create_lenet_model(input_shape)
history = model.fit(train_generator, epochs=10, validation_data=validation_generator)# 保存模型
model.save('dementia_classification_model.h5')# 绘制训练和验证损失
plt.plot(history.history['loss'], label='训练损失')
plt.plot(history.history['val_loss'], label='验证损失')
plt.title('训练和验证损失')
plt.xlabel('时期')
plt.ylabel('损失')
plt.legend()
plt.show()# 绘制训练和验证准确率
plt.plot(history.history['accuracy'], label='训练准确率')
plt.plot(history.history['val_accuracy'], label='验证准确率')
plt.title('训练和验证准确率')
plt.xlabel('时期')
plt.ylabel('准确率')
plt.legend()
plt.show()

 

loss值与准确率的变化图

可以看到才第四轮准确率就已经很高了 

9 预测结果代码

import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing import image
import numpy as np
import matplotlib.pyplot as plt# 加载模型
model = load_model('dementia_classification_model.h5')# 定义类别标签
class_labels = ['MildDemented', 'ModerateDemented', 'NonDemented', 'VeryMildDemented']# 预测函数
def predict_image(img_path):img = image.load_img(img_path, target_size=(176, 208))img_array = image.img_to_array(img)img_array = np.expand_dims(img_array, axis=0)img_array /= 255.0prediction = model.predict(img_array)predicted_class = class_labels[np.argmax(prediction)]# 显示图像和预测结果plt.imshow(image.load_img(img_path))plt.title(f'Predicted: {predicted_class}')plt.axis('off')plt.show()# 预测并展示结果
img_path = r'D:\Pycharm_workspace\LeNet实验_四分类\image\VeryMildDemented\verymildDem0.jpg'  # 你的图片路径
predict_image(img_path)

 

10 四分类结果识别

1 MildDem成功识别(右侧有图片名称)

2 ModerateDem 成功识别

3 NonDem成功识别

 4 VeryMildDem成功识别

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/380649.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

新时代多目标优化【数学建模】领域的极致探索——数学规划模型

目录 例1 1.问题重述 2.基本模型 变量定义: 目标函数: 约束条件: 3.模型分析与假设 4.模型求解 5.LINGO代码实现 6.结果解释 ​编辑 7.敏感性分析 8.结果解释 例2 奶制品的销售计划 1.问题重述 ​编辑 2.基本模型 3.模…

Python酷库之旅-第三方库Pandas(036)

目录 一、用法精讲 111、pandas.Series.item方法 111-1、语法 111-2、参数 111-3、功能 111-4、返回值 111-5、说明 111-6、用法 111-6-1、数据准备 111-6-2、代码示例 111-6-3、结果输出 112、pandas.Series.xs方法 112-1、语法 112-2、参数 112-3、功能 112-…

基于Centos7搭建rsyslog服务器

一、配置rsyslog可接收日志 1、准备新的Centos7环境 2、部署lnmp环境 # 安装扩展源 wget -O /etc/yum.repos.d/epel.repo http://mirrors.aliyun.com/repo/epel-7.repo# 安装扩展源 yum install nginx -y# 安装nginx yum install -y php php-devel php-fpm php-mysql php-co…

mac二进制安装operator-sdk

0. 前置条件 1. 安装go 安装步骤略。 1. 下载operator-sdk源码包 https://github.com/operator-framework/operator-sdk 1.1 选择适合当前go版本的operator版本,在operator-sdk/go.mod文件中可以查看Operator-sdk使用的go版本。 2. 编译 源码包下载后&#x…

【从零开始实现stm32无刷电机FOC】【实践】【5/7 stm32 adc外设的高级用法】

目录 采样时刻触发采样同步采样 点击查看本文开源的完整FOC工程 本节介绍的adc外设高级用法用于电机电流控制。 从前面几节可知,电机力矩来自于转子的q轴受磁力,而磁场强度与电流成正比,也就是说电机力矩与q轴电流成正相关,控制了…

UDP网口(1)概述

文章目录 1.计算机网络知识在互联网中的应用2.认识FPGA实现UDP网口通信3.FPGA实现UDP网口通信的方案4.FPGA实现UDP网口文章安排5.传送门 1.计算机网络知识在互联网中的应用 以在浏览器中输入淘宝网为例,介绍数据在互联网是如何传输的。我们将要发送的数据包称作A&a…

SpringAI简单使用(本地模型+自定义知识库)

Ollama 简介 Ollama是一个开源的大型语言模型服务工具,它允许用户在本地机器上构建和运行语言模型,提供了一个简单易用的API来创建、运行和管理模型,同时还提供了丰富的预构建模型库,这些模型可以轻松地应用在多种应用场景中。O…

为 android编译 luajit库、 交叉编译

时间:20200719 本机环境:iMac2017 macOS11.4 参考: 官方的文档:Use the NDK with other build systems 写在前边:交叉编译跟普通编译类似,无非是利用特殊的编译器、链接器生成动态或静态库; make 本质上是按照 Make…

哈默纳科HarmonicDrive减速机组装注意事项

在机械行业中,精密传动设备HarmonicDrive减速机对于维持机械运作的稳定性和高效性起着至关重要的作用。然而在减速机的组装过程中,任何一个细微的错误都可能导致其运转时出现振动、异响等不良现象,严重时甚至可能影响整机的性能。因此&#x…

Python+Django+MySQL的新闻发布管理系统【附源码,运行简单】

PythonDjangoMySQL的新闻发布管理系统【附源码,运行简单】 总览 1、《新闻发布管理系统》1.1 方案设计说明书设计目标工具列表 2、详细设计2.1 登录2.2 程序主页面2.3 新闻新增界面2.4 文章编辑界面2.5 新闻详情页2.7 其他功能贴图 3、下载 总览 自己做的项目&…

Flink调优详解:案例解析(第42天)

系列文章目录 一、Flink-任务参数配置 二、Flink-SQL调优 三、阿里云Flink调优 文章目录 系列文章目录前言一、Flink-任务参数配置1.1 运行时参数1.2 优化器参数1.3 表参数 二、Flink-SQL调优2.1 mini-batch聚合2.2 两阶段聚合2.3 分桶2.4 filter去重(了解&#xf…

代码解读:Diffusion Models中的长宽桶技术(Aspect Ratio Bucketing)

Diffusion Models专栏文章汇总:入门与实战 前言:自从SDXL提出了长宽桶技术之后,彻底解决了不同长宽比的图像输入问题,现在已经成为训练扩散模型必选的方案。这篇博客从代码详细解读如何在模型训练的时候运用长宽桶技术(Aspect Rat…

UNiapp 微信小程序渐变不生效

开始用的一直是这个,调试一直没问题,但是重新启动就没生效,经查询这个不适合小程序使用:不适合没生效 background-image:linear-gradient(to right, #33f38d8a,#6dd5ed00); 正确使用下面这个: 生效,适合…

Python list comprehension (列表推导式 - 列表解析式 - 列表生成式)

Python list comprehension {列表推导式 - 列表解析式 - 列表生成式} 1. Python list comprehension (列表推导式 - 列表解析式 - 列表生成式)2. Example3. ExampleReferences Python 中的列表解析式并不是用来解决全新的问题,只是为解决已有问题提供新的语法。 列…

(10)深入理解pandas的核心数据结构:DataFrame高效数据清洗技巧

目录 前言1. DataFrame数据清洗1.1 处理缺失值(NaNs)1.1.1 数据准备1.1.2 读取数据1.1.3 查找具有 null 值或缺失值的行和列1.1.4 计算每列缺失值的总数1.1.5 删除包含 null 值或缺失值的行1.1.6 利用 .fillna() 方法用Portfolio …

Windows搭建RTMP视频流服务器

参考了一篇文章,见文末。 博客中nginx下载地址失效,附上一个有效的地址: Index of /download/ 另外,在搭建过程中,遇到的问题总结如下: 1 两个压缩包下载解压并重命名后,需要 将nginx-rtmp…

如何使用简鹿水印助手或 Photoshop 给照片添加文字

在社交媒体中,为照片添加个性化的文字已经成为了一种流行趋势。无论是添加注释、引用名言还是表达情感,文字都能够为图片增添额外的意义和风格。本篇文章将使用“简鹿水印助手”和“Adobe Photoshop”这两种工具给照片添加文字的详细步骤。 使用简鹿水印…

【python基础】组合数据类型:元组、列表、集合、映射

文章目录 一. 序列类型1. 元组类型2. 列表类型(list)2.1. 列表创建2.2 列表操作2.3. 列表元素遍历 ing元素列表求平均值删除散的倍数 二. 集合类型(set)三. 映射类型(map)1. 字典创建2. 字典操作3. 字典遍历…

【EI检索】第二届机器视觉、图像处理与影像技术国际会议(MVIPIT 2024)

一、会议信息 大会官网:www.mvipit.org 官方邮箱:mvipit163.com 会议出版:IEEE CPS 出版 会议检索:EI & Scopus 检索 会议地点:河北张家口 会议时间:2024 年 9 月 13 日-9 月 15 日 二、征稿主题…

【香橙派开发板测试】:在黑科技Orange Pi AIpro部署YOLOv8深度学习纤维分割检测模型

文章目录 🚀🚀🚀前言一、1️⃣ Orange Pi AIpro开发板相关介绍1.1 🎓 核心配置1.2 ✨开发板接口详情图1.3 ⭐️开箱展示 二、2️⃣配置开发板详细教程2.1 🎓 烧录镜像系统2.2 ✨配置网络2.3 ⭐️使用SSH连接主板 三、…