《动手学深度学习(PyTorch版)》笔记4.5

注:书中对代码的讲解并不详细,本文对很多细节做了详细注释。另外,书上的源代码是在Jupyter Notebook上运行的,较为分散,本文将代码集中起来,并加以完善,全部用vscode在python 3.9.18下测试通过。

Chapter4 Multilayer Perceptron

4.5 Weight Decay

前一节我们描述了过拟合的问题,本节我们将介绍一些正则化模型的技术。我们总是可以通过去收集更多的训练数据来缓解过拟合,但这可能成本很高,耗时颇多,或者完全超出我们的控制,因而在短期内不可能做到。假设我们已经拥有尽可能多的高质量数据,我们便可以将重点放在正则化技术上。

在多项式回归的例子中,我们可以通过调整拟合多项式的阶数来限制模型的容量。实际上,限制特征的数量是缓解过拟合的一种常用技术。然而,简单地丢弃特征对这项工作来说可能过于生硬。我们继续思考多项式回归的例子,考虑高维输入可能发生的情况。多项式对多变量数据的自然扩展称为单项式(monomials),也可以说是变量幂的乘积。单项式的阶数是幂的和。例如, x 1 2 x 2 x_1^2 x_2 x12x2 x 3 x 5 2 x_3 x_5^2 x3x52都是3次单项式。

注意,随着阶数 d d d的增长,带有阶数 d d d的项数迅速增加。 给定 k k k个变量,阶数为 d d d的项的个数为 ( k − 1 + d k − 1 ) {k - 1 + d} \choose {k - 1} (k1k1+d),即 C k − 1 + d k − 1 = ( k − 1 + d ) ! ( d ) ! ( k − 1 ) ! C^{k-1}_{k-1+d} = \frac{(k-1+d)!}{(d)!(k-1)!} Ck1+dk1=(d)!(k1)!(k1+d)!。因此即使是阶数上的微小变化,比如从 2 2 2 3 3 3,也会显著增加我们模型的复杂性。仅仅通过简单的限制特征数量(在多项式回归中体现为限制阶数),可能仍然使模型在过简单和过复杂中徘徊,我们需要一个更细粒度的工具来调整函数的复杂性,使其达到一个合适的平衡位置。

4.5.1 Norm and Weight Decay

权重衰减是最广泛使用的正则化的技术之一在训练参数化机器学习模型时,权重衰减(weight decay)是最广泛使用的正则化的技术之一,它通常也被称为 L 2 L_2 L2正则化。这项技术通过函数与零的距离来衡量函数的复杂度,因为在所有函数 f f f中,函数 f = 0 f = 0 f=0(所有输入都得到值 0 0 0)在某种意义上是最简单的。但是我们应该如何精确地测量一个函数和零之间的距离呢?没有一个准确的答案。
一种简单的方法是通过线性函数 f ( x ) = w ⊤ x f(\mathbf{x}) = \mathbf{w}^\top \mathbf{x} f(x)=wx中的权重向量的某个范数来度量其复杂性,例如 ∥ w ∥ 2 \| \mathbf{w} \|^2 w2。要保证权重向量比较小,最常用方法是将其范数作为惩罚项加到最小化损失的问题中。将原来的训练目标最小化训练标签上的预测损失,调整为最小化预测损失和惩罚项之和。现在,如果我们的权重向量增长的太大,我们的学习算法可能会更集中于最小化权重范数 ∥ w ∥ 2 \| \mathbf{w} \|^2 w2。这正是我们想要的。3.1节的线性回归例子里,损失由下式给出:

L ( w , b ) = 1 n ∑ i = 1 n 1 2 ( w ⊤ x ( i ) + b − y ( i ) ) 2 . L(\mathbf{w}, b) = \frac{1}{n}\sum_{i=1}^n \frac{1}{2}\left(\mathbf{w}^\top \mathbf{x}^{(i)} + b - y^{(i)}\right)^2. L(w,b)=n1i=1n21(wx(i)+by(i))2.

为了惩罚权重向量的大小,我们必须以某种方式在损失函数中添加 ∥ w ∥ 2 \| \mathbf{w} \|^2 w2,但是模型应该如何平衡这个新的额外惩罚的损失?实际上,我们通过正则化常数 λ \lambda λ来描述这种权衡,这是一个非负超参数,我们使用验证数据拟合:

L ( w , b ) + λ 2 ∥ w ∥ 2 , L(\mathbf{w}, b) + \frac{\lambda}{2} \|\mathbf{w}\|^2, L(w,b)+2λw2,

