CNN|ResNet-50

 导入数据

import matplotlib.pyplot as plt
# 支持中文
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号import os,PIL,pathlib
import numpy as npfrom tensorflow import keras
from tensorflow.keras import layers,models
data_dir = "/Users/yueyishen/jupter/data/bird_photos"data_dir = pathlib.Path(data_dir)

查看数据

image_count = len(list(data_dir.glob('*/*')))print("图片总数为:",image_count)
图片总数为: 565

数据预处理

1. 加载数据

使用image_dataset_from_directory方法将磁盘中的数据加载到tf.data.Dataset中

batch_size = 8
img_height = 224
img_width = 224
"""
关于image_dataset_from_directory()的详细介绍可以参考文章:https://mtyjkh.blog.csdn.net/article/details/117018789
"""
train_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.2,subset="training",seed=123,image_size=(img_height, img_width),batch_size=batch_size)

Found 565 files belonging to 4 classes. Using 452 files for training. 

"""
关于image_dataset_from_directory()的详细介绍可以参考文章:https://mtyjkh.blog.csdn.net/article/details/117018789
"""
val_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.2,subset="validation",seed=123,image_size=(img_height, img_width),batch_size=batch_size)

Found 565 files belonging to 4 classes. Using 113 files for validation.

class_names = train_ds.class_names
print(class_names)

['Bananaquit', 'Black Skimmer', 'Black Throated Bushtiti', 'Cockatoo']

可视化数据

plt.figure(figsize=(10, 5))  # 图形的宽为10高为5
plt.suptitle("瓜牛")for images, labels in train_ds.take(1):for i in range(8):ax = plt.subplot(2, 4, i + 1)  plt.imshow(images[i].numpy().astype("uint8"))plt.title(class_names[labels[i]])plt.axis("off")

再次检查数据

for image_batch, labels_batch in train_ds:print(image_batch.shape)print(labels_batch.shape)break

(8, 224, 224, 3) (8,)

配置数据集

● shuffle() : 打乱数据

● prefetch() :预取数据,加速运行,其详细介绍可以参考我前两篇文章,里面都有讲解。

● cache() :将数据集缓存到内存当中,加速运行

AUTOTUNE = tf.data.AUTOTUNEtrain_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

构建ResNet-50网络模型

