图神经网络:处理复杂关系结构与图分类任务的强大工具

创作不易,您的打赏、关注、点赞、收藏和转发是我坚持下去的动力!图神经网络

图神经网络(Graph Neural Network, GNN)是针对图数据的一类神经网络模型。图数据具有节点(节点代表实体)和边(边代表节点之间的关系),因此,GNN能够处理这种复杂的关系结构,提取图结构中有用的信息。GNN的基本思想是通过消息传递(message passing)机制将节点和它们的邻居进行特征融合,从而更新节点的表示。这种表示可以用来进行节点分类、边预测或者整个图的分类等任务。

1. GNN基础知识

GNN的核心机制是基于图的消息传递和特征聚合。对于每个节点,GNN会收集其邻居节点的信息,然后通过一定的聚合函数(例如求和或平均)生成新的特征表示。

1.1 图的定义
  • 节点(Node):图中的实体,记作 (v_i)。
  • 边(Edge):节点之间的关系,记作 (e_{ij}),表示从节点 (v_i) 到节点 (v_j) 的连接。
  • 邻居节点(Neighbors):节点 (v_i) 的直接相连节点集合,记作 (N(v_i))。
1.2 GNN的消息传递机制

GNN的基本操作包括两个步骤:

  1. 消息传递(Message Passing):从每个节点的邻居节点收集特征。
  2. 特征更新(Feature Update):将节点的特征与邻居的特征聚合,更新节点的表示。

假设节点 (v_i) 的初始特征为 (h_i^{(0)}),其第 (k) 次迭代时的特征表示为 (h_i^{(k)})。GNN通过以下两步进行更新:

  • 聚合邻居特征:将节点 (v_i) 的所有邻居节点的特征聚合起来,例如求和或平均:
    [
    m_i^{(k)} = \text{AGGREGATE}({ h_j^{(k-1)} : j \in N(v_i) })
    ]
  • 更新节点特征:将聚合的邻居特征与节点本身的特征结合起来,更新节点的表示:
    [
    h_i^{(k)} = \text{UPDATE}(h_i^{(k-1)}, m_i^{(k)})
    ]
1.3 GNN在图分类任务中的应用

图分类任务的目标是给定一张图,预测该图的类别。常见应用包括化学分子分类、社交网络分析等。在这种任务中,GNN的目标是通过学习图的全局结构信息来预测整张图的标签。

GNN处理图分类任务的流程一般如下:

  1. 特征初始化:给每个节点赋予初始特征(可以是节点的属性)。
  2. 消息传递与特征更新:通过多层GNN层,将节点特征与其邻居进行聚合和更新。
  3. 图的汇总(Readout):将所有节点的特征汇总为图的表示(例如通过求平均或全连接层)。
  4. 分类器:使用图的表示作为输入,通过一个分类器预测图的类别。

2. Python实现示例

我们可以使用PyTorch Geometric来实现一个简单的图分类任务。

2.1 安装依赖

首先,你需要安装PyTorchPyTorch Geometric库:

pip install torch
pip install torch-geometric
2.2 数据准备

我们使用PyTorch Geometric中的一个经典的图分类数据集MUTAG,这是一个小型化学分子数据集,每个分子作为一张图,目标是预测分子的类别。

import torch
import torch.nn.functional as F
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv, global_mean_pool# 加载数据集
dataset = TUDataset(root='/tmp/MUTAG', name='MUTAG')# 划分训练集和测试集
train_dataset = dataset[:150]
test_dataset = dataset[150:]train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
2.3 定义GNN模型

我们定义一个简单的图卷积网络(GCN)用于图分类任务。

class GCN(torch.nn.Module):def __init__(self):super(GCN, self).__init__()# 定义两个GCN层self.conv1 = GCNConv(dataset.num_node_features, 64)self.conv2 = GCNConv(64, 64)# 最后一个全连接层用于图分类self.fc = torch.nn.Linear(64, dataset.num_classes)def forward(self, data):x, edge_index, batch = data.x, data.edge_index, data.batch# 第一层GCN + ReLU激活x = self.conv1(x, edge_index)x = F.relu(x)# 第二层GCNx = self.conv2(x, edge_index)# 使用全局平均池化将节点特征聚合为图的特征x = global_mean_pool(x, batch)# 最后通过全连接层进行分类x = self.fc(x)return F.log_softmax(x, dim=1)
2.4 模型训练和测试

