盘一盘 Python 系列 10 - Keras (上)

640?wx_fmt=png

本文含  12119  字, 64 图表截屏
建议阅读 62  分钟

0
引言

本文是 Python 系列的第十三篇,也是深度学习框架的第一篇 - Keras。

  • 深度学习之 Keras

  • 深度学习之 TensorFlow

  • 深度学习之 PyTorch

  • 深度学习之 MXnet


Keras 是一个高级的 (high-level) 深度学习框架,作者是 François Chollet。Keras 可以以两种方法运行:

  1. 以 TensorFlow, CNTK, 或者 Theano 作为后端 (backend) 运行

  2. 在 TensorFlow 里面直接运行 tf.keras

640?wx_fmt=png

我们用的是 TensorFlow 下面的 Keras,不过在本贴不会涉及任何关于 TensorFlow 的内容,只单单讲解 tf.keras 下面的内容。首先引入 tensorflow 和 keras。

import tensorflow as tf	
import tensorflow.keras as keras

Keras 是深度学习框架,里面有各种深度学习模型,介绍它之前让我们先回忆下它的好兄弟 - 机器学习框架 Scikit-Learn。

Scikit-Learn

在 Scikit-Learn 里完整的一套流程如下:

640?wx_fmt=png

数据是不可缺少的,Scikit-Learn 里面也有不少自带数据集。大家应该还记得 Scikit-Learn 里面的三大核心 API 吧:估计器(estimator),预测器(predictor)和转换器(transformer)。丛上图看估计器用来构建模型和拟合模型,而预测器用来评估模型。而转换器一般用来做数据预处理得到干净的 X_trainy_train

除了数据和模型,要完成一个任务还需定义损失函数(loss function)和指定算法(algorithm),它们都隐藏在 Scikit-Learn 的具体模型中,比如

  • LinearRegression 模型用的是 mean_square_error 损失函数,用梯度下降算法

  • LogisticRegression 模型用的是 cross_entropy 损失函数,用梯度下降算法

损失函数和算法都会在 Keras 里面都会显性定义出来,带着上面 Scikit-Learn 的图,让我们来看看 Keras 的高层流程。

Keras

说白了,Keras 里面的模型都是神经网络,而神经网络都是一层一层(layer by layer)叠加起来的,在Keras 里完整的一套流程如下:

640?wx_fmt=png

总共分五步:

  1. 引入数据:和 Scikit-Learn 操作一样

    1. 用 numpy 数据

    2. 引用自带数据

构建模型:用 Keras 构建模型就类似把每层当积木连起来称为一个网络, 连接的方法有三种:

  1. 序列式(sequential)

  2. 函数式(functional)

  3. 子类化(subclassing)

编译模型:这是 Scikit-Learn 里面没有的,显性定义出损失函数(loss)、优化方法(optimizer)和监控指标(metrics)。

拟合模型:和 Scikit-Learn 里的估计器类似,但可以额外设定 epoch 数量、是否包含验证集、设定调用函数里面的指标,等等。

评估模型:和 Scikit-Learn 里的预测器类似。

本帖目录如下(由于内容太多,上帖只包括第一章关于 Keras 最基本的知识,下帖再讲怎么用 Keras 做一些有趣的事情):

目录

第一章 - Keras 简介

    1.1 Keras 数据

    1.2 Keras 里的神经网络

    1.3 构建模型

    1.4 编译模型

    1.5 拟合模型

    1.6 评估模型

    1.7 保存模型

第二章 - 用 Keras 画画

第三章 - 用 Keras 写作

第四章 - 用 Keras 作曲

总结

640?wx_fmt=png

1
Keras 简介

1.1

Keras 数据

Numpy 数据格式

不像 TensorFlow, PyTorch 和 MXNet 有自己特有的数据格式

  • Tensorflow 用 tf.Tensor

  • MXNet 用 ndarray

  • PyTorch里用 torch.tensor

Keras 的数据格式就是 numpy array。

机器学习 (深度学习) 中用到的数据,包括结构性数据 (数据表) 和非结构性数据 (序列、图片、视屏) 都是张量,总结如下:

  • 数据表-2D 形状 = (样本数,特征数)

  • 序列类-3D 形状 = (样本数,步长,特征数)

  • 图像类-4D 形状 = (样本数,宽,高,通道数)

  • 视屏类-5D 形状 = (样本数,帧数,宽,高,通道数)

机器学习,尤其深度学习,需要大量的数据,因此样本数肯定占一个维度,惯例我们把它称为维度 1。这样机器学习要处理的张量至少从 2 维开始。

2D 数据表

2 维张量就是矩阵,也叫数据表,一般用 csv 存储。