from keras import layersfrom keras.layers import Input,Activation,BatchNormalization,Flatten
from keras.layers import Dense,Conv2D,MaxPooling2D,ZeroPadding2D,AveragePooling2D
from keras.models import Modeldef identity_block(input_tensor, kernel_size, filters, stage, block):filters1, filters2, filters3 = filtersname_base = str(stage) + block + '_identity_block_'x = Conv2D(filters1, (1, 1), name=name_base + 'conv1')(input_tensor)x = BatchNormalization(name=name_base + 'bn1')(x)x = Activation('relu', name=name_base + 'relu1')(x)x = Conv2D(filters2, kernel_size,padding='same', name=name_base + 'conv2')(x)x = BatchNormalization(name=name_base + 'bn2')(x)x = Activation('relu', name=name_base + 'relu2')(x)x = Conv2D(filters3, (1, 1), name=name_base + 'conv3')(x)x = BatchNormalization(name=name_base + 'bn3')(x)x = layers.add([x, input_tensor] ,name=name_base + 'add')x = Activation('relu', name=name_base + 'relu4')(x)return x# 在残差网络中,广泛地使用了BN层;但是没有使用MaxPooling以便减小特征图尺寸,
# 作为替代,在每个模块的第一层,都使用了strides = (2, 2)的方式进行特征图尺寸缩减,
# 与使用MaxPooling相比,毫无疑问是减少了卷积的次数,输入图像分辨率较大时比较适合
# 在残差网络的最后一级,先利用layer.add()实现H(x) = x + F(x)
def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2)):filters1, filters2, filters3 = filtersres_name_base = str(stage) + block + '_conv_block_res_'name_base = str(stage) + block + '_conv_block_'x = Conv2D(filters1, (1, 1), strides=strides, name=name_base + 'conv1')(input_tensor)x = BatchNormalization(name=name_base + 'bn1')(x)x = Activation('relu', name=name_base + 'relu1')(x)x = Conv2D(filters2, kernel_size, padding='same', name=name_base + 'conv2')(x)x = BatchNormalization(name=name_base + 'bn2')(x)x = Activation('relu', name=name_base + 'relu2')(x)x = Conv2D(filters3, (1, 1), name=name_base + 'conv3')(x)x = BatchNormalization(name=name_base + 'bn3')(x)shortcut = Conv2D(filters3, (1, 1), strides=strides, name=res_name_base + 'conv')(input_tensor)shortcut = BatchNormalization(name=res_name_base + 'bn')(shortcut)x = layers.add([x, shortcut], name=name_base+'add')x = Activation('relu', name=name_base+'relu4')(x)return xdef ResNet50(input_shape=[224,224,3],classes=1000):img_input = Input(shape=input_shape)x = ZeroPadding2D((3, 3))(img_input)x = Conv2D(64, (7, 7), strides=(2, 2), name='conv1')(x)x = BatchNormalization(name='bn_conv1')(x)x = Activation('relu')(x)x = MaxPooling2D((3, 3), strides=(2, 2))(x)x =     conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1))x = identity_block(x, 3, [64, 64, 256], stage=2, block='b')x = identity_block(x, 3, [64, 64, 256], stage=2, block='c')x =     conv_block(x, 3, [128, 128, 512], stage=3, block='a')x = identity_block(x, 3, [128, 128, 512], stage=3, block='b')x = identity_block(x, 3, [128, 128, 512], stage=3, block='c')x = identity_block(x, 3, [128, 128, 512], stage=3, block='d')x =     conv_block(x, 3, [256, 256, 1024], stage=4, block='a')x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b')x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c')x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d')x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e')x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f')x =     conv_block(x, 3, [512, 512, 2048], stage=5, block='a')x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b')x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c')x = AveragePooling2D((7, 7), name='avg_pool')(x)x = Flatten()(x)x = Dense(classes, activation='softmax', name='fc1000')(x)model = Model(img_input, x, name='resnet50')# 加载预训练模型model.load_weights("/Users/yueyishen/jupter/data/resnet50_weights_tf_dim_ordering_tf_kernels.h5")return modelmodel = ResNet50()
model.summary()

编译

在准备对模型进行训练之前,还需要再对其进行一些设置。以下内容是在模型的编译步骤中添加的:

● 损失函数(loss):用于衡量模型在训练期间的准确率。

● 优化器(optimizer):决定模型如何根据其看到的数据和自身的损失函数进行更新。

● 指标(metrics):用于监控训练和测试步骤。以下示例使用了准确率,即被正确分类的图像的比率

model.compile(optimizer="adam",loss='sparse_categorical_crossentropy',metrics=['accuracy'])

训练模型

epochs = 10history = model.fit(train_ds,validation_data=val_ds,epochs=epochs
)

 模型评估

acc = history.history['accuracy']
val_acc = history.history['val_accuracy']loss = history.history['loss']
val_loss = history.history['val_loss']epochs_range = range(epochs)plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.suptitle("微信公众号:K同学啊")plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

 预测

# 采用加载的模型(new_model)来看预测结果plt.figure(figsize=(10, 5))  # 图形的宽为10高为5
plt.suptitle("微信公众号:K同学啊")for images, labels in val_ds.take(1):for i in range(8):ax = plt.subplot(2, 4, i + 1)  # 显示图片plt.imshow(images[i].numpy().astype("uint8"))# 需要给图片增加一个维度img_array = tf.expand_dims(images[i], 0) # 使用模型预测图片中的人物predictions = model.predict(img_array)plt.title(class_names[np.argmax(predictions)])plt.axis("off")

 

