第T7周:咖啡豆识别

  • >- **🍨 本文为[🔗365天深度学习训练营](https://mp.weixin.qq.com/s/0dvHCaOoFnW8SCp3JpzKxg) 中的学习记录博客**
    >- **🍖 原作者:[K同学啊](https://mtyjkh.blog.csdn.net/)**
  • 难度:夯实基础⭐⭐
  • 语言:Python3、TensorFlow2
  • 🏡 我的环境:
  • 语言环境:Python3.8
  • 编译器:jupyter lab
  • 深度学习环境:TensorFlow2.4.1

一、前期工作

1. 设置GPU

如果使用的是CPU可以忽略这步

import tensorflow as tfgpus = tf.config.list_physical_devices("GPU")if gpus:tf.config.experimental.set_memory_growth(gpus[0], True)  #设置GPU显存用量按需使用tf.config.set_visible_devices([gpus[0]],"GPU")

2. 导入数据

from tensorflow       import keras
from tensorflow.keras import layers,models
import numpy             as np
import matplotlib.pyplot as plt
import os,PIL,pathlibdata_dir = "./49-data/"
data_dir = pathlib.Path(data_dir)
image_count = len(list(data_dir.glob('*/*.png')))print("图片总数为:",image_count)

 1200

二、数据预处理

1. 加载数据

使用image_dataset_from_directory方法将磁盘中的数据加载到tf.data.Dataset

from tensorflow import keras
from tensorflow.keras import layers, models
import os, PIL, pathlib
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
import tensorflow as tfbatch_size = 32
img_height = 224
img_width = 224
train_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.2,subset="training",seed = 123,image_size=(img_height, img_width),batch_size=batch_size
)
val_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.2,subset="validation",seed = 123,image_size=(img_height, img_width),batch_size=batch_size
)
class_names = train_ds.class_names
print(class_names)

 

2. 可视化数据

plt.figure(figsize=(10,4))
for images, labels in train_ds.take(1):for i in range(10):ax = plt.subplot(2, 5, i+1)plt.imshow(images[i].numpy().astype('uint8'))plt.title(class_names[np.argmax(labels[i])])plt.axis('off')
for image_batch, labels_batch in train_ds:print(image_batch.shape) print(labels_batch.shape)break

 

3. 配置数据集

AUTOTUNE = tf.data.AUTOTUNEtrain_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size = AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)
normalization_layer = layers.experimental.preprocessing.Rescaling(1./255)train_ds = train_ds.map(lambda x, y: (normalization_layer(x), y))
val_ds   = val_ds.map(lambda x, y: (normalization_layer(x), y))
image_batch, labels_batch = next(iter(val_ds))
first_image = image_batch[0]
print(np.min(first_image), np.max(first_image))
0.0 1.0

 

三、构建VGG-16网络

在官方模型与自建模型之间进行二选一就可以了,选着一个注释掉另外一个。

VGG优缺点分析:

  • VGG优点

VGG的结构非常简洁,整个网络都使用了同样大小的卷积核尺寸(3x3)和最大池化尺寸(2x2)

  • VGG缺点

1)训练时间过长,调参难度大。2)需要的存储容量大,不利于部署。例如存储VGG-16权重值文件的大小为500多MB,不利于安装到嵌入式系统中。

自建模型

from tensorflow.keras import layers, models, Input
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Dense, Flatten, Dropoutdef VGG16(nb_classes, input_shape):input_tensor = Input(shape=input_shape)# 1st blockx = Conv2D(64, (3,3), activation='relu', padding='same',name='block1_conv1')(input_tensor)x = Conv2D(64, (3,3), activation='relu', padding='same',name='block1_conv2')(x)x = MaxPooling2D((2,2), strides=(2,2), name = 'block1_pool')(x)# 2nd blockx = Conv2D(128, (3,3), activation='relu', padding='same',name='block2_conv1')(x)x = Conv2D(128, (3,3), activation='relu', padding='same',name='block2_conv2')(x)x = MaxPooling2D((2,2), strides=(2,2), name = 'block2_pool')(x)# 3rd blockx = Conv2D(256, (3,3), activation='relu', padding='same',name='block3_conv1')(x)x = Conv2D(256, (3,3), activation='relu', padding='same',name='block3_conv2')(x)x = Conv2D(256, (3,3), activation='relu', padding='same',name='block3_conv3')(x)x = MaxPooling2D((2,2), strides=(2,2), name = 'block3_pool')(x)# 4th blockx = Conv2D(512, (3,3), activation='relu', padding='same',name='block4_conv1')(x)x = Conv2D(512, (3,3), activation='relu', padding='same',name='block4_conv2')(x)x = Conv2D(512, (3,3), activation='relu', padding='same',name='block4_conv3')(x)x = MaxPooling2D((2,2), strides=(2,2), name = 'block4_pool')(x)# 5th blockx = Conv2D(512, (3,3), activation='relu', padding='same',name='block5_conv1')(x)x = Conv2D(512, (3,3), activation='relu', padding='same',name='block5_conv2')(x)x = Conv2D(512, (3,3), activation='relu', padding='same',name='block5_conv3')(x)x = MaxPooling2D((2,2), strides=(2,2), name = 'block5_pool')(x)# full connectionx = Flatten()(x)x = Dense(4096, activation='relu',  name='fc1')(x)x = Dense(4096, activation='relu', name='fc2')(x)output_tensor = Dense(nb_classes, activation='softmax', name='predictions')(x)model = Model(input_tensor, output_tensor)return modelmodel=VGG16(len(class_names), (img_width, img_height, 3))
model.summary()

