第T8周:使用TensorFlow实现猫狗识别

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

    文章目录

    • 一、前期工作
      • 1.设置GPU(如果使用的是CPU可以忽略这步)
      • 2. 导入数据
    • 二、数据预处理
      • 1、加载数据
      • 2、再次检查数据
      • 3. 配置数据集
      • 4. 可视化数据
    • 三、构建CNN网络
    • 四、编译
    • 五、训练模型
    • 六、模型评估
    • 七、预测
    • 八、知识点
      • 1、训练方式
      • 2、tqdm
        • 2.1、基本用法:
        • 2.2、手动进度更新:

电脑环境:
语言环境:Python 3.8.0
编译器:Jupyter Notebook
深度学习环境:tensorflow 2.15.0

一、前期工作

1.设置GPU(如果使用的是CPU可以忽略这步)

import tensorflow as tfgpus = tf.config.list_physical_devices("GPU")if gpus:tf.config.experimental.set_memory_growth(gpus[0], True)  #设置GPU显存用量按需使用tf.config.set_visible_devices([gpus[0]],"GPU")# 打印显卡信息,确认GPU可用
print(gpus)

2. 导入数据

import matplotlib.pyplot as plt
# 支持中文
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号import os,PIL,pathlib#隐藏警告
import warnings
warnings.filterwarnings('ignore')data_dir = "./365-7-data"
data_dir = pathlib.Path(data_dir)image_count = len(list(data_dir.glob('*/*')))print("图片总数为:",image_count)

二、数据预处理

1、加载数据

使用image_dataset_from_directory方法将磁盘中的数据加载到tf.data.Dataset中。

batch_size = 8
img_height = 224
img_width = 224"""
关于image_dataset_from_directory()的详细介绍可以参考文章:https://mtyjkh.blog.csdn.net/article/details/117018789
"""
train_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.2,subset="training",seed=12,image_size=(img_height, img_width),batch_size=batch_size)val_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.2,subset="validation",seed=12,image_size=(img_height, img_width),batch_size=batch_size)

我们可以通过class_names输出数据集的标签。标签将按字母顺序对应于目录名称。

class_names = train_ds.class_names
print(class_names)

输出:

[‘cat’, ‘dog’]

2、再次检查数据

for image_batch, labels_batch in train_ds:print(image_batch.shape)print(labels_batch.shape)break

输出:

(8, 224, 224, 3)
(8,)

3. 配置数据集

AUTOTUNE = tf.data.AUTOTUNEdef preprocess_image(image,label):return (image/255.0,label)# 归一化处理
train_ds = train_ds.map(preprocess_image, num_parallel_calls=AUTOTUNE)
val_ds   = val_ds.map(preprocess_image, num_parallel_calls=AUTOTUNE)train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds   = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

4. 可视化数据

plt.figure(figsize=(15, 10))  # 图形的宽为15高为10for images, labels in train_ds.take(1):for i in range(8):ax = plt.subplot(5, 8, i + 1) plt.imshow(images[i])plt.title(class_names[labels[i]])plt.axis("off")

三、构建CNN网络