640?wx_fmt=png

这套房屋 21,000 个数据包括其价格 (y),平方英尺,卧室数,楼层,日期,翻新年份等等 21 栏。该数据形状为 (21000, 21)。传统机器学习的线性回归可以来预测房价。

2 维张量的数据表示图如下:

640?wx_fmt=png

3D 序列数据

推特 (twitter) 的每条推文 (tweet) 规定只能发 280 个字符。在编码推文时,将 280 个字符的序列用独热编码 (one-hot encoding) 到包含 128 个字符的 ASCII 表,如下所示。

640?wx_fmt=png

这样,每条推文都可以编码为 2 维张量形状 (280, 128),比如一条 tweet 是 "I love python :)",这句话映射到 ASCII 表变成:

640?wx_fmt=png

如果收集到 1 百万条推文,那么整个数据集的形状为 (1000000, 280, 128)。传统机器学习的对率回归可以来做情感分析。

3 维张量的数据表示图如下:

640?wx_fmt=png

4D 图像数据

图像通常具有 3 个维度:宽度,高度和颜色通道。虽然是黑白图像 (如 MNIST 数字) 只有一个颜色通道,按照惯例,我们还是把它当成 4 维,即颜色通道只有一维。

  • 一组黑白照片可存成形状为 (样本数,宽,高,1) 的 4 维张量

  • 一组彩色照片可存成形状为 (样本数,宽,高,3) 的 4 维张量

640?wx_fmt=png

通常 0 代表黑色,255 代表白色。

4 维张量的数据表示图如下:

640?wx_fmt=png

5D 视屏数据

视频可以被分解成一幅幅帧 (frame)。

  • 每幅帧就是彩色图像,可以存储在形状是 (宽度,高度,通道) 的 3D 张量中

  • 视屏 (一个序列的帧) 可以存储在形状是 (帧数,宽度,高度,通道) 的 4D 张量中

  • 一批不同的视频可以存储在形状是 (样本数,帧数,宽度,高度,通道) 的 5D 张量中

下面一个 9:42 秒的 1280 x 720 油管视屏 (哈登三分绝杀勇士),被分解成 40 个样本数据,每个样本包括 240 帧。这样的视频剪辑将存储在形状为 (40, 240, 1280, 720, 3) 的张量中。

640?wx_fmt=png

5 维张量的数据表示图如下:

640?wx_fmt=png

对于以上用 numpy 自定义的各种维度的数据集 (X, y),用 Scikit-Learn 的子包 model_selection 里的 train_test_split 函数,代码如下:

from sklearn.model_selection import train_test_splitX_train, X_test, y_train, y_test = train_test_split( X, y, test_size=0.2 )

自带数据集

和 Scikit-Learn 一样,Keras 本身也自带数据集,从其官网中收集到 7 套。

  1. Boston housing price regression dataset

  2. CIFAR10 small image classification

  3. CIFAR100 small image classification

  4. IMDB Movie reviews sentiment classification

  5. Reuters newswire topics classification

  6. MNIST database of handwritten digits

  7. Fashion-MNIST database of fashion articles

想了解数据的具体描述,去 https://keras.io/datasets/ 链接。

想引进并划分它们可用以下代码:

from keras.datasets import data	
(x_train, y_train), (x_test, y_test) 	
= data.load_data()

这里 data 指的就是上面七套数据的统称,比如

  1. data = boston_housing

  2. data = cifar10

  3. data = cifar100

  4. data = imbd

  5. data = reuters

  6. data = mnist

  7. data = fashion_mnist

或者直接写 keras.dataset. + <tab> 来选取数据,展示如下

640?wx_fmt=gif

在本节后面介绍构建模型的三种方式时,我们用 fashion_mnist 数据来说明。Fashion-MNIST是一个替代 MNIST 手写数字集的图像数据集。它是由Zalando(一家德国的时尚科技公司)旗下的研究部门提供。

640?wx_fmt=png

Fashion-MNIST 的大小、格式和训练集/测试集划分与原始的 MNIST 完全一致。60000/10000 的训练测试数据划分,28x28 的灰度图片。

打印它们的形状确认一下。

(x_train, y_train),(x_test, y_test) = data.load_data()	
x_train, x_test = x_train / 255.0, x_test / 255.0	print( x_train.shape )	
print( x_test.shape )	
print( y_train.shape )	
print( y_test.shape )
(60000, 28, 28)
(10000, 28, 28)
(60000,)
(10000,)

每个训练和测试样本都按照以下类别(总共有 10 类标签)进行了标注:

640?wx_fmt=png

我们来验证一下标签和图片是不是一一对应。

