PyG教程:MessagePassing基类

PyG教程:MessagePassing基类

  • 一、引言
  • 二、如何自定义消息传递网络
    • 1.构造函数
    • 2.propagate函数
    • 3.message函数
    • 4.aggregate函数
    • 5.update函数
  • 三、代码实战
    • 1.图数据定义
    • 2.实现GNN的消息传递过程
    • 3.完整代码
    • 4.完整代码的精简版本
  • 四、总结
    • 1.MessagePassing各个函数的执行顺序
    • 2.参考资料

一、引言

PyG框架中提供了一个消息传递基类torch_geometric.nn.MessagePassing,它实现了消息传递的自动处理,继承该类可以简单方便的构建自己的消息传播GNN。

二、如何自定义消息传递网络

要自定义GNN模型,首先需要继承MessagePassing类,然后重写如下方法:

  • message(...):构建要传递的消息;
  • aggregate(...):将从源节点传递过来的消息聚合到目标结点;
  • update(...):更新节点的消息。

上述方法并不是一定都要自定义,若MessagePassing类默认实现满足你的需求,则可以不重写。

1.构造函数

继承MessagePassing类后,在构造函数中可以通过super().__init__方法来向基类MessagePassing传递参数,来指定消息传递过程中的一些行为。MessagePassing类的初始化函数如下:
在这里插入图片描述
参数说明:

参数名参数说明
aggr消息传递中的消息聚合方式,常用的包括summeanminmaxmul等等。default: sum
flow消息传播的方向,其中source_to_targe表示从源节点到目标节点、target_to_source表示从目标节点到源节点。default:source_to_target
node_dim传播的维度,default:-2
decomposed_layers这个参数没用过,我也还不知道,后面会更新。

2.propagate函数

在具体介绍消息传递的三个相关函数之前,首先先介绍propagate函数,该函数是消息传递的启动函数,调用该函数后依次会执行messageaggregateudpate函数来完成消息的传递聚合更新。该函数的声明如下:
在这里插入图片描述
参数说明:

参数名参数说明
edge_index边索引
size这个参数目前我理解的不是很透彻,后面透彻了补一下
**kwargs构建、聚合和更新消息所需的额外数据,都可以传入propagate函数,这些参数可以在消息传递过程中的三个函数中接收。

该函数一般会传入edge_index和特征x

3.message函数

message函数是用来构建节点的消息的。传递给propagate函数的tensor可以映射到中心(target)节点邻居(source)节点上,只需要在相应变量名后加上_ior_j即可,通常称_i为中心(target)节点,称_j为邻居(source)节点。

source节点和target节点的关系:
在这里插入图片描述
message实现源码:
在这里插入图片描述

从源码的默认实现可以看出,message传递的消息就是邻居节点自身的特征向量。

示例:

def forward(self, data):out = self.propagate(edge_index, x=x)passdef message(self, x_i, x_j, edge_index_i, edge_index_j):pass

该例子中利用propagate函数传递了两个参数edge_indexx,则message函数可以根据propagate函数中的两个参数构造自己的参数,上述message函数中的构造参数为:

  • x_i:中心节点(target)的特征向量组成的矩阵,注意该矩阵与图节点的矩阵x是不同的;
  • x_j:邻居节点(source)的特征向量组成的矩阵;
  • edge_index_i:中心节点的索引;
  • edge_index_j:邻居节点的索引。

注意,若flow='source_to_target',则消息将由邻居节点传向中心节点,若flow='target_to_source'则消息将从中心节点传向邻居节点,默认为第一种情况

4.aggregate函数

消息聚合函数aggregate用来聚合来自邻居的消息,常用的包括summeanmaxmin等,可以通过super().__init__()中的参数aggr来设定。该函数的第一个参数为message函数的返回值。

  • aggr='sum' 表示 和聚合,它会对每个特征维度计算所有邻居节点的消息的总和。
  • aggr='mean' 表示 平均值值聚合,它会对每个特征维度计算所有邻居节点的消息的平均值。
  • aggr='max' 表示 最大值聚合,它会对每个特征维度选择所有邻居节点的消息中的最大值。
  • aggr='min' 表示 最小值聚合,它会对每个特征维度选择所有邻居节点的消息中的最小值。

