第T2周:彩色图片分类

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

👉 要求:

  • 学习如何编写一个完整的深度学习程序
  • 了解分类彩色图片会灰度图片有什么区别
  • 测试集accuracy到达72%

🦾我的环境:

  • 语言环境:Python3.8
  • 编译器:Jupyter Lab
  • 深度学习环境:
    • TensorFlow2

一、 前期准备

1.1. 设置GPU

  • 如果设备上支持GPU就使用GPU,否则使用CPU
  • Mac上的GPU使用mps
import tensorflow as tfgpus = tf.config.list_physical_devices("GPU")if gpus:gpu0 = gpus[0] #如果有多个GPU,仅使用第0个GPUtf.config.experimental.set_memory_growth(gpu0, True) #设置GPU显存用量按需使用tf.config.set_visible_devices([gpu0],"GPU")gpu0
PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')

1.2. 导入数据

使用dataset下载MNIST数据集,并划分好训练集与测试集

import tensorflow as tf
from tensorflow.keras import datasets, layers, models
import matplotlib.pyplot as plt(train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data()
A local file was found, but it seems to be incomplete or outdated because the auto file hash does not match the original value of 6d958be074577803d12ecdefd02955f39262c83c16fe9348329d7fe0b5c001ce so we will re-download the data.
Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
170498071/170498071 [==============================] - 8500s 50us/step

1.3. 归一化

数据归一化作用

● 使不同量纲的特征处于同一数值量级,减少方差大的特征的影响,使模型更准确。
● 加快学习算法的收敛速度。

更详解的介绍请参考文章:🔗归一化与标准化

# 将像素的值标准化至0到1的区间内。(对于灰度图片来说,每个像素最大值是255,每个像素最小值是0,也就是直接除以255就可以完成归一化。)
train_images, test_images = train_images / 255.0, test_images / 255.0# 查看数据维数信息
train_images.shape,test_images.shape,train_labels.shape,test_labels.shape
((50000, 32, 32, 3), (10000, 32, 32, 3), (50000, 1), (10000, 1))

1.4. 可视化图片

class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer','dog', 'frog', 'horse', 'ship', 'truck']plt.figure(figsize=(20,10))
for i in range(20):plt.subplot(5,10,i+1)plt.xticks([])plt.yticks([])plt.grid(False)plt.imshow(train_images[i], cmap=plt.cm.binary)plt.xlabel(class_names[train_labels[i][0]])
plt.show()

在这里插入图片描述

二、构建简单的CNN网络

⭐池化层

池化层对提取到的特征信息进行降维,一方面使特征图变小,简化网络计算复杂度;另一方面进行特征压缩,提取主要特征,增加平移不变性,减少过拟合风险。但其实池化更多程度上是一种计算性能的一个妥协,强硬地压缩特征的同时也损失了一部分信息,所以现在的网络比较少用池化层或者使用优化后的如SoftPool。

池化层包括最大池化层(MaxPooling)和平均池化层(AveragePooling),均值池化对背景保留更好,最大池化对纹理提取更好)。同卷积计算,池化层计算窗口内的平均值或者最大值。例如通过一个 2*2 的最大池化层,其计算方式如下:
在这里插入图片描述

我们即将构建模型的结构图,我以分别二维和三维的形式展示出来方便大家理解。

  • 平面结构图
    在这里插入图片描述

  • 立体结构图
    在这里插入图片描述

