李沐深度学习-多层感知机、模型选择、过拟合、欠拟合

3.8.1 隐藏层

多层感知机在单层神经网络的基础上引入了一到多个隐藏层(hidden layer)。隐藏层位于输入层和输出层之间。图3.3展示了一个多层感知机的神经网络图,它含有一个隐藏层,该层中有5个隐藏单元。

                                                                    图3.3 带有隐藏层的多层感知机

在图3.3所示的多层感知机中,输入和输出个数分别为4和3,中间的隐藏层中包含了5个隐藏单元(hidden unit)。由于输入层不涉及计算,图3.3中的多层感知机的层数为2。由图3.3可见,隐藏层中的神经元和输入层中各个输入完全连接,输出层中的神经元和隐藏层中的各个神经元也完全连接。因此,多层感知机中的隐藏层和输出层都是全连接层。

3.8.2 激活函数

 一般来说,有了激活函数,就不可能再将我们的多层感知机退化成线性模型

上述问题的根源在于全连接层只是对数据做仿射变换(一种带有偏置项的线性变换),而多个仿射变换的叠加仍然是一个仿射变换。
解决问题的一个方法是引入非线性变换,例如对隐藏变量使用按元素运算的非线性函数进行变换,然后再作为下一个全连接层的输入。这个非线性函数被称为激活函数(activation function)。下面我们介绍几个常用的激活函数。

ps:softmax回归的模型架构通过单个仿射变换将我们的输入直接映射到输出,然后进行softmax操作

3.8.2.1 ReLU函数

ReLU(rectified linear unit)函数提供了一个很简单的非线性变换。给定元素$x$,该函数定义为

ReLU(x)=max(x,0).

可以看出,ReLU函数只保留正数元素,并将负数元素清零。

x = torch.arange(-8.0, 8.0, 0.1, requires_grad=True)
y = torch.relu(x)
d2l.plot(x.detach(), y.detach(), 'x', 'relu(x)', figsize=(5, 2.5))

显然,当输入为负数时,ReLU函数的导数为0;当输入为正数时,ReLU函数的导数为1。尽管输入为0时ReLU函数不可导,但是我们可以取此处的导数为0。下面绘制ReLU函数的导数。

y.backward(torch.ones_like(x), retain_graph=True)
d2l.plot(x.detach(), x.grad, 'x', 'grad of relu', figsize=(5, 2.5))

3.8.2.2 sigmoid函数

[对于一个定义域在ℝR中的输入, sigmoid函数将输入变换为区间(0, 1)上的输出]。

下面绘制了sigmoid函数。当输入接近0时,sigmoid函数接近线性变换。

y = torch.sigmoid(x)
d2l.plot(x.detach(), y.detach(), 'x', 'sigmoid(x)', figsize=(5, 2.5))

依据链式法则,sigmoid函数的导数

sigmoid′(x)=sigmoid(x)(1−sigmoid(x)).

下面绘制了sigmoid函数的导数。当输入为0时,sigmoid函数的导数达到最大值0.25;当输入越偏离0时,sigmoid函数的导数越接近0。

# 清除以前的梯度
x.grad.data.zero_()
y.backward(torch.ones_like(x),retain_graph=True)
d2l.plot(x.detach(), x.grad, 'x', 'grad of sigmoid', figsize=(5, 2.5))

3.8.2.3 tanh函数

 [tanh(双曲正切)函数也能将其输入压缩转换到区间(-1, 1)上]。

我们接着绘制tanh函数。当输入接近0时,tanh函数接近线性变换。虽然该函数的形状和sigmoid函数的形状很像,但tanh函数在坐标系的原点上对称。

y = torch.tanh(x)
d2l.plot(x.detach(), y.detach(), 'x', 'tanh(x)', figsize=(5, 2.5))

依据链式法则,tanh函数的导数

tanh′(x)=1−tanh2(x).

下面绘制了tanh函数的导数。当输入为0时,tanh函数的导数达到最大值1;当输入越偏离0时,tanh函数的导数越接近0。

# 清除以前的梯度
x.grad.data.zero_()
y.backward(torch.ones_like(x),retain_graph=True)
d2l.plot(x.detach(), x.grad, 'x', 'grad of tanh', figsize=(5, 2.5))

