torch_geometric使用手册-Creating Message Passing Networks(专题二)

创建消息传递网络 (Message Passing Networks)

在图神经网络中,将卷积操作推广到不规则域通常表现为一种邻域聚合 (neighborhood aggregation)消息传递 (message passing) 机制。
这一机制通过聚合节点的邻居信息,更新每个节点的特征。

以下公式描述了消息传递机制的基本形式:

公式解释

x i ( k ) = γ ( k ) ( x i ( k − 1 ) , ⨁ j ∈ N ( i ) ϕ ( k ) ( x i ( k − 1 ) , x j ( k − 1 ) , e j , i ) ) \mathbf{x}_i^{(k)} = \gamma^{(k)} \left( \mathbf{x}_i^{(k-1)}, \bigoplus_{j \in \mathcal{N}(i)} \, \phi^{(k)}\left(\mathbf{x}_i^{(k-1)}, \mathbf{x}_j^{(k-1)},\mathbf{e}_{j,i}\right) \right) xi(k)=γ(k) xi(k1),jN(i)ϕ(k)(xi(k1),xj(k1),ej,i)

  • x i ( k − 1 ) \mathbf{x}_i^{(k-1)} xi(k1): 表示第 k − 1 k-1 k1 层时节点 i i i 的特征。
  • e j , i \mathbf{e}_{j,i} ej,i: 表示从节点 j j j 到节点 i i i 的边特征(可选)。
  • N ( i ) \mathcal{N}(i) N(i): 节点 i i i 的邻居节点集合。
  • ϕ ( k ) \phi^{(k)} ϕ(k): 消息函数 (message function),生成从邻居节点 j j j 到节点 i i i 的消息。
  • ⨁ \bigoplus : 聚合函数 (aggregation function),例如加和 (sum)、均值 (mean) 或最大值 (max)。
  • γ ( k ) \gamma^{(k)} γ(k): 更新函数 (update function),结合节点本身的特征与聚合后的消息。

PyTorch Geometric (PyG) 提供了一个名为 MessagePassing 的基类,专门用于实现基于消息传递机制的图神经网络(GNN)。这个类封装了消息传递中的许多细节,开发者只需要定义核心函数,例如消息构造(message)、特征更新(update),以及选择合适的聚合方式(aggr),即可实现复杂的 GNN 算法。


核心概念与方法解析

1. 构造 MessagePassing 基类

MessagePassing(aggr="add", flow="source_to_target", node_dim=-2)
功能
  • 定义消息传递的聚合方式:

    • aggr: 表示如何将来自邻居节点的消息聚合到目标节点。
      • add(加和):计算邻居节点消息的加权和。
      • mean(平均):取邻居节点消息的加权平均值。
      • max(最大值):选择邻居节点消息的最大值。
  • 定义消息的传递方向:

    • flow:
      • "source_to_target":从源节点传递消息到目标节点。
      • "target_to_source":从目标节点向源节点传递消息。
  • node_dim:

    • 指定在哪一维度上传递节点特征。通常是倒数第二维(默认为 -2),适配节点特征张量。

2. 消息传递的入口:propagate 方法

MessagePassing.propagate(edge_index, size=None, **kwargs)
功能
  • 触发消息传递过程,从边索引和输入特征开始,依次执行:
    1. 消息构造message):生成从邻居节点传来的消息。
    2. 消息聚合aggregate,自动完成):将邻居节点的消息聚合到目标节点。
    3. 特征更新update):更新目标节点的最终特征。

注意: 这是入口函数,类似forward的操作,会调用messageaggregateupdate函数.

参数
  • edge_index:

    • 图的边索引,形状为 [2, num_edges]
    • 第一行表示源节点,第二行表示目标节点。
  • size:

    • 图中节点的数量或维度。
    • 对于普通图,默认为 [num_nodes, num_nodes];对于二分图(bipartite graph),可以传递 (N, M),分别表示源节点和目标节点数量。
  • kwargs:

    • 其他参数,如节点特征 x,边特征 edge_attr 等。

3. 消息生成:message 方法

MessagePassing.message(...)
功能
  • 根据每条边的两端节点特征(源节点和目标节点)以及边特征,生成要传递的消息。
参数
  • 默认情况下:
    • x_j: 源节点的特征。
    • x_i: 目标节点的特征。
    • edge_attr: 边的特征(如果存在)。
自动变量映射

propagate 内部,会根据 edge_index 自动将输入特征分为:

  • x_j:从源节点出发的特征。
  • x_i:传递到目标节点的特征。