5.update函数

update函数用来更新节点的消息,aggregate函数的返回值作为该函数的第一个参数。

默认实现:
在这里插入图片描述

从默认实现可以看出update函数没有进行任何的操作,只是将raggregate函数的返回值返回了而已。

实际写代码的过程中,我们也不会去重写这个方法,而是,在forward函数中调用完propagate(…)函数后编写代码,代替update函数的功能。

三、代码实战

假设我们设计一个GNN模型,其中消息传递过程用公式表示如下:
X i ( k ) = X i ( k − 1 ) + ∑ j ∈ N ( i ) X j ( k − 1 ) (1) X_i^{(k)} = X_i^{(k-1)} + \sum _{j\in {\mathcal {N(i)}}} X_j^{(k-1) }\tag {1} Xi(k)=Xi(k1)+jN(i)Xj(k1)(1)

  • message生成的消息就是中心节点的邻居节点的特征向量。
  • aggregaet聚合消息的方式是sum,即把所有邻居节点的特征向量加起来。
  • update更新中心节点的方式是:将聚合得到的消息和中心节点自身的特征向量相加。

1.图数据定义

我们有如下数据:

import torch
from torch_geometric.data import Dataedge_index = torch.tensor([[0, 1],[1, 0]], dtype=torch.long)
x = torch.tensor([[-1, 1], [0, 1], [1, 1]], dtype=torch.float)
data = Data(x=x, edge_index=edge_index.contiguous())

在这里插入图片描述

2.实现GNN的消息传递过程

class MyConv(MessagePassing):def __init__(self):super().__init__(aggr='sum')def forward(self, data):out = self.propagate(data.edge_index, x=data.x)# out = out + x return outdef message(self, x_i, x_j, edge_index_i, edge_index_j):# 生成的消息就是邻居节点的特征向量,直接使用 x_j 访问获取就行return x_jdef aggregate(self, message, edge_index_i):# 这里只是写的样例,实际上一般不会重写这个方法,直接使用默认的就好了,只需要自己选择一下聚合的方式即可return super().aggregate(message, edge_index_i, dim_size=len(x))def update(self, aggregate, x):# 一般也不会重写这个方法的,update阶段可以在forward函数中调用完propagate(...)函数后编写代码。return x + aggregate

3.完整代码

import torch
from torch_geometric.data import Data
from torch_geometric.nn import MessagePassingclass MyConv(MessagePassing):def __init__(self):super().__init__(aggr='sum')def forward(self, data):out = self.propagate(data.edge_index, x=data.x)out = out + data.xreturn outdef message(self, x_i, x_j, edge_index_i, edge_index_j):# 生成的消息就是邻居节点的特征向量,直接使用 x_j 访问获取就行return x_j# def aggregate(self, message, edge_index_i):# 	return super().aggregate(message, edge_index_i, dim_size=len(x))# def update(self, aggregate, x):# 	return x + aggregateif __name__ == '__main__':edge_index = torch.tensor([[0, 1],[1, 0]], dtype=torch.long)x = torch.tensor([[-1, 1], [0, 1], [1, 1]], dtype=torch.float)data = Data(x=x, edge_index=edge_index.contiguous())myConv = MyConv()print(myConv(data))

4.完整代码的精简版本

import torch
from torch_geometric.data import Data
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loopsclass MyConv(MessagePassing):def __init__(self):super().__init__(aggr='sum')def forward(self, data):edge_index, _ = add_self_loops(data.edge_index, num_nodes=len(data.x))out = self.propagate(edge_index, x=data.x)return outif __name__ == '__main__':edge_index = torch.tensor([[0, 1],[1, 0]], dtype=torch.long)x = torch.tensor([[-1, 1], [0, 1], [1, 1]], dtype=torch.float)data = Data(x=x, edge_index=edge_index.contiguous())myConv = MyConv()print(myConv(data))

思考:大家可以根据上面讲解的细节,理解一下这个精简版本的代码的实现逻辑和过程。

四、总结

1.MessagePassing各个函数的执行顺序

在这里插入图片描述

