【深度学习_TensorFlow】过拟合

写在前面

过拟合与欠拟合


欠拟合: 是指在模型学习能力较弱,而数据复杂度较高的情况下,模型无法学习到数据集中的“一般规律”,因而导致泛化能力弱。此时,算法在训练集上表现一般,但在测试集上表现较差,泛化性能不佳。

过拟合: 是指模型在训练数据上表现很好,但在测试数据上表现不佳。这是由于模型过于复杂,记住了训练数据中的噪声和模式,而没有学到一般规则。本文将探讨过拟合问题以及其解决方法。

过拟合的解决方法


方法描述
增加训练数据
(More data)通过增加训练数据量,简单粗暴最有效,可以减少过拟合现象。
正则化
(regularization)在损失函数中添加一项,以惩罚模型的复杂度。常用的正则化方法包括L1正则化、L2正则化和dropout。
早停法
(Early stopping)在训练过程中,每次迭代后都会评估模型在验证集上的性能。如果性能在连续若干次迭代中没有提高,就停止训练。
数据增强
(Data augmentation)通过改变原有数据减少过拟合。例如,可以通过旋转、缩放等方式对图像数据进行增强。

写在中间

可以看到随着网络层数增加,模型变得复杂,过拟合现象变得愈发严重。接下来,我们将介绍一系列方法来帮助检测并抑制过拟合现象。

在这里插入图片描述

1. 交叉验证

增加数据集是最有效的方法,但是代价往往是昂贵的,所以要充分利用好现有的数据集。前面我们介绍了数据集需要划分为训练集和测试集,但我们为了挑选模型超参数和检测过拟合线性,一般需要将原来的训练集再次切分为新的训练集和验证集(validation set)。最终数据集被切分为 训练集、验证集、测试集。这三部分数据集的功能如下:

类别描述
训练集用于训练模型的参数,通过学习训练数据集来进行模型训练
验证集用于评估训练过程中的模型表现,调整模型的超参数
测试集用于评估最终训练好的模型在真实数据上的表现,测试验证模型的性能,评估模型的预测能力和泛化能力

验证集和测试集的区分

验证集使命:根据验证集的表现来调整模型的各种超参数的设置,提升模型的泛化能力。

测试集使命:就是检验模型的能力,其表现不能用来反馈模型的调整(就如你不能拿着期末考试原题来练习,否则期末高分就不能体现出你平时学习的真实状况),我们的办法就是从平常的练习题中抽取几道题组成验证集来检验你的能力。

这是一个将mnist手写数字识别测试集切分的例子,将6万张图像的前5万张划分为训练集,后1万张划分为验证集

(x, y), (x_test, y_test) = datasets.mnist.load_data()# 60k训练集切分为 50k训练集和 10k验证集
# (x_train, y_train), (x_val, y_val)
x_train, x_val = tf.split(x, num_or_size_splits=[50000, 10000])
y_train, y_val = tf.split(y, num_or_size_splits=[50000, 10000])# 训练集
train_db = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_db = train_db.map(preprocess).shuffle(10000).batch(128)# 验证集
val_db = tf.data.Dataset.from_tensor_slices((x_val, y_val))
val_db = val_db.map(preprocess).shuffle(10000).batch(128)# 测试集
test_db = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_db = test_db.map(preprocess).shuffle(10000).batch(128)

但是这样切分还是有局限性的,训练集只能在前5万张图像中出现,验证集只能在后1万张图像中出现,是否有方法让验证集能使用前5万张的图像呢?还真有,这就是标题所提及的**交叉验证:**我们将训练集的6万张图像划成n份,每训练取其中的 n - 1 份作为训练集,取 1 份作为验证集,而非固定前5万张为训练集,后 1 万张为验证集。

# 创建一个范围为60000的索引数组
idx = tf.range(60000)
# 随机打乱索引数组
idx = tf.random.shuffle(idx)
# 使用索引数组从x和y中获取训练集的样本
# 训练集的样本数量为50000
x_train, y_train = tf.gather(x, idx[:50000]), tf.gather(y, idx[:50000])
# 使用索引数组从x和y中获取验证集的样本
# 验证集的样本数量为10000
x_val, y_val = tf.gather(x, idx[-10000:]), tf.gather(y, idx[-10000:])

