过去的蛋白质语言模型以单个序列为输入,MSA Transformer以多序列比对的形式将一组序列作为输入。该模型将行和列注意力交织在输入序列中,并在许多蛋白质家族中使用mask语言建模目标进行训练。模型的性能远超过了当时最先进的无监督学习方法,其参数效率远高于当时最先进的蛋白质语言模型。
来自:MSA Transformer, ICML2021
目录
- ESM-MSA-1b概述
- 输入
- 绑定行注意力
- 预训练
- 结果
- 无监督接触图预测
- 有监督接触图预测
ESM-MSA-1b概述
ESM-MSA-1b是一个在UR50上训练的掩码语言模型,ESM-MSA-1b在三种下游任务:无监督氨基酸接触图预测、有监督氨基酸接触图预测和二级结构预测上均达到了当时的SOTA水平。
模型采用随机mask,其中,Transformer做了改进,使内存占用降低。模型与普通Transformer的区别是,其特殊的轴向注意力机制,见图1所示。对于蛋白质通用规律的学习,MSA信息中同源序列不同位置上的氨基酸对于当前的氨基酸的权重信息并不大,当前氨基酸最重要的关注点还是同一序列其他氨基酸(行)和同源不同序列同一位置的其他氨基酸(列),因此注意力限制在横纵这两条轴向范围就可以了,可以大大降低时间复杂度。
- 图1左,注意力稀疏结构。通过将注意力限制在行和列上,计算成本从 O ( ( L M ) 2 ) O((LM)^{2}) O((LM)2)降低到 O ( L M 2 ) + O ( L 2 M ) O(LM^{2})+O(L^{2}M) O(LM2)+O(L2M),其中 M M M是MSA中的行数, L L L是列数。
- 中间:未绑定的行注意力对MSA中的每个序列使用不同的注意力。绑定行注意力对MSA中的所有序列使用单个注意力图,从而约束了接触结构。
- 右:一个MSA Transformer块。所描绘的架构来自最终模型。
输入
Transformer是强大的序列模型,能够将信息从任何位置传递到任何其他位置。然而,它们并不适用于一组对齐的序列。在MSA中简单地连接长度为 L L L的 M M M个序列将允许跨所有序列的注意力,但 ( M L ) 2 (ML)^{2} (ML)2的自注意力map将占用大量内存。MSA Transformer主要贡献是将Transformer预训练扩展到在MSA上运行,同时将其结构视为 M × L M\times L M×L特征矩阵。
作者将输入MSA描述为矩阵 x ∈ R M × L x\in\mathbb{R}^{M\times L} x∈RM×L,其中行对应MSA中的序列,列对应对齐序列中的位置,条目 x m i x_{mi} xmi取整数值1,编码序列 m m m在位置 i i i处的氨基酸同一性(与参考氨基酸相同)。编码输入后,每一层都有一个 R M × L × d \R^{M\times L\times d} RM×L×d状态作为输入和输出。对于Transformer的核心,作者采用了Ho等人(2019)和Child等人(2019年)的轴向注意力方法。这种方法将注意力交替放在2D状态的行和列上(见图1)。MSA上注意力的这种稀疏模式使列注意力的注意力成本为 O ( L M 2 ) O(LM^2) O(LM2),行注意力的注意力为 O ( M L 2 ) O(ML^2) O(ML2)。
对于token embedding,将不同氨基酸用整数表示,形成一个整型向量。词库包括20种标准氨基酸、5种非标准氨基酸和4种特殊字符,共29种氨基酸token。
对于position embedding,标准Transformer位置嵌入是添加到序列中每个位置的1D信号。最常用的是固定正弦或可学习位置嵌入。Rives等人发现,学习位置嵌入通常会使蛋白质语言模型的下游性能更好。MSA是一个2D输入,因此必须考虑两种类型的位置嵌入。对于所有训练过的模型,作者提供了一个1D序列位置嵌入,它独立地添加到MSA的每一行。这使得模型能够区分不同的对齐位置。对于一个模型,作者还为MSA的每一列独立添加了一个位置嵌入,这允许模型区分不同的序列(如果没有这个,模型会将输入序列视为一个无序集)。作者还确保序列中的第一个位置始终是参考,这样就可以通过位置嵌入来唯一识别它。最后发现,引入列位置嵌入会略微提高性能。
绑定行注意力
轴向注意力的标准实现允许输入的每一行和每一列都有独立的注意力图。然而,在MSA中,每个序列都通常应该具有相似的结构。为了利用这种共享结构,作者假设将MSA中的序列之间的行注意力图绑定起来是有益的。另一个好处是,绑定注意力将行注意力的内存占用从 O ( M L 2 ) O(ML^2) O(ML2)减少到 O ( L 2 ) O(L^2) O(L2)。
预训练
模型在2600万MSA的数据集上进行训练。通过使用HHblits搜索UniClust30,为每个UniRef50序列生成MSA。MSA的平均深度为1192。
模型使用masked training方法进行训练,直接输出的是每一个masked token处为各种氨基酸的概率。然而主要目标并不是得到这些概率,而是要通过训练后的attention map预测蛋白质二、三级结构。
- 二级结构:基于MSA Transformer的特征表示向量(representation)预测8种折叠,准确率为72.9%。
- 三级结构:基于MSA Transformer各层、各注意力头的attention map,训练logistic回归模型,对蛋白质三级结构进行预测。
结果
无监督接触图预测
Rao等人表明,蛋白语言模型在没有监督的情况下能够捕捉蛋白质结构信息。可以通过在有限数量的蛋白质结构上训练一个小的逻辑回归来实现,同时可以根据注意力头的残基之间的注意力来预测残基 i i i和 j j j之间接触的概率。
使用相同的验证方法。对trRosetta数据集中的20个训练结构进行逻辑回归拟合。然后,这被用来预测trRosetta数据集中另外14842个结构上的蛋白质接触概率(不包括训练结构)。本质其实还是有监督。
有监督接触图预测
作者将MSA Transformer作为监督结构预测管道的一个组件进行评估。根据Rives等人的研究,使用0.001的学习率训练了一个具有32个激活块的深度残差网络。使用15051个MSA和结构的trRosetta训练集对网络进行分箱成对距离分布(distogram)的监督(给接触距离做了更细粒度的分箱标签)。输入残差网络的特征是,查询序列中,两个氨基酸embedding的concat。