from tensorflow.keras import layers, models, Input
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Dense, Flatten, Dropoutdef VGG16(nb_classes, input_shape):input_tensor = Input(shape=input_shape)# 1st blockx = Conv2D(64, (3,3), activation='relu', padding='same',name='block1_conv1')(input_tensor)x = Conv2D(64, (3,3), activation='relu', padding='same',name='block1_conv2')(x)x = MaxPooling2D((2,2), strides=(2,2), name = 'block1_pool')(x)# 2nd blockx = Conv2D(128, (3,3), activation='relu', padding='same',name='block2_conv1')(x)x = Conv2D(128, (3,3), activation='relu', padding='same',name='block2_conv2')(x)x = MaxPooling2D((2,2), strides=(2,2), name = 'block2_pool')(x)# 3rd blockx = Conv2D(256, (3,3), activation='relu', padding='same',name='block3_conv1')(x)x = Conv2D(256, (3,3), activation='relu', padding='same',name='block3_conv2')(x)x = Conv2D(256, (3,3), activation='relu', padding='same',name='block3_conv3')(x)x = MaxPooling2D((2,2), strides=(2,2), name = 'block3_pool')(x)# 4th blockx = Conv2D(512, (3,3), activation='relu', padding='same',name='block4_conv1')(x)x = Conv2D(512, (3,3), activation='relu', padding='same',name='block4_conv2')(x)x = Conv2D(512, (3,3), activation='relu', padding='same',name='block4_conv3')(x)x = MaxPooling2D((2,2), strides=(2,2), name = 'block4_pool')(x)# 5th blockx = Conv2D(512, (3,3), activation='relu', padding='same',name='block5_conv1')(x)x = Conv2D(512, (3,3), activation='relu', padding='same',name='block5_conv2')(x)x = Conv2D(512, (3,3), activation='relu', padding='same',name='block5_conv3')(x)x = MaxPooling2D((2,2), strides=(2,2), name = 'block5_pool')(x)# full connectionx = Flatten()(x)x = Dense(4096, activation='relu',  name='fc1')(x)x = Dense(4096, activation='relu', name='fc2')(x)output_tensor = Dense(nb_classes, activation='softmax', name='predictions')(x)model = Model(input_tensor, output_tensor)return modelmodel=VGG16(1000, (img_width, img_height, 3))
model.summary()

四、编译

model.compile(optimizer="adam",loss     ='sparse_categorical_crossentropy',metrics  =['accuracy'])

五、训练模型

from tqdm import tqdm
import tensorflow.keras.backend as Kepochs = 10
lr     = 1e-4# 记录训练数据,方便后面的分析
history_train_loss     = []
history_train_accuracy = []
history_val_loss       = []
history_val_accuracy   = []for epoch in range(epochs):train_total = len(train_ds)val_total   = len(val_ds)"""total:预期的迭代数目ncols:控制进度条宽度mininterval:进度更新最小间隔,以秒为单位(默认值:0.1)"""with tqdm(total=train_total, desc=f'Epoch {epoch + 1}/{epochs}',mininterval=1,ncols=100) as pbar:lr = lr*0.92K.set_value(model.optimizer.lr, lr)for image,label in train_ds:   """训练模型,简单理解train_on_batch就是:它是比model.fit()更高级的一个用法想详细了解 train_on_batch 的同学,可以看看我的这篇文章:https://www.yuque.com/mingtian-fkmxf/hv4lcq/ztt4gy"""history = model.train_on_batch(image,label)train_loss     = history[0]train_accuracy = history[1]pbar.set_postfix({"loss": "%.4f"%train_loss,"accuracy":"%.4f"%train_accuracy,"lr": K.get_value(model.optimizer.lr)})pbar.update(1)history_train_loss.append(train_loss)history_train_accuracy.append(train_accuracy)print('开始验证!')with tqdm(total=val_total, desc=f'Epoch {epoch + 1}/{epochs}',mininterval=0.3,ncols=100) as pbar:for image,label in val_ds:      history = model.test_on_batch(image,label)val_loss     = history[0]val_accuracy = history[1]pbar.set_postfix({"loss": "%.4f"%val_loss,"accuracy":"%.4f"%val_accuracy})pbar.update(1)history_val_loss.append(val_loss)history_val_accuracy.append(val_accuracy)print('结束验证!')print("验证loss为:%.4f"%val_loss)print("验证准确率为:%.4f"%val_accuracy)

输出:

Epoch 1/10: 100%|████████| 340/340 [01:53<00:00,  2.99it/s, loss=0.8901, accuracy=0.1250, lr=9.2e-5]
开始验证!
Epoch 1/10: 100%|█████████████████████| 85/85 [00:03<00:00, 23.67it/s, loss=0.6123, accuracy=0.6250]
结束验证!
验证loss为:0.6123
验证准确率为:0.6250
Epoch 2/10: 100%|███████| 340/340 [00:22<00:00, 15.12it/s, loss=0.1449, accuracy=1.0000, lr=8.46e-5]
开始验证!
Epoch 2/10: 100%|█████████████████████| 85/85 [00:03<00:00, 25.99it/s, loss=0.2008, accuracy=0.8750]
结束验证!
验证loss为:0.2008
验证准确率为:0.8750
Epoch 3/10: 100%|███████| 340/340 [00:22<00:00, 15.23it/s, loss=0.0083, accuracy=1.0000, lr=7.79e-5]
开始验证!
Epoch 3/10: 100%|█████████████████████| 85/85 [00:03<00:00, 25.47it/s, loss=0.0298, accuracy=1.0000]
结束验证!
验证loss为:0.0298
验证准确率为:1.0000
Epoch 4/10: 100%|███████| 340/340 [00:22<00:00, 14.86it/s, loss=0.0321, accuracy=1.0000, lr=7.16e-5]
开始验证!
Epoch 4/10: 100%|█████████████████████| 85/85 [00:03<00:00, 25.84it/s, loss=0.0092, accuracy=1.0000]
结束验证!
验证loss为:0.0092
验证准确率为:1.0000
Epoch 5/10: 100%|███████| 340/340 [00:22<00:00, 15.03it/s, loss=0.3167, accuracy=0.8750, lr=6.59e-5]
开始验证!
Epoch 5/10: 100%|█████████████████████| 85/85 [00:03<00:00, 26.73it/s, loss=0.0381, accuracy=1.0000]
结束验证!
验证loss为:0.0381
验证准确率为:1.0000
Epoch 6/10: 100%|███████| 340/340 [00:22<00:00, 15.38it/s, loss=0.0323, accuracy=1.0000, lr=6.06e-5]
开始验证!
Epoch 6/10: 100%|█████████████████████| 85/85 [00:03<00:00, 25.85it/s, loss=0.0002, accuracy=1.0000]
结束验证!
验证loss为:0.0002
验证准确率为:1.0000
Epoch 7/10: 100%|███████| 340/340 [00:22<00:00, 15.04it/s, loss=0.0005, accuracy=1.0000, lr=5.58e-5]
开始验证!
Epoch 7/10: 100%|█████████████████████| 85/85 [00:03<00:00, 26.34it/s, loss=0.0040, accuracy=1.0000]
结束验证!
验证loss为:0.0040
验证准确率为:1.0000
Epoch 8/10: 100%|███████| 340/340 [00:21<00:00, 15.47it/s, loss=0.0018, accuracy=1.0000, lr=5.13e-5]
开始验证!
Epoch 8/10: 100%|█████████████████████| 85/85 [00:03<00:00, 26.12it/s, loss=0.0171, accuracy=1.0000]
结束验证!
验证loss为:0.0171
验证准确率为:1.0000
Epoch 9/10: 100%|███████| 340/340 [00:22<00:00, 15.38it/s, loss=0.0000, accuracy=1.0000, lr=4.72e-5]
开始验证!
Epoch 9/10: 100%|█████████████████████| 85/85 [00:03<00:00, 26.08it/s, loss=0.0009, accuracy=1.0000]
结束验证!
验证loss为:0.0009
验证准确率为:1.0000
Epoch 10/10: 100%|██████| 340/340 [00:21<00:00, 15.49it/s, loss=0.0050, accuracy=1.0000, lr=4.34e-5]
开始验证!
Epoch 10/10: 100%|████████████████████| 85/85 [00:03<00:00, 26.46it/s, loss=0.0001, accuracy=1.0000]
结束验证!
验证loss为:0.0001
验证准确率为:1.0000

六、模型评估

