【PyTorch】模型选择、欠拟合和过拟合

文章目录

  • 1. 理论介绍
  • 2. 实例解析
    • 2.1. 实例描述
    • 2.2. 代码实现
      • 2.2.1. 完整代码
      • 2.2.2. 输出结果

1. 理论介绍

  • 将模型在训练数据上拟合的比在潜在分布中更接近的现象称为过拟合, 用于对抗过拟合的技术称为正则化
  • 训练误差和验证误差都很严重, 但它们之间差距很小。 如果模型不能降低训练误差,这可能意味着模型过于简单(即表达能力不足),无法捕获试图学习的模式。 这种现象被称为欠拟合
  • 训练误差是指模型在训练数据集上计算得到的误差。
  • 泛化误差是指模型应用在同样从原始样本的分布中抽取的无限多数据样本时,模型误差的期望。我们永远不能准确地计算出泛化误差,在实际中,我们只能通过将模型应用于一个独立的测试集来估计泛化误差, 该测试集由随机选取的、未曾在训练集中出现的数据样本构成。
  • 影响模型泛化的因素
    • 可调整参数的数量。当可调整参数的数量(有时称为自由度)很大时,模型往往更容易过拟合。
    • 参数采用的值。当权重的取值范围较大时,模型可能更容易过拟合。
    • 训练样本的数量。即使模型很简单,也很容易过拟合只包含一两个样本的数据集,而过拟合一个有数百万个样本的数据集则需要一个极其灵活的模型。
  • 在机器学习中,我们通常在评估几个候选模型后选择最终的模型。 这个过程叫做模型选择。候选模型可能在本质上不同,也可能是不同的超参数设置下的同一类模型。
  • 为了确定候选模型中的最佳模型,我们通常会使用验证集。验证集与测试集十分相似,唯一的区别是验证集是用于确定最佳模型,测试集是用于评估最终模型的性能
  • K K K折交叉验证:当训练数据稀缺时,将原始训练数据分成 K K K个不重叠的子集。 然后执行 K K K次模型训练和验证,每次在 ( K − 1 ) (K-1) (K1)个子集上进行训练, 并在剩余的一个子集(在该轮中没有用于训练的子集)上进行验证。 最后,通过对 K K K次实验的结果取平均来估计训练和验证误差。
  • 引起过拟合的因素
    • 模型复杂度
      模型复杂度
    • 数据集大小
      • 训练数据集中的样本越少,我们就越有可能(且更严重地)过拟合。
      • 给出更多的数据,拟合更复杂的模型可能是有益的; 如果没有足够的数据,简单的模型可能更有用。

2. 实例解析

2.1. 实例描述

使用以下三阶多项式来生成训练和测试数据 y = 5 + 1.2 x − 3.4 x 2 2 ! + 5.6 x 3 3 ! + ϵ where  ϵ ∼ N ( 0 , 0. 1 2 ) . y = 5 + 1.2x - 3.4\frac{x^2}{2!} + 5.6 \frac{x^3}{3!} + \epsilon \text{ where } \epsilon \sim \mathcal{N}(0, 0.1^2). y=5+1.2x3.42!x2+5.63!x3+ϵ where ϵN(0,0.12).并用1阶(线性模型)、3阶、20阶多项式拟合。

2.2. 代码实现

2.2.1. 完整代码

