深度学习笔记11-优化器对比实验(Tensorflow)

  • 🍨 本文为🔗365天深度学习训练营中的学习记录博客
  • 🍖 原作者:K同学啊

目录

一、导入数据并检查

二、配置数据集

三、数据可视化

四、构建模型

五、训练模型

六、模型对比评估

七、总结


一、导入数据并检查

import pathlib,PIL
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
# 支持中文
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签data_dir    = pathlib.Path("./T6")
image_count = len(list(data_dir.glob('*/*')))
batch_size = 16
img_height = 336
img_width  = 336
"""
关于image_dataset_from_directory()的详细介绍可以参考文章:https://mtyjkh.blog.csdn.net/article/details/117018789
"""
train_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.2,subset="training",seed=12,image_size=(img_height, img_width),batch_size=batch_size)
val_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.2,subset="validation",seed=12,image_size=(img_height, img_width),batch_size=batch_size)

class_names = train_ds.class_names
print(class_names)

for image_batch, labels_batch in train_ds:print(image_batch.shape)print(labels_batch.shape)break

二、配置数据集

AUTOTUNE = tf.data.AUTOTUNE
#归一化处理
def train_preprocessing(image,label):return (image/255.0,label)train_ds = (train_ds.cache().shuffle(1000).map(train_preprocessing)    # 这里可以设置预处理函数
#     .batch(batch_size)           # 在image_dataset_from_directory处已经设置了batch_size.prefetch(buffer_size=AUTOTUNE)
)val_ds = (val_ds.cache().shuffle(1000).map(train_preprocessing)    # 这里可以设置预处理函数
#     .batch(batch_size)         # 在image_dataset_from_directory处已经设置了batch_size.prefetch(buffer_size=AUTOTUNE)
)

三、数据可视化

plt.figure(figsize=(10, 8))  # 图形的宽为10高为5
plt.suptitle("数据展示")for images, labels in train_ds.take(1):for i in range(15):plt.subplot(4, 5, i + 1)plt.xticks([])plt.yticks([])plt.grid(False)# 显示图片plt.imshow(images[i])# 显示标签plt.xlabel(class_names[labels[i]-1])plt.show()

四、构建模型

from tensorflow.keras.layers import Dropout,Dense,BatchNormalization
from tensorflow.keras.models import Modeldef create_model(optimizer='adam'):# 加载预训练模型vgg16_base_model = tf.keras.applications.vgg16.VGG16(weights='imagenet',include_top=False,#不包含顶层的全连接层input_shape=(img_width, img_height, 3),pooling='avg')#平均池化层替代顶层的全连接层for layer in vgg16_base_model.layers:layer.trainable = False  #将 trainable属性设置为 False 意味着在训练过程中,这些层的权重不会更新X = vgg16_base_model.outputX = Dense(170, activation='relu')(X)X = BatchNormalization()(X)X = Dropout(0.5)(X)output = Dense(len(class_names), activation='softmax')(X)#神经元数量等于类别数vgg16_model = Model(inputs=vgg16_base_model.input, outputs=output)vgg16_model.compile(optimizer=optimizer,loss='sparse_categorical_crossentropy',metrics=['accuracy'])return vgg16_modelmodel1 = create_model(optimizer=tf.keras.optimizers.Adam())
model2 = create_model(optimizer=tf.keras.optimizers.SGD())#随机梯度下降(SGD)优化器的
model2.summary()

五、训练模型

NO_EPOCHS = 20history_model1  = model1.fit(train_ds, epochs=NO_EPOCHS, verbose=1, validation_data=val_ds)
history_model2  = model2.fit(train_ds, epochs=NO_EPOCHS, verbose=1, validation_data=val_ds)

六、模型对比评估

