Sampler类的sample_step
方法的主要目的是根据扩散模型的预测生成在时间步 t-1
上的下一个三维结构、序列和其他相关特征。这是扩散采样过程的核心步骤之一。
源代码:
def sample_step(self, *, t, x_t, seq_init, final_step):'''Generate the next pose that the model should be supplied at timestep t-1.Args:t (int): The timestep that has just been predictedseq_t (torch.tensor): (L,22) The sequence at the beginning of this timestepx_t (torch.tensor): (L,14,3) The residue positions at the beginning of this timestepseq_init (torch.tensor): (L,22) The initialized sequence used in updating the sequence.Returns:px0: (L,14,3) The model's prediction of x0.x_t_1: (L,14,3) The updated positions of the next step.seq_t_1: (L,22) The updated sequence of the next step.tors_t_1: (L, ?) The updated torsion angles of the next step.plddt: (L, 1) Predicted lDDT of x0.'''msa_masked, msa_full, seq_in, xt_in, idx_pdb, t1d, t2d, xyz_t, alpha_t = self._preprocess(seq_init, x_t, t)N,L = msa_masked.shape[:2]if self.symmetry is not None:idx_pdb, self.chain_idx = self.symmetry.res_idx_procesing(res_idx=idx_pdb)msa_prev = Nonepair_prev = Nonestate_prev = Nonewith torch.no_grad():msa_prev, pair_prev, px0, state_prev, alpha, logits, plddt = self.model(msa_masked,msa_full,seq_in,xt_in,