import os
import numpy as np
import math, torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from tensorboardX import SummaryWriter
from rich.progress import trackdef evaluate_loss(dataloader, net, criterion):"""评估模型在指定数据集上的损失"""num_examples = 0loss_sum = 0.0with torch.no_grad():for X, y in dataloader:X, y = X.cuda(), y.cuda()loss = criterion(net(X), y)num_examples += y.shape[0]loss_sum += loss.sum()return loss_sum / num_examplesdef load_dataset(*tensors):"""加载数据集"""dataset = TensorDataset(*tensors)return DataLoader(dataset, batch_size, shuffle=True)if __name__ == '__main__':# 全局参数设置num_epochs = 400batch_size = 10learning_rate = 0.01# 创建记录器def log_dir():root = "runs"if not os.path.exists(root):os.mkdir(root)order = len(os.listdir(root)) + 1return f'{root}/exp{order}'writer = SummaryWriter(log_dir())# 生成数据集max_degree = 20             # 多项式最高阶数n_train, n_test = 100, 100  # 训练集和测试集大小true_w = np.zeros(max_degree+1)true_w[0:4] = np.array([5, 1.2, -3.4, 5.6])features = np.random.normal(size=(n_train + n_test, 1))np.random.shuffle(features)poly_features = np.power(features, np.arange(max_degree+1).reshape(1, -1))for i in range(max_degree+1):poly_features[:, i] /= math.gamma(i + 1)  # gamma(n)=(n-1)!labels = np.dot(poly_features, true_w)labels += np.random.normal(scale=0.1, size=labels.shape)    # 加高斯噪声服从N(0, 0.01)poly_features, labels = [torch.as_tensor(x, dtype=torch.float32) for x in [poly_features, labels.reshape(-1, 1)]]def loop(model_degree):# 创建模型net = nn.Linear(model_degree+1, 1, bias=False).cuda()nn.init.normal_(net.weight, mean=0, std=0.01)criterion = nn.MSELoss(reduction='none')optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate)# 加载数据集dataloader_train = load_dataset(poly_features[:n_train, :model_degree+1], labels[:n_train])dataloader_test = load_dataset(poly_features[n_train:, :model_degree+1], labels[n_train:])# 训练循环for epoch in track(range(num_epochs), description=f'{model_degree}-degree'):for X, y in dataloader_train:X, y = X.cuda(), y.cuda()loss = criterion(net(X), y)optimizer.zero_grad()loss.mean().backward()optimizer.step()writer.add_scalars(f"{model_degree}-degree", {"train_loss": evaluate_loss(dataloader_train, net, criterion),"test_loss": evaluate_loss(dataloader_test, net, criterion),}, epoch)print(f"{model_degree}-degree: weights =", net.weight.data.cpu().numpy())for model_degree in [1, 3, 20]:loop(model_degree)writer.close()

2.2.2. 输出结果

权重

  • 采用1阶多项式(线性模型)拟合
    1
  • 采用3阶多项式拟合
    3
  • 采用20阶多项式拟合
    20

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/212141.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Java程序员,你掌握了多线程吗?(文末送书)

目录 01、多线程对于Java的意义02、为什么Java工程师必须掌握多线程03、Java多线程使用方式04、如何学好Java多线程送书规则 摘要:互联网的每一个角落,无论是大型电商平台的秒杀活动,社交平台的实时消息推送,还是在线视频平台的流…

@德人合科技 | 数据透明加密防泄密系统\文件文档加密\设计图纸加密|源代码加密防泄密软件系统,——防止内部办公终端核心文件数据/资料外泄!

一款专业的数据防泄密管理系统,它采用了多种加密模式,包括透明加密、半透明加密和落地加密等,可以有效地保护企业的核心数据安全。 PC端访问地址: https://isite.baidu.com/site/wjz012xr/2eae091d-1b97-4276-90bc-6757c5dfedee …

验证码的多种生成策略

&#x1f60a; 作者&#xff1a; 瓶盖子io &#x1f496; 主页&#xff1a; 瓶盖子io-CSDN博客 第一种 a.导入依赖 <dependency><groupId>org.apache.commons</groupId><artifactId>commons-lang3</artifactId><version>3.10</ver…

NAS外网访问方案

基础流程 路由器开启端口映射&#xff08;如果有猫则要配置猫为转发模式&#xff0c;由路由器直接拨号即可使用第三方程序让内网ip发布到公网上&#xff08;如果有云服务器&#xff09;需要开启防火墙端口 好用的第三方程序 FRP穿透 优点&#xff1a;开源免费&#xff0c;速…

C++ 指针进阶

目录 一、字符指针 二、指针数组 三、数组指针 数组指针的定义 &数组名 与 数组名 数组指针的使用 四、数组参数 一维数组传参 二维数组传参 五、指针参数 一级指针传参 二级指针传参 六、函数指针 七、函数指针数组 八、指向函数指针数组的指针 九、回调函…

Ubuntu宝塔面板本地部署轻论坛系统HadSky并远程访问

文章目录 前言1. 网站搭建1.1 网页下载和安装1.2 网页测试1.3 cpolar的安装和注册 2. 本地网页发布2.1 Cpolar临时数据隧道2.2 Cpolar稳定隧道&#xff08;云端设置&#xff09;2.3 Cpolar稳定隧道&#xff08;本地设置&#xff09;2.4 公网访问测试 总结 前言 经过多年的基础…

C++多态(详解)

一、多态的概念 1.1、多态的概念 多态&#xff1a;多种形态&#xff0c;具体点就是去完成某个行为&#xff0c;当不同的对象去完成时会产生出不同的状态。 举个例子&#xff1a;比如买票这个行为&#xff0c;当普通人买票时&#xff0c;是全价买票&#xff1b;学生买票时&am…

百度推送收录工具-免费的各大搜索引擎推送工具

