0. 引言
前几天分几篇博文精细地讲述了《von Mises-Fisher 分布》, 以及相应的 PyTorch 实现《von Mises-Fisher Distribution (代码解析)》, 其中以 Uniform 分布为例简要介绍了 torch.distributions
包的用法. 本以为已经可以了, 但这两天看到论文 The Power Spherical distribution 的代码, 又被其实现分布的方式所吸引.
Power Spherical 分布与 von Mises Fisher 分布类似, 只不过将后者概率密度函数中的指数函数换成了多项式函数: f p ( x ; μ , κ ) ∝ e x p ( κ μ ⊺ x ) ⇓ f p ( x ; μ , κ ) ∝ ( 1 + μ ⊺ x ) κ \begin{aligned} f_p(\bm{x}; \bm{\mu}, \kappa) &\propto exp(\kappa \bm{\mu}^\intercal \bm{x}) \\ &\Downarrow\\ f_p(\bm{x}; \bm{\mu}, \kappa) &\propto (1+\bm{\mu}^\intercal \bm{x})^\kappa \\ \end{aligned} fp(x;μ,κ)fp(x;μ,κ)∝exp(κμ⊺x)⇓∝(1+μ⊺x)κ 采样框架基本一致, 且这么做可以使边缘 t t t 的线性变换 t + 1 2 ∼ B e t a ( p − 1 2 + κ , p − 1 2 ) \frac{t+1}{2} \sim Beta(\frac{p-1}{2}+\kappa, \frac{p-1}{2}) 2t+1∼Beta(2p−1+κ,2p−1), 从而避免了接受-拒绝采样过程.
当然, 按照之前的 VonMisesFisher
的写法, 这个 t
的采样大概是这样:
z = beta.sample(sample_shape)
t = 2 * z - 1
但现在我遇到了这种写法:
class MarginalTDistribution(tds.TransformedDistribution):arg_constraints = {'dim': constraints.positive_integer,'scale': constraints.positive,}has_rsample = Truedef __init__(self, dim, scale, validate_args=None):self.dim = dimself.scale = scalesuper().__init__(tds.Beta( # 用 Beta 分布转换, z 服从 Beta(α+κ,β)(dim - 1) / 2 + scale, (dim - 1) / 2, validate_args=validate_args),transforms=tds.AffineTransform(loc=-1, scale=2), # t=2z-1 是想要的边缘分布随机数)
然后就可以进行对 t t t 的采样了.
架构大概是这样的: 一个基本分布类 distributions.Beta
和一个转换 transforms.AffineTransform
, 输入到 TransformedDistribution
的子类 MarginalTDistribution
中, 通过对一个 B e t a Beta Beta 的线性转换, 实现边缘分布 t t t.
我们可以看到其基本架构, 本文将详细解析其内部的具体细节, 包括:
1. Distribution
在之前的 <von Mises-Fisher Distribution (代码解析)> 中, 已经通过 Uniform
简单介绍了 Distribution
的用法. 它是实现各种分布的抽象基类. 本文将以解析源码的方式详细介绍.
1.1 参数验证 validate_args
打开源码, 首先映入眼帘的是关于参数验证的代码:
# true if Python was not started with an -O option. See also the assert statement.
_validate_args = __debug__@staticmethod
def set_default_validate_args(value: bool) -> None:"""设置 validation 是否开启.validation 通常是耗时的, 所以最好在模型 work 后关闭它."""if value not in [True, False]:raise ValueErrorDistribution._validate_args = value
Distribution
有一个类属性叫 _validate_args
, 默认值是 __debug__
(见附录1), 可以通过类静态方法 set_default_validate_args(value: bool)
来修改此值.
构造方法 __init__(...)
中的验证逻辑:
def __init__(self, ..., validate_args: Optional[bool]=None):...if validate_args is not None:self._validate_args = validate_args
也就是说, 你可以在创建 Distribution
实例的时候设置是否进行参数验证. 如果不设置, 则按照类的属性 Distribution._validate_args
来.
if self._validate_args: # validate_args=False 就不用设置 arg_constraints 了try: # 尝试获取字典 arg_constraintsarg_constraints = self.arg_constraintsexcept NotImplementedError: # 如果没设置, 则设置为 {}, 抛出警告arg_constraints = {}warnings.warn(...)
如果需要验证参数, 那么首先要获取一个叫 arg_constraints
的参数验证字典, 它列出了需要验证哪些参数. 这个抽象类里面并没有给出, 需要用户继承该类时写在子类中. 以 Uniform
为例:
class Uniform(Distribution):...arg_constraints = {"low": constraints.dependent(is_discrete=False, event_dim=0),"high": constraints.dependent(is_discrete=False, event_dim=0),}...
至于 constraints.dependent
是啥, 后面会详细介绍. 值得注意的是, 如果你在创建实例时指定 validate_args=False
, 那么所有关于参数验证的事就都不用管了.
for param, constraint in arg_constraints.items():if constraints.is_dependent(constraint):continue # skip constraints that cannot be checkedif param not in self.__dict__ and isinstance(getattr(type(self), param), lazy_property):continue # skip checking lazily-constructed argsvalue = getattr(self, param) # 从当前对象获取参数 valuevalid = constraint.check(value) # 检查参数值if not valid.all(): # 检查不通过raise ValueError(...)
这一段就是验证过程了, 包括:
- skip constraints that cannot be checked, 由
constraints.is_dependent(constraint)
判断是否可验证; - skip checking lazily-constructed args, 即参数名不在
self.__dict__
中, 并属于lazy_property
的跳过; - 获得参数, 进行验证;
具体的验证细节将在后面介绍.
1.2 batch_shape
& event_shape
除了 validate_args
参数, __init__(...)
方法中的另外两个参数就是:
def __init__(self,batch_shape: torch.Size = torch.Size(),event_shape: torch.Size = torch.Size(),
):self._batch_shape = batch_shapeself._event_shape = event_shape...
这两个参数是啥? 在这个抽象类中, 我们看不到太多信息, 甚至 Uniform
中也只有 batch_shape = self.low.size()
的信息, 大概意思同时进行着一批的均匀分布, 如 low = torch.tensor([0.0, 1.0])
时, batch_shape = torch.Size([2])
, 表示一个二元的均匀分布. 看 MultivariateNormal
, 里面信息量较大:
batch_shape = torch.broadcast_shapes(covariance_matrix.shape[:-2], # [:-2]是去掉了协方差矩阵的维度, 剩下的可能是 batch 的维度loc.shape[:-1] # [:-1]是去掉了 envent 的维度, 剩下的可能是 batch 的维度
) # broadcast_shapes 意思是进行了广播, 如果 matrix 的 batch_shape 是 [2,1], loc 的 batch_shape 是 [1,2], 那么整个的 batch_shape 是广播后的 [2,2]
self.covariance_matrix = covariance_matrix.expand(batch_shape + (-1, -1)) # 之后 covariance_matrix 都被 expand 了
...
event_shape = self.loc.shape[-1:] # 看来就是样本的 shape
从这一段来看, batch_shape
是指创建的实例在进行多少个平行的基本分布, 而 event_shape
是指基本分布的事件(支撑点)维度. 如:
locs = torch.randn(2, 3)
matrixs = torch.randn(2, 3, 3)
covariance_matrixs = torch.bmm(matrixs, matrixs.transpose(1, 2))normal = distributions.MultivariateNormal(loc=locs, covariance_matrix=covariance_matrixs)
print(normal.batch_shape) # 2
print(normal.event_shape) # 3
print(normal.sample())
##### output #####
torch.Size([2])
torch.Size([3])
tensor([[ 1.8972, -0.3961, -0.1530],[-0.5018, -2.5110, 0.1293]])
batch 的意思还是那个 batch, 不过这里是指分布的 batch, 而不是数据的 batch. 采样时, 得到一批 samples, 对应每个分布.
还有一个 method 和这两个参数有关: expand
, 因为它是一个抽象 method, 基类中并没有实现, 那就直接看 MultivariateNormal
中的:
def expand(self, batch_shape: torch.Size, _instance=None):"""Args:batch_shape (torch.Size): the desired expanded size._instance: new instance provided by subclasses that need to override `.expand`.Returns:New distribution instance with batch dimensions expanded to `batch_size`."""new = self._get_checked_instance(MultivariateNormal, _instance)batch_shape = torch.Size(batch_shape)loc_shape = batch_shape + self.event_shapecov_shape = batch_shape + self.event_shape + self.event_shapenew.loc = self.loc.expand(loc_shape)new._unbroadcasted_scale_tril = self._unbroadcasted_scale_trilif "covariance_matrix" in self.__dict__:new.covariance_matrix = self.covariance_matrix.expand(cov_shape)if "scale_tril" in self.__dict__:new.scale_tril = self.scale_tril.expand(cov_shape)if "precision_matrix" in self.__dict__:new.precision_matrix = self.precision_matrix.expand(cov_shape)super(MultivariateNormal, new).__init__(batch_shape, self.event_shape, validate_args=False)new._validate_args = self._validate_argsreturn new
这个 method 会创建一个新的 instance 或调用的时候用户提供, 并设置 batch_shape
为参数提供的形状, 然后把参数 expand
到新的 batch_shape
. 用法:
mean = torch.randn(3)
matrix = torch.randn(3, 3)
covariance_matrix = torch.mm(matrix, matrix.t())mvn = MultivariateNormal(mean, covariance_matrix)
bmvn = mvn.expand(torch.Size([2]))print(bmvn.batch_shape)
print(bmvn.event_shape)
print(bmvn.sample())##### output #####
torch.Size([2])
torch.Size([3])
tensor([[-4.0891, -4.2424, 6.2574],[ 0.7656, -0.2199, -0.9836]])
1.3 一些属性
包括: m e a n mean mean, m o d e mode mode, s t d std std, v a r i a n c e variance variance, e n t r o p y entropy entropy 等基本属性, 都需要用户在子类中自己实现. 还有一些相关的函数:
- cumulative density/mass function
cdf(value)
; - inverse cumulative density/mass function
icdf(value)
;
这个函数非常有用, Inverse Transform Sampling 中用其进行采样. 从 U ( 0 , 1 ) U(0,1) U(0,1) 中采样一个 u u u, 然后令 x = F − 1 ( u ) x = F^{-1}(u) x=F−1(u) 就是所求随机变量 X X X 的一个采样. - log of the probability density/mass function
log_prob(value)
, 对数概率.
注意, 目前看到的只有 log_prob
, 并没有 prob
, 一些示例要么只算 log_prob
, 要么计算后通过 exp(log_prob)
得到 prob
.
2. constraints.Constraint
前面在1.1参数验证中已经遇到 constraints.dependent(is_discrete=False, event_dim=0)
和 constraint.check(value)
, 但没有讲具体细节. 本节将详细剖析.
2.1 抽象基类 Constraint
先看源码:
class Constraint:"""一个 constraint 对象, 表示变量在某区域内有效, 即变量可优化的范围."""is_discrete = False # Default to continuous.event_dim = 0 # Default to univariate.def check(self, value):"""结果的形状为"sample_shape + batch_shape", 指示 each event 值是否满足此限制."""raise NotImplementedError
这是抽象基类 Constraint
, 比较简单, 只有两个类属性和一个 method check(value)
. is_discrete
表示待验证值是否为离散; 联想前面的 event_shape
, 大概可以知道 event_dim
是指 len(event_shape)
.(不过目前看只是为了验证参数, 还能验证采样的 event?)
2.2 _Dependent()
不被验证
这个基类信息太少, 对我们理解前面的内容毫无用处, 还是直接观察一些子类吧. 从 dependent = _Dependent()
开始, 它是 constraints.py
中定义好的 placeholder(这个倒是可以学一学):
class _Dependent(Constraint): # 看"_", 应该是不希望用户直接创建实例"""Placeholder for variables whose support depends on other variables.These variables obey no simple coordinate-wise constraints."""def __init__(self, *, is_discrete=NotImplemented, event_dim=NotImplemented):self._is_discrete = is_discreteself._event_dim = event_dimsuper().__init__()def __call__(self, *, is_discrete=NotImplemented, event_dim=NotImplemented):"""Support for syntax to customize static attributes::constraints.dependent(is_discrete=True, event_dim=1)"""if is_discrete is NotImplemented: # 未提供就是默认is_discrete = self._is_discreteif event_dim is NotImplemented:event_dim = self._event_dimreturn _Dependent(is_discrete=is_discrete, event_dim=event_dim)def check(self, x):raise ValueError("Cannot determine validity of dependent constraint")
闹了半天, 我们并不能看到 constraints.dependent(is_discrete=False, event_dim=0)
有什么卵用, 只知道 “Cannot determine validity of dependent constraint”, 这也呼应了前面的:
if constraints.is_dependent(constraint):continue # skip constraints that cannot be checked
也就是说, dependent
类型的限制是不会执行参数验证的. 那这个 _Dependent
到底有何用处? 先不管了.
2.3 _IndependentConstraint
重新解释 event_dim
我们看点复杂的, MultivariateNormal.arg_constraints
:
arg_constraints = {"loc": constraints.real_vector,"covariance_matrix": constraints.positive_definite,"precision_matrix": constraints.positive_definite,"scale_tril": constraints.lower_cholesky,
}
这些都是 constraints.py
中定义好的实例, 对于大多情况, 这些预定义好的实例已经够用, 但如果需要, 你也可以自定义. 先看 real_vector
:
independent = _IndependentConstraint
real_vector = independent(real, 1)
class _IndependentConstraint(Constraint):"""封装一个 constraint, 通过 aggregating over ``reinterpreted_batch_ndims``-many dims in :meth:`check`,an event is valid 当且仅当它依赖的所有 entries 是 valid 的."""def __init__(self, base_constraint, reinterpreted_batch_ndims):self.base_constraint = base_constraintself.reinterpreted_batch_ndims = reinterpreted_batch_ndimssuper().__init__()@propertydef event_dim(self):# real.event_dim 是 0, + real_vector(reinterpreted_batch_ndims=1) = 1return self.base_constraint.event_dim + self.reinterpreted_batch_ndimsdef check(self, value):result = self.base_constraint.check(value) # 首先要符合 base.checkif result.dim() < self.reinterpreted_batch_ndims:# 给 batch 留够 dimexpected = self.base_constraint.event_dim + self.reinterpreted_batch_ndimsraise ValueError(f"Expected value.dim() >= {expected} but got {value.dim()}")result = result.reshape( # 减掉 eventresult.shape[: result.dim() - self.reinterpreted_batch_ndims] + (-1,))result = result.all(-1) # 减少一个 dimreturn result
意思很明了了, real_vector
是依赖于 real
(base_constraint) 的, reinterpreted_batch_ndims=1
是说把原来的 value
重新解释, event_dim
加上 reinterpreted_batch_ndims
, 比如
value = [[1, 2, 3],[4, 5, 6]]
本来 real
的 event_dim=0
, 验证结果为(sample_shape + batch_shape = (2,2)
):
value = [[True, True, True],[True, True, True]]
现在重新解释为 event_dim=1
, 验证结果为:
result = result.reshape( # 减掉 eventresult.shape[: result.dim() - self.reinterpreted_batch_ndims] + (-1,) # (-1,) 表示新 event 内的所有 entries 展平
)
result = result.all(-1) # 新 event 内的所有 entries 为 True, 则新 event 为 True
================>
value = [True, True]
3. Transform
& _InverseTransform
上一节介绍了 constraints.Constraint
, 明白了在构建 Distribution
实例时进行的参数验证, 以保证用户提供的参数符合要求. 但还留下了一个疑问: Constraint
中的 event_dim
是指 len(event_shape)
, 难道还能验证采样的 event? 再者, check(value)
返回值的形状是 sample_shape + batch_shape
, 进一步说明它是会被用于采样结果检查的. 让我们看一看能否在 Transform
中找到答案.
Transform
& _InverseTransform
是一对互逆的操作, 实现从一个分布到另一个分布的转换. 这很有用, 因为 distributions
包已经实现了很多常见分布和转换, 自由组合威力巨大. 本节将详细介绍它是如何实现对分布的转换的.
[注] 从 _InverseTransform
的_
来看, 是不需要用户了解它的.
3.1 抽象类 Transform
的基本信息
class Transform:"""变换的抽象基类, 子类应该实现 one or both of `_call` or `_inverse`.如果 `bijective=True`, 则必须实现 `log_abs_det_jacobian`.Args:cache_size (int): If one, the latest single value is cached.Only 0 and 1 are supported."""bijective = False # Transform 是否双射, 默认 Falsedomain: constraints.Constraint # 有效输入范围codomain: constraints.Constraint # 有效输出范围def __init__(self, cache_size=0):self._cache_size = cache_sizeself._inv = Noneif cache_size == 0:pass # default behaviorelif cache_size == 1:self._cached_x_y = None, Noneelse:raise ValueError("cache_size must be 0 or 1")super().__init__()
果然, Transform
中有 Constraint
的, 分别是 domain
和 codomain
, 用于其检查输入输出是否符合要求. 此外, 还有 bijective
和 cache_size
这两个信息, 等一下看后面怎么说.
3.2 AffineTransform
抽象类的基本信息不多, 还是要看一个简单的例子: AffineTransform
, 线性变换.
class AffineTransform(Transform):bijective = Truedef __init__(self, loc, scale, event_dim=0, cache_size=0):super().__init__(cache_size=cache_size)self.loc = locself.scale = scaleself._event_dim = event_dim
线性变换是可逆的, 可以看到它的 bijective = True
. 参数是 y = l o c + s c a l e × x y = loc + scale × x y = loc + scale × x 中的 loc
和 scale
; event_dim
则是用于构建 domain
和 codomain
:
@constraints.dependent_property(is_discrete=False)
def domain(self):if self.event_dim == 0:return constraints.realreturn constraints.independent(constraints.real, self.event_dim)@constraints.dependent_property(is_discrete=False)
def codomain(self):if self.event_dim == 0:return constraints.realreturn constraints.independent(constraints.real, self.event_dim)
即, domain
和 codomain
被限制为 event_dim
维向量, 默认是 0
, 输入输出皆为标量.
变换过程
def _call(self, x):"""Method to compute forward transformation."""return self.loc + self.scale * xdef _inverse(self, y):"""Method to compute inverse transformation."""return (y - self.loc) / self.scale
由于是双射, 还要实现:
def log_abs_det_jacobian(self, x, y):shape = x.shapescale = self.scaleif isinstance(scale, numbers.Real):result = torch.full_like(x, math.log(abs(scale)))else:result = torch.abs(scale).log()if self.event_dim:result_size = result.size()[: -self.event_dim] + (-1,)result = result.view(result_size).sum(-1)shape = shape[: -self.event_dim]return result.expand(shape)
计算结果的形状
@propertydef inv(self):"""Returns the inverse :class:`Transform` of this transform.This should satisfy ``t.inv.inv is t``."""inv = Noneif self._inv is not None:inv = self._inv()if inv is None:inv = _InverseTransform(self)self._inv = weakref.ref(inv)return invdef __call__(self, x):"""Computes the transform `x => y`."""if self._cache_size == 0:return self._call(x)x_old, y_old = self._cached_x_yif x is x_old:return y_oldy = self._call(x)self._cached_x_y = x, yreturn ydef _inv_call(self, y):"""Inverts the transform `y => x`."""if self._cache_size == 0:return self._inverse(y)x_old, y_old = self._cached_x_yif y is y_old:return x_oldx = self._inverse(y)self._cached_x_y = x, yreturn x
附录
1. __debug__
和 assert
(来自 Kimi)
__debug__
是一个内置变量,用于指示 Python 解释器是否处于调试模式。当 Python 以调试模式运行时,__debug__
被设置为 True
;否则,在优化模式下运行时,它被设置为 False
。
__debug__
可以用于条件性地执行调试代码,例如:
if __debug__:print("Debug mode is on, performing extra checks...")# 这里可以放一些只在调试模式下运行的代码,比如详细的日志记录# 或者复杂的验证逻辑
else:print("Debug mode is off.")
在上面的例子中,如果命令行执行:
python -O myscript.py
##### output #####
Debug mode is off.
------------------------------------------------------
python myscript.py
##### output #####
Debug mode is on, performing extra checks...
assert
语句受 __debug__
影响:
def calculate(a, b):# 这个 assert 在 __debug__ 为 True 时执行assert a > 0 and b > 0, "Both inputs must be positive."# 正常的函数逻辑return a * b# 在这里,assert 会检查输入是否为正数
result = calculate(5, 3)
print(result)# 如果我们改变条件使 assert 失败
# result = calculate(-1, 3) # 这会触发 AssertionError,除非运行时 __debug__ 为 False