Pytorch+PyG实现GAT(图注意力网络)

文章目录

  • 前言
  • 一、导入相关库
  • 二、加载Cora数据集
  • 三、定义GAT网络
  • 四、定义模型
  • 五、模型训练
  • 六、模型验证
  • 七、结果


前言

大家好,我是阿光。

本专栏整理了《图神经网络代码实战》,内包含了不同图神经网络的相关代码实现(PyG以及自实现),理论与实践相结合,如GCN、GAT、GraphSAGE等经典图网络,每一个代码实例都附带有完整的代码。

正在更新中~ ✨

在这里插入图片描述

🚨 我的项目环境:

  • 平台:Windows10
  • 语言环境:python3.7
  • 编译器:PyCharm
  • PyTorch版本:1.11.0
  • PyG版本:2.1.0

💥 项目专栏:【图神经网络代码实战目录】


本文我们将使用Pytorch + Pytorch Geometric来简易实现一个GAT(图注意力网络),让新手可以理解如何PyG来搭建一个简易的图网络实例demo。

一、导入相关库

本项目我们需要结合两个库,一个是Pytorch,因为还需要按照torch的网络搭建模型进行书写,第二个是PyG,因为在torch中并没有关于图网络层的定义,所以需要torch_geometric这个库来定义一些图层。

import torch
import torch.nn.functional as F
import torch.nn as nn
import torch_geometric.nn as pyg_nn
from torch_geometric.datasets import Planetoid

二、加载Cora数据集

本文使用的数据集是比较经典的Cora数据集,它是一个根据科学论文之间相互引用关系而构建的Graph数据集合,论文分为7类,共2708篇。

  • Genetic_Algorithms
  • Neural_Networks
  • Probabilistic_Methods
  • Reinforcement_Learning
  • Rule_Learning
  • Theory

这个数据集是一个用于图节点分类的任务,数据集中只有一张图,这张图中含有2708个节点,10556条边,每个节点的特征维度为1433。

# 1.加载Cora数据集
dataset = Planetoid(root='./data/Cora', name='Cora')

三、定义GAT网络

这里我们就不重点介绍GAT网络了,相信大家能够掌握基本原理,本文我们使用的是PyG定义网络层,在PyG中已经定义好了GATConv这个层,该层采用的就是GAT机制。

在这里插入图片描述

对于GATConv的常用参数:

  • in_channels:每个样本的输入维度,就是每个节点的特征维度
  • out_channels:经过注意力机制后映射成的新的维度,就是经过GAT后每个节点的维度长度
  • heads:是否采用多头注意力机制,默认是1
  • concat:是否拼接多头注意力机制的结果,如果为False,就会将多头注意力的结果平均化作为最终该节点的特征,如果为True,就会将多个结果进行拼接,形成heads*out_channels长度的向量,默认为True
  • dropout:按照一定概率放弃邻居的聚合操作,默认为0,使用所有邻居进行聚合
  • add_self_loops:为图添加自环,是否考虑自身节点的信息
  • bias:训练一个偏置b
# 2.定义GAT网络
class GAT(nn.Module):def __init__(self, num_node_features, num_classes):super(GAT, self).__init__()self.conv1 = pyg_nn.GATConv(in_channels=num_node_features,out_channels=16,heads=2)self.conv2 = pyg_nn.GATConv(in_channels=2*16,out_channels=num_classes,heads=1)def forward(self, data):x, edge_index = data.x, data.edge_indexx = self.conv1(x, edge_index)x = F.relu(x)x = F.dropout(x, training=self.training)x = self.conv2(x, edge_index)return F.log_softmax(x, dim=1)

上面网络我们定义了两个GATConv层,第一层的参数的输入维度就是初始每个节点的特征维度,输出维度是16,然后采用了多头注意力机制,那么经过该层之后每个节点的特征维度就变成了32,将两个头的结果拼接。

第二个层的输入维度为32,输出维度为分类个数,因为我们需要对每个节点进行分类,最终加上softmax操作。