也可以不用手动划分,在训练函数中增加参数,即可自动化操作

# 使用60k训练集对网络进行训练,设置训练周期数为6
# 设置验证集占比为0.1,即数据集中10%的数据将作为验证集
# 设置每个2个训练周期进行一次验证
network.fit(train_db, epochs=6, validation_split=0.1, validation_freq=2)

2. 正则化

之所以出现过拟合的现象,是因为模型太过复杂,这里通过限制网络参数的稀疏性来约束网络的实际容量,这种约束一般通过在损失函数上添加额外的参数稀疏性惩罚项实现,常用的正则化的方式有L0、L1、L2正则化,dropout正则化。


简单介绍

关于正则化的数学原理,本人也搞不明白,也就不滥竽充数了。

有关实现简单概括就是在损失函数中引入模型权重参数的L范数,使学习到的权重参数稀疏化。

正则化的方式可以手动实现,也可以调用API实现:其中手动实现主要在计算loss值的后面,调用API主要在创建层的时候。

import tensorflow as tf
from tensorflow.keras import layers, regularizers# 网络构建
# 这会在模型损失函数中加入权重参数的L2范数作为惩罚项,力度由0.001控制。
network = Sequential([layers.Dense(256, kernel_regularizer=regularizers.l2(0.001), activation='relu'),layers.Dense(128, kernel_regularizer=regularizers.l2(0.001), activation='relu'),layers.Dense(64, kernel_regularizer=regularizers.l2(0.001), activation='relu'),layers.Dense(32, kernel_regularizer=regularizers.l2(0.001), activation='relu'),layers.Dense(10)])
# 参数构建
network.build(input_shape=(None, 28*28))
# 模型展示
network.summary()
# 截取手动前向计算的代码
for step, (x, y) in enumerate(train_db):# 创建一个 GradientTape,用于记录计算过程with tf.GradientTape() as tape:x = tf.reshape(x, (-1, 28*28))  # [b, 28, 28] => [b, 784]out = network(x)  # [b, 784] => [b, 10]y_onehot = tf.one_hot(y, depth=10)  # [b] => [b, 10]# 使用交叉熵损失函数计算 lossloss = tf.reduce_mean(tf.losses.categorical_crossentropy(y_onehot, out, from_logits=True))loss_regularization = []for p in network.trainable_variables:  # 重点在这里,遍历网络中所有的可训练参数(network.trainable_variables)。loss_regularization.append(tf.nn.l2_loss(p))  # # 对每个参数计算L2正则化项(tf.nn.l2_loss(p)),这会返回一个标量。loss_regularization = tf.reduce_sum(tf.stack(loss_regularization))  # # 将所有参数的L2正则化项求和,得到正则化损失loss_regularization。# 将损失函数定义为交叉熵损失和 L2 正则化损失的和loss = loss + 0.0001 * loss_regularization# 使用 tape 计算损失函数关于网络参数的梯度,并应用优化器进行反向传播更新参数grads = tape.gradient(loss, network.trainable_variables)optimizer.apply_gradients(zip(grads, network.trainable_variables))

正则化效果

在这里插入图片描述

Dropout


通过随机断开神经网络的连接,减少每次训练时实际参与计算的模型的参数量,但是在测试时,Dropout会恢复所有的连接,保证模型测试时获得最好的性能。

在这里插入图片描述

我们以层方式来实现以上功能

import tensorflow as tf
from tensorflow.keras import layers, regularizers# 网络构建network = Sequential([layers.Dense(256, activation='relu'),layers.Dropout(0.5),  # 有0.5的概率断开与下一层神经元的连接layers.Dense(128, activation='relu'),layers.Dropout(0.5),layers.Dense(64, activation='relu'),layers.Dense(32, activation='relu'),layers.Dense(10)])
# 参数构建
network.build(input_shape=(None, 28*28))
# 模型展示
network.summary()for step, (x, y) in enumerate(train_db):# 训练时with tf.GradientTape() as tape:out = network(x, training=True)# 测试时out = network(x, training=False)

