动手学深度学习(Pytorch版)代码实践 -卷积神经网络-28批量规范化

28批量规范化

"""可持续加速深层网络的收敛速度"""
import torch
from torch import nn
import liliPytorch as lp
import matplotlib.pyplot as pltdef batch_norm(X, gamma, beta, moving_mean, moving_var, eps, momentum):"""实现一个具有张量的批量规范化层。"""# 如果启用了梯度计算,torch.is_grad_enabled() 返回 True;否则返回 False。if not torch.is_grad_enabled():# torch.no_grad() 是一个上下文管理器,用于临时禁用梯度计算# torch.enable_grad() 是一个上下文管理器,用于在禁用梯度计算的上下文中重新启用梯度计算。X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)else:assert len(X.shape) in (2, 4)if len(X.shape) == 2:# 使用全连接层的情况,计算特征维上的均值和方差mean = X.mean(dim=0) # 计算张量 X 沿着第 0 维的平均值# 维度 0 代表样本数量,即沿着每个特征计算所有样本的平均值。var = ((X - mean) ** 2).mean(dim=0)else:# 使用二维卷积层的情况,计算通道维上(axis=1)的均值和方差。# 这里我们需要保持X的形状以便后面可以做广播运算mean = X.mean(dim=(0, 2, 3), keepdim=True)var = ((X - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)# 训练模式下,用当前的均值和方差做标准化X_hat = (X - mean) / torch.sqrt(var + eps)# 更新移动平均的均值和方差moving_mean = momentum * moving_mean + (1.0 - momentum) * meanmoving_var = momentum * moving_var + (1.0 - momentum) * var# gamma 和 beta 的更新是通过反向传播和优化器自动完成的Y = gamma * X_hat + beta # 缩放和移位return Y, moving_mean.data, moving_var.dataclass BatchNorm(nn.Module):# num_features:完全连接层的输出数量或卷积层的输出通道数。# num_dims:2表示完全连接层,4表示卷积层def __init__(self, num_features, num_dims):super().__init__()if num_dims == 2:shape = (1, num_features)else:shape = (1, num_features, 1, 1)# 参与求梯度和迭代的拉伸和偏移参数,分别初始化成1和0self.gamma = nn.Parameter(torch.ones(shape))self.beta = nn.Parameter(torch.zeros(shape))# 非模型参数的变量初始化为0和1# 经过归一化处理后的数据均值接近于零。因此,将滑动均值初始化为0,是对数据初始均值的一种合理假设。self.moving_mean = torch.zeros(shape)# 方差表示数据的离散程度。将滑动方差初始化为1,意味着假设数据的初始方差为1,# 即数据分布接近标准正态分布。这样初始化可以避免初始阶段的数值不稳定。self.moving_var = torch.ones(shape)def forward(self, X):# 如果X不在内存上,将moving_mean和moving_var# 复制到X所在GPU上                              if self.moving_mean.device != X.device:self.moving_mean = self.moving_mean.to(X.device)self.moving_var = self.moving_var.to(X.device)# 保存更新过的moving_mean和moving_varY, self.moving_mean, self.moving_var = batch_norm(X, self.gamma, self.beta, self.moving_mean,self.moving_var, eps=1e-5, momentum=0.9)return Y#使用批量规范化层的 LeNet
net = nn.Sequential(nn.Conv2d(1, 6,  kernel_size=5, padding=2), # 卷积层1:输入通道数1,输出通道数6,卷积核大小5x5,填充2BatchNorm(num_features=6, num_dims=4),nn.ReLU(), # 激活函数nn.AvgPool2d(kernel_size=2, stride=2), # 平均池化层1:池化窗口大小2x2,步幅2nn.Conv2d(6, 16, kernel_size=5), # 卷积层2:输入通道数6,输出通道数16,卷积核大小5x5BatchNorm(num_features=16, num_dims=4),nn.ReLU(), nn.AvgPool2d(kernel_size=2, stride=2), # 平均池化层2:池化窗口大小2x2,步幅2nn.Flatten(), # 展平层:将多维输入展平为1维nn.Linear(16 * 5 * 5, 120), # 全连接层1:输入节点数16*5*5,输出节点数120BatchNorm(num_features=120, num_dims=2),nn.ReLU(),nn.Linear(120, 84), # 全连接层2:输入节点数120,输出节点数84BatchNorm(num_features=84, num_dims=2),nn.ReLU(), nn.Linear(84, 10) # 全连接层3:输入节点数84,输出节点数10(对应10个分类)
)lr, num_epochs, batch_size = 1.0, 10, 256
train_iter, test_iter = lp.loda_data_fashion_mnist(batch_size)
# lp.train_ch6(net, train_iter, test_iter, num_epochs, lr, lp.try_gpu())
# plt.show()# loss 0.200, train acc 0.925, test acc 0.812
# 34957.3 examples/sec on cuda:0# loss 0.189, train acc 0.928, test acc 0.894
# 33471.2 examples/sec on cuda:0#简明实现
net = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5), nn.BatchNorm2d(6), nn.ReLU(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Conv2d(6, 16, kernel_size=5), nn.BatchNorm2d(16), nn.ReLU(),nn.AvgPool2d(kernel_size=2, stride=2), nn.Flatten(),nn.Linear(256, 120), nn.BatchNorm1d(120), nn.ReLU(),nn.Linear(120, 84), nn.BatchNorm1d(84), nn.ReLU(),nn.Linear(84, 10)
)
lp.train_ch6(net, train_iter, test_iter, num_epochs, lr, lp.try_gpu())
plt.show()# nn.Sigmoid()
# loss 0.263, train acc 0.902, test acc 0.833
# 46935.0 examples/sec on cuda:0# nn.ReLU()
# loss 0.224, train acc 0.914, test acc 0.874
# 44479.2 examples/sec on cuda:0
"""
通常高级API变体运行速度快得多,因为它的代码已编译为C++或CUDA,而我们的自定义代码由Python实现。
"""

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/363829.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