四、定义模型

下面就是定义了一些模型需要的参数,像学习率、迭代次数这些超参数,然后是模型的定义以及优化器及损失函数的定义,和pytorch定义网络是一样的。

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 设备
epochs = 200 # 学习轮数
lr = 0.0003 # 学习率
num_node_features = dataset.num_node_features # 每个节点的特征数
num_classes = dataset.num_classes # 每个节点的类别数
data = dataset[0].to(device) # Cora的一张图# 3.定义模型
model = GAT(num_node_features, num_classes).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr) # 优化器
loss_function = nn.NLLLoss() # 损失函数

五、模型训练

模型训练部分也是和pytorch定义网络一样,因为都是需要经过前向传播、反向传播这些过程,对于损失、精度这些指标可以自己添加。

# 训练模式
model.train()for epoch in range(epochs):optimizer.zero_grad()pred = model(data)loss = loss_function(pred[data.train_mask], data.y[data.train_mask]) # 损失correct_count_train = pred.argmax(axis=1)[data.train_mask].eq(data.y[data.train_mask]).sum().item() # epoch正确分类数目acc_train = correct_count_train / data.train_mask.sum().item() # epoch训练精度loss.backward()optimizer.step()if epoch % 20 == 0:print("【EPOCH: 】%s" % str(epoch + 1))print('训练损失为:{:.4f}'.format(loss.item()), '训练精度为:{:.4f}'.format(acc_train))print('【Finished Training!】')

六、模型验证

下面就是模型验证阶段,在训练时我们是只使用了训练集,测试的时候我们使用的是测试集,注意这和传统网络测试不太一样,在图像分类一些经典任务中,我们是把数据集分成了两份,分别是训练集、测试集,但是在Cora这个数据集中并没有这样,它区分训练集还是测试集使用的是掩码机制,就是定义了一个和节点长度相同纬度的数组,该数组的每个位置为True或者False,标记着是否使用该节点的数据进行训练。

# 模型验证
model.eval()
pred = model(data)# 训练集(使用了掩码)
correct_count_train = pred.argmax(axis=1)[data.train_mask].eq(data.y[data.train_mask]).sum().item()
acc_train = correct_count_train / data.train_mask.sum().item()
loss_train = loss_function(pred[data.train_mask], data.y[data.train_mask]).item()# 测试集
correct_count_test = pred.argmax(axis=1)[data.test_mask].eq(data.y[data.test_mask]).sum().item()
acc_test = correct_count_test / data.test_mask.sum().item()
loss_test = loss_function(pred[data.test_mask], data.y[data.test_mask]).item()print('Train Accuracy: {:.4f}'.format(acc_train), 'Train Loss: {:.4f}'.format(loss_train))
print('Test  Accuracy: {:.4f}'.format(acc_test), 'Test  Loss: {:.4f}'.format(loss_test))

七、结果

【EPOCH:1
训练损失为:1.9267 训练精度为:0.2429
【EPOCH:21
训练损失为:1.8147 训练精度为:0.6357
【EPOCH:41
训练损失为:1.6524 训练精度为:0.8643
【EPOCH:61
训练损失为:1.4964 训练精度为:0.8929
【EPOCH:81
训练损失为:1.2864 训练精度为:0.9500
【EPOCH:101
训练损失为:1.1380 训练精度为:0.9571
【EPOCH:121
训练损失为:0.9685 训练精度为:0.9500
【EPOCH:141
训练损失为:0.8542 训练精度为:0.9571
【EPOCH:161
训练损失为:0.7348 训练精度为:0.9571
【EPOCH:181
训练损失为:0.6463 训练精度为:0.9714
【Finished Training!】>>>Train Accuracy: 0.9929 Train Loss: 0.5155
>>>Test  Accuracy: 0.7810 Test  Loss: 1.0362
训练集测试集
Accuracy0.99290.7810
Loss0.51551.0362

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/48071.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

脑电信号特征提取方法与应用

