AlphaFold3 data_transforms 模块的
squeeze_features 函数的作用去除 蛋白质特征张量中不必要的单维度(singleton dimensions)和重复维度,以使其适配 AlphaFold3 预期的输入格式。
源代码:
def squeeze_features(protein):"""Remove singleton and repeated dimensions in protein features."""protein["aatype"] = torch.argmax(protein["aatype"], dim=-1)for k in ["domain_name","msa","num_alignments","seq_length","sequence","superfamily","deletion_matrix","resolution","between_segment_residues","residue_index","template_all_atom_mask",]:if k in protein:final_dim = protein[k].shape[-1]if isinstance(final_dim, int) and final_dim == 1:if torch.is_tensor(protein[k]):protein[k] = torch.squeeze(protein[k], dim=-1)else:protein[k] = np.squeeze(protein[k], axis=-1)for k in ["seq_length", "num_alignments"]:if k in protein:protein[k] = protein[k][0]return protein
源码解读:
- 该函数接收
protein
(一个 包含蛋白质特征的字典)作为输入。 - 主要任务:
- 将 one-hot
aatype
转换为索引表示。 - 移除 shape 为
(N, ..., 1)
的单维度。 - 提取
seq_length
和num_alignments
的实际数值。
- 将 one-hot
Step 1: 处理 aatype
protein["aatype"] = torch.argmax(protein["aatype"], dim=-1)
- 输入
aatype
(氨基酸类型)通常是 one-hot 编码 - 通过
torch.argmax(..., dim=-1)
获取 索引 - 目的:简化
aatype
的数据表示,使其直接存储氨基酸索引,而不是 one-hot 矩阵。
Step 2: 移除单维度
for k in ["domain_name","msa","num_alignments","seq_length","sequence","superfamily","deletion_matrix","resolution","between_segment_residues","residue_index","template_all_atom_mask",
]:if k in protein:final_dim = protein[k].shape[-1] # 获取最后一维的大小if isinstance(final_dim, int) and final_dim == 1:if torch.is_tensor(protein[k]):protein[k] = torch.squeeze(protein[k], dim=-1) # 去掉单维度else:protein[k] = np.squeeze(protein[k], axis=-1)
- 遍历多个
protein
特征字段,检查它们是否存在。 - 如果最后一维
final_dim
为1
,说明这个维度是无意义的单维度,需要去除:- 如果是 PyTorch 张量(
torch.Tensor
),使用torch.squeeze(dim=-1)
。 - 如果是 NumPy 数组,使用
np.squeeze(axis=-1)
。
- 如果是 PyTorch 张量(
Step 3: 处理 seq_length
和 num_alignments
for k in ["seq_length", "num_alignments"]:if k in protein:protein[k] = protein[k][0]
seq_length
和 num_alignments
可能是 列表或张量,但它们的数值其实是一个单独的整数,因此需要转换成 标量值。
结论
1️⃣ 转换 aatype
: 从 one-hot 编码 转换成 索引表示。
2️⃣ 移除无用的单维度: 让 msa
, resolution
, deletion_matrix
等数据符合 AlphaFold3 预期格式。
3️⃣ 转换 seq_length
和 num_alignments
为标量: 确保它们不会以张量形式存在,而是整数。
💡 最终作用:保证输入数据的维度符合 AlphaFold3 训练时的输入要求,提高数据处理效率。