【pytorch】weight_norm和spectral_norm

apply_parametrization_normspectral_norm是 PyTorch 中用于对模型参数进行规范化的方法,但它们在实现和使用上有显著的区别。以下是它们的主要区别和对比:

实现方式

weight_norm:

weight_norm 是一种参数重参数化技术,将权重分解为两个部分:方向(v) 和 大小(g)。
具体来说,权重 w 被重参数化为:
w=g⋅
∥v∥
v

其中,g 是标量,表示权重的大小;v 是向量,表示权重的方向。
这种方法通过分离权重的大小和方向,使得优化过程更加稳定。

归一化后的权重向量实际上是 v 的归一化形式,即 v_normalized = v / ||v||,而 weight 的值为 g * v_normalized

spectral_norm:

spectral_norm 是一种基于谱范数的规范化方法,谱范数定义为矩阵 M 的最大奇异值:

在这里插入图片描述

具体来说,谱范数是矩阵 M 作用在单位向量上时的最大放大因子。

作用:通过限制矩阵的最大奇异值,控制矩阵的放大能力,从而提高模型的稳定性和泛化能力

通过幂迭代法(Power Iteration)计算矩阵的最大奇异值。具体步骤如下:
初始化两个向量 u 和 v。
迭代计算:
在这里插入图片描述

最大奇异值 在这里插入图片描述

然后将权重规范化为
在这里插入图片描述

使用场景

weight_norm:

主要用于规范化权重,特别适用于需要控制权重大小的场景。
例如,在某些生成模型或自注意力机制中,权重的大小对模型的稳定性和性能有重要影响。
weight_norm 提供了一种简单且直接的方式来实现权重的规范化。

spectral_norm:

常用于生成对抗网络(GAN)中,通过限制生成器和判别器的最大奇异值,提高模型的稳定性和泛化能力。
优点:能够有效控制矩阵的放大能力,适用于需要限制模型输出范围的场景

代码样例

weight_norm:

import random
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.utils.weight_norm as weight_norm# 创建一个简单的线性层
linear_layer = nn.Linear(2, 3)# Override the weights of linear_layer
linear_layer.weight.data = torch.tensor([[0, -2e-2], [0.3, 1.0], [1e-2, 0]], dtype=torch.float32)  # 3x2
linear_layer.bias.data = torch.tensor([-0.3, 0, -2e-2], dtype=torch.float32)  # 3x1# Apply weight normalization
linear_layer = weight_norm(linear_layer, name='weight')optimizer = optim.Adam(linear_layer.parameters(), lr=2e-3)# 打印原始权重
print("################原始权重:################")
print("linear_layer.weight_v: ", linear_layer.weight_v)
print("linear_layer.weight_g: ", linear_layer.weight_g)
print("linear_layer.weight: ", linear_layer.weight)
print("linear_layer.bias: ", linear_layer.bias)# 前向传播
for i in range(1000):input_tensor = torch.randn(10, 2) * random.random() * 10output = linear_layer(input_tensor)loss = (1 - output.mean())optimizer.zero_grad()loss.backward()  # 反向传播,触发梯度计算optimizer.step()# 打印规范化后的权重
print("################规范化后的权重:#################")
print("linear_layer.weight_v: ", linear_layer.weight_v)
print("linear_layer.weight_g: ", linear_layer.weight_g)
print("linear_layer.bias: ", linear_layer.bias)# Optionally, print the effective weight
print("linear_layer.weight (effective): ", linear_layer.weight)# ... existing code...# 前向传播
#input_tensor = torch.tensor([[1.0, 2.0]], dtype=torch.float32)  # 1x2
for i in range(1000):input_tensor = torch.randn(10, 2)*random.random()*10output = linear_layer(input_tensor)optimizer.zero_grad()(1-output.mean()).backward()  # 反向传播,触发梯度计算optimizer.step()  # 更新权重output = linear_layer(input_tensor)optimizer.zero_grad()(1-output.mean()).backward()  # 反向传播,触发梯度计算optimizer.step()  # 更新权重# 打印规范化后的权重
print("################规范化后的权重:#################")
print("linear_layer.weight: ", linear_layer.weight)  # 权重的大小
print("linear_layer.bias: ", linear_layer.bias) 
print("linear_layer.weight_v: ", linear_layer.weight_v)  
print("linear_layer.weight_g: ", linear_layer.weight_g) 

