【PyTorch】权重衰减

文章目录

  • 1. 理论介绍
  • 2. 实例解析
    • 2.1. 实例描述
    • 2.2. 代码实现

1. 理论介绍

  • 通过对模型过拟合的思考,人们希望能通过某种工具调整模型复杂度,使其达到一个合适的平衡位置。
  • 权重衰减(又称 L 2 L_2 L2正则化)通过为损失函数添加惩罚项,用来惩罚权重的 L 2 L_2 L2范数,从而限制模型参数值,促使模型参数更加稀疏或更加集中,进而调整模型的复杂度,即: L ( w , b ) + λ 2 ∥ w ∥ 2 L(\mathbf{w}, b) + \frac{\lambda}{2} \|\mathbf{w}\|^2 L(w,b)+2λw2其中 λ \lambda λ权重衰减的超参数
  • L p L_p Lp范数: ∥ x ∥ p = ( ∑ i = 1 n ∣ x i ∣ p ) 1 / p \|\mathbf{x}\|_p = \left(\sum_{i=1}^n \left|x_i \right|^p \right)^{1/p} xp=(i=1nxip)1/p
    p = 1 p=1 p=1时称为 L 1 L_1 L1范数;当 p = 2 p=2 p=2时称为 L 2 L_2 L2范数。
    惩罚 L 1 L_1 L1范数会导致模型将权重集中在一小部分特征上, 而将其他权重清除为零, 这称为特征选择;惩罚 L 2 L_2 L2范数会导致模型在大量特征上均匀分布权重,使得模型对单个变量的观测误差更为稳定。
  • 通常不建议对偏置进行正则化,因为偏置的取值并不像权值那样会随着训练过程而变化,因此对偏置进行正则化对于控制模型的复杂度影响较小;另外,对偏置进行正则化可能会导致对数据中的偏移进行过度拟合,而减弱了模型对其他特征的学习。

2. 实例解析

2.1. 实例描述

使用以下公式生成包含20个样本的小训练集和100个样本的测试集,并用线性网络进行拟合: y = 0.05 + ∑ i = 1 200 0.01 x i + ϵ where  ϵ ∼ N ( 0 , 0.0 1 2 ) . y = 0.05 + \sum_{i = 1}^{200} 0.01 x_i + \epsilon \text{ where } \epsilon \sim \mathcal{N}(0, 0.01^2). y=0.05+i=12000.01xi+ϵ where ϵN(0,0.012).

2.2. 代码实现

  • 主要代码
optimizer = optim.SGD([{"params": net.weight,"weight_decay": weight_decay},{"params": net.bias}], lr=lr)
  • 完整代码
import os
import torch
from torch import nn, optim
from torch.utils.data import TensorDataset, DataLoader
from tensorboardX import SummaryWriter
from rich.progress import trackdef data_generator(w, b, num):"""为线性模型生成数据"""X = torch.randn(num, len(w))y = torch.sum(X @ w, dim=1) + by += torch.normal(0, 0.01, y.shape)return X, y.reshape(-1, 1)def load_dataset(*tensors):"""加载数据集"""dataset = TensorDataset(*tensors)return DataLoader(dataset, batch_size, shuffle=True)def evaluate_loss(dataloader, net, criterion):"""评估模型在指定数据集上的损失"""num_examples = 0loss_sum = 0.0with torch.no_grad():for X, y in dataloader:X, y = X.cuda(), y.cuda()loss = criterion(net(X), y)num_examples += y.shape[0]loss_sum += loss.sum()return loss_sum / num_examplesif __name__ == '__main__':# 全局参数设置lr = 0.003num_epochs = 100batch_size = 5# 创建记录器def log_dir():root = "runs"if not os.path.exists(root):os.mkdir(root)order = len(os.listdir(root)) + 1return f'{root}/exp{order}'writer = SummaryWriter(log_dir=log_dir())# 合成数据集num_inputs = 200n_train, n_test = 20, 100true_w, true_b = torch.ones((num_inputs, 1)) * 0.01, 0.05X, y = data_generator(true_w, true_b, n_train + n_test)# 加载数据集dataloader_train = load_dataset(X[:n_train], y[:n_train])dataloader_test = load_dataset(X[n_train:], y[n_train:])def loop(weight_decay):# 定义模型net = nn.Linear(num_inputs, 1).cuda()nn.init.normal_(net.weight)nn.init.constant_(net.bias, 0)criterion = nn.MSELoss(reduction='none')optimizer = optim.SGD([{"params": net.weight,"weight_decay": weight_decay},{"params": net.bias}], lr=lr)# 训练循环for epoch in track(range(num_epochs), description=f'wd={weight_decay}'):for X, y in dataloader_train:X, y = X.cuda(), y.cuda()loss = criterion(net(X), y)optimizer.zero_grad()loss.mean().backward()optimizer.step()writer.add_scalars(f'wd={weight_decay}', {'train_loss': evaluate_loss(dataloader_train, net, criterion),'test_loss': evaluate_loss(dataloader_test, net, criterion),}, epoch)for weight_decay in [0, 3]:loop(weight_decay)writer.close()
  • 输出结果
    • weight_decay = 0
      0
    • weight_decay = 3
      3

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/211732.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【ArcGIS微课1000例】0078:创建点、线、面数据的最小几何边界

