【神经网络】基于CNN(卷积神经网络)构建猫狗分类模型

文章目录

    • 解决问题
    • 数据集
    • 探索性数据分析
    • 数据预处理
      • 数据集分割
      • 数据预处理
    • 构建模型并训练
      • 构建模型
      • 训练模型
    • 结果分析与评估
    • 模型保存
    • 结果预测
    • 经验总结

解决问题

针对经典猫狗数据集,基于卷积神经网络,构建猫狗二元分类模型,使用数据集进行参数训练,模型评估,然后使用模型进行分类预测,最后对模型进行保存,供后续使用。

数据集

数据集来源

猫狗数据集

探索性数据分析

查看待训练识别图片

from matplotlib import pyplot as plt
import os
import random# 获取文件名
_,_,cat_images = next(os.walk('../../dataset/kagglecatsanddogs_5340/PetImages/Cat'))# 准备3*3 图表
fig, ax = plt.subplots(3, 3, figsize=(20, 10))
# 随机选择一幅图像并绘制
for idx, img in enumerate(random.sample(cat_images, 9)):img_read = plt.imread('../../dataset/kagglecatsanddogs_5340/PetImages/Cat/' + img)ax[int(idx / 3), idx % 3].imshow(img_read)ax[int(idx / 3), idx % 3].set_title('cat/' + img)ax[int(idx / 3), idx % 3].axis('off')
plt.show()

查看狗图片类似,将Cat目录换成Dog即可

image-20240614224011275

数据预处理

数据集分割

由于下载的图片猫和狗各在一个文件夹内,如下:

image-20240613223710122

需要将数据按80%:20%进行分割,分为训练集和测试集。目录结构如下:

image-20240613223517479

下面进行数据拆分,核心代码(以猫图片为例)如下:

# 训练数据集80% 测试数据集20%
train_size = 0.8
# 获取猫图像数量
_, _, cat_images = next(os.walk(src_folder+'Cat/'))
num_cat_images = len(cat_images)
num_cat_images_train = int(train_size * num_cat_images)
num_cat_images_test = num_cat_images - num_cat_images_train
# 分割猫图像
cat_train_images = random.sample(cat_images, num_cat_images_train)
for img in cat_train_images:shutil.copy(src=src_folder+'Cat/'+img, dst=src_folder+'Train/Cat/')
cat_test_images  = [img for img in cat_images if img not in cat_train_images]
for img in cat_test_images:shutil.copy(src=src_folder+'Cat/'+img, dst=src_folder+'Test/Cat/')

数据预处理

这一步要将分割后的数据集转成和模型结构匹配的数据类型。使用keras提供的ImageDataGenerator类和flow_from_directory()方法

ImageDataGenerator类:图像增强类,可以进行图像旋转、图像平移、水平翻转、图像缩放等操作;

flow_from_directory()方法:ImageDataGenerator类的方法,支持以图像路径为输入,按批次加载图像到内存,防止训练数据量过大,机器内存不足问题;还支持对图像进行预处理操作,例如尺寸缩放和图像增强

# 训练数据预处理
training_data_generator = ImageDataGenerator(rescale=1./255)
training_set = training_data_generator.flow_from_directory('../../dataset/kagglecatsanddogs_5340/PetImages/train/',target_size=(32, 32),batch_size=16,class_mode='binary')# 测试数据预处理
testing_data_generator = ImageDataGenerator(rescale= 1./255)
testing_set = testing_data_generator.flow_from_directory('../../dataset/kagglecatsanddogs_5340/PetImages/test/',target_size=(32, 32),batch_size=16, class_mode='binary')

构建模型并训练

构建模型