3.8.3 多层感知机

多层感知机就是含有至少一个隐藏层的由全连接层组成的神经网络,且每个隐藏层的输出通过激活函数进行变换。多层感知机的层数和各隐藏层中隐藏单元个数都是超参数。

在回归问题中,我们将输出层的输出个数设为1,并将输出直接提供给线性回归中使用的平方损失函数。

小结

  • 多层感知机在输出层与输入层之间加入了一个或多个全连接隐藏层,并通过激活函数对隐藏层输出进行变换。
  • 常用的激活函数包括ReLU函数、sigmoid函数和tanh函数。

3.10 多层感知机的简洁实现

import torch
from torch import nn
from d2l import torch as d2l

3.10.1 定义模型

和softmax回归唯一的不同在于,添加了2个全连接层。 第一层是[隐藏层],它(包含256个隐藏单元,并使用了ReLU激活函数)。 第二层是输出层。

net = nn.Sequential(nn.Flatten(),nn.Linear(784, 256),nn.ReLU(),nn.Linear(256, 10))def init_weights(m):if type(m) == nn.Linear:nn.init.normal_(m.weight, std=0.01)net.apply(init_weights);

3.10.2 读取数据并训练模型

[训练过程]的实现与我们实现softmax回归时完全相同, 这种模块化设计使我们能够将与模型架构有关的内容独立出来。

batch_size, lr, num_epochs = 256, 0.1, 10
loss = nn.CrossEntropyLoss()
trainer = torch.optim.SGD(net.parameters(), lr=lr)train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)

1.2  小结

  • 我们可以使用高级API更简洁地实现多层感知机。
  • 对于相同的分类问题,多层感知机的实现与softmax回归的实现相同,只是多层感知机的实现里增加了带有激活函数的隐藏层。

3.11 模型选择、欠拟合和过拟合

如果你改变过实验中的模型结构或者超参数,你也许发现了:当模型在训练数据集上更准确时,它在测试数据集上却不一定更准确。这是为什么呢?

3.11.1 训练误差和泛化误差

训练误差(training error):模型在训练数据集上表现出的误差

泛化误差(generalization error):模型在任意一个测试数据样本上表现出的误差的期望,并常常通过测试数据集上的误差来近似。

计算训练误差和泛化误差可以使用线性回归用到的平方损失函数和softmax回归用到的交叉熵损失函数。

让我们以高考为例来直观地解释训练误差和泛化误差这两个概念。训练误差可以认为是做往年高考试题(训练题)时的错误率,泛化误差则可以通过真正参加高考(测试题)时的答题错误率来近似。假设训练题和测试题都随机采样于一个未知的依照相同考纲的巨大试题库。如果让一名未学习中学知识的小学生去答题,那么测试题和训练题的答题错误率可能很相近。但如果换成一名反复练习训练题的高三备考生答题,即使在训练题上做到了错误率为0,也不代表真实的高考成绩会如此。

模型的参数是通过在训练数据集上训练模型而学习出的,参数的选择依据了最小化训练误差(高三备考生)。所以,训练误差的期望小于或等于泛化误差。

由于无法从训练误差估计泛化误差,一味地降低训练误差并不意味着泛化误差一定会降低。

机器学习模型应关注降低泛化误差。

3.11.2 模型选择

在机器学习中,通常需要评估若干候选模型的表现并从中选择模型。这一过程称为模型选择(model selection)。

可供选择的候选模型可以是有着不同超参数的同类模型。

以多层感知机为例,我们可以选择隐藏层的个数,以及每个隐藏层中隐藏单元个数和激活函数。为了得到有效的模型,我们通常要在模型选择上下一番功夫。下面,我们来描述模型选择中经常使用的验证数据集(validation data set)。

3.11.2.1 验证数据集

从严格意义上讲,测试集只能在所有超参数和模型参数选定后使用一次。不可以使用测试数据选择模型,如调参。

我们可以从给定的训练集中随机选取一小部分作为验证集,而将剩余部分作为真正的训练集。

3.11.2.3 K折交叉验证

通过对𝐾次实验的结果取平均来估计训练和验证误差。

3.11.3 欠拟合和过拟合

