深度学习——图像分类(CNN)—训练模型

训练模型

    • 1.导入必要的库
    • 2.定义超参数
    • 3.读取训练和测试标签CSV文件
    • 4.确保标签是字符串类型
    • 5.显示两个数据框的前几行以了解它们的结构
    • 6.定义图像处理参数
    • 7.创建图像数据生成器
    • 8.设置目录路径
    • 9.创建训练和验证数据生成器
    • 10.构建模型
    • 11.编译模型
    • 12.训练模型并收集历史
    • 13.绘制损失和准确率曲线
    • 14.保存图表
    • 15.保存模型到本地

1.导入必要的库

pandas as pd: Pandas是一个强大的数据分析和处理库,它提供了数据结构(如DataFrame)和工具,用于数据操作和分析。
tensorflow.keras.preprocessing.image import ImageDataGenerator: ImageDataGenerator是Keras的一部分,它用于图像数据的预处理和增强,例如,随机裁剪、旋转、缩放等。
tensorflow.keras.models import Sequential: Sequential模型是Keras中的一种模型,它允许您顺序地堆叠层。
tensorflow.keras.layers: 包含了Keras中所有的层类型,如Conv2D、MaxPooling2D、Flatten、Dense等。
tensorflow.keras.optimizers: 包含了Keras中所有的优化器类型,如Adam、SGD等。
sklearn.model_selection import train_test_split: train_test_split是Scikit-Learn的一部分,它用于将数据集分割为训练集和测试集。
numpy as np: NumPy是一个用于科学计算的库,它提供了高效的数组处理能力,对于图像处理等任务非常有用。
sklearn.preprocessing import LabelBinarizer: LabelBinarizer是Scikit-Learn的一部分,它用于将类别标签转换为二进制数组。
matplotlib.pyplot as plt: Matplotlib是一个绘图库,pyplot是其中的一个模块,它提供了一个类似于MATLAB的绘图框架。
import pickle: pickle是Python的标准库,它用于序列化Python对象,以便将它们保存到文件或从文件中加载。

import pandas as pd
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout
from tensorflow.keras.optimizers import Adam
from sklearn.model_selection import train_test_split
import numpy as np
from sklearn.preprocessing import LabelBinarizer
import matplotlib.pyplot as plt
import pickle

2.定义超参数

INIT_LR = 0.01
EPOCHS = 30
BS = 32

3.读取训练和测试标签CSV文件

train_labels.csv和test_labels.csv在资源中。

# 读取训练标签CSV文件
train_labels_filename = 'train_labels.csv'
train_labels_df = pd.read_csv(train_labels_filename)# 读取测试标签CSV文件
test_labels_filename = 'test_labels.csv'
test_labels_df = pd.read_csv(test_labels_filename)

4.确保标签是字符串类型

train_labels_df[‘label’] = train_labels_df[‘label’].astype(str):

train_labels_df['label']:这是train_labels_df DataFrame中名为label的列。
.astype(str):这是Pandas中的一个方法,用于将列的数据类型转换为字符串类型。

test_labels_df[‘label’] = test_labels_df[‘label’].astype(str):

test_labels_df['label']:这是test_labels_df DataFrame中名为label的列。
.astype(str):这是Pandas中的一个方法,用于将列的数据类型转换为字符串类型。

train_labels_df['label'] = train_labels_df['label'].astype(str)
test_labels_df['label'] = test_labels_df['label'].astype(str)

5.显示两个数据框的前几行以了解它们的结构

print(train_labels_df.head())
print(test_labels_df.head())

6.定义图像处理参数

img_width:这是一个变量,用于存储图像的宽度。
img_height:这是一个变量,用于存储图像的高度。
= 150, 150:这行代码将img_width和img_height变量分别设置为150。

img_width, img_height = 150, 150

7.创建图像数据生成器

ImageDataGenerator:这是Keras中的一个类,用于创建一个数据生成器,用于图像数据的增强和预处理。
rescale=1./255:这是一个参数,用于将图像的像素值从0到255的范围转换为0到1的范围,这是常见的图像预处理步骤。
validation_split=0.2:这是一个参数,用于指定训练数据中用于验证的比例。在这里,20%的数据将用于验证,80%的数据将用于训练。
data_gen:这是生成的ImageDataGenerator对象,它将在后续的训练过程中用于生成增强的图像数据。

