消息传递神经网络
- 一、引言
- 二、消息传递范式介绍
- 三、消息传递的实现(pyG)
- 1、MessagePassing基类
- 2、继承MessagePassing实现GCNConv
一、引言
为节点生成节点表征是图计算任务成功的关键,神经网络的生成节点表征的操作叫做节点嵌入(node embeddi ng)
二、消息传递范式介绍
基于消息传递范式的生成节点表征的过程:
我们从左往右来看此图。图的左边是我们输入的整张图(INPUT GRAPH),由ABCDEFG六个节点组成,现在目标是得到更新之后A节点(target node)表示。
再看右边,与A相邻的BCD三个点进行变换和聚合后就会得到更新后的A节点信息。同理所有的节点都与它相邻节点有关,更新后的信息为相邻节点变换和聚合后的特征信息。消息传递图神经网络是指遵循“消息传递范式”的图神经网络,此类图神经网络实现了上述的节点信息更新过程。
注:未经过训练的图神经网络生成的节点表征还不是好的节点表征,好的节点表征可用于衡量节点之间的相似性。
三、消息传递的实现(pyG)
1、MessagePassing基类
Pytorch Geometric(PyG)提供了 MessagePassing基类,它封装了“消息传递”的运行流程。通过继承 MessagePassing 基类,可以方便地构造消息传递图神经网络,构造一个最简单的消息传递图神经网络类,我们只需定义聚合和更新的方法即可。
-
MessagePassing(aggr=“add”, flow=“source_to_target”,node_dim=-2) :
aggr :定义要使用的聚合方案("add"、"mean "或 "max");flow :定义消息传递的流向("source_to_target "或"target_to_source");node_dim :定义沿着哪个轴线传播
-
MessagePassing.propagate(edge_index, size=None,**kwargs) :
开始传递消息的起始调用,在此方法中 message 、 update 等方法被调用。它以 edge_index (边的端点的索引)和 flow (消息的流向)以及一些额外的数据为参数。
-
MessagePassing.aggregate(…) :
将从源节点传递过来的消息聚合在目标节点上,一般可选的聚合方式 有 sum , mean 和 max 。
-
MessagePassing.message(…):
接收传递给MessagePassing.propagate(edge_index, size=None,**kwargs) 方法的所有参数
-
MessagePassing.update(aggr_out, …) :
每个节点更新节点表征。此方法以aggregate 方法的输出为第一个参 数,并接收所有传递给propagate() 方法的参数。
2、继承MessagePassing实现GCNConv
将上面的公式写成代码
import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
import warnings
warnings.filterwarnings("ignore")class GCNConv(MessagePassing):def __init__(self, in_channels, out_channels):super(GCNConv, self).__init__(aggr='add', flow='source_to_target')# "Add" aggregation (Step 5).# flow='source_to_target' 表示消息从源节点传播到目标节点self.lin = torch.nn.Linear(in_channels,out_channels)def forward(self, x, edge_index):# x has shape [N, in_channels]# edge_index has shape [2, E] edge_index (边的端点的索引)# Step 1: 在邻接矩阵中加入自循环的边edge_index, _ = add_self_loops(edge_index,num_nodes=x.size(0))# Step 2: 线性变换节点特征矩阵。x = self.lin(x)# Step 3: 计算归一化系数norm#归一化系数是由每个节点的节点度#return edge_index, edge_weightrow, col = edge_indexdeg = degree(col, x.size(0), dtype=x.dtype)deg_inv_sqrt = deg.pow(-0.5)norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]## Step 4-5: 消息传递return self.propagate(edge_index, x=x, norm=norm)def message(self, x_j, norm):# x_j has shape [E, out_channels]# Step 4: 节点特征归一化.return norm.view(-1, 1) * x_j
GCNConv 继承了 MessagePassing 并以"求和"作为领域节点信息聚合方式。
该层的所有逻辑都发生在其 forward() 方法中。