pytorch图神经网络处理图结构数据

 人工智能例子汇总:AI常见的算法和例子-CSDN博客 

图神经网络(Graph Neural Networks,GNNs)是一类能够处理图结构数据的深度学习模型。图结构数据由节点(vertices)和边(edges)组成,其中节点表示实体,边表示实体之间的关系或连接。GNNs 通过在图的结构上进行信息传递和节点嵌入(node embedding)来学习节点或图的特征表示。

GNN的关键思想是通过消息传递机制(message passing)更新每个节点的表示,通常是基于其邻居节点的特征信息。GNNs 可以广泛应用于许多领域,如社交网络分析、推荐系统、知识图谱、分子图表示等。

以下是GNN的基本组成部分和工作原理:

  1. 节点表示更新:每个节点的表示通过其邻居节点的表示进行更新。常见的做法是通过聚合邻居节点的特征,然后与节点本身的特征进行结合

GNN的变种

  1. GCN(Graph Convolutional Networks):一种基于图卷积的GNN,通过聚合邻居节点的特征来更新节点表示,适用于无向图。

  2. GraphSAGE(Graph Sample and Aggregation):通过随机采样邻居节点来提高计算效率,尤其适用于大规模图。

  3. GAT(Graph Attention Networks):引入了注意力机制,使得不同邻居对节点更新的贡献不同,能够动态调整每个邻居的权重。

  4. Graph Isomorphism Network (GIN):通过强大的表征能力增强了图的判别性。

GNN的应用

  • 社交网络分析:预测用户之间的关系或用户的兴趣。
  • 推荐系统:基于用户和物品之间的图结构进行个性化推荐。
  • 生物信息学:如分子图表示,用于药物发现、蛋白质结构预测等。
  • 图像分割与语义分析:在视觉任务中处理图形数据,捕捉图像之间的关系。

例子:

import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data
import matplotlib.pyplot as plt# 1. 生成随机图数据
num_nodes = 100
x = torch.rand((num_nodes, 2))  # 100 个节点,每个节点有 2 维特征
y = (x[:, 0] + x[:, 1] > 1).long()  # 二分类标签(0 或 1)# 2. 生成图结构(邻接关系)
edge_index = []
for i in range(num_nodes):for j in range(i + 1, num_nodes):if (y[i] == y[j] and torch.rand(1).item() > 0.6) or (y[i] != y[j] and torch.rand(1).item() > 0.9):edge_index.append([i, j])edge_index.append([j, i])
edge_index = torch.tensor(edge_index, dtype=torch.long).t()# 3. 训练集和测试集
train_mask = torch.rand(num_nodes) < 0.8  # 80% 训练,20% 测试
test_mask = ~train_mask# 4. 构造 PyG 数据对象
data = Data(x=x, edge_index=edge_index, y=y, train_mask=train_mask, test_mask=test_mask)# 5. 定义 4 层 GCN 模型
class GCN(torch.nn.Module):def __init__(self):super(GCN, self).__init__()self.conv1 = GCNConv(2, 16)self.conv2 = GCNConv(16, 16)self.conv3 = GCNConv(16, 16)  # 将 conv3 输出改为与输入维度相同self.conv4 = GCNConv(16, 2)  # 输出类别数 2def forward(self, data):x, edge_index = data.x, data.edge_indexx = F.relu(self.conv1(x, edge_index))x = F.relu(self.conv2(x, edge_index))x = F.relu(self.conv3(x, edge_index)) + x  # 跳跃连接,维度一致x = self.conv4(x, edge_index)return F.log_softmax(x, dim=1)  # 输出对数概率# 6. 训练模型
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GCN().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=500, gamma=0.5)  # 学习率衰减data = data.to(device)
num_epochs = 2000  # 增加训练轮数for epoch in range(num_epochs):model.train()optimizer.zero_grad()out = model(data)loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])loss.backward()optimizer.step()scheduler.step()  # 逐步降低学习率if epoch % 200 == 0:print(f"Epoch {epoch}, Loss: {loss.item():.4f}")# 7. 评估模型
model.eval()
out = model(data)
pred = out.argmax(dim=1)  # 取最大值的索引作为类别
test_pred = pred[data.test_mask]
test_true = data.y[data.test_mask]# 8. 过滤低置信度预测
proba = torch.exp(out)  # 转换为 softmax
test_pred[proba[data.test_mask].max(dim=1)[0] < 0.6] = -1  # 低置信度设为 -1# 9. 可视化测试结果
test_mask_np = torch.arange(num_nodes)[data.test_mask].cpu().numpy()
test_pred_np = test_pred.cpu().numpy()
test_true_np = test_true.cpu().numpy()plt.figure(figsize=(10, 5))
plt.scatter(test_mask_np, test_pred_np, color='blue', alpha=0.5, label='Predicted')
plt.scatter(test_mask_np, test_true_np, color='red', alpha=0.5, label='True')
plt.xlabel('Test Node Index')
plt.ylabel('Node Class')
plt.title('Test Results vs True Results')
plt.legend()
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/12313.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