输出结果

################原始权重:################
linear_layer.weight_v:  Parameter containing:
tensor([[ 0.0000, -0.0200],[ 0.3000,  1.0000],[ 0.0100,  0.0000]], requires_grad=True)
linear_layer.weight_g:  Parameter containing:
tensor([[0.0200],[1.0440],[0.0100]], requires_grad=True)
linear_layer.weight:  tensor([[ 0.0000, -0.0200],[ 0.3000,  1.0000],[ 0.0100,  0.0000]], grad_fn=<WeightNormInterfaceBackward0>)
linear_layer.bias:  Parameter containing:
tensor([-0.3000,  0.0000, -0.0200], requires_grad=True)
################规范化后的权重:#################
linear_layer.weight_v:  Parameter containing:
tensor([[0.0599, 0.0568],[0.4113, 0.8878],[0.0544, 0.0556]], requires_grad=True)
linear_layer.weight_g:  Parameter containing:
tensor([[0.1038],[1.1324],[0.0897]], requires_grad=True)
linear_layer.bias:  Parameter containing:
tensor([1.7000, 2.0000, 1.9800], requires_grad=True)
linear_layer.weight (effective):  tensor([[0.0751, 0.0716],[0.4758, 1.0276],[0.0625, 0.0643]], grad_fn=<WeightNormInterfaceBackward0>)

spectral_norm:

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.utils.weight_norm as spectral_norm# 创建一个简单的线性层
linear_layer = nn.Linear(2, 3)
# 覆盖线性层的权重
linear_layer.weight.data = torch.tensor([[0, -2e2], [0.3, 1.0], [1e2, 0]], dtype=torch.float32)  # 3x2
linear_layer.bias.data = torch.tensor([-0.3, 0, -2e2], dtype=torch.float32)  # 3x1# 应用 weight_norm
linear_layer = spectral_norm(linear_layer, name='weight')# 定义优化器
optimizer = optim.Adam(linear_layer.parameters(), lr=1e-4)# 打印原始权重
print("\n################ 原始权重:################\n")
print("linear_layer.weight: ", linear_layer.weight)
print("linear_layer.bias: ", linear_layer.bias)# 前向传播及训练
for i in range(1000):input_tensor = torch.randn(10, 2)output = linear_layer(input_tensor)loss = (1 - output.mean())  # 损失函数optimizer.zero_grad()loss.backward()  # 反向传播optimizer.step()  # 更新权重# 打印规范化后的权重
print("\n################ 规范化后的权重:################\n")
print("linear_layer.weight: ", linear_layer.weight)  # 权重的大小
print("linear_layer.bias: ", linear_layer.bias)
print("linear_layer.weight_v: ", linear_layer.weight_v)  
print("linear_layer.weight_g: ", linear_layer.weight_g)

输出结果

################ 原始权重:################linear_layer.weight:  tensor([[   0.0000, -200.0000],[   0.3000,    1.0000],[ 100.0000,    0.0000]], grad_fn=<WeightNormInterfaceBackward0>)
linear_layer.bias:  Parameter containing:
tensor([  -0.3000,    0.0000, -200.0000], requires_grad=True)################ 规范化后的权重:################linear_layer.weight:  tensor([[-1.7618e-04, -2.0000e+02],[ 2.9977e-01,  9.9956e-01],[ 1.0000e+02, -3.6516e-04]], grad_fn=<WeightNormInterfaceBackward0>)
linear_layer.bias:  Parameter containing:
tensor([-2.0001e-01,  1.0000e-01, -1.9989e+02], requires_grad=True)
linear_layer.weight_v:  Parameter containing:
tensor([[-2.0177e-04, -2.0001e+02],[ 2.9990e-01,  1.0001e+00],[ 1.0001e+02, -3.7132e-04]], requires_grad=True)
linear_layer.weight_g:  Parameter containing:
tensor([[200.0006],[  1.0435],[ 99.9998]], requires_grad=True)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/17314.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

