使用 Keras 训练一个卷积神经网络(CNN)(入门篇)

在上一篇文章中,我们介绍了如何使用 Keras 训练一个简单的全连接神经网络(MLP)。本文将带你深入学习如何使用 Keras 构建和训练一个卷积神经网络(CNN),用于图像分类任务。我们将继续使用 MNIST 数据集,但这次我们将采用更适合图像数据的 CNN 架构。

目录

  1. 环境准备
  2. 导入必要的库
  3. 加载和预处理数据
  4. 构建卷积神经网络模型
  5. 编译模型
  6. 训练模型
  7. 评估模型
  8. 保存和加载模型
  9. 可视化训练过程
  10. 总结

1. 环境准备

确保你已经安装了 Python(推荐 3.6 及以上版本)和 TensorFlow(Keras 已集成在 TensorFlow 中)。如果尚未安装,请运行以下命令:

pip install tensorflow

2. 导入必要的库

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, models
import numpy as np
import matplotlib.pyplot as plt
  • tensorflow: 深度学习框架,Keras 已集成其中。
  • numpy: 用于数值计算。
  • matplotlib.pyplot: 用于数据可视化。

3. 加载和预处理数据

我们继续使用 Keras 自带的 MNIST 数据集。

# 加载 MNIST 数据集
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()# 查看数据形状
print(f"训练数据形状: {x_train.shape}, 训练标签形状: {y_train.shape}")
print(f"测试数据形状: {x_test.shape}, 测试标签形状: {y_test.shape}")# 数据预处理
# 归一化:将像素值缩放到 0-1 之间
x_train = x_train.astype("float32") / 255.0
x_test = x_test.astype("float32") / 255.0# CNN 需要添加通道维度
x_train = np.expand_dims(x_train, -1)  # 形状变为 (60000, 28, 28, 1)
x_test = np.expand_dims(x_test, -1)    # 形状变为 (10000, 28, 28, 1)# 将标签转换为分类编码
num_classes = 10
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)# 可视化部分数据
plt.figure(figsize=(10,10))
for i in range(25):plt.subplot(5,5,i+1)plt.xticks([])plt.yticks([])plt.grid(False)plt.imshow(x_train[i].reshape(28, 28), cmap=plt.cm.binary)plt.xlabel(np.argmax(y_train[i]))
plt.show()

说明:

  • CNN 需要输入数据具有通道维度,因此使用 np.expand_dims 添加一个维度。
  • MNIST 数据集是灰度图像,因此通道维度为 1。

4. 构建卷积神经网络模型

我们将构建一个简单的 CNN 模型,包含两个卷积层和两个池化层,最后接上全连接层进行分类。

model = models.Sequential([layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),  # 卷积层,32 个 3x3 卷积核layers.MaxPooling2D((2, 2)),  # 最大池化层,池化窗口 2x2layers.Conv2D(64, (3, 3), activation='relu'),  # 卷积层,64 个 3x3 卷积核layers.MaxPooling2D((2, 2)),  # 最大池化层layers.Flatten(),  # 展平层layers.Dense(64, activation='relu'),  # 全连接层,64 个神经元layers.Dense(num_classes, activation='softmax')  # 输出层,10 个神经元
])# 查看模型结构
model.summary()

说明:

  • Conv2D: 二维卷积层,用于提取图像特征。
  • MaxPooling2D: 最大池化层,用于下采样,减少参数数量。
  • Flatten: 将多维输入一维化,以便连接全连接层。
  • Dense: 全连接层,用于分类。

5. 编译模型

model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy'])

说明:

  • 使用 Adam 优化器和交叉熵损失函数。
  • 评估指标为准确率。

6. 训练模型

# 设置训练参数
batch_size = 128
epochs = 10# 训练模型
history = model.fit(x_train, y_train,batch_size=batch_size,epochs=epochs,validation_split=0.1)  # 使用 10% 的训练数据作为验证集

说明:

  • 使用 10% 的训练数据作为验证集,以监控模型在验证集上的性能。

7. 评估模型

test_loss, test_acc = model.evaluate(x_test, y_test, verbose=2)
print(f"\n测试准确率: {test_acc:.4f}")

8. 保存和加载模型

# 保存模型
model.save("mnist_cnn_model.h5")# 加载模型
new_model = keras.models.load_model("mnist_cnn_model.h5")

9. 可视化训练过程

# 绘制训练 & 验证的准确率和损失值
plt.figure(figsize=(12,4))# 准确率
plt.subplot(1,2,1)
plt.plot(history.history['accuracy'], label='训练准确率')
plt.plot(history.history['val_accuracy'], label='验证准确率')
plt.xlabel('Epoch')
plt.ylabel('准确率')
plt.legend(loc='lower right')
plt.title('训练与验证准确率')# 损失值
plt.subplot(1,2,2)
plt.plot(history.history['loss'], label='训练损失')
plt.plot(history.history['val_loss'], label='验证损失')
plt.xlabel('Epoch')
plt.ylabel('损失')
plt.legend(loc='upper right')
plt.title('训练与验证损失')plt.show()