接下来,我们将探究模型训练中经常出现的两类典型问题:

欠拟合 :模型无法得到较低的训练误差的现象;

过拟合 :模型的训练误差远小于它在测试数据集上的误差的现象。

3.11.3.1 模型复杂度

高阶多项式函数模型参数更多,模型函数的选择空间更大,所以高阶多项式函数比低阶多项式函数的复杂度更高。因此,高阶多项式函数比低阶多项式函数更容易在相同的训练数据集上得到更低的训练误差。

给定训练数据集,如果模型的复杂度过低,很容易出现欠拟合;如果模型复杂度过高,很容易出现过拟合。

图3.4 模型复杂度对欠拟合和过拟合的影响

3.11.3.2 训练数据集大小

一般来说,如果训练数据集中样本数过少,过拟合更容易发生。

此外,泛化误差不会随训练数据集里样本数量增加而增大。

3.11.4 多项式函数拟合实验

我们现在可以(通过多项式拟合来探索这些概念)。

import math
import numpy as np
import torch
from torch import nn
from d2l import torch as d2l

3.11.4.1 生成数据集

[使用以下三阶多项式来生成训练和测试数据的标签:]

其中噪声项ϵ服从均值为0、标准差为0.1的正态分布。训练数据集和测试数据集的样本数都设为100。

max_degree = 20  # 多项式的最大阶数
n_train, n_test = 100, 100  # 训练和测试数据集大小
true_w = np.zeros(max_degree)  # 分配大量的空间
true_w[0:4] = np.array([5, 1.2, -3.4, 5.6])features = np.random.normal(size=(n_train + n_test, 1))
np.random.shuffle(features)
poly_features = np.power(features, np.arange(max_degree).reshape(1, -1))
for i in range(max_degree):poly_features[:, i] /= math.gamma(i + 1)  # gamma(n)=(n-1)!
# labels的维度:(n_train+n_test,)
labels = np.dot(poly_features, true_w)
labels += np.random.normal(scale=0.1, size=labels.shape)

看一看生成的数据集的前两个样本。第一个值是与偏置相对应的常量特征。

# NumPy ndarray转换为tensor
true_w, features, poly_features, labels = [torch.tensor(x, dtype=torch.float32) for x in [true_w, features, poly_features, labels]]
features[:2], poly_features[:2, :], labels[:2]

3.11.4.2 定义、训练和测试模型

[实现一个函数来评估模型在给定数据集上的损失]。

def evaluate_loss(net, data_iter, loss):  #@save"""评估给定数据集上模型的损失"""metric = d2l.Accumulator(2)  # 损失的总和,样本数量for X, y in data_iter:out = net(X)y = y.reshape(out.shape)l = loss(out, y)metric.add(l.sum(), l.numel())return metric[0] / metric[1]

[定义训练函数]。

def train(train_features, test_features, train_labels, test_labels,num_epochs=400):loss = nn.MSELoss()input_shape = train_features.shape[-1]# 不设置偏置,因为我们已经在多项式中实现了它net = nn.Sequential(nn.Linear(input_shape, 1, bias=False))batch_size = min(10, train_labels.shape[0])train_iter = d2l.load_array((train_features, train_labels.reshape(-1,1)),batch_size)test_iter = d2l.load_array((test_features, test_labels.reshape(-1,1)),batch_size, is_train=False)trainer = torch.optim.SGD(net.parameters(), lr=0.01)animator = d2l.Animator(xlabel='epoch', ylabel='loss', yscale='log',xlim=[1, num_epochs], ylim=[1e-3, 1e2],legend=['train', 'test'])for epoch in range(num_epochs):d2l.train_epoch_ch3(net, train_iter, loss, trainer)if epoch == 0 or (epoch + 1) % 20 == 0:animator.add(epoch + 1, (evaluate_loss(net, train_iter, loss),evaluate_loss(net, test_iter, loss)))print('weight:', net[0].weight.data.numpy())

3.11.4.3 三阶多项式函数拟合(正常)

我们先使用与数据生成函数同阶的三阶多项式函数拟合。实验表明,这个模型的训练误差和在测试数据集的误差都较低。训练出的模型参数也接近真实值:𝑤=[5,1.2,−3.4,5.6]w=[5,1.2,−3.4,5.6]。

