7.5 详解批量规范化 对某个维度取平均值代码解读

一.举例计算均值、方差

假设我们有以下一组数据:[10, 15, 20, 25, 30]首先,我们计算均值,即将所有数据相加后除以数据的数量:
**均值** = (10 + 15 + 20 + 25 + 30) / 5 = 100 / 5 = 20

1.1标准差

接下来,我们计算标准差,用来衡量数据的离散程度。标准差的计算公式如下:
**标准差** = sqrt( ( (x1 - 平均值)^2 + (x2 - 平均值)^2 + ... + (xn - 平均值)^2 ) / n )标准差 = sqrt( ( (10 - 20)^2 + (15 - 20)^2 + (20 - 20)^2 + (25 - 20)^2 + (30 - 20)^2 ) / 5 )= sqrt( (100 + 25 + 0 + 25 + 100) / 5 )= sqrt( 250 / 5 )= sqrt( 50 )7.071

二.对例子中的数据标准化

现在,让我们对数据进行标准化。标准化是将数据转换为均值为0,标准差为1的标准正态分布。
对于每个数据点,我们可以使用以下公式进行标准化:

2.1公式

标准化数据 = (原始数据 - 均值) / 标准差

对于我们的数据集,标准化后的结果如下:
(10 - 20) / 7.071-1.414
(15 - 20) / 7.071-0.707
(20 - 20) / 7.071 = 0
(25 - 20) / 7.0710.707
(30 - 20) / 7.0711.414因此,经过标准化后的数据集为:[-1.414, -0.707, 0, 0.707, 1.414]

2.2 标准化后的数据特点

均值为0,方差为1

2.3 将数据标准化有什么好处?

  1. 消除数据相差太大的影响:标准化后的数据具有相同的量纲,消除了原始数据中不同变量之间因量纲不同而引起的影响,确保各个变量对分析结果的贡献相对均等。

  2. 提高模型性能:在许多机器学习算法中,输入数据的标准化可以提高模型的训练效果和预测准确性。标准化后的数据分布更接近于标准正态分布,可以降低模型对异常值的敏感性,使模型更加稳定和可靠。

  3. 加速优化过程:某些优化算法(如梯度下降法)在处理标准化后的数据时更加高效。标准化可以使优化算法收敛更快,并且更容易找到全局最优解或更接近最优解的解。

  4. 减少计算开销:标准化后的数据具有较小的数值范围,可以减少计算时的数据溢出或欠溢问题,提高计算的稳定性和准确性。

2.4 为什么数据可以标准化?可视化标准化效果

** BatchNorm(归一化/标准化)**

归一化/标准化实质是一种线性变换,线性变换有很多良好的性质,这些性质决定了对数据改变后不会造成“失效”,反而能提高数据的表现,这些性质是归一化/标准化的前提。比如有一个很重要的性质:线性变换不会改变原始数据的数值排序。

标准化前:
在这里插入图片描述
标准化后:
在这里插入图片描述
看到变化了吗,虽然各个点的相对位置看上去还是没变,但是坐标轴变了。均值是0,标准差为1。

参考文献:
数据预处理:标准化和归一化
https://zhuanlan.zhihu.com/p/63911364

三. 在模型中对小批次数据进行批量标准化

3.1 批量标准化公式

在这里插入图片描述

3.2 在模型的什么地方应用批量标准化?

  1. 在全连接层的激活函数之前

在这里插入图片描述

  1. 卷积层之后的非线性的激活函数之前

  2. 预测过程中的批量标准化
    我们需要对逐个样本预测,我们需要移动估算整个训练数据集的样
    本均值和⽅差。

四.代码