我们定义训练和测试的函数,分别用于训练模型和评估模型的性能。

# 定义设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GCN().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)def train():model.train()total_loss = 0for data in train_loader:data = data.to(device)optimizer.zero_grad()output = model(data)loss = F.nll_loss(output, data.y)loss.backward()optimizer.step()total_loss += loss.item()return total_loss / len(train_loader)def test(loader):model.eval()correct = 0for data in loader:data = data.to(device)output = model(data)pred = output.argmax(dim=1)correct += pred.eq(data.y).sum().item()return correct / len(loader.dataset)# 训练模型
for epoch in range(1, 201):loss = train()test_acc = test(test_loader)print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Test Acc: {test_acc:.4f}')
2.5 解释代码
  • GCNConv:图卷积层,用于将节点的特征与其邻居的特征进行聚合。
  • global_mean_pool:对图中的所有节点特征进行全局池化,将节点特征汇总为图的特征表示。
  • forward:定义了模型的前向传播,输入图的特征和结构,输出图的类别预测。

通过上述代码,你可以用GNN进行图分类任务。这个模型会对每张图中的所有节点进行特征更新,并最终通过全连接层进行分类。

大家有技术交流指导、论文及技术文档写作指导、课程知识点讲解、项目开发合作的需求可以搜索关注我私信我

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/436765.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Webstorm 中对 Node.js 后端项目进行断点调试

首先,肯定需要有一个启动服务器的命令脚本。 然后,写一个 debug 的配置: 然后,debug 模式 启动项目和 启动调试服务: 最后,发送请求,即可调试: 这几个关键按钮含义: 重启…

【STM32-HAL库】自发电型风速传感器(使用STM32F407ZGT6)(附带工程下载链接)

一、自发电型风速传感器介绍 自发电型风速传感器,也称为风力发电型风速传感器或无源风速传感器,是一种不需要外部电源即可工作的风速测量设备。这种传感器通常利用风力来驱动内部的发电机构,从而产生电能来供电测量风速的传感器部分。以下是自…

Footprint Growthly Quest 工具:赋能 Telegram 社区实现 Web3 飞速增长

作者:Stella L (stellafootprint.network) 在 Web3 的快节奏世界里,社区互动是关键。而众多 Web3 社区之所以能够蓬勃发展,很大程度上得益于 Telegram 平台。正因如此,Footprint Analytics 精心打造了 Growthly —— 一款专为 Tel…

Miniforge详细安装教程(macOs和Windows)

(注:主要是解决商业应用anaconda收费问题,这是轻量级的代替,个人完全可以使用anaconda和miniconda) Miniforge 是一个轻量级的包管理器,类似于 Anaconda 和 Miniconda。它主要用于安装基于 conda 的 Python 环境,专注于…

opencv - Fourier Transform 傅里叶变换

目标 在本节中,我们将学习 使用 OpenCV 查找图像的傅里叶变换利用 Numpy 中可用的 FFT 函数傅里叶变换的一些应用我们将看到以下函数:cv.dft()、cv.idft() 等 理论 傅里叶变换用于分析各种滤波器的频率特性。对于图像,2D 离散傅里叶变换 …

Android Studio 新版本 Logcat 的使用详解

点击进入官方Logcat介绍 一个好的Android程序员要会使用AndroidStudio自带的Logcat查看日志,会Log定位也是查找程序bug的第一关键。同时Logcat是一个查看和处理日志消息的工具,它可以更快的帮助开发者调试应用程序。 步入正题,看图说话。 点…

sql-server【bcp工具】

目录 1.查看bcp是否可用 2.bcp 命令的基本语法 3.数据导出 4.数据导入 bcp(Bulk Copy Program)是 SQL Server 提供的一个命令行工具,用于在 SQL Server 实例与用户指定格式的数据文件之间批量复制表或视图数据。bcp 工具非常适合进行大量…

基于Spark的汽车行业大数据分析及可视化系统

作者:计算机学姐 开发技术:SpringBoot、SSM、Vue、MySQL、JSP、ElementUI、Python、小程序等,“文末源码”。 专栏推荐:前后端分离项目源码、SpringBoot项目源码、Vue项目源码、SSM项目源码、微信小程序源码 精品专栏:…

【leetcode】121.买卖股票的最佳时机