说明:

  • 通过可视化训练过程中的准确率和损失值,可以帮助我们了解模型的训练情况,判断是否存在过拟合或欠拟合。

10. 本节回顾

本节介绍了如何使用 Keras 构建和训练一个简单的卷积神经网络(CNN),用于手写数字识别任务。主要步骤包括:

  1. 环境准备和库导入: 确保安装了必要的库,并导入所需模块。
  2. 数据加载和预处理: 加载 MNIST 数据集,进行归一化,并添加通道维度。
  3. 构建 CNN 模型: 使用 Conv2D、MaxPooling2D、Flatten、Dense 等层构建模型。
  4. 编译模型: 指定优化器、损失函数和评估指标。
  5. 训练模型: 使用训练数据训练模型,并使用验证集监控性能。
  6. 评估模型: 在测试集上评估模型性能。
  7. 保存和加载模型: 将训练好的模型保存到磁盘,并可加载进行预测。
  8. 可视化训练过程: 通过绘制准确率和损失值曲线,了解模型的训练情况。

通过这个基础教程,你可以开始自行探索更复杂的 CNN 模型和更深入的应用,如图像分类、目标检测、图像分割等。

导师简介

前腾讯电子签的前端负责人,现 whentimes tech CTO,专注于前端技术的大咖一枚!一路走来,从小屏到大屏,从 Web 到移动,什么前端难题都见过。热衷于用技术打磨产品,带领团队把复杂的事情做到极简,体验做到极致。喜欢探索新技术,也爱分享一些实战经验,帮助大家少走弯路!

温馨提示:可搜老码小张公号联系导师

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/471174.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【C语言】值传递和地址传递

值传递 引用传递(传地址,传引用)的区别 传值,是把实参的值赋值给行参 ,那么对行参的修改,不会影响实参的值。 传地址,是传值的一种特殊方式,只是他传递的是地址,不是普通…

C语言入门到精通(第六版)——第十六章

16、网络套接字编程 16.1、计算机网络基础 计算机网络技术是计算机技术和通信技术相结合的产物,代表计算机的一个重要发展方向。了解计算机的网络结构,有助于用户开发网络应用程序。 16.1.1、IP地址 为了使网络上的计算机能够彼此识别对方,…

Cyberchef配合Wireshark提取并解析HTTP/TLS流量数据包中的文件

本文将介绍一种手动的轻量级的方式,还原HTTP/TLS协议中传输的文件,为流量数据包中的文件分析提供帮助。 如果捕获的数据包中存在非文本类文件,例如png,jpg等图片文件,或者word,Excel等office文件异或是其他类型的二进…

记录使用documents4j来将word文件转化为pdf文件

本文记录使用documents4j来将word文件转化为pdf文件 文章目录 程序实例maven导入代码实现程序结果 本文小结 程序实例 maven导入 <!--word转pdf--><dependency><groupId>com.documents4j</groupId><artifactId>documents4j-local</artifactI…

SQL面试题——奔驰SQL面试题 车辆在不同驾驶模式下的时间