本实例为专栏系统文章:讲述在ArcMap10.6中创建点数据最小几何边界(范围),配套案例数据,持续同步更新! 文章目录 一、工具介绍二、实战演练三、注意事项一、工具介绍 创建包含若干面的要素类,用以表示封闭单个输入要素或成组的输入要素指定的最小边界几何。 工具位于:数…

HarmonyOS开发(十):通知

1、通知概述 1.1、简介 应用可以通过通知接口发送通知消息,终端用户可以通过通知栏查看通知内容,也可以点击通知来打开应用。 通知使用的的常见场景: 显示接收到的短消息、即使消息...显示应用推送消息显示当前正在进行的事件&#xff0c…

聚观早报 |东方甄选将上架文旅产品;IBM首台模块化量子计算机

【聚观365】12月6日消息 东方甄选将上架文旅产品 IBM首台模块化量子计算机 新思科技携手三星上新兴领域 英伟达与软银推动人工智能研发 苹果对Vision Pro供应商做出调整 东方甄选将上架文旅产品 东方甄选宣布12月10日将在东方甄选APP上线文旅产品,受这一消息影…

软件工程之需求分析

一、对需求的基本认识 1.需求分析简介 (1)什么是需求 用户需求:由用户提出。原始的用户需求通常是不能直接做成产品的,需要对其进行分析提炼,最终形成产品需求。 产品需求:产品经理针对用户需求提出的解决方案。 (2)为什么要…

Web前端JS如何获取 Video/Audio 视音频声道(左右声道|多声道)、视音频轨道、音频流数据

写在前面: 根据Web项目开发需求,需要在H5页面中,通过点击视频列表页中的任意视频进入视频详情页,然后根据视频的链接地址,主要是 .mp4 文件格式,在进行播放时实时的显示该视频的音频轨道情况,并…

短视频购物系统源码:构建创新购物体验的技术深度解析

短视频购物系统作为电商领域的新宠,其背后的源码实现是其成功的关键。本文将深入探讨短视频购物系统的核心技术和源码设计,以揭示其如何构建创新购物体验的技术奥秘。 1. 技术架构与框架选择 短视频购物系统的源码首先考虑的是其技术架构。常见的选择…

多传感器融合SLAM在自动驾驶方向的初步探索的记录

1. VIO的不可观问题 现有的VIO都是解决的六自由度的问题, 但是对于行驶在路面上的车来说, 通常情况下不会有roll与z方向的自由度, 而且车体模型限制了不可能有纯yaw的变换. 同时由于IMU在Z轴上与roll, pitch上激励不足, 会导致IMU在初始化过程中尺度不准以及重力方向估计错误,…

华为数通---BFD多跳检测示例

定义 双向转发检测BFD(Bidirectional Forwarding Detection)是一种全网统一的检测机制,用于快速检测、监控网络中链路或者IP路由的转发连通状况。 目的 为了减小设备故障对业务的影响,提高网络的可靠性,网络设备需要…

User: zhangflink is not allowed to impersonate zhangflink

使用hive2连接进行添加数据是报错: [08S01][1] Error while processing statement: FAILED: Execution Error, return code 1 from org.apache.hadoop.hive.ql.exec.mr.MapRedTask. User: zhangflink is not allowed to impersonate zhangflink 有些文章说需要修…