四、编译


在准备对模型进行训练之前,还需要再对其进行一些设置。以下内容是在模型的编译步骤中添加的:
●损失函数(loss):用于衡量模型在训练期间的准确率。
●优化器(optimizer):决定模型如何根据其看到的数据和自身的损失函数进行更新。
●指标(metrics):用于监控训练和测试步骤。以下示例使用了准确率,即被正确分类的图像的比率。

# 设置初始学习率
initial_learning_rate = 1e-4lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(initial_learning_rate, decay_steps=30,      decay_rate=0.92,     staircase=True)# 设置优化器
opt = tf.keras.optimizers.Adam(learning_rate=initial_learning_rate)model.compile(optimizer=opt,loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=['accuracy'])

五、训练模型

🔊注:从本周开始,网络越来越复杂,对算力要求也更高,CPU训练模型时间会很长,建议尽可能的使用GPU来跑。

epochs = 20history = model.fit(train_ds,validation_data=val_ds,epochs=epochs
)

六、可视化结果

acc = history.history['accuracy']
val_acc = history.history['val_accuracy']loss = history.history['loss']
val_loss = history.history['val_loss']epochs_range = range(epochs)plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

七、总结

VGG-16也存在一些局限性,如参数量较大导致训练和推理时间较长,且需要大量资源;对小尺寸图像和资源有限的环境可能不理想等。在实际应用中,需要根据具体任务和资源条件进行权衡和选择。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/452509.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Qt5.14.2 安装详细教程(图文版)

Qt 是一个跨平台的 C 应用程序开发框架,主要用于开发图形用户界面(GUI)程序,但也支持非 GUI 程序的开发。Qt 提供了丰富的功能库和工具,使开发者能够在不同平台上编写、编译和运行应用程序,而无需修改代码。…

如何通过CDN优化网站服务器访问速度?

CDN,即内容分发网络(Content Delivery Network),在现代互联网中起着重要作用。它可以显著提升网站服务器的访问速度。以下是CDN在加速网站访问方面的主要优势及其工作原理。 1. 全球分布的服务器节点 CDN通过在全球范围内布设多个…

CANoe与C#联合仿真方案

引言 CANoe作为一款强大的网络仿真工具,能够模拟各种通信协议,尤其是在汽车领域的CAN、LIN、Ethernet等协议。而C#作为一种广泛使用的编程语言,能够为CANoe提供灵活的用户界面和逻辑控制。本文将探讨如何将CANoe与C#结合,实现高效的联合仿真方案。 1. 系统架构 联合仿真…

Python学习的自我理解和想法(16)

学的是b站的课程(千锋教育),跟老师写程序,不是自创的代码! 今天是学Python的第16天,从今天开始,每天一到两个常用模块,更完恢复到原来的。开学了,时间不多,写…

Bug:通过反射修改@Autowired注入Bean的字段,明确存在,报错 NoSuchFieldException

【BUG】通过Autowired注入了一个Bean SeqNo,测试的时候需要修改其中的字段。通过传统的反射,无论如何都拿不到信息,关键是一方面可以通过IDEA跳转,一方面debug也确实能看到这个字段。但是每次调用set方法报错:NoSuchFi…

Nuxt.js 应用中的 app:templatesGenerated 事件钩子详解

title: Nuxt.js 应用中的 app:templatesGenerated 事件钩子详解 date: 2024/10/19 updated: 2024/10/19 author: cmdragon excerpt: app:templatesGenerated 是 Nuxt.js 的一个生命周期钩子,在模板编译到虚拟文件系统(Virtual File System, VFS)之后被调用。这个钩子允许…

【数据分享】1901-2023年我国省市县三级逐月最低气温(免费获取/Shp/Excel格式)

之前我们分享过1901-2023年1km分辨率逐月最低气温栅格数据(可查看之前的文章获悉详情),该数据来源于国家青藏高原科学数据中心,很多小伙伴拿到数据后反馈栅格数据不太方便使用,问我们能不能把数据处理为更方便使用的Sh…

Android15之解决gdb:Remote register badly formatted问题(二百三十六)

