基于卷积神经网络的图像二分类检测模型训练与推理实现教程 | 幽络源

前言

对于本教程,说白了,就是期望能通过一个程序判断一张图片是否为某个物体,或者说判断一张图片是否为某个缺陷。因为本教程是针对二分类问题,因此主要处理 是 与 不是 的问题,比如我的模型是判断一张图片是否为苹果,那么拿一张图片给模型去推理,他会得出这张图是苹果的概率,如果概率大于0.5(这个概率在0~1之间),那么就判断为是苹果。

教程内容

使用了Python的 TensorFlow 和 Keras 库 构建卷积神经网络来完成二分类模型训练,以及使用模型完成对一张图片的推理。原文链接:基于卷积神经网络的图像二分类检测模型训练与推理实现教程 | 幽络源

大致步骤

1.确定环境与库

2.准备数据集并且划分

3.数据集的命名问题注意事项

4.编写训练代码完成模型训练

5.编写推理代码

6.测试二分类检测结果

7.根据结果优化数据集

步骤1.确定环境与库

Python环境是必备的,我这里所使用的Python版本为3.12.3

其次还需要以下库,依次执行如下命令即可

pip install tensorflow
pip install pillow
pip install scipy

如图

1

2

步骤2.准备数据集并且划分

我这里以判断图片是否为冲沟缺陷 来准备数据集,首先创建数据集的目录结构,结构如下

data/train/true_sample/ false_sample/  val/true_sample/false_sample/

QQ_1734065732662

目录解释:

data:作为数据集的根目录

train和val分别为训练集、验证集目录

true_sample:正类样本,也就是我这里需要把含有冲沟缺陷的图放到这个目录

false_sample:负类样本,也就是这里需要将不含有冲沟缺陷的图片放进这个目录

如图,我向train和val的true_sample目录加入了一些含有冲沟缺陷的图片

3

对于负类样本,也不是无脑的只要不是冲沟就往里面放,而是放置你认为训练出的模型可能会将什么识别为正类样本。比如滑坡和冲沟其实是有联系的,但不完全等同于,所以我需要将滑坡相关的,但是没有冲沟情况的图片放入false_sample中,期望模型不要误判。再比如一个苹果,你可能需要把红色气球作为父类样本,防止模型将红气球判断为是苹果,如图是我的负类样本

4

步骤3.数据集的命名问题注意事项

关于数据集的命名,这里其实有一个坑,但是先说避免坑的做法:就像步骤2一样,你的正类样本所放置的目录命名为true_sample、负类样本所放置的目录命名为false_sample就行了。(如果看不懂下面的解释,按照这里做法做就是了)

然后我来解释下是什么坑,对于这个二分类模型训练,训练出来的模型,无非是识别 是 与 不是 的问题,但是模型怎么区分我的哪个目录放置的为是,哪个目录放置的为不是呢,步骤4会给出训练代码,训练代码中的加载数据集时有一行如下代码

class_mode='binary'  # 二分类(冲沟缺陷 vs. 非冲沟缺陷)

这表示我们要做二分类模型训练,加上这行代码,在加载数据集时,Keras 会自动将这些文件夹的名称作为标签,分别命名为1 和 0,如果被命名为标签1 的目录,则在推理时,概率越接近于1,则越表示是标为1的目录的样本,反之概率越接近于0,则越表示是标为0的目录的样本。而keras自动命名标签1和0时是根据目录名首字母的顺序来的字,字母靠前的标为0,后者为1,true_sample的首字母为t,false_sample的首字母为f,因此false_sample标为0,true_sample标为1,这是符合我们的正常预期的。

反面例子:

如果我把正类样本放置于名为defect的目录,负类样本放置于no_defect目录会怎样呢,按照如上解释,defect目录会被标为0,no_defect目录会被标为1,这就和我们预期相反了,什么意思呢。我把正类样本放置defect目录中,其推理结果将会是越接近0,则越表示为正类了,因此这里特别需要注意(如果你要自定义目录名的话)。

步骤4.编写训练代码完成模型训练

先直接上训练代码