在互联网时代&#xff0c;网站收录是网站建设的重要一环。百度推送工具作为一种提高网站收录速度的方式备受关注。在这个信息爆炸的时代&#xff0c;对于网站管理员和站长们来说&#xff0c;了解并使用一些百度推送工具是非常重要的。本文将重点分享百度批量域名推送工具和百度…

深入理解JVM内存空间的担保策略

Java虚拟机&#xff08;JVM&#xff09;的内存管理是Java性能调优中最重要的方面之一&#xff0c;特别是在处理大型应用和服务时。JVM内存管理的一个关键组成部分是垃圾回收&#xff08;GC&#xff09;。在GC过程中&#xff0c;JVM需要确保有足够的内存来创建新对象&#xff0c…

景联文科技解读《2023人工智能基础数据服务产业发展白皮书》,助力解决数据标注挑战

前段时间&#xff0c;国家工业信息安全发展研究中心发布《2023人工智能基础数据服务产业发展白皮书》&#xff08;以下简称“白皮书”&#xff09;。 《白皮书》指出&#xff0c;2022年&#xff0c;中国人工智能基础数据服务产业的市场规模为45亿元&#xff0c;预计今年将达到5…

一步解决 java.io.FileNotFoundException: 找不到文件异常

1.问题描述 java.io.FileNotFoundException: C:\Users\Administrator\AppData\Local\Temp\localhost\uploads\image\20231206\2843cb16-9654-4e52-a757-76e3ca1f80ff.png (系统找不到指定的路径。) 2.原因分析 文件路径中的文件目录不存在 3.解决方案 方案一&#xff1a;如果…

最优化方法复习——线性规划之对偶问题

一、线性规划对偶问题定义 原问题&#xff1a; 对偶问题&#xff1a; &#xff08;1&#xff09;若一个模型为目标求 “极大”&#xff0c;约束为“小于等于” 的不等式&#xff0c;则它的对偶模型为目标求“极小”&#xff0c;约束是“大于等于”的不等式。即“Max&#xff0…

docker基本管理和概念

1、定义&#xff1a;一个开源的应用容器引擎&#xff0c;基于go语言开发&#xff0c;运行在liunx系统中的开源的、轻量级的“虚拟机” docker的容器技术可以在一台主机上轻松的为任何应用创建一个轻量级的、可移植的、自给自足的容器 docker的宿主机是liunx系统&#xff0c;集…

Maven——使用Nexus创建私服

私服不是Maven的核心概念&#xff0c;它仅仅是一种衍生出来的特殊的Maven仓库。通过建立自己的私服&#xff0c;就可以降低中央仓库负荷、节省外网带宽、加速Maven构建、自己部署构件等&#xff0c;从而高效地使用Maven。 有三种专门的Maven仓库管理软件可以用来帮助大家建立…

六个自媒体写作方法,提升自媒体创作收益

在自媒体时代&#xff0c;写作成为了一个不可或缺的技能。特别是对于新手来说&#xff0c;掌握一些有效的写作方法&#xff0c;可以事半功倍&#xff0c;更好地展现个人创意和观点。在这里&#xff0c;我将分享六个适合新手的自媒体写作方法&#xff0c;希望能够为你在写作之路…

class039 嵌套类问题的递归解题套路【算法】

class039 嵌套类问题的递归解题套路 算法讲解039【必备】嵌套类问题的递归解题套路 Code1 772. 基本计算器 III 实现一个基本的计算器来计算简单的表达式字符串。 表达式字符串只包含非负整数&#xff0c;算符、-、*、&#xff0c; ,左括号(和右括号)。整数除法需要向下截…

unity学习笔记

一、射线检测 如何让鼠标点击某个位置&#xff0c;游戏角色就能移动到该位置&#xff1f; 实现的原理分析&#xff1a;我们能看见游戏的东西就是摄像机拍摄到的东西&#xff0c;所以摄像机的镜平面就是当前能看到的了。 那接下来我们可以让摄像机发射一条射线&#xff0c;鼠标…

【网络编程】-- 01 概述、IP

网络编程 1 概述 1.1 计算机网络 (连接分散计算机设备以实现信息传递的系统) 计算机网络是指将地理位置不同的具有独立功能的多台计算机及其外部设备&#xff0c;通过通信线路连接起来&#xff0c;在网络操作系统&#xff0c;网络管理软件及网络通信协议的管理和协调下&…

06 硬件知识入门(MOSS管)

1 简介 MOS管和三极管的驱动方式完全不一样&#xff0c;以NPN型三极管为例&#xff0c;base极以小电流打开三极管&#xff0c;此时三极管的集电极被打开&#xff0c;发射极的高电压会导入&#xff0c;此时电流&#xff1a;Ic IbIe &#xff1b;电压&#xff1a;Ue>Uc>Ub…