epochs_range = range(epochs)plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)plt.plot(epochs_range, history_train_accuracy, label='Training Accuracy')
plt.plot(epochs_range, history_val_accuracy, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')plt.subplot(1, 2, 2)
plt.plot(epochs_range, history_train_loss, label='Training Loss')
plt.plot(epochs_range, history_val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

在这里插入图片描述

七、预测

import numpy as np# 采用加载的模型(new_model)来看预测结果
plt.figure(figsize=(18, 3))  # 图形的宽为18高为5
plt.suptitle("预测结果展示")for images, labels in val_ds.take(1):for i in range(8):ax = plt.subplot(1,8, i + 1)  # 显示图片plt.imshow(images[i].numpy())# 需要给图片增加一个维度img_array = tf.expand_dims(images[i], 0) # 使用模型预测图片中的人物predictions = model.predict(img_array)plt.title(class_names[np.argmax(predictions)])plt.axis("off")

输出:

1/1 [==============================] - 0s 247ms/step
1/1 [==============================] - 0s 19ms/step
1/1 [==============================] - 0s 21ms/step
1/1 [==============================] - 0s 23ms/step
1/1 [==============================] - 0s 20ms/step
1/1 [==============================] - 0s 21ms/step
1/1 [==============================] - 0s 22ms/step
1/1 [==============================] - 0s 19ms/step

在这里插入图片描述

八、知识点

1、训练方式

这是我们之前的训练方法。

history = model.fit(train_ds,validation_data=val_ds,epochs=epochs
)

本次使用的训练函数是model.train_on_batch
函数原型:

Model.train_on_batch(x, y=None, sample_weight=None, class_weight=None, return_dict=False)

  • sample_weight:与x长度相同的可选数组,包含适用于每个样本的模型损失的权重。在时态数据的情况下,您可以传递一个具有形状(samples, sequence_length)的2D数组,以便对每个样本的每个时间步应用不同的权重。
  • class_weight:可选的字典。将类索引(整数)映射到权值(浮点数),以应用于训练期间该类样本的模型损失。这对于告诉模型“更多地关注”来自代表性不足的类的样本是有用的。
  • return_dict:如果为True,则损失和度量结果将作为字典返回,其中每个键是度量的名称。如果为False,它们将作为列表返回。

2、tqdm

tqdm是一个用于在终端中显示进度条的Python库。它提供了一种简单的方式来跟踪迭代过程的进度,无论是在循环中处理大量数据还是在长时间运行的任务中。

2.1、基本用法:

  • 在for循环中使用:
from tqdm import tqdm
import timefor i in tqdm(range(10)):time.sleep(1)# 模拟任务执行时间
100%|██████████| 10/10 [00:10<00:00,  1.00s/it]
  • 自定义进度条样式

desc:设置进度条的前缀文本;ncols:设置进度条的长度

from tqdm import tqdm
import time
for i in tqdm(range(10), desc="Processing", ncols=80):time.sleep(0.5)   
Processing: 100%|███████████████████████████████| 10/10 [00:05<00:00,  1.99it/s]

2.2、手动进度更新:

tqdm可以手动更新,将其对象赋给一个变量,然后调用.update(N)方法来更新进度,tqdm()有个可选的参数设置迭代总数,然后通过update方法进行累加,每次执行update都会打印一次当前进度。

示例:新建一个tqdm实例,total=100表示迭代总数为100

percent = tqdm(total=100)

输出:

  0%|          | 0/100 [00:03<?, ?it/s]

调用update(N)方法,表示完成N次迭代,进度条则会显示对应的百分比

percent.update(1)

输出:

  1%|          | 1/100 [00:47<1:18:17, 47.45s/it]

再次调用会进行累加:

percent.update(90)

输出:

 91%|█████████ | 91/100 [01:35<00:08,  1.12it/s] 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/401937.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

mysql写个分区表

因为表量已经达到1个亿了。现在想做个优化&#xff0c;先按照 create_time 时间进行分区吧。 create_time 是varchar类型。 CREATE TABLE orders (id varchar(40) NOT NULL ,order_no VARCHAR(20) NOT NULL,create_time VARCHAR(20) NOT NULL,amount DECIMAL(10,2) NOT NULL,…

springboot使用aop或Jackson进行数据脱敏

1.aop 启动类加EnableAspectJAutoProxy 自定义注解&#xff0c;在实体类中使用表示被脱敏字段 建立aop切面类 可能这里gpt会建议你用Pointcut("execution(public * com.xx.aop..*.get*(..))")这种方式拦截&#xff0c;这种我试了&#xff0c;拦截不住。猜测在mvc返…

FPGA开发——UART回环实现之接收模块的设计

一、简介 因为我们本次进行串口回环的实验的对象是FPGA开发板和PC端&#xff0c;所以在接收和发送模块中先编写接收模块&#xff0c;这样可以在后面更好的进行发送模块的验证。&#xff08;其实这里先编写哪个模块&#xff09;都不影响&#xff0c;这里看自己心情&#xff0c;反…

【SpringBoot】【autopoi】java生成word,基于模版生成(文本、图片、表格)

基于模版生成word 1、引入maven2、word模版编写3、java代码4、效果5、word转pdf AutoPoi的主要特点 参考文献 https://help.jeecg.com/autopoi/autopoi/prequel/test.html 1.设计精巧,使用简单 2.接口丰富,扩展简单 3.默认值多,write less do more 4.spring mvc支持,web导出可以…

【ubuntu24.04】远程开发:微软RDP;ssh远程root登录;clion以root远程

本地配置了一台ubutnu服务器,运行各种服务。偶尔会远程过去,做一些UI操作。感觉nomachine的就是会模糊一些,可能是默认的编码比较均衡?RDP更清晰? RDP 与nomachine比,更清晰,但是貌似不支持自动缩放窗口?默认的配置就比较高:GPT的建议 安装xrdp还要配置session:1. 安…

Git 课程任务

安装好git 写自我介绍 配置完git&#xff0c;进行提交 创建个人仓库 添加链接 本地提交到远程仓库

leetcode198打家劫舍

题目描述 LeetCode 第 198 题——打家劫舍&#xff08;House Robber&#xff09; 你是一个职业小偷&#xff0c;计划偷窃沿街的房屋。每间房内都藏有一定的现金&#xff0c;这个地方所有的房屋都围成一圈&#xff0c;并且相邻的房屋有安全系统会相连&#xff0c;如果两间相邻的…

【C++高阶】哈希—— 位图 | 布隆过滤器 | 哈希切分

✨ 人生如梦&#xff0c;朝露夕花&#xff0c;宛若泡影 &#x1f30f; &#x1f4c3;个人主页&#xff1a;island1314 &#x1f525;个人专栏&#xff1a;C学习 ⛺️ 欢迎关注&#xff1a;&#x1f44d;点赞 &#x1f442;&am…

C++竞赛初阶L1-11-第五单元-for循环(25~26课)519: T454430 人口增长问题

题目内容 假设目前的世界人口有 x 亿&#xff0c;按照每年 0.1% 的增长速度&#xff0c;n 年后将有多少人&#xff1f; 输入格式 一行两个正整数 x 和 n&#xff0c;之间有一个空格。其中&#xff0c;1≤x≤100,1≤n≤100。 输出格式 一行一个数&#xff0c;表示答案。以亿…

RK3576 芯片介绍

RK3576 芯片介绍 RK3576瑞芯微第二代8nm高性能AIOT平台&#xff0c;它集成了独立的6TOPS&#xff08;Tera Operations Per Second&#xff0c;每秒万亿次操作&#xff09;NPU&#xff08;神经网络处理单元&#xff09;&#xff0c;用于处理人工智能相关的任务。此外&#xff0…

使用ITextRenderer导出PDF后无法打开问题,提示‘无法打开此文件‘

依赖如下 <!-- https://mvnrepository.com/artifact/org.xhtmlrenderer/flying-saucer-pdf --> <dependency><groupId>org.xhtmlrenderer</groupId><artifactId>flying-saucer-pdf</artifactId><version>9.1.22</version> &l…

6.MySQL的增删改查

目录 Create 单行插入数据 全列插入 多行数据指定列插入 插入否则更新 主键冲突 唯一键冲突 &#xff08;☆&#xff09; 替换数据 Retrieve Select列 全列查询 指定列查询 查询字段为表达式 where条件 NULL 的查询 NULL 和 NULL 的比较&#xff0c; 和 <>…

如何选择图片和视频

文章目录 1. 概念介绍2. 方法与细节2.1 实现方法2.2 具体细节 3. 示例代码4. 内容总结 我们在上一章回中介绍了"如何选择视频文件"相关的内容&#xff0c;本章回中将介绍如何混合选择图片和视频文件.闲话休提&#xff0c;让我们一起Talk Flutter吧。 1. 概念介绍 我…

Vue3学习 Day01

创建第一个vue项目 1.安装node.js cmd输入node查看是否安装成功 2.vscode开启一个终端&#xff0c;配置淘宝镜像 # 修改为淘宝镜像源 npm config set registry https://registry.npmmirror.com 输入如下命令创建第一个Vue项目 3.下载依赖&#xff0c;启动项目 访问5173端口 …

在线考试系统源码开发

在线考试系统开发需求与功能架构概览可以归纳为以下几个方面&#xff1a; 一、系统开发需求&#xff1a; 1、安全保障&#xff1a;系统需要提供完善的安全措施&#xff0c;这包括但不限于用户身份验证、数据加密技术&#xff0c;以及防止作弊的功能&#xff0c;确保考试的公平…

C语言程序设计-[23] 数组应用(续)

1、输入一行字符,统计其中有多少个单词。 根据以上分析&#xff0c;代码与结果如下&#xff1a; #include "stdio.h"int main ( ) { char c,pre,str[81];int i, n0;gets (str);pre ;for (i0; cstr[i]; i){if (c ! && pre ){ n;}pre c;}printf("…

谷歌发布会回顾:Gemini Live 与 Pixel 9 系列重磅亮相!

在 2024 年的 Made by Google 大会 上&#xff0c;谷歌重磅发布了全新 AI 产品 Gemini Live 和新一代硬件设备 Pixel 9 系列。这场发布会的亮点不只是 AI 的进步&#xff0c;还在于其硬件与 AI 的深度融合。本文将从技术角度回顾此次发布的重点内容&#xff0c;深入解析 Gemini…

Python爬虫——爬取某网站的视频

爬取视频 本次爬取&#xff0c;还是运用的是requests方法 首先进入此网站中&#xff0c;选取你想要爬取的视频&#xff0c;进入视频播放页面&#xff0c;按F12&#xff0c;将网络中的名称栏向上拉找到第一个并点击&#xff0c;可以在标头中&#xff0c;找到后续我们想要的一些…

WebGIS开发中一些常见的概念

0. 坐标系投影 地理坐标系和投影坐标系是两种常用的坐标系统&#xff0c;它们各自有着独特的特性和应用场景。 0.1 地理坐标系 地理坐标系(Geographic Coordinate System&#xff0c; 简称 GCS)是以地球椭球体面为参考面&#xff0c;以法线为依据&#xff0c;用经纬度表示地…

Knowledge-Adaptive Contrastive Learning for Recommendation

Knowledge-Adaptive Contrastive Learning for Recommendation&#xff08;WSDM2023&#xff09; 摘要 通过对用户-项目交互和知识图&#xff08;KG&#xff09;信息进行联合建模&#xff0c;基于知识图谱的推荐系统在缓解数据稀疏和冷启动问题方面表现出了优越性。 近年来&a…