import torch
from torch import nn
from d2l import torch as d2l
import time
# eps是上面公式由于标准差在分母的位置,所以加入一个常量eps,防止分母为0
def batch_norm(X,gamma,beta,moving_mean,moving_var,eps,momentum):# 通过is_grad_enabled方法来判断当前模式是训练模式还是预测模式if not torch.is_grad_enabled():# 如果是预测模式,直接使用传入的移动平均所得的均值和方差X_hat = (X-moving_mean) / torch.sqrt(moving_var+eps)else:assert len(X.shape) in (2,4)if len(X.shape)==2: # 使用全连接层的情况,计算特征维上的均值和方差# 对每行求平均值mean = X.mean(dim=0)var = ((X-mean)**2).mean(dim=0)else:# 使用二维卷积层的情况,计算通道维上(axis=1)的均值和方差。这里我们需要保持X的形状以便后面可以做广播运算mean = X.mean(axis=(0, 2, 3), keepdims=True)var = ((X - mean) ** 2).mean(axis=(0, 2, 3), keepdims=True)# 训练模式下用当前的均值和方差做标准化X_hat = (X - mean) / torch.sqrt(var + eps)moving_mean = momentum * moving_mean + (1.0 - momentum) * meanmoving_var = momentum * moving_var + (1.0 - momentum) * varY = gamma * X_hat + beta # 缩放和移位return Y, moving_mean.data, moving_var.data

BatchNorm类 & BatchNorm(卷积层的输出通道数, num_dims=2或4)

BatchNorm(卷积层的输出通道数, num_dims=24),由于BatchNorm只出现在全连接层的激活函数之前和卷积层的激活函数之前,所以规定2表示全连接层,4表示卷积层。
class BatchNorm(nn.Module):
# num_features:完全连接层的输出数量或卷积层的输出通道数。
# num_dims:2表示完全连接层,4表示卷积层def __init__(self, num_features, num_dims):super().__init__()if num_dims == 2:shape = (1, num_features)else:shape = (1, num_features, 1, 1)# 参与求梯度和迭代的拉伸和偏移参数,分别初始化成1和0self.gamma = nn.Parameter(torch.ones(shape))self.beta = nn.Parameter(torch.zeros(shape))# 非模型参数的变量初始化为0和1self.moving_mean = torch.zeros(shape)self.moving_var = torch.ones(shape)def forward(self, X):# 如果X不在内存上,将moving_mean和moving_var# 复制到X所在显存上if self.moving_mean.device != X.device:self.moving_mean = self.moving_mean.to(X.device)self.moving_var = self.moving_var.to(X.device)# 保存更新过的moving_mean和moving_varY, self.moving_mean, self.moving_var = batch_norm(X, self.gamma, self.beta, self.moving_mean,self.moving_var, eps=1e-5, momentum=0.9)return Y

sigmoid激活函数版本模型

net = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5), BatchNorm(6, num_dims=4), nn.Sigmoid(), nn.AvgPool2d(kernel_size=2, stride=2),nn.Conv2d(6, 16, kernel_size=5), BatchNorm(16, num_dims=4), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2), nn.Flatten(),nn.Linear(16*4*4, 120), BatchNorm(120, num_dims=2), nn.Sigmoid(),nn.Linear(120, 84), BatchNorm(84, num_dims=2), nn.Sigmoid(),nn.Linear(84, 10)
)

库中的训练函数 train_ch6 没有取最优的准确率,自己实现一个

def train_ch6(net, train_iter, test_iter, num_epochs, lr, device):"""Train a model with a GPU (defined in Chapter 6).Defined in :numref:`sec_lenet`"""def init_weights(m):if type(m) == nn.Linear or type(m) == nn.Conv2d:nn.init.xavier_uniform_(m.weight)net.apply(init_weights)print('training on', device)net.to(device)optimizer = torch.optim.SGD(net.parameters(), lr=lr)loss = nn.CrossEntropyLoss()animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs],legend=['train loss', 'train acc', 'test acc'])timer, num_batches = d2l.Timer(), len(train_iter)best_test_acc = 0for epoch in range(num_epochs):# Sum of training loss, sum of training accuracy, no. of examplesmetric = d2l.Accumulator(3)net.train()for i, (X, y) in enumerate(train_iter):timer.start()optimizer.zero_grad()X, y = X.to(device), y.to(device)y_hat = net(X)l = loss(y_hat, y)l.backward()optimizer.step()with torch.no_grad():metric.add(l * X.shape[0], d2l.accuracy(y_hat, y), X.shape[0])timer.stop()train_l = metric[0] / metric[2]train_acc = metric[1] / metric[2]if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:animator.add(epoch + (i + 1) / num_batches,(train_l, train_acc, None))test_acc = d2l.evaluate_accuracy_gpu(net, test_iter)if test_acc>best_test_acc:best_test_acc = test_accanimator.add(epoch + 1, (None, None, test_acc))print(f'loss {train_l:.3f}, train acc {train_acc:.3f}, 'f'test acc {test_acc:.3f}, best test acc {best_test_acc:.3f}')# 取的好像是平均准备率print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec 'f'on {str(device)}')