总结:

  • 数据导入与预处理:首先导入必要的库,设置数据目录,查看数据总数为 565 张图片。使用 image_dataset_from_directory 方法将磁盘中的数据加载为训练集和验证集,进行数据预处理,包括设置图像大小、批次大小等,并对数据集进行打乱、预取和缓存等操作。
  • 可视化数据:通过 plt.figure 和循环展示了训练集中的部分图片,并标注了图片的类别名称。再次检查数据,打印出图像批次的形状和标签批次的形状。
  • 构建 ResNet-50 网络模型:定义了 identity_block 和 conv_block 函数,用于构建 ResNet-50 模型。该模型接收输入形状为 [224,224,3] 的图像,经过一系列卷积、批归一化、激活和残差连接等操作,最后输出分类结果。
  • 编译模型:在编译模型时,设置损失函数为 sparse_categorical_crossentropy,优化器为 adam,指标为准确率。
  • 训练模型:使用训练集和验证集对模型进行训练,设置训练轮数为 10 轮。
  • 模型评估:绘制训练和验证的准确率及损失曲线,以评估模型的性能。
  • 预测:对验证集中的部分图片进行预测,展示预测结果的类别名称

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/17338.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于微型5G网关的石化厂区巡检机器人应用

石化工业属于高风险产业,由于涉及易燃易爆、有毒有害工业原料,为了保障企业的安全生产与持续运营,因此相比其它行业需要进行更高频次、更全面细致的安全巡检和监测。由于传统的人工巡检监测存在诸多不便,例如工作强度大、现场环境…

Docker+Jenkins自动化部署SpringBoot项目【详解git,jdk,maven,ssh配置等各种配置,附有示例+代码】

文章目录 DockerJenkins部署SpringBoot项目一.准备工作1.1安装jdk111.2安装Maven 二.Docker安装Jenkins2.1安装Docker2.2 安装Jenkins2.3进入jenkins 三.Jenkins设置3.1安装jenkins插件3.2全局工具配置全局配置jdk全局配置maven全局配置git 3.3 系统配置安装 Publish Over SSH …

知识图谱数据库 Neo4j in Docker笔记

下载 docker pull neo4j:community官方说明 https://neo4j.com/docs/operations-manual/2025.01/docker/introduction/ 启动 docker run \--restart always \--publish7474:7474 --publish7687:7687 \--env NEO4J_AUTHneo4j/your_password \--volumeD:\files\knowledgegrap…

前缀和算法篇:解决子数组累加和问题