data_gen = ImageDataGenerator(rescale=1./255, validation_split=0.2)

8.设置目录路径

train和test压缩文件在资源中

# 并且数据集应该存储在环境可访问的路径中
train_dir = 'D:/rgzn/face/DATASET/train'  # 包含子文件夹的父目录
test_dir = 'D:/rgzn/face/DATASET/test'    # 包含子文件夹的父目录

9.创建训练和验证数据生成器

#flow_from_dataframe:这是Keras中的一个方法,用于创建一个数据生成器,它可以从DataFrame中加载图像和标签。
train_data_gen = data_gen.flow_from_dataframe(#要加载的数据源
dataframe=train_labels_df,
#包含图像文件的目录
directory=train_dir,  
#DataFrame中包含图像路径的列名。
x_col='image',
#DataFrame中包含标签的列名。
y_col='label',
#目标图像的大小
target_size=(img_width, img_height),
#每次迭代中从数据生成器中获取的样本数量。
batch_size=32,
#随机种子,用于确保每次运行时生成相同的数据增强
seed=42,
#数据集的子集,用于训练。subset='training',
)
validation_data_gen = data_gen.flow_from_dataframe(dataframe=test_labels_df,directory=test_dir,  # 包含子文件夹的父目录x_col='image',y_col='label',target_size=(img_width, img_height),batch_size=32,
seed=42,
#数据集的子集,用于验证。subset='validation',
)

10.构建模型

# 构建模型
model = Sequential()
model.add(Conv2D(32, (3, 3), activation='relu', input_shape=(150, 150, 3)))
model.add(MaxPooling2D(pool_size=(2, 2)))# 新增的卷积层
model.add(Conv2D(64, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))model.add(Conv2D(128, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))# 展平层
model.add(Flatten())# 全连接层
model.add(Dense(128, activation='relu'))
model.add(Dropout(0.5))# 输出层
model.add(Dense(7, activation='softmax'))

11.编译模型

model.compile(loss='categorical_crossentropy',optimizer='adam',metrics=['accuracy'])

model:这是之前创建和配置的Keras模型。
compile:这是Keras中的一个方法,用于编译模型,指定训练过程中使用的损失函数、优化器和评估指标。
loss='categorical_crossentropy':这是模型使用的损失函数,适用于多类分类问题。
optimizer='adam':这是模型使用的优化器,用于调整模型的权重以最小化损失函数。
metrics=['accuracy']:这是模型使用的评估指标,用于评估模型在训练数据上的性能。

12.训练模型并收集历史

history = model.fit(train_data_gen, epochs=EPOCHS, validation_data=validation_data_gen, batch_size=BS)

fit:这是Keras中的一个方法,用于训练模型。
train_data_gen:这是之前创建的训练数据生成器。
epochs=EPOCHS:这是训练过程中重复训练数据的次数。
validation_data=validation_data_gen:这是用于验证模型的数据。
batch_size=BS:这是每次迭代中从数据生成器中获取的样本数量。
history:这是训练过程中记录的性能指标,如损失和准确率。

13.绘制损失和准确率曲线

N = np.arange(0, EPOCHS)
#设置图表的样式
plt.style.use('ggplot')
plt.figure()plt.plot(N, history.history['loss'], label='train_loss')
plt.plot(N, history.history['val_loss'], label='val_loss')
plt.plot(N, history.history['accuracy'], label='train_acc')
plt.plot(N, history.history['val_accuracy'], label='val_acc')plt.title("Training Loss And Accuracy (CNN)")
plt.xlabel('Epoch #')
plt.ylabel('Loss/Accuracy')
plt.legend()
plt.axis([0, EPOCHS, 0, 2])

14.保存图表

plt.savefig('plot.png')

15.保存模型到本地

print('[INFO] 正在保存模型')
model.save('model.h5')

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/329516.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

excel转pdf并且加水印,利用ByteArrayOutputStream内存流不产生中间文件

