过拟合与欠拟合

一、模型选择

1、问题导入

2、训练误差与泛化误差

3、验证数据集和测试数据集

4、K-折交叉验证

       一般在没有足够多数据时使用。

二、过拟合与欠拟合

1、过拟合

过拟合的定义:

       当学习器把训练样本学的“太好”了的时候,很可能已经把训练样本自身的一些特点当作了所有潜在样本都会具有的一般性质,这样就会导致泛化性能下降,这种现象称为过拟合。具体表现就是最终模型在训练集上效果好;在测试集上效果差。模型泛化能力弱。

过拟合的原因:

  • 训练数据中噪音干扰过大,使得学习器认为部分噪音是特征从而扰乱学习规则。
  • 建模样本选取有误,例如训练数据太少,抽样方法错误,样本label错误等,导致样本不能代表整体。
  • 模型不合理,或假设成立的条件与实际不符。
  • 特征维度/参数太多,导致模型复杂度太高。

2、欠拟合

欠拟合的定义:

       欠拟合是指对训练样本的一般性质尚未学好。在训练集及测试集上的表现都不好。

欠拟合的原因:

  • 模型复杂度过低
  • 特征量过少

3、模型容量

4、数据复杂度

三、代码解释

       我们可以通过多项式拟合来探索这些概念。

import math
import numpy as np
import torch
from torch import nn
from d2l import torch as d2l

1、生成数据集

       给定$x$,我们将使用以下三阶多项式来生成训练和测试数据的标签:

$y = 5 + 1.2x - 3.4\frac{x^2}{2!} + 5.6 \frac{x^3}{3!} + \epsilon \text{ where } \epsilon \sim \mathcal{N}(0, 0.1^2).$

       噪声项$\epsilon$服从均值为0且标准差为0.1的正态分布。在优化的过程中,我们通常希望避免非常大的梯度值或损失值。这就是我们将特征从$x^i$调整为$\frac{x^i}{i!}$的原因,这样可以避免很大的$i$带来的特别大的指数值。我们将为训练集和测试集各生成100个样本。