对于 λ = 0 \lambda = 0 λ=0,我们恢复了原来的损失函数。对于 λ > 0 \lambda > 0 λ>0,我们限制 ∥ w ∥ \| \mathbf{w} \| w的大小。这里我们仍然除以 2 2 2,是因为我们取一个二次函数的导数时, 2 2 2 1 / 2 1/2 1/2会抵消。通过平方 L 2 L_2 L2范数,我们去掉平方根,留下权重向量每个分量的平方和,这使得惩罚的导数很容易计算。

L 2 L_2 L2正则化线性模型构成经典的岭回归(ridge regression)算法, L 1 L_1 L1正则化线性回归是统计学中类似的基本模型,通常被称为套索回归(lasso regression)。我们通常使用 L 2 L_2 L2范数的一个原因是它对权重向量的大分量施加了巨大的惩罚。这使得我们的学习算法偏向于在大量特征上均匀分布权重的模型。在实践中,这可能使它们对单个变量中的观测误差更为稳定。相比之下, L 1 L_1 L1惩罚会导致模型将权重集中在一小部分特征上,而将其他权重清除为零。这称为特征选择(feature selection),这可能是其他场景下需要的。

L 2 L_2 L2正则化回归的小批量随机梯度下降更新如下式:

w ← ( 1 − η λ ) w − η ∣ B ∣ ∑ i ∈ B x ( i ) ( w ⊤ x ( i ) + b − y ( i ) ) . \begin{aligned} \mathbf{w} & \leftarrow \left(1- \eta\lambda \right) \mathbf{w} - \frac{\eta}{|\mathcal{B}|} \sum_{i \in \mathcal{B}} \mathbf{x}^{(i)} \left(\mathbf{w}^\top \mathbf{x}^{(i)} + b - y^{(i)}\right). \end{aligned} w(1ηλ)wBηiBx(i)(wx(i)+by(i)).

根据之前章节所讲的,我们根据估计值与观测值之间的差异来更新 w \mathbf{w} w。然而,我们同时也在试图将 w \mathbf{w} w的大小缩小到零。这就是为什么这种方法有时被称为权重衰减。我们仅考虑惩罚项,优化算法在训练的每一步衰减权重。与特征选择相比,权重衰减为我们提供了一种连续的机制来调整函数的复杂度。较小的 λ \lambda λ值对应较少约束的 w \mathbf{w} w,而较大的 λ \lambda λ值对 w \mathbf{w} w的约束更大。是否对相应的偏置 b 2 b^2 b2进行惩罚在不同的实践中会有所不同,在神经网络的不同层中也会有所不同。通常,网络输出层的偏置项不会被正则化。

4.5.2 High Dimensional Linear Regression

我们通过一个简单的例子来演示权重衰减。首先,我们像以前一样生成一些数据,生成公式如下:

y = 0.05 + ∑ i = 1 d 0.01 x i + ϵ where  ϵ ∼ N ( 0 , 0.0 1 2 ) . y = 0.05 + \sum_{i = 1}^d 0.01 x_i + \epsilon \text{ where } \epsilon \sim \mathcal{N}(0, 0.01^2). y=0.05+i=1d0.01xi+ϵ where ϵN(0,0.012).

我们选择标签是关于输入的线性函数。标签同时被均值为0,标准差为0.01高斯噪声破坏。为了使过拟合的效果更加明显,我们可以将问题的维数增加到 d = 200 d = 200 d=200,并使用一个只包含20个样本的小训练集。