2.参考资料

  • PyG: MessagePassing
  • PyG: Creating Message Passing Networks

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/481118.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Linux—进程学习—04(进程地址空间学习)

目录 Linux—进程学习—41.程序地址空间1.1虚拟地址空间的现象1.2虚拟地址空间的理解(感性) 2.进程地址空间2.0 mm_struct结构体2.1 mm_struct结构体的源代码2.2分页&虚拟地址空间解释前面的实验现象 2.3进程地址空间存在的原因2.3.1第一个原因2.3.2第二个原因2.3.3第三个原…

信息安全实验--密码学实验工具:CrypTool

1. CrypTool介绍💭 CrypTool 1的开源教育工具,用于密码学研究。通过CrypTool 1,可以实现加密和解密操作,数字签名。CrypTool1和2有很多区别的。 2. CrpyTool下载🔧 在做信息安全实验--密码学相关实验时,发…

nodejs30: CSS 剪辑路径clip-path导致伪元素不可见问题及解决方法

相关问题 应用圆角裁剪时无法显示::after 取消clip-path设置&#xff1a; 完整问题代码 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, i…

三、计算机视觉_08YOLO目标检测

0、前言 YOLO作为目前CV领域的扛把子&#xff0c;分类、检测等任务样样精通&#xff0c;本文将基于两个小案例&#xff0c;用YOLO做检测任务&#xff0c;看看效果如何 1、对图片内容做检测 假设我有一张名为picture.jpeg的图片&#xff0c;其内容如下 我将图片和代码放到了同…

STM32 ADC --- 知识点总结

STM32 ADC — 知识点总结 文章目录 STM32 ADC --- 知识点总结cubeMX中配置注解单次转换模式、连续转换模式、扫描模式单通道采样的情况单次转换模式&#xff1a;连续转换模式&#xff1a; 多通道采样的情况禁止扫描模式&#xff08;单次转换模式或连续转换模式&#xff09;单次…

SQL Server 实战 - 多种连接

目录 背景 一、多种连接 1. 复合连接条件 2. 跨数据库连接 3. 隐连接 4. 自连接 5. 多表外连接 6. UNION ALL 二、一个对比例子 背景 本专栏文章以 SAP 实施顾问在实施项目中需要掌握的 sql 语句为偏向进行选题&#xff1a; 用例&#xff1a;SAP B1 的数据库工具&am…

Nginx:ssl

目录 部署ssl前提 nginx部署ssl证书 部署ssl部署建议 部署ssl前提 网站有域名根据域名申请到ssl证书&#xff0c;并下载证书部署到nginx中 部署了ssl证书后&#xff0c;访问的流量是加密的。 nginx部署ssl证书 #80端口跳转到443 server {listen 80;return 302 https://1…

MySQL之单行函数

目录 1. 函数的理解 单行函数 2. 数值函数 2.1 基本函数 2.2 角度与弧度互换函数 2.3 三角函数 2.4 指数与对数 2.5 进制间的转换 3. 字符串函数 4. 日期和时间函数 4.1 获取日期、时间 4.2 日期与时间戳的转换​编辑 4.3 获取月份、星期、星期数、天数等函数 4.4 …

Next.js-样式处理

#题引&#xff1a;我认为跟着官方文档学习不会走歪路 Next.js 支持多种为应用程序添加样式的方法&#xff0c;包括&#xff1a; CSS Modules&#xff1a;创建局部作用域的 CSS 类&#xff0c;避免命名冲突并提高可维护性。全局 CSS&#xff1a;使用简单&#xff0c;对于有传统…

Leetcode 每日一题 104.二叉树的最大深度

目录 问题描述 示例 示例 1&#xff1a; 示例 2&#xff1a; 约束条件 题解 方法一&#xff1a;广度优先搜索&#xff08;BFS&#xff09; 步骤 代码实现 方法二&#xff1a;递归 步骤 代码实现 结论 问题描述 给定一个二叉树 root&#xff0c;我们需要返回其最大…

SQL基础入门——SQL基础语法