思路&#xff1a; 找到后面与前面的差值最大即可。 代码&#xff1a; int maxProfit(int* prices, int pricesSize) {int i 0, j 0;//i是后一个最大的&#xff0c;j是前面最小的int max 0, temp 0;//表示最大值for (i 1; i < pricesSize; i){if (prices[j] < pr…

产品经理的学习

初学 接需求 画原型 写文档 日常产出 流程图 举例购物的流程 结构图 一个应用的全部功能&#xff0c;用思维导图的方式去罗列出来 竞品分析文档 竞品分类 竞品选择 竞品采集 竞品文档书写 也可以做一个产品的产品结构图 需求文档 干系人 需求方 记录人 产品经理 其他项目干系人…

【办公类-48-03】20240930每月电子屏台账汇总成docx-3(三园区合并EXCLE,批量生成3份word)

背景需求&#xff1a; 前期电子屏汇总是“总园”用“”问卷星”、“一分园”用“腾讯文档”&#xff0c;二分园“用“手写word”” 【办公类-48-02】20240407每月电子屏台账汇总成docx-2&#xff08;腾讯文档xlsx导入docx&#xff0c;每页20条&#xff09;【办公类-48-02】20…

腾讯云新开端口

检查防火墙设置 890 2024-09-30 20:47:18 netstat -tuln | grep 1213891 2024-09-30 20:47:49 ping 110.40.130.231892 2024-09-30 20:48:38 sudo firewall-cmd --zonepublic --add-port1213/tcp --permanent893 2024-09-30 20:48:51 sudo firewall-cmd --reload894 2024-…

汽车线束之故障诊断方案-TDR测试

当前&#xff0c;在汽车布局中的线束的性能要求越来越高。无法通过简单的通断测试就能满足性能传输要求。早起对智能化要求不高&#xff0c;比如没有激动雷达、高清摄像、中央CPU等。 近几年的智能驾驶对网络传输要求越来越高&#xff0c;不但是高速率&#xff0c;还需要高稳定…

常见的RTSP播放器有哪些?

VLC播放器 特点&#xff1a;VLC 是一款功能强大、跨平台的多媒体播放器&#xff0c;支持多种音频和视频格式以及流媒体协议&#xff0c;包括 RTSP。它具有广泛的解码器支持&#xff0c;能播放大多数常见的视频和音频格式。其开源特性使得它拥有活跃的开发者社区&#xff0c;不断…

HCIP--以太网交换安全(一)

目录 端口隔离 MAC地址表安全 以太网交换安全概述&#xff1a;以太网交换安全是一系列技术和策略的集合&#xff0c;旨在保护以太网交换机免受各种网络攻击和威胁。 端口隔离 一、端口隔离概述&#xff1a; 作用&#xff1a;可以实现同一个VLAN内端口的隔离 优势&#xff1a…

modelsim仿真 wave视图里 数据位宽和进制怎么显示

在modelsim 某些版本安装后&#xff0c;如ModelSim SE-64 2020.4版本&#xff0c;重置布局等情况下&#xff0c; 解决方案其实很简单&#xff1a; 点击中间的按钮 在Wave Windows Preferences 勾选Display-Show Radix Base -> Waveforms

Lj视频下载器 1.1.37 简洁高效的视频下载工具

Lj视频下载器是一个功能强大的视频下载器&#xff0c;支持直接添加视频地址或 m3u8 资源地址&#xff0c;可以从网页中自动提取视频进行下载。支持多种视频格式&#xff0c;包括 m3u8&#xff0c;并能自动检测并移除广告片段。 大小&#xff1a;19M 百度网盘&#xff1a;https…

音悦 1.5.1 完全免费,无广告,纯净听歌体验

音悦是一款完全免费的听歌应用&#xff0c;汇聚全网多平台曲库&#xff0c;拥有排行榜、MV、个性电台、我的歌单、收藏喜欢等功能。无需会员&#xff0c;没有广告&#xff0c;免费听歌下歌&#xff0c;是一款非常纯净小巧但功能齐全的听歌神器。 大小&#xff1a;27.6M 百度网…

YOLOv11尝鲜测试五分钟极简配置

ultralytics团队在最近又推出了YOLOv11&#xff0c;不知道在有生之年能不能看到YOLOv100呢哈哈。 根据官方文档&#xff0c;在 Python>3.8并且PyTorch>1.8的环境下即可安装YOLOv11&#xff0c;因此之前YOLOv8的环境是可以直接用的。 安装YOLOv11&#xff1a; pip instal…