【TensorFlow2 之014】在 TF 2.0 中实现 LeNet-5

一、说明

         在这篇文章中,我们将展示如何在 TensorFlow 中实现像 \(LeNet-5\) 这样的基础卷积神经网络。LeNet-5 架构由 Yann LeCun 于 1998 年发明,是第一个卷积神经网络。

 数据黑客变种rs    深度学习 机器学习 TensorFlow    2020 年 2 月 29 日  |  0

1.1 教程概述:

  1. 理论重述
  2. 在 TensorFlow 中的实现

1. 理论重述

\(LeNet-5 \) 的目标是识别手写数字。因此,它作为输入 \(32\times32\times1 \) 图像。它是灰度图像,因此通道数为 \(1 \)。下面我们可以看到该网络的架构。

LeNet5架构

LeNet-5架构

        I. 最初的 MNIST 现在被认为太简单了。在过去的二十年里,许多研究人员针对原始 MNIST 提出了成功的解决方案。您可以在此处查看直接比较结果。

        由于该网络主要是为 MNIST 数据集设计的,因此它的性能明显更好。通过微小的改变,它就可以在 Fashion MNIST 数据集上达到这种准确性。然而,在这篇文章中,我们将坚持网络的原始架构。

关于 LeNet-5 架构,您可以在此处阅读详细的理论文章。

        让我们总结一下 LeNet-5 架构的各层。

图层类型特征图尺寸内核大小跨步激活
图像132×32
卷积628×285×51正值
平均池化614×142×22
卷积1610×105×51正值
平均池化165×52×22
完全连接120正值
完全连接84正值
完全连接10软最大

二、TensorFlow中的实现

        交互式 Colab 笔记本可在以下链接找到  
        在 Google Colab 中运行

为了练习,您可以尝试通过将Fashion_mnist替换为mnist、cifar10或其他来更改数据集。


        让我们从导入所有必需的库开始。导入后,我们可以使用导入的模块来加载数据。load_data  () 函数将自动下载数据并将其拆分为训练集和测试集。