1.前缀和原理 那么在介绍前缀和的原理之前,那么我们先来说下前缀和最基本的一个应用场景,那么就是如我们标题所说的子数组累加和问题,那么假设我们现在有一个区间为[L,R]的数组,那么我们要求的其中子数组比如[L,i]或者[i,m] (L&l…

Notepad++ 中删除所有以 “pdf“ 结尾的行

Notepad 中删除所有以 “pdf” 结尾的行 操作步骤 1.打开文件: 在 Notepad 中打开你需要处理的文本文件。 2.打开查找和替换对话框: 按快捷键 Ctrl F,打开“查找和替换”对话框。 3.启用正则表达式模式: 在对话框的底部&#xf…

知识管理成功:关键指标和策略,研究信息的投资回报率

信息过载会影响生产力。没有人工智能的帮助,信息过载会影响生产力。大量的可用信息,知识工作者不仅仅是超负荷工作;他们感到不知所措,他们倾向于浪费时间(和脑细胞)来应付他们被大量的数据抛向他们&#xf…

Golang 进阶训练营

一、Golang 的 slice、map、channel 1.1 slice vs array a : make([]int, 100) //切片 b : [100]int{} //数组array需指明长度,长度为常量且不可改变 array长度为其类型中的组成部分(给参数为长度100的数组的方法传长度为101的会报错) array在…

Oracle临时表空间(基础操作)

临时表空间 临时表空间:用来存放用户的临时数据,临时数据在需要时被覆盖,关闭数据库后自动删除,其中不能存放永久性数据。 用户进程和服务器进程是一对一的叫做专用连接。 任何一个用户连到oracle数据库,oracle都会…

AI时代的前端开发:对抗压力的利器

在飞速发展的AI时代,前端开发工程师们面临着前所未有的挑战。项目周期不断缩短,需求变化日新月异,交付压力更是与日俱增,这使得开发人员承受着巨大的压力。如何提升对抗压能力,成为摆在每一位前端工程师面前的重要课题…

如何使用DHTMLX Scheduler的拖放功能,在 JS 日程安排日历中创建一组相同的事件

DHTMLX Scheduler 是一个全面的调度解决方案,涵盖了与规划事件相关的广泛需求。假设您在我们的 Scheduler 文档中找不到任何功能,并且希望在我们的 Scheduler 文档中看到您的项目。在这种情况下,很可能可以使用自定义解决方案来实现此类功能。…

计算机网络-八股-学习摘要

一:HTTP的基本概念 全称: 超文本传输协议 从三个方面介绍HTTP协议 1,超文本:我们先来理解「文本」,在互联网早期的时候只是简单的字符文字,但现在「文本」的涵义已经可以扩展为图片、视频、压缩包等&am…

【pytorch】weight_norm和spectral_norm

apply_parametrization_norm 和spectral_norm是 PyTorch 中用于对模型参数进行规范化的方法,但它们在实现和使用上有显著的区别。以下是它们的主要区别和对比: 实现方式 weight_norm: weight_norm 是一种参数重参数化技术,将权…

回归预测 | Matlab实现PSO-HKELM粒子群算法优化混合核极限学习机多变量回归预测

回归预测 | Matlab实现PSO-HKELM粒子群算法优化混合核极限学习机多变量回归预测 目录 回归预测 | Matlab实现PSO-HKELM粒子群算法优化混合核极限学习机多变量回归预测效果一览基本介绍程序设计参考资料 效果一览 基本介绍 1.回归预测 | Matlab实现PSO-HKELM粒子群算法优化混合核…

多媒体软件安全与授权新范例,用 CodeMeter 实现安全、高效的软件许可管理

背景概述 Reason Studios 成立于 1994 年,总部位于瑞典斯德哥尔摩,是全球领先的音乐制作软件开发商。凭借创新的软件产品和行业标准技术,如 ReWire 和 REX 文件格式,Reason Studios 为全球专业音乐人和业余爱好者提供了一系列高质…

C++,STL容器适配器,stack:栈深入解析

文章目录 一、容器概览与核心特性核心特性速览二、底层实现原理1. 容器适配器设计2. 默认容器对比三、核心操作详解1. 容器初始化2. 元素操作接口3. 自定义栈实现四、实战应用场景1. 括号匹配校验2. 浏览器历史记录管理五、性能优化策略1. 底层容器选择基准2. 内存预分配技巧六…

互联网大厂中面试的高频计算机网络问题及详解

前言 哈喽各位小伙伴们,本期小梁给大家带来了互联网大厂中计算机网络部分的高频面试题,本文会以通俗易懂的语言以及图解形式描述,希望能给大家的面试带来一点帮助,祝大家offer拿到手软!!! 话不多说,我们立刻进入本期正题! 一、计算机网络基础部分 1 …

「软件设计模式」工厂方法模式 vs 抽象工厂模式

前言 在软件工程领域,设计模式是解决常见问题的经典方案。本文将深入探讨两种创建型模式:工厂方法模式和抽象工厂模式,通过理论解析与实战代码示例,帮助开发者掌握这两种模式的精髓。 一、工厂方法模式(Factory Metho…

Docker部署Alist网盘聚合管理工具完整教程

Docker部署Alist网盘聚合管理工具完整教程 部署alist初始化修改密码添加存储!联通网盘阿里云盘百度网盘 部署alist 本文以Linux Docker部署,假设你已经安装好Docker docker run -d --restartalways \-v /your/data:/opt/alist/data \-p 5244:5244 \-e …

Excel常用操作

Excel常用操作 学习资源 37_电子表格处理考点精讲_设置数据格式_哔哩哔哩_bilibili 快速输入数据与编辑数据 一个工作簿可以包含多个工作表 特殊数据的添加格式 输入负数, 例如-3、-5 常规输入, 直接输入-3、-5;使用(), 例如在单元格中输入(3)回车即可变为-3;上述括号不区分中…

SpringMVC环境搭建

文章目录 1.模块创建1.创建一个webapp的maven项目2.目录结构 2.代码1.HomeController.java2.home.jsp3.applicationContext.xml Spring配置文件4.spring-mvc.xml SpringMVC配置文件5.web.xml 配置中央控制器以及Spring和SpringMVC配置文件的路径6.index.jsp 3.配置Tomcat1.配置…