在深度学习中常见的初始化操作

目录

截断正态分布来初始化张量

逐行代码解释

相关理论解释

截断正态分布函数

截断正态分布的定义

截断正态分布的作用

计算截断点的作用

具体步骤

正态分布的累积分布函数(CDF)

 正态分布的累积分布函数与误差函数的关系

示例计算

误差函数

应用:

定义:

误差函数的性质

Python 中的误差函数

总结


截断正态分布来初始化张量

import math
import warnings
import torchdef _no_grad_trunc_normal_(tensor, mean, std, a, b):def norm_cdf(x):return (1. + math.erf(x / math.sqrt(2.))) / 2.if (mean < a - 2 * std) or (mean > b + 2 * std):warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. ""The distribution of values may be incorrect.",stacklevel=2)with torch.no_grad():l = norm_cdf((a - mean) / std)u = norm_cdf((b - mean) / std)tensor.uniform_(2 * l - 1, 2 * u - 1)tensor.erfinv_()tensor.mul_(std * math.sqrt(2.))tensor.add_(mean)tensor.clamp_(min=a, max=b)return tensor

逐行代码解释

1、正态分布的累积分布函数(CDF)norm_cdf 函数计算标准正态分布的累积分布函数。

def norm_cdf(x):return (1. + math.erf(x / math.sqrt(2.))) / 2.

2、警告:检查均值是否在截断边界 [a, b] 的2个标准差范围内,如果不在,则发出警告。

if (mean < a - 2 * std) or (mean > b + 2 * std):warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. ""The distribution of values may be incorrect.",stacklevel=2)

3、不跟踪梯度:以下代码块确保初始化时不跟踪梯度,这对于设置神经网络的初始权重很有用。

with torch.no_grad():l = norm_cdf((a - mean) / std)u = norm_cdf((b - mean) / std)tensor.uniform_(2 * l - 1, 2 * u - 1)tensor.erfinv_()tensor.mul_(std * math.sqrt(2.))tensor.add_(mean)tensor.clamp_(min=a, max=b)return tensor
  • lu 是截断点 ab 处的累积分布函数值。
  • tensor.uniform_(2 * l - 1, 2 * u - 1) 用从指定范围的均匀分布生成的值初始化张量。
  • tensor.erfinv_() 对张量应用误差函数的逆函数。
  • tensor.mul_(std * math.sqrt(2.)) 将张量的值缩放到期望的标准差。
  • tensor.add_(mean) 将张量的值平移到期望的均值。
  • tensor.clamp_(min=a, max=b) 确保张量中的所有值都在指定的截断范围 [a, b] 之内。

相关理论解释

截断正态分布函数

截断正态分布的定义

        给定一个均值为 μ、标准差为 σ 的正态分布 N(μ,σ2),截断正态分布在区间 [a,b] 上的定义如下:

         其中,ϕ(x) 是正态分布的概率密度函数(PDF),Φ(x)是正态分布的累积分布函数(CDF)。

截断正态分布的作用
  1. 限制范围:确保生成的随机变量值在某个指定范围内,这对于物理约束或特定应用场景非常重要。
  2. 防止异常值:避免生成不合实际或有害的极端值,例如在神经网络权重初始化时防止极端值导致的训练不稳定。
计算截断点的作用

在实现截断正态分布时,我们需要计算截断点 ab 对应的累积分布函数值 l 和 u,以便生成满足截断条件的随机数。

具体步骤
  1. 标准化:将截断点 ab 标准化为标准正态分布中的值。

  2. 计算标准正态分布的 CDF:计算标准正态分布在标准化后的截断点 lu 处的累积分布函数值。注意:此处有一个性质,就是随机变量Φ(l)和Φ(u)是满足[0,1]的均匀分布。

  3. 转换为均匀分布:生成的均匀分布随机数在 [2Φ(l)−1,2Φ(u)−1] 区间内。

  4. 逆误差函数:将均匀分布的值通过逆误差函数转换为标准正态分布的值。

    tensor.erfinv()

  5. 缩放和平移:将标准正态分布的值缩放到所需的标准差,并平移到所需的均值。

  6. 截断:确保所有值都在 [a,b] 区间内。