from tensorflow.keras import layers, models
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import tensorflow as tftrain_dir='data/train'
val_dir='data/val'# 设置图像的尺寸和批量大小,不用改,保持150是最平衡的
IMG_HEIGHT = 150
IMG_WIDTH = 150
BATCH_SIZE = 12# 数据预处理与增强
train_datagen = ImageDataGenerator(rescale=1./255,  # 将像素值归一化到 [0, 1] 区间shear_range=0.2,zoom_range=0.2,horizontal_flip=True
)validation_datagen = ImageDataGenerator(rescale=1./255)# 加载训练和验证数据
train_generator = train_datagen.flow_from_directory(train_dir,  # 训练数据目录target_size=(IMG_HEIGHT, IMG_WIDTH),  # 图像尺寸batch_size=BATCH_SIZE,class_mode='binary'  # 二分类(冲沟缺陷 vs. 非冲沟缺陷)
)train_class_labels = train_generator.class_indices
print("训练集自动标签映射关系为:"+str(train_class_labels))validation_generator = validation_datagen.flow_from_directory(val_dir,  # 验证数据目录target_size=(IMG_HEIGHT, IMG_WIDTH),batch_size=BATCH_SIZE,class_mode='binary'
)val_class_labels = validation_generator.class_indices
print("测试集自动标签映射关系为:"+str(val_class_labels))# 将数据生成器转换为 tf.data.Dataset 并应用 repeat() 方法
train_dataset = tf.data.Dataset.from_generator(lambda: train_generator,output_signature=(tf.TensorSpec(shape=(None, IMG_HEIGHT, IMG_WIDTH, 3), dtype=tf.float32),tf.TensorSpec(shape=(None,), dtype=tf.int32))
)
train_dataset = train_dataset.repeat()  # 确保数据重复validation_dataset = tf.data.Dataset.from_generator(lambda: validation_generator,output_signature=(tf.TensorSpec(shape=(None, IMG_HEIGHT, IMG_WIDTH, 3), dtype=tf.float32),tf.TensorSpec(shape=(None,), dtype=tf.int32))
)
validation_dataset = validation_dataset.repeat()  # 确保数据重复# 构建模型
model = models.Sequential([layers.InputLayer(shape=(IMG_HEIGHT, IMG_WIDTH, 3)),  # 添加 Input 层layers.Conv2D(32, (3, 3), activation='relu'),layers.MaxPooling2D((2, 2)),layers.Conv2D(64, (3, 3), activation='relu'),layers.MaxPooling2D((2, 2)),layers.Conv2D(128, (3, 3), activation='relu'),layers.MaxPooling2D((2, 2)),layers.Flatten(),layers.Dense(128, activation='relu'),layers.Dense(1, activation='sigmoid')  # 输出层,二分类问题
])# 编译模型
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])# 训练模型
model.fit(train_dataset,steps_per_epoch=train_generator.samples // BATCH_SIZE,epochs=30,validation_data=validation_dataset,validation_steps=validation_generator.samples // BATCH_SIZE
)# 保存模型
model.save('defect_detector_model.keras')  # 使用 .keras 格式保存模型

使用这段代码训练数据集你唯一需要注意的是保持代码文件于数据集文件在同一目录,或者使用绝对路径,如图

QQ_1734070943655

我们启动训练代码,可以看到控制台在按照规定的轮次30在训练中,而且可以看到我在训练代码中加入了输出标签映射关系来确保正类与负类的映射关系正确,如图

QQ_1734071390702

训练后,你会得到一个名为defect_detector_nodel.keras的文件,推理时会使用该模型进行推理

步骤5.编写推理代码

代码如下:

import os
from tensorflow.keras.models import load_model
import numpy as np
from tensorflow.keras.preprocessing import image# 加载训练好的模型
model = load_model('defect_detector_model.keras')  # 注意加载的是 .keras 格式# 设置输入图像的目标尺寸(与训练时相同)
IMG_HEIGHT = 150
IMG_WIDTH = 150# 定义函数来加载并预测图像
def predict_image(img_path):# 加载图像并进行预处理img = image.load_img(img_path, target_size=(IMG_HEIGHT, IMG_WIDTH))img_array = image.img_to_array(img)  # 将图像转换为数组img_array = np.expand_dims(img_array, axis=0)  # 扩展维度,成为一个 batchimg_array = img_array / 255.0  # 归一化处理(与训练时一致)# 预测图像类别prediction = model.predict(img_array)  # 返回的是一个包含概率的数组return prediction[0][0]  # 提取预测的概率值picPath=r"测试图.jpg"
confidence = predict_image(picPath)
print("有冲沟缺陷的概率为:"+str(confidence))