import matplotlib.pyplot as plt
import torch
from d2l import torch as d2l
from torch import nnn_train, n_test, num_inputs, batch_size = 20, 100, 200, 5
true_w, true_b = torch.ones((num_inputs, 1)) * 0.01, 0.05
train_data = d2l.synthetic_data(true_w, true_b, n_train)
train_iter = d2l.load_array(train_data, batch_size)
test_data = d2l.synthetic_data(true_w, true_b, n_test)
test_iter = d2l.load_array(test_data, batch_size, is_train=False)#初始化模型参数
def init_params():w = torch.normal(0, 1, size=(num_inputs, 1), requires_grad=True)b = torch.zeros(1, requires_grad=True)return [w, b]#定义L2范数惩罚
def l2_penalty(w):return torch.sum(w.pow(2)) / 2#训练
def train(lambd):w, b = init_params()net, loss = lambda X: d2l.linreg(X, w, b), d2l.squared_lossnum_epochs, lr = 100, 0.003animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',xlim=[5, num_epochs], legend=['train', 'test'])for epoch in range(num_epochs):for X, y in train_iter:# 增加了L2范数惩罚项,# 广播机制使l2_penalty(w)成为一个长度为batch_size的向量l = loss(net(X), y) + lambd * l2_penalty(w)l.sum().backward()d2l.sgd([w, b], lr, batch_size)if (epoch + 1) % 5 == 0:animator.add(epoch + 1, (d2l.evaluate_loss(net, train_iter, loss),d2l.evaluate_loss(net, test_iter, loss)))print('w的L2范数是:', torch.norm(w).item())#忽略正则化直接训练
train(lambd=0)
plt.show()#使用正则化后进行训练
train(lambd=3)
plt.show()#简洁实现
def train_concise(wd):net = nn.Sequential(nn.Linear(num_inputs, 1))for param in net.parameters():param.data.normal_()loss = nn.MSELoss(reduction='none')num_epochs, lr = 100, 0.003# 偏置参数没有衰减trainer = torch.optim.SGD([{"params":net[0].weight,'weight_decay': wd},{"params":net[0].bias}], lr=lr)animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',xlim=[5, num_epochs], legend=['train', 'test'])for epoch in range(num_epochs):for X, y in train_iter:trainer.zero_grad()l = loss(net(X), y)l.mean().backward()trainer.step()if (epoch + 1) % 5 == 0:animator.add(epoch + 1,(d2l.evaluate_loss(net, train_iter, loss),d2l.evaluate_loss(net, test_iter, loss)))print('w的L2范数:', net[0].weight.norm().item())train_concise(0)
train_concise(3)
plt.show()

无正则化效果(过拟合):
在这里插入图片描述

正则化效果:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/246113.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

ES文档索引、查询、分片、文档评分和分析器技术原理

技术原理 索引文档 索引文档分为单个文档和多个文档。 单个文档 新建单个文档所需要的步骤顺序: 客户端向 Node 1 发送新建、索引或者删除请求。节点使用文档的 _id 确定文档属于分片 0 。请求会被转发到 Node 3,因为分片 0 的主分片目前被分配在 …

微信小程序(十七)自定义组件生命周期(根据状态栏自适配)

注释很详细,直接上代码 上一篇 新增内容: 1.获取手机状态栏的高度 2.验证attached可以修改数据 3.动态绑定样式数值 源码: myNav.js Component({lifetimes:{//相当于vue的created,因为无法更新数据被打入冷宫created(){},//相当于vue的mount…

【JS基础】事件对象event、环境对象this、事件的高级操作

文章目录 一、事件对象1.1 事件对象是什么?1.2 使用方法 二、环境对象this以及回调函数2.1 它是什么?2.2 演示示例 三、事件的高级操作3.1 事件流3.2 事件捕获3.3 事件冒泡以及阻止冒泡3.4 事件解绑3.5 mouseover和mouseenter事件的区别3.6 事件委托它是…

HTML新手教程

HTML入门 教程:【狂神说Java】HTML5完整教学通俗易懂_哔哩哔哩_bilibili 一.初识HTML HyperTextMarkupLanguage(超文本标记语言) 超文本包括:文字、图片、音频、视频、动画。 HTML5的优势 世界知名浏览器厂商对HTML5的支持市场的…

解决WinForms跨线程操作控件的问题

解决WinForms跨线程操作控件的问题 介绍 在构建Windows窗体应用程序时,我们通常会遇到需要从非UI线程更新UI元素的场景。由于WinForms控件并不是线程安全的,直接这样做会抛出一个异常:“控件’control name’是从其他线程创建的,…

每日OJ题_算法_二分查找⑦_力扣153. 寻找旋转排序数组中的最小值

目录 力扣153. 寻找旋转排序数组中的最小值 解析代码 力扣153. 寻找旋转排序数组中的最小值 153. 寻找旋转排序数组中的最小值 - 力扣(LeetCode) 难度 中等 已知一个长度为 n 的数组,预先按照升序排列,经由 1 到 n 次 旋转 后…

node学习过程中的终端命令

冷的哥们手真tm冷,打字都是僵的,屮 目录 一、在学习nodejs过程中用到的终端命令总结 一、在学习nodejs过程中用到的终端命令 node -v nvm install 20.11.0 nvm list nvm list available nvm on nvm -v nvm use 20.11.0 node加要运行的js文件路径 ps&a…

Keycloak - docker 运行 前端集成

Keycloak - docker 运行 & 前端集成 这里的记录主要是跟我们的项目相关的一些本地运行/测试,云端用的 keycloak 版本不一样,不过本地我能找到的最简单的配置是这样的 docker 配置 & 运行 keycloak keycloak 有官方(Red Hat Inc.)的镜像&#…

搭建Redis集群

一 应用场景 为什么需要redis集群? 当主备复制场景,无法满足主机的单点故障时,需要引入集群配置。 一般数据库要处理的读请求远大于写请求 ,针对这种情况,我们优化数据库可以采用读写分离的策略。我们可以部 署一台…

数据结构与算法——队列

概述 计算机科学中,queue 是以顺序的方式维护的一组数据集合,在一端添加数据,从另一端移除数据。添加的一端称为尾,移除的一端称为头。 功能 插入offer(value : E) : boolean  取值并移除poll() : E  取值peek() : E  判断…

项目中日历管理学习使用

一些项目中会有日历或日期设置,最基本的会显示工作日,休息日,节假日等等,下面就是基于项目中的日历管理功能,要显示工作日,休息日,节假日 效果图 获取国家法定节假日工具类 public class Holi…

「QT」QString类的详细说明

✨博客主页何曾参静谧的博客📌文章专栏「QT」QT5程序设计📚全部专栏「VS」Visual Studio「C/C++」C/C++程序设计「UG/NX」BlockUI集合「Win」Windows程序设计「

Qt编写手机端视频播放器/推流工具/Onvif工具

一、视频播放器 同时支持多种解码内核,包括qmedia内核(Qt4/Qt5/Qt6)、ffmpeg内核(ffmpeg2/ffmpeg3/ffmpeg4/ffmpeg5/ffmpeg6)、vlc内核(vlc2/vlc3)、mpv内核(mpv1/mp2)、…

第十七讲_HarmonyOS应用开发Stage模型应用组件

HarmonyOS应用开发Stage模型应用组件 1. 应用级配置2. Module级配置3. Stage模型的组件3.1 AbilityStage3.1.1 AbilityStage的创建和配置3.1.2 AbilityStage的生命周期回调3.1.3 AbilityStage的事件回调: 3.2 UIAbility3.2.1 UIAbility生命周期3.2.3 UIAbility启动模…

修复idea,eclipse ,clion控制台中文乱码

控制台乱码问题主要原因并不在编译器IDE身上,还主要是Windows的控制台默认编码问题。。。 Powershell,cmd等默认编码可能不是UTF-8,无需改动IDE的settings或者properties(这治标不治本),直接让Windows系统…

初识MQRabbitMQ快速入门

一、同步和异步通讯 微服务间通讯有同步和异步两种方式: 同步通讯:就像打电话,需要实时响应。 异步通讯:就像发邮件,不需要马上回复。 两种方式各有优劣,打电话可以立即得到响应,但是你却不能…

老旧小区火灾频发,LoRa无线系统筑牢安全防线

近日,全国各地多个老旧小区火灾事故频发,从安微合肥南二环一老旧小区居民楼起火、上海金山区一小区居民楼火灾,到1月24日江西新余市特大火灾......都造成了不同程度的人员伤亡和财产损失,令人扼腕痛惜,教训十分深刻。 …

【图像分割】【深度学习】Windows10下UNet代码Pytorch实现与源码讲解

【图像分割】【深度学习】Windows10下UNet代码Pytorch实现与源码讲解 提示:最近开始在【医学图像分割】方面进行研究,记录相关知识点,分享学习中遇到的问题已经解决的方法。 文章目录 【图像分割】【深度学习】Windows10下UNet代码Pytorch实现与源码讲解前言UNet模型运行环境搭…

前端工程化之:webpack1-6(编译过程)

一、webpack编译过程 webpack 的作用是将源代码编译(构建、打包)成最终代码。 整个过程大致分为三个步骤: 初始化编译输出 1.初始化 初始化时我们运行的命令 webpack 为核心包, webpack-cli 提供了 webpack 命令,通过…

第八篇 交叉编译华为云Iot SDK到Orangepi3B

本篇主要内容: 一、交叉编译华为云Iot SDK依赖1.宿主机安装交叉编译工具链(1)选择下载交叉编译工具链(2)解压、添加环境变量、重启2.交叉编译依赖库(0) 准备工作(1) 交叉…