正态分布的累积分布函数(CDF)

定义:用于计算正态分布从负无穷大到给定值 x的概率。具体而言,对于标准正态分布 N(0,1),CDF 表示为:

 正态分布的累积分布函数与误差函数的关系

在代码中,我们通过误差函数(erf)来计算标准正态分布的 CDF。误差函数与标准正态分布的 CDF 之间有如下关系:

 代码中的 norm_cdf 函数:

def norm_cdf(x):return (1. + math.erf(x / math.sqrt(2.))) / 2.

norm_cdf 函数的实现如下:

  1. 输入:函数接收一个参数 x,它是需要计算 CDF 的点。
  2. 计算误差函数math.erf(x / math.sqrt(2.)) 计算 \frac{x}{\sqrt{2}} 的误差函数值。
  3. 调整误差函数值:将误差函数的结果加 1,然后除以 2,得到标准正态分布在 x 点的 CDF 值。

以下是函数的具体步骤:

  1. math.erf(x / math.sqrt(2.))计算误差函数
  2. 1. + math.erf(x / math.sqrt(2.)):将误差函数的结果加 1。
  3. (1. + math.erf(x / math.sqrt(2.))) / 2.:结果除以 2 得到最终的 CDF 值。
示例计算

假设我们需要计算标准正态分布在 x=1处的 CDF 值:

import mathdef norm_cdf(x):return (1. + math.erf(x / math.sqrt(2.))) / 2.x = 1
cdf_value = norm_cdf(x)
print("CDF value at x = 1:", cdf_value)

 运行以上代码,会输出 x=1处的 CDF 值,即:

CDF value at x = 1: 0.8413447460685429

这意味着在标准正态分布中,小于等于 1 的值的概率大约为 0.8413。

误差函数

应用:

        数学上用于处理正态分布和概率问题的重要函数。误差函数用于计算某个值在标准正态分布中的概率,并且在统计学、概率论和许多应用数学领域中都有广泛应用。

定义:

         这个积分没有解析解,因此通常通过数值方法进行计算。

误差函数的性质
  • 对称性:误差函数是奇函数,即erf⁡(−x)=−erf⁡(x) 。
  • 值域:误差函数的值域在 −1 到 1 之间,即 −1≤erf⁡(x)≤1。
  • 边界值:当 x→∞ 时,erf⁡(x)→1;当 x→−∞时,erf⁡(−x)→−1。
Python 中的误差函数

在 Python 中,可以使用 math 模块中的 erf 函数来计算误差函数值。以下是一个示例:

import mathx = 1.0
erf_value = math.erf(x)
print("erf(1.0) =", erf_value)

运行结果是:

erf(1.0) = 0.8427007929497149

这意味着当x=1.0 时,erf(1.0)的值大约为 0.8427。

总结

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/329818.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

软件设计师-上午题-计算题汇总

一、存储系统 - 存储容量计算&#xff08;字节编址、位编址、芯片个数&#xff09; 内存地址是16进制 内存地址编址的单位是Byte&#xff0c;1K1024B 1B 8 bit 1.计算存储单元个数 存储单元个数 末地址 - 首地址 1 eg. 按字节编址&#xff0c;地址从 A4000H 到 CBFFFH&…

使用B2M 算法批量将可执行文件转为灰度图像

参考论文 基于二进制文件的 C 语言编译器特征提取及识别 本实验使用 B2M 算法将可执行文件转为灰度图像&#xff0c;可执行文件转为灰度图的流程如图 4-3 所示。将 可执行文件每 8 位读取为一个无符号的的整型常量&#xff0c;一个可执行文件得到一个一维向量&#xff0c; …

深度学习之基于Tensorflow+Keras+CNN模型实时对手写数字进行分类

