论文阅读与分析:Few-Shot Graph Learning for Molecular Property Prediction
- 论文地址和代码地址
- 1 摘要
- 2 主要贡献
- 3 基础知识
- Meta Learning
- 1 介绍
- 2 学习算法
- Step 1: What is learnable in a learning algorithm?
- Step 2:Define loss function for learning algorithm F ϕ F_{\phi} Fϕ
- Step 3:Optimazation
- 框架总流程
- 3 ML和Meta的区别
- 二者目标
- 训练数据
- 框架
- MAML
- 4 细节问题
- 问题定义
- 分子图神经网络
- 5 META-MGNN
- 模型框架图
- Meta-learning Setup
- 元训练
- 梯度下降
- 损失函数(自监督模块)
- 任务感知注意力模块
- 元测试
- 6 实验
- 数据集
- 模型比较结果
论文地址和代码地址
论文地址:https://arxiv.org/pdf/2102.07916
代码地址:https://github.com/zhichunguo/Meta-MGNN
1 摘要
图神经网络最近的成功显着促进了分子特性预测, 推进了药物发现等活动。现有的深度神经网络方法 通常需要每个属性都需要大量的训练数据集,在实验数据有限的情况下(特别是新的分子属性)会损害其性能,这在现实情况中很常见。为此,我们提出了 Meta-MGNN,一种用于小样本分子特性预测的新模型。 Meta-MGNN应用分子图神经网络来学习分子表示,并构建用于模型优化的元学习框架。 为了利用未标记的分子信息并解决不同分子特性的 任务异质性,Meta-MGNN进一步结合了分子结构、基于属性的自学习-将监督模块和自注意力任务权重纳入到前一个框架中,强化了整个学习模型。 对两个公共多属性数据集的大量实验表明,MetaMGNN 优于各种最先进的方法。
2 主要贡献
(1)三大挑战
- 第一个挑战是设计一个深度神经网络,可以从少量数据中发现有效的分子表示。
- 由于只有有限数量的标记分子属性数据可用,第二个挑战是利用分子数据中有用的未标记信息,并进一步开发有效的学习程序来转移其他属性预测的知识,
- 以便模型可以用有限的数据快速适应新的分子特性 不同的分子性质可能代表完全不同的分子结构。因此,在知识转移过程中应该区别对待他们的数据。第三个挑战是在执行有效的学习过程时区分分子属性的不同重要性
(2)解决挑战
为了解决挑战,提出了一种称为 Meta-MGNN 的新模型,用于小样本分子属性预测
- 利用图神经网络与预训练过程融合异构分子图信息作为分子嵌入
- 开发了一个元学习框架来转移来自不同性质预测任务的知识,并获得一个初始化良好的模型,该模型可以快速适应有限数据的 新分子性质。
- 为了利用和捕获分子数据中的未标记信息,我们设计了一个自监督模块,该模块由键重构损失和原子类型预测损失组成,并伴有主要属性预测损失。
- 考虑到不同的属性预测任务对少样本学习器的贡献不同,我们进一步引入了自我关注任务权重 来衡量它们的重要性。自监督模块和自注意力任务权重都被纳入元学习过程中,以加强模型
(3)主要贡献
- 我们将分子性质预测问题定义为少量样例学习问题,它利用各种特性的丰富信息来解决每个特性缺乏实验室数据的问题。
- 为了应对小样本的挑战,我们提出了一种称为元的新模型MGNN 通过探索图神经网 络、自监督学习和任务权重感知元学习。
- 我们在两个公共数据集上进行了广泛的实验,评估结果表明 Meta-MGNN 的性能优于最先进的方法。每个模型组件的有效性也得到了验证。
3 基础知识
Meta Learning
1 介绍
- 工业界:大量GPU同时训练多组可能的超参数,找到结果较好的参数配置
- 学术界:定义一组好的参数
- Meta-learning:learn to learn,希望自己学会超参数、网络架构……,帮助学术界解决问题
对于机器学习,三大步骤如下:
“学习算法”可以看做一个函数F,它的输入是训练数据,输出是model
- 一般的ML中,F是人想出来的(hand-crafted)
- Meta Learning 自动学习“学习算法”F
2 学习算法
Step 1: What is learnable in a learning algorithm?
确定要学习的架构,从而让机器来学习,这个叫:learnable components,也称为 ϕ \phi ϕ,里面一般有:
- 网络架构
- 初始化参数 θ 0 \theta_{0} θ0
- 学习率 η
- ……
不同的Meta Learning,会学习不同的learnable components
Step 2:Define loss function for learning algorithm F ϕ F_{\phi} Fϕ
此时开始定义损失函数:
- 损失函数为: L ( ϕ ) L(\phi) L(ϕ),学习算法为: F ϕ F_{\phi} Fϕ
- L ( ϕ ) L(\phi) L(ϕ)的值越小,就说明学习算法 F ϕ F_{\phi} Fϕ越好
在Meta Learning中,所需的数据有:
- 训练数据→训练任务(训练任务里的训练数据+测试数据)
- 测试数据→测试任务(包含训练数据+测试数据)
例如,想要训练一个二元分类器,则需要准备二元分类任务,每个任务中有训练集和测试集,如下所示:
有了训练任务,接下来就要定义 L ( ϕ ) L(\phi) L(ϕ)。使用学习到的学习算法 F ϕ F_{\phi} Fϕ ,针对某一个特定任务的训练数据开始训练,得到模型 f θ 1 ∗ f_{\theta^{1*}} fθ1∗:
- 模型 f θ 1 ∗ f_{\theta^{1*}} fθ1∗性能越好,学习算法 F ϕ F_{\phi} Fϕ就越好, L ( ϕ ) L(\phi) L(ϕ)就越低
因此接下来需要评估模型 f θ 1 ∗ f_{\theta^{1*}} fθ1∗的性能,因此使用测试数据(带标签)来测试模型,计算输出跟正确结果之间的差异,从而算出Cross Entropy,最终加和Cross Entropy,最终得到 l 1 l^{1} l1
- l 1 l^{1} l1表示 f θ 1 ∗ f_{\theta^{1*}} fθ1∗的性能如何, l 1 l^{1} l1越小,性能越好
- 在一般的机器学习中,loss是根据“训练数据”来计算的;而在元学习中,loss根据训练任务中的“测试数据”进行计算。
要想测试其性能,还需拿其它二元分类的任务来测试它,将该“学习算法”在所有的“学习任务”上的损失求和,得到总的 L ( ϕ ) L(\phi) L(ϕ)
Step 3:Optimazation
此时已经得到了学习算法 F ϕ F_{\phi} Fϕ的损失函数 L ( ϕ ) L(\phi) L(ϕ),接下里进行优化:
- 若 ∂ L ( ϕ ) ∂ ϕ \frac{\partial L(\phi)}{\partial\phi} ∂ϕ∂L(ϕ),可导,则可以使用梯度下降法
- 若不可导,则使用强化学习或者进化算法进行训练
- 最终得到学习而来的学习算法 F ϕ F_{\phi} Fϕ
框架总流程
在Meta Learning中,所需的数据有:
- 训练数据→训练任务(训练任务里的训练数据+测试数据)
- 测试数据→测试任务(包含训练数据+测试数据)
框架总流程如下:
- 使用训练任务的数据按照上述三大步骤得到学习过的学习算法 F ϕ F_{\phi} Fϕ
- 使用测试任务的数据中的训练数据让学习过的学习算法 F ϕ F_{\phi} Fϕ**训练出模型 f θ ∗ f_{\theta^*} fθ∗
- 将模型 f θ ∗ f_{\theta^*} fθ∗用在测试任务的数据中的测试数据上,得到最终结果
所以我们可以得知:
- 训练任务是找出学习算法 F ϕ F_{\phi} Fϕ
- few-shot learning(小样本学习):利用meta-learning的技术,可以达到“few-shot learning”的目的。
3 ML和Meta的区别
二者目标
- ML:找到一个能完成任务的函数f
- Meta:找到一个学习算法F,能够找到f
训练数据
- ML:完成一个任务即可,使用这个任务中的“训练数据”进行训练
- Meta:使用若干个任务进行训练,每个“训练任务”中都有“训练数据+测试数据”
- Support set(支持集):任务里的训练数据
- Query set(查询集):任务里的测试数据
框架
训练时:
-
ML:学习算法是人工设定的,即Within-task Training(任务内训练)
-
Meta:学习算法是在多个任务上训练得到的,即Across-task Training(跨任务训练)
测试时: -
ML:直接使用训练得到的模型在任务中使用测试数据进行测试(Within-task Testing:任务内测试)
-
Meta:测试学习算法(Across-task Testing:跨任务测试)
- 将测试任务的训练数据给学习算法,训练得到该任务的模型(任务内训练)
- 将测试任务的测试数据给模型,得到最终结果(任务内测试)
- 这样一个流程叫Episode
对于损失函数:
- ML:对一个任务中所有的测试数据的损失之和
- Meta: l l l是一个任务的损失, L L L是所有任务的损失之和
此时,要注意的是:
- 内循环:计算一个 l l l需要经历即下面的Training Examples+Testing Examples
- 外循环:需要多个任务进行
MAML
MAML聚焦于学习一个最好的初始化参数 ϕ \phi ϕ,MAML的目标是找到最优的初始参数 ϕ \phi ϕ,是所有任务的测试损失值最小,这样在遇到新任务时,只需基于少量标签对初始化参数进行微调就可以获得很好的效果。
- MAML的N-way K-shot:任务的类别为N,每个类别的Support set为K,query
set人为进行选择
假设采样了一个batch的任务,对伪代码进行解析:
- 第5,6行对任务的支持集求损失并将参数 θ \theta θ更新为 θ ′ \theta^{'} θ′,假设一个batch有5个任务,则此时参数 θ ′ \theta^{'} θ′有10个
- 第8行,使用 θ ′ \theta^{'} θ′参数,对所有任务的查询集计算出各自的loss,将这些loss求和,计算出梯度,利用该梯度更新初始参数。( θ ′ \theta^{'} θ′和每个任务的查询集一一对应)
4 细节问题
问题定义
(1)图的定义
设 G = ( V , E ) G = (\mathcal{V},\mathcal{E}) G=(V,E)是分子图, V \mathcal{V} V是节点集, E \mathcal{E} E是边集。分子图中的节点代表化学原子,边代表两个原子之间的化学键。
(2)学习映射函数
给定一组分子图 G = { G 1 , ⋯ , G N } \mathcal{G}=\{G_{1},\cdots,G_{N}\} G={G1,⋯,GN}及其标签 Y = { y 1 , ⋯ , y N } \mathcal{Y}=\{y_{1},\cdots,y_{N}\} Y={y1,⋯,yN},分子性质预测的目标是学习一个分子表示向量预测每一个 G i ∈ G G_{i}∈G Gi∈G的标签(分子属性),也就是学习一个映射函数: f θ : G → Y f_{\theta}:\mathcal{G}\to\mathcal{Y} fθ:G→Y
(3)问题定义
少样本分子性质预测:给定分子性质 Y = { y 1 , ⋯ , y N } \mathcal{Y}=\{y_{1},\cdots,y_{N}\} Y={y1,⋯,yN}和相应的少样本分子图集 { G 1 ∈ y 1 , ⋯ , G N ∈ y N } \{G_{1}\in y_{1},\cdots,G_{N}\in y_{N}\} {G1∈y1,⋯,GN∈yN}(训练数据),任务是设计一个机器学习模型来预测只有少数样本(测试数据)的新特性的分子图。
分子图神经网络
在本节中,我们将介绍使用 GNN 获取分子表示的细节。GNN 模型能够利用图结构和节点/边特征信息来学习每个节点 v ∈ V v∈V v∈V的表示向量 h v h_{v} hv,经过 l l l次迭代后,节点表示 h v l h_{v}^{l} hvl能够捕获 l l l跳邻域内的信息。
作为GNN的输入层,首先使用分子图中的属性来初始化节点和边的表示。
- 节点属性包括原子数(AN)和手性标签 (CT)、
- 边属性包括键类型(BT)和键方向(BD)
我们将节点表示初始化: h v ( 0 ) = v A N ⊕ v C T \mathbf{h}_{v}^{(0)}=\mathbf{v}_{AN}\oplus\mathbf{v}_{CT} hv(0)=vAN⊕vCT;将边表示初始化: h e ( 0 ) = e B T ⊕ e B D \mathbf{h}_{e}^{(0)}=\mathbf{e}_{BT}\oplus\mathbf{e}_{BD} he(0)=eBT⊕eBD
- 其中 v , e v,e v,e分别表示节点和边
- ⊕是串联运算符
在第 l l l层GNN中,节点表示 h v l h_{v}^{l} hvl的公式为:
h N ( v ) ( l ) = A G G l ( { h u ( l − 1 ) : ∀ u ∈ N ( v ) } , { h e ( l − 1 ) : e = ( v , u ) } ) , ( 1 ) h v ( l ) = σ ( W ( l ) ⋅ C o N C A T ( h v ( l − 1 ) , h N ( v ) ( l ) ) ) , ( 2 ) \mathbf{h}_{\mathcal{N}(v)}^{(l)}=\mathrm{AGG}_{l}(\{\mathbf{h}_{u}^{(l-1)}:\forall u\in\mathcal{N}(v)\},\{\mathbf{h}_{e}^{(l-1)}:e=(v,u)\}),\quad(1)\\\mathbf{h}_{v}^{(l)}=\sigma(\mathbf{W}^{(l)}\cdot\mathrm{CoNCAT}(\mathbf{h}_{v}^{(l-1)},\mathbf{h}_{\mathcal{N}(v)}^{(l)})),\quad(2) hN(v)(l)=AGGl({hu(l−1):∀u∈N(v)},{he(l−1):e=(v,u)}),(1)hv(l)=σ(W(l)⋅CoNCAT(hv(l−1),hN(v)(l))),(2)
其中:
- N ( v ) N(v) N(v)是 v v v的邻居集合
- A G G AGG AGG是邻居聚合函数
有许多 A G G AGG AGG的架构,如GCN,GAT,GIN等。在这里,选择使用GIN。
我们可以学习到分子图中每个节点的表示:
h v = h v l / ∣ ∣ h v l ∣ ∣ h_{v}=h_{v}^{l} / ||h_{v}^{l}|| hv=hvl/∣∣hvl∣∣
最终使用最后一层的平均节点嵌入来代表整个分子图的图集表示 h G h_{G} hG,如下:
h G = MEAN ( { h v ( l ) : v ∈ V } ) , ( 3 ) \mathbf{h}_G=\text{MEAN}(\{\mathbf{h}_v^{(l)}:v\in V\}),\quad(3) hG=MEAN({hv(l):v∈V}),(3)
h G h_{G} hG,进一步输入分类器(例如,多层感知)进行分子属性预测。
5 META-MGNN
模型框架图
(a)META-MGNN的总体框架:
- 对一批训练任务进行采样。 对于每个任务,支持集中都有一些数据示例。这些示例被输入到由 θ \theta θ参数化的 GNN 中。
- 计算支持集损失 L s u p p o r t \mathcal{L}_{support} Lsupport并将其用于将 GNN 参数更新为 θ ′ \theta^{\prime} θ′。
- 相应查询集中的示例被输入到由 θ ′ \theta^{\prime} θ′参数化的 GNN中,并计算该任务的损失 L q u e r y ′ \mathcal{L}_{query}^{\prime} Lquery′。对于其他训练任务重复相同的过程。
- 计算所有采样任务的 L q u e r y ′ \mathcal{L}_{query}^{\prime} Lquery′ 的总和,并用它来进一步更新 GNN 参数以进行测试。
(b)自监督模块
包括键重构和原子类型预测:
- 橙色部分显示我们对两个原子进行采样并使用 GNN 来预测它们之间是否存在键。
- 绿色部分显示我们随机屏蔽了几个原子并使用 GNN来预测它们的类型。
(c)任务感知注意力
它计算同一任务的査询集中所有分子嵌入的平均值来表示该任务。通过嵌入每个任务,我们设计了一个自注意力层来计算每个任务的权重,然后将其合并到元训练过程中以更新模型参数 θ \theta θ。
Meta-learning Setup
构建了基于MAML的元学习框架。给定一个模型 f θ f_{\theta} fθ, θ \theta θ是可学习参数。该模型可将分子图映射到具体的属性(例如毒性),即: f θ : G → Y f_{\theta}:\mathcal{G}\to\mathcal{Y} fθ:G→Y
- 在𝑘-shot元学习中,对于从分布 p ( T ) p(\mathcal{T}) p(T)采样的每个任务 T τ \mathcal{T}_{\tau} Tτ,模型仅使用 k k k数据样本训练,并在 T τ \mathcal{T}_{\tau} Tτ的剩余样本中进行测试。
因此,将每个任务对应的训练集和测试集称为支持集和查询集,表示为 T τ = { G τ , Y τ , G τ ′ , Y τ ′ } \mathcal{T}_{\tau}=\{G_{\tau},\mathcal{Y}_{\tau},\mathcal{G}_{\tau}^{\prime},\mathcal{Y}_{\tau}^{\prime}\} Tτ={Gτ,Yτ,Gτ′,Yτ′}。
- G τ , Y τ G_{\tau},\mathcal{Y}_{\tau} Gτ,Yτ是支持集的输入分子图和属性标签
- G τ ′ , Y τ ′ \mathcal{G}_{\tau}^{\prime},\mathcal{Y}_{\tau}^{\prime} Gτ′,Yτ′是查询集的输入分子图和属性标签
元训练期间,其步骤如下:
- 模型首先使用每个任务的支持集更新为特定任务的模型
- 使用训练数据中所有任务的查询集的预测损失进一步优化为任务无关模型
元测试期间,其步骤如下:
- 充分的元训练
- 学习到的模型可以进一步利用 k k k数据样本作为支持集来预测新的任务(新的分子性质)
元训练
梯度下降
元训练的目标是获得初始化良好的参数 θ \theta θ,参数需具备如下要求:
- 该参数可以普遍适用于不同的任务
- 在新任务上使用少量数据进行少量梯度下降更新后表现良好
对于任务 T τ \mathcal{T}_{\tau} Tτ,首先将支持集输入模型并计算损失 L T τ \mathcal{L}_{T_{\tau}} LTτ,通过梯度下降将参数 θ \theta θ更新:
θ τ ′ = θ − α ∇ θ L T τ ( θ ) , ( 4 ) \theta_{\tau}'=\theta-\alpha\nabla_{\theta}\mathcal{L}_{\mathcal{T}_{\tau}}(\theta),\quad(4) θτ′=θ−α∇θLTτ(θ),(4)
- 其中 α \alpha α是步长
- 该式子仅显示了一步梯度更新,而实际上我们可以采取多步梯度更新
损失函数(自监督模块)
L T τ \mathcal{L}_{\mathcal{T}_{\tau}} LTτ由下游任务的监督信号计算。由于样本数量少,简单使用监督信号可能并不有效;分子的复杂性带来了结构和属性方面的有用的未标记信息。因此,需要利用分子图中的未标记信息(设计了自监督模块)。
- 自监督模块:键重建损失、原子类型预测损失和属性预测损失组成
(1)分子预测损失
对于图级分子表示 h \mathbf{h} h(式子3),引入了多层感知机,即 y ^ = M L P ( h ) \hat{y}=\mathrm{MLP}(\mathbf{h}) y^=MLP(h),预测损失定义为预测标签和真实标签之间的交叉熵损失:
L l a b e l ( θ ) = − 1 k ∑ i = 1 k C R o s s E N T R O P Y ( y i , y ^ i ) ( 5 ) \mathcal{L}_{label}(\theta)=-\frac{1}{k}\sum_{i=1}^{k}\mathrm{CRossENTROPY}(y_{i},\hat{y}_{i})\quad(5) Llabel(θ)=−k1i=1∑kCRossENTROPY(yi,y^i)(5)
- k k k是数据样本的数量
(2)键重建损失
在分子图中,通过有边的节点对采样一组正边(有键),然后通过没有边的节点对采集一组负边(没有键)。将 ε s \varepsilon_{s} εs表示为采样的正边和负边的并集。
- 设置 ε s = 10 \varepsilon_{s}=10 εs=10,表示5个正样本和5个负样本
对于选中的任意节点对,将节点映射为嵌入,分别为: h u \mathbf{h}_{u} hu和 h v \mathbf{h}_{v} hv,对其内积,根据其内积判断其是否有键,其公式为: e ^ u v = h v ⊤ ⋅ h u \hat{e}_{uv}=\mathbf{h}_{v}^{\top}\cdot\mathbf{h}_{u} e^uv=hv⊤⋅hu
- 若两节点越相似,则说明越可能存在键,内积值越高
键重建损失定义为:真实键和预测键之间的二元交叉熵损失。公式如下:
L e d g e ( θ ) = − 1 ∣ E s ∣ ∑ e u v ∈ E s BINARYCROssENTROPY ( e u v , e ^ u v ) ( 6 ) \mathcal{L}_{edge}(\theta)=-\frac{1}{|\mathcal{E}_{s}|}\sum_{e_{uv}\in\mathcal{E}_{s}}\text{BINARYCROssENTROPY}(e_{uv},\hat{e}_{uv})\quad(6) Ledge(θ)=−∣Es∣1euv∈Es∑BINARYCROssENTROPY(euv,e^uv)(6)
(3)原子类型预测损失
在分子中,不同的原子以一定的方式连接(例如碳-碳键、碳-氧键),导致不同的分子结构。原子类型决定了分子图中的节点如何与相邻节点连接。因此,利用节点(原子)的上下文子图来预测其类型。
首先对分子图的一组节点进行采样,表示为: V c t ⊆ V \mathcal{V}_{ct}\subseteq \mathcal{V} Vct⊆V。对于每一个在 V c t \mathcal{V}_{ct} Vct中的节点 v v v,其上下文子图被定义为在 l l l跳内的邻居。公式化为:
G s u b = ( U s u b , E s u b ) G_{sub}=(\mathcal{U}_{sub},\mathcal{E}_{sub}) Gsub=(Usub,Esub)
- 其中, U s u b = { v } ∪ N l ( v ) \mathcal{U}_{sub}=\{v\}\cup\mathcal{N}_{l}(v) Usub={v}∪Nl(v)
- E s u b ⊆ U s u b × U s u b \mathcal{E}_{sub}\subseteq\mathcal{U}_{sub}\times\mathcal{U}_{sub} Esub⊆Usub×Usub
- N l ( v ) N_{l}(v) Nl(v)代表节点 v v v的邻居节点集合
- 一般选择 l = 1 l=1 l=1
- 取 V c t \mathcal{V}_{ct} Vct为图中节点的15%
接着在上下文子图中所有节点的均值池化之上使用多层感知器(MLP),不包括中心节点。原子类型预测损失被表示为预测节点类型和真实节点类型之间的交叉熵损失:
v ^ i = MLP ( MEAN ( { h u : u ∈ N l ( v ) } ) ) , ( 7 ) \hat{v}_i=\text{MLP}\left(\text{MEAN}(\{\mathbf{h}_u:u\in\mathcal{N}_l(v)\})\right),\quad(7) v^i=MLP(MEAN({hu:u∈Nl(v)})),(7)
L n o d e ( θ ) = − 1 ∣ V c ∣ ∑ i = 1 ∣ V c ∣ CRossENTROPY ( v i , v ^ i ) , ( 8 ) \mathcal{L}_{node}(\theta)=-\frac{1}{|\mathcal{V}_{c}|}\sum_{i=1}^{|\mathcal{V}_{c}|}\text{CRossENTROPY}(v_{i},\hat{v}_{i}),\quad(8) Lnode(θ)=−∣Vc∣1i=1∑∣Vc∣CRossENTROPY(vi,v^i),(8)
(4)联合损失
任务 T τ \mathcal{T}_{\tau} Tτ的损失公式如下:
L T τ ( θ ) = L n o d e ( θ ) + λ 1 L e d g e ( θ ) + λ 2 L l a b e l ( θ ) ( 9 ) \mathcal{L}_{\mathcal T_{\tau}}(\theta)=\mathcal{L}_{node}(\theta)+\lambda_{1}\mathcal{L}_{edge}(\theta)+\lambda_{2}\mathcal{L}_{label}(\theta)\quad(9) LTτ(θ)=Lnode(θ)+λ1Ledge(θ)+λ2Llabel(θ)(9)
- 其中 λ 1 \lambda_{1} λ1和 λ 2 \lambda_{2} λ2是控制不同损失重要性的参数,设置为 λ 1 = λ 2 = 0.1. \lambda_{1}=\lambda_{2}=0.1. λ1=λ2=0.1.
任务感知注意力模块
根据公式 θ τ ′ = θ − α ∇ θ L T τ ( θ ) \theta_{\tau}'=\theta-\alpha\nabla_{\theta}\mathcal{L}_{\mathcal{T}_{\tau}}(\theta) θτ′=θ−α∇θLTτ(θ)获得了参数 θ τ ′ \theta_{\tau}' θτ′,接下来模型会进一步更新,如下式子:
θ ← θ − β ∇ θ ∑ T τ ∼ p ( T ) η ( T τ ) ⋅ L T τ ′ ( θ τ ′ ) , ( 10 ) \theta\leftarrow\theta-\beta\nabla_\theta\sum_{\mathcal{T}_\tau\sim p(\mathcal{T})}\eta(\mathcal{T}_\tau)\cdot\mathcal{L}_{\mathcal{T}_\tau}^{\prime}(\theta_\tau^{\prime}),\quad(10) θ←θ−β∇θTτ∼p(T)∑η(Tτ)⋅LTτ′(θτ′),(10)
- β \boldsymbol{\beta} β是元学习率
- L T τ ′ \mathcal{L}_{\mathcal{T}_\tau}^{\prime} LTτ′是任务 T τ \mathcal{T}_{\tau} Tτ的查询集的联合损失
创新点
元学习方法(例如,MAML)在优化元学习器时以相同的权重对待每个任务,即 η ( T τ ) \eta(\mathcal{T}_{\tau}) η(Tτ)对于所有任务都是相同的。但不同的属性预测任务对元学习器优化的贡献不同,因此使用自注意力机制来计算每个任务的重要性:
η ( T τ ) = exp ( M L P ( H T τ ) ) ∑ T τ ′ ∈ T exp ( M L P ( H T τ ′ ) ) , H T τ = M E A N ( { h T τ , i } i = 1 k ) , ( 11 ) \eta(\mathcal{T}_{\tau})=\frac{\exp(\mathrm{MLP}(\mathrm{H}_{\mathcal{T}_{\tau}}))}{\sum_{\mathcal{T}_{\tau^{\prime}}\in\mathcal{T}}\exp(\mathrm{MLP}(\mathrm{H}_{\mathcal{T}_{\tau^{\prime}}}))}, \mathrm{H}_{\mathcal{T}_{\tau}}=\mathrm{MEAN}(\{\mathbf{h}_{\mathcal{T}_{\tau},i}\}_{i=1}^{k}),\quad(11) η(Tτ)=∑Tτ′∈Texp(MLP(HTτ′))exp(MLP(HTτ)),HTτ=MEAN({hTτ,i}i=1k),(11)
- T \mathcal{T} T是所有任务的集合
- H T τ \mathrm{H}_{\mathcal{T}_{\tau}} HTτ是任务嵌入,通过对 T τ \mathcal{T}_{\tau} Tτ中所有的分子嵌入进行平均计算得出
元测试
在元测试期间,我们首先利用新任务的少样本支持集,通过使用式4来更新Meta-MGNN 的参数 θ \theta θ,然后评估查询集的功能。
6 实验
数据集
使用 Tox21 和 Sider 数据集,在每个任务中,分子被分为正实例和负实例(即二元标签)。正实例意味着分子具有特定属性,负实例意味着分子不具有该属性。
- Tox21:是一个用于毒性预测的生物信息学资源,有 7,831 个实例,有 12 种不同的任务(12个生物靶标的毒性)。
- Sider:是一个用于药物-副作用关系的生物信息学资源,有 1,427 个实例, 有 27 个不同的任务(27个系统器官类别)。
- 拆分了 Tox21 中的 3 个任务和 Sider 中 的 6 个任务用于元测试
模型比较结果
(1)评估标准
使用 ROC-AUC 评估每个模型的性能。我们将每个分子属性视为小样本学习的独立任务。我们分别使用3个和6个任务作 为Tox21和 Sider 的测试任务。每个任务都是一个二元标签分类任务。
(2)实验结果
- 最后一列报告了 MetaMGNN 在不同任务中相对于最好结果的平均改进
- 加深字体表示模型的最好结果
- 下划线表示用于与MetaMGNN 比较的模型中最佳的结果