import datetime
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as pltfrom tensorflow.keras import Model
from tensorflow.keras.models import Sequential
from tensorflow.keras.losses import categorical_crossentropy
from tensorflow.keras.layers import Dense, Flatten, Conv2D, AveragePooling2Dfrom tensorflow.keras import datasets
from tensorflow.keras.utils import to_categorical#from __future__ import absolute_import, division, print_function, unicode_literals
# The data, split between train and test sets:
(x_train, y_train), (x_test, y_test) = datasets.fashion_mnist.load_data()
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz
32768/29515 [=================================] - 0s 2us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz
26427392/26421880 [==============================] - 7s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz
8192/5148 [===============================================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz
4423680/4422102 [==============================] - 1s 0us/step

        我们可以检查新数据的形状,发现我们的图像是 28×28 像素,因此我们需要添加一个新轴,它将代表多个通道。此外,对标签进行 one-hot 编码和对输入图像进行归一化也很重要。

print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')
print(x_train[0].shape, 'image shape')

x_train shape: (60000, 28, 28)
60000 train samples
10000 test samples
(28, 28) image shape
 

# Add a new axis
x_train = x_train[:, :, :, np.newaxis]
x_test = x_test[:, :, :, np.newaxis]print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')
print(x_train[0].shape, 'image shape')
x_train shape: (60000, 28, 28, 1)
60000 train samples
10000 test samples
(28, 28, 1) image shape
# Convert class vectors to binary class matrices.num_classes = 10
y_train = to_categorical(y_train, num_classes)
y_test = to_categorical(y_test, num_classes)
# Data normalization
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255


        现在,是时候开始使用 TensorFlow 2.0 来构建我们的卷积神经网络了。最简单的方法是使用 Sequential API。我们将其包装在一个名为LeNet的类中。输入是图像,输出是类概率向量。

在 tf.keras 中,顺序模型表示层的线性堆栈,在本例中它遵循网络架构。

# LeNet-5 model
class LeNet(Sequential):def __init__(self, input_shape, nb_classes):super().__init__()self.add(Conv2D(6, kernel_size=(5, 5), strides=(1, 1), activation='tanh', input_shape=input_shape, padding="same"))self.add(AveragePooling2D(pool_size=(2, 2), strides=(2, 2), padding='valid'))self.add(Conv2D(16, kernel_size=(5, 5), strides=(1, 1), activation='tanh', padding='valid'))self.add(AveragePooling2D(pool_size=(2, 2), strides=(2, 2), padding='valid'))self.add(Flatten())self.add(Dense(120, activation='tanh'))self.add(Dense(84, activation='tanh'))self.add(Dense(nb_classes, activation='softmax'))self.compile(optimizer='adam',loss=categorical_crossentropy,metrics=['accuracy'])
model = LeNet(x_train[0].shape, num_classes)
model.summary()
Model: "le_net"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d (Conv2D)              (None, 28, 28, 6)         156       
_________________________________________________________________
average_pooling2d (AveragePo (None, 14, 14, 6)         0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 10, 10, 16)        2416      
_________________________________________________________________
average_pooling2d_1 (Average (None, 5, 5, 16)          0         
_________________________________________________________________
flatten (Flatten)            (None, 400)               0         
_________________________________________________________________
dense (Dense)                (None, 120)               48120     
_________________________________________________________________
dense_1 (Dense)              (None, 84)                10164     
_________________________________________________________________
dense_2 (Dense)              (None, 10)                850       
=================================================================
Total params: 61,706
Trainable params: 61,706
Non-trainable params: 0

        创建模型后,我们需要训练它的参数以使其强大。让我们训练给定数量的 epoch 模型。我们可以在TensorBoard中看到训练的进度。

# Place the logs in a timestamped subdirectory
# This allows to easy select different training runs
# In order not to overwrite some data, it is useful to have a name with a timestamp
log_dir="logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
# Specify the callback object
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)# tf.keras.callback.TensorBoard ensures that logs are created and stored
# We need to pass callback object to the fit method
# The way to do this is by passing the list of callback objects, which is in our case just one
model.fit(x_train, y=y_train, epochs=20, validation_data=(x_test, y_test), callbacks=[tensorboard_callback],verbose=0)%tensorboard --logdir logs/fit

        仅 20 个 epoch 就已经不错了。

        之后,我们可以做出一些预测并将其可视化。使用predict_classes()函数,我们可以从网络中获取准确的类别而不是概率值。它与使用numpy.argmax()相同。我们将用红色显示错误的预测,用蓝色表示正确的预测。

class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat','Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']prediction_values = model.predict_classes(x_test)# set up the figure
fig = plt.figure(figsize=(15, 7))
fig.subplots_adjust(left=0, right=1, bottom=0, top=1, hspace=0.05, wspace=0.05)# plot the images: each image is 28x28 pixels
for i in range(50):ax = fig.add_subplot(5, 10, i + 1, xticks=[], yticks=[])ax.imshow(x_test[i,:].reshape((28,28)),cmap=plt.cm.gray_r, interpolation='nearest')if prediction_values[i] == np.argmax(y_test[i]):# label the image with the blue textax.text(0, 7, class_names[prediction_values[i]], color='blue')else:# label the image with the red textax.text(0, 7, class_names[prediction_values[i]], color='red')
时尚迷斯特

三、总结

        所以,在这里我们学习了如何在 Tensorflow 2.0 中开发和训练 LeNet-5。在下 一篇文章中 ,我们将继续实现流行的卷积神经网络,并学习如何在 TensorFlow 2.0 中实现AlexNet 。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/159086.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

GB28181平台简介

产品简介 LiveMedia视频中间件是支持部署到本地服务器或者云服务器的纯软件服务,也提供服务器、GPU一体机全包服务,提供视频设备管理、无插件、跨平台的实时视频、历史回放、语音对讲、设备控制等基础功能,支持视频协议有海康、大华私有协议…

竞赛 深度学习LSTM新冠数据预测

文章目录 0 前言1 课题简介2 预测算法2.1 Logistic回归模型2.2 基于动力学SEIR模型改进的SEITR模型2.3 LSTM神经网络模型 3 预测效果3.1 Logistic回归模型3.2 SEITR模型3.3 LSTM神经网络模型 4 结论5 最后 0 前言 🔥 优质竞赛项目系列,今天要分享的是 …

Idea创建springboot工程的时候,发现pom文件没有带<parent>标签

今天创建springboot工程,加载maven的时候报错: 这个问题以前遇到过,这是因为 mysql-connector-j 没有带版本号的原因,但是springboot的依赖的版本号不是都统一交给spring-boot-starter-parent管理了吗,为什么还会报错&…

华为云云耀云服务器L实例评测|华为云耀云服务器L实例评测用例(五)

六、华为云耀云服务器L实例评测用例: “兵马未动,粮草先行”,随着企业业务的快速发展,服务器在数字化建设体系至关重要,为了保证服务器的稳定性、可靠性,需要对服务器进行评测,以确保服务器能够…

kafka详解(三)

2.2 Kafka命令行操作 2.2.1 主题命令行操作 1)查看操作主题命令参数 [aahadoop102 kafka]$ bin/kafka-topics.sh2)查看当前服务器中的所有topic (配置了环境变量不需要写bin/) [aahadoop102 kafka]$ bin/kafka-topics.sh --bootstrap-server hadoop10…

Linux gcc和make学习

文章目录 GCCgcc的安装gcc的工作流程 makefilemakefile的规则工作原理自动生成makefile的变量自定义变量预定义变量自动变量 模式匹配函数wildcard函数patsubst函数 伪声明 GCC gcc全程是(GNU compiler collection CNU编译器套件),是由GNU开发…

想要精通算法和SQL的成长之路 - 分割数组的最大值

