深度学习笔记11-优化器对比实验(Tensorflow)

  • 🍨 本文为🔗365天深度学习训练营中的学习记录博客
  • 🍖 原作者:K同学啊

目录

一、导入数据并检查

二、配置数据集

三、数据可视化

四、构建模型

五、训练模型

六、模型对比评估

七、总结


一、导入数据并检查

import pathlib,PIL
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
# 支持中文
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签data_dir    = pathlib.Path("./T6")
image_count = len(list(data_dir.glob('*/*')))
batch_size = 16
img_height = 336
img_width  = 336
"""
关于image_dataset_from_directory()的详细介绍可以参考文章:https://mtyjkh.blog.csdn.net/article/details/117018789
"""
train_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.2,subset="training",seed=12,image_size=(img_height, img_width),batch_size=batch_size)
val_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.2,subset="validation",seed=12,image_size=(img_height, img_width),batch_size=batch_size)

class_names = train_ds.class_names
print(class_names)

for image_batch, labels_batch in train_ds:print(image_batch.shape)print(labels_batch.shape)break

二、配置数据集

AUTOTUNE = tf.data.AUTOTUNE
#归一化处理
def train_preprocessing(image,label):return (image/255.0,label)train_ds = (train_ds.cache().shuffle(1000).map(train_preprocessing)    # 这里可以设置预处理函数
#     .batch(batch_size)           # 在image_dataset_from_directory处已经设置了batch_size.prefetch(buffer_size=AUTOTUNE)
)val_ds = (val_ds.cache().shuffle(1000).map(train_preprocessing)    # 这里可以设置预处理函数
#     .batch(batch_size)         # 在image_dataset_from_directory处已经设置了batch_size.prefetch(buffer_size=AUTOTUNE)
)

三、数据可视化

plt.figure(figsize=(10, 8))  # 图形的宽为10高为5
plt.suptitle("数据展示")for images, labels in train_ds.take(1):for i in range(15):plt.subplot(4, 5, i + 1)plt.xticks([])plt.yticks([])plt.grid(False)# 显示图片plt.imshow(images[i])# 显示标签plt.xlabel(class_names[labels[i]-1])plt.show()

四、构建模型

from tensorflow.keras.layers import Dropout,Dense,BatchNormalization
from tensorflow.keras.models import Modeldef create_model(optimizer='adam'):# 加载预训练模型vgg16_base_model = tf.keras.applications.vgg16.VGG16(weights='imagenet',include_top=False,#不包含顶层的全连接层input_shape=(img_width, img_height, 3),pooling='avg')#平均池化层替代顶层的全连接层for layer in vgg16_base_model.layers:layer.trainable = False  #将 trainable属性设置为 False 意味着在训练过程中,这些层的权重不会更新X = vgg16_base_model.outputX = Dense(170, activation='relu')(X)X = BatchNormalization()(X)X = Dropout(0.5)(X)output = Dense(len(class_names), activation='softmax')(X)#神经元数量等于类别数vgg16_model = Model(inputs=vgg16_base_model.input, outputs=output)vgg16_model.compile(optimizer=optimizer,loss='sparse_categorical_crossentropy',metrics=['accuracy'])return vgg16_modelmodel1 = create_model(optimizer=tf.keras.optimizers.Adam())
model2 = create_model(optimizer=tf.keras.optimizers.SGD())#随机梯度下降(SGD)优化器的
model2.summary()

五、训练模型

NO_EPOCHS = 20history_model1  = model1.fit(train_ds, epochs=NO_EPOCHS, verbose=1, validation_data=val_ds)
history_model2  = model2.fit(train_ds, epochs=NO_EPOCHS, verbose=1, validation_data=val_ds)

六、模型对比评估