1. 数据库、表、列的创建与管理 在SQL中&#xff0c;数据库是一个数据的集合&#xff0c;包含了多个表、视图、索引、存储过程等对象。每个表由若干列&#xff08;字段&#xff09;组成&#xff0c;表中的数据行代表记录。管理数据库和表的结构是SQL的基础操作。 1.1 创建数据…

IP与“谷子”齐飞,阅文“乘势而上”?

爆火的“谷子经济”&#xff0c;又捧出一只“潜力股”。 近日&#xff0c;阅文集团股价持续上涨&#xff0c;5日累计涨幅达13.20%。这其中&#xff0c;周三股价一度大涨约15%至29.15港元&#xff0c;强势突破20日、30日、120日等多根均线&#xff0c;市值突破280亿港元关口。 …

EXCEL截取某一列从第一个字符开始到特定字符结束的字符串到新的一列

使用EXCEL中的公式进行特定截取 假设列A是一组产品的编码&#xff0c;我们需要的数据是“-”之前的字段。 我们需要在B1单元格输入公式“LEFT(A1,SEARCH("-",A1)-1)”然后选中B1至B4单元格&#xff0c;按“CTRLD”向下填充&#xff0c;就可以得出其它几行“-”之前的…

重塑视频新语言,让每一帧都焕发新生——Video-Retalking,开启数字人沉浸式交流新纪元!

模型简介 Video-Retalking 模型是一种基于深度学习的视频再谈话技术&#xff0c;它通过分析视频中的音频和图像信息&#xff0c;实现视频角色口型、表情乃至肢体动作的精准控制与合成。这一技术的实现依赖于强大的技术架构和核心算法&#xff0c;特别是生成对抗网络&#xff0…

多头注意力机制:从原理到应用的全面解析

目录 什么是多头注意力机制&#xff1f; 原理解析 1. 注意力机制的核心公式 2. 多头注意力的扩展 为什么使用多头注意力&#xff1f; 实际应用 1. Transformer中的应用 2. NLP任务 3. 计算机视觉任务 PyTorch 实现示例 总结 近年来&#xff0c;“多头注意力机制&…

力扣637. 二叉树的层平均值

给定一个非空二叉树的根节点 root , 以数组的形式返回每一层节点的平均值。与实际答案相差 10-5 以内的答案可以被接受。 提示&#xff1a; 树中节点数量在 [1, 104] 范围内-231 < Node.val < 231 - 1 代码&#xff1a; /*** Definition for a binary tree node.* stru…

Opencv+ROS实现摄像头读取处理画面信息

一、工具 ubuntu18.04 ROSopencv2 编译器&#xff1a;Visual Studio Code 二、原理 图像信息 ROS数据形式&#xff1a;sensor_msgs::Image OpenCV数据形式&#xff1a;cv:Mat 通过cv_bridge()函数进行ROS向opencv转换 cv_bridge是在ROS图像消息和OpenCV图像之间进行转…

Perforce SAST专家详解:自动驾驶汽车的安全与技术挑战,Klocwork、Helix QAC等静态代码分析成必备合规性工具

自动驾驶汽车安全吗&#xff1f;现代汽车的软件包含1亿多行代码&#xff0c;支持许多不同的功能&#xff0c;如巡航控制、速度辅助和泊车摄像头。而且&#xff0c;这些嵌入式系统中的代码只会越来越复杂。 随着未来汽车的互联程度越来越高&#xff0c;这一趋势还将继续。汽车越…

架构-微服务-服务配置

文章目录 前言一、配置中心介绍1. 什么是配置中心2. 解决方案 二、Nacos Config入门三、Nacos Config深入1. 配置动态刷新2. 配置共享 四、nacos服务配置的核心概念 前言 服务配置--Nacos Config‌ 微服务架构下关于配置文件的一些问题&#xff1a; 配置文件相对分散。在一个…

攻防世界GFSJ1193 cat_theory

题目编号&#xff1a;GFSJ1193 附件下载后是一个jpg文件和一个sage文件&#xff08;python&#xff09;&#xff1a; 1. 分析图片&#xff08;.jpg文件&#xff09; 这个交换图展示的是一个加密系统的 同态加密 性质&#xff0c;其核心思想是&#xff1a;加密前的操作与加密后…