欢迎关注我的CSDN:https://spike.blog.csdn.net/
本文地址:https://spike.blog.csdn.net/article/details/134333419
在蛋白质复合物结构预测中,在 TemplatePairEmbedderMultimer 层中 ,构建 Template Pair 特征的源码,即:
- 将特征
template_dgram
、pseudo_beta_mask_2d
、aatype_one_hot
、backbone_mask_2d
、unit_vector(x/y/z)
特征,通过 linear 层累加到一起。 - 其中,都需要使用
multichain_mask_2d
进行固定掩码,选择单链区域。 - 输出维度:
([1, 1102, 1102, 64])
,linear层的输出c_out
维度是 64。
源码如下:
def forward(self,template_dgram: torch.Tensor,aatype_one_hot: torch.Tensor,query_embedding: torch.Tensor,pseudo_beta_mask: torch.Tensor,backbone_mask: torch.Tensor,multichain_mask_2d: torch.Tensor,unit_vector: geometry.Vec3Array,
) -> torch.Tensor:act = 0.0pseudo_beta_mask_2d = (pseudo_beta_mask[..., None] * pseudo_beta_mask[..., None, :])pseudo_beta_mask_2d = pseudo_beta_mask_2d * multichain_mask_2dtemplate_dgram = template_dgram * pseudo_beta_mask_2d[..., None]act += self.dgram_linear(template_dgram)act += self.pseudo_beta_mask_linear(pseudo_beta_mask_2d[..., None])aatype_one_hot = aatype_one_hot.to(template_dgram.dtype)act += self.aatype_linear_1(aatype_one_hot[..., None, :, :])act += self.aatype_linear_2(aatype_one_hot[..., None, :])backbone_mask_2d = backbone_mask[..., None] * backbone_mask[..., None, :]backbone_mask_2d = backbone_mask_2d * multichain_mask_2dx, y, z = [coord * backbone_mask_2d for coord in unit_vector]act += self.x_linear(x[..., None])act += self.y_linear(y[..., None])act += self.z_linear(z[..., None])act += self.backbone_mask_linear(backbone_mask_2d[..., None])query_embedding = self.query_embedding_layer_norm(query_embedding)act += self.query_embedding_linear(query_embedding)return act
template_dgram
特征:
template_dgram
特征与 multichain_mask_2d
:
backbone_mask_2d
特征:
backbone_mask_2d
特征与 multichain_mask_2d
:
写入特征,即:
tmp_dict = dict()
tmp_dict["pseudo_beta_mask_2d_prev"] = pseudo_beta_mask_2d.cpu().numpy()
tmp_dict["pseudo_beta_mask_2d_post"] = pseudo_beta_mask_2d.cpu().numpy()
tmp_dict["template_dgram_post"] = template_dgram.cpu().numpy()
tmp_dict["backbone_mask_2d_prev"] = backbone_mask_2d.cpu().numpy()
tmp_dict["backbone_mask_2d_post"] = backbone_mask_2d.cpu().numpy()import pickle
with open("template_pair_embedder_multimer.pkl", "wb") as f:pickle.dump(tmp_dict, f)
logger.info(f"[CL] saved template_pair_embedder_multimer!")
读取特征,即:
def load_tensor_dict(input_path):"""加载特征文件['template_dgram', 'z', 'pseudo_beta_mask', 'backbone_mask', 'multichain_mask_2d','unit_vector_x', 'unit_vector_y', 'unit_vector_z']"""import picklewith open(input_path, "rb") as f:obj = pickle.load(f)print(f"[Info] feat_dict: {obj.keys()}")return objdef process_template_pair_embedder_multimer_dict(feat_dict, output_dir):print(f"[Info] feat_dict.keys: {feat_dict.keys()}")draw_tensor_2d(feat_dict["pseudo_beta_mask_2d_prev"], os.path.join(output_dir, "pseudo_beta_mask_2d_prev.png"))draw_tensor_2d(feat_dict["pseudo_beta_mask_2d_post"], os.path.join(output_dir, "pseudo_beta_mask_2d_prev.png"))draw_template_dgram(feat_dict["template_dgram_post"], os.path.join(output_dir, "template_dgram_post.png"))draw_tensor_2d(feat_dict["backbone_mask_2d_prev"], os.path.join(output_dir, "backbone_mask_2d_prev.png"))draw_tensor_2d(feat_dict["backbone_mask_2d_post"], os.path.join(output_dir, "backbone_mask_2d_post.png"))def draw_tensor_2d(feat, output_path):"""backbone_mask: torch.Size([1, 1102])"""feat = np.squeeze(feat)f, ax_arr = plt.subplots(1, 1, figsize=(8, 5))im = ax_arr.imshow(feat)f.colorbar(im, ax=ax_arr)plt.savefig(output_path, bbox_inches='tight', format='png')plt.show()def draw_template_dgram(feat, output_path):"""template_dgram: torch.Size([1, 1102, 1102, 39])"""f, ax_arr = plt.subplots(6, 7, figsize=(24, 15))ax_arr = ax_arr.flatten()feat = np.squeeze(feat)print(f"[Info] feat: {feat.shape}")for i in range(0, 42):if i <= 38:im = ax_arr[i].imshow(feat[:, :, i], interpolation='none')f.colorbar(im, ax=ax_arr[i])else:ax_arr[i].set_axis_off()plt.savefig(output_path, bbox_inches='tight', format='png')plt.show()