# 从多项式特征中选择前4个维度,即1,x,x^2/2!,x^3/3!
train(poly_features[:n_train, :4], poly_features[n_train:, :4],labels[:n_train], labels[n_train:])

3.11.4.4 线性函数拟合(欠拟合)

 欠拟合(该模型的训练误差在迭代早期下降后便很难继续降低。在完成最后一次迭代周期后,训练误差依旧很高。)非线性模型容易出现欠拟合。

# 从多项式特征中选择前2个维度,即1和x
train(poly_features[:n_train, :2], poly_features[n_train:, :2],labels[:n_train], labels[n_train:])

3.11.4.5 训练样本不足(过拟合)

# 过拟合(训练误差较低,测试误差很高。样本数量太少,模型显得过于复杂,容易被训练数据中的噪声影响)。

# 从多项式特征中选取所有维度
train(poly_features[:n_train, :], poly_features[n_train:, :],labels[:n_train], labels[n_train:], num_epochs=1500)

小结

  • 由于无法从训练误差估计泛化误差,一味地降低训练误差并不意味着泛化误差一定会降低。机器学习模型应关注降低泛化误差。
  • 可以使用验证数据集来进行模型选择。
  • 欠拟合指模型无法得到较低的训练误差,过拟合指模型的训练误差远小于它在测试数据集上的误差。
  • 应选择复杂度合适的模型并避免使用过少的训练样本。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/433085.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Windows环境部署Oracle 11g

Windows环境部署Oracle 11g 1.安装包下载2. 解压安装包3. 数据库安装3.1 执行安装脚本3.2 电子邮件设置3.3 配置安装选项3.4 配置系统类3.5 选择数据库安装类型3.6 选择安装类型3.7 数据库配置3.8 确认安装信息3.9 设置口令 Oracle常用命令 2023年10月中旬就弄出大致的文章&…

铨顺宏科技携RTLS+RFID技术亮相工博会!

中国国际工业博览会盛大开幕! 铨顺宏科技展亮点速递 铨顺宏科技展位号:F117 中国国际博览会今日开幕,铨顺宏科技携创新产品亮相,吸引众多参观者。 我们珍视此次国际盛会,将全力以赴确保最佳体验。 工作人员热情解答…

深度剖析OnlyFans:超越AI的盈利模式与未来挑战

引言 近年来,OnlyFans以其惊人的收入水平震惊了硅谷,2022年的66亿美元营收远超OpenAI的34亿美元。本文将深入探讨OnlyFans的成功原因、商业模式以及面临的AI挑战,试图揭示其在付费内容生态中的独特地位。 OnlyFans的商业模式 OnlyFans成立…

Solidity——抽象合约和接口详解

🚀本系列文章为个人学习笔记,目的是巩固知识并记录我的学习过程及理解。文笔和排版可能拙劣,望见谅。 Solidity中的抽象合约和接口详解 目录 什么是抽象合约?抽象合约的语法接口(Interface)的定义接口的语…

MySQL_插入、更新和删除数据

课 程 推 荐我 的 个 人 主 页:👉👉 失心疯的个人主页 👈👈入 门 教 程 推 荐 :👉👉 Python零基础入门教程合集 👈👈虚 拟 环 境 搭 建 :&#x1…

使用Electron打包一个Vue3项目全步骤

1.创建一个Vue3项目 2.使用 WebStorm打开项目,并安装依赖项 npm install 等待完成后, 安装electron npm install --save-dev electron 等待完成后, 安装electron 打包依赖项(打包成可执行文件) npm install electron-packager --save-dev 3…

【大模型-驯化】成功解决载cuda-11.8配置下搭建swift框架

【大模型-驯化】成功解决载cuda-11.8配置下搭建swift框架 本次修炼方法请往下查看 🌈 欢迎莅临我的个人主页 👈这里是我工作、学习、实践 IT领域、真诚分享 踩坑集合,智慧小天地! 🎇 相关内容文档获取 微信公众号 &…

油气田可视化管理:精准监测与高效生产

通过图扑可视化技术实时监测油气田运行数据,优化生产流程,提高资源利用率和安全性,实现精准化管理。