model = models.Sequential([layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)), #卷积层1,卷积核3*3layers.MaxPooling2D((2, 2)),                   #池化层1,2*2采样layers.Conv2D(64, (3, 3), activation='relu'),  #卷积层2,卷积核3*3layers.MaxPooling2D((2, 2)),                   #池化层2,2*2采样layers.Conv2D(64, (3, 3), activation='relu'),  #卷积层3,卷积核3*3layers.Flatten(),                      #Flatten层,连接卷积层与全连接层layers.Dense(64, activation='relu'),   #全连接层,特征进一步提取layers.Dense(10)                       #输出层,输出预期结果
])model.summary()  # 打印网络结构
Model: "sequential"
_________________________________________________________________Layer (type)                Output Shape              Param #   
=================================================================conv2d (Conv2D)             (None, 30, 30, 32)        896       max_pooling2d (MaxPooling2  (None, 15, 15, 32)        0         D)                                                              conv2d_1 (Conv2D)           (None, 13, 13, 64)        18496     max_pooling2d_1 (MaxPoolin  (None, 6, 6, 64)          0         g2D)                                                            conv2d_2 (Conv2D)           (None, 4, 4, 64)          36928     flatten (Flatten)           (None, 1024)              0         dense (Dense)               (None, 64)                65600     dense_1 (Dense)             (None, 10)                650       =================================================================
Total params: 122570 (478.79 KB)
Trainable params: 122570 (478.79 KB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________2024-06-23 22:16:01.054779: I metal_plugin/src/device/metal_device.cc:1154] Metal device set to: Apple M2
2024-06-23 22:16:01.054802: I metal_plugin/src/device/metal_device.cc:296] systemMemory: 16.00 GB
2024-06-23 22:16:01.054811: I metal_plugin/src/device/metal_device.cc:313] maxCacheSize: 5.33 GB
2024-06-23 22:16:01.054984: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:303] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2024-06-23 22:16:01.055316: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:269] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)

三、编译模型

model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=['accuracy'])

四、训练模型

history = model.fit(train_images, train_labels, epochs=10, validation_data=(test_images, test_labels))
Epoch 1/102024-06-23 22:16:41.825293: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.1563/1563 [==============================] - ETA: 0s - loss: 1.5781 - accuracy: 0.42422024-06-23 22:16:54.304550: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.1563/1563 [==============================] - 13s 8ms/step - loss: 1.5781 - accuracy: 0.4242 - val_loss: 1.3528 - val_accuracy: 0.5133
Epoch 2/10
1563/1563 [==============================] - 12s 8ms/step - loss: 1.2892 - accuracy: 0.5464 - val_loss: 1.2880 - val_accuracy: 0.5617
Epoch 3/10
1563/1563 [==============================] - 12s 8ms/step - loss: 1.3585 - accuracy: 0.5521 - val_loss: 1.6484 - val_accuracy: 0.5155
Epoch 4/10
1563/1563 [==============================] - 12s 8ms/step - loss: 2.0448 - accuracy: 0.5044 - val_loss: 3.0545 - val_accuracy: 0.4380
Epoch 5/10
1563/1563 [==============================] - 12s 8ms/step - loss: 5.7139 - accuracy: 0.4563 - val_loss: 20.7035 - val_accuracy: 0.2908
Epoch 6/10
1563/1563 [==============================] - 12s 8ms/step - loss: 45.9029 - accuracy: 0.3672 - val_loss: 109.2576 - val_accuracy: 0.3624
Epoch 7/10
1563/1563 [==============================] - 12s 8ms/step - loss: 504.0281 - accuracy: 0.2838 - val_loss: 1375.9681 - val_accuracy: 0.2399
Epoch 8/10
1563/1563 [==============================] - 12s 8ms/step - loss: 3719.2263 - accuracy: 0.2359 - val_loss: 6212.4688 - val_accuracy: 0.2268
Epoch 9/10
1563/1563 [==============================] - 12s 8ms/step - loss: 11472.0957 - accuracy: 0.2238 - val_loss: 20005.8828 - val_accuracy: 0.1773
Epoch 10/10
1563/1563 [==============================] - 12s 8ms/step - loss: 25618.4004 - accuracy: 0.2182 - val_loss: 31095.4336 - val_accuracy: 0.2160

五、预测

通过模型进行预测得到的是每一个类别的概率,数字越大该图片为该类别的可能性越大

plt.imshow(test_images[1])

在这里插入图片描述

输出测试集中第一张图片的预测结果

import numpy as nppre = model.predict(test_images)
print(class_names[np.argmax(pre[1])])
 75/313 [======>.......................] - ETA: 0s2024-06-23 22:20:12.257425: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.313/313 [==============================] - 1s 3ms/step
ship

六、模型评估