class_names = [ "T-shirt/top", "Trouser", "Pullover", 	"Dress", "Coat", "Sandal", "Shirt",	"Sneaker", "Bag", "Ankle boot" ]	class_names[y_train[0]]

'Ankle boot'

第一张图片的标签是踝靴(Ankle boot),用 matplotlib.pyplot 里的 imshow 函数来把图片展示出来。

import matplotlib.pyplot as plt	
plt.imshow(x_train[0])

640?wx_fmt=png

是踝靴!

1.2

Keras 里的神经网络

组成神经网络的四个方面:

  1. (layers)和模型(models)

  2. 输入(input)和输出(output)

  3. 损失函数(loss)

  4. 优化器(optimizer)

多个链接在一起组成了模型,将输入数据映射为预测值。然后损失函数将这些预测值输出,并与目标进行比较,得到损失值,用于衡量网络预测值与预期结果的匹配程度。优化器使用这个损失值来更新网络的权重。

下图给出模型输入输出损失函数优化器之间的关系:

640?wx_fmt=png

神经网络里面的基本数据结构是层,而 Keras 里 layers 也是最基本的模块。

不同数据格式或不同数据处理类型需要用到不同的层,比如

  • 形状为 (样本数,特征数) 的 2D 数据用全连接层,对应 Keras 里面的 Dense

  • 形状为 (样本数,步长,特征数) 的 3D 序列数据用循环层,对应 Keras 里面的 RNN, GRU 或 LSTM

  • 形状为 (样本数,宽,高,通道数) 的 4D 图像数据用二维卷积层,对应 Keras 里面的 Conv2D 

等等。。。

模型

深度学习模型是层构成的有向无环图。最常见的例子就是层的线性堆叠,将单一输入映射为单一输出(single input to single output)。

此外,神经网络还有更复杂的结构,比如

多输入(multi-input)模型

用数值型和类别型的数据(下图左边),和图像数据(下图右边)一起来预测房价。

640?wx_fmt=png

多输出(multi-output)模型

根据图像数据来识别物体(下图左分支)和颜色(下图右分支)。

640?wx_fmt=png

等等。。。

损失函数

在 Keras 里将连成模型确定网络架构后,你还需要选择以下两个参数,选择损失函数和设定优化器

在训练过程中需要将最小化损失函数,这它是衡量当前任务是否已成功完成的标准。

对于分类、回归、序列预测等常见问题,你可以遵循一些简单的指导原则来选择正确的损失函数。

  • 对于二分类问题,用二元交叉熵(binary crossentropy)损失函数

  • 对于多分类问题,用分类交叉熵(categorical crossentropy)损失函数

  • 对于回归问题,用均方误差(mean-squared error)损失函数

  • 对于序列学习问题,用联结主义时序分类(CTC,connectionist temporal classification)损失函数

有时在面对真正全新的问题时,你还需要自主的设计损失函数,但这个超出本帖的范围了,以后再讲。

优化器

优化器决定如何基于损失函数对网络进行更新。它执行的是随机梯度下降(stochastic gradient descent,SGD)方法或其变体,目前 Keras 优化器包括

  • Adagrad

  • Adadelta

  • RMSprop

  • Adam

  • AdaMax

  • Nadam

  • AMSGrad

具体每个方法就不细讲了,对算法感兴趣的读者可参考链接 http://ruder.io/optimizing-gradient-descent/.

借用 Ruder 大神上面文章里的两幅动图对比各种优化算法的表现,图一对比他们在鞍点(saddle point)处的收敛到最优值的速度,SGD 没有收敛,图二从损失函数等值线(contour)看收敛速度,SGD 最慢。

640?wx_fmt=gif

640?wx_fmt=gif

1.3

构建模型

本节分别用序列式、函数书和子类化,配着 Fashion-MNIST 数据集构建模型,注意为了便于说明 Keras 语法特征,我故意只构建个简单模型,可能不实际,比如分类 Fashion-MNIST,用卷积效果网络好些,干嘛只用全连接网络举例?

一切只是便于解说基本核心概念。这些基本点弄清楚了,构建复杂模型和构建简单模型没任何区别。

序列式建模

序列式(sequential)建模有两种方式。

方式 1

用全连接网络(fully-connected neural network, FCNN)来建模,代码吐下:

640?wx_fmt=png

首先用 Sequential() 创建一个空模型 ,这个没办法,硬着记住吧。

接下来就像搭积木一样,用 add() 函数一层层加 layers,这个操作用代码写出来很自然,但是 layers 有很多种,这里用了两种:

  • Flatten:顾名思义,就是通过「Flatten 层」把高维数据打平成低维数据,做的就是下图的事。

