欢迎关注我的CSDN:https://spike.blog.csdn.net/
本文地址:https://spike.blog.csdn.net/article/details/132597659
OpenFold Multimer 是基于深度学习的方法,预测蛋白质的多聚体结构和相互作用。利用大规模的蛋白质序列和结构数据,以及先进的神经网络架构,来学习蛋白质的表示和特征。可以处理不同类型的多聚体,包括同源和异源多聚体,以及复杂的蛋白质-蛋白质相互作用网络。OpenFold Multimer 的目标是为生物学家提供一个快速、准确和易用的工具,来探索蛋白质的多聚体功能和机制。
训练参数:
python3 train_openfold.py \--train_data_dir [your folder]/af2-data-v230/pdb_mmcif/mmcif_files/ \--train_alignment_dir mydata/alignment_dir/ \--train_mmcif_data_cache_path mmcif_cache.json \--template_mmcif_dir [your folder]/af2-data-v230/pdb_mmcif/mmcif_files/ \--output_dir mydata/output_dir/ \--max_template_date "2021-10-10" \--config_preset "model_1_multimer_v3" \--template_release_dates_cache_path mmcif_cache.json \--precision bf16 \--gpus 1 \--replace_sampler_ddp=True \--seed 42 \--deepspeed_config_path deepspeed_config.json \--checkpoint_every_epoch \--train_chain_data_cache_path chain_data_cache.json \--obsolete_pdbs_file_path [your folder]/af2-data-v230/pdb_mmcif/obsolete.dat
1. train_alignment_dir
核心关注 train_alignment_dir
,这部分是缓存的预处理特征,调用路径如下:
train_openfold.py
的args
参数,传入OpenFoldMultimerDataModule
类- 再由
dataset_gen()
方法,也就是OpenFoldSingleMultimerDataset
类,接收 - 参数由
alignment_dir=self.train_alignment_dir
,转换成alignment_dir
- 再由
OpenFoldMultimerDataModule
类,调用OpenFoldSingleMultimerDataset
类
即
# train_openfold.py
# ...
if "multimer" in args.config_preset:data_module = OpenFoldMultimerDataModule(config=config.data,batch_seed=args.seed,**vars(args))
# ...# openfold/data/data_modules.py#OpenFoldMultimerDataModule
# ...
if self.training_mode:train_dataset = dataset_gen(data_dir=self.train_data_dir,mmcif_data_cache_path=self.train_mmcif_data_cache_path,alignment_dir=self.train_alignment_dir,filter_path=self.train_filter_path,max_template_hits=self.config.train.max_template_hits,shuffle_top_k_prefiltered=self.config.train.shuffle_top_k_prefiltered,treat_pdb_as_distillation=False,mode="train",alignment_index=self.alignment_index,)
# ...
在 OpenFoldSingleMultimerDataset
类中,alignment_dir
用于 _chain_ids
的赋值,即
if alignment_index is not None:self._chain_ids = list(alignment_index.keys())
else:self._chain_ids = list(os.listdir(alignment_dir))
alignment_index_path
支持作为参数,传入,默认是空,相关描述如下,核心是先编译成单个文件,再读入,可以提升效率:
In cases where it may be burdensome to create separate files for each chain’s alignments, alignment directories can be consolidated using the scripts in scripts/alignment_db_scripts/. First, run create_alignment_db.py to consolidate an alignment directory into a pair of database and index files. Once all alignment directories (or shards of a single alignment directory) have been compiled, unify the indices with unify_alignment_db_indices.py. The resulting index, super.index, can be passed to the training script flags containing the phrase alignment_index. In this scenario, the alignment_dir flags instead represent the directory containing the compiled alignment databases. Both the training and distillation datasets can be compiled in this way. Anecdotally, this can speed up training in I/O-bottlenecked environments.
其中,self._chain_ids
是全部的训练集:
def __len__(self):return len(self._chain_ids)
设置 logger 日志:
import logging
logging.basicConfig()
logger = logging.getLogger(__file__)
logger.setLevel(level=logging.INFO)
训练数据的遍历参数:
def __getitem__(self, idx):mmcif_id = self.idx_to_mmcif_id(idx)chains = self.mmcif_data_cache[mmcif_id]['chain_ids']
根据输出,组织训练数据:
mmcif_id is: 5ykn, idx: 8580 and has 1 chains
mmcif_id is: 2lna, idx: 3848 and has 1 chains
mmcif_id is: 7rrp, idx: 8447 and has 24 chains
mmcif_id is: 6k8h, idx: 7870 and has 2 chains
...
2. OpenFoldSingleMultimerDataset
具体分析 OpenFoldSingleMultimerDataset 类。在 __getitem__
方法中,遍历训练样本,核心关注:
self.idx_to_mmcif_id()
函数调用self._mmcifs[idx]
- 2个关键变量,
self._mmcifs
和self.mmcif_data_cache
,而且两者的 keys 要保持一致。
即:
def __getitem__(self, idx):mmcif_id = self.idx_to_mmcif_id(idx)chains = self.mmcif_data_cache[mmcif_id]['chain_ids']print(f"mmcif_id is: {mmcif_id}, idx: {idx} and has {len(chains)} chains")
关于 self._mmcifs
数据,调用 mmcif_data_cache_path
-> self.mmcif_data_cache
-> self._mmcifs
mmcif_data_cache_path
来源于预处理的过程
即:
# ...
logger.info(f"[CL] mmcif_data_cache_path: {mmcif_data_cache_path}")
if mmcif_data_cache_path is not None:with open(mmcif_data_cache_path, "r") as infile:self.mmcif_data_cache = json.load(infile)assert isinstance(self.mmcif_data_cache, dict)
# ...
if self.mmcif_data_cache is not None:self._mmcifs = list(self.mmcif_data_cache.keys())self._mmcif_id_to_idx_dict = {mmcif: i for i, mmcif in enumerate(self._mmcifs)}
其中 mmcif_cache.json
的文件数据,包括PDB信息,即:
{"4ewn": {"release_date": "2012-12-05","chain_ids": ["D"],"seqs": ["MLAKRI..."],"no_chains": 1,"resolution": 1.9},"5m9r": {"release_date": "2017-02-22","chain_ids": ["A", "B"],"seqs": ["MQDNS...","MQDNS..."],"no_chains": 2,"resolution": 1.44},
# ...
BugFix: 增加 train_mmcif_data_cache_path
参数
--train_mmcif_data_cache_path mmcif_cache.json