如何在谷歌浏览器上玩大型多人在线游戏

在如今的数字时代,谷歌浏览器已经成为了许多人上网冲浪的首选工具。除了浏览网页、观看视频之外,你还可以在谷歌浏览器上畅玩各种大型多人在线游戏。本文将为你详细介绍如何在谷歌浏览器上玩大型多人在线游戏的步骤。 (本文由https://chrome…

AWS Network Firewall - 配置只应许白名单域名出入站

参考链接 https://repost.aws/zh-Hans/knowledge-center/network-firewall-configure-domain-ruleshttps://aws.amazon.com/cn/blogs/networking-and-content-delivery/deployment-models-for-aws-network-firewall/ 1. 创建防火墙 选择防火墙的归属子网(选择公有…

WinForm程序嵌入Web网页

文章目录 前言一、三方库或控件的选择测试二、Microsoft Edge WebView2安装、使用步骤1.安装2.使用 前言 由于此项目需要winform客户端嵌入web网页并于JAVA端交互数据,所以研究了一下嵌入web网页这部分,趟了一遍雷,这里做下记录。 一、三方库…

C# 委托(Delegate)二

一.委托的多播(Multicasting of a Delegate): 委托对象,使用 "" 运算符进行合并,一个合并委托调用它所合并的两个委托。使用"-" 运算符从合并的委托中移除组件委托。 注:只有相同类型…

微服务-流量染色

1. 功能目的 通过设置请求头的方式将http请求优先打到指定的服务上,为微服务开发调试工作提供便利 请求报文难模拟:可以直接在测试环境页面上操作,流量直接打到本地IDEA进行debug请求链路较长:本地开发无需启动所有服务&#xf…

[附源码]网上订餐系统+SpringBoot+前后端分离

今天带来一款优秀的项目:网上订餐系统源码 。 系统采用的流行的前后端分离结构,包含了“管理端”,“商家管理端”,“用户购买端” 如果您有任何问题,也请联系小编,小编是经验丰富的程序员! 一.…

【Python语言初识(五)】

一、文件和异常 在Python中实现文件的读写操作其实非常简单,通过Python内置的open函数,我们可以指定文件名、操作模式、编码信息等来获得操作文件的对象,接下来就可以对文件进行读写操作了。这里所说的操作模式是指要打开什么样的文件&#…

SpringSecurity -- 入门使用

文章目录 什么是 SpringSesurity ?细节使用方法 什么是 SpringSesurity ? 在我们的开发中,安全还是有些必要的 用 拦截器 和 过滤器 写代码还是比较麻烦。 SpringSecurity 是 SpringBoot 的底层安全默认选型。一般我们需要认证和授权&#xf…

程序编译的四个阶段

程序编译的四个阶段 #include <stdio.h>int main(){printf("Hello World~");return 0; } hello.c程序的生命周期从一个高级C语言程序开始&#xff0c;这种形式容易被人读懂。 但这无法直接被计算机读懂。为了在系统上运行hello.c程序&#xff0c;每条C语言都…

mysql数据库的基本管理

目录 一.数据库的介绍 二.mariadb的安装 三.软件基本信息 四.数据库开启 五.数据库的安全初始化 六.数据库的基本管理 七.数据密码管理 八.用户授权 九.数据库的备份 十.web控制器 一.数据库的介绍 1.什么是数据库 数据库就是个高级的表格软件 2.常见数据库 Mysql Oracl…

[大语言模型] 情感认知在大型语言模型中的近期进展-2024-09-26

[大语言模型] 情感认知在大型语言模型中的近期进展-2024-09-26 论文信息 Title: Recent Advancement of Emotion Cognition in Large Language Models Authors: Yuyan Chen, Yanghua Xiao https://arxiv.org/abs/2409.13354 情感认知在大型语言模型中的近期进展 《Recent A…

JVM 垃圾回收算法细节

目录 前言 GC Root 可达性分析 根节点枚举 安全点 安全区域 记忆集与卡表 写屏障 并行的可达性分析 前言 学习了几种垃圾收集算法之后&#xff0c; 我们再来看看它们在具体实现上有什么细节之处&#xff0c;我们所能看到的理论很简单&#xff0c;但是实现起来那…