4. 特征更新:update 方法

MessagePassing.update(aggr_out, ...)
功能
  • 根据聚合后的结果 aggr_out,计算目标节点的最终特征。
参数
  • aggr_out: 聚合后的邻居节点消息。
  • 可以使用其他参数,例如目标节点本身的初始特征。

5. 应用流程总结

  1. 消息生成

    • 根据边和节点特征,生成从邻居节点传递的消息(通过 message 方法)。
  2. 消息聚合

    • 使用选定的聚合方式(aggr 参数,如加和或平均),将消息聚合到目标节点。
  3. 特征更新

    • 在目标节点上应用更新规则,生成最终的节点特征(通过 update 方法)。

示例:实现经典的 GCN 和 EdgeConv

实现 GCN 层(Graph Convolutional Layer)

GCN 层的数学定义如下:

x i ( k ) = ∑ j ∈ N ( i ) ∪ { i } 1 deg ⁡ ( i ) ⋅ deg ⁡ ( j ) ⋅ ( W ⊤ ⋅ x j ( k − 1 ) ) + b \mathbf{x}_i^{(k)} = \sum_{j \in \mathcal{N}(i) \cup \{ i \}} \frac{1}{\sqrt{\deg(i)} \cdot \sqrt{\deg(j)}} \cdot \left( \mathbf{W}^{\top} \cdot \mathbf{x}_j^{(k-1)} \right) + \mathbf{b} xi(k)=jN(i){i}deg(i) deg(j) 1(Wxj(k1))+b

  • 邻居节点特征通过一个权重矩阵 W \mathbf{W} W 进行变换。
  • 然后,按照节点度进行归一化。
  • 最后,对邻居节点特征进行聚合并添加偏置项 b \mathbf{b} b

这个公式可以拆解为以下几个步骤:

  1. 为邻接矩阵添加自环(self-loops)
  2. 对节点特征矩阵进行线性变换
  3. 计算归一化系数
  4. 对特征进行归一化处理
  5. 聚合邻居节点特征(使用"加和"操作,"add" 聚合)。
  6. 对聚合结果加上最终的偏置项

在实现过程中:

  • 步骤 1-3 通常在消息传递(message passing)前完成。
  • 步骤 4-5 使用 MessagePassing 基类轻松实现。

以下是完整的 GCN 层实现代码:

import torch
from torch.nn import Linear, Parameter
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degreeclass GCNConv(MessagePassing):def __init__(self, in_channels, out_channels):super().__init__(aggr='add')  # "加和"聚合 (Step 5)self.lin = Linear(in_channels, out_channels, bias=False)self.bias = Parameter(torch.empty(out_channels))self.reset_parameters()def reset_parameters(self):self.lin.reset_parameters()self.bias.data.zero_()def forward(self, x, edge_index):# Step 1: 添加自环到邻接矩阵edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))# Step 2: 对节点特征进行线性变换x = self.lin(x)# Step 3: 计算归一化系数row, col = edge_indexdeg = degree(col, x.size(0), dtype=x.dtype)deg_inv_sqrt = deg.pow(-0.5)deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]# Step 4-5: 开始消息传递out = self.propagate(edge_index, x=x, norm=norm)# Step 6: 添加最终的偏置项out = out + self.biasreturn outdef message(self, x_j, norm):# 对节点特征进行归一化 (Step 4)return norm.view(-1, 1) * x_j

实现 EdgeConv(边卷积)

边卷积用于处理图结构或点云,其数学定义为:

x i ( k ) = max ⁡ j ∈ N ( i ) h Θ ( x i ( k − 1 ) , x j ( k − 1 ) − x i ( k − 1 ) ) \mathbf{x}_i^{(k)} = \max_{j \in \mathcal{N}(i)} h_{\mathbf{\Theta}} \left( \mathbf{x}_i^{(k-1)}, \mathbf{x}_j^{(k-1)} - \mathbf{x}_i^{(k-1)} \right) xi(k)=jN(i)maxhΘ(xi(k1),xj(k1)xi(k1))

其中, h Θ h_{\mathbf{\Theta}} hΘ 是一个多层感知机(MLP)。
与 GCN 类似,EdgeConv 层也基于 MessagePassing 实现,但使用的是 "max" 聚合方式。

以下是 EdgeConv 的实现代码:

import torch
from torch.nn import Sequential as Seq, Linear, ReLU
from torch_geometric.nn import MessagePassingclass EdgeConv(MessagePassing):def __init__(self, in_channels, out_channels):super().__init__(aggr='max')  # "最大值" 聚合self.mlp = Seq(Linear(2 * in_channels, out_channels),ReLU(),Linear(out_channels, out_channels))def forward(self, x, edge_index):return self.propagate(edge_index, x=x)def message(self, x_i, x_j):# 计算相对特征并输入到 MLPtmp = torch.cat([x_i, x_j - x_i], dim=1)return self.mlp(tmp)

EdgeConv 实际上是一个动态卷积,每一层都在特征空间中根据最近邻重新计算图。
PyG 提供了一个 GPU 加速的 k-NN 图生成方法 knn_graph

from torch_geometric.nn import knn_graphclass DynamicEdgeConv(EdgeConv):def __init__(self, in_channels, out_channels, k=6):super().__init__(in_channels, out_channels)self.k = kdef forward(self, x, batch=None):edge_index = knn_graph(x, self.k, batch, loop=False, flow=self.flow)return super().forward(x, edge_index)

DynamicEdgeConv 动态生成 k-NN 图,然后调用 EdgeConvforward 方法。


练习题翻译

关于 GCNConv:

  1. rowcol 包含什么信息?
  2. degree 方法的作用是什么?
  3. 为什么用 degree(col, ...) 而不是 degree(row, ...)
  4. deg_inv_sqrt[col]deg_inv_sqrt[row] 的作用是什么?
  5. message 方法中,x_j 包含什么信息?如果 self.lin 是恒等函数,x_j 的内容具体是什么?
  6. 添加一个 update 方法,使其将变换后的中心节点特征添加到聚合输出中。

关于 EdgeConv:

  1. x_ix_j - x_i 是什么?
  2. torch.cat([x_i, x_j - x_i], dim=1) 的作用是什么?为什么是 dim=1

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/477252.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

C++ASCII码表和字符操作

目录 1. 引言 2. ASCII码表 2.1 控制字符 2.2 可显示字符 3. 字符操作 3.1 记住几个字符规律 3.2 打印能够显示的ASCII码 3.3 字母大小写转换 3.4 数字转数字字符 1. 引言 在电子计算机中,只能识别由 0 和 1 组成的一串串的二进制数字,为了将人类…

git使用(二)

git使用(二) git常用基本操作命令git clonegit loggit remotegit statusgit addgit commitgit pushgit branchgit pull git常用基本操作命令 git clone 项目开发中项目负责人会在github上创建一个远程仓库,我们需要使用git clone将远程仓库…

密码学11

概论 计算机安全的最核心三个关键目标(指标)/为:保密性 Confidentiality、完整性 Integrity、可用性 Availability ,三者称为 CIA三元组 数据保密性:确保隐私或是秘密信息不向非授权者泄漏,也不被非授权者使…

netstat -tuln | grep 27017(显示所有监听状态的 TCP 和 UDP 端口,并且以数字形式显示地址和端口号)

文章目录 1. 确定占用端口的进程使用 lsof 命令使用 fuser 命令 2. 结束占用端口的进程3. 修改 MongoDB 配置文件4. 检查 MongoDB 日志文件5. 重新启动 MongoDB 服务6. 检查 MongoDB 服务状态总结 [rootlocalhost etc]# netstat -tuln | grep 27017 tcp 0 0 127.0.…

ElasticSearch7.x入门教程之集群安装(一)

文章目录 前言一、es7.x版本集群安装二、elasticsearch-head安装三、Kibana安装总结 前言 在工作中遇到了,便在此记录一下,以防后面会再次遇到。第一次使用是在2020年末,过了很久了,忘了些许部分了。 在工作当中,如果…

I.MX6U 裸机开发18.GPT定时器实现高精度延时