App的测试,和传统软件测试有哪些区别?应该增加哪些方面的测试用例?

从上图可知,测试人员所测项目占比中,App测试占比是最高的。 这就意味着学习期间,我们要花最多的精力去学App的各类测试。也意味着我们找工作前,就得知道,App的测试点是什么,App功能我们得会测试&#xff0…

2023 IoTDB 用户大会成功举办,深入洞察工业互联网数据价值

2023 年 12 月 3 日,中国通信学会作为指导单位,Apache IoTDB Community、清华大学软件学院、中国通信学会开源技术委员会联合主办,“科创中国”开源产业科技服务团和天谋科技(北京)有限公司承办的 2023 IoTDB 用户大会…

学习极市开发平台

这是官网的链接:极市开发者平台-计算机视觉算法开发落地平台-极市科技 (cvmart.net) 第一次用这个平台有很多问题,首先在使用这个平台之前,我大部分时候使用的是百度的飞浆平台,也就是BML,去训练一些深度学习的模型。 …

配置端口安全示例

组网需求 如图1所示,用户PC1、PC2、PC3通过接入设备连接公司网络。为了提高用户接入的安全性,将接入设备Switch的接口使能端口安全功能,并且设置接口学习MAC地址数的上限为接入用户数,这样其他外来人员使用自己带来的PC无法访问公…

AI 绘画 | Stable Diffusion 动漫人物真人化

前言 如何让一张动漫人物变成真实系列人物?Stable Diffusion WebUI五步即可实现。快来使用AI绘画打开异世界的大门吧!!! 动漫真人化 首先在图生图里上传一张二次元动漫人物图片,然后选择一个真实系人物画风的大模型,最后点击DeepBooru 反推,自动填充提示词,调整重绘…

vue中实现数字+英文字母组合键盘

完整代码 <template><div class"login"><div click"setFileClick">欢迎使用员工自助终端</div><el-dialog title"初始化设置文件打印消耗品配置密码" :visible.sync"dialogSetFile" width"600px&quo…

C语言之联合和枚举

C语言之联合和枚举 文章目录 C语言之联合和枚举1. 联合体1.1 联合体的声明1.2 联合体的特点1.3 结构体和联合体对比1.4 联合体大小的计算1.5 联合体小练习 2. 枚举2.1 枚举类型的声明2.2 枚举类型的优点2.3 枚举类型的使用 1. 联合体 1.1 联合体的声明 像结构体⼀样&#xff…

苹果OS X系统介绍(Mac OS --> Mac OS X --> OS X --> macOS)

文章目录 OS X系统介绍历史与版本架构内核与低级系统图形&#xff0c;媒体和用户界面应用程序和服务 特性用户友好强大的命令行安全性集成与互操作性 总结 OS X系统介绍 OS X是由苹果公司为Macintosh计算机系列设计的基于UNIX的操作系统。其界面友好&#xff0c;易于使用&…

CleanMyMac x4.15软件应用程序永久使用

许多刚从Windows系统转向Mac系统怀抱的用户&#xff0c;一开始难免不习惯&#xff0c;因为Mac系统没有像Windows一样的C盘、D盘&#xff0c;分盘分区明显。因此这也带来了一些问题&#xff0c;关于Mac的磁盘的清理问题&#xff0c;怎么进行清理&#xff1f;怎么确保清理的干净&…

Linux系统调试课:网络性能工具总结

文章目录 一、网络性能指标二、netstat三、route四、iptables沉淀、分享、成长,让自己和他人都能有所收获!😄 📢本篇章一起了解下网络性能工具。 一、网络性能指标 从网络性能指标出发,你更容易把性能工具同系统工作原理关联起来,对性能问题有宏观的认识和把握。这样,…

学习记录---Kubernetes的资源指标管道-metrics api的安装部署

一、简介 Metrics API&#xff0c;为我们的k8s集群提供了一组基本的指标(资源的cpu和内存)&#xff0c;我们可以通过metrics api来对我们的pod开展HPA和VPA操作(主要通过在pod中对cpu和内存的限制实现动态扩展)&#xff0c;也可以通过kubectl top的方式&#xff0c;获取k8s中n…