首先先引入包&#xff1a;加水印和excel转PDF的 <dependency><groupId>com.itextpdf</groupId><artifactId>itextpdf</artifactId><version>5.5.12</version></dependency><dependency><groupId>org.apache.poi&l…

jenkins插件之xunit

安装jenkins插件 搜索xunit并安装 项目配置 配置 - Build Steps 您的项目 - 配置 - Build Steps, 新增 Run with timeout 超时时间根据实际情况配置 Build Step选择 执行SHELL 填写一下命令&#xff0c;这个命令是docker中执行phpunit单元测试&#xff0c;请根据你的实际…

FPGA学习笔记之Nios II(一)简单介绍及新建工程及下载

系列文章目录 文章目录 系列文章目录前言QsysNios IIhello world 实例Platform DesignNios II程序设计 前言 利用Quartus中的Qsys工具&#xff0c;可以实现在FPGA里面跑嵌入式的功能 Qsys Altera 公司将主控制器、数字信号处理模块、存储器及其控制模块、各种接口协议等模块&…

亚马逊测评还能做吗?

只能说测评不是唯一的手段&#xff0c;但是推销量的一把好手。首先测评能让listing快速成长&#xff0c;短期内有望成为爆款&#xff0c;速度快&#xff0c;利润高&#xff0c;回款快。相对其他推广&#xff0c;测评无疑是有效&#xff0c;省培养listing的方法。其次新品前期太…

聊聊 JSON Web Token (JWT) 和 jwcrypto 的使用

哈喽大家好&#xff0c;我是咸鱼。 最近写的一个 Python 项目用到了 jwcrypto 这个库&#xff0c;这个库是专门用来处理 JWT 的&#xff0c;JWT 全称是 JSON Web Token &#xff0c;JSON 格式的 Token。 今天就来简单入门一下 JWT。 官方介绍&#xff1a;https://jwt.io/intr…

RH850F1KM-S4-100Pin_ R7F7016453AFP MCAL Gpt 配置

1、Gpt组件包含的子配置项 GptDriverConfigurationGptDemEventParameterRefsGptConfigurationOfOptApiServicesGptChannelConfigSet2、GptDriverConfiguration 2.1、GptAlreadyInitDetCheck 该参数启用/禁用Gpt_Init API中的GPT_E_ALREADY_INITIALIZED Det检查。 true:开启Gpt_…

JS核心语法【流程控制语句、函数】;DOM【查找元素、操作元素、事件】--学习JavaEE的day48

day48 JS核心技术 JS核心语法 继day47 注意&#xff1a;用到控制台输出、弹窗 流程控制语句 If else、For、For-in(遍历数组时&#xff0c;跟Java是否一样【java没有】)、While、Do while、break、continue 案例&#xff1a; 1.求1-100之间的偶数之和 <!DOCTYPE html> …

Android消息机制回顾(Handler、Looper、MessageQueue源码解析)

回顾&#xff1a; Android消息机制 Android消息机制主要指的是Handler的运行机制以及Handler所附带的MessageQueue和Looper的工作机制。 介绍 通过Handler 消息机制来解决线程之间通信问题&#xff0c;或者用来切换线程。特别是在更新UI界面时&#xff0c;确保了线程间的数…

5.23 学习总结

一.项目优化&#xff08;语音通话&#xff09; 实现步骤&#xff1a; 1.用户发送通话申请&#xff0c;并处理通话请求&#xff0c;如果同意&#xff0c;为两个用户之间进行连接。 2.获取到电脑的麦克风和扬声器&#xff0c;将获取到的语音信息转换成以字节数组的形式传递。 …

基于FPGA的图像直方图均衡化处理verilog实现,包含tb测试文件和MATLAB辅助验证

目录 1.算法运行效果图预览 2.算法运行软件版本 3.部分核心程序 4.算法理论概述 5.算法完整程序工程 1.算法运行效果图预览 FPGA的仿真图如下&#xff1a; 将数据导入MATLAB&#xff0c;对比结果如下&#xff1a; 2.算法运行软件版本 MATLAB2022a vivado2019.2 3.部分…

【Android安全】AOSP版本对应编号| AOSP版本适配Pixel或Nexus型号 | 驱动脚本下载地址

