T11:优化器对比实验

T11周:优化器对比实验

      • **一、前期工作**
        • 1.设置GPU,导入库
      • **二、数据预处理**
        • 1.导入数据
        • 2.检查数据
        • 3.配置数据集
        • 4.数据可视化
      • **三、构建模型**
      • **四、训练模型**
      • **五、模型评估**
      • 六、总结

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

🍺
本次主要是探究不同优化器、以及不同参数配置对模型的影响,在论文当中我们也可以进行优化器的比对,以增加论文工作量。·

⛽ 我的环境

  • 语言环境:Python3.10.12
  • 编译器:Google Colab
  • 深度学习环境:
    • TensorFlow2.17.0

一、前期工作

1.设置GPU,导入库
#os提供了一些与操作系统交互的功能,比如文件和目录操作
import os
#提供图像处理的功能,包括打开和显示、保存、裁剪等
import PIL
from PIL import Image
#pathlib提供了一个面向对象的接口来处理文件系统路径。路径被表示为Path对象,可以调用方法来进行各种文件和目录操作。
import pathlib#用于绘制图形和可视化数据
import tensorflow as tf
import matplotlib.pyplot as plt
#用于数值计算的库,提供支持多维数组和矩阵运算
import numpy as np
#keras作为高层神经网络API,已被集成进tensorflow,使得训练更方便简单
from tensorflow import keras
#layers提供了神经网络的基本构建块,比如全连接层、卷积层、池化层等
#提供了构建和训练神经网络模型的功能,包括顺序模型(Sequential)和函数式模型(Functional API)
from tensorflow.keras import layers, models
#导入两个重要的回调函数:前者用于训练期间保存模型最佳版本;后者监测到模型性能不再提升时提前停止训练,避免过拟合
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
import tensorflow as tf
gpus = tf.config.list_physical_devices("GPU")if gpus:gpu0 = gpus[0] #如果有多个GPU,仅使用第0个GPUtf.config.experimental.set_memory_growth(gpu0, True) #设置GPU显存用量按需使用tf.config.set_visible_devices([gpu0],"GPU")from tensorflow  import keras
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import warnings,os,PIL,pathlibwarnings.filterwarnings("ignore")             #忽略警告信息
#plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False    # 用来正常显示负号

二、数据预处理