640?wx_fmt=png

  • Dense:顾名思义,就是通过「Dense 层」把前一层每一个神经元和后一层神经元(除了偏置)两两相连,如下图:

640?wx_fmt=png

层的大方向弄清楚后,让我们看看里面的参数


  • Flatten( input_shape=[28,28] )

  • Dense( 100, activation='relu' )

  • Dense( 10, activation='softmax' ) 

每个层的第一个参数都是设定该层输出数据的维度。比如

  • Flatten 层输出形状 784 的一维数据

  • 第一个 Dense 层输出形状 100 的一维数据

  • 第二个 Dense 层输出形状 10 的一维数据

在 Keras 里不需要设定该层输入数据的维度,为什么呢?很简单,上一层的输出数据维度 = 该层的输入数据维度!Keras 会自动帮你连起来,那么

  • Flatten 层接受形状 28 × 28 的二维数据,输出形状 780 的一维数据

  • 第一个 Dense 层接受形状 100 的一维数据,输出形状 10 的一维数据

  • 第二个 Dense 层接受形状 10 的一维数据,输出形状 10 的一维数据

每个层(除了 Flatten 层)的第二个参数设定了激活函数的方式,比如

  • 第一个 Dense 层用 relu,防止梯度消失

  • 第二个 Dense 层用 softmax,因为 Fashion-MNIST 是个多分类问题

当构建完模型,我们可以打印出它的层的信息(用 model.layers)和概要信息(用 model.summary())。

model.layers

640?wx_fmt=png

整个模型有三层,按顺序它们的类别分别是 Flatten, Dense 和 Dense。

model.summary()

640?wx_fmt=png