算法题 — 接雨水

给定 n 给非负整数,表示每个宽度为 1 的柱子的高度图,计算按照此排列的柱子,下雨之后能能接到多少雨水。 输入:height [0, 1, 0, 2, 1, 0, 1, 3, 2, 1, 2, 1] 输出:6 解释:上面是由数组 [0, 1, 0, 2, 1,…

算法基础--------【图论】

图论(待完善) DFS:和回溯差不多 BFS:进while进行层序遍历 定义: 图论(Graph Theory)是研究图及其相关问题的数学理论。图由节点(顶点)和连接这些节点的边组成。图论的研究范围广泛,涉及路径、…

【日记】现在的孩子真是不怕大人呢(1975 字)

正文 时间太晚了,而且想写的内容有点多,就不写在日记本上了。 不过说内容多,其实也只有两件事情。其他的就一笔带过吧。一件关于灵,另一件事关于遇见的孩子。 首先说说工作,今天真的如昨天预料的那样,特别忙…

基于Pico和MicroPython点亮ws2812彩色灯带

基于Pico和MicroPython点亮ws2812彩色灯带 文章目录 基于Pico和MicroPython点亮ws2812彩色灯带IntroductionPracticeConclusion Introduction 点亮发光的LED灯是简单有趣的实验,点亮多个ws2812小灯串联起来的灯带,可对多个彩色小灯进行编程,…

软件测试之接口测试(Postman/Jmeter)

🍅 视频学习:文末有免费的配套视频可观看 🍅 点击文末小卡片,免费获取软件测试全套资料,资料在手,涨薪更快 一、什么是接口测试 通常做的接口测试指的是系统对外的接口,比如你需要从别的系统来…

cartographer从入门到精通(一):cartographer介绍

一、cartographer重要文档 有关cartographer的资料有2个比较重要的网站,我们的介绍也是基于这两个网站,其中会加入自己的一些理解,后续也有一些对代码的修改,来实现我们想完善的功能。 1-Cartographer 2-Cartographer ROS 第1个…

融资担保行业数字化转型探索与实践

融资担保行业数字化转型探索与实践 随着全球经济的快速发展和科技的不断进步,数字化转型已成为各行各业提升竞争力和实现可持续发展的必然选择。融资担保行业作为金融体系中的重要组成部分,也在积极探索和实践数字化转型,以更好地服务中小微企…

小时候的子弹击中了现在的我-hive进阶:案例解析(第18天)

系列文章目录 一、Hive表操作 二、数据导入和导出 三、分区表 四、官方文档(了解) 五、分桶表(熟悉) 六、复杂类型(熟悉) 七、Hive乱码解决(操作。可以不做,不影响) 八、…

图像大模型中的注意力和因果掩码

AIM — 图像领域中 LLM 的对应物。尽管 iGPT 已经存在 2 年多了,但自回归尚未得到充分探索。在本文中,作者表明,当使用 AIM 对网络进行预训练时,一组图像数据集上的下游任务的平均准确率会随着数据和参数的增加而线性增加。 要运…

已解决javax.xml.bind.MarshalException:在RMI中,参数或返回值无法被编组的正确解决方法,亲测有效!!!

已解决javax.xml.bind.MarshalException:在RMI中,参数或返回值无法被编组的正确解决方法,亲测有效!!! 目录 问题分析 出现问题的场景 服务器端代码 客户端代码 报错原因 解决思路 解决方法 1. 实现…

Redis 7.x 系列【11】数据类型之位图(Bitmap)

有道无术,术尚可求,有术无道,止于术。 本系列Redis 版本 7.2.5 源码地址:https://gitee.com/pearl-organization/study-redis-demo 文章目录 1. 概述2. 基本命令2.1 SETBIT2.2 GETBIT2.3 BITCOUNT2.4 BITPOS2.5 BITFIELD2.6 BITF…

OverTheWire Bandit 靶场通关解析(下)

介绍 OverTheWire Bandit 是一个针对初学者设计的网络安全挑战平台,旨在帮助用户掌握基本的命令行操作和网络安全技能。Bandit 游戏包含一系列的关卡,每个关卡都需要解决特定的任务来获取进入下一关的凭证。通过逐步挑战更复杂的问题,用户可…

绝了!Stable Diffusion做AI治愈图片视频,用来做副业简直无敌!10分钟做一个爆款视频保姆教程

一 项目分析 这个治愈类视频的玩法是通过AI生成日常生活场景,制作的vlog,有这样的一个号,发布了几条作品,就涨粉了2000多,点赞7000多,非常的受欢迎。 下面给大家看下这种作品是什么样的,如图所…

Python面试宝典第1题:两数之和

题目 给定一个整数数组 nums 和一个目标值 target,找出数组中和为目标值的两个数的索引。可以假设每个输入只对应唯一的答案,且同样的元素不能被重复利用。比如:给定 nums [2, 7, 11, 15] 和 target 9,返回 [0, 1],因…

基于Java的蛋糕预定系统【附源码+LW】

摘 要 当今社会进入了科技进步、经济社会快速发展的新时代。国际信息和学术交流也不断加强,计算机技术对经济社会发展和人民生活改善的影响也日益突出,人类的生存和思考方式也产生了变化。传统购物方式采取了人工的管理方法,但这种管理方法存…

springboot系列七: Lombok注解,Spring Initializr,yaml语法

老韩学生 LombokLombok介绍Lombok常用注解Lombok应用实例代码实现idea安装lombok插件 Spring InitializrSpring Initializr介绍Spring Initializr使用演示需求说明方式1: IDEA创建方式2: start.spring.io创建 注意事项和说明 yaml语法yaml介绍使用文档yaml基本语法数据类型字面…

黑芝麻科技A1000简介

文章目录 1. A1000 简介2. 感知能力评估3. 竞品对比4. 系统软件1. A1000 简介

【R语言】plot输出窗口大小的控制

如果需要输出png格式的图片并设置dpi,可采用以下代码 png("A1.png",width 10.09, height 10.35, units "in",res 300) 为了匹配对应的窗口大小,在输出的时候保持宽度和高度一致即可,步骤如下: 如上的“10…

【递归、搜索与回溯】记忆化搜索

记忆化搜索 1.记忆化搜索2.不同路径3.最长递增子序列4. 猜数字大小 II5.矩阵中的最长递增路径 点赞👍👍收藏🌟🌟关注💖💖 你的支持是对我最大的鼓励,我们一起努力吧!😃😃…

最近写javaweb出现的一个小bug---前端利用 form 表单传多项数据,后端 Servlet 取出的各项数据均为空

目录: 一. 问题引入二 解决问题 一. 问题引入 近在写一个 java web 项目时,遇到一个让我头疼了晚上的问题:前端通过 post 提交的 form 表单数据可以传到后端,但当我从 Servlet 中通过 request.getParameter(“name”) 拿取各项数…