AOSP版本对应编号 https://source.android.com/docs/setup/about/build-numbers?hlzh-cn#source-code-tags-and-builds 例如android-8.1.0_r1 对应的编号是OPM1.171019.011 可以适配Pixel 2 XL AOSP驱动脚本下载 编译AOSP时&#xff0c;需要Google的驱动&#xff0c;后面才…

Jenkins 构建 Maven 项目:项目和服务器在一起的情况

bash.sh内容 #!/bin/bash#删除历史数据 rm -rf ruoyi-admin.jar# appname$1 appnamevideo.xxxxx.com #获取传入的参数 echo "arg:$appname"#获取正在运行的jar包pid # pidps -ef | grep $1 | grep java -jar | awk {printf $2} pidps -ef | grep $appname | grep ja…

高铁VR虚拟全景展示提升企业实力和形象

步入VR的神奇世界&#xff0c;感受前所未有的汽车展示体验。VR虚拟现实技术以其独特的沉浸式模拟&#xff0c;让你仿佛置身于真实展厅之中&#xff0c;尽情探索汽车的每一处细节。 一、定制化展示&#xff0c;随心所欲 VR汽车虚拟展厅打破空间束缚&#xff0c;让汽车制造商能够…

从零开始傅里叶变换

从零开始傅里叶变换 1 Overview2 傅里叶级数2.1 基向量2.2 三角函数系表示 f ( t ) f(t) f(t)2.2.1 三角函数系的正交性2.2.2 三角函数系的系数 2.3 复指数函数系表示 f ( t ) f(t) f(t)2.3.1 复指数函数系的系数2.3.2 复指数函数系的正交性 2.4 傅里叶级数总结 3 傅里叶变换…

基于轻量级神经网络GhostNet开发构建CIFAR100数据集场景下的图像识别分析系统,对比不同分辨路尺度下模型的性能情况

Cifar100数据集是一个经典的图像分类数据集&#xff0c;常用于计算机视觉领域的研究和算法测试。以下是关于Cifar100数据集的详细介绍&#xff1a; 数据集构成&#xff1a;Cifar100数据集包含60000张训练图像和10000张测试图像。其中&#xff0c;训练图像分为100个类别&#x…

webgl入门-绘制三角形

绘制三角形 前言 三角形是一个最简单、最稳定的面&#xff0c;webgl 中的三维模型都是由三角面组成的。咱们这一篇就说一下三角形的绘制方法。 课堂目标 理解多点绘图原理。可以绘制三角形&#xff0c;并将其组合成多边形。 知识点 缓冲区对象点、线、面图形 第一章 web…

C# run Node.js

C# run nodejs Inter-Process Communication&#xff0c;IPC Process类 启动Node.js进程&#xff0c;通过标准输入输出与其进行通信。 // n.js// 监听来自标准输入的消息 process.stdin.on(data, function (data) {// 收到消息后&#xff0c;在控制台输出并回复消息console.l…

C++设计模式---面向对象原则

面向对象设计原则 原则的目的&#xff1a;高内聚&#xff0c;低耦合 1. 单一职责原则 类的职责单一&#xff0c;对外只提供一种功能&#xff0c;而引起类变化的原因都应该只有一个。 2. 开闭原则 对扩展开放&#xff0c;对修改关闭&#xff1b;增加功能是通过增加代码来实现的&…

探索 Rust 语言的精髓:深入 Rust 标准库

探索 Rust 语言的精髓&#xff1a;深入 Rust 标准库 Rust&#xff0c;这门现代编程语言以其内存安全、并发性和性能优势而闻名。它不仅在系统编程领域展现出强大的能力&#xff0c;也越来越多地被应用于WebAssembly、嵌入式系统、分布式服务等众多领域。Rust 的成功&#xff0…

计算机网络数据链路层知识点总结

3.1 数据链路层功能概述 &#xff08;1&#xff09;知识总览 &#xff08;2&#xff09;数据链路层的研究思想 &#xff08;3&#xff09;数据链路层基本概念 &#xff08;4&#xff09;数据链路层基本功能 3.1 封装成帧和透明传输 &#xff08;1&#xff09;数据链路层功能…