该模型自动被命名 sequential_8,接着一张表分别描述每层的名称类型(layer (type))、输出形状(Output Shape)和参数个数(Param #)。我们一层层来看

  1. Flatten 层被命名为 flatten_7

    1. 输出形状是 (None, 784),784 好理解,就是 28×28 打平之后的维度,这个 None 其实是样本数,更严谨的讲是一批 (batch) 里面的样本数。为了代码简洁,这个「0 维」的样本数在建模时通常不需要显性写出来。

    2. 参数个数为 0,因为打平只是重塑数组,不需要任何参数来完成重塑动作。

  2. 第一个 Dense 层被命名为 dense_5

    1. 输出形状是 (None, 100),好理解。

    2. 参数个数为 78500,为什么不是 784×100 = 78400 呢?别忘了偏置项(bias)哦,(784+1)×100 = 78500。

  3. 第二个 Dense 层被命名为 dense_6

    1. 输出形状是 (None, 10),好理解。

    2. 参数个数为 1010,考虑偏置项,(100+1)×10 = 1010。

最下面还列出总参数量 79510,可训练参数量 79510,不可训练参数量 0。为什么还有参数不需要训练呢?你想想迁移学习,把借过来的网络锁住开始的 n 层,只训练最后 1- 2 层,那前面 n 层的参数可不就不参与训练吗?


再回顾一下代码。

640?wx_fmt=png

如果你还是觉得上面代码太多,我们还可以做进一步精简,事先引用 Sequential, Flattern 和 Dense。

from tensorflow.keras.models import Sequential	
from tensorflow.keras.layers import Flatten	
from tensorflow.keras.layers import Dense

这样每次就不用重复写 keras.models 和 keras.layers 了,下面代码是不是简洁多了。

640?wx_fmt=png

事无巨细把最简单的序列式建模讲完,大家是不是觉得 Keras 很简单呢?

方式 2

让我们再回顾一下 model.layers。

model.layers

640?wx_fmt=png

仔细看看输出数据的格式,是个列表,那么有没有一种方法用列表而不用 model.add() 来构建模型么?有,代码如下:

640?wx_fmt=png

model.summary()

640?wx_fmt=png

同样的模型结果(输入形状和参数个数,名称不一样),但是又省掉几个 model.add() 的字节了,代码看起来又简洁些。

我们可以用 model.layers[n].name 来获取第 n 层的名称。

model.layers[1].name
'dense_11'

也可以用 get_weights() 来获取每层的权重矩阵 W 和偏置向量 b

weights, biases = model.layers[1].get_weights()	
weights

640?wx_fmt=png

biases

640?wx_fmt=png

当模型还没训练时,W 是随机初始化,而 b 是零初始化。最后检查一下它们的形状。

print( weights.shape )	
print( biases.shape )
(784, 100)
(100,)

小结

一张图总结「序列式建模」。

640?wx_fmt=png

函数式建模

上面的序列式只适用于线性堆叠层的神经网络,但这种假设过于死板,有些网络

  • 需要多个输入

  • 需要多个输出

  • 在层与层之间具有内部分支

这使得网络看起来像是层构成的图(graph),而不是层的线性堆叠,这是需要更加通用和灵活建模方式,函数式(functional)建模。

本小节还是用上面序列式的简单例子来说明函数式建模,目的只是阐明函数式建模的核心要点,更加实际的案例放在之后几章。

首先引入必要的模块,和序列式建模比,注意 Input 和 Model 是个新东西。

from tensorflow.keras.layers import Input, Flatten, Dense	
from tensorflow.keras.models import Model

代码如下。

640?wx_fmt=png

函数式建模只用记住一句话:把层当做函数用。有了这句在心,代码秒看懂。

第二行,把 Flatten() 当成函数 f,化简不就是 x = f(input)

第三行,把 Dense(100, activation='relu') 当成函数 g,化简不就是 x = g(x)

第三行,把 Dense(10, activation='softmax') 当成函数 h,化简不就是 output = h(x)

这样一层层(函数接着函数)把 input 传递到 output,最后再用 Model() 将他俩建立关系。

看看模型概要。

model.summary()

640?wx_fmt=png

概要包含的内容和序列式建模产生的一眼,除了多了一个 InputLayer。

序列式构建的模型都可以用函数式来完成,反之不行,如果在两者选一,建议只用函数式来构建模型。

小结

一张图对比「函数式建模」和「序列式建模」。

640?wx_fmt=png

子类化建模

序列式和函数式都是声明式编程(declarative programming),它描述目标的性质,让计算机明白目标,而非流程。

具体来说,它们都是声明哪些层应该按什么顺序来添加,层与层以什么样的方式连接,所有声明完成之后再给模型喂数据开始训练。这种方法有好有快。

  • 好处:模型很容易保存、复制和分享,模型结构也容易展示和分析,因此调试起来比较容易。

  • 坏处:是个静态模型,很多情况模型有循环(loops)和条件分支(conditional branching)。这是我们更需要命令式编程(imperative programming)了。

子类化(subclassing)建模登场了。

首先引入必要的模块

from tensorflow.keras.layers import Flatten, Dense	
from tensorflow.keras.models import Model

Model 是个类别,而子类化就是创建 Model 的子类,起名为 SomeModel。

640?wx_fmt=png

该类别里有一个构造函数 __init__() 和一个 call() 函数:

构造函数负责创建不同的层,在本例中创建了一个隐藏层 self.hidden 和一个输出层 self.main_output。

call() 函数负责各种计算,注意到该函数有个参数是 input。

咋一看子类化函数式非常像,但有个细微差别,构造函数里面只有各种层,没有 input,而做计算的地方全部在 call() 里进行。这样就把创建层和计算两者完全分开。

在 call() 你可以尽情发挥想象:用各种 for, if, 甚至低层的 Tensorflow 里面的操作。研究员比较喜欢用子类化构建模型,他们可以尝试不同的点子。

1.4

编译模型

当构建模型完毕,接着需要编译(compile)模型,需要设定三点:

  1. 根据要解决的任务来选择损失函数

  2. 选取理想的优化器

  3. 选取想监控的指标

代码如下:

640?wx_fmt=png

损失函数 loss

常见问题类型的最后一层激活和损失函数,可供选择:

  • 二分类问题:最后一层激活函数是 sigmoid,损失函数是 binary_crossentropy

  • 多分类问题:最后一层激活函数是 softmax,损失函数是 categorical_crossentropy

  • 多标签问题:最后一层激活函数是 sigmoid,损失函数是 binary_crossentropy

  • 回归问题:最后一层无激活函数是,损失函数是 mse

Fashion_MNIST 是一个十分类问题,因此损失函数是 categorical_crossentropy。

优化器 optimizer

大多数情况下,使用 adam 和 rmsprop 及其默认的学习率是稳妥的。本例中选择的是 adam。

除了通过名称来调用优化器 model.compile('名称'),我们还可以通过实例化对象来调用优化器 model.compile('优化器')。选取几个对比如下:

名称:SGD

对象:SGD(lr=0.01, momentum=0.0, decay=0.0, nesterov=False)

名称:RMSprop

对象:RMSprop(lr=0.001, rho=0.9, epsilon=None, decay=0.0)

名称:Adagrad

对象:Adagrad(lr=0.01, epsilon=None, decay=0.0)

名称:Adam

对象

Adam(lr=0.001, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0, amsgrad=False)

这些优化器对象都在 keras.optimizer 命名空间下。使用优化器对象来编译模型的好处是可以调节里面的超参数比如学习率 lr,使用名称则来编译模型只能采用优化器的默认参数,比如用 Adam 里面的学习率 0.001。

指标 metrics

指标和损失函数一样,都可以通过用名称实例化对象来调用,在本例中的指标是精度,那么可写成

名称:metrics = ['acc']

对象:metrics = [metrics.categorical_accuracy])