SQL面试题——奔驰SQL面试题 我们的表大致如下 CREATE TABLE signal_log( vin STRING COMMENTvehicle frame id, signal_name STRING COMMENTfunction name, signal_value STRING COMMENT signal value , ts BIGINT COMMENTevent timestamp, dt STRING COMMENTformat yyyy-mm…

使用 unicorn 和 capstone 库来模拟 ARM Thumb 指令的执行(一)

import binascii import unicorn import capstonedef printArm32Regs(mu):for i in range(66,78):print("R%d,value:%x"%(i-66,mu.reg_read(i)))def testhumb():CODE b\x1C\x00\x0A\x46\x1E\x00"""MOV R3, R0 的机器码&#xff1a;0x1C 0x00&#xf…

WordPress 6.7 “Rollins”发布

每个 WordPress 版本都会向一位在音乐界留下不可磨灭印记的艺术家致敬。WordPress 6.7 的代号为“Rollins”&#xff0c;旨在向传奇爵士萨克斯演奏家桑尼罗林斯致敬。罗林斯是爵士乐界最伟大的即兴演奏家和先驱之一&#xff0c;他以精湛的技术、创新精神和无畏的音乐表达方式影…

844.比较含退格的字符串

java用 O&#xff08;1&#xff09;空间这个方法&#xff0c;容易挺多bug的… O&#xff08;1&#xff09;空间 #&#xff1a;删除前一个字符 》 从后面开始判断&#xff08;这样可以用跳过的思想&#xff09;不能使用两次 i- - 来处理 # 的操作&#xff0c;会造成误删了前面…

WLAN消失或者已连接但是访问不了互联网

目录 1、WLAN已连接但是访问不了互联网 2、WLAN图标消失 今晚电脑突然连不上网了&#xff0c;重启试了好多种办法都没有用。 1、WLAN已连接但是访问不了互联网 这个的问题很多&#xff0c;建议直接网络重置&#xff0c;即将网络驱动全部删除&#xff0c;然后重新安装。 首先…

Linux源码阅读笔记-V4L2框架基础介绍

V4L2视频设备驱动基础 V4L2 是专门为 Linux 设备设计的整套视频框架&#xff08;其主要核心在 Linux 内核&#xff0c;相当于 Linux 操作系统上层的视频源捕获驱动框架&#xff09;。为上层访问系统底层的视频设备提供一个统一的标准接口。V4L2 驱动框架能够支持多种类型设备&…

C 语言 【模拟实现内存库函数】

1、memcpy memcpy函数是C/C语言中的一个用于内存复制的函数&#xff0c;声明在 string.h 中&#xff08;C是 cstring&#xff09;。其原型是&#xff1a; void * memcpy ( void * destination, const void * source, size_t num ); 其中&#xff0c;destination表示的是要拷贝…

【大数据学习 | flume】flume的概述与组件的介绍

1. flume概述 Flume是cloudera(CDH版本的hadoop) 开发的一个分布式、可靠、高可用的海量日志收集系统。它将各个服务器中的数据收集起来并送到指定的地方去&#xff0c;比如说送到HDFS、Hbase&#xff0c;简单来说flume就是收集日志的。 Flume两个版本区别&#xff1a; ​ 1&…

01:(手撸HAL+CubeMX)时钟篇

&#xff08;手撸HALCubeMX&#xff09;时钟篇 1、对SystemInit函数的分析2、使用HSI将总线时钟配置为最高频率3、使用HSE将总线时钟配置为最高频率4、使用Cube配置时钟树的参数5、对HAL_Init函数分析6、对系统定时器中断服务函数分析 有关时钟树和上电/复位的基础知识请参考“…

大数据新视界 -- 大数据大厂之 Impala 存储格式转换:从原理到实践,开启大数据性能优化星际之旅(下)(20/30)

&#x1f496;&#x1f496;&#x1f496;亲爱的朋友们&#xff0c;热烈欢迎你们来到 青云交的博客&#xff01;能与你们在此邂逅&#xff0c;我满心欢喜&#xff0c;深感无比荣幸。在这个瞬息万变的时代&#xff0c;我们每个人都在苦苦追寻一处能让心灵安然栖息的港湾。而 我的…

Spring Boot实现文件上传与OSS集成:从基础到应用

目录 前言1. 文件上传的基础实现1.1 前端文件上传请求1.2 后端文件接收与保存 2. 集成第三方OSS服务2.1 准备工作2.2 编写OSS集成代码2.3 修改Controller实现文件上传至OSS 3. 文件上传的扩展&#xff1a;多文件上传与权限控制结语 前言 随着互联网应用的快速发展&#xff0c;…

微服务各组件整合

nacos 第一步&#xff0c;引入依赖 <dependency><groupId>com.alibaba.cloud</groupId><artifactId>spring-cloud-starter-alibaba-nacos-discovery</artifactId></dependency> 第二步&#xff0c;增加配置 spring:application:name: …

机器学习总结

机器学习按照模型类型分为监督学习模型&#xff0c;无监督学习模型和概率模型三大类&#xff1a; 下图是机器学习笔记思维导图&#xff0c;&#xff1a; 一.什么是机器学习 从本质上讲&#xff0c;可以认为机器学习就是在数据中寻找一种合适的函数来描述输入与输出之间的关系。…

WEB攻防-通用漏洞SQL注入sqlmapOracleMongodbDB2等

SQL注入课程体系&#xff1a; 1、数据库注入-access mysql mssql oracle mongodb postgresql 2、数据类型注入-数字型 字符型 搜索型 加密型&#xff08;base64 json等&#xff09; 3、提交方式注入-get post cookie http头等 4、查询方式注入-查询 增加 删除 更新 堆叠等 …

三、损失函数

损失函数 前言一、分类问题的损失函数1.1 二分类损失函数1.1.1 数学定义1.1.2 函数解释&#xff1a;1.1.3 性质1.1.4 计算演示1.1.5 代码演示 1.2 多分类损失函数1.1.1 数学定义1.1.2 性质与特点1.1.3 计算演示1.1.4 代码演示 二、回归问题的损失函数2.1 MAE损失2.2 MSE损失2.3…

PNG图片批量压缩exe工具+功能纯净+不改变原始尺寸

小编最近有一篇png图片要批量压缩&#xff0c;大小都在5MB之上&#xff0c;在网上找了半天要么就是有广告&#xff0c;要么就是有毒&#xff0c;要么就是功能复杂&#xff0c;整的我心烦意乱。 于是我自己用python写了一个纯净工具&#xff0c;只能压缩png图片&#xff0c;没任…