[mmdetection]fast-rcnn模型训练自己的数据集的详细教程

本篇博客是由本人亲自调试成功后的学习笔记。使用了mmdetection项目包进行fast-rcnn模型的训练&#xff0c;数据集是自制图像数据。废话不多说&#xff0c;下面进入训练步骤教程。 注&#xff1a;本人使用linux服务器进行展示&#xff0c;Windows环境大差不差。另外&#xff0…

对比uart iic spi 三种总线的使用

1.uart串口通信 1.1uart的通信总线方式 1.2查询开发板和数据手册对需要进行修改的串口进行设置 例如STM32MP157aaa 1.设置8bit数据位 2.设置无校验位 3.设置1bit停止位 4.设置波特率为115200 5.设置16倍过采样 7.使能发送器 TE 8.使能接收器 RE 9.使能串口 UE10.发送数据&…

【玩转 Postman 接口测试与开发2_016】第13章:在 Postman 中实现契约测试(Contract Testing)与 API 接口验证(上)

《API Testing and Development with Postman》最新第二版封面 文章目录 第十三章 契约测试与 API 接口验证1 契约测试的概念2 契约测试的工作原理3 契约测试的分类4 DeepSeek 给出的契约测试相关背景5 契约测试在 Postman 中的创建方法6 API 实例的基本用法7 API 实例的类型实…

java-(Oracle)-Oracle,plsqldev,Sql语法,Oracle函数

卸载好注册表,然后安装11g 每次在执行orderby的时候相当于是做了全排序,思考全排序的效率 会比较耗费系统的资源,因此选择在业务不太繁忙的时候进行 --给表添加注释 comment on table emp is 雇员表 --给列添加注释; comment on column emp.empno is 雇员工号;select empno,en…

尚硅谷课程【笔记】——大数据之Shell【一】

课程视频&#xff1a;【【尚硅谷】Shell脚本从入门到实战】 一、Shell概述 为什么要学习Shell&#xff1f; 1&#xff09;需要看懂运维人员的Shell程序 2&#xff09;偶尔编写一些简单的Shell程序来管理集群、提高开发效率 什么是Shell&#xff1f; 1&#xff09;Shell是一…

pytorch实现长短期记忆网络 (LSTM)

人工智能例子汇总&#xff1a;AI常见的算法和例子-CSDN博客 LSTM 通过 记忆单元&#xff08;cell&#xff09; 和 三个门控机制&#xff08;遗忘门、输入门、输出门&#xff09;来控制信息流&#xff1a; 记忆单元&#xff08;Cell State&#xff09; 负责存储长期信息&…

CDDIS从2025年2月开始数据迁移

CDDIS 将从 2025 年 2 月开始将我们的网站从 cddis.nasa.gov 迁移到 earthdata.nasa.gov&#xff0c;并于 2025 年 6 月结束。 期间可能对GAMIT联网数据下载造成影响。

【Redis】主从模式,哨兵,集群

主从复制 单点问题&#xff1a; 在分布式系统中&#xff0c;如果某个服务器程序&#xff0c;只有一个节点&#xff08;也就是一个物理服务器&#xff09;来部署这个服务器程序的话&#xff0c;那么可能会出现以下问题&#xff1a; 1.可用性问题&#xff1a;如果这个机器挂了…

华为云kubernetes部署deepseek r1、ollama和open-webui(已踩过坑)

1 概述 ollama是一个管理大模型的一个中间层&#xff0c;通过它你可以下载并管理deepseek R1、llama3等大模型。 open-webui是一个web界面&#xff08;界面设计受到chatgpt启发&#xff09;&#xff0c;可以集成ollama API、 OpenAI的 API。 用常见的web应用架构来类比&#x…

在Mac mini M4上部署DeepSeek R1本地大模型