简介: CSDN博客专家、《Android系统多媒体进阶实战》一书作者 新书发布:《Android系统多媒体进阶实战》🚀 优质专栏: Audio工程师进阶系列【原创干货持续更新中……】🚀 优质专栏: 多媒体系统工程师系列【…

C++临时变量的常量性

C 临时变量的常量性-CSDN博客 #include <iostream> using namespace std; void print(string& str) {cout << str << endl; }int main() {print("hello world");//print(string("hello world"));return 0; }编译器根据字符串"…

探索 Coconut: Python 的新篇章

文章目录 探索 Coconut: Python 的新篇章背景&#xff1a;为何选择 Coconut&#xff1f;Coconut 是什么&#xff1f;如何安装 Coconut&#xff1f;简单的库函数使用方法1. 惰性列表2. 模式匹配3. 函数组合4. 协程5. 模式匹配数据类型 场景应用1. Web 开发2. 数据处理3. 异步编程…

【C++】类的默认成员函数:深入剖析与应用(下)

&#x1f4af;前言 回顾上篇文章&#x1f449;【C】类的默认成员函数&#xff1a;深入剖析与应用&#xff08;上&#xff09;中对构造函数、拷贝构造函数和析构函数的讨论&#xff0c;强调这些默认成员函数在类的创建、初始化和销毁过程中的重要性。 ✍引出本篇将继续探讨剩余…

【工具变量】A股上市企业大数据应用(2001-2023年)-参考柏淑嫄实践

数据简介&#xff1a;企业数字化转型的浪潮孕育出大数据&#xff0c;大数据技术是在数据处理和应用中释放大数据多元价值的必要手段。大数据作为企业发展的战略资源和生产要素对企业转型发展具有重要意义。对上市企业大数据应用程度进行测算不仅有助于了解大数据相关技术在企业…

Triton Inference Server 架构原理

文章目录 TensorRT-LLM & Triton Server 部署回顾部署梳理Triton 架构 为什么要使用 backend &#xff1f;triton_model_repo 目录结构Ensemble 模式BLS 模式 上篇文章进行了 TensorRT-LLM & Triton Server 部署 &#xff0c;本篇简单讲讲 Triton Inference Server 的架…

ECCV2024 Tracking 汇总

一、OneTrack: Demystifying the Conflict Between Detection and Tracking in End-to-End 3D Trackers paper&#xff1a; https://www.ecva.net/papers/eccv_2024/papers_ECCV/papers/01174.pdf 二、VETRA: A Dataset for Vehicle Tracking in Aerial Imagery paper&#…

基于ECS和NAS搭建个人网盘

前言 在数字化时代&#xff0c;数据已成为我们生活中不可或缺的一部分。个人文件、照片、视频等数据的积累&#xff0c;使得我们需要一个安全、可靠且便捷的存储解决方案。传统的物理存储设备&#xff08;如硬盘、U盘&#xff09;虽然方便&#xff0c;但存在易丢失、损坏和数据…

2013 lost connection to MySQL server during query

1.问题 使用navicat连接doris&#xff0c;会有这个错误。 2.解决 换低版本的navicat比如navicat11。

【LVGL快速入门(二)】LVGL开源框架入门教程之框架使用(UI界面设计)

零.前置篇章 本篇前置文章为【LVGL快速入门(一)】LVGL开源框架入门教程之框架移植 一.UI设计 介绍使用之前&#xff0c;我们要学习一款LVGL官方的UI设计工具SquareLine Studio&#xff0c;使用图形化设计方式设计出我们想要的界面&#xff0c;然后生成对应源文件导入工程使用…

人工智能公司未达到欧盟人工智能法案标准

关注公众号网络研究观获取更多内容。 据路透社获得的数据显示&#xff0c;领先的人工智能&#xff08;AI&#xff09;模型在网络安全弹性和防止歧视性输出等领域未能满足欧洲关键监管标准。 《欧盟人工智能法案》将在未来两年分阶段实施&#xff0c;旨在解决人们对这些技术在…

【计网】从零开始理解TCP协议 --- 拥塞控制机制,延迟应答机制,捎带应答,面向字节流

时间就是性命。 无端的空耗别人的时间&#xff0c; 其实是无异于谋财害命的。 --- 鲁迅 --- 从零开始理解TCP协议 1 拥塞控制2 延迟应答3 捎带应答4 面向字节流5 TCP异常情况TCP小结 1 拥塞控制 尽管TCP拥有滑动窗口这一高效的数据传输机制&#xff0c;能够确保在对方接收…

倍福TwinCAT程序中遇到的bug

文章目录 问题描述&#xff1a;TwinCAT嵌入式控制器CX5140在上电启动后&#xff0c;X001网口接网线通讯灯不亮&#xff0c;软件扫描不到硬件网口 解决方法&#xff1a;硬件断电重启后&#xff0c;X001网口恢复正常 问题描述&#xff1a;TwinCAT软件点击激活配置后&#xff0c;…