from matplotlib.ticker import MultipleLocator
plt.rcParams['savefig.dpi'] = 300 #图片像素
plt.rcParams['figure.dpi']  = 300 #分辨率acc1     = history_model1.history['accuracy']
acc2     = history_model2.history['accuracy']
val_acc1 = history_model1.history['val_accuracy']
val_acc2 = history_model2.history['val_accuracy']loss1     = history_model1.history['loss']
loss2     = history_model2.history['loss']
val_loss1 = history_model1.history['val_loss']
val_loss2 = history_model2.history['val_loss']epochs_range = range(len(acc1))plt.figure(figsize=(16, 4))
plt.subplot(1, 2, 1)plt.plot(epochs_range, acc1, label='Training Accuracy-Adam')
plt.plot(epochs_range, acc2, label='Training Accuracy-SGD')
plt.plot(epochs_range, val_acc1, label='Validation Accuracy-Adam')
plt.plot(epochs_range, val_acc2, label='Validation Accuracy-SGD')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
# 设置刻度间隔,x轴每1一个刻度
ax = plt.gca()
ax.xaxis.set_major_locator(MultipleLocator(1))plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss1, label='Training Loss-Adam')
plt.plot(epochs_range, loss2, label='Training Loss-SGD')
plt.plot(epochs_range, val_loss1, label='Validation Loss-Adam')
plt.plot(epochs_range, val_loss2, label='Validation Loss-SGD')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')# 设置刻度间隔,x轴每1一个刻度
ax = plt.gca()
ax.xaxis.set_major_locator(MultipleLocator(1))plt.show()

可以看出,在这个实例中,Adam优化器的效果优于SGD优化器

七、总结

      通过本次实验,学会了比较不同优化器(Adam和SGD)在训练过程中的性能表现,可视化训练过程的损失曲线和准确率等指标。这是一项非常重要的技能,在研究论文中,可以通过这些优化方法可以提高工作量。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/505695.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

FreeROTS学习 内存管理

内存管理是一个系统基本组成部分,FreeRTOS 中大量使用到了内存管理,比如创建任务、信号量、队列等会自动从堆中申请内存,用户应用层代码也可以 FreeRTOS 提供的内存管理函数来申请和释放内存 FreeRTOS 内存管理简介 FreeRTOS 创建任务、队列…

【设计模式】介绍常见的设计模式

🥰🥰🥰来都来了,不妨点个关注叭! 👉博客主页:欢迎各位大佬!👈 文章目录 ✨ 介绍一下常见的设计模式✨ Spring 中常见的设计模式 这期内容主要是总结一下常见的设计模式,可…

6 分布式限流框架

限流的作用 在API对外互联网开放的情况下,是无法控制调用方的行为的。当遇到请求激增或者黑客攻击的情况下,会导致接口占用大量的服务器资源,使得接口响应效率的降低或者超时,更或者导致服务器宕机。 限流是指对应用服务进行限制…

【动态规划篇】欣赏概率论与镜像法融合下,别出心裁探索解答括号序列问题

本篇鸡汤:没有人能替你承受痛苦,也没有人能拿走你的坚强. 欢迎拜访:羑悻的小杀马特.-CSDN博客 本篇主题:带你解答洛谷的括号序列问题(绝对巧解) 制作日期:2025.01.10 隶属专栏:C/C题…

数据库高安全—角色权限:权限管理权限检查

目录 3.3 权限管理 3.4 权限检查 书接上文数据库高安全—角色权限:角色创建角色管理,从角色创建和角色管理两方面对高斯数据库的角色权限进行了介绍,本篇将从权限管理和权限检查方面继续解读高斯数据库的角色权限。 3.3 权限管理 &#x…

深入浅出负载均衡:理解其原理并选择最适合你的实现方式

负载均衡是一种在多个计算资源(如服务器、CPU核心、网络链接等)之间分配工作负载的技术,旨在优化资源利用率、提高系统吞吐量和降低响应时间。负载均衡的实现方式多种多样,以下是几种常见的实现方式: 1. 硬件负载均衡&…

Training-free regional prompting for diffusion transformers

通过语言模型来构建位置关系的,omnigen combine来做位置生成,其实可以通过大模型来做,不错。 1.introduction 文生图模型在准确处理具有复杂空间布局的提示时仍然面临挑战,1.通过自然语言准确描述特定的空间布局非常困难,特别是当对象数量增加或需要精确的位置控制时,2.…

麦田物语学习笔记:背包物品选择高亮显示和动画

如题,本篇文章没讲动画效果 基本流程 1.代码思路 (1)先用点击事件的接口函数去实现,点击后反转选择状态(isSelected),以及设置激活状态(SetActive),并且还需要判断该格子是否为空,空格子是点不动的,完成后以上后,出现的问题是高亮应该是有且仅有一个格子是高亮的,而现在可以让…

Linux:深入了解fd文件描述符

