AlphaFold3 rigid_utils 模块Rotation类的
map_tensor_fn方法
主要作用是对旋转矩阵或四元数上的最后一维应用一个函数 (fn
) ,并返回一个新的 Rotation
对象。
源代码:
def map_tensor_fn(self, fn: Callable[torch.Tensor, torch.Tensor]) -> Rotation:"""Apply a Tensor -> Tensor function to underlying rotation tensors,mapping over the rotation dimension(s). Can be used e.g. to sum outa one-hot batch dimension.Args:fn:A Tensor -> Tensor function to be mapped over the Rotation Returns:The transformed Rotation object""" if(self._rot_mats is not None):rot_mats = self._rot_mats.view(self._rot_mats.shape[:-2] + (9,))rot_mats = torch.stack(list(map(fn, torch.unbind(rot_mats, dim=-1))), dim=-1)rot_mats = rot_mats.view(rot_mats.shape[:-1] + (3, 3))return Rotation(rot_mats=rot_mats, quats=None)elif(self._quats is not None):quats = torch.stack(list(map(fn, torch.unbind(self._quats, dim=-1))), dim=-1)return Rotation(rot_mats=None, quats=quats, normalize_quats=False)else:raise ValueError("Both rotations are None")
代码解读:
方法签名
def map_tensor_fn(self, fn: Callable[torch.Tensor, torch.Tensor]) -> Rotation:
-
fn
:接收一个Tensor
,返回一个Tensor
,典型用途是对旋转的某个维度做变换,比如求和、加权平均等。 -
返回值:一个新的
Rotation
对象,里面装着变换后的旋转矩阵 (rot_mats
) 或四元数 (quats
)。
处理旋转矩阵 (_rot_mats
)
如果 self._rot_mats
存在,就走这条分支:
if self._rot_mats is not None:# 把 (batch_size, ..., 3, 3) reshape 成 (batch_size, ..., 9)rot_mats = self._rot_mats.view(self._rot_mats.shape[:-2] + (9,))
✅ 解释:
view()
是为了把 3x3
的旋转矩阵摊平成 9 维向量,方便对最后一维应用函数。
rot_mats = torch.stack(list(map(fn, torch.unbind(rot_mats, dim=-1))), dim=-1
)
✅ 解释:
-
torch.unbind()
:沿最后一维解开成 9 个独立的张量。 -
map(fn, ...)
:对每个解开的张量应用fn
。 -
torch.stack()
:把变换后的 9 个张量重新堆叠回去。
注: torch.unbind 维度 -1 ,torch.stack 维度 +1, 并且都处理相同的维度(-1)。
rot_mats = rot_mats.view(rot_mats.shape[:-1] + (3, 3))
return Rotation(rot_mats=rot_mats, quats=None)
✅ 解释:
把 9 维向量重新 reshaped 成 (3, 3)
矩阵,并用它创建一个新的 Rotation
对象。
处理四元数 (_quats
)
如果矩阵不存在,走四元数分支:
elif self._quats is not None:quats = torch.stack(list(map(fn, torch.unbind(self._quats, dim=-1))), dim=-1)return Rotation(rot_mats=None, quats=quats, normalize_quats=False)
✅ 解释:
-
逻辑和矩阵类似,先
unbind()
分解四元数的最后一维,对每个部分应用fn()
,再stack()
堆叠回来。 -
创建新
Rotation
对象时加了normalize_quats=False
,说明这一步不需要再归一化。
防错处理
如果两个旋转表示都没有,抛出异常:
else:raise ValueError("Both rotations are None")
总结
map_tensor_fn()
是一种 高阶函数,它能灵活地对旋转矩阵或四元数的最后一维执行各种操作(比如求和、加权、归一化、剪裁等)。
核心逻辑:
-
矩阵路径 → reshape(9维) → 分解 → 应用函数 → 堆叠 → 恢复3x3
-
四元数路径 → 分解 → 应用函数 → 堆叠