max_degree = 20  # 多项式的最大阶数
n_train, n_test = 100, 100  # 训练和测试数据集大小
true_w = np.zeros(max_degree)  # 分配大量的空间
true_w[0:4] = np.array([5, 1.2, -3.4, 5.6])     # 注意true_w.shape = (max_degree,),只是后面16个都为0
features = np.random.normal(size=(n_train + n_test, 1))
np.random.shuffle(features)     # features.shape = (n_train + n_test, 1)
# 使用 numpy.power 函数将 features 的每个元素分别与 np.arange(max_degree)(从0到max_degree-1的数组)进行幂运算。
poly_features = np.power(features, np.arange(max_degree).reshape(1, -1))    # ploy_feature.shape = (n_train + n_test, max_degree)
for i in range(max_degree):poly_features[:, i] /= math.gamma(i + 1)  # gamma(n)=(n-1)! 第i列数据除以i的阶乘
# labels的维度:(n_train+n_test,)
labels = np.dot(poly_features, true_w)  # labels.shape = (n_train + n_test,)  通过将多项式特征 poly_features 与真实系数 true_w 相乘得到标签。
labels += np.random.normal(scale=0.1, size=labels.shape)    # 向标签数据添加服从正态分布的噪声
# 下面4个分别对应: w x x^n/n! y=w*(x^n/n!)
true_w, features, poly_features, labels = [torch.tensor(x, dtype=       # 多项式:"polynomial"torch.float32) for x in [true_w, features, poly_features, labels]]  # for训练遍历列表[true_w, features, poly_features, labels]中的4个元素,将NumPy ndarray转换为tensor

2、对模型进行训练和测试

       首先让我们实现一个函数来评估模型在给定数据集上的损失。

def evaluate_loss(net, data_iter, loss):"""评估给定数据集上模型的损失"""metric = d2l.Accumulator(2)  # 损失的总和,样本数量for X, y in data_iter:out = net(X)y = y.reshape(out.shape)l = loss(out, y)metric.add(l.sum(), l.numel())return metric[0] / metric[1]

现在定义训练函数。

def train(train_features, test_features, train_labels, test_labels,num_epochs=400):loss = nn.MSELoss(reduction='none')     # 使用均方根损失input_shape = train_features.shape[-1]  # input_shape=20# 不设置偏置,因为我们已经在多项式中实现了它net = nn.Sequential(nn.Linear(input_shape, 1, bias=False))batch_size = min(10, train_labels.shape[0])train_iter = d2l.load_array((train_features, train_labels.reshape(-1,1)),batch_size)test_iter = d2l.load_array((test_features, test_labels.reshape(-1,1)),batch_size, is_train=False)trainer = torch.optim.SGD(net.parameters(), lr=0.01)animator = d2l.Animator(xlabel='epoch', ylabel='loss', yscale='log',xlim=[1, num_epochs], ylim=[1e-3, 1e2],legend=['train', 'test'])for epoch in range(num_epochs):d2l.train_epoch_ch3(net, train_iter, loss, trainer)if epoch == 0 or (epoch + 1) % 20 == 0:animator.add(epoch + 1, (evaluate_loss(net, train_iter, loss),evaluate_loss(net, test_iter, loss)))print('weight:', net[0].weight.data.numpy())

3、拟合结果分析

(1)三阶多项式函数拟合(正常)

       我们将首先使用三阶多项式函数,它与数据生成函数的阶数相同。结果表明,该模型能有效降低训练损失和测试损失。学习到的模型参数也接近真实值$w = [5, 1.2, -3.4, 5.6]$

# 从多项式特征中选择前4个维度,即1,x,x^2/2!,x^3/3!
# poly_features每行的前4个值是线性相关的,因为此时已经进行了幂运算、阶乘运算,就差乘true_w了,也正因如此前面才能使用单线性层nn.Linear拟合
train(poly_features[:n_train, :4], poly_features[n_train:, :4],labels[:n_train], labels[n_train:])

(2)线性函数拟合(欠拟合)

       让我们再看看线性函数拟合,减少该模型的训练损失相对困难。在最后一个迭代周期完成后,训练损失仍然很高。当用来拟合非线性模式(如这里的三阶多项式函数)时,线性模型容易欠拟合。不过这里的欠拟合是因为数据没给全(因为数据没给全而导致特征不明显),前4项weight的后面2项模型学不到,人为强制使模型欠拟合了,而不是模型容量不够。

# 从多项式特征中选择前2个维度,即1和x
train(poly_features[:n_train, :2], poly_features[n_train:, :2],labels[:n_train], labels[n_train:])

(3)高阶多项式函数拟合(过拟合) 

       现在,让我们尝试使用一个阶数过高的多项式来训练模型。在这种情况下,没有足够的数据用于学到高阶系数应该具有接近于零的值。因此,这个过于复杂的模型会轻易受到训练数据中噪声的影响。虽然训练损失可以有效地降低,但测试损失仍然很高。结果表明,复杂模型对数据造成了过拟合。其实就是给的数据太多(后面16项多余了),模型把不该学的(后面16项的weight)给学到了,于是就造成了过拟合。

# 从多项式特征中选取所有维度
train(poly_features[:n_train, :], poly_features[n_train:, :],labels[:n_train], labels[n_train:], num_epochs=1500)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/217532.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Navicat16 无限试用 亲测有效

Navicat16 无限试用 亲测有效 亲测有效!!! 吐槽下,有的用不了,有的是图片,更甚者还有收费的,6的一批 粘贴下面的代码,保存到桌面,命名为 trial-navicat16.bat echo off…

探索GameFi:区块链与游戏的未来融合

在过去的几年里,区块链技术逐渐渗透到各个领域,为不同行业带来了前所未有的变革。其中,游戏行业成为了一个引人注目的焦点,而这种结合被称为GameFi,即游戏金融。GameFi不仅仅是一个概念,更是一场区块链和游…

宏景eHR SQL 注入漏洞复现(CVE-2023-6655)

0x01 产品简介 宏景eHR人力资源管理软件是一款人力资源管理与数字化应用相融合,满足动态化、协同化、流程化、战略化需求的软件。 0x02 漏洞概述 宏景eHR 中发现了一种被分类为关键的漏洞,该漏洞影响了Login Interface组件中/w_selfservice/oauthservlet/%2e../.%2e/genera…

关于“Python”的核心知识点整理大全19

目录 ​编辑 8.6.4 使用 as 给模块指定别名 8.6.5 导入模块中的所有函数 8.7 函数编写指南 8.8 小结 第9章 类 9.1 创建和使用类 9.1.1 创建 Dog 类 dog.py 1. 方法__init__() 2. 在Python 2.7中创建类 9.1.2 根据类创建实例 1. 访问属性 2. 调用方法 3. 创建多…

到底什么是DevOps

DevOps不是一组工具,也不是一个特定的岗位。在我看来DevOps更像是一种软件开发文化,一种实现快速交付能力的手段。 DevOps 强调的是高效组织团队之间如何通过自动化的工具协作和沟通来完成软件的生命周期管理,从而更快、更频繁地交付更稳定的…

宠物自助洗护小程序系统

提供给宠物的自助洗澡机, 集恒温清洗、浴液 护毛、吹干、消毒于一体,宠物主人只需用微信小程序源码,即可一键开启洗宠流程。 主要功能: 在线预约 在线支付 洗护记录 会员系统 宠物管理 设备管理 多商户加盟

大数据技术10:Flink从入门到精通

导语:前期入门Flink时,可以直接编写通过idea编写Flink程序,然后直接运行main方法,无需搭建环境。我碰到许多初次接触Flink的同学,被各种环境搭建、提交作业、复杂概念给劝退了。前期最好的入门方式就是直接上手写代码&…

win10 + vs2017 + cmake3.17编译OSG-3.4.1

1. 下载文件 主要用到4个文件 1)OSG-3.4.1源码2)OSG第三方依赖库3)OSG示例数据4)cmake-3.17 我已经准备好了,大家可以自行下载。下载路径: 链接:https://pan.baidu.com/s/1E3YESh0T9KPlJJe2…