# 定义超参数
# 特征滤波器尺寸
FILTER_SIZE = 3
# 特征滤波器数量
FILTER_NUM = 32
# 图片输入尺寸
INPUT_SIZE = 32
# 最大池化尺寸
MAXPOOL_SIZE = 2
# 批量处理图片的大小
BATCH_SIZE = 16
STEPS_PER_EPOCH = 20000 // BATCH_SIZE
# 训练轮次
EPOCHS = 10
# 定义模型
model = Sequential()
# 添加卷积、池化层 提取特征
model.add(Conv2D(FILTER_NUM, (FILTER_SIZE, FILTER_SIZE), input_shape=(INPUT_SIZE, INPUT_SIZE, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(MAXPOOL_SIZE,MAXPOOL_SIZE)))
# 再添加卷积、池化层 提取特征
model.add(Conv2D(FILTER_NUM, (FILTER_SIZE, FILTER_SIZE), input_shape=(INPUT_SIZE, INPUT_SIZE, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(MAXPOOL_SIZE,MAXPOOL_SIZE)))
# 对输出结果进行降维处理,转成一维张量
model.add(Flatten())
# 添加全链接层,根据特征进行分类预测
model.add(Dense(units=128, activation='relu'))
# 添加dropout层,随机将一部分输入设置为0,防止模型复杂,出现过拟合现象
model.add(Dropout(0.5))
# 添加输出层,一个节点
model.add(Dense(units=1, activation='sigmoid'))

该模型结构分为,卷积池化层,卷积池化层,Flatten层,全链接层1,全链接层2(输出层)如下:

image-20240614221747896

其中,第一列是神经网络的层,第二列是每层的输出形状,第三层是每层训练的参数

可以看到,该模型图像输入尺寸是(32,32),经过一层卷积(32个特征过滤器)输出为(30,30,32),经过一层最大池化层,输出为(15,15,32);其中特征滤波器尺寸为3*3,所以滤波后的尺寸会是32-(3-1)=30,经过最大池化(2x2)尺寸减半,为15。

训练模型

# 模型训练
model.fit(training_set, steps_per_epoch=STEPS_PER_EPOCH, epochs=EPOCHS, verbose=1)
image-20240614123217943

结果分析与评估

model.evaluate(testing_set,steps=len(testing_set),verbose=1)
image-20240614222619817

准确度达到了0.7856

模型保存

from joblib import dump, load# 模型持久化 到磁盘
dump(model, './猫狗分类.onnx')

结果预测

引入保存模型,随机选取一张图片进行预测分类

from matplotlib import pyplot as plt
fig, ax = plt.subplots()
img = plt.imread('../../dataset/kagglecatsanddogs_5340/PetImages/Dog/6.jpg')
ax.imshow(img)
plt.show()
image-20240614223543720
from joblib import dump, load
model = load('./猫狗分类.onnx')from tensorflow.keras.preprocessing.image import img_to_array,load_imgimg = load_img('../../dataset/kagglecatsanddogs_5340/PetImages/Dog/6.jpg',target_size=(32,32))
img = img_to_array(img)
img /= 255
import numpy as np
img_array = np.expand_dims(img, axis=0)
print(img_array.shape)
model.predict(img_array)

在这里插入图片描述

由于是二元分类,0和1分别表示猫狗,输出概率接近表示是狗,接近0表示是猫狗。但具体为啥0表示猫1表示狗而不是反过来表示,还待研究。

经验总结

1 在使用next()加载图像时,要确保路径正确,否则会报StopIteration错误,原因是路径错误,找不到可迭代的数据。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/353534.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

驱动开发(四):Linux内核中断

驱动开发系列文章: 驱动开发(一):驱动代码的基本框架 驱动开发(二):创建字符设备驱动 驱动开发(三):内核层控制硬件层 驱动开发(四&#xf…

【数据结构C++】表达式求值(多位数)课程设计

📚博客主页:Zhui_Yi_ 🔍:上期回顾:图 ❤️感谢大家点赞👍🏻收藏⭐评论✍🏻,您的三连就是我持续更新的动力❤️ 🎇追当今朝天骄,忆顾往昔豪杰。 …

房地产房型展示信息小程序的内容是什么

地产业规模之大且品牌众多,还有房屋租赁、中介等,无论开发商公司还是衍生行业商家都需要多渠道宣传品牌和客户触达沟通转化,除了线下各种传单,线上也是主要场景,通过各种连接来达到相应目标。 也因此需符合平台生态开…

人工智能--自然语言处理NLP概述

欢迎来到 Papicatch的博客 目录 🍉引言 🍈基本概念 🍈核心技术 🍈常用模型和方法 🍈应用领域 🍈挑战和未来发展 🍉案例分析 🍈机器翻译中的BERT模型 🍈情感分析在…

iCopy for Mac 剪切板 粘贴工具 历史记录 安装(保姆级教程,新手小白轻松上手)

Mac分享吧 文章目录 效果可留存文本、图片、文件等复制历史记录也可根据关键字进行历史记录检索点击一下,可复制双击两下,复制内容,并将信息粘贴至鼠标指针处 一、准备工作二、开始安装1、双击运行软件,将其从左侧拖入右侧文件夹…

【包管理】Node.JS与Ptyhon安装

文章目录 Node.JSPtyhon Node.JS Node.js的安装通常包括以下几个步骤: 访问Node.js官网: 打开Node.js的官方网站(如:https://nodejs.org/zh-cn/download/)。 下载安装包: 根据你的操作系统选择对应的Node…

Kotlin编程实践-【Java如何调用Kotlin中带默认值参数的函数】

问题 如果你有一个带有默认参数值的 Kotlin 函数,如何从 Java 调用它而无须为每个参数显式指定值? 方案 为函数添加注解JvmOverloads。 也就是为Java添加重载方法,这样Java调用Kotlin的方法时就不用传递全部的参数了。 示例 在 Kotlin …

python中scrapy

安装环境 pip install scrapy 发现Twisted版本不匹配 卸载pip uninstall Twisted 安装 pip install Twisted22.10.0 新建scrapy项目 scrapy startproject 项目名 注意:项目名称不允许使用数字开头,也不能包含中文 eg: scrapy startproject scrapy_baidu_…

【vue baidu-map】解决更新数据,bm-marker显示不完全问题

实现效果&#xff1a; 问题&#xff1a;切换上面基地tab键&#xff0c;导致地图图标展示不完全&#xff1b;刷新页面就可以正常展示。判断是<bm-marker>标记元素没有动态刷新dom元素引起的问题。 方案&#xff1a;this.$nextTick({}) this.$nextTick(()>{this.equipm…

Python复数的加、减、乘、除运算

一、复数 复数由实部和虚部组成&#xff0c;形如(a,b均为实数)的数为复数&#xff0c;其中&#xff0c;a被称为实部&#xff0c;b被称为虚部&#xff0c;i为虚数单位&#xff0c;。复数通常用z表示&#xff0c;即zabi&#xff0c;当z的虚部b&#xff1d;0时&#xff0c;则z为实…

云电脑有多好用?适合哪些人使用?

云电脑作为一种新型的计算模式&#xff0c;其应用场景广泛且多样&#xff0c;适合各类人群使用。云电脑适合什么人群使用&#xff1f;云电脑有哪些应用场景&#xff1f;有什么好的云电脑推荐&#xff1f;以下本文将详细探讨云电脑的主要应用场景及其适用人群的相关内容&#xf…

怎样搭建serveru ftp个人服务器

首先说说什么是ftp&#xff1f; FTP协议是专门针对在两个系统之间传输大的文件这种应用开发出来的&#xff0c;它是TCP/IP协议的一部分。FTP的意思就是文件传输协议&#xff0c;用来管理TCP/IP网络上大型文件的快速传输。FTP早也是在Unix上开发出来的&#xff0c;并且很长一段…

【Android】基于webView打造富文本编辑器(H5)

目录 前言一、实现效果二、具体实现1. 导入网页资源2. 页面设计3. 功能调用4. 完整代码 总结 前言 HTML5是构建Web内容的一种语言描述方式。HTML5是Web中核心语言HTML的规范&#xff0c;用户使用任何手段进行网页浏览时看到的内容原本都是HTML格式的&#xff0c;在浏览器中通过…

Boosting原理代码实现

1&#xff0e;提升方法是将弱学习算法提升为强学习算法的统计学习方法。在分类学习中&#xff0c;提升方法通过反复修改训练数据的权值分布&#xff0c;构建一系列基本分类器&#xff08;弱分类器&#xff09;&#xff0c;并将这些基本分类器线性组合&#xff0c;构成一个强分类…

什么是拷贝?我:Ctrl + C ...

前言 当谈及拷贝&#xff0c;你的第一印象会不会和我一样&#xff0c;ctrl c ctrl v ... &#xff1b;虽然效果和拷贝是一样的&#xff0c;但是你知道拷贝的原理以及它的实现方法吗&#xff1f;今天就让我们一起探究一下拷贝中深藏的知识点吧。 拷贝 首先来看下面一段代码…

Vue10-实战快速上手

实战快速上手 我们采用实战教学模式并结合ElementUI组件库&#xff0c;将所需知识点应用到实际中&#xff0c;以最快速度带领大家掌握Vue的使用&#xff1b; 1、创建工程 注意&#xff1a;命令行都要使用管理员模式运行 1、创建一个名为hello-vue的工程vue init webpack hel…

iview 组件里面的(任何一个月)整月日期全部选中_iview时间轴选中有历史记录日期

iview 组件里面的整月日期全部选中&#xff1a; ①&#xff1a;第一种是当前月的日期全部选中&#xff1a; 先上效果图&#xff1a;当前月分 获取到的值&#xff1a; 当前月的方法&#xff1a; // getDateStr() {// var curDate new Date();// var curMonth curDate.ge…

【HTML01】HTML基础-基本元素-附带案例-作业

文章目录 HTML 概述学HTML到底学什么HTML的基本结构HTML的注释的作用html的语法HTML的常用标签&#xff1a;相关单词参考资料 HTML 概述 英文全称&#xff1a;Hyper Text Markup Language 中文&#xff1a;超文本标记语言&#xff0c;就将常用的50多个标记嵌入在纯文本中&…

python pytest 参数化的几种方式

一、使用pytest.mark.parametrize装饰器&#xff1a; 可以使用pytest提供的pytest.mark.parametrize装饰器来指定参数化测试的参数。下面是一个示例&#xff1a; import pytest# pytest.mark.parametrize装饰器 # 其中num expected&#xff0c;分别对应(1, 1),(2, 4),(3, 9)&…

c语言中的字符函数

1.字符分类函数 c语言中有一系列函数是专门做字符分类的&#xff0c;也就是一个字符属于什么类型的字符。这些函数的使用需要包含一个头文件是ctype.h 可能你看这些感觉很懵&#xff0c;我以islower举例 #include<ctype.h> int main() {int retislower(A);printf("…