from matplotlib.ticker import MultipleLocator
plt.rcParams['savefig.dpi'] = 300 #图片像素
plt.rcParams['figure.dpi']  = 300 #分辨率acc1     = history_model1.history['accuracy']
acc2     = history_model2.history['accuracy']
val_acc1 = history_model1.history['val_accuracy']
val_acc2 = history_model2.history['val_accuracy']loss1     = history_model1.history['loss']
loss2     = history_model2.history['loss']
val_loss1 = history_model1.history['val_loss']
val_loss2 = history_model2.history['val_loss']epochs_range = range(len(acc1))plt.figure(figsize=(16, 4))
plt.subplot(1, 2, 1)plt.plot(epochs_range, acc1, label='Training Accuracy-Adam')
plt.plot(epochs_range, acc2, label='Training Accuracy-SGD')
plt.plot(epochs_range, val_acc1, label='Validation Accuracy-Adam')
plt.plot(epochs_range, val_acc2, label='Validation Accuracy-SGD')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
# 设置刻度间隔,x轴每1一个刻度
ax = plt.gca()
ax.xaxis.set_major_locator(MultipleLocator(1))plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss1, label='Training Loss-Adam')
plt.plot(epochs_range, loss2, label='Training Loss-SGD')
plt.plot(epochs_range, val_loss1, label='Validation Loss-Adam')
plt.plot(epochs_range, val_loss2, label='Validation Loss-SGD')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')# 设置刻度间隔,x轴每1一个刻度
ax = plt.gca()
ax.xaxis.set_major_locator(MultipleLocator(1))plt.show()

可以看出,在这个实例中,Adam优化器的效果优于SGD优化器

七、总结

      通过本次实验,学会了比较不同优化器(Adam和SGD)在训练过程中的性能表现,可视化训练过程的损失曲线和准确率等指标。这是一项非常重要的技能,在研究论文中,可以通过这些优化方法可以提高工作量。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/1245.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

HBuilderX打包ios保姆式教程

1、登录苹果开发者后台并登录已认证开发者账号ID Sign In - Apple 2、创建标识符(App ID)、证书,描述文件 3、首先创建标识符,用于新建App应用 3-1、App的话直接选择第一个App IDs,点击右上角继续 3-2、选择App&#x…

计算机网络 (39)TCP的运输连接管理

前言 TCP(传输控制协议)是一种面向连接的、可靠的传输协议,它在计算机网络中扮演着至关重要的角色。TCP的运输连接管理涉及连接建立、数据传送和连接释放三个阶段。 一、TCP的连接建立 TCP的连接建立采用三次握手机制,其过程如下&…

2 XDMA IP中断

三种中断 1. Legacy 定义:Legacy 中断是传统的中断处理方式,使用物理中断线(例如 IRQ)来传递中断信号。缺点: 中断线数量有限,通常为 16 条,限制了可连接设备的数量。中断处理可能会导致中断风…

22、PyTorch nn.Conv2d卷积网络使用教程