开始训练

'''开始计时'''
start_time = time.time()# 配置参数
lr, num_epochs, batch_size = 1.0, 10, 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())'''计时结束'''
end_time = time.time()
run_time = end_time - start_time
# 将输出的秒数保留两位小数
if int(run_time)<60:print(f'{round(run_time,2)}s')
else:print(f'{round(run_time/60,2)}minutes')

sigmoid激活函数版本结果

在这里插入图片描述

将sigmoid换成relu的版本结果

# 换成Relu()
net = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5), BatchNorm(6, num_dims=4), nn.ReLU(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Conv2d(6, 16, kernel_size=5), BatchNorm(16, num_dims=4), nn.ReLU(),nn.AvgPool2d(kernel_size=2, stride=2), nn.Flatten(),nn.Linear(16*4*4, 120), BatchNorm(120, num_dims=2), nn.ReLU(),nn.Linear(120, 84), BatchNorm(84, num_dims=2), nn.Sigmoid(),nn.Linear(84, 10)
)

在这里插入图片描述

将sigmoid激活函数换成relu和把平均池化AvgPool2d换成最大值池化MaxPool2d版本结果

学习率用的是1

# 换成Relu()+最大值池化
net = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5), BatchNorm(6, num_dims=4), nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(6, 16, kernel_size=5), BatchNorm(16, num_dims=4), nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2), nn.Flatten(),nn.Linear(16*4*4, 120), BatchNorm(120, num_dims=2), nn.ReLU(),nn.Linear(120, 84), BatchNorm(84, num_dims=2), nn.Sigmoid(),nn.Linear(84, 10)

)在这里插入图片描述

保持如上网络不变,学习率换成1.3的结果

在这里插入图片描述

五.补充

5.1 方差和标准差公式

在这里插入图片描述

5.2 对某个维度取平均值

X.mean(dim=0),表示对每个列取平均值,保留行

import torch
# X = torch.rand(size=(2,2))
X = torch.tensor([[1.0, 1.0],[-1.0, -1.0]])
mean = X.mean(dim=0)
mean,mean.shape

输出结果:

(tensor([0., 0.]), torch.Size([2]))

X.mean(dim=1),表示对每个行取平均值,保留列

mean = X.mean(dim=1)
mean,mean.shape
(tensor([ 1., -1.]), torch.Size([2]))

X.mean(dim=(0,2,3),keepdim=True),表示对维度0(样本),维度2(行数),维度3(列数)求平均。保留维度1(通道)
keepdim=True,表示保留维度

X = torch.tensor([[[[1.0, 2.0],[0.0, 4.0]]]])
X.shape
# 按通道数求均值就是 把二维矩阵求和/矩阵大小。 如上方 1+2+0+4=7,7/4=1.75
mean = X.mean(dim=(0,2,3),keepdim=True)
mean,mean.shape
(tensor([[[[1.7500]]]]), torch.Size([1, 1, 1, 1]))

怎么理解使用大小为 1 的小批量应用批量归一化,我们将无法学到任何东西

1.请注意,如果我们尝试使用大小为 1 的小批量应用批量归一化,我们将无法学到任何东西。 这是因为在减去均值之后,每个隐藏单元将为 0。 所以,只有使用足够大的小批量,批量归一化这种方法才是有效且稳定的。 请注意,在应用批量归一化时,批量大小的选择可能比没有批量归一化时更重要”,请问怎么理解呢?这里的批量请问是网络训练的batch吗?

指batch_size,notebook上设置为256。如果batch_size为1,那么这个小批量的均值就是这个样本值本身,则样本值减去均值就为0了,也即文中所说的”每个隐藏单元将为 0“,所以batch_size要足够大,但也不要太大