前言 脑电图(EEG)信号在理解与脑功能和脑相关疾病的电活动方面发挥着重要作用。典型的脑电信号分析流程如下:(1)数据采集;(2)数据预处理;(3)特征提取;(4)特征选择;(5)模型训练与分类;(6)性能评估。当信号分…

基础2-用卷积神经网络进行颅内和头皮脑电图数据分析的广义癫痫预测

A Generalised Seizure Prediction with Convolutional Neural Networks for Intracranial and Scalp Electroencephalogram Data Analysis 为了改善耐药癫痫和强直性癫痫患者的生活,癫痫预测作为最具挑战性的预测数据分析工作之一已引起越来越多的关注。许多杰出的…

异质图神经网络(持续更新ing...)

诸神缄默不语-个人CSDN博文目录 本文将对异质图神经网络(HGNN, heterogeneous graph neural networks)的方法演变进行梳理和介绍。 最近更新时间:2023.5.10 最早更新时间:2022.10.31 文章目录 1. 异质图2. 处理为同质图3. 知识图…

Python画棵圣诞树 ~ Merry Christmas ~

圣诞节快到了,用python、turtle画棵圣诞树吧~_Ding2langdang的博客-CSDN博客 转载于Ding2langdang 最近圣诞节快到啦,CSDN的热搜也变成了”代码画颗圣诞树“,看了几篇博客,发现原博主把一些圣诞树给融合在了一起。 我更喜欢树叶…

圣诞树网页和圣诞树应用程序

圣诞树网页和圣诞树应用程序 1、圣诞树网页 效果图 代码 <!DOCTYPE HTML PUBLIC "-//W3C//DTD HTML 4.01 Frameset//EN" "http://www.w3.org/TR/html4/frameset.dtd"> <html> <head> <title>写给xxx的的圣诞树</title> …

【圣诞节限定】教你用Python画圣诞树,做个浪漫的程序员

最近在各大社交平台看到好多圣诞树,看到大佬们画的圣诞树一个比一个精致,我也特别想尝试画一棵特别的圣诞树。下面是我画的一棵简易的圣诞树,虽然和网络上大佬们的圣诞树相比不是很精致,但是对于萌新们来说,画这样一棵简易的圣诞树还是非常轻松的。 ps:重要的不是圣诞树,…

浪漫的turtle,送给程序员自己的圣诞树

前几天一直在整 Pyqt5 相关的知识&#xff0c;在 Python UI 的世界里 Pyqt5 只是其中的一种用来做应用程序比较 nice。要在一个画布上面呈现我们需要的东西还是得依赖 turtle 比较靠谱&#xff0c;什么组件就做什么事、没有谁比谁厉害&#xff0c;只是在合适的地方用合适的组件…

用代码画两棵圣诞树送给你【附详细代码】

大家好&#xff0c;我是宁一 代码的魔力之处在于&#xff0c;可以帮我们实现许多奇奇怪怪、有趣的想法。 比如&#xff0c;用Python的Turtle库&#xff0c;可以帮我们在电脑上画出好看的图像。 下面这张樱花图就是用Turtle库实现的。 这不圣诞节快到啦。 那么就用代码来画一…

最近比较火的圣诞树HTML代码

<!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><title>张洋</title><link rel"stylesheet" href"https://cdnjs.cloudflare.com/ajax/libs/normalize/5.0.0/normalize.min.css"…

【圣诞树代码】

新建一个HTML文件&#xff0c;直接复制粘贴就行。 <!DOCTYPE html> <html lang"en" ><head><meta charset"UTF-8"><title></title><link rel"stylesheet" href"https://cdnjs.cloudflare.com/aja…

【圣诞节】简单代码实现圣诞树|圣诞贺卡 | 快来为心爱的她送上专属的圣诞礼物叭~

圣诞节马上就要到了&#xff0c;不知道给自己喜欢的人准备什么样的惊喜吗&#xff1f;作为一名程序员&#xff0c;当然是用编程制作专属于她or他的圣诞树&#xff01; 目录 &#x1f384;圣诞树 ✨3D圣诞树 代码块 打开方式 修改位置 效果展示 ✨音乐律动圣诞树 代码块…