欢迎大家点赞、收藏、关注、评论啦 &#xff0c;由于篇幅有限&#xff0c;只展示了部分核心代码。 文章目录 一项目简介 二、功能三、系统四. 总结 一项目简介 一、项目背景与意义 随着深度学习和计算机视觉技术的快速发展&#xff0c;手写数字识别已成为一个重要的应用场景。…

装备制造项目管理软件:奥博思PowerProject项目管理系统

数字化正逐步改变着制造方式和企业组织模式。某制造企业领导层透露&#xff0c;在采用数字化项目管理模式后&#xff0c;企业的发展韧性更加强劲&#xff0c;构筑起了竞争新优势&#xff0c;企业产品研制周期缩短25%&#xff0c;生产效率提升18%。 随着全球经济的发展&#xf…

SpringBootWeb 篇-深入了解 Mybatis 删除、新增、更新、查询的基础操作与 SQL 预编译解决 SQL 注入问题

&#x1f525;博客主页&#xff1a; 【小扳_-CSDN博客】 ❤感谢大家点赞&#x1f44d;收藏⭐评论✍ 文章目录 1.0 Mybatis 的基础操作 2.0 基础操作 - 环境准备 3.0 基础操作 - 删除操作 3.1 SQL 预编译 3.2 SQL 预编译的优势 3.3 参数占位符 4.0 基础操作 - 新增 4.1 主键返回…

深度学习之基于Pytorch框架多人多摄像头摔倒跌倒坠落检测

欢迎大家点赞、收藏、关注、评论啦 &#xff0c;由于篇幅有限&#xff0c;只展示了部分核心代码。 文章目录 一项目简介 二、功能三、系统四. 总结 一项目简介 一、项目背景 随着智能监控技术的广泛应用&#xff0c;对于公共场合的安全监控需求日益增加。摔倒跌倒坠落是常见的…

基于深度学习的Tensorflow卷积神经网络(CNN)车牌识别

欢迎大家点赞、收藏、关注、评论啦 &#xff0c;由于篇幅有限&#xff0c;只展示了部分核心代码。 文章目录 一项目简介 二、功能三、系统四. 总结 一项目简介 一、项目背景 车牌识别&#xff08;License Plate Recognition, LPR&#xff09;是智能交通系统&#xff08;ITS&a…

解锁产品迭代新速度:A/B测试在AI大模型时代的应用

本文作者为火山引擎A/B测试平台DataTester的资深研发工程师刘明瑶。作为火山引擎数智平台VeDI旗下的核心产品&#xff0c;DataTester源于字节跳动长期的技术和业务沉淀&#xff0c;目前已经服务了数百家企业&#xff0c;助力企业在业务增长、用户转化、产品迭代、策略优化以及运…

深度学习之Tensorflow卷积神经网络手势识别

欢迎大家点赞、收藏、关注、评论啦 &#xff0c;由于篇幅有限&#xff0c;只展示了部分核心代码。 文章目录 一项目简介 二、功能三、系统四. 总结 一项目简介 一、项目背景与意义 手势识别是计算机视觉和人工智能领域的重要应用之一&#xff0c;具有广泛的应用前景&#xff…

抖音视频怎么去水印保存部分源码|短视频爬虫提取收集下载工具

抖音视频怎么去水印保存部分源码|短视频爬虫提取收集下载工具 抖音视频去水印保存部分源码&#xff1a; 通过使用Python中的requests、re和os等库&#xff0c;可以编写如下代码来实现抖音视频去水印保存的功能。 短视频爬虫提取手机下载工具的使用方法&#xff1a; 该工具主…

【Linux学习】进程地址空间与写时拷贝

文章目录 Linux进程内存布局图&#xff1a;内存布局的验证 进程地址空间写时拷贝 Linux进程内存布局图&#xff1a; 地址空间的范围&#xff0c;在32位机器上是2^32比特位,也就是[0,4G]。 内存布局的验证 代码验证内存布局&#xff1a; 验证代码&#xff1a; #include<s…