文章目录 1. 卷积2. python 代码3. notes 1. 卷积 输入A张量为: A [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 ] \begin{equation} A\begin{bmatrix} 0&1&2&3\\\\ 4&5&6&7\\\\ 8&9&10&11\\\\ 12&13&14&15 \end{b…

基于微信小程序的汽车销售系统的设计与实现springboot+论文源码调试讲解

第4章 系统设计 一个成功设计的系统在内容上必定是丰富的,在系统外观或系统功能上必定是对用户友好的。所以为了提升系统的价值,吸引更多的访问者访问系统,以及让来访用户可以花费更多时间停留在系统上,则表明该系统设计得比较专…

【ROS2】☆ launch之Python

☆重点 ROS1和ROS2其中一个很大区别之一就是launch的编写方式。在ROS1中采用xml格式编写launch,而ROS2保留了XML 格式launch,还另外引入了Python和YAML 编写方式。选择哪种编写取决于每位开发人员的爱好,但是ROS2官方推荐使用Python方式编写…

Facebook 隐私变革之路:回顾与展望

在数字时代,个人隐私的保护一直是社交平台面临的重大挑战之一。作为全球最大的社交网络平台,Facebook(现为Meta)在处理用户隐私方面的变革,历经了多次调整与完善。本文将回顾Facebook在隐私保护方面的历程,…

P10打卡——pytorch实现车牌识别

🍨 本文为🔗365天深度学习训练营中的学习记录博客🍖 原作者:K同学啊 1.检查GPU from torchvision.transforms import transforms from torch.utils.data import DataLoader from torchvision import datasets import torchvisio…

Pycharm连接远程解释器

这里写目录标题 0 前言1 给项目添加解释器2 通过SSH连接3 找到远程服务器的torch环境所对应的python路径,并设置同步映射(1)配置服务器的系统环境(2)配置服务器的conda环境 4 进入到程序入口(main.py&#…

VS2015 + OpenCV + OnnxRuntime-Cpp + YOLOv8 部署

近期有个工作需求是进行 YOLOv8 模型的 C 部署,部署环境如下 系统:WindowsIDE:VS2015语言:COpenCV 4.5.0OnnxRuntime 1.15.1 0. 预训练模型保存为 .onnx 格式 假设已经有使用 ultralytics 库训练并保存为 .pt 格式的 YOLOv8 模型…

uniApp通过xgplayer(西瓜播放器)接入视频实时监控

🚀 个人简介:某大型国企资深软件开发工程师,信息系统项目管理师、CSDN优质创作者、阿里云专家博主,华为云云享专家,分享前端后端相关技术与工作常见问题~ 💟 作 者:码喽的自我修养&#x1f9…

fast-crud select下拉框 实现多选功能及下拉框数据动态获取(通过接口获取)

教程 fast-crud select示例配置需求:需求比较复杂 1. 下拉框选项需要通过后端接口获取 2. 实现多选功能 由于这个前端框架使用逻辑比较复杂我也是第一次使用,所以只记录核心问题 环境:vue3,typescript,fast-crud ,elementPlus 效果 代码 // crud.tsx文件(/.ts也行 js应…

【Go】:图片上添加水印的全面指南——从基础到高级特性

前言 在数字内容日益重要的今天,保护版权和标识来源变得关键。为图片添加水印有助于声明所有权、提升品牌认知度,并防止未经授权的使用。本文将介绍如何用Go语言实现图片水印,包括静态图片和带旋转、倾斜效果的文字水印,帮助您有…

Win10微调大语言模型ChatGLM2-6B

在《Win10本地部署大语言模型ChatGLM2-6B-CSDN博客》基础上进行,官方文档在这里,参考了这篇文章 首先确保ChatGLM2-6B下的有ptuning AdvertiseGen下载地址1,地址2,文件中数据留几行 模型文件下载地址 (注意&#xff1…

多模态论文笔记——CLIP

大家好,这里是好评笔记,公主号:Goodnote,专栏文章私信限时Free。本文详细介绍这几年AIGC火爆的隐藏功臣,多模态模型:CLIP。 文章目录 CLIP(Contrastive Language-Image Pre-training&#xff09…

【ArcGIS微课1000例】0137:色彩映射表转为RGB全彩模式

本文讲述ArcGIS中,将tif格式的影像数据从色彩映射表转为RGB全彩模式。 参考阅读:【GlobalMapper精品教程】093:将tif影像色彩映射表(调色板)转为RGB全彩模式 文章目录 一、色彩映射表预览二、色彩映射表转为RGB全彩模式一、色彩映射表预览 加载配套数据包中的0137.rar中的…

计算机网络(四)——网络层

目录 一、功能 二、IP数据报分片 三、DHCP动态主机配置协议 四、网络地址转换(NAT)技术 五、无分类编址CIDR 六、ARP地址解析协议 七、ICMP网际控制报文协议 八、IPv4和IPv6的区别 九、IPv4向IPv6的两种过渡技术——双栈协议和隧道技术 十、路由…

【大数据】机器学习 -----关于data.csv数据集分析案例

打开表 import pandas as pd df2 pd.read_csv("data.csv",encoding"gbk") df2.head()查看数据属性(列标题,表形状,类型,行标题,值) print("列标题:",df2.columns)Data…

在 Linux 下Ubuntu创建同权限用户

我是因为不小心把最开始创建的用户的文件夹颜色搞没了,再后来全白用习惯了,就不想卸载了,像创建一个和最开始创建的用户有一样的权限可以执行sudo -i进入root一样的用户 如图这是最原始的样子 第一步 创建新用户,我这里是因为之前…

toRef 和 toRefs 详解及应用

在 Vue 3 中,toRef 和 toRefs 是两个用于创建响应式引用的工具,主要用于组合式 API(Composition API)的场景中 1. toRef 定义 toRef 将某个对象的某个属性包装成一个响应式引用。这样可以直接对该引用进行操作,而不需…