python3+TensorFlow 2.x(五)CNN

目录

CNN理解

code实现人脸识别

数据集准备:

code实现

模型解析

结果展示

结果探讨

基于vgg16的以图搜图

数据准备

图库database

检索测试集datatest

code实现

code解析

结果展示


CNN理解

卷积神经网络(CNN)是深度学习中最强大的架构之一,广泛应用于图像处理、目标检测和计算机视觉任务。TensorFlow 2.x 提供了非常简便的接口来实现卷积神经网络。

神经网络(CNN)个人理解

code实现人脸识别

数据集准备:

人脸表情数据集fer2013.csv,python3+TensorFlow 2.x直接下载可能遇到问题,可以下载到本地然后加载。

百度网盘:链接:  人脸表情数据集fer2013 提取码: ph63

code实现

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelBinarizer
import tensorflow as tf
from tensorflow.keras import layers, models
import random# 1. 加载数据集
data = pd.read_csv('fer2013.csv')# 2. 数据预处理
# 将像素值转换为 numpy 数组
X = np.array([np.fromstring(img, sep=' ') for img in data['pixels']])
X = X.reshape(-1, 48, 48, 1)  # 48x48 像素,单通道(灰度图像)
X = X.astype('float32') / 255.0  # 归一化# 标签处理
y = data['emotion'].values
y = LabelBinarizer().fit_transform(y)  # 标签二值化# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 3. 定义卷积神经网络模型
model = models.Sequential()
model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(48, 48, 1)))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(128, (3, 3), activation='relu'))
model.add(layers.Flatten())
model.add(layers.Dense(128, activation='relu'))
model.add(layers.Dense(7, activation='softmax'))  # 7 类情感# 4. 编译模型
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])# 5. 训练模型
history = model.fit(X_train, y_train, epochs=30, batch_size=64, validation_data=(X_test, y_test))# 6. 评估模型
test_loss, test_acc = model.evaluate(X_test, y_test, verbose=2)
print('\nTest accuracy:', test_acc)# 7. 可视化训练过程
plt.figure(figsize=(12, 4))# 绘制损失
plt.subplot(1, 2, 1)
plt.plot(history.history['loss'], label='Train Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()# 绘制准确率
plt.subplot(1, 2, 2)
plt.plot(history.history['accuracy'], label='Train Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.title('Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()plt.show()# 8. 随机选择图像进行可视化识别结果
def visualize_predictions(X, y, model, num_images=5):# 随机选择图像索引random_indices = random.sample(range(X.shape[0]), num_images)plt.figure(figsize=(15, 6))for i, idx in enumerate(random_indices):plt.subplot(2, num_images//2, i + 1)# 显示图像plt.imshow(X[idx].reshape(48, 48), cmap='gray')plt.axis('off')# 进行预测prediction = model.predict(np.expand_dims(X[idx], axis=0))predicted_class = np.argmax(prediction)# 显示真实标签和预测标签plt.title(f'True: {np.argmax(y[idx])}, Pred: {predicted_class}')plt.show()# 可视化识别结果
visualize_predictions(X_test, y_test, model, num_images=6)# 9. 保存模型
model.save('fer_model.h5')

模型解析

加载数据集:使用 pandas 加载 fer2013.csv 文件。如果文件不在当前工作目录中,你需要提供完整的文件路径。
数据预处理:将 pixels 列中的字符串转换为 NumPy 数组,并重塑为适合 CNN 的形状(48x48 像素,单通道)。将像素值归一化到 [0, 1] 范围。使用 LabelBinarizer 对情感标签进行二值化处理。
划分数据集:使用 train_test_split 将数据集划分为训练集和测试集。
定义模型:构建一个简单的卷积神经网络,包括卷积层、池化层和全连接层。
编译模型:使用 Adam 优化器和交叉熵损失函数编译模型。根据你的计算资源和数据集大小,调整 epochs 和 batch_size 的值,以获得更好的训练效果
训练模型:使用 fit() 方法训练模型,并保存训练过程中的历史记录。
评估模型:在测试集上评估模型的准确性,并输出结果。
可视化识别结果:定义 visualize_predictions 函数,随机选择图像进行预测,并显示真实标签和预测标签。
保存模型:将训练好的模型保存到文件中,以便后续使用。

结果展示

 

保存了模型可以直接加载保存的训练模型进行测试 (前提测试集已经划分)

loaded_model = tf.keras.models.load_model('fer_model.h5')
# 选择要测试的图像数量
num_images = 4# 随机选择图像索引
random_indices = np.random.choice(X_test.shape[0], num_images, replace=False)plt.figure(figsize=(15, 6))for i, idx in enumerate(random_indices):plt.subplot(2, num_images // 2, i + 1)# 显示图像plt.imshow(X_test[idx].reshape(48, 48), cmap='gray')plt.axis('off')# 进行预测prediction = loaded_model.predict(np.expand_dims(X_test[idx], axis=0))predicted_class = np.argmax(prediction)# 显示真实标签和预测标签plt.title(f'True: {np.argmax(y_test[idx])}, Pred: {predicted_class}')plt.show()

结果探讨

准确率只有50多,需要调节参数(学习率,迭代次数)或模型结构(卷积层数量,卷积核大小)等进行多次训练测试找到合适的模型。

基于vgg16的以图搜图

VGG16是一种深度卷积神经网络(CNN),由牛津大学​Visual Geometry Group​在2014年提出,主要用于图像分类任务。‌ VGG16以其深度和简洁性而闻名,是深度学习领域中的经典模型之一。该网络由多个卷积层和池化层交替堆叠而成,最后使用全连接层进行分类。VGG16的特点是使用了连续的小卷积核(3x3)和池化层,这使得网络可以构建得更深,通常可以达到16层或19层。

数据准备

图库database

检索测试集datatest

code实现

import os
import numpy as np
import cv2
import matplotlib.pyplot as plt
from tensorflow.keras.applications import VGG16
from tensorflow.keras.applications.vgg16 import preprocess_input
from tensorflow.keras.preprocessing import image
from sklearn.metrics.pairwise import cosine_similarity# 加载 VGG16 模型
model = VGG16(weights='imagenet', include_top=False, pooling='avg')# 图像预处理
def preprocess_image(image_path):img = image.load_img(image_path, target_size=(224, 224))img_array = image.img_to_array(img)img_array = np.expand_dims(img_array, axis=0)img_array = preprocess_input(img_array)return img_array# 提取特征
def extract_features(image_path):img_array = preprocess_image(image_path)features = model.predict(img_array)return features.flatten()# 计算相似度
def find_similar_images(query_image_path, image_folder, top_k=9):query_features = extract_features(query_image_path)similarities = []for image_name in os.listdir(image_folder):image_path = os.path.join(image_folder, image_name)if os.path.isfile(image_path):features = extract_features(image_path)sim = cosine_similarity(query_features.reshape(1, -1), features.reshape(1, -1))similarities.append((image_name, sim[0][0]))# 按相似度排序similarities.sort(key=lambda x: x[1], reverse=True)return similarities[:top_k]# 可视化搜索结果
def visualize_results(query_image_path, similar_images, image_folder):plt.figure(figsize=(10, 5))# 显示查询图像plt.subplot(1, len(similar_images) + 1, 1)plt.imshow(cv2.cvtColor(cv2.imread(query_image_path), cv2.COLOR_BGR2RGB))plt.title("Query Image")plt.axis('off')# 显示相似图像for i, (image_name, similarity) in enumerate(similar_images):image_path = os.path.join(image_folder, image_name)plt.subplot(1, len(similar_images) + 1, i + 2)plt.imshow(cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB))plt.title(f"{image_name}\nSim: {similarity:.4f}")plt.axis('off')plt.tight_layout()plt.show()# 主程序
if __name__ == "__main__":query_image = "datatest/cat.jpg"  # 替换为你的查询图像路径image_folder = "database"  # 替换为你的图像文件夹路径similar_images = find_similar_images(query_image, image_folder)print("与查询图像相似的图像:")for image_name, similarity in similar_images:print(f"{image_name}: 相似度 {similarity:.4f}")visualize_results(query_image, similar_images, image_folder)

code解析

环境设置:使用 TensorFlow 和 Keras 加载 VGG16 模型,并设置为不包含顶部分类层,使用平均池化。
图像预处理:使用 Keras 的 image 模块读取图像并进行预处理,包括调整大小和标准化。
特征提取:使用 VGG16 模型提取图像特征,并将其展平为一维数组。
相似度计算:使用余弦相似度计算查询图像与文件夹中每个图像的相似度。
可视化搜索结果:使用 Matplotlib 可视化查询图像和相似图像,并显示相似度。

结果展示

图像库里只有5张图,当query_image = "datatest/girl.jpg" top_k=3时,结果如下

当query_image = "datatest/cat.jpg", top_k=9时

确保你有足够的图像数据以进行相似度搜索。VGG16 模型较大,特征提取可能需要一些时间,具体取决于图像数量和计算资源。调整 top_k 参数以获取更多或更少的相似图像。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/8005.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

(一)HTTP协议 :请求与响应

前言 爬虫需要基础知识,HTTP协议只是个开始,除此之外还有很多,我们慢慢来记录。 今天的HTTP协议,会有助于我们更好的了解网络。 一、什么是HTTP协议 (1)定义 HTTP(超文本传输协议&#xff…

FPGA实现任意角度视频旋转(完结)视频任意角度旋转实现

本文主要介绍如何基于FPGA实现视频的任意角度旋转,关于视频180度实时旋转、90/270度视频无裁剪旋转,请见本专栏前面的文章,旋转效果示意图如下: 为了实时对比旋转效果,采用分屏显示进行处理,左边代表旋转…

如何移植ftp服务器到arm板子?

很多厂家提供的sdk,一般都不自带ftp服务器功能, 需要要发人员自己移植ftp服务器程序。 本文手把手教大家如何移植ftp server到arm板子。 环境 sdk:复旦微 Buildroot 2018.02.31. 解压 $ mkdir ~/vsftpd $ cp vsftpd-3.0.2.tar.gz ~/vs…

【阅读笔记】基于整数+分数微分的清晰度评价算子

本文介绍的是一种新的清晰度评价算子,整数微分算子分数微分算子 一、概述 目前在数字图像清晰度评价函数中常用的评价函数包括三类:灰度梯度评价函数、频域函数和统计学函数,其中灰度梯度评价函数具有计算简单,评价效果好等优点…

LabVIEW 保存文件 生产者/消费者设计

LabVIEW 保存文件 生产者/消费者设计 简介生产消费模式设计结构 简介 主从模式的数据通信是利用全局变量、局域变量或共享变量实现的,由于这些变量的每次复制都是原始数据的一个副本,占据了大量的空间。实际上,只需要使用一部分缓冲区作为数…

网络安全 | F5-Attack Signatures-Set详解

关注:CodingTechWork 创建和分配攻击签名集 可以通过两种方式创建攻击签名集:使用过滤器或手动选择要包含的签名。  基于过滤器的签名集仅基于在签名过滤器中定义的标准。基于过滤器的签名集的优点在于,可以专注于定义用户感兴趣的攻击签名…

宏_wps_宏修改word中所有excel表格的格式_设置字体对齐格式_删除空行等

需求: 将word中所有excel表格的格式进行统一化,修改其中的数字类型为“宋体, 五号,右对齐, 不加粗,不倾斜”,其中的中文为“宋体, 五号, 不加粗,不倾斜” 数…

项目集成RabbitMQ

文章目录 1.common-rabbitmq-starter1.创建common-rabbitmq-starter2.pom.xml3.自动配置1.RabbitMQAutoConfiguration.java2.spring.factories 2.测试使用1.创建common-rabbitmq-starter-demo2.目录结构3.pom.xml4.application.yml5.TestConfig.java 配置交换机和队列6.TestCon…

Shotcut新版来袭,新增HSL滤镜、硬件编码,剪辑更流畅

Shotcut 是一款功能强大、完全免费且开源的多平台视频编辑工具,适用于 Windows、macOS 和 Linux 系统。作为一款专业的视频编辑软件,它不仅支持数百种音频和视频格式的直接编辑,还提供了无需导入即可进行原生编辑的便捷功能。Shotcut 的核心优…

K8s运维管理平台 - xkube体验:功能较多

目录 简介Lic安装1、需要手动安装MySQL,**建库**2、启动命令3、[ERROR] GetNodeMetric Fail:the server is currently unable to handle the request (get nodes.metrics.k8s.io qfusion-1) 使用总结优点优化 补充1:layui、layuimini和beego的详细介绍1.…

BAHD酰基转移酶对紫草素的手性催化-文献精读105

Two BAHD Acyltransferases Catalyze the Last Step in the Shikonin/Alkannin Biosynthetic Pathway 两个BAHD酰基转移酶催化了紫草素/左旋紫草素生物合成途径中的最后一步 一个BAHD酰基转移酶专门催化紫草素的酰基化,而另一个BAHD酰基转移酶则仅催化紫草素的对映…

STM32完全学习——RT-thread在STM32F407上移植

一、写在前面 关于源码的下载,以及在KEIL工程里面添加操作系统的源代码,这里就不再赘述了。需要注意的是RT-thread默认里面是会使用串口的,因此需要额外的进行串口的初始化,有些人可能会问,为什么不直接使用CubMAX直接…

单片机内存管理剖析

一、概述 在单片机系统中,内存资源通常是有限的,因此高效的内存管理至关重要。合理地分配和使用内存可以提高系统的性能和稳定性,避免内存泄漏和碎片化问题。单片机的内存主要包括程序存储器(如 Flash)和数据存储器&a…

“AI质量评估系统:智能守护,让品质无忧

嘿,各位小伙伴们!今天咱们来聊聊一个在现代社会中越来越重要的角色——AI质量评估系统。你知道吗?在这个快速发展的时代,产品质量已经成为企业生存和发展的关键。而AI质量评估系统,就像是我们的智能守护神,…

人工智能:从基础到前沿

目录 目录 1. 引言 2. 人工智能基础 2.1 什么是人工智能? 2.2 人工智能的历史 2.3 人工智能的分类 3. 机器学习 3.1 机器学习概述 3.2 监督学习 3.3 无监督学习 3.4 强化学习 4. 深度学习 4.1 深度学习概述 4.2 神经网络基础 4.3 卷积神经网络&#…

Centos7系统php8编译安装ImageMagick/Imagick扩展教程整理

Centos7系统php8编译安装ImageMagick/Imagick扩展教程整理 安装php8安装ImageMagick1、下载ImageMagick2、解压并安装3、查看是否安装成功 安装imagick扩展包 安装php8 点我安装php8 安装ImageMagick 1、下载ImageMagick wget https://www.imagemagick.org/download/ImageMa…

基于微信阅读网站小程序的设计与实现(LW+源码+讲解)

专注于大学生项目实战开发,讲解,毕业答疑辅导,欢迎高校老师/同行前辈交流合作✌。 技术范围:SpringBoot、Vue、SSM、HLMT、小程序、Jsp、PHP、Nodejs、Python、爬虫、数据可视化、安卓app、大数据、物联网、机器学习等设计与开发。 主要内容:…

51单片机入门_02_C语言基础0102

C语言基础部分可以参考我之前写的专栏C语言基础入门48篇 以及《从入门到就业C全栈班》中的C语言部分,本篇将会结合51单片机讲差异部分。 课程主要按照以下目录进行介绍。 文章目录 1. 进制转换2. C语言简介3. C语言中基本数据类型4. 标识符与关键字5. 变量与常量6.…

【Numpy核心编程攻略:Python数据处理、分析详解与科学计算】1.2 ndarray解剖课:多维数组的底层实现

1.2 《ndarray解剖课:多维数组的底层实现》 内容介绍 NumPy 的 ndarray 是其核心数据结构,用于高效处理多维数组。在这篇文章中,我们将深入解析 ndarray 的底层实现,探讨其内存结构、维度、数据类型、步长等关键概念&#xff0c…

C++——list的了解和使用

目录 引言 forward_list与list 标准库中的list 一、list的常用接口 1.list的迭代器 2.list的初始化 3.list的容量操作 4.list的访问操作 5.list的修改操作 6.list的其他操作 二、list与vector的对比 结束语 引言 本篇博客要介绍的是STL中的list。 求点赞收藏评论…