I.MX6U 裸机开发18.GPT定时器实现高精度延时 一、GPT定时器简介1. GPT 功能2. 时钟源3. 框图4. 运行模式(1)Restart mode(2)Free-Run Mode 5. 中断类型(1)溢出中断 Rollover Interrupt(2&#x…

key-value存储实现

文章目录 一、项目简介二、项目流程图三、网络3.1、epoll实现3.2、io_uring实现 四、协议五、存储5.1、array实现5.2、rbtree实现5.3、hash实现 六、测试 一、项目简介 key-value存储其实是一个小型的redis,用户在客户端输入存储相关的指令发送给服务器端&#xff…

大公司如何实现打印机共享的?如何对打印机进行管控或者工号登录后进行打印?异地打印机共享的如何实现可以帮助用户在不同地理位置使用同一台打印机完成打印任务?

大公司如何实现打印机共享的?如何对打印机进行管控或者工号登录后进行打印?异地打印机共享的如何实现可以帮助用户在不同地理位置使用同一台打印机完成打印任务? 如果在局域网内,可以不需要进行二次开发,通过对打印机进…

微软发布Win11 24H2系统11月可选更新KB5046740!

系统之家11月22日报道,微软针对Win11 24H2系统推出2024年11月最新可选更新补丁KB5046740,更新后系统版本后升至26100.2454,此次更新后修复当应用程序以PDF和XLSX格式导出图表对象时停止响应、无法使用API查找旋转信息等问题。以下小编将给大家…

探索 RocketMQ:企业级消息中间件的选择与应用

一、关于RocketMQ RocketMQ 是一个高性能、高可靠、可扩展的分布式消息中间件,它是由阿里巴巴开发并贡献给 Apache 软件基金会的一个开源项目。RocketMQ 主要用于处理大规模、高吞吐量、低延迟的消息传递,它是一个轻量级的、功能强大的消息队列系统&…

李宏毅机器学习课程知识点摘要(6-13集)

pytorch简单的语法和结构 dataset就是数据集,dataloader就是分装好一堆一堆的 他们都是torch.utils.data里面常用的函数,已经封装好了 下面的步骤是把数据集读进来 这里是读进来之后,进行处理 声音信号,黑白照片,红…

Wekan看板安装部署与使用介绍

Wekan看板安装部署与使用介绍 1. Wekan简介 ​ Wekan 是一个开源的看板式项目管理工具,它的配置相对简单,因为大多数功能都是开箱即用的。它允许用户以卡片的形式组织和跟踪任务,非常适合敏捷开发和日常任务管理。Wekan 的核心功能包括看板…

【Mysql】开窗聚合函数----SUM,AVG, MIN,MAX

1、概念 在窗口中,每条记录动态地应用聚合函数(如:SUM(),AVG(),MAX(),MIN(),COUNT(),)可以动态计算在指定的窗口内的各种聚合函数值。 2、操作 以下操作将基于employee表进行操作。 sum() 进行sum的时候,没有order …

EWA Volume Splatting

摘要 本文提出了一种基于椭圆高斯核的直接体绘制新框架,使用了一种投影方法(splatting approach)。为避免混叠伪影(aliasing artifacts),我们引入了一种重采样滤波器的概念,该滤波器结合了重建核…

Vue实训---0-完成Vue开发环境的搭建

1.在官网下载和安装VS Code编辑器 完成中文语言扩展(chinese),安装成功后,需要重新启动VS Code编辑器,中文语言扩展才可以生效。 安装Vue-Official扩展,步骤与安装中文语言扩展相同(专门用于为“…

C# 超链接控件LinkLabel无法触发Alt快捷键

在C#中,为控件添加快捷键的方式有两种,其中一种就是Windows中较为常见的Alt快捷键,比如运行对话框,记事本菜单等。只需要按下 Alt 框号中带下划线的字母即可触发该控件的点击操作。如图所示 在C#开发中,实现类似的操作…

赛氪媒体支持“2024科普中国青年之星创作交流活动”医学专场落幕

2024年11月15日下午,由中国科普作家协会、科普中国发展服务中心主办,什刹海文化展示中心承办,并携手国内产学研一体融合领域的领军者——赛氪网共同支持的“2024科普中国青年之星创作交流活动”医学科普专场,在什刹海文化展示中心…

《现代制造技术与装备》是什么级别的期刊?是正规期刊吗?能评职称吗?

​问题解答 问:《现代制造技术与装备》是不是核心期刊? 答:不是,是知网收录的第二批认定学术期刊。 问:《现代制造技术与装备》级别? 答:省级。主管单位:齐鲁工业大学&#xff0…

(十一)Python字符串常用操作

一、访问字符串值 Python访问子字符串变量,可以使用方括号来截取字符串。与列表的索引一样,字符串索引从0开始。 hh"LaoTie 666" hh[2] mm"床前明月光" mm[3] 字符串的索引值可以为负值。若索引值为负数,则表示由字符…

数据结构(初阶6)---二叉树(遍历——递归的艺术)(详解)

二叉树的遍历与练习 一.二叉树的基本遍历形式1.前序遍历(深度优先遍历)2.中序遍历(深度优先遍历)3.后序遍历(深度优先遍历)4.层序遍历!!(广度优先遍历) 二.二叉树的leetcode小练习1.判断平衡二叉树1)正常解法2)优化解法 2.对称二叉…