在Mac mini M4上部署DeepSeek R1本地大模型 安装ollama 本地部署&#xff0c;我们可以通过Ollama来进行安装 Ollama 官方版&#xff1a;【点击前往】 Web UI 控制端【点击安装】 如何在MacOS上更换Ollama的模型位置 默认安装时&#xff0c;OLLAMA_MODELS 位置在"~/.o…

CSS 背景与边框:从基础到高级应用

CSS 背景与边框&#xff1a;从基础到高级应用 1. CSS 背景样式1.1 背景颜色示例代码&#xff1a;设置背景颜色 1.2 背景图像示例代码&#xff1a;设置背景图像 1.3 控制背景平铺行为示例代码&#xff1a;控制背景平铺 1.4 调整背景图像大小示例代码&#xff1a;调整背景图像大小…

数据思维错题知识点整理(复习)

小的知识点整理 目前常见的数据采集方案有什么。 埋点、可视化埋点、无埋点&#xff08;无埋点并不是字面意思不埋点&#xff0c;其实也是一种埋点&#xff0c;只是让开发人员完全无感知&#xff0c;直接嵌入sdk&#xff0c;然后每个元素都能查看他们的情况&#xff0c;后续开…

PyQt4学习笔记2】QMainWindow

目录 一、创建 QMainWindow 组件 1. 创建工具栏 2. 创建停靠窗口 3. 设置状态栏 4. 设置中央窗口部件 二、QMainWindow 的主要方法 1. addToolBar() 2. addDockWidget() 3. setStatusBar() 4. setCentralWidget() 5. menuBar() 6. saveState() 和 restoreState() 三、QMainWind…

Linux:文件系统(软硬链接)

目录 inode ext2文件系统 Block Group 超级块&#xff08;Super Block&#xff09; GDT&#xff08;Group Descriptor Table&#xff09; 块位图&#xff08;Block Bitmap&#xff09; inode位图&#xff08;Inode Bitmap&#xff09; i节点表&#xff08;inode Tabl…

ubuntu22.40安装及配置静态ip解决重启后配置失效

遇到这种错误&#xff0c;断网安装即可&#xff01; 在Ubuntu中配置静态IP地址的步骤如下。根据你使用的Ubuntu版本&#xff08;如 Netplan 或传统的 ifupdown&#xff09;&#xff0c;配置方法有所不同。以下是基于 Netplan 的配置方法&#xff08;适用于Ubuntu 17.10及更高版…

手写MVVM框架-实现简单的数据代理

MVVM框架最显著的特点就是虚拟dom和响应式的数据、我们以Vue为例&#xff0c;分别实现data、computed、created、methods以及虚拟dom。 这一章我们先实现简单的响应式&#xff0c;修改数据之后在控制台打印。 我们将该框架命名为MiniVue。 首先我们需要创建MiniVue的类(src/co…

ESLint

ESLint ESLint 是一个针对 JS 的代码风格检查工具&#xff0c;当不满足其要求的风格时&#xff0c;会给予警告或错误。 官网&#xff1a;https://eslint.org/ 中文网&#xff1a;https://eslint.nodejs.cn/ 安装使用 在你的项目中安装 ESLint 包&#xff1a; npm install -…

kaggle视频行为分析1st and Future - Player Contact Detection

这次比赛的目标是检测美式橄榄球NFL比赛中球员经历的外部接触。您将使用视频和球员追踪数据来识别发生接触的时刻&#xff0c;以帮助提高球员的安全。两种接触&#xff0c;一种是人与人的&#xff0c;另一种是人与地面&#xff0c;不包括脚底和地面的&#xff0c;跟我之前做的这…

Chapter 6 -Fine-tuning for classification

Chapter 6 -Fine-tuning for classification 本章内容涵盖 引入不同的LLM微调方法准备用于文本分类的数据集修改预训练的 LLM 进行微调微调 LLM 以识别垃圾邮件评估微调LLM分类器的准确性使用微调的 LLM 对新数据进行分类 现在&#xff0c;我们将通过在大语言模型上对特定目标任…

【从零开始的LeetCode-算法】922. 按奇偶排序数组 II

给定一个非负整数数组 nums&#xff0c; nums 中一半整数是 奇数 &#xff0c;一半整数是 偶数 。 对数组进行排序&#xff0c;以便当 nums[i] 为奇数时&#xff0c;i 也是 奇数 &#xff1b;当 nums[i] 为偶数时&#xff0c; i 也是 偶数 。 你可以返回 任何满足上述条件的…