CNN和MobileNetV2搭建的水果识别模型

一、 概述

1. 项目背景

水果是人们日常生活中重要的食品之一,其营养丰富、口感美味、色香俱佳,因此备受广大消费者的喜爱。 然而,在市场上,各种不同品种的水果琳琅满目,对于人类的肉眼识别来说并不容易实现。传统的检测方法需要人工参与,效率低下,成本高昂, 同时还容易出现误判和漏检等问题。基于此,利用计算机视觉技术开发水果识别系统,能够极大地提升水果检测的效率, 减少人工介入,为消费者提供更好的服务体验。

2. 研究意义

随着智能手机、平板电脑等移动设备的广泛应用,人们越来越需要将物理世界和数字世界相结合。在这个环境下,开发一款高效、精准的水果识别系统,有助于优化用户体验,提高生产效率,降低生鲜水果流通损失率,从而推动整个水果行业的数字化转型。

3. 问题定义

本项目旨在研究利用深度学习模型进行水果图像分类的方法,具体包括两个主要任务:一是使用卷积神经网络(CNN)模型进行水果图片的分类,二是探索轻量级神经网络模型MobileNetV2在水果图像分类中的应用。

二、 构建模型

1、 数据样本

使用百度飞桨-公共数据集

https://aistudio.baidu.com/aistudio/datasetdetail/193821

因为30种水果种类过多,不便于后续的热力图生成与结果分析,所以只取其中的15类水果数据,并按照4:1对数据集进行划分为训练集和测试集。 有以下15类:哈密瓜、柠檬、桂圆、梨、榴莲、火龙果、猕猴桃、胡萝卜、芒果、苦瓜、草莓、荔枝、菠萝、车厘子、黄瓜。 数据集中的图片的尺寸大小并不统一,所以在进行模型训练以及验证之前,定义了加载数据集的函数。

def data_load(data_dir, test_data_dir, img_height, img_width, batch_size)

通过传入的img_height, img_width参数,调用TensorFlow函数

def train_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,label_mode='categorical',seed=123,image_size=(img_height, img_width),batch_size=batch_size)

将数据集图片全部处理成img_height * img_width的大小即224*224.

2、 模型结构

一是使用卷积神经网络(CNN)模型进行水果图片的分类,二是探索轻量级神经网络模型MobileNetV2在水果图像分类中的应用。

2.1 CNN结构

通过TensorFlow构建CNN模型

模型定义函数如下:

def model_load(IMG_SHAPE=(223, 224, 3), class_num=15):# 搭建模型model = tf.keras.models.Sequential([# 对模型做归一化的处理tf.keras.layers.experimental.preprocessing.Rescaling(1. / 255, input_shape=IMG_SHAPE),# 卷积层tf.keras.layers.Conv2D(32, (3, 3), activation='relu'),# 池化层tf.keras.layers.MaxPooling2D(2, 2),# 卷积层tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),# 池化层tf.keras.layers.MaxPooling2D(2, 2),# 二维输出转化一维tf.keras.layers.Flatten(),tf.keras.layers.Dense(128, activation='relu'),tf.keras.layers.Dense(class_num, activation='softmax')])# 输出模型信息model.summary()opt = tf.keras.optimizers.SGD(learning_rate=0.005)model.compile(optimizer=opt, loss='categorical_crossentropy', metrics=['accuracy'])return model

 模型介绍:

首先将输入的图片进行归一化处理到0~1之间。然后就是两层卷积层,第一层将通道数由3进行升维到32,第二层则由32升维到64,两层卷积层的卷积核大小都是33,默认步长为1。每层卷积之后都使用Max最大池化,大小22,使用默认步长为2.然后通过Flatten将输出展开到一维长度,随后是一个全连接层,输出到128个神经元。激活函数全部使用的是ReLU。最后一层全连接层,输出映射到15个神经元,因为数据集中是15种水果,采用softMax激活函数,用来预测每个类别的概率。

2.2 MobileNetV2结构

MobileNet的基本单元是深度可分离卷积,实质是一种可分解卷积操作。可分为两个更小的操作:Depthwise convolution和Pointwise convoluton。 标准的卷积核DkDkM是对与输入通道数M进行卷积操作,N个卷积核。

而MobileNet的Depthwise是对每个输入通道进行分别的卷积 。因为这属于分组卷积,所以在进行卷积操作以后为了减少信息损失,然后再用pointwise convolution也就是1*1的卷积核进行卷积。

通过Depthwise convolution和Pointwise convoluton深度可分离卷积以后的整体效果和一个标准卷积差不多,但因为是对不同的通道进行分别卷积,相较于常规的对整体所有通道进行卷积,可以显著的减少计算量,通过pointwise convolution又不损失信息不减少精度,速度更快。