指标不会用于训练过程,只是让我们监控模型训练时的表现,常见的指标如下:

640?wx_fmt=gif

除了 Keras 自带指标,我们还可以自定指标,下列的 mean_pred 就是自定义指标(该指标计算预测的平均值)。

def mean_pred(y_true, y_pred):	return K.mean(y_pred)	model.compile(optimizer='sgd',	loss='binary_crossentropy',	metrics=['acc', mean_pred])

1.5

拟合模型

基本操作

和 Scikit-Learn 一样,Keras 里也用 model.fit() 函数;和 Scikit-Learn 不一样,Keras 会设置要遍历训练数据多少遍,即 epochs,先用 20 遍。

640?wx_fmt=png

640?wx_fmt=png

发现 loss 逐渐减少,acc 逐渐提高,这么个简单的单层全连接神经网络在 Fashion_MNIST 上精度做到 92.82% 也可以了。

调用函数

如果项目只要求精度达到 90% 即可,那么我们不用浪费资源把程序跑到底。这是用调用函数(callback)来控制,代码如下:

640?wx_fmt=png

回调函数是一个函数的合集,会在训练的阶段中所使用。你可以使用回调函数来查看训练模型的内在状态和统计。你可以传递一个列表的回调函数(作为 callbacks 关键字参数)到 Sequential 或 Model 类型的 .fit() 方法。在训练时,相应的回调函数的方法就会被在各自的阶段被调用。 

在本例中,我们定义的是 on_epoch_end(),在每期结束式,一旦精度超过 90%,模型就停止训练。

最常见的回调函数是

  • ModelCheckpoint

  • EarlyStopping

此外而且具体情况,我们可以自定义

  • on_train_begin()

  • on_train_end()

  • on_epoch_begin()

  • on_epoch_end()

  • on_batch_begin()

  • on_batch_end()

定义完 callbacks,我们只用把它当做参数传到 model.fit() 里。

640?wx_fmt=png

640?wx_fmt=png

在 Epoch = 8 时,训练精度达到 90.17%,停止训练。

1.6

预测模型

Keras 预测模型和 Scikit-Learn 里一样,都用是 model.predict()。

prob = model.predict( x_test[0:1] )	
prob

640?wx_fmt=png

在测试集上第一张图上做预测,输出是一个数组,里面 10 个数值代表每个类别预测的概率。看上去是第 10 类(索引为 9)概率最大。用 argmax 验证一下果然是的,而且把真正标签打印出来也吻合,第一张图预测对了。

import numpy as np	
print( np.argmax(prob) )	
print( y_test[0] )
9
9

前面讲了 Fashion_MNIST 第 10 类是踝靴(Ankle boot),画出来看看。

plt.imshow(x_test[0])

640?wx_fmt=png

最后用 model.evaluate() 来看看模型在所有测试集上的表现。

model.evaluate( x_test, y_test )

640?wx_fmt=png

训练精度 90.17% 但是测试精度 87.73%,有过拟合的征兆。这是需要用验证集了。

验证集

我们将原来训练集前 5000 个当验证集,剩下了当训练集。

640?wx_fmt=png

这时来用回调函数关注验证精度 val_acc,一旦超过 90% 就停止训练。

640?wx_fmt=png

代码基本和上面一样,唯一区别是把 (x_valid, y_valid) 传到 model.fit() 中。

640?wx_fmt=png

640?wx_fmt=png

但是验证精度适中没有超过 90%,模型从头训练到完。

难道是我们的单层全连接模型太简单?现在数据集可不是 MNIST 而是 Fashion_MNIST 啊,服装的特征还是数字的特征要丰富多了吧,再怎么样也要弄到卷积神经网络吧。

首先引进二维卷积层 Conv2D 和二维最大池化层 MaxPooling2D。在全连接层前我们放了两组 Conv2D + MaxPooling2D。

640?wx_fmt=png

640?wx_fmt=png

效果一下子出来了,训练精度 98.71% 但是验证精度只有 91.36%,明显的过拟合。画个图看的更明显。

640?wx_fmt=png

640?wx_fmt=png

怎么办?用 Dropout 试试?

代码和上面一摸一样,在第一个全连接层前加一个 Dropout 层(高亮强调出)。

640?wx_fmt=png

640?wx_fmt=png

虽然训练精度降到 93.89% 但是验证精度提高到 92.26%,Dropout 有效地抑制了过拟合。继续上图。

640?wx_fmt=png

1.7

保存模型

花费很长时间辛苦训练的模型不保存下次再从头开始训练太傻了。