这段推理代码中,我们加载了刚才训练出的模型,然后使用了一张名为测试图.jpg的图片来进行推理,然后输出他有缺陷的概率

步骤6.测试二分类检测结果

我这里就不用一张图片来测试了,我这里指定一个目录,进行整个目录来测试里面的图片,还是附上我这个推理代码吧

import os
from tensorflow.keras.models import load_model
import numpy as np
from tensorflow.keras.preprocessing import image# 加载训练好的模型
model = load_model('defect_detector_model.keras')  # 注意加载的是 .keras 格式# 设置输入图像的目标尺寸(与训练时相同)
IMG_HEIGHT = 150
IMG_WIDTH = 150# 定义函数来加载并预测图像
def predict_image(img_path):# 加载图像并进行预处理img = image.load_img(img_path, target_size=(IMG_HEIGHT, IMG_WIDTH))img_array = image.img_to_array(img)  # 将图像转换为数组img_array = np.expand_dims(img_array, axis=0)  # 扩展维度,成为一个 batchimg_array = img_array / 255.0  # 归一化处理(与训练时一致)# 预测图像类别prediction = model.predict(img_array)  # 返回的是一个包含概率的数组return prediction[0][0]  # 提取预测的概率值# 测试目录,包含要进行推理的图像
testDir = r"D:\virtualTemp\pythonProject\CNN分类检测\data\train\true_sample"
pics = os.listdir(testDir)
# 遍历目录中的所有图片并进行预测
for pic in pics:picPath = os.path.join(testDir, pic)  # 获取图片的完整路径# 获取预测结果的置信度confidence = predict_image(picPath)# 输出图像的置信度和类别print(f"{pic} 置信度: {confidence:.4f}, 预测结果: {'有缺陷' if confidence >= 0.5 else '无缺陷'}")

我先使用正类样本来测试,先看看拿训练的数据如何,然后再用另外的图片来测试

结果如下图,正类样本中只有一张图判定为了无冲沟,但是我正类样本中其实都应当是冲沟,而我有101张图,因此这里正确率为99.009%

QQ_1734071615033

拿训练的数据来说话可能没有说服力,现在我使用爬图器来批量的爬取一些图片,需要的可以这里拿=> 幽络源爬图器

如图我爬取了3轮桥梁破损图,2轮冲沟地貌图,对于冲沟图,最好是手动删一些莫名奇妙的图,便于验证

QQ_1734072068792

QQ_1734072170259

ok,然后先测试桥梁破损,如果足够符合预期,足够表示模型很好,那么推理出的有缺陷数量应该没有或者很少才对,结果如下

QQ_1734072462049

看起来结果并不好,90张图中,居然有44张判定为了有冲沟缺陷,正确率只有46/90=51.11%,再测试下正类检测呢,如图48张图中只有11张判定为了无,还是不错的。

步骤7.根据结果优化数据集

在步骤6的测试中可知,所训练的模型对正类比较适应,对负类的学习还有所欠缺,处理方法有如下

1.调整判定指标confidence,一般为0.5,可以调大以提高正确率,但是不推荐这么做

2.加大训练轮次

3.训练时的父类样本图片多加一些

ok,方法1我不是很推荐,现在首先加大训练次数到100,然后多爬取一些非冲沟图加入到负类样本之中,当然,桥梁破损的图也放进去一些,然后重新训练获取模型。

训练完后还是按照步骤6中来测试桥梁破损,如图,这一次,90张图中判定为有缺陷的只有7个了,非常不错,正确率提高到了82/90=91.11%

QQ_1734073419635

结语