而MobileNetV2相较于V1的改进,是使用了反向线性残差结构。

先采用了1 * 1卷积进行了升维,然后采用3 * 3深度可分离卷积进行特征提取,最后用1 * 1卷积进行降维,降维时不采用激活函数。V2比V1的参数量和计算量会更小、准确率会更高。 模型定义函数如下:

def model_load(IMG_SHAPE=(224, 224, 3), class_num=15):#加载预训练的mobilenet模型base_model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE,include_top=False,weights='imagenet')base_model.trainable = Falsemodel = tf.keras.models.Sequential([# 进行归一化的处理tf.keras.layers.experimental.preprocessing.Rescaling(1. / 127.5, offset=-1, input_shape=IMG_SHAPE),# 主干模型base_model,#全局平均池化tf.keras.layers.GlobalAveragePooling2D(),# 全连接层tf.keras.layers.Dense(class_num, activation='softmax')])model.summary()model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])return model

 模型介绍:

迁移学习调用了在ImageNet上预训练后的MobileNetV2模型,并去除了顶部的全连接层只留下里面的卷积层和池化层,作为我们的主干模型。冻结了主干模型的参数以适应我们后面自己添加的全连接层的训练,可以加快训练速度。

整个模型先进行归一化,映射到-1~1之间,然后通过我们的主干模型,接着是全局平均池化转化为固定长度的向量。然后就是一个全连接层,映射到class_num个神经元上,也就是我们的水果种类的数量15,通过softmax激活函数预测每个水果类别的概率。

三、 实验结果

1、 CNN训练过程及分析

除了定义数据集加载函数data_load和模型构建函数model_load外,还定义了showAccuracyAndLoss(history)用来从history中提取模型训练集和验证集的准确率和误差损失,绘制训练过程中的loss和accuracy曲线图。

def data_load(data_dir, test_data_dir, img_height, img_width, batch_size)
def model_load(IMG_SHAPE=(224, 224, 3), class_num=15)
def show_loss_acc(history)

在train(epochs)函数中调用history = model.fit(train_ds, validation_data=val_ds, epochs=epochs)进行训练,通过model.save("models/cnn_fv.h5")保存为模型文件。

定义了test_cnn()函数通过保存的模型文件对验证集进行验证,并通过showHM绘制heatmap热力图。  先使用sgd随机梯度下降优化器和categorical_crossentropy 多分类交叉熵损失函数,epcoh=10进行训练,默认学习率0.01。

在第10轮时训练集的准确率只有74%,明显训练轮次过少,调整epoch=25重新训练。

 

观察曲线,在18轮以后,训练集的准确率就已经达到了100%,而测试集上的准确率只有50%,随着轮次的增加,测试集上的交叉熵损失值也在增加,发生过拟合。

测试集的热力图表现出来的准确率也比较差。 调整学习率,由0.01降为0.05,其它不变,重新训练。

效果不佳。继续降低学习率为0.001,epoch增加到40,重新训练。

相较于最初的0.01学习率,测试集上的交叉熵损失在2一下,更低了一点,但是测试集上的准确率还是没有得到很大的提高。

改变优化器,使用Adam优化器,epoch=40,其它保持不变,学习率默认0.01

在第10轮时对于训练集的准确率就已经100%,而测试集的交叉熵损失反而达到4以上,比使用sgd优化器时的过拟合更加严重。

尝试调整CNN网络结构。

首尾增加2个卷积层池化层,Flatten展开一维后增加1个全连接层,训练70轮,使用sgd优化器和多分类交叉熵损失函数,效果不理想。

排查原因,首要原因是数据集的问题,对于像荔枝的数据集,在我们这个模型中预测出来的草莓的概率反而比荔枝更高。查看数据集图片发现这个数据集样本量不够大,荔枝只有156张图片,而且有剥开皮的、还没熟透绿色的、照片调色过艳的,类型过杂图片过少造成预测准确率低。

利用tensorFlow的ImageDataGenerator对训练集进行数据增强,加上原来的数据集部分,扩大为原来的5倍。

