目录
截断正态分布来初始化张量
逐行代码解释
相关理论解释
截断正态分布函数
截断正态分布的定义
截断正态分布的作用
计算截断点的作用
具体步骤
正态分布的累积分布函数(CDF)
正态分布的累积分布函数与误差函数的关系
示例计算
误差函数
应用:
定义:
误差函数的性质
Python 中的误差函数
总结
截断正态分布来初始化张量
import math
import warnings
import torchdef _no_grad_trunc_normal_(tensor, mean, std, a, b):def norm_cdf(x):return (1. + math.erf(x / math.sqrt(2.))) / 2.if (mean < a - 2 * std) or (mean > b + 2 * std):warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. ""The distribution of values may be incorrect.",stacklevel=2)with torch.no_grad():l = norm_cdf((a - mean) / std)u = norm_cdf((b - mean) / std)tensor.uniform_(2 * l - 1, 2 * u - 1)tensor.erfinv_()tensor.mul_(std * math.sqrt(2.))tensor.add_(mean)tensor.clamp_(min=a, max=b)return tensor
逐行代码解释
1、正态分布的累积分布函数(CDF):norm_cdf
函数计算标准正态分布的累积分布函数。
def norm_cdf(x):return (1. + math.erf(x / math.sqrt(2.))) / 2.
2、警告:检查均值是否在截断边界 [a, b]
的2个标准差范围内,如果不在,则发出警告。
if (mean < a - 2 * std) or (mean > b + 2 * std):warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. ""The distribution of values may be incorrect.",stacklevel=2)
3、不跟踪梯度:以下代码块确保初始化时不跟踪梯度,这对于设置神经网络的初始权重很有用。
with torch.no_grad():l = norm_cdf((a - mean) / std)u = norm_cdf((b - mean) / std)tensor.uniform_(2 * l - 1, 2 * u - 1)tensor.erfinv_()tensor.mul_(std * math.sqrt(2.))tensor.add_(mean)tensor.clamp_(min=a, max=b)return tensor
l
和u
是截断点a
和b
处的累积分布函数值。tensor.uniform_(2 * l - 1, 2 * u - 1)
用从指定范围的均匀分布生成的值初始化张量。tensor.erfinv_()
对张量应用误差函数的逆函数。tensor.mul_(std * math.sqrt(2.))
将张量的值缩放到期望的标准差。tensor.add_(mean)
将张量的值平移到期望的均值。tensor.clamp_(min=a, max=b)
确保张量中的所有值都在指定的截断范围[a, b]
之内。
相关理论解释
截断正态分布函数
截断正态分布的定义
给定一个均值为 μ、标准差为 σ 的正态分布 N(μ,σ2),截断正态分布在区间 [a,b] 上的定义如下:
其中,ϕ(x) 是正态分布的概率密度函数(PDF),Φ(x)是正态分布的累积分布函数(CDF)。
截断正态分布的作用
- 限制范围:确保生成的随机变量值在某个指定范围内,这对于物理约束或特定应用场景非常重要。
- 防止异常值:避免生成不合实际或有害的极端值,例如在神经网络权重初始化时防止极端值导致的训练不稳定。
计算截断点的作用
在实现截断正态分布时,我们需要计算截断点 a
和 b
对应的累积分布函数值 l 和 u,以便生成满足截断条件的随机数。
具体步骤
-
标准化:将截断点
a
和b
标准化为标准正态分布中的值。 -
计算标准正态分布的 CDF:计算标准正态分布在标准化后的截断点
l
和u
处的累积分布函数值。注意:此处有一个性质,就是随机变量Φ(l)和Φ(u)是满足[0,1]的均匀分布。 -
转换为均匀分布:生成的均匀分布随机数在 [2Φ(l)−1,2Φ(u)−1] 区间内。
-
逆误差函数:将均匀分布的值通过逆误差函数转换为标准正态分布的值。
tensor.erfinv()
-
缩放和平移:将标准正态分布的值缩放到所需的标准差,并平移到所需的均值。
-
截断:确保所有值都在 [a,b] 区间内。
正态分布的累积分布函数(CDF)
定义:用于计算正态分布从负无穷大到给定值 x的概率。具体而言,对于标准正态分布 N(0,1),CDF 表示为:
正态分布的累积分布函数与误差函数的关系
在代码中,我们通过误差函数(erf)来计算标准正态分布的 CDF。误差函数与标准正态分布的 CDF 之间有如下关系:
代码中的 norm_cdf
函数:
def norm_cdf(x):return (1. + math.erf(x / math.sqrt(2.))) / 2.
norm_cdf
函数的实现如下:
- 输入:函数接收一个参数
x
,它是需要计算 CDF 的点。 - 计算误差函数:
math.erf(x / math.sqrt(2.))
计算 \frac{x}{\sqrt{2}} 的误差函数值。 - 调整误差函数值:将误差函数的结果加 1,然后除以 2,得到标准正态分布在
x
点的 CDF 值。
以下是函数的具体步骤:
math.erf(x / math.sqrt(2.))
:计算误差函数 。1. + math.erf(x / math.sqrt(2.))
:将误差函数的结果加 1。(1. + math.erf(x / math.sqrt(2.))) / 2.
:结果除以 2 得到最终的 CDF 值。
示例计算
假设我们需要计算标准正态分布在 x=1处的 CDF 值:
import mathdef norm_cdf(x):return (1. + math.erf(x / math.sqrt(2.))) / 2.x = 1
cdf_value = norm_cdf(x)
print("CDF value at x = 1:", cdf_value)
运行以上代码,会输出 x=1处的 CDF 值,即:
CDF value at x = 1: 0.8413447460685429
这意味着在标准正态分布中,小于等于 1 的值的概率大约为 0.8413。
误差函数
应用:
数学上用于处理正态分布和概率问题的重要函数。误差函数用于计算某个值在标准正态分布中的概率,并且在统计学、概率论和许多应用数学领域中都有广泛应用。
定义:
这个积分没有解析解,因此通常通过数值方法进行计算。
误差函数的性质
- 对称性:误差函数是奇函数,即 。
- 值域:误差函数的值域在 −1 到 1 之间,即 −1≤erf(x)≤1。
- 边界值:当 x→∞ 时,erf(x)→1;当 x→−∞时,erf(−x)→−1。
Python 中的误差函数
在 Python 中,可以使用 math
模块中的 erf
函数来计算误差函数值。以下是一个示例:
import mathx = 1.0
erf_value = math.erf(x)
print("erf(1.0) =", erf_value)
运行结果是:
erf(1.0) = 0.8427007929497149
这意味着当x=1.0 时,erf(1.0)的值大约为 0.8427。