Android--Jetpack--Navigation详解

须知少日拏云志,曾许人间第一流 一,定义 Navigation 翻译成中文就是导航的意思。它是谷歌推出的Jetpack的一员,其目的主要就是来管理页面的切换和导航。 Activity 嵌套多个 Fragment 的 UI 架构模式已经非常普遍,但是对 Fragmen…

机器人制作开源方案 | 智能助老机器人

作者:刘颖、王浩宇、党玉娟 单位:北京科技大学 指导老师:刘新洋、栗琳 1. 项目背景 1.1 行业背景 随着越来越多的服务机器人进入家庭,应用场景呈现多元化和专业化,机器人产业生态体系正在不断完善,服务…

【MySQL】MySQL库的增删查改

文章目录 1.库的操作1.1创建数据库1.2创建数据库案例 2.字符集和校验规则2.1查看系统默认字符集以及校验规则2.2查看数据库支持的字符集2.3查看数据库支持的字符集校验规则2.4校验规则对数据库的影响 3.操纵数据库3.1查看数据库3.2显示创建语句3.3修改数据库3.4数据库删除3.5备…

2023年医疗器械行业分析(京东医疗器械运营数据分析):10月销额增长53%

随着我国整体实力的增强、国民生活水平的提高、人口老龄化、医疗保障体系不断完善等因素的驱动,我国的医疗器械市场增长迅速。 根据鲸参谋电商数据分析平台的相关数据显示,今年10月份,京东平台上医疗器械市场的销量将近1200万,环比…

1+X大数据平台运维职业技能等级证书中级

该部分是选择题部分,实操题在主页的另一篇文章 考试名称:“1X”大数据平台运维职业技能等级证书(中级) 1X 大数据平台运维中级测试题一、单选题 以下哪种情况容易引发 HDFS 负载不均问题?( C&#xff09…

windows禁用系统更新

1.在winr运行框中输入services.msc,打开windows服务窗口。 services.msc 2.在服务窗口中,我们找到Windows update选项,如下图所示: 3.双击windows update服务,我们把启动类型改为禁用,如下图所示&#xff…

AI浪潮下,大模型如何在音视频领域运用与实践?

视频云大模型算法「方法论」。 刘国栋|演讲者 在AI技术发展如火如荼的当下,大模型的运用与实践在各行各业以千姿百态的形式展开。音视频技术在多场景、多行业的应用中,对于智能化和效果性能的体验优化有较为极致的要求。如何运用好人工智能提…

STM32--中断使用(超详细!)

写在前面:前面的学习中,我们接触了STM32的第一个外设GPIO,这也是最常用的一个外设;而除了GPIO外,中断也是一个十分重要且常用的外设;只有掌握了中断,再处理程序时才能掌握好解决实际问题的逻辑思…

Arris VAP2500 list_mac_address未授权RCE漏洞复现

0x01 产品简介 Arris VAP2500是美国Arris集团公司的一款无线接入器产品。 0x02 漏洞概述 Arris VAP2500 list_mac_address接口处命令执行漏洞,未授权的攻击者可通过该漏洞在服务器端任意执行代码,写入后门,获取服务器权限,进而控制整个web服务器。 0x03 复现环境 FOFA…

ZLMediaKit 编译以及测试(Centos 7.9 环境)

文章目录 一、前言二、编译器1、获取代码2、编译器2.1 编译器版本要求2.2 安装编译器 3、安装cmake4、依赖库4.1 依赖库列表4.2 安装依赖库4.2.1 安装libssl-dev和libsdl-dev4.2.2 安装 ffmpeg-devel依赖和ffmpeg依赖 三、构建和编译项目(启用WebRTC功能&#xff09…

计算机网络:物理层(奈氏准则和香农定理,含例题)

带你速通计算机网络期末 文章目录 一、码元和带宽 1、什么是码元 2、数字通信系统数据传输速率的两种表示方法 2.1、码元传输速率 2.2、信息传输速率 3、例题 3.1、例题1 3.2、例题2 4、带宽 二、奈氏准则(奈奎斯特定理) 1、奈氏准则简介 2、…

Docker安全性:最佳实践和常见安全考虑

Docker 的快速发展和广泛应用使其成为现代应用开发的热门选择,然而,容器环境的安全性也受到关注。本文将深入研究 Docker 安全性的最佳实践,包括容器镜像安全、容器运行时安全、网络安全等方面,并提供丰富的示例代码,帮…