参考链接:https://blog.csdn.net/qq_60567866/article/details/125608162

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/83496.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

笔记本电脑如何把sd卡数据恢复

在使用笔记本电脑过程中&#xff0c;如果不小心将SD卡里面的重要数据弄丢怎么办呢&#xff1f;别着急&#xff0c;本文将向您介绍SD卡数据丢失常见原因和恢复方法。 ▌一、SD卡数据丢失常见原因 - 意外删除&#xff1a;误操作或不小心将文件或文件夹删除。 - 误格式化&#…

跨境干货|TikTok变现的9种方法

在这个流量为王的时代&#xff0c;哪里有流量&#xff0c;哪里就有商机。TikTok作为近几年最火爆的社媒平台之一&#xff0c;在全球范围都具有一定的影响力。随着TikTok Shop等商务功能加持上线&#xff0c;更是称为跨境电商的新主场之一。 在这样的UGC平台&#xff0c;想要变…

适配器模式-java实现

意图 复用已经存在的接口&#xff0c;与所需接口不一致的类。即将一个类&#xff08;通常是旧系统中的功能类&#xff09;&#xff0c;通过适配器转化成另一个接口的实现。&#xff08;简单来说&#xff0c;就是复用旧系统的功能&#xff0c;去实现新的接口&#xff09; 我们举…

API 测试 | 了解 API 接口概念|电商平台 API 接口测试指南

什么是 API&#xff1f; API 是一个缩写&#xff0c;它代表了一个 pplication P AGC 软件覆盖整个房间。API 是用于构建软件应用程序的一组例程&#xff0c;协议和工具。API 指定一个软件程序应如何与其他软件程序进行交互。 例行程序&#xff1a;执行特定任务的程序。例程也称…

深度学习:使用卷积神经网络CNN实现MNIST手写数字识别

引言 本项目基于pytorch构建了一个深度学习神经网络&#xff0c;网络包含卷积层、池化层、全连接层&#xff0c;通过此网络实现对MINST数据集手写数字的识别&#xff0c;通过本项目代码&#xff0c;从原理上理解手写数字识别的全过程&#xff0c;包括反向传播&#xff0c;梯度…

【UE4 RTS】04-Camera Pan

前言 本篇实现了CameraPawn的旋转功能。 效果 步骤 1. 打开项目设置&#xff0c;添加两个操作映射 2. 打开玩家控制器“RTS_PlayerController_BP”&#xff0c;新建一个浮点型变量&#xff0c;命名为“PanSpeed” 在事件图表中添加如下节点 此时运行游戏可以发现当鼠标移动…

ReSharper C++ 2023 Crack

ReSharper C 2023 Crack ReSharper的AI助手会考虑项目中使用的语言和技术。这种上下文感知可以一开始就调整其响应&#xff0c;为您节省时间和精力。 您可以在查询中包含部分源代码。ReSharper将检测你发送或粘贴到聊天中的代码&#xff0c;并正确格式化&#xff0c;而人工智能…

【数据结构OJ题】合并两个有序数组

原题链接&#xff1a;https://leetcode.cn/problems/merge-sorted-array/ 目录 1. 题目描述 2. 思路分析 3. 代码实现 1. 题目描述 2. 思路分析 看到这道题&#xff0c;我们注意到nums1[ ]和nums2[ ]两个数组都是非递减的。所以我们很容易想到额外开一个数组tmp[ ]&#x…

重试框架入门:Spring-RetryGuava-Retry

前言 在日常工作中&#xff0c;随着业务日渐庞大&#xff0c;不可避免的涉及到调用远程服务&#xff0c;但是远程服务的健壮性和网络稳定性都是不可控因素&#xff0c;因此&#xff0c;我们需要考虑合适的重试机制去处理这些问题&#xff0c;最基础的方式就是手动重试&#xf…

C语言函数详解(1)

目录 函数是什么 C语言中函数的分类 库函数 自定义函数 函数的参数 实际参数&#xff08;实参&#xff09; 形式参数&#xff08;形参&#xff09; 函数的调用 传值调用 传址调用 练习 函数的嵌套调用和链式访问 嵌套调用 链式访问 函数是什么 数学中我们常见到函…