1.导入数据
from google.colab import drive
drive.mount("/content/drive/")
%cd "/content/drive/My Drive/Colab Notebooks/jupyter notebook/data/"
Drive already mounted at /content/drive/; to attempt to forcibly remount, call drive.mount("/content/drive/", force_remount=True).
/content/drive/My Drive/Colab Notebooks/jupyter notebook/data
data_dir = "./T6"
data_dir = pathlib.Path(data_dir)# 使用glob方法获取当前目录的子目录里所有以'.jpg'为结尾的文件
# '*/*.jpg' 是一個通配符模式
# 第一个星号表示当前目录
# 第二个星号表示子目录
image_count = len (list(data_dir.glob("*/*.jpg")))
print("图片总数:", image_count)
图片总数: 1800
#设置批量大小,即每次训练模型时输入图像数量
#每次训练迭代时,模型需处理32张图像
batch_size = 16
#图像的高度,加载图像数据时,将所有的图像调整为相同的高度
img_height = 336
#图像的宽度,加载图像数据时,将所有的图像调整为相同的宽度
img_width = 336
"""
关于image_dataset_from_directory()的详细介绍可以参考文章:https://mtyjkh.blog.csdn.net/article/details/117018789
"""
train_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.2,#指定数据集中分割出多少比例数据当作验证集,0.1表示10%数据会被用来当验证集subset="training",#指定是用于训练还是验证的数据子集,这里设定为trainingseed=12,#用于设置随机数种子,以确保数据集划分的可重复性和一致性image_size=(img_height, img_width),batch_size=batch_size)
Found 1800 files belonging to 17 classes.
Using 1440 files for training.
"""
关于image_dataset_from_directory()的详细介绍可以参考文章:https://mtyjkh.blog.csdn.net/article/details/117018789
"""
val_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.2,subset="validation",seed=12,image_size=(img_height, img_width),batch_size=batch_size)
Found 1800 files belonging to 17 classes.
Using 360 files for validation.
class_names = train_ds.class_names
# 可以通过class_names输出数据集的标签。标签将按字母顺序对应于目录名称
class_names
['Angelina Jolie','Brad Pitt','Denzel Washington','Hugh Jackman','Jennifer Lawrence','Johnny Depp','Kate Winslet','Leonardo DiCaprio','Megan Fox','Natalie Portman','Nicole Kidman','Robert Downey Jr','Sandra Bullock','Scarlett Johansson','Tom Cruise','Tom Hanks','Will Smith']
2.检查数据
for image_batch, labels_batch in train_ds:print(image_batch.shape)print(labels_batch.shape)break
(16, 336, 336, 3)
(16,)
3.配置数据集
#自动调整数据管道性能
AUTOTUNE = tf.data.AUTOTUNE
# 使用 tf.data.AUTOTUNE 具体的好处包括:
#自动调整并行度:自动决定并行处理数据的最佳线程数,以最大化数据吞吐量。
#减少等待时间:通过优化数据加载和预处理,减少模型训练时等待数据的时间。
#提升性能:自动优化数据管道的各个环节,使整个训练过程更高效。
#简化代码:不需要手动调整参数,代码更简洁且易于维护。#使用cache()方法将训练集缓存到内存中,这样加快数据加载速度
#当多次迭代训练数据时,可以重复使用已经加载到内存的数据而不必重新从磁盘加载
#使用shuffle()对训练数据集进行洗牌操作,打乱数据集中的样本顺序
#参数1000指缓冲区大小,即每次从数据集中随机选择的样本数量
#prefetch()预取数据,节约在训练过程中数据加载时间def train_preprocessing(image,label):return(image/255.0,label)train_ds = train_ds.cache().shuffle(1000).map(train_preprocessing).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().shuffle(1000).map(train_preprocessing).prefetch(buffer_size=AUTOTUNE)
4.数据可视化
plt.figure(figsize=(10, 8))  # 图形的宽为10高为5
plt.suptitle("数据展示")for images, labels in train_ds.take(1):for i in range(15):plt.subplot(4, 5, i + 1)plt.xticks([])plt.yticks([])plt.grid(False)# 显示图片plt.imshow(images[i])# 显示标签plt.xlabel(class_names[labels[i]-1])plt.show()

在这里插入图片描述

三、构建模型

from tensorflow.keras.layers import Dropout,Dense,BatchNormalization
from tensorflow.keras.models import Modeldef create_model(optimizer):# 加载预训练模型vgg16_base_model = tf.keras.applications.vgg16.VGG16(weights='imagenet',include_top=False,input_shape=(img_width, img_height, 3),pooling='avg')for layer in vgg16_base_model.layers:layer.trainable = FalseX = vgg16_base_model.outputX = Dense(170, activation='relu')(X)X = BatchNormalization()(X)X = Dropout(0.5)(X)output = Dense(len(class_names), activation='softmax')(X)vgg16_model = Model(inputs=vgg16_base_model.input, outputs=output)vgg16_model.compile(optimizer=optimizer,loss='sparse_categorical_crossentropy',metrics=['accuracy'])return vgg16_modelmodel1 = create_model(optimizer=tf.keras.optimizers.Adam())
model2 = create_model(optimizer=tf.keras.optimizers.SGD())
model2.summary()

在这里插入图片描述
在这里插入图片描述

四、训练模型

NO_epochs = 50history_model1=model1.fit(train_ds,epochs=NO_epochs,validation_data=val_ds)
history_model2=model2.fit(train_ds,epochs=NO_epochs,validation_data=val_ds)

部分训练过程:
在这里插入图片描述

五、模型评估