以上是幽络源的基于卷积神经网络的图像二分类检测模型训练与推理实现教程,对Python、Java感兴趣的小伙伴可加群交流

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/489946.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【razor】echo搭配relay功能分析

echo 要搭配relay 实现作者说relay在linux上跑,可以模拟丢包、延迟目前没看到如何模拟。relay监听9200,有俩作用 echopeer1 发relay,replay 把peer1的包给peer2 ,实现p2p能力。 接收端:采集后发送发给relay的 接收端的地址就是自己,的地址就是本地的9200,因此是让relay接…

轩凯生物被警示,财务内控不规范,华泰证券又被处罚

作者:Tracy 来源:IPO魔女 11月21日,南京轩凯生物科技股份有限公司(简称“轩凯生物”)被交易所下达书面警示的自律监管函。同时其保荐机构华泰联合证券和会计师事务所天衡,均受到监管处罚。这是今年来&…

【C++习题】19.数组中第K个大的元素

题目&#xff1a;数组中第K个大的元素 链接&#x1f517;&#xff1a;数组中第K个大的元素 题目&#xff1a; 代码&#xff1a; class Solution { public:int findKthLargest(vector<int>& nums, int k) {// 将数组中的元素先放入优先级队列中priority_queue<i…

一键学懂BurpSuite(7)

声明&#xff01; 学习视频来自B站up主 泷羽sec 有兴趣的师傅可以关注一下&#xff0c;如涉及侵权马上删除文章&#xff0c;笔记只是方便各位师傅的学习和探讨&#xff0c;文章所提到的网站以及内容&#xff0c;只做学习交流&#xff0c;其他均与本人以及泷羽sec团队无关&#…

工业大数据分析算法实战-day04

文章目录 day04统计分析概率分布参数估计假设检验 统计分布拟合1.基于核函数的非参数方法2. 单概率分布的参数化拟合3. 混合概率分布估计 线性回归模型1. OLS模型&#xff08;普通最小二乘法&#xff09;2. OLS模型检验3. 鲁棒线性回归4. 结构复杂度惩罚&#xff08;正则化&…

【Golang】Go语言编程思想(六):Channel,第四节,Select

使用 Select 如果此时我们有多个 channel&#xff0c;我们想从多个 channel 接收数据&#xff0c;谁来的快先输出谁&#xff0c;此时应该怎么做呢&#xff1f;答案是使用 select&#xff1a; package mainimport "fmt"func main() {var c1, c2 chan int // c1 and …

Python中的OpenCV详解

文章目录 Python中的OpenCV详解一、引言二、OpenCV基础操作1、OpenCV简介2、安装OpenCV3、图像读取与显示 三、图像处理技术1、边缘检测2、滤波技术 四、使用示例1、模板匹配 五、总结 Python中的OpenCV详解 一、引言 在当今数字化社会中&#xff0c;图像处理和计算机视觉技术…

基于python的Selenium webdriver环境搭建(笔记)

一、PyCharm安装配置Selenium环境 本文使用环境&#xff1a;windows11、Python 3.8.1、PyCharm 2019.3.3、Selenium 3.141.0 测试开发环境搭建综述 安装python和pycharm安装浏览器安装selenium安装浏览器驱动测试环境是否正确 这里我们直接从第三步开始 1.1 Seleium安装 …

LLMC:大语言模型压缩工具的开发实践

关注&#xff1a;青稞AI&#xff0c;学习最新AI技术 青稞Talk主页&#xff1a;qingkelab.github.io/talks 大模型的进步&#xff0c;正推动我们向通用人工智能迈进&#xff0c;然而庞大的计算和显存需求限制了其广泛应用。模型量化作为一种压缩技术&#xff0c;虽然可以用来加速…

【渗透测试】信息收集二

其他信息收集 在渗透测试中&#xff0c;历史漏洞信息收集是一项重要的工作&#xff0c;以下是相关介绍&#xff1a; 历史漏洞信息收集的重要性 提高效率&#xff1a;通过收集目标系统或应用程序的历史漏洞信息&#xff0c;可以快速定位可能存在的安全问题&#xff0c;避免重复…