圣诞节来了,怎能还没有圣诞树呢 快来为心爱的她送上专属的圣诞礼物叭~

&#x1f4e2;&#x1f4e2;&#x1f4e2;&#x1f4e3;&#x1f4e3;&#x1f4e3; &#x1f33b;&#x1f33b;&#x1f33b;Hello&#xff0c;大家好我叫是Dream呀&#xff0c;一个有趣的Python博主&#xff0c;小白一枚&#xff0c;多多关照&#x1f61c;&#x1f61c;&…

用Python画圣诞树 ‘‘遇见’’ 圣诞老人

这是雪程序的1.1版本。 上个版本的文章---看这里&#xff1a; 忙活半天只为了看雪--送给大家的冬至礼物https://blog.csdn.net/qq_54554848/article/details/121873955?spm1001.2014.3001.5501&#xff08;下述代码基于上个版本&#xff09; 上次我发布了--冬至礼物的博客&…

圣诞节快到了,程序员应该给女友送一个线上圣诞树

我们把下载的压缩包解压&#xff0c;把exe文件放到桌面&#xff0c;双击打开即可。 桌面效果图&#xff1a; 可以打开多个圣诞树&#xff0c;如果关闭圣诞树需要鼠标右键点击exit即可。 代码如下&#xff1a; <!DOCTYPE HEML PUBLIC> <html> <head> <me…

快要圣诞节啦,快去给小伙伴们分享漂亮的圣诞树吧

最近翻到一篇知乎&#xff0c;上面有不少用Python&#xff08;大多是turtle库&#xff09;绘制的树图&#xff0c;感觉很漂亮&#xff0c;我整理了一下&#xff0c;挑了一些我觉得不错的代码分享给大家&#xff08;这些我都测试过&#xff0c;确实可以生成喔~&#xff09; 重中…

python画圣诞树【方块圣诞树、线条圣诞树、豪华圣诞树】

文章目录 前言【便捷源码下载处】1.方块圣诞树2.线条圣诞树3.豪华圣诞树 这篇文章主要介绍了使用Python画了一棵圣诞树的实例代码,本文通过实例代码给大家介绍的非常详细&#xff0c;对大家的学习或工作具有一定的参考借鉴价值&#xff0c;需要的朋友可以参考下 前言【便捷源码…

圣诞节都到了,快使用代码画棵圣诞树吧

&#x1f4d2; 博客首页&#xff1a;✎﹏ℳ๓敬坤的博客 &#x1f388; &#x1f60a; 我只是一个代码的搬运工 &#x1f383; &#x1f389; 欢迎来访的读者关注、点赞和收藏 &#x1f91e; &#x1f609; 有问题可以私信交流 &#x1f606; &#x1f4c3; 文章标题&#xff1…

圣诞树代码 html

新建txt文档 <!DOCTYPE html> <html lang"en" > <head> <meta charset"UTF-8"> <title></title> <link rel"stylesheet" href"https://cdnjs.cloudflare.com/ajax/libs/normalize/5.0.0/normaliz…

圣诞节来了,用Python Turtle画棵圣诞树吧

如何实现上图效果呢&#xff1f;让我们开始吧&#xff01; 首先&#xff0c;导入turtle和random from turtle import * import random as rd 然后&#xff0c;写一个待会儿要用到的函数&#xff0c;用于随机生成True和False def true_or_false(percent50):nrd.randint(1,10…

大家都在画圣诞树,我们用代码敲一颗吧~圣诞树

前段时间发布的文章很多人问怎么操作的&#xff0c;今天具体说明一下&#xff1a; PS&#xff1a;如果需要下载可以点击左下角阅读全文下载代码使用更方便 具体步骤如下&#xff1a; 复制下面代码在电脑里面新建一个记事本&#xff0c;将代码复制到新建的记事本里保存记事本&am…