目录 1. 文件分类 2. IO函数 2.1 fopen读写模式 2.2 重定向 2.3 标准文件流 3. 系统调用 3.1 open函数认识 3.2 open函数使用 3.3 close函数 3.4 write函数 3.5 read函数 4. fd文件描述符 4.1 标准输入输出 4.2 什么是文件描述符 4.3 语言级文件操作 1. 文件分类…

数据结构:栈(Stack)和队列(Queue)—面试题(一)

目录 1、括号匹配 2、逆波兰表达式求值 3、栈的压入、弹出序列 4、最小栈 1、括号匹配 习题链接https://leetcode.cn/problems/valid-parentheses/description/ 描述: 给定一个只包括 (,),{,},[,] …

51单片机(一) keil4工程与小灯实验

直接开始 新建一个工程 在这里插入图片描述 添加文件 另存为 添加文件到组 写下一个超循环系统代码 调整编译项编译 可以在工程目录找到编译好的led_fst.hex 自行烧写到各自的开发板。 会看到什么都没有。 现在定义一个GPIO端口与小灯的连接,再点亮小灯…

基于 Python 和 OpenCV 的人脸识别上课考勤管理系统

博主介绍:✌程序员徐师兄、7年大厂程序员经历。全网粉丝12w、csdn博客专家、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ 🍅文末获取源码联系🍅 👇🏻 精彩专栏推荐订阅👇…

Vue2:el-table中的文字根据内容改变颜色

想要实现的效果如图,【级别】和【P】列的颜色根据文字内容变化 1、正常创建表格 <template><el-table:data="tableData"style="width: 100%"><el-table-column prop="id" label="ID"/> <el-table-column …

案例研究:UML用例图中的结账系统

在软件工程和系统分析中&#xff0c;统一建模语言&#xff08;UML&#xff09;用例图是一种强有力的工具&#xff0c;用于描述系统与其用户之间的交互。本文将通过一个具体的案例研究&#xff0c;详细解释UML用例图的关键概念&#xff0c;并说明其在设计结账系统中的应用。 用…

73.矩阵置零 python

矩阵置零 题目题目描述示例 1&#xff1a;示例 2&#xff1a;提示&#xff1a; 题解思路分析Python 实现代码代码解释提交结果 题目 题目描述 给定一个 m x n 的矩阵&#xff0c;如果一个元素为 0 &#xff0c;则将其所在行和列的所有元素都设为 0 。请使用 原地 算法。 示例…

C++【深入底层,从零模拟实现string类】

在学习了类和对象、模板等前期的C基础知识之后&#xff0c;我们可以尝试根据C标准库中所提供的接口类型&#xff0c;来搭建我们自己的string类型。这个过程有助于初学者掌握C的基础语法及底层逻辑。 框架的搭建 首先搭建模型的基础框架&#xff0c;需要建立my_string.h和my_st…

切忌 SELECT *,就算表只有一列

原文地址 尽量避免 SELECT *&#xff0c;即使在单列表上也是如此 – 如果你现在不同意这一点&#xff0c;读完这篇文章&#xff0c;你可能就要动摇了。 2012年的一个故事 这是我 12 年前&#xff08;约 2012-2013 年&#xff09;在客户后台应用程序中遇到的一个真实故事。 当…

DEV C++软件下载

一、进入网站 https://sourceforge.net/projects/orwelldevcpp/ 二、点击下载 三、安装步骤 1、点击 “OK” 2、点击“I agree” 3、点击“Next” 4、按步骤切换路径&#xff0c;本文选在D盘&#xff0c;可自行选取文件路径 5、等待安装 6、点击完成 7、选择语言 8、点击“N…

OpenBSD之安装指南

安装介质下载 OpenBSD的官网下载地址&#xff1a;https://www.openbsd.org/faq/faq4.html#Download&#xff0c;同时也是《OpenBSD FAQ - Installation Guide》。长篇大论了很多&#xff0c;每一个章节都能看懂是干嘛的&#xff0c;连起来就容易晕。并且是英文的&#xff0c;要…

Vue.config.productionTip = false 不起作用的问题及解决

文章目录 一、问题描述二、解决方法 一、问题描述 当我们在代码页面上引入Vue.js(开发版本)时&#xff0c;运行代码会出现以下提示&#xff0c;这句话的意思是&#xff1a;您正在开发模式下运行Vue&#xff0c;在进行生产部署时&#xff0c;请确保打开生产模式 You are runni…