基于FPGA的VGA协议实现----条纹-文字-图片

基于FPGA的VGA协议实现----条纹-文字-图片 引言&#xff1a; ​ 随着数字电子技术的飞速发展&#xff0c;现场可编程门阵列&#xff08;FPGA&#xff09;因其高度的灵活性和并行处理能力&#xff0c;在数字系统设计中扮演着越来越重要的角色。FPGA能够实现复杂的数字逻辑&#…

字节面试:百亿级数据存储,怎么设计?只是分库分表吗?

尼恩&#xff1a;百亿级数据存储架构起源 在40岁老架构师 尼恩的读者交流群(50)中&#xff0c;经常性的指导小伙伴们改造简历。 经过尼恩的改造之后&#xff0c;很多小伙伴拿到了一线互联网企业如得物、阿里、滴滴、极兔、有赞、希音、百度、网易、美团的面试机会&#xff0c…

基于Tensorflow卷积神经网络垃圾智能分类系统

欢迎大家点赞、收藏、关注、评论啦 &#xff0c;由于篇幅有限&#xff0c;只展示了部分核心代码。 文章目录 一项目简介 二、功能三、系统四. 总结 一项目简介 一、项目背景与意义 随着城市化进程的加速&#xff0c;垃圾问题日益严重&#xff0c;垃圾分类成为解决这一问题的关…

(全面)Nginx格式化插件,Nginx生产工具,Nginx常用命令

目录 &#x1f3ab; 前言 &#x1f389; 开篇福利 &#x1f381; 开篇福利 x2 Double happiness # 介绍 # 地址 # 下载 &#x1f4bb; 命令及解析 # 整个文件系统中搜索名为nginx.conf的文件 # 编辑nginx.conf文件 # 重新加载配置文件 # 快速查找nginx.conf文件并使…

Android和flutter交互,maven库的形式导入aar包

记录遇到的问题&#xff0c;在网上找了很多资料&#xff0c;都是太泛泛了&#xff0c;使用后&#xff0c;还不能生效&#xff0c;缺少详细的说明&#xff0c;或者关键代码缺失&#xff0c;我遇到的问题用红色的标注了 导入aar包有两种模式 1.比较繁琐的&#xff0c;手动将aar…

Linux应用入门(二)

1. 输入系统应用编程 1.1 输入系统介绍 常见的输入设备有键盘、鼠标、遥控杆、书写板、触摸屏等。用户经过这些输入设备与Linux系统进行数据交换。这些设备种类繁多&#xff0c;如何去统一它们的接口&#xff0c;Linux为了统一管理这些输入设备实现了一套能兼容所有输入设备的…

【真人Q版手办风】线稿手绘+ AI绘图 Stable Diffusion 完整制作过程分享

大家好&#xff0c;我是设计师阿威。 今天给大家分享一篇【真人Q版卡通手办】风格的制作过程&#xff0c;话不多说&#xff0c;进入正题。 成品预览 手绘线稿 首先&#xff0c;我使用的是老款手绘软件【SAI】&#xff0c;用[钢笔工具]进行了人物的线稿Q版描绘。&#x1f447…

最大负载1kg!高度模块化设计!大象机器人智能遥控操作机械臂组合myArm MC

引入 近年来&#xff0c;市面上涌现了许多类似于斯坦福大学的 Alopha 机器人项目&#xff0c;这些项目主要通过模仿人类的运动轨迹来进行学习&#xff0c;实现了仿人类的人工智能。Alopha 机器人通过先进的算法和传感技术&#xff0c;能够精确复制人类的动作&#xff0c;并从中…

二、使用Django创建一个基础应用

职位管理系统 - 建模 职位名称类别工作地点职位职责职位要求发布人发布日期修改日期 安装django pip install django5.0查看django版本 python -m django --version创建项目 django-admin startproject recruitment启动服务 python manage.py runserver 0.0.0.0:8000创建…