from matplotlib.ticker import MultipleLocator
plt.rcParams['savefig.dpi'] = 300 #图片像素
plt.rcParams['figure.dpi']  = 300 #分辨率acc1     = history_model1.history['accuracy']
acc2     = history_model2.history['accuracy']
val_acc1 = history_model1.history['val_accuracy']
val_acc2 = history_model2.history['val_accuracy']loss1     = history_model1.history['loss']
loss2     = history_model2.history['loss']
val_loss1 = history_model1.history['val_loss']
val_loss2 = history_model2.history['val_loss']epochs_range = range(len(acc1))plt.figure(figsize=(16, 4))
plt.subplot(1, 2, 1)plt.plot(epochs_range, acc1, label='Training Accuracy-Adam')
plt.plot(epochs_range, acc2, label='Training Accuracy-SGD')
plt.plot(epochs_range, val_acc1, label='Validation Accuracy-Adam')
plt.plot(epochs_range, val_acc2, label='Validation Accuracy-SGD')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
# 设置刻度间隔,x轴每1一个刻度
ax = plt.gca()
ax.xaxis.set_major_locator(MultipleLocator(1))plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss1, label='Training Loss-Adam')
plt.plot(epochs_range, loss2, label='Training Loss-SGD')
plt.plot(epochs_range, val_loss1, label='Validation Loss-Adam')
plt.plot(epochs_range, val_loss2, label='Validation Loss-SGD')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')# 设置刻度间隔,x轴每1一个刻度
ax = plt.gca()
ax.xaxis.set_major_locator(MultipleLocator(1))plt.show()

在这里插入图片描述