3. Early stopping

早停法

那么如何选择合适的 Epoch 就提前停止训练(Early Stopping),避免出现过拟合现象呢?我们可以通过观察验证指标的变化,来预测最适合的 Epoch 可能的位置。具体地,对于分类问题,我们可以记录模型的验证准确率,并监控验证准确率的变化,当发现验证准确率连续𝑛个 Epoch 没有下降时,可以预测可能已经达到了最适合的 Epoch 附近,从而提前终止训练。

from tensorflow.keras.callbacks import EarlyStopping# 数据集读取···# 定义早停法回调函数
early_stopping = EarlyStopping(monitor='val_loss',  # 监视验证集losspatience=3,  # 当验证集loss在3个epoch内都没有改善则停止训练mode='min',  # 监测loss时一般设置为min,监测准确值时一般设置为maxverbose=1,  # 检测值改善时打印一条信息restore_best_weights=True  # 将权重恢复到最好的一个epoch)# 网络构建# 参数构建# 模型装配# 模型训练,添加参数network.fit(train_db, epochs=100,validation_data=val_db, validation_steps=10,callbacks=[early_stopping])

这里我们会对手写数字识别的代码再次进行修改,来使用上面提及的方法,你可以通过修改repeat的方式来复制数据集使训练数据增多,更改epochs的方式来增加训练次数,经过测试,这段代码在10 epochs 之后便达到了过拟合

import tensorflow as tf
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics, regularizers
from tensorflow.keras.callbacks import EarlyStopping
# 处理每一张图像
def preprocess(x, y):x = tf.cast(x, dtype=tf.float32) / 255.x = tf.reshape(x, [28 * 28])y = tf.cast(y, dtype=tf.int32)y = tf.one_hot(y, depth=10)return x, y# 数据集读取
(x, y), (x_test, y_test) = datasets.mnist.load_data()# # 固定切分
# # 60k训练集切分为 50k训练集和 10k验证集
# # (x_train, y_train), (x_val, y_val)
# x_train, x_val = tf.split(x, num_or_size_splits=[50000, 10000])
# y_train, y_val = tf.split(y, num_or_size_splits=[50000, 10000])# 交叉验证切分
idx = tf.range(60000)
idx = tf.random.shuffle(idx)
x_train, y_train = tf.gather(x, idx[:50000]), tf.gather(y, idx[:50000])
x_val, y_val = tf.gather(x, idx[-10000:]), tf.gather(y, idx[-10000:])# 训练集
train_db = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_db = train_db.map(preprocess).shuffle(10000).batch(128).repeat(10)# 验证集
val_db = tf.data.Dataset.from_tensor_slices((x_val, y_val))
val_db = val_db.map(preprocess).shuffle(10000).batch(128)# 测试集
test_db = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_db = test_db.map(preprocess).shuffle(10000).batch(128)# 定义早停法回调函数
early_stopping = EarlyStopping(monitor='val_loss',  # 监视验证集losspatience=3,  # 当验证集loss在2个epoch内都没有改善则停止训练mode='min',  # 监测loss时一般设置为min,监测准确值时一般设置为maxverbose=1,  # 检测值改善时打印一条信息restore_best_weights=True  # 将权重恢复到最好的一个epoch)# 网络构建
# 正则化:在模型损失函数中加入权重参数的L2范数作为惩罚项,力度由0.001控制。
# Dropout:添加dropout层来随机断开连接
network = Sequential([layers.Dense(256, kernel_regularizer=regularizers.l2(0.001), activation='relu'),layers.Dropout(0.5),layers.Dense(128, kernel_regularizer=regularizers.l2(0.001), activation='relu'),layers.Dropout(0.5),layers.Dense(64, kernel_regularizer=regularizers.l2(0.001), activation='relu'),layers.Dense(32, kernel_regularizer=regularizers.l2(0.001), activation='relu'),layers.Dense(10)])
# 参数构建
network.build(input_shape=(None, 28*28))
# 模型展示
network.summary()
# 模型装配
network.compile(optimizer=optimizers.Adam(learning_rate=0.01),loss=tf.losses.CategoricalCrossentropy(from_logits=True),metrics=['accuracy'])
# 模型训练
network.fit(train_db, epochs=50,validation_data=val_db, validation_steps=10,callbacks=[early_stopping])
# 模型评估
print('模型评估:')
network.evaluate(test_db)

写在最后

👍🏻点赞,你的认可是我创作的动力!
⭐收藏,你的青睐是我努力的方向!
✏️评论,你的意见是我进步的财富!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/110695.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

亿发浙江生产工厂信息化建设管理平台,实现生产智能化、数字化

在全球化、科技深刻变革的时代,浙江省信息化建设正迎来新的发展机遇。以物联网、人工智能大数据、为代表的新技术应用,为人类社会带来了智能、便捷,也标志着新一代信息化浪潮已经到来。特别是在生产型企业中,智能制造是生产型企业…

运用Python解析HTML页面获取资料

在网络爬虫的应用中,我们经常需要从HTML页面中提取图片、音频和文字资源。本文将介绍如何使用Python的requests库和BeautifulSoup解析HTML页面,获取这些资源。 一、环境准备 首先,确保您已经安装了Python环境。接下来,我们需要安…

HUT23级训练赛

目录 A - tmn学长的字符串1 B - 帮帮神君先生 C - z学长的猫 D - 这题用来防ak E - 这题考察FFT卷积 F - 这题考察二进制 G - 这题考察高精度 H - 这题考察签到 I - 爱派克斯,启动! J - tmn学长的字符串2 K - 秋奕来买瓜 A - tmn学长的字符串1 思路&#x…

CSS中如何实现背景图片的平铺和定位?

聚沙成塔每天进步一点点 ⭐ 专栏简介⭐ 平铺背景图片⭐ 背景图片定位⭐ 同时设置平铺和定位⭐ 写在最后 ⭐ 专栏简介 前端入门之旅:探索Web开发的奇妙世界 记得点击上方或者右侧链接订阅本专栏哦 几何带你启航前端之旅 欢迎来到前端入门之旅!这个专栏是…

3D点云处理:基于2D边缘提取的方法提取3D点云边缘(占位待补充)

文章目录 0. 实现效果 微信:dhlddx B站演示视频 0. 实现效果

1.1 数据库系统简介

思维导图: 1.1.数据库系统简介 前言: 数据库系统是一个软件系统,用于管理和操作数据库。它提供了一个组织良好、高效并能够方便存取的数据存储机制,并且能够支持各种数据操作、事务管理、并发控制和恢复功能。以下是数据库系统的…

基于PIC单片机温度-脉搏-DS18B20温度-液晶12864显示(proteus仿真+源程序)

一、系统方案 1、上电初始化液晶第一行显示脉搏,第二行显示温度,第三行显示模式,第四行显示强度;按下K1按键可以选择模式,催眼模式或治疗模式。 2、治疗模块下,可以通过K2、K3修改强度。 二、硬件设计 原理…

c++之指针

总结性质 我们如何在一个函数中获取数组的长度: 我们都知道,在main函数中我们获得数组的长度只需要使用sizeof(a)/sizeof(a【0】)即可获得,但当我们把一个数组传入到方法时,c默认把…

W5500-EVB-PICO进行UDP组播数据回环测试(九)

前言 上一章我们用我们的开发板作为UDP客户端连接服务器进行数据回环测试,那么本章我们进行UDP组播数据回环测试。 什么是UDP组播? 组播是主机间一对多的通讯模式, 组播是一种允许一个或多个组播源发送同一报文到多个接收者的技术。组播源将…

完美解决Ubuntu网络故障,连接异常,IP地址一直显示127.0.0.1

终端输入ifconfig显示虚拟机IP地址为127.0.0.1&#xff0c;具体输出内容如下&#xff1a; wxyubuntu:~$ ifconfig lo: flags73<UP,LOOPBACK,RUNNING> mtu 65536inet 127.0.0.1 netmask 255.0.0.0inet6 ::1 prefixlen 128 scopeid 0x10<host>loop txqueuelen …

Java【手撕滑动窗口】LeetCode 209. “长度最小子数组“, 图文详解思路分析 + 代码

文章目录 前言一、长度最小子数组1, 题目2, 思路分析3, 代码 前言 各位读者好, 我是小陈, 这是我的个人主页, 希望我的专栏能够帮助到你: &#x1f4d5; JavaSE基础: 基础语法, 类和对象, 封装继承多态, 接口, 综合小练习图书管理系统等 &#x1f4d7; Java数据结构: 顺序表, 链…

苹果新健康专利:利用 iPhone、Apple Watch 来分析佩戴者的呼吸情况

根据美国商标和专利局&#xff08;USPTO&#xff09;公示的清单&#xff0c;苹果获得了一项健康相关的技术专利&#xff0c;可以利用 iPhone、Apple Watch 来分析佩戴者的呼吸系统。 苹果在专利中概述了一种测量用户呼吸功能的系统&#xff0c;通过 iPhone 上的光学感测单元&am…

中央发文:提高青年人才资助比例, 放宽学历、年龄限制 (附2023国自然资助比例统计)~

8 月 27 日&#xff0c;中共中央办公厅、国务院办公厅印发《关于进一步加强青年科技人才培养和使用的若干措施》&#xff08;以下简称《若干措施》&#xff09;&#xff0c;明确提出包括提高国家自然科学基金对青年科技人才的资助比例&#xff0c;放宽学历、年龄限制等措施&…

stm32之DHT11

今天&#xff0c;记录一下DHT11&#xff0c;涉及到了单总线协议&#xff0c;所以先花点时间谈论一下单总线协议&#xff08;DS18B20也是用的单总线&#xff09;。 单总线协议 单总线技术的通信协议 可能这时序图就是个例子&#xff0c;ds18b20的时序图与DHT11的时序图也是不一…

我想开通期权?如何开通期权账户?

场内期权的合约由交易所统一标准化定制&#xff0c;大家面对的同一个合约对应的价格都是一致的&#xff0c;比较公开透明&#xff0c;期权开户当天不能交易的&#xff0c;期权开户需要满足20日日均50万及半年交易经验即可操作&#xff0c;下文科普我想开通期权&#xff1f;如何…

软件工程(十七) 行为型设计模式(三)

1、观察者模式 简要说明 定义对象间的一种一对多的依赖关系,当一个对象的状态发生改变时,所有依赖于它的对象都得到通知并自动更新 速记关键字 联动,广播消息 类图如下 基于上面的类图,我们来实现一个监听器。类图中的Subject对应我们的被观察对象接口(IObservable),…

开源与云计算:新的合作模式

&#x1f337;&#x1f341; 博主猫头虎 带您 Go to New World.✨&#x1f341; &#x1f984; 博客首页——猫头虎的博客&#x1f390; &#x1f433;《面试题大全专栏》 文章图文并茂&#x1f995;生动形象&#x1f996;简单易学&#xff01;欢迎大家来踩踩~&#x1f33a; &a…

SQL注入之宽字节注入

文章目录 宽字节注入是什么&#xff1f;注入练习让转义符失效联合查询 代码审计 宽字节注入是什么&#xff1f; 宽字节注入准确来说不是注入手法&#xff0c;而是另外一种比较特殊的情况。宽字节注入的目的是绕过单双引号转义。 宽字节注入是一种绕过单双引号转义的手段&#x…

HQL解决连续三天登陆问题

1.背景 统计连续登录天数超过3天的用户&#xff0c;输出信息包括&#xff1a;用户id&#xff0c;登录天数&#xff0c;起始时间&#xff0c;结束时间&#xff1b; 2.准备数据 -- 建表 create table if not exists user_login_3days(user_id STRING,login_date date );--插入…

基于PIC单片机篮球计分计时器

一、系统方案 本设计采用PIC单片机作为主控制器&#xff0c;矩阵键盘控制&#xff0c;比分&#xff0c;计时控制&#xff0c;24秒&#xff0c;液晶12864显示。 二、硬件设计 原理图如下&#xff1a; 三、单片机软件设计 1、首先是系统初始化 2、液晶显示程序 /*************…