对于用序列式函数式构建的模型可以用 model.save() 来保存:

model.save("my_keras_model.h5")

加载可用 models 命名空间里面的 load_model() 函数:

model = keras.models.load_model("my_keras_model.h5")

子类化构建的模型不能用上面的 save 和 load 来保存和加载,它对应的方式是

  • save_weights()

  • load_weights()

虽然没能保存模型所有的东西,但是保存了最重要的参数,这就够了。


一篇 12000 字的长文才讲清了 Keras 的基本内容,一开始本想一贴写完,结果太低估 Keras 和太高估自己了,啥也别说了,敬请期待 Keras(下)!

Stay Tuned!

640?

公众号:AI蜗牛车

保持谦逊、保持自律、保持进步

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/63375.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

这届全明星,把NBA又燃回来了

第一个罚球&#xff0c;戴维斯出手后&#xff0c;听到哐当医生&#xff0c;皮球掉了出来。我又紧张了。微信群了很多人开始发消息&#xff0c;说詹姆斯队又要输了。回到比赛。戴维斯当时没有任何微笑&#xff0c;我估计他内心也是紧张的&#xff0c;他有点埋怨哈登&#xff0c;…

CTR---DIN原理,及deepctr组网实现DIN

文章目录 原理小结deepctr实现DIN&#xff08;基于df的数据格式&#xff09; 原理小结 Candidate Ad item&#xff0c;在这指广告特征。 User profile features 代表用户的特征。 Context Features 代表跟场景有关的特征&#xff0c;比如时间戳之类的。 User Behaviors 代表…

第02章 PyTorch基础知识

文章目录 第02章 Pytorch基础知识2.1 张量2.2 自动求导2.3 并行计算简介2.3.1 为什么要做并行计算2.3.2 CUDA是个啥2.3.3 做并行的方法 补充&#xff1a;通过股票数据感受张量概念。 本图文是Datawhale组队学习Pytorch的学习笔记&#xff0c;主要内容包括张量的概念&#xff08…

大胆预测NBA2011-2012季后赛形势

以下都是我个人的看法,纯属猜测。希望大家不喜勿喷,有相同喜好的爱好者也可以把自己的想法写在评论上,大家一起讨论哦哈哈 NBA季后赛第一轮正在如火如荼的进行当中,除了雷霆已经以4:0的大比分淘汰了卫冕冠军小牛队之外,其他14支球队的争夺也已经到了最后的阶段了,现在先…

通过KNN算法预测数据所属NBA球员——Python实现

项目介绍 通过得分&#xff0c;篮板&#xff0c;助攻&#xff0c;出场时间四个数据来预测属于哪位球员。 选取了LeBron James,Chris Paul,James Harden,Kevin Love,Dwight Howard五位球员单场数据。 数据来源 本文使用数据全部来自于科赛网 &#xff0c;字段解释如下&#xff1…

java的后撤建_后撤步难学?做好这几点,你的后撤步也能像哈登一样强!

原标题&#xff1a;后撤步难学&#xff1f;做好这几点&#xff0c;你的后撤步也能像哈登一样强&#xff01; 后撤步投篮的一些要点&#xff0c; 可得好好学学&#xff0c; 没有一招拿得出手的后撤步&#xff0c; 如何在球场上立足。 中国孔子说"性相近也&#xff0c;习相远…

利用Python从数据分析的角度告诉你NBA2018-2019常规赛季为什么字母哥比哈登强?

目录 基于NBA2018-2019赛季常规赛球员数据进行数据挖掘 1. 挖掘背景与目标 1.1 挖掘背景 1.2 挖掘目标 2. 分析方法与过程 2.1 分析方法&#xff08;主成分分析&#xff09; 2.1 分析过程 3. 获取数据 4. 数据探索性分析与预处理 4.1探索性分析 4.1.1 条形图分析 4…

AI篮球裁判火了,走步算得特别准,就问哈登慌不慌

Alex 发自 凹非寺量子位 | 公众号 QbitAI 打篮球的友友们应该知道&#xff0c;走步是比赛中最常见的违规之一。 为了更好地监测篮球比赛中球员是否出现走步行为&#xff0c;一位网名叫Ayush Pai的小哥&#xff08;我们就叫他AP哥吧&#xff09;搞出了一个AI裁判。 如你所见&…

预测2019-2020赛季常规赛MVP

受新冠肺炎影响&#xff0c;2019-2020赛季NBA已经处于停摆状态&#xff0c;是否以及何时能复赛还不清楚。相关的各项评选如常规赛MVP、最佳阵容、最佳防守等也由于疫情暂停了。按照往年的赛程节奏&#xff0c;此时也应该进入常规赛收官阶段了。本文利用历史数据和本赛季常规赛已…