datagen = ImageDataGenerator(rotation_range=40,  # 随机旋转角度范围width_shift_range=0.2,  # 随机水平平移范围(相对于图片宽度)height_shift_range=0.2,  # 随机竖直平移范围(相对于图片高度)shear_range=0.2,  # 随机裁剪zoom_range=0.2,  # 随机缩放horizontal_flip=True,  # 随机水平翻转vertical_flip=True,  # 随机竖直翻转fill_mode='nearest')  # 填充模式

使用最早定义的CNN网络,epoch=15,sgd优化器和多分类交叉熵损失函数,训练结果如下:

使用原始的测试集测试后的热力图:

对数据增强后的测试集进行测试,测试结果热力图:

2、 MobileNetV2训练过程及分析

使用adam优化器和sgd随机梯度下降优化器和categorical_cr ossentropy多分类交叉熵损失函数,默认学习率0.01,epoch=10进行训练。

 

可以观察到在第6轮训练时,训练集上的准确率就已经达到了100%,而且测试集上的准确率也有90%以上,交叉熵损失达到0.5以下。训练效果非常好。 对原始测试集测试热力图如下:

对数据增强后的测试集,测试热力图如下:

 

相比原测试集,准确率只有几类水果稍微下降。

由此得知,相比较于从头开始训练一个自己的CNN模型,利用迁移学习使用预训练过的MobileNetV2作为主干,利用它在ImageNet上学到的特征,在此基础上进行微调适应自己的数据集,可以显著降低训练时间和成本,大大提高准确度。

四、总结

在本项目中着重探索了利用深度学习模型进行水果图像分类的方法。具体而言包括使用卷积神经网络(CNN)模型进行水果图片的分类和探索轻量级神经网络模型MobileNetV2在水果图像分类中的应用。

在第一项任务中,使用TensorFlow构建了一个简单的CNN模型,并通过调整模型参数来提高准确率。在实验过程中发现由于数据集的问题,训练结果并不理想,测试集上的准确率低于预期,同时出现了过拟合的情况。针对这个问题,从优化器、学习率和训练轮次等方面入手,对模型进行了改进和调整。但是由于数据集本身的局限性,改进效果并不显著。后续对数据集进行数据增强,效果相对右改善。因此使用迁移学习中的MobileNetV2模型进行图像分类。

在第二项任务中,使用预训练的MobileNetV2模型作为主干模型,并对其进行微调以适应自己的数据集。通过这种方法成功地提高了分类准确率。 迁移学习对于解决小规模数据集上的图像分类问题具有重要意义。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/467280.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

GEE 图表——ATom:气象测量系统(MMS)的测量数据,2016-2018 年

目录 简介 摘要 代码 引用 网址推荐 知识星球 机器学习 ATom: Measurements from Meteorological Measurement System (MMS), 2016-2018 简介 该数据集包含气象测量系统(MMS)仪器在四次 ATom 活动中的测量数据。 气象测量系统是一种最先进的仪器…

一文解秘Rust如何与Java互操作

本博客所有文章除特别声明外,均采用CC BY-NC-SA 4.0许可协议。转载请注明来自 唯你 使用场景 JAVA 与 Rust 互操作让 Rust 可以背靠 Java 大生态来做更多事情,而 Java 也可以享受 Rust 语言特性的内存安全,所有权机制,无畏并发。…

信息安全工程师(82)操作系统安全概述

一、操作系统安全的概念 操作系统安全是指操作系统在基本功能的基础上增加了安全机制与措施,从而满足安全策略要求,具有相应的安全功能,并符合特定的安全标准。在一定约束条件下,操作系统安全能够抵御常见的网络安全威胁&#xff…

微服务系列四:热更新措施与配置共享

目录 前言 一、基于Nacos的管理中心整体方案 二、配置共享动态维护 2.1 分析哪些配置可拆,需要动态提供哪些参数 2.2 在nacos 分别创建共享配置 创建jdbc相关配置文件 创建日志相关配置文件 创建接口文档配置文件 2.3 拉取本地合并配置文件 2.3.1 拉取出现…

线程级耗时统计工具类TimeWatcher

线程级耗时统计工具类TimeWatcher 先看效果 假设我们的业务代码逻辑是这样的 那么最终的日志打印效果为(注:此为美化输出,也可设置为常规一行输出,还可自定义) 2024-11-08T23:48:53.39008:00 INFO 31472 --- [nio-…

arkUI:Flex弹性布局的各个属性

arkUI:Flex弹性布局的简单使用 1 主要内容说明2 相关内容2.1 Flex弹性布局的方向2.1.1 源码1的简答说明2.1.2 源码1 (Flex弹性布局的方向)2.1.3 源码1运行效果2.1.3.1 当direction: FlexDirection.RowReverse2.1.3.2 当direction: FlexDirect…

详解Gemini API的使用:在国内实现大模型对话与目标检测教程

摘要:本博客介绍了如何利用Gemini API实现多轮对话和图像目标检测识别功能,在Python中快速搭建自己的大模型完成实际任务。通过详细的步骤解析,介绍了如何申请Gemini API密钥,调用API、对话实现的代码,给出了上传图片识…

5G时代已来:我们该如何迎接超高速网络?

内容概要 随着5G技术的普及,我们的生活似乎变得更加“科幻”了。想象一下,未来的智能家居将不仅仅是能够听你说“开灯”;它们可能会主动询问你今天心情如何,甚至会推荐你一杯“维他命C芒果榨汁”,帮助你抵御夏天的炎热…

算法每日练 -- 双指针篇(持续更新中)

介绍: 常见的双指针有两种形式,一种是对撞指针(左右指针),一种是快慢指针(前后指针)。需要注意这里的双指针不是 int* 之类的类型指针,而是使用数组下标模拟地址来进行遍历的方式。 …

理解鸿蒙app 开发中的 context

是什么 Context是应用中对象的上下文,其提供了应用的一些基础信息,例如resourceManager(资源管理)、applicationInfo(当前应用信息)、dir(应用文件路径)、area(文件分区…

贝尔不等式,路径积分与AB(Aharonov-Bohm)效应

贝尔不等式、路径积分与Aharonov-Bohm(AB)效应 这些概念分别源于量子力学不同的理论分支和思想实验,但它们都揭示了量子力学的奇异性质,包括非局域性、相位效应和波粒二象性。以下详细解析每一概念,并探讨其相互联系。…

python 爬虫 入门 六、Selenium

Selenium本来是一个自动测试工具,用于模拟用户对网站进行操作。在爬虫领域也有其用处。 一、下载安装Selenium及附属插件 pip install Selenium 安装完成后还需要安装一个浏览器驱动,来让python能启动浏览器。 如果是Edge或者其他基于Chromium的浏览器…

Linux(CentOS)yum update -y 事故

CentOS版本:CentOS 7 事情经过: 1、安装好CentOS 7,系统自带JDK8,版本为:1.8.0_181 2、安装好JDK17,版本为:17.0.13 3、为了安装MySQL执行了 yum update -y(这个时候不知道该命令的…

uniapp uni-calendar日历实现考勤统计功能

根据日历组件代码结构 构成相应结构的状态统计数据 list 再遍历到每日的子组件中 <view class"uni-calendar__weeks-item" v-for"(weeks,weeksIndex) in item" :key"weeksIndex"><calendar-item class"uni-calendar-item--hook&q…

环境配置与搭建

安装pytorch 官网连链接&#xff1a;https://pytorch.org/ 特殊包名 cv2 pip install opencv-python sklearn pip install scikit-learnPIL pip install Pillow使用jupyter notebook pip install jupyter安装显卡驱动 Windows Linux 视频教程&#xff1a; 【ubuntu2…

【数据库实验一】数据库及数据库中表的建立实验

目录 实验1 学习RDBMS的使用和创建数据库 一、 实验目的 二、实验内容 三、实验环境 四、实验前准备 五、实验步骤 六、实验结果 七、评价分析及心得体会 实验2 定义表和数据库完整性 一、 实验目的 二、实验内容 三、实验环境 四、实验前准备 五、实验步骤 六…

SpringBoot健身房管理:技术与实践

2相关技术 2.1 MYSQL数据库 MySQL是一个真正的多用户、多线程SQL数据库服务器。 是基于SQL的客户/服务器模式的关系数据库管理系统&#xff0c;它的有点有有功能强大、使用简单、管理方便、安全可靠性高、运行速度快、多线程、跨平台性、完全网络化、稳定性等&#xff0c;非常…

CulturalBench :一个旨在评估大型语言模型在全球不同文化背景下知识掌握情况的基准测试数据集

2024-10-04&#xff0c;为了提升大型语言模型在不同文化背景下的实用性&#xff0c;华盛顿大学、艾伦人工智能研究所等机构联合创建了CulturalBench。这个数据集包含1,227个由人类编写和验证的问题&#xff0c;覆盖了包括被边缘化地区在内的45个全球区域。CulturalBench的推出&…

python登录功能实现

一.用python实现基本的登录功能 #-----------------1.基本登录功能------------------- nameinput("qq账号&#xff1a;") if name"jc":passwdinput("密码&#xff1a;")if passwd"123456":print("登录成功")else:print(&q…

如何使用Python管理环境变量

文章目录 📖 介绍 📖🏡 演示环境 🏡📒 环境变量 📒📝 环境变量简介📝 Python 中的环境变量操作📝 获取环境变量📝 设置环境变量🔖 临时设置🔖 永久设置📝 删除环境变量📝 临时删除📝 永久删除📝 小结⚓️ 相关链接 ⚓️📖 介绍 📖 环境变量…