回归预测 | Matlab实现PSO-HKELM粒子群算法优化混合核极限学习机多变量回归预测

回归预测 | Matlab实现PSO-HKELM粒子群算法优化混合核极限学习机多变量回归预测 目录 回归预测 | Matlab实现PSO-HKELM粒子群算法优化混合核极限学习机多变量回归预测效果一览基本介绍程序设计参考资料 效果一览 基本介绍 1.回归预测 | Matlab实现PSO-HKELM粒子群算法优化混合核…

多媒体软件安全与授权新范例,用 CodeMeter 实现安全、高效的软件许可管理

背景概述 Reason Studios 成立于 1994 年&#xff0c;总部位于瑞典斯德哥尔摩&#xff0c;是全球领先的音乐制作软件开发商。凭借创新的软件产品和行业标准技术&#xff0c;如 ReWire 和 REX 文件格式&#xff0c;Reason Studios 为全球专业音乐人和业余爱好者提供了一系列高质…

C++,STL容器适配器,stack:栈深入解析

文章目录 一、容器概览与核心特性核心特性速览二、底层实现原理1. 容器适配器设计2. 默认容器对比三、核心操作详解1. 容器初始化2. 元素操作接口3. 自定义栈实现四、实战应用场景1. 括号匹配校验2. 浏览器历史记录管理五、性能优化策略1. 底层容器选择基准2. 内存预分配技巧六…

互联网大厂中面试的高频计算机网络问题及详解

前言 哈喽各位小伙伴们,本期小梁给大家带来了互联网大厂中计算机网络部分的高频面试题,本文会以通俗易懂的语言以及图解形式描述,希望能给大家的面试带来一点帮助,祝大家offer拿到手软!!! 话不多说,我们立刻进入本期正题! 一、计算机网络基础部分 1 …

「软件设计模式」工厂方法模式 vs 抽象工厂模式

前言 在软件工程领域&#xff0c;设计模式是解决常见问题的经典方案。本文将深入探讨两种创建型模式&#xff1a;工厂方法模式和抽象工厂模式&#xff0c;通过理论解析与实战代码示例&#xff0c;帮助开发者掌握这两种模式的精髓。 一、工厂方法模式&#xff08;Factory Metho…

Docker部署Alist网盘聚合管理工具完整教程

Docker部署Alist网盘聚合管理工具完整教程 部署alist初始化修改密码添加存储&#xff01;联通网盘阿里云盘百度网盘 部署alist 本文以Linux Docker部署&#xff0c;假设你已经安装好Docker docker run -d --restartalways \-v /your/data:/opt/alist/data \-p 5244:5244 \-e …

Excel常用操作

Excel常用操作 学习资源 37_电子表格处理考点精讲_设置数据格式_哔哩哔哩_bilibili 快速输入数据与编辑数据 一个工作簿可以包含多个工作表 特殊数据的添加格式 输入负数, 例如-3、-5 常规输入, 直接输入-3、-5;使用(), 例如在单元格中输入(3)回车即可变为-3;上述括号不区分中…

SpringMVC环境搭建

文章目录 1.模块创建1.创建一个webapp的maven项目2.目录结构 2.代码1.HomeController.java2.home.jsp3.applicationContext.xml Spring配置文件4.spring-mvc.xml SpringMVC配置文件5.web.xml 配置中央控制器以及Spring和SpringMVC配置文件的路径6.index.jsp 3.配置Tomcat1.配置…

常见的排序算法:插入排序、选择排序、冒泡排序、快速排序

1、插入排序 步骤&#xff1a; 1.从第一个元素开始&#xff0c;该元素可以认为已经被排序 2.取下一个元素tem&#xff0c;从已排序的元素序列从后往前扫描 3.如果该元素大于tem&#xff0c;则将该元素移到下一位 4.重复步骤3&#xff0c;直到找到已排序元素中小于等于tem的元素…

Golang的容器化部署流程