今天nba预测分析_焰神体育【NBA】赛事推荐预测分析:1月15日《开拓者》vs《步行者》...

波特兰开拓者(主) VS 印第安纳步行者 比赛时间&#xff1a;2021 1月15日 11:00 印第安纳步行者队 周四的大新闻是詹姆斯哈登在连续几周表现不佳后终于如愿以获&#xff0c;被交易到布鲁克林篮网队。 印第安纳步行者队用奥拉迪波交换莱弗里特到火箭。 凯文-普理查德可以说是今天…

前端图片显示不出来

原来的代码是 <img src"Release/warn.png">给路径加上 / 就可以了 <img src"/Release/warn.png"> 然后就正常显示了

页面加载微信聊天记录图片不显示问题

今天在做微信客服功能的时候页面通过异步请求微信的聊天记录&#xff0c;并把获取的结果appendchild页面中&#xff0c;发现微信的图片无权加载。如下图&#xff1a; 经过查阅资料得知&#xff0c;因为微信加载图片是通过一个地址请求然后返回的真是的图片地址&#xff0c;在请…

为什么计算机没有桌面显示不出来,​为什么电脑图片显示不出来

我们日常使用的电脑中&#xff0c;往往会有一些图片保存下来&#xff0c;用户想要打开自己需要的图片时&#xff0c;也可以通过显示的缩略图来查找&#xff0c;然而最近有用户的电脑桌面上的图片总是不显示出来&#xff0c;这让我们需要一张一张的进行查看&#xff0c;那么为什…

为什么html中图片显示不出来,网页图片不能显示 网页图片显示不出来的解决办法...

很多朋友上网遇到这样一种情况在浏览网页的时候发现网页中德图片不显示&#xff0c;(电脑百事网)一般现象是要门图像位置是空白&#xff0c;要么图像位置显示一个红叉&#xff0c;如下图所示&#xff0c;一般来说网页图片不显示主要影响页面美观&#xff0c;对我们影响相对不大…

在html中图片不显示不出来,网页图片显示不出来

很多小伙伴在打开网页的时候&#xff0c;发现网页的图片加载不出来&#xff0c;显示一个的标志&#xff0c;这是怎么一回事呢?可能是你的网速过低&#xff0c;等待一会就可以了&#xff0c;也可能是设置里面没有把显示图片打勾&#xff0c;具体的解决方法下面一起来看看吧。 显…

为什么html中图片显示不出来,网页图片显示不出来是什么原因?

原标题&#xff1a;网页图片显示不出来是什么原因&#xff1f; 在平时生活上网的过程&#xff0c;我们常常会遇到网页虽然是正常打开了&#xff0c;但网页上的图片却无法显示出来&#xff0c;无论怎么刷新也无法显示呢&#xff1f;一个网页打开正常与否&#xff0c;其实由很多因…

html浏览器图片不显示图片,教你网页图片显示不出来怎么办

网页是构成网站的基本元素&#xff0c;是一个包含HTML标签的纯文本文件&#xff0c;而文字与图片是构成一个网页的最基本的元素。今天&#xff0c;小编就给大家介绍一下网页图片显示不出来的解决方法&#xff0c;有需要就来了解一下吧 在查看网页的时候最重要的就是图片&#x…

流利阅读 2019.2.27 How sky-high rents forced people into imaginative alternatives

下载 笔记版/无笔记版 pdf资料&#xff1a; GitHub - zhbink/LiuLiYueDu: 流利阅读pdf汇总 本文内容全部来源于流利阅读。流利阅读对每期内容均有很好的文章讲解&#xff0c;向您推荐。 您可以关注微信公众号&#xff1a;流利阅读 了解详情。 How sky-high rents forced people…

V-Net 《Multi-Passage Machine Reading Comprehension with Cross-Passage Answer Verification》阅读理解笔记

V-Net 《Multi-Passage Machine Reading Comprehension with Cross-Passage Answer Verification》 这篇文章是发表在2018年ACL上的&#xff0c;是抽取式的。在微软发布的MS MARCO数据集和百度发布的中文数据集DuReader上得到了SOTA效果。 分以下四部分介绍&#xff1a; Mot…

掌握这15个可视化图表,小白也能轻松玩转数据分析

大数据时代&#xff0c;数据驱动决策。处理不好庞大、复杂的数据&#xff0c;其价值将大打折扣。 那如何缩短数据与用户的距离&#xff1f;让用户一眼Get到重点&#xff1f;让老板为你的汇报方案鼓掌&#xff1f; 本文通过连环15关&#xff0c;层层深入&#xff0c;传你数据匹…