def test_accuracy_report(model):score = model.evaluate(val_ds, verbose=1)print('Loss function: %s, accuracy:' % score[0], score[1])print("model1:")
test_accuracy_report(model1)
print("model2:")
test_accuracy_report(model2)
model1:
[1m23/23[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 140ms/step - accuracy: 0.5283 - loss: 2.5871
Loss function: 2.443842649459839, accuracy: 0.5249999761581421
model2:
[1m23/23[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 141ms/step - accuracy: 0.5863 - loss: 1.3692
Loss function: 1.4873623847961426, accuracy: 0.5555555820465088

六、总结

本周学习了调用vgg16并构建functional model来进行不同优化器的设置和训练,并在最后对训练参数过程进行可视化。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/444856.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【前端碎片记录】大文件分片上传

大文件分片上传,主要是为了提高上传效率,避免网络问题或者其他原因导致整个上传失败。 HTML部分没什么特殊代码,这里只写js代码。用原生js实现,框架中可参考实现 // 获取上传文件的 input框 const ipt document.querySelector(…

aws(学习笔记第五课) AWS的firewall SecurityGroup,代理转发技术

aws(学习笔记第五课) AWS的firewall– SecurityGroup,代理转发技术 学习内容: AWS的firewall– SecurityGroup代理转发技术 1. AWS的filewall– SecurityGroup 控制进入虚拟服务器的网络流量 通常的firewall(防火墙)配置 AWS上使用安全组进行网络流量…

息肉检测数据集 yolov5 yolov8适用于目标检测训练已经调整为yolo格式可直接训练yolo网络

息肉检测数据集 yolov5 yolov8格式 息肉检测数据集介绍 数据集概述 名称:息肉检测数据集(基于某公开的分割数据集调整)用途:适用于目标检测任务,特别是内窥镜图像中的息肉检测格式:YOLO格式(边…

Transactional注解导致Spring Bean定时任务失效

背景 业务需要定时捞取数据库中新增的数据做数据处理及分析,更新状态,处理结束。而我们不能随意定义线程池,规定使用统一的标准规范来定义线程池。如在配置文件中配置线程池的属性:名称,线程核心数等,任务…

04-SpringBootWeb案例(中)

3. 员工管理 完成了部门管理的功能开发之后,我们进入到下一环节员工管理功能的开发。 基于以上原型,我们可以把员工管理功能分为: 分页查询(今天完成)带条件的分页查询(今天完成)删除员工&am…

Linux_kernel内核定时器14

一、内核定时器 1、内核定时器 使用方法: 2、系统时钟中断处理函数 1)更新时间 2)检查当前时间片是否耗尽 Linux操作系统是基于时间片轮询的,属于抢占式的内核 3)jiffies 3、基本概念 1)HZ HZ决定了1秒钟产…

OCP迎来新版本,让OceanBase的运维管理更高效

近期,OceanBase的OCP发布了新版本,全面支持 OceanBase 内核 4.3.2 及更低版本。新版本针对基础运维、性能监控、运维配置、外部集成等多个方面实现了 20余项的优化及强化措施,增强产品的易用性和稳定性,从而帮助用户更加高效地管理…

中国地级市生态韧性数据及城市生态韧性数据(2000-2022年)

一测算方式: 参考C刊《管理学刊》楚尔鸣(2023)老师的做法,城市生态韧性主要衡量一个城市在面临生态环境系统压力或突发冲击时,约束污染排放、维护生态环境状态和治理能力提升的综合水平。 参考郭海红和刘新民的研究&a…

Redis持久化机制(RDBAOF详解)

目录 一、Redis持久化介绍二、Redis持久化方式1、RDB持久化(1) 介绍(2) RDB持久化触发机制(3) RDB优点和缺点(4) RDB流程 2、AOF(append only file)持久化(1) 介绍(2) AOF优点和缺点(3) AOF文件重写(4) AOF文件重写流程 三、AOF和RDB持久化注意事项 一、Redis持久化介绍 Redis…

【小工具分享】下载保存指定网页的所有图片

一、保存百度首页所有的图片 先看一下保存的图片情况 二、思路 1、打开网页 2、获取所有图片 3、依次下载保存图片到指定路径 三、完整代码 from selenium import webdriver from selenium.webdriver.common.by import By b webdriver.Firefox() import urllib.request…

C++系统教程004-数据类型(03)

一 .变量 变量是指在程序运行期间其值可以发生改变的量。每个变量都必须有一个名称作为唯一的标识,且具有一个特定的数据类型。变量使用之前,一定要先进行声明或定义。 1.变量的声明和定义 C中,变量声明是指为变量提供一个名称&#xff0c…

嵌入式面试——FreeRTOS篇(七) 软件定时器

本篇为:FreeRTOS 软件定时器篇 一、软件定时器的简介 1、定时器介绍 答: 定时器:从指定的时刻开始,经过一个指定时间,然后触发一个超时事件,用户可以自定义定时器周期。 硬件定时器:芯片本…

基于差分进化灰狼混合优化的SVM(DE-GWO-SVM)数据预测算法matlab仿真

目录 1.程序功能描述 2.测试软件版本以及运行结果展示 3.核心程序 4.本算法原理 4.1 DE优化 4.2 GWO优化 5.完整程序 1.程序功能描述 基于差分进化灰狼混合优化的SVM(DE-GWO-SVM)数据预测算法matlab仿真,对比SVM和GWO-SVM。 2.测试软件版本以及运行结果展示…

论文阅读:Split-Aperture 2-in-1 Computational Cameras (二)

Split-Aperture 2-in-1 Computational Cameras (一) Coded Optics for High Dynamic Range Imaging 接下来,文章介绍了二合一相机在几种场景下的应用,首先是高动态范围成像,现有的快照高动态范围(HDR)成像工作已经证…

自然语言处理(NLP)论文数量的十年趋势:2014-2024

引言 近年来,自然语言处理(NLP)已成为人工智能(AI)和数据科学领域中的关键技术之一。随着数据规模的不断扩大和计算能力的提升,NLP技术从学术研究走向了广泛的实际应用。通过观察过去十年(2014…

处理 Vue3 中隐藏元素刷新闪烁问题

一、问题说明 页面刷新,原本隐藏的元素会一闪而过。 效果展示: 页面的导航栏通过路由跳转中携带的 meta 参数控制导航栏的 显示/隐藏,但在实践过程中发现,虽然元素隐藏了,但是刷新页面会出现闪烁的问题。 项目源码&…

ros2:从github上下载源码进行编译

首先,创建工作空间 # 1. 递归创建工作空间目录 mkdir -p catkin_ws/src # 2. 进入src目录 cd catkin_ws/src然后如果你没有安装git,需要 sudo apt install git然后输入。 git clone https://github.com/6-robot/wpr_simulation.git这时候,…

MYSQL 常见锁机制详解,常见锁问题排查及分析

1,锁分类 锁冲突是影响数据库性能的重要指标,本章节介绍MYSQL常见锁,及各种说的常用示例,mysql锁的分类如下: 从操作类型分类:读锁、写锁; 从操作粒度分类:表锁、页锁、行锁&#x…

文献阅读Prov-GigaPath模型--相关知识点罗列

文章链接:A whole-slide foundation model for digital pathology from real-world data | NatureDigital pathology poses unique computational challenges, as a standard gigapixel slide may comprise tens of thousands of image tiles1–3. Prior models hav…

Java中的二维数组

二维数组 使用方式1:动态初始化1.语法:2.比如:3.二维数组在内存的存在形式 使用方式2:动态初始化使用方法3:动态初始化--列数不确定使用方式4:静态初始化1.定义2.使用 使用方式1:动态初始化 1.…