想要精通算法和SQL的成长之路 - 分割数组的最大值 前言一. 分割数组的最大值1.1 二分法 前言 想要精通算法和SQL的成长之路 - 系列导航 一. 分割数组的最大值 原题链接 首先面对这个题目,我们可以捕获几个关键词: 非负整数。非空连续子数组。 那么我…

线性排序:如何根据年龄给100万用户数据排序?

文章来源于极客时间前google工程师−王争专栏。 桶排序、计数排序、基数排序时间复杂度是O(n),所以这类排序算法叫作线性排序。 线性的原因:三个算法是非基于比较的排序算法,都不涉及元素之间的比较操作。 三种排序对排序的数据要求苛刻&am…

CCF CSP认证 历年题目自练Day30

题目一 试题编号: 202203-1 试题名称: 未初始化警告 时间限制: 1.0s 内存限制: 512.0MB 问题描述: 题目背景 一个未经初始化的变量,里面存储的值可能是任意的。因此直接使用未初始化的变量,比…

太强了,真的太强了!

国庆之后gpt4上线了很多强大的功能,有超级强大的数据分析和挖掘的功能,有可以比肩AI绘图神器Midjourney的绘图功能(前面写了一篇泰酷辣!目前最强的AI绘画神器!文生图模型 DALLE 3来啦!)&#xf…

Python正则表达式

正则表达式 当处理文本数据时,正则表达式是一种强大的工具,它允许我们根据特定的模式来匹配、搜索和处理字符串。 正则表达式由一系列字符和特殊字符组成,用于描述文本模式。这些模式可以包含普通字符(如字母、数字和标点符号&a…

【TensorFlow2 之012】TF2.0 中的 TF 迁移学习

#012 TensorFlow 2.0 中的 TF 迁移学习 一、说明 在这篇文章中,我们将展示如何在不从头开始构建计算机视觉模型的情况下构建它。迁移学习背后的想法是,在大型数据集上训练的神经网络可以将其知识应用于以前从未见过的数据集。也就是说,为什么…

蓝桥杯 第 1 场算法双周赛 第1题 三带一 c++ map 巧解 加注释

题目 三带一【算法赛】https://www.lanqiao.cn/problems/5127/learning/?contest_id144 问题描述 小蓝和小桥玩斗地主,小蓝只剩四张牌了,他想知道是否是“三带一”牌型。 所谓“三带一”牌型,即四张手牌中,有三张牌一样&#…

CSS学习基础知识

CSS学习笔记 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta http-equiv"X-UA-Compatible" content"IEedge"><meta name"viewport" content"widthdevice-width,…

独立式三相无源逆变电源设计

摘要 面对全球日趋严重的能源危机问题&#xff0c;可再生能源的开发和利用得到了人们的高度重视。其中辐射到地球太阳能资源是十分富饶的&#xff0c;绿色清洁的太阳能不会危害我们的生存环境&#xff0c;因而受到了人们的广泛利用。光伏发电作为可再生能源被广泛的应用&#x…

RabbitMq启用TLS

Windows环境 查看配置文件的位置 选择使用的节点 查看当前节点配置文件的配置 配置TLS 将证书放到同配置相同目录中 编辑配置文件添加TLS相关配置 [{ssl, [{versions, [tlsv1.2]}]},{rabbit, [{ssl_listeners, [5671]},{ssl_options, [{cacertfile,"C:/Users/17126…

如何定制化跑腿小程序源码

跑腿小程序源码为您提供了一个强大的起点&#xff0c;但要创建一个成功的本地服务平台&#xff0c;您通常需要对源码进行定制化。这篇文章将介绍如何定制化跑腿小程序源码&#xff0c;包括添加新功能、修改界面和优化用户体验。 选择合适的跑腿小程序源码 首先&#xff0c;您…

Linux查看端口号及进程信息

Linux查看端口号及进程 Linux查看端口号 netstat netstat -tuln显示当前正在监听的端口号以及相关的进程信息 ss ss -tuln与netstat类似&#xff0c;ss也可以用于显示当前监听的端口以及相关信息 isof isof -i :端口号端口号替换为具体要查找的端口号&#xff0c;显示该端…

Leetcode 75——1768.交替合并字符串 解题思路与具体代码【C++】

一、题目描述与要求 1768. 交替合并字符串 - 力扣&#xff08;LeetCode&#xff09; 题目描述 给你两个字符串 word1 和 word2 。请你从 word1 开始&#xff0c;通过交替添加字母来合并字符串。如果一个字符串比另一个字符串长&#xff0c;就将多出来的字母追加到合并后字符…

DOSBox和MASM汇编开发环境搭建

DOSBox和MASM汇编开发环境搭建 1 安装DOSBox2 安装MASM3 编译测试代码4 运行测试代码5 调试测试代码 本文属于《 X86指令基础系列教程》之一&#xff0c;欢迎查看其它文章。 1 安装DOSBox 下载DOSBox和MASM&#xff1a;https://download.csdn.net/download/u011832525/884180…