# Golang的容器化部署流程 什么是容器化部署 容器化部署是将应用程序、运行环境及其依赖项打包在一起&#xff0c;以便可以在任何环境中快速、一致地运行的技术。它提供了更高效的资源利用、更便捷的部署和更稳定的环境。 的容器化支持 天生支持跨平台编译&#xff0c;使得将Go…

前缀树算法篇:前缀信息的巧妙获取

前缀树算法篇&#xff1a;前缀信息的巧妙获取 那么前缀树算法是一个非常常用的算法&#xff0c;那么在介绍我们前缀树具体的原理以及实现上&#xff0c;我们先来说一下我们前缀树所应用的一个场景&#xff0c;那么在一个字符串的数据集合当中&#xff0c;那么我们查询我们某个字…

tomcat html乱码

web tomcat html中文乱码 将html文件改成jsp <% page language"java" contentType"text/html; charsetUTF-8" pageEncoding"UTF-8"%>添加 <meta charset"UTF-8">

安全测试|SSRF请求伪造

前言 SSRF漏洞是一种在未能获取服务器权限时&#xff0c;利用服务器漏洞&#xff0c;由攻击者构造请求&#xff0c;服务器端发起请求的安全漏洞&#xff0c;攻击者可以利用该漏洞诱使服务器端应用程序向攻击者选择的任意域发出HTTP请求。 很多Web应用都提供了从其他的服务器上…

【笛卡尔树】

笛卡尔树 笛卡尔树定义构建性质 习题P6453 [COCI 2008/2009 #4] PERIODNICF1913D Array CollapseP4755 Beautiful Pair[ARC186B] Typical Permutation Descriptor 笛卡尔树 定义 笛卡尔树是一种二叉树&#xff0c;每一个节点由一个键值二元组 ( k , w ) (k,w) (k,w) 构成。要…

java测试题

String str2 "he""llo" xx.java--->xx.class----->内存 在由.java文件通过javac命令变为.class文件的过程中已经自动拼接变为“hello” String str2 "he"new String"llo" 在编译为,class文件时还是两个字符串“he”和“llo”…

SQLite 数据库:优点、语法与快速入门指南

文章目录 一、引言二、SQLite 的优点 &#x1f4af;三、SQLite 的基本语法3.1 创建数据库3.2 创建表3.3 插入数据3.4 查询数据3.5 更新数据3.6 删除数据3.7 删除表 四、快速入门指南4.1 安装 SQLite4.2 创建数据库4.3 创建表4.4 插入数据4.5 查询数据4.6 更新数据4.7 删除数据4…

无人机之无线传输技术!

一、Lightbridge和OcuSync图传技术 Lightbridge技术&#xff1a;这是大疆自主研发的一种专用通信链路技术&#xff0c;使用单向图像数据传输&#xff0c;类似于电视广播塔的数据传输形式。它主要采用2.4GHz频段进行传输&#xff0c;并且可以实现几乎“零延时”的720p高清图像传…

逻辑分析仪的使用-以STM32C8T6控制SG90舵机为例

STM32C8T6控制SG90舵机 1.逻辑分析仪作用 逻辑分析仪在嵌入式开发中的作用非常重要&#xff0c;它是开发、调试和排错过程中的一个不可或缺的工具。具体来说&#xff0c;逻辑分析仪的作用包括以下几个方面&#xff1a; 1.信号捕获和分析&#xff1a; 逻辑分析仪能够实时捕获多个…

线性代数 第七讲 二次型_标准型_规范型_坐标变换_合同_正定二次型详细讲解_重难点题型总结

文章目录 1.二次型1.1 二次型、标准型、规范型、正负惯性指数、二次型的秩1.2 坐标变换1.3 合同1.4 正交变换化为标准型1.5 可逆线性变换和正交变换1.6 二次型化标准形&#xff0c;二次型化规范形的联系思考1.8 两个二次型联系的思考1.9 对于配方法问题的深入思考 2.二次型的主…

vue学习9

1.文章分类页面-element-plus表格 基本架子-PageContainer封装 按需引入的彩蛋&#xff0c;components里面的内容都会自动注册 用el-card组件&#xff0c;里面使用插槽或具名插槽 文章分类渲染 & loading处理 序号&#xff1a; <el-table-column type"index"…