node笔记——调用免费qq的smtp发送html格式邮箱

文章目录 ⭐前言⭐smtp授权码获取⭐nodemailer⭐postman验证接口⭐结束 ⭐前言 大家好&#xff0c;我是yma16&#xff0c;本文分享关于node调用免费qq的smtp发送邮箱。 node系列往期文章 node_windows环境变量配置 node_npm发布包 linux_配置node node_nvm安装配置 node笔记_h…

从零实现深度学习框架——Transformer从菜鸟到高手(一)

引言 &#x1f4a1;本文为&#x1f517;[从零实现深度学习框架]系列文章内部限免文章&#xff0c;更多限免文章见 &#x1f517;专栏目录。 本着“凡我不能创造的&#xff0c;我就不能理解”的思想&#xff0c;系列文章会基于纯Python和NumPy从零创建自己的类PyTorch深度学习框…

js 正则表达式

js 正则表达式 http://tool.oschina.net/regex https://developer.mozilla.org/zh-CN/docs/Web/JavaScript/Guide/Regular_Expressions 11 22 333

磁盘的管理

一、磁盘的分区 查看磁盘 lsblk fdisk -l 2、分区 没有e扩展&#xff0c;则都是主分区&#xff0c;已经有三个主分区了&#xff0c;剩下的全设置为扩展 查看分区结果&#xff1a; 二、格式化 三、挂载

JVM、JRE、JDK三者之间的关系

JVM、JRE和JDK是与Java开发和运行相关的三个重要概念。 再了解三者之前让我们先来了解下java源文件的执行顺序&#xff1a; 使用编辑器或IDE(集成开发环境)编写Java源文件.即demo.java程序必须编译为字节码文件&#xff0c;javac(Java编译器)编译源文件为demo.class文件.类文…

Web-WebApp Vue.js 目录结构

WebApp Vue.js 目录结构 目录解析 目录/文件 说明 build 最终发布的代码存放位置。config 配置目录&#xff0c;包括端口号等。我们初学可以使用默认的。node_modules npm 加载的项目依赖模块 src 这里是我们要开发的目录&#xff0c;基本上要做的事情都在这个目录里。里面包…

Pycharm如何打断点进行调试?

断点调试&#xff0c;是编写程序中一个很重要的步骤&#xff0c;有些简单的程序使用print语句就可看出问题&#xff0c;而比较复杂的程序&#xff0c;函数和变量较多的情况下&#xff0c;这时候就需要打断点了&#xff0c;更容易定位问题。 一、添加断点 在代码的行标前面&…

ATF(TF-A)安全通告 TFV-6 (CVE-2017-5753, CVE-2017-5715, CVE-2017-5754)

ATF(TF-A)安全通告汇总 目录 一、ATF(TF-A)安全通告 TFV-6 (CVE-2017-5753, CVE-2017-5715, CVE-2017-5754) 二、Variant 1 (CVE-2017-5753) 三、Variant 2 (CVE-2017-5715) 四、Variant 3 (CVE-2017-5754) 一、ATF(TF-A)安全通告 TFV-6 (CVE-2017-5753, CVE-2017-5715, C…

15-矩阵转置的拓展延伸

&#x1f52e;矩阵的转置✨ 前言 在很多时候我们拿到的数据本身可能并不会把点的坐标按列的方向排列起来&#xff0c;对于我们人类来说&#xff0c;更方便的方式依然是把这个点的坐标按行的方向排列&#xff0c;我们比较熟悉把矩阵看作为一个数据&#xff0c;在这里&#xff0…

06-3_Qt 5.9 C++开发指南_多窗体应用程序的设计(主要的窗体类及其用途;窗体类重要特性设置;多窗口应用程序设计)

文章目录 1. 主要的窗体类及其用途2. 窗体类重要特性的设置2.1 setAttribute()函数2.2 setWindowFlags()函数2.3 setWindowState()函数2.4 setWindowModality()函数2.5 setWindowOpacity()函数 3. 多窗口应用程序设计3.1 主窗口设计3.2 QFormDoc类的设计3.3 QFormDoc类的使用3.…