import matplotlib.pyplot as pltplt.plot(history.history['accuracy'], label='accuracy')
plt.plot(history.history['val_accuracy'], label = 'val_accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.ylim([0.5, 1])
plt.legend(loc='lower right')
plt.show()test_loss, test_acc = model.evaluate(test_images,  test_labels, verbose=2)

在这里插入图片描述

print(test_acc)
0.6845156432345124

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/357854.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

前端下载文件流,axios设置responseType: arraybuffer/blob无效

项目中调用后端下载文件接口&#xff0c;设置responseType: arraybuffer,实际拿到的数据data是字符串 axios({method: post,url: /api/v1/records/recording-file/play,// 如果有需要发送的数据&#xff0c;可以放在这里data: { uuid: 06e7075d-4ce0-476f-88cb-87fb0a1b4844 }…

使用VisualBox+Vagrant搭建Centos虚拟机环境

1.下载并安装VisualBox&#xff1b; 2.下载并安装Vagrant; 3.打开cmd窗口&#xff0c;执行命令vagrant init centos/7&#xff0c;初始化centos环境&#xff0c;该步骤受网络带宽影响&#xff0c;可能挂级30分钟到1个小时&#xff1b; 4.启动虚拟机&#xff1a;vagrant up&…

基于CDMA的多用户水下无线光通信(2)——系统模型和基于子空间的延时估计

本文首先介绍了基于CDMA的多用户UOWC系统模型&#xff0c;并给出了多用户收发信号的数学模型。然后介绍基于子空间的延时估计算法&#xff0c;该算法只需要已知所有用户的扩频码&#xff0c;然后根据扩频波形的循环移位在观测空间的信号子空间上的投影进行延时估计。 1、基于C…

【Linux进程】进程的 切换 与 调度(图形化解析,小白一看就懂!!!)

目录 &#x1f525;前言&#x1f525; &#x1f4a7;进程切换&#x1f4a7; &#x1f4a7;进程调度&#x1f4a7; &#x1f525;总结与提炼&#x1f525; &#x1f525;共勉&#x1f525; &#x1f525;前言&#x1f525; 在 Linux 操作系统中&#xff0c;进程的 调度 与 …

【STM32-新建工程-寄存器版本】

STM32-新建工程-寄存器版本 ■ 下载相关STM32Cube官方固件包&#xff08;F1&#xff0c;F4&#xff0c;F7&#xff0c;H7&#xff09;■ 1. ST官方搜索STM32Cube■ 2. 搜索 STM32Cube■ 3. 点击获取软件■ 4. 选择对应的版本下载■ 5. 输入账号信息■ 6. 出现下载弹框&#xff…

React@16.x(34)动画(中)

目录 3&#xff0c;SwitchTransition3.1&#xff0c;原理3.1.2&#xff0c;key3.1.2&#xff0c;mode 3.2&#xff0c;举例3.3&#xff0c;结合 animate.css 4&#xff0c;TransitionGroup4.1&#xff0c;其他属性4.1.2&#xff0c;appear4.1.2&#xff0c;component4.1.3&…

MFC学习--CListCtrl复选框以及选择

如何展示复选框 //LVS_EX_CHECKBOXES每一行的最前面带个复选框//LVS_EX_FULLROWSELECT整行选中//LVS_EX_GRIDLINES网格线//LVS_EX_HEADERDRAGDROP列表头可以拖动m_listctl.SetExtendedStyle(LVS_EX_FULLROWSELECT | LVS_EX_CHECKBOXES | LVS_EX_GRIDLINES); 全选&#xff0c;全…

如何获得一个Oracle 23ai数据库(vagrant box)

准确的说&#xff0c;是Oracle 23ai Free Developer版&#xff0c;因为企业版目前只在云上&#xff08;OCI和Azure&#xff09;和ECC上提供。 前面我博客介绍了3种方法&#xff1a; Virtual ApplianceRPM安装Docker 今天介绍最近新出的一种方法&#xff0c;也是我最为推荐的…

探索CSS clip-path: polygon():塑造元素的无限可能

在CSS的世界里&#xff0c;clip-path 属性赋予了开发者前所未有的能力&#xff0c;让他们能够以非传统的方式裁剪页面元素&#xff0c;创造出独特的视觉效果。其中&#xff0c;polygon() 函数尤其强大&#xff0c;它允许你使用多边形来定义裁剪区域的形状&#xff0c;从而实现各…

定时器-前端使用定时器3s轮询状态接口,2min为接口超时

背景 众所周知&#xff0c;后端是处理不了复杂的任务的&#xff0c;所以经过人家的技术讨论之后&#xff0c;把业务放在前端来实现。记录一下这次的离大谱需求吧。 如图所示&#xff0c;这个页面有5个列表&#xff0c;默认加载计划列表。但是由于后端的种种原因&#xff0c;这…

【C#】使用数字和时间方法ToString()格式化输出字符串显示

在C#编程项目开发中&#xff0c;几乎所有对象都有格式化字符串方法&#xff0c;其中常见的是数字和时间的格式化输出多少不一样&#xff0c;按实际需要而定吧&#xff0c;现记录如下&#xff0c;以后会用得上。 文章目录 数字格式化时间格式化 数字格式化 例如&#xff0c;保留…

WPF三方UI库全局应用MessageBox样式(.NET6版本)

一、问题场景 使用HandyControl简写HC 作为基础UI组件库时&#xff0c;希望系统中所有的MessageBox 样式都使用HC的MessageBox&#xff0c;常规操作如下&#xff1a; 在对应的xxxx.cs 顶部使用using 指定特定类的命名空间。 using MessageBox HandyControl.Controls.Message…

快去复习吧+++常用算法及参考算法 递推法++穷举法++排序(冒泡、选择)++查找(顺序、折半)++字符串处理++方程求根++无穷级数求和

接上&#xff1a;常用算法及参考算法 &#xff08;1&#xff09;累加 &#xff08;2&#xff09;累乘 &#xff08;3&#xff09;素数 &#xff08;4&#xff09;最大公约数 &#xff08;5&#xff09;最值问题 &#xff08;6&#xff09;迭代法 常用算法及参考算法 7. 递推法…

【LocalAI】(13):LocalAI最新版本支持Stable diffusion 3,20亿参数图像更加细腻了,可以继续研究下

最新版本v2.17.1 https://github.com/mudler/LocalAI/releases Stable diffusion 3 You can use Stable diffusion 3 by installing the model in the gallery (stable-diffusion-3-medium) or by placing this YAML file in the model folder: Stable Diffusion 3 Medium 正…

Git使用过程中涉及的几个区域

一. 简介 Git 是一个开源的分布式版本控制系统&#xff0c;可以有效、高速的处理从很小到非常大的项目版本管理&#xff0c;也是 Linus Torvalds 为了帮助管理 Linux内核开发而开发的一个开放源码的版本控制软件。 本文简单了解一下 git涉及的几个部分&#xff0c;以及git 常…

Django 模版过滤器

Django模版过滤器是一个非常有用的功能&#xff0c;它允许我们在模版中处理数据。过滤器看起来像这样&#xff1a;{{ name|lower }}&#xff0c;这将把变量name的值转换为小写。 1&#xff0c;创建应用 python manage.py startapp app5 2&#xff0c;注册应用 Test/Test/sett…

安卓中使用ttf字体文件

官方文档中提供的方法要设备能访问google&#xff1f; 官方方法 直接下载字体的fft文件 我要使用的是lexend 需要的格式可以在里面搜索 使用下载的ttf文件 解压出来 可以单独使用static里面的&#xff0c;里面是直接的lexend的各种格式 但是我这里直接使用Lexend-Vari…

IDEA Plugins中搜索不到插件解决办法

IDEA中搜不到插件有三种解决方案&#xff1a; 设置HTTP选项&#xff0c;可以通过File->Settings->Plugins->⚙->HTTP Proxy Settings进行设置 具体可参考这篇博文&#xff1a;IDEA Plugins中搜索不到插件解决办法本地安装&#xff0c;ile->Settings->Plugin…

【python】python股票量化交易策略分析可视化(源码+数据集+论文)【独一无二】

&#x1f449;博__主&#x1f448;&#xff1a;米码收割机 &#x1f449;技__能&#x1f448;&#xff1a;C/Python语言 &#x1f449;公众号&#x1f448;&#xff1a;测试开发自动化【获取源码商业合作】 &#x1f449;荣__誉&#x1f448;&#xff1a;阿里云博客专家博主、5…

【Leetcode】520. 检测大写字母

文章目录 题目思路代码复杂度分析时间复杂度空间复杂度 结果总结 题目 题目链接&#x1f517;我们定义&#xff0c;在以下情况时&#xff0c;单词的大写用法是正确的&#xff1a; 全部字母都是大写&#xff0c;比如 “USA” 。单词中所有字母都不是大写&#xff0c;比如 “le…