TQ15EG开发板教程:使用SSH登录petalinux

本例程在上一章“创建运行petalinux2019.1”基础上进行&#xff0c;本例程将实现使用SSH登录petalinux。 将上一章生成的BOOT.BIN与imag.ub文件放入到SD卡中启动。给开发板插入电源与串口&#xff0c;注意串口插入后会识别出两个串口号&#xff0c;都需要打开&#xff0c;查看串…

微信小程序5-图片实现点击动作和动态加载同类数据

搜索 微信小程序 “动物觅踪” 观看效果 感谢阅读&#xff0c;初学小白&#xff0c;有错指正。 一、功能描述 a. 原本想通过按钮加载背景图片&#xff0c;来实现一个可以点击的搜索button&#xff0c;但是遇到两个难点&#xff0c;一是按钮大小调整不方便&#xff08;网上搜索…

学习笔记:从ncsi/nc-si协议和代码了解网络协议的设计范式

学习笔记&#xff1a;从ncsi/nc-si协议和代码了解网络协议的设计范式 参考文档&#xff1a; https://www.dmtf.org/standards/published_documents https://www.dmtf.org/dsp/DSP0222 https://www.dmtf.org/sites/default/files/standards/documents/DSP0222_1.2.0.pdf参考代…

3D 生成重建030-SV3D合成环绕视频以生成3D

3D 生成重建030-SV3D合成环绕视频以生成3D 文章目录 0 论文工作1 论文方法2 实验结果 0 论文工作 论文提出了Stable Video 3D (SV3D)——一个用于生成围绕三维物体的高分辨率图像到多视角视频的潜在视频扩散模型。最近关于三维生成的文献提出了将二维生成模型应用于新视图合成…

3D 生成重建035-DiffRF直接生成nerf

3D 生成重建035-DiffRF直接生成nerf 文章目录 0 论文工作1 论文方法2 实验结果 0 论文工作 本文提出了一种基于渲染引导的三维辐射场扩散新方法DiffRF&#xff0c;用于高质量的三维辐射场合成。现有的方法通常难以生成具有细致纹理和几何细节的三维模型&#xff0c;并且容易出…

Spark执行计划解析后是如何触发执行的?

在前一篇Spark SQL 执行计划解析源码分析中&#xff0c;笔者分析了Spark SQL 执行计划的解析&#xff0c;很多文章甚至Spark相关的书籍在讲完执行计划解析之后就开始进入讲解Stage切分和调度Task执行&#xff0c;每个概念之间没有强烈的关联&#xff0c;因此这中间总感觉少了点…

探索Python的魔法工具箱:functools

文章目录 探索Python的魔法工具箱&#xff1a;functools背景库介绍安装简单库函数使用方法lru_cachepartialreducecmp_to_keytotal_ordering 场景应用缓存数据库查询结果固定函数参数计算序列的累积和自动补全比较方法将比较函数转换为key函数 常见Bug及解决方案Bug 1: lru_cac…

leetcode 3266 K次乘运算后的最终数组II 题解

题目大意 原题面 给你一个数组 nums&#xff0c;然后进行 k 轮游戏&#xff0c;每轮游戏都会选择数组当中最小的元素然后乘上一个数 multiplier&#xff08;题目给出&#xff09;&#xff0c;问你 k 轮游戏结束之后&#xff0c;这个数组长什么样子&#xff0c;所有的元素要对 …

事务管理与锁机制

title: 事务管理与锁机制 date: 2024/12/14 updated: 2024/12/14 author: cmdragon excerpt: 在数据库系统中,事务管理至关重要,它确保多个数据库操作能够作为一个单一的逻辑单元来执行,从而维护数据的一致性和完整性。一个良好的事务管理系统能够解决并发操作带来的问题…

各种消息中间件介绍

消息中间件是一种在分布式系统中实现消息传递的软件架构&#xff0c;它允许不同的应用程序或系统组件之间异步地交换信息。 1. Apache Kafka Kafka是一个分布式流处理平台&#xff0c;能够处理高吞吐量的数据。它主要用于构建实时数据管道和流应用程序。 • Broker&#xff1a;…