-
本文为🔗365天深度学习训练营 中的学习记录博客
原作者:K同学啊 -
要求:
1.本地读取并加载数据。
2.了解循环神经网络(RNN)的构建过程
3.测试集accuracy到达87%拔高:
1.测试集accuracy到达89%
我的环境:
●语言环境:Python3.6.5
●编译器:Jupyter Lab
●深度学习框架:TensorFlow2.4.1
●数据:心脏病数据
代码流程图如下所示:
一、RNN是什么
传统神经网络的结构比较简单:输入层 – 隐藏层 – 输出层
RNN 跟传统神经网络最大的区别在于每次都会将前一次的输出结果,带到下一次的隐藏层中,一起训练。如下图所示:
这里用一个具体的案例来看看 RNN 是如何工作的:用户说了一句“what time is it?”,我们的神经网络会先将这句话分为五个基本单元(四个单词+一个问号)
然后,按照顺序将五个基本单元输入RNN网络,先将 “what”作为RNN的输入,得到输出 01
随后,按照顺序将“time”输入到RNN网络,得到输出02。
这个过程我们可以看到,输入 “time” 的时候,前面“what” 的输出也会对02的输出产生了影响(隐藏层中有一半是黑色的)。
以此类推,我们可以看到,前面所有的输入产生的结果都对后续的输出产生了影响(可以看到圆形中包含了前面所有的颜色)
当神经网络判断意图的时候,只需要最后一层的输出 05,如下图所示:
二、前期准备
- 设置GPU
import tensorflow as tf
gpus = tf.config.list_physical_devices("GPU")if gpus:gpu0 = gpus[0]tf.config.experimental.set_memory_growth(gpu0,true)tf.config.set_visible_devices([gpu0],"GPU")print("GPU: ",gpus)
else:print("CPU:")# 确认当前可见的设备列表
print(tf.config.list_physical_devices())
代码输出:
CPU:
[PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU')]
- 导入数据
数据介绍:
- ●age:1) 年龄
●sex:2) 性别
●cp:3) 胸痛类型 (4 values)
●trestbps:4) 静息血压
●chol:5) 血清胆甾醇 (mg/dl
●fbs:6) 空腹血糖 > 120 mg/dl
●restecg:7) 静息心电图结果 (值 0,1 ,2)
●thalach:8) 达到的最大心率
●exang:9) 运动诱发的心绞痛
●oldpeak:10) 相对于静止状态,运动引起的ST段压低
●slope:11) 运动峰值 ST 段的斜率
●ca:12) 荧光透视着色的主要血管数量 (0-3)
●thal:13) 0 = 正常;1 = 固定缺陷;2 = 可逆转的缺陷
●target:14) 0 = 心脏病发作的几率较小 1 = 心脏病发作的几率更大
import pandas as pd
import numpy as npdf = pd.read_csv("./R1/heart.csv")
df
代码输出 :
age | sex | cp | trestbps | chol | fbs | restecg | thalach | exang | oldpeak | slope | ca | thal | target | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 63 | 1 | 3 | 145 | 233 | 1 | 0 | 150 | 0 | 2.3 | 0 | 0 | 1 | 1 |
1 | 37 | 1 | 2 | 130 | 250 | 0 | 1 | 187 | 0 | 3.5 | 0 | 0 | 2 | 1 |
2 | 41 | 0 | 1 | 130 | 204 | 0 | 0 | 172 | 0 | 1.4 | 2 | 0 | 2 | 1 |
3 | 56 | 1 | 1 | 120 | 236 | 0 | 1 | 178 | 0 | 0.8 | 2 | 0 | 2 | 1 |
4 | 57 | 0 | 0 | 120 | 354 | 0 | 1 | 163 | 1 | 0.6 | 2 | 0 | 2 | 1 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
298 | 57 | 0 | 0 | 140 | 241 | 0 | 1 | 123 | 1 | 0.2 | 1 | 0 | 3 | 0 |
299 | 45 | 1 | 3 | 110 | 264 | 0 | 1 | 132 | 0 | 1.2 | 1 | 0 | 3 | 0 |
300 | 68 | 1 | 0 | 144 | 193 | 1 | 1 | 141 | 0 | 3.4 | 1 | 2 | 3 | 0 |
301 | 57 | 1 | 0 | 130 | 131 | 0 | 1 | 115 | 1 | 1.2 | 1 | 1 | 3 | 0 |
302 | 57 | 0 | 1 | 130 | 236 | 0 | 0 | 174 | 0 | 0.0 | 1 | 1 | 2 | 0 |
303 rows × 14 columns
- 检查数据
# 检查是否有空值
df.isnull().sum()
代码输出 :
age 0
sex 0
cp 0
trestbps 0
chol 0
fbs 0
restecg 0
thalach 0
exang 0
oldpeak 0
slope 0
ca 0
thal 0
target 0
dtype: int64
三、数据预处理
- 划分训练集与测试集
测试集与验证集的关系:
- 1.验证集并没有参与训练过程梯度下降过程的,狭义上来讲是没有参与模型的参数训练更新的。
2.但是广义上来讲,验证集存在的意义确实参与了一个“人工调参”的过程,我们根据每一个epoch训练之后模型在valid data上的表现来决定是否需要训练进行early stop,或者根据这个过程模型的性能变化来调整模型的超参数,如学习率,batch_size等等。
3.我们也可以认为,验证集也参与了训练,但是并没有使得模型去overfit验证集。
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_splitX = df.iloc[:,:-1]
y = df.iloc[:,-1]X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.1, random_state = 1)
X_train.shape, y_train.shape
代码输出:
((272, 13), (272,))
检查并重置索引:
# 重置索引以确保索引从 0 开始连续排列
X_train = X_train.reset_index(drop=True)
y_train = y_train.reset_index(drop=True)
X_test = X_test.reset_index(drop=True)
y_test = y_test.reset_index(drop=True)
- 标准化
# 将每一列特征标准化为标准正太分布,注意,标准化是针对每一列而言的
sc = StandardScaler()
X_train = sc.fit_transform(X_train)
X_test = sc.transform(X_test)X_train = X_train.reshape(X_train.shape[0], X_train.shape[1], 1)
X_test = X_test.reshape(X_test.shape[0], X_test.shape[1], 1)
四、构建RNN模型
函数原型:
tf.keras.layers.SimpleRNN(units,activation=‘tanh’,use_bias=True,kernel_initializer=‘glorot_uniform’,recurrent_initializer=‘orthogonal’,bias_initializer=‘zeros’,kernel_regularizer=None,recurrent_regularizer=None,bias_regularizer=None,activity_regularizer=None,kernel_constraint=None,recurrent_constraint=None,bias_constraint=None,dropout=0.0,recurrent_dropout=0.0,return_sequences=False,return_state=False,go_backwards=False,stateful=False,unroll=False,**kwargs)
关键参数说明:
- ●units: 正整数,输出空间的维度。
●activation: 要使用的激活函数。 默认:双曲正切(tanh)。 如果传入 None,则不使用激活函数 (即 线性激活:a(x) = x)。
●use_bias: 布尔值,该层是否使用偏置向量。
●kernel_initializer: kernel 权值矩阵的初始化器, 用于输入的线性转换 (详见 initializers)。
●recurrent_initializer: recurrent_kernel 权值矩阵 的初始化器,用于循环层状态的线性转换 (详见 initializers)。
●bias_initializer:偏置向量的初始化器 (详见initializers).
●dropout: 在 0 和 1 之间的浮点数。 单元的丢弃比例,用于输入的线性转换。
import tensorflow
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense,LSTM,SimpleRNNmodel = Sequential()
model.add(SimpleRNN(200, input_shape= (13,1), activation='relu'))
model.add(Dense(100, activation='relu'))
model.add(Dense(1, activation='sigmoid'))
model.summary()
代码输出 :
Model: "sequential_15"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
simple_rnn_15 (SimpleRNN) (None, 200) 40400
_________________________________________________________________
dense_30 (Dense) (None, 100) 20100
_________________________________________________________________
dense_31 (Dense) (None, 1) 101
=================================================================
Total params: 60,601
Trainable params: 60,601
Non-trainable params: 0
_________________________________________________________________
五、编译模型
opt = tf.keras.optimizers.Adam(learning_rate=1e-4)model.compile(loss='binary_crossentropy',optimizer=opt,metrics="accuracy")
六、训练模型
epochs = 100history = model.fit(X_train, y_train, epochs=epochs, batch_size=128, validation_data=(X_test, y_test),verbose=1)
代码输出 :
WARNING:tensorflow:Keras is training/fitting/evaluating on array-like data. Keras may not be optimized for this format, so if your input data format is supported by TensorFlow I/O (https://github.com/tensorflow/io) we recommend using that to load a Dataset instead.
Epoch 1/100
1/3 [=========>....................] - ETA: 1s - loss: 0.6790 - accuracy: 0.7422WARNING:tensorflow:Keras is training/fitting/evaluating on array-like data. Keras may not be optimized for this format, so if your input data format is supported by TensorFlow I/O (https://github.com/tensorflow/io) we recommend using that to load a Dataset instead.
3/3 [==============================] - 1s 140ms/step - loss: 0.6789 - accuracy: 0.7189 - val_loss: 0.6590 - val_accuracy: 0.8387
Epoch 2/100
3/3 [==============================] - 0s 42ms/step - loss: 0.6704 - accuracy: 0.7557 - val_loss: 0.6463 - val_accuracy: 0.8710
Epoch 3/100
3/3 [==============================] - 0s 41ms/step - loss: 0.6628 - accuracy: 0.7954 - val_loss: 0.6344 - val_accuracy: 0.9032
Epoch 4/100
3/3 [==============================] - 0s 38ms/step - loss: 0.6555 - accuracy: 0.8049 - val_loss: 0.6226 - val_accuracy: 0.8710
Epoch 5/100
3/3 [==============================] - 0s 39ms/step - loss: 0.6492 - accuracy: 0.7883 - val_loss: 0.6111 - val_accuracy: 0.9032
Epoch 6/100
3/3 [==============================] - 0s 80ms/step - loss: 0.6384 - accuracy: 0.8204 - val_loss: 0.5999 - val_accuracy: 0.9032
Epoch 7/100
3/3 [==============================] - 0s 50ms/step - loss: 0.6321 - accuracy: 0.8135 - val_loss: 0.5884 - val_accuracy: 0.9032
Epoch 8/100
3/3 [==============================] - 0s 39ms/step - loss: 0.6233 - accuracy: 0.8280 - val_loss: 0.5771 - val_accuracy: 0.8710
Epoch 9/100
3/3 [==============================] - 0s 39ms/step - loss: 0.6179 - accuracy: 0.8077 - val_loss: 0.5656 - val_accuracy: 0.8710
Epoch 10/100
3/3 [==============================] - 0s 38ms/step - loss: 0.6106 - accuracy: 0.8039 - val_loss: 0.5531 - val_accuracy: 0.8710
Epoch 11/100
3/3 [==============================] - 0s 39ms/step - loss: 0.6013 - accuracy: 0.8059 - val_loss: 0.5393 - val_accuracy: 0.8710
Epoch 12/100
3/3 [==============================] - 0s 40ms/step - loss: 0.5936 - accuracy: 0.8030 - val_loss: 0.5243 - val_accuracy: 0.8710
Epoch 13/100
3/3 [==============================] - 0s 39ms/step - loss: 0.5828 - accuracy: 0.8059 - val_loss: 0.5081 - val_accuracy: 0.8710
Epoch 14/100
3/3 [==============================] - 0s 40ms/step - loss: 0.5703 - accuracy: 0.8105 - val_loss: 0.4901 - val_accuracy: 0.8710
Epoch 15/100
3/3 [==============================] - 0s 38ms/step - loss: 0.5624 - accuracy: 0.8055 - val_loss: 0.4698 - val_accuracy: 0.8387
Epoch 16/100
3/3 [==============================] - 0s 39ms/step - loss: 0.5496 - accuracy: 0.8093 - val_loss: 0.4463 - val_accuracy: 0.8387
Epoch 17/100
3/3 [==============================] - 0s 41ms/step - loss: 0.5265 - accuracy: 0.8173 - val_loss: 0.4202 - val_accuracy: 0.8710
Epoch 18/100
3/3 [==============================] - 0s 40ms/step - loss: 0.5125 - accuracy: 0.8170 - val_loss: 0.3936 - val_accuracy: 0.9032
Epoch 19/100
3/3 [==============================] - 0s 39ms/step - loss: 0.4855 - accuracy: 0.8211 - val_loss: 0.3687 - val_accuracy: 0.9032
Epoch 20/100
3/3 [==============================] - 0s 38ms/step - loss: 0.4711 - accuracy: 0.8280 - val_loss: 0.3483 - val_accuracy: 0.9032
Epoch 21/100
3/3 [==============================] - 0s 40ms/step - loss: 0.4655 - accuracy: 0.8115 - val_loss: 0.3291 - val_accuracy: 0.9032
Epoch 22/100
3/3 [==============================] - 0s 40ms/step - loss: 0.4729 - accuracy: 0.7844 - val_loss: 0.3127 - val_accuracy: 0.9032
Epoch 23/100
3/3 [==============================] - 0s 39ms/step - loss: 0.4467 - accuracy: 0.7952 - val_loss: 0.2997 - val_accuracy: 0.9032
Epoch 24/100
3/3 [==============================] - 0s 39ms/step - loss: 0.4421 - accuracy: 0.8115 - val_loss: 0.2922 - val_accuracy: 0.8710
Epoch 25/100
3/3 [==============================] - 0s 42ms/step - loss: 0.4430 - accuracy: 0.7943 - val_loss: 0.2860 - val_accuracy: 0.9032
Epoch 26/100
3/3 [==============================] - 0s 38ms/step - loss: 0.4230 - accuracy: 0.8261 - val_loss: 0.2849 - val_accuracy: 0.9032
Epoch 27/100
3/3 [==============================] - 0s 38ms/step - loss: 0.4347 - accuracy: 0.7925 - val_loss: 0.2902 - val_accuracy: 0.9032
Epoch 28/100
3/3 [==============================] - 0s 98ms/step - loss: 0.4299 - accuracy: 0.8077 - val_loss: 0.2838 - val_accuracy: 0.9032
Epoch 29/100
3/3 [==============================] - 0s 41ms/step - loss: 0.4166 - accuracy: 0.7954 - val_loss: 0.2777 - val_accuracy: 0.9032
Epoch 30/100
3/3 [==============================] - 0s 42ms/step - loss: 0.4076 - accuracy: 0.8104 - val_loss: 0.2762 - val_accuracy: 0.8710
Epoch 31/100
3/3 [==============================] - 0s 43ms/step - loss: 0.3831 - accuracy: 0.8327 - val_loss: 0.2770 - val_accuracy: 0.8710
Epoch 32/100
3/3 [==============================] - 0s 40ms/step - loss: 0.3778 - accuracy: 0.8424 - val_loss: 0.2783 - val_accuracy: 0.8710
Epoch 33/100
3/3 [==============================] - 0s 40ms/step - loss: 0.3856 - accuracy: 0.8238 - val_loss: 0.2784 - val_accuracy: 0.8710
Epoch 34/100
3/3 [==============================] - 0s 38ms/step - loss: 0.3829 - accuracy: 0.8201 - val_loss: 0.2814 - val_accuracy: 0.9032
Epoch 35/100
3/3 [==============================] - 0s 39ms/step - loss: 0.3859 - accuracy: 0.8305 - val_loss: 0.2809 - val_accuracy: 0.9032
Epoch 36/100
3/3 [==============================] - 0s 40ms/step - loss: 0.3814 - accuracy: 0.8266 - val_loss: 0.2782 - val_accuracy: 0.8710
Epoch 37/100
3/3 [==============================] - 0s 40ms/step - loss: 0.3708 - accuracy: 0.8325 - val_loss: 0.2790 - val_accuracy: 0.8710
Epoch 38/100
3/3 [==============================] - 0s 39ms/step - loss: 0.3541 - accuracy: 0.8565 - val_loss: 0.2850 - val_accuracy: 0.9032
Epoch 39/100
3/3 [==============================] - 0s 38ms/step - loss: 0.3645 - accuracy: 0.8448 - val_loss: 0.2861 - val_accuracy: 0.9032
Epoch 40/100
3/3 [==============================] - 0s 40ms/step - loss: 0.3611 - accuracy: 0.8535 - val_loss: 0.2850 - val_accuracy: 0.9032
Epoch 41/100
3/3 [==============================] - 0s 38ms/step - loss: 0.3544 - accuracy: 0.8389 - val_loss: 0.2882 - val_accuracy: 0.8387
Epoch 42/100
3/3 [==============================] - 0s 39ms/step - loss: 0.3550 - accuracy: 0.8325 - val_loss: 0.2930 - val_accuracy: 0.8387
Epoch 43/100
3/3 [==============================] - 0s 39ms/step - loss: 0.3457 - accuracy: 0.8471 - val_loss: 0.2887 - val_accuracy: 0.8387
Epoch 44/100
3/3 [==============================] - 0s 39ms/step - loss: 0.3643 - accuracy: 0.8465 - val_loss: 0.2881 - val_accuracy: 0.9032
Epoch 45/100
3/3 [==============================] - 0s 39ms/step - loss: 0.3435 - accuracy: 0.8543 - val_loss: 0.2881 - val_accuracy: 0.9032
Epoch 46/100
3/3 [==============================] - 0s 41ms/step - loss: 0.3316 - accuracy: 0.8533 - val_loss: 0.2873 - val_accuracy: 0.9032
Epoch 47/100
3/3 [==============================] - 0s 41ms/step - loss: 0.3347 - accuracy: 0.8637 - val_loss: 0.2876 - val_accuracy: 0.9032
Epoch 48/100
3/3 [==============================] - 0s 43ms/step - loss: 0.3504 - accuracy: 0.8520 - val_loss: 0.2900 - val_accuracy: 0.9032
Epoch 49/100
3/3 [==============================] - 0s 40ms/step - loss: 0.3467 - accuracy: 0.8502 - val_loss: 0.2938 - val_accuracy: 0.9032
Epoch 50/100
3/3 [==============================] - 0s 39ms/step - loss: 0.3439 - accuracy: 0.8531 - val_loss: 0.2978 - val_accuracy: 0.9032
Epoch 51/100
3/3 [==============================] - 0s 40ms/step - loss: 0.3161 - accuracy: 0.8648 - val_loss: 0.2958 - val_accuracy: 0.9032
Epoch 52/100
3/3 [==============================] - 0s 39ms/step - loss: 0.3193 - accuracy: 0.8532 - val_loss: 0.2974 - val_accuracy: 0.8387
Epoch 53/100
3/3 [==============================] - 0s 40ms/step - loss: 0.3136 - accuracy: 0.8724 - val_loss: 0.3034 - val_accuracy: 0.8387
Epoch 54/100
3/3 [==============================] - 0s 39ms/step - loss: 0.3545 - accuracy: 0.8557 - val_loss: 0.3048 - val_accuracy: 0.8710
Epoch 55/100
3/3 [==============================] - 0s 38ms/step - loss: 0.3196 - accuracy: 0.8703 - val_loss: 0.3099 - val_accuracy: 0.9032
Epoch 56/100
3/3 [==============================] - 0s 39ms/step - loss: 0.3278 - accuracy: 0.8655 - val_loss: 0.3147 - val_accuracy: 0.9032
Epoch 57/100
3/3 [==============================] - 0s 38ms/step - loss: 0.3299 - accuracy: 0.8673 - val_loss: 0.3199 - val_accuracy: 0.9032
Epoch 58/100
3/3 [==============================] - 0s 39ms/step - loss: 0.3090 - accuracy: 0.8639 - val_loss: 0.3238 - val_accuracy: 0.9032
Epoch 59/100
3/3 [==============================] - 0s 39ms/step - loss: 0.3138 - accuracy: 0.8675 - val_loss: 0.3208 - val_accuracy: 0.9032
Epoch 60/100
3/3 [==============================] - 0s 39ms/step - loss: 0.3226 - accuracy: 0.8759 - val_loss: 0.3214 - val_accuracy: 0.8710
Epoch 61/100
3/3 [==============================] - 0s 43ms/step - loss: 0.3142 - accuracy: 0.8837 - val_loss: 0.3211 - val_accuracy: 0.9032
Epoch 62/100
3/3 [==============================] - 0s 40ms/step - loss: 0.3081 - accuracy: 0.8799 - val_loss: 0.3200 - val_accuracy: 0.9032
Epoch 63/100
3/3 [==============================] - 0s 41ms/step - loss: 0.3048 - accuracy: 0.8666 - val_loss: 0.3192 - val_accuracy: 0.9032
Epoch 64/100
3/3 [==============================] - 0s 40ms/step - loss: 0.3239 - accuracy: 0.8557 - val_loss: 0.3170 - val_accuracy: 0.8710
Epoch 65/100
3/3 [==============================] - 0s 42ms/step - loss: 0.2841 - accuracy: 0.8866 - val_loss: 0.3203 - val_accuracy: 0.8710
Epoch 66/100
3/3 [==============================] - 0s 40ms/step - loss: 0.2978 - accuracy: 0.8856 - val_loss: 0.3230 - val_accuracy: 0.8710
Epoch 67/100
3/3 [==============================] - 0s 40ms/step - loss: 0.3093 - accuracy: 0.8689 - val_loss: 0.3256 - val_accuracy: 0.8710
Epoch 68/100
3/3 [==============================] - 0s 40ms/step - loss: 0.3024 - accuracy: 0.8760 - val_loss: 0.3253 - val_accuracy: 0.9032
Epoch 69/100
3/3 [==============================] - 0s 39ms/step - loss: 0.2891 - accuracy: 0.8676 - val_loss: 0.3261 - val_accuracy: 0.9032
Epoch 70/100
3/3 [==============================] - 0s 39ms/step - loss: 0.3024 - accuracy: 0.8608 - val_loss: 0.3207 - val_accuracy: 0.9032
Epoch 71/100
3/3 [==============================] - 0s 41ms/step - loss: 0.3002 - accuracy: 0.8645 - val_loss: 0.3120 - val_accuracy: 0.8710
Epoch 72/100
3/3 [==============================] - 0s 40ms/step - loss: 0.2882 - accuracy: 0.8857 - val_loss: 0.3095 - val_accuracy: 0.8710
Epoch 73/100
3/3 [==============================] - 0s 43ms/step - loss: 0.2855 - accuracy: 0.8828 - val_loss: 0.3080 - val_accuracy: 0.8710
Epoch 74/100
3/3 [==============================] - 0s 83ms/step - loss: 0.3003 - accuracy: 0.8826 - val_loss: 0.3050 - val_accuracy: 0.8710
Epoch 75/100
3/3 [==============================] - 0s 40ms/step - loss: 0.2957 - accuracy: 0.8787 - val_loss: 0.3048 - val_accuracy: 0.8710
Epoch 76/100
3/3 [==============================] - 0s 39ms/step - loss: 0.2909 - accuracy: 0.8827 - val_loss: 0.3125 - val_accuracy: 0.8710
Epoch 77/100
3/3 [==============================] - 0s 39ms/step - loss: 0.2875 - accuracy: 0.8768 - val_loss: 0.3230 - val_accuracy: 0.8710
Epoch 78/100
3/3 [==============================] - 0s 39ms/step - loss: 0.2838 - accuracy: 0.8807 - val_loss: 0.3291 - val_accuracy: 0.9032
Epoch 79/100
3/3 [==============================] - 0s 39ms/step - loss: 0.2824 - accuracy: 0.8750 - val_loss: 0.3308 - val_accuracy: 0.9032
Epoch 80/100
3/3 [==============================] - 0s 40ms/step - loss: 0.2851 - accuracy: 0.8864 - val_loss: 0.3253 - val_accuracy: 0.8710
Epoch 81/100
3/3 [==============================] - 0s 39ms/step - loss: 0.2930 - accuracy: 0.8861 - val_loss: 0.3217 - val_accuracy: 0.8387
Epoch 82/100
3/3 [==============================] - 0s 39ms/step - loss: 0.2958 - accuracy: 0.8976 - val_loss: 0.3173 - val_accuracy: 0.8710
Epoch 83/100
3/3 [==============================] - 0s 39ms/step - loss: 0.2781 - accuracy: 0.9015 - val_loss: 0.3110 - val_accuracy: 0.8710
Epoch 84/100
3/3 [==============================] - 0s 39ms/step - loss: 0.2806 - accuracy: 0.8930 - val_loss: 0.3118 - val_accuracy: 0.8710
Epoch 85/100
3/3 [==============================] - 0s 39ms/step - loss: 0.2711 - accuracy: 0.8836 - val_loss: 0.3157 - val_accuracy: 0.8710
Epoch 86/100
3/3 [==============================] - 0s 39ms/step - loss: 0.2770 - accuracy: 0.8826 - val_loss: 0.3178 - val_accuracy: 0.8710
Epoch 87/100
3/3 [==============================] - 0s 39ms/step - loss: 0.2728 - accuracy: 0.8750 - val_loss: 0.3221 - val_accuracy: 0.8710
Epoch 88/100
3/3 [==============================] - 0s 39ms/step - loss: 0.2804 - accuracy: 0.8777 - val_loss: 0.3304 - val_accuracy: 0.8710
Epoch 89/100
3/3 [==============================] - 0s 40ms/step - loss: 0.2568 - accuracy: 0.8911 - val_loss: 0.3385 - val_accuracy: 0.8710
Epoch 90/100
3/3 [==============================] - 0s 42ms/step - loss: 0.2698 - accuracy: 0.9044 - val_loss: 0.3411 - val_accuracy: 0.8710
Epoch 91/100
3/3 [==============================] - 0s 40ms/step - loss: 0.2606 - accuracy: 0.9112 - val_loss: 0.3381 - val_accuracy: 0.8710
Epoch 92/100
3/3 [==============================] - 0s 39ms/step - loss: 0.2694 - accuracy: 0.8997 - val_loss: 0.3303 - val_accuracy: 0.8710
Epoch 93/100
3/3 [==============================] - 0s 39ms/step - loss: 0.2747 - accuracy: 0.8975 - val_loss: 0.3261 - val_accuracy: 0.8710
Epoch 94/100
3/3 [==============================] - 0s 40ms/step - loss: 0.2721 - accuracy: 0.8778 - val_loss: 0.3288 - val_accuracy: 0.8710
Epoch 95/100
3/3 [==============================] - 0s 39ms/step - loss: 0.2532 - accuracy: 0.8828 - val_loss: 0.3207 - val_accuracy: 0.8710
Epoch 96/100
3/3 [==============================] - 0s 40ms/step - loss: 0.2644 - accuracy: 0.8921 - val_loss: 0.3097 - val_accuracy: 0.8387
Epoch 97/100
3/3 [==============================] - 0s 39ms/step - loss: 0.2713 - accuracy: 0.9091 - val_loss: 0.3076 - val_accuracy: 0.8387
Epoch 98/100
3/3 [==============================] - 0s 39ms/step - loss: 0.2727 - accuracy: 0.9052 - val_loss: 0.3059 - val_accuracy: 0.8387
Epoch 99/100
3/3 [==============================] - 0s 32ms/step - loss: 0.2571 - accuracy: 0.9102 - val_loss: 0.3182 - val_accuracy: 0.8710
Epoch 100/100
3/3 [==============================] - 0s 31ms/step - loss: 0.2526 - accuracy: 0.8952 - val_loss: 0.3311 - val_accuracy: 0.8710
七、模型评估
import matplotlib.pyplot as pltacc = history.history['accuracy']
val_acc = history.history['val_accuracy']loss = history.history['loss']
val_loss = history.history['val_loss']epochs_range = range(epochs)plt.figure(figsize=(14, 4))
plt.subplot(1, 2, 1)plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()
代码输出 :
八、总结:
学习了什么是RNN,怎么处理数据,如何构建简单的RNN模型。本次训练的测试集最大值为0.9112,要想达到更大,可以修改学习率、batch_size、构建更复杂的模型网络等等。