【深度学习笔记】09 权重衰减

09 权重衰减

    • 范数和权重衰减
    • 利用高维线性回归实现权重衰减
    • 权重衰减的简洁实现

范数和权重衰减

在训练参数化机器学习模型时,权重衰减(decay weight)是最广泛应用的正则化技术之一,它通常也被称为 L 2 L_2 L2正则化。这项技术通过函数与零的距离来衡量函数的复杂度,
因为在所有函数 f f f中,函数 f = 0 f = 0 f=0(所有输入都得到值 0 0 0
在某种意义上是最简单的。

一种简单的方法是通过线性函数
f ( x ) = w ⊤ x f(\mathbf{x}) = \mathbf{w}^\top \mathbf{x} f(x)=wx
中的权重向量的某个范数来度量其复杂性,
例如 ∥ w ∥ 2 \| \mathbf{w} \|^2 w2
要保证权重向量比较小,
最常用方法是将其范数作为惩罚项加到最小化损失的问题中。
将原来的训练目标最小化训练标签上的预测损失,
调整为最小化预测损失和惩罚项之和。

损失由下式给出:

L ( w , b ) = 1 n ∑ i = 1 n 1 2 ( w ⊤ x ( i ) + b − y ( i ) ) 2 . L(\mathbf{w}, b) = \frac{1}{n}\sum_{i=1}^n \frac{1}{2}\left(\mathbf{w}^\top \mathbf{x}^{(i)} + b - y^{(i)}\right)^2. L(w,b)=n1i=1n21(wx(i)+by(i))2.

x ( i ) \mathbf{x}^{(i)} x(i)是样本 i i i的特征,
y ( i ) y^{(i)} y(i)是样本 i i i的标签,
( w , b ) (\mathbf{w}, b) (w,b)是权重和偏置参数。

为了惩罚权重向量的大小,
必须以某种方式在损失函数中添加 ∥ w ∥ 2 \| \mathbf{w} \|^2 w2
我们通过正则化常数 λ \lambda λ来描述这种权衡,
这是一个非负超参数,我们使用验证数据拟合:

L ( w , b ) + λ 2 ∥ w ∥ 2 , L(\mathbf{w}, b) + \frac{\lambda}{2} \|\mathbf{w}\|^2, L(w,b)+2λw2,

对于 λ = 0 \lambda = 0 λ=0,我们恢复了原来的损失函数。
对于 λ > 0 \lambda > 0 λ>0,我们限制 ∥ w ∥ \| \mathbf{w} \| w的大小。
这里我们仍然除以 2 2 2:当我们取一个二次函数的导数时,
2 2 2 1 / 2 1/2 1/2会抵消。

通过平方 L 2 L_2 L2范数,我们去掉平方根,留下权重向量每个分量的平方和。
这使得惩罚的导数很容易计算:导数的和等于和的导数。

L 2 L_2 L2正则化回归的小批量随机梯度下降更新如下式:

w ← ( 1 − η λ ) w − η ∣ B ∣ ∑ i ∈ B x ( i ) ( w ⊤ x ( i ) + b − y ( i ) ) . \begin{aligned} \mathbf{w} & \leftarrow \left(1- \eta\lambda \right) \mathbf{w} - \frac{\eta}{|\mathcal{B}|} \sum_{i \in \mathcal{B}} \mathbf{x}^{(i)} \left(\mathbf{w}^\top \mathbf{x}^{(i)} + b - y^{(i)}\right). \end{aligned} w(1ηλ)wBηiBx(i)(wx(i)+by(i)).

我们根据估计值与观测值之间的差异来更新 w \mathbf{w} w
然而,我们同时也在试图将 w \mathbf{w} w的大小缩小到零。
这就是为什么这种方法有时被称为权重衰减
我们仅考虑惩罚项,优化算法在训练的每一步衰减权重。
与特征选择相比,权重衰减为我们提供了一种连续的机制来调整函数的复杂度。
较小的 λ \lambda λ值对应较少约束的 w \mathbf{w} w
而较大的 λ \lambda λ值对 w \mathbf{w} w的约束更大。

是否对相应的偏置 b 2 b^2 b2进行惩罚在不同的实践中会有所不同,
在神经网络的不同层中也会有所不同。
通常,网络输出层的偏置项不会被正则化。

利用高维线性回归实现权重衰减

%matplotlib inline
import torch
from torch import nn
from d2l import torch as d2l

首先生成数据,生成公式如下:

y = 0.05 + ∑ i = 1 d 0.01 x i + ϵ where  ϵ ∼ N ( 0 , 0.0 1 2 ) . y = 0.05 + \sum_{i = 1}^d 0.01 x_i + \epsilon \text{ where } \epsilon \sim \mathcal{N}(0, 0.01^2). y=0.05+i=1d0.01xi+ϵ where ϵN(0,0.012).

选择标签是关于输入的线性函数。
标签同时被均值为0,标准差为0.01高斯噪声破坏。
为了使过拟合的效果更加明显,我们可以将问题的维数增加到 d = 200 d = 200 d=200
并使用一个只包含20个样本的小训练集。

n_train, n_test, num_inputs, batch_size = 20, 100, 200, 5
true_w, true_b = torch.ones((num_inputs, 1)) * 0.01, 0.05
train_data = d2l.synthetic_data(true_w, true_b, n_train)
train_iter = d2l.load_array(train_data, batch_size)
test_data = d2l.synthetic_data(true_w, true_b, n_test)
test_iter = d2l.load_array(test_data, batch_size, is_train=False)

初始化模型参数

定义一个函数来随机初始化模型参数

def init_params():w = torch.normal(0, 1, size = (num_inputs, 1), requires_grad = True)b = torch.zeros(1, requires_grad = True)return [w, b]

定义 L 2 L_2 L2范数惩罚

def l2_penalty(w):return torch.sum(w.pow(2)) / 2

定义训练代码实现

下面的代码将模型拟合训练数据集,并在测试数据集上进行评估。

def train(lambd):w, b = init_params()net, loss = lambda X: d2l.linreg(X, w, b), d2l.squared_lossnum_epochs, lr = 100, 0.003animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',xlim=[5, num_epochs], legend=['train', 'test'])for epoch in range(num_epochs):for X, y in train_iter:# 增加了L2范数惩罚项,# 广播机制使l2_penalty(w)成为一个长度为batch_size的向量l = loss(net(X), y) + lambd * l2_penalty(w)l.sum().backward()d2l.sgd([w, b], lr, batch_size)if (epoch + 1) % 5 == 0:animator.add(epoch + 1, (d2l.evaluate_loss(net, train_iter, loss),d2l.evaluate_loss(net, test_iter, loss)))print('w的L2范数是:', torch.norm(w).item())

忽略正则化直接训练

用lamdb=0禁用权重衰减后运行代码。此时训练误差有所减少,但测试误差没有减少,这意味着出现了严重的过拟合。

train(lambd = 0)
w的L2范数是: 14.971677780151367

在这里插入图片描述

使用权重衰减

使用权重衰减来运行代码。此时训练误差增大,但测试误差减小。这正是我们期望从正则化中得到的效果。

train(lambd = 3)
w的L2范数是: 0.34405317902565

在这里插入图片描述

权重衰减的简洁实现

在实例化优化器时直接通过weight_decay指定weight decay超参数。默认情况下,PyTorch同时衰减权重和便宜。这里只为权重设置了weight_decay,所以偏置参数 b b b不会衰减。

def train_concise(wd):net = nn.Sequential(nn.Linear(num_inputs, 1))for param in net.parameters():param.data.normal_()loss = nn.MSELoss(reduction='none')num_epochs, lr = 100, 0.003# 偏置参数没有衰减trainer = torch.optim.SGD([{"params":net[0].weight,'weight_decay': wd},{"params":net[0].bias}], lr=lr)animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',xlim=[5, num_epochs], legend=['train', 'test'])for epoch in range(num_epochs):for X, y in train_iter:trainer.zero_grad()l = loss(net(X), y)l.mean().backward()trainer.step()if (epoch + 1) % 5 == 0:animator.add(epoch + 1,(d2l.evaluate_loss(net, train_iter, loss),d2l.evaluate_loss(net, test_iter, loss)))print('w的L2范数:', net[0].weight.norm().item())
train_concise(0)
w的L2范数: 13.416662216186523

在这里插入图片描述

train_concise(3)
w的L2范数: 0.39273694157600403

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/210737.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【人体解剖学与组织胚胎学】练习一高度相联知识点整理及对应习题

文章目录 [toc]骨性鼻旁窦填空题问答题 关节填空题简答题 胸廓填空题简答题![在这里插入图片描述](https://img-blog.csdnimg.cn/direct/827e7d1db3af42858d8734bb81911fea.jpeg)补充 骨性鼻旁窦 填空题 问答题 关节 填空题 简答题 胸廓 填空题 简答题 补充 第二肋对应胸骨…

混音编曲软件tudio One 6.5.1 保姆级安装教程

根据软件大数据显示De-Esser驯服人声嘶嘶声和其他高频声音,和其他 Studio One 中新的去实体插件一样高效且直观易用,使用“收听”按钮查找有问题的频率,然后使用相关的旋钮和 S-Mon 功能拨入 S-Reduce 量即可。实际上我们可以这样讲工作流和协…

Linux进程间通信之共享内存

📟作者主页:慢热的陕西人 🌴专栏链接:Linux 📣欢迎各位大佬👍点赞🔥关注🚓收藏,🍉留言 本博客主要内容讲解共享内存原理和相关接口的介绍,以及一个…

SpringBoot+SSM项目实战 苍穹外卖(3)

继续上一节的内容,本节完成菜品管理功能,包括公共字段自动填充、新增菜品、菜品分页查询、删除菜品、修改菜品。 目录 公共字段自动填充新增菜品文件上传实现新增菜品实现 useGeneratedKeys 菜品分页查询删除菜品修改菜品根据id查询菜品实现修改菜品实现…

Redis中的缓存穿透、雪崩、击穿(详细)

目录 一、概念 1. 缓存穿透(Cache Penetration) 解决方案: 2. 缓存雪崩(Cache Avalanche) 解决方案: 3. 缓存击穿(Cache Breakdown) 解决方案: 二、三者出现的根本原…

为XiunoBBS4.0开启redis缓存且支持密码验证

修改模块文件1 xiunoPHP/cache_redis.class.php: <?phpclass cache_redis {public $conf array();public $link NULL;public $cachepre ;public $errno 0;public $errstr ;public function __construct($conf array()) {if(!extension_loaded(Redis)) {return $thi…

有趣的代码——有故事背景的程序设计3

这篇文章再和大家分享一些有“背景”的程序设计&#xff0c;希望能够让大家学到知识的同时&#xff0c;对编程学习更感兴趣&#xff0c;更能在这条路上坚定地走下去。 目录 1.幻方问题 2.用函数打印九九乘法表 3.鸡兔同笼问题 4.字数统计 5.简单选择排序 1.幻方问题 幻方又…

Mac苹果视频剪辑:Final Cut Pro Mac

Final Cut Pro是一款由Apple公司开发的专业视频非线性编辑软件&#xff0c;是业界著名的视频剪辑软件之一。它最初发布于1999年&#xff0c;是Mac电脑上的一款独占软件。Final Cut Pro具有先进的剪辑工具、丰富的特效和颜色分级、音频处理等功能&#xff0c;使得用户可以轻松地…

Linux之重谈文件和c语言文件接口

重谈文件 文件 内容 属性, 所有对文件的操作都是: a.对内容操作 b.对属性操作 关于文件 一&#xff1a; 即使文件的内容为空&#xff0c;该文件也会在磁盘上也会占空间&#xff0c;因为文件不仅仅只有内容还有文件对应的属性&#xff0c;文件的内容会占用空间, 文件的属性也…

【面试】Java最新面试题资深开发-JVM第一弹

问题一&#xff1a;Java中的垃圾回收机制 在Java中&#xff0c;垃圾回收是如何工作的&#xff0c;可以简要描述一下垃圾回收的算法有哪些吗&#xff1f; 在Java中&#xff0c;垃圾回收是一种自动管理内存的机制&#xff0c;它负责识别不再被程序引用的对象并释放其占用的内存…

HarmonyOS与AbilitySlice路由配置

上一章我有教到鸿蒙应用开发——Ability鸿蒙应用开发的基础知识&#xff0c;那么今天我们来讲一下AbilitySlice路由配置 AbilitySlice路由配置 虽然一个Page可以包含多个AbilitySlice&#xff0c;但是Page进入前台时界面默认只展示一个AbilitySlice。默认展示的AbilitySlice是…

Java+SSM springboot+MySQL家政服务预约网站设计en24b

随着社区居民对生活品质的追求以及社会老龄化的加剧&#xff0c;社区居民对家政服务的需求越来越多&#xff0c;家政服务业逐渐成为政府推动、扶持和建设的重点行业。家政服务信息化有助于提高社区家政服务的工作效率和质量。 本次开发的家政服务网站是一个面向社区的家政服务网…

TCP首部格式_基本知识

TCP首部格式 表格索引: 源端口目的端口 序号 确认号 数据偏移保留 ACK等 窗口检验和紧急指针 TCP报文段首部格式图 源端口与目的端口: 各占16位 序号:占32比特&#xff0c;取值范围0~232-1。当序号增加到最后一个时&#xff0c;下一个序号又回到0。用来指出本TCP报文段数据载…

面试 Java 框架八股文十问十答第二期

面试 Java 框架八股文十问十答第二期 作者&#xff1a;程序员小白条&#xff0c;个人博客 相信看了本文后&#xff0c;对你的面试是有一定帮助的&#xff01; ⭐点赞⭐收藏⭐不迷路&#xff01;⭐ 1.AOP的术语&#xff0c;以及两种动态代理实现方法&#xff0c;以及它们的区别…

Notepad++批量添加引号

工作中经常会遇到这样情景&#xff1a;业务给到一批订单号&#xff0c;需要查询这批订单的某些字段信息。在where条件中需要传入这些订单号的数组&#xff0c;并且订单号用引号引起&#xff0c;用引号隔开。 字符串之间长度相同 可以按住CtrlAlt和鼠标左键选中区域&#xff0…

手持式安卓主板_PDA安卓板_智能手持终端方案

手持式安卓主板方案是一种智能终端设备&#xff0c;具备自动对焦和闪光灯功能&#xff0c;可以在昏暗的环境下快速扫描二维码并轻松采集数据。该方案还提供多渠道支付和数据采集功能&#xff0c;为用户提供了便捷的体验。 该方案的产品基于手持式安卓主板&#xff0c;并搭载了八…

基于ROPNet项目训练modelnet40数据集进行3d点云的配置

项目地址&#xff1a; https://github.com/zhulf0804/ROPNet 在 MVP Registration Challenge (ICCV Workshop 2021)&#xff08;ICCV Workshop 2021&#xff09;中获得了第二名。项目可以在win10环境下运行。 论文地址&#xff1a; https://arxiv.org/abs/2107.02583 网络简介…

基于H5“汉函谷关起点新安县旅游信息系统”设计与实现

目 录 摘 要 1 ABSTRACT 2 第1章 绪论 3 1.1 系统开发背景及意义 3 1.2 系统开发的目标 3 第2章 主要开发技术介绍 5 2.1 H5技术介绍 5 2.2 Visual Studio 技术介绍 5 2.3 SQL Server数据库技术介绍 6 第3章 系统分析与设计 7 3.1 可行性分析 7 3.1.1 技术可行性 7 3.1.2 操作…

HTTP请求

前言 HTTP是应用层的一个协议。实际我们访问一个网页&#xff0c;都会像该网页的服务器发送HTTP请求&#xff0c;服务器解析HTTP请求&#xff0c;返回HTTP响应。如此就是我们获取资源或者上传资源的原理 HTTP请求报头格式 图片来自网络 HTTP请求报头总体有四部分&#xff1a;…

pycharm中绘制一个3D曲线

import numpy as np import matplotlib.pyplot as plt # 中文的设置 import matplotlib as mp1 from mpl_toolkits.mplot3d import Axes3D mp1.rcParams["font.sans-serif"] ["kaiti"] mp1.rcParams["axes.unicode_minus"] False # 数据创建 X…