【Graph Net学习】GNN/GCN代码实战

一、简介

        GNN(Graph Neural Network)和GCN(Graph Convolutional Network)都是基于图结构的神经网络模型。本文目标就是打代码基础,未用PyG,来扒一扒Graph Net两个基础算法的原理。直接上代码。

二、代码

import time
import random
import os
import numpy as np
import math
from torch.nn.parameter import Parameter
from torch.nn.modules.module import Moduleimport torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optimimport scipy.sparse as sp#配置项
class configs():def __init__(self):# Dataself.data_path = r'E:\code\Graph\data'self.save_model_dir = r'\code\Graph'self.model_name = r'GCN' #GNN/GCNself.seed = 2023self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")self.batch_size = 64self.epoch = 200self.in_features = 1433  #core ~ feature:1433self.hidden_features = 16  # 隐层数量self.output_features = 8  # core~paper-point~ 8类self.learning_rate = 0.01self.dropout = 0.5self.istrain = Trueself.istest = Truecfg = configs()def seed_everything(seed=2023):random.seed(seed)os.environ['PYTHONHASHSEED']=str(seed)np.random.seed(seed)torch.manual_seed(seed)seed_everything(seed = cfg.seed)#数据
class Graph_Data_Loader():def __init__(self):self.adj, self.features, self.labels, self.idx_train, self.idx_val, self.idx_test = self.load_data()self.adj = self.adj.to(cfg.device)self.features = self.features.to(cfg.device)self.labels = self.labels.to(cfg.device)self.idx_train = self.idx_train.to(cfg.device)self.idx_val = self.idx_val.to(cfg.device)self.idx_test = self.idx_test.to(cfg.device)def load_data(self,path=cfg.data_path, dataset="cora"):"""Load citation network dataset (cora only for now)"""print('Loading {} dataset...'.format(dataset))idx_features_labels = np.genfromtxt(os.path.join(path,dataset,dataset+'.content'),dtype=np.dtype(str))features = sp.csr_matrix(idx_features_labels[:, 1:-1], dtype=np.float32)labels = self.encode_onehot(idx_features_labels[:, -1])# build graphidx = np.array(idx_features_labels[:, 0], dtype=np.int32)idx_map = {j: i for i, j in enumerate(idx)}edges_unordered = np.genfromtxt(os.path.join(path,dataset,dataset+'.cites'),dtype=np.int32)edges = np.array(list(map(idx_map.get, edges_unordered.flatten())),dtype=np.int32).reshape(edges_unordered.shape)adj = sp.coo_matrix((np.ones(edges.shape[0]), (edges[:, 0], edges[:, 1])),shape=(labels.shape[0], labels.shape[0]),dtype=np.float32)# build symmetric adjacency matrixadj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj)features = self.normalize(features)adj = self.normalize(adj + sp.eye(adj.shape[0]))idx_train = range(140)idx_val = range(200, 500)idx_test = range(500, 1500)features = torch.FloatTensor(np.array(features.todense()))labels = torch.LongTensor(np.where(labels)[1])adj = self.sparse_mx_to_torch_sparse_tensor(adj)idx_train = torch.LongTensor(idx_train)idx_val = torch.LongTensor(idx_val)idx_test = torch.LongTensor(idx_test)return adj, features, labels, idx_train, idx_val, idx_testdef encode_onehot(self,labels):classes = set(labels)classes_dict = {c: np.identity(len(classes))[i, :] for i, c inenumerate(classes)}labels_onehot = np.array(list(map(classes_dict.get, labels)),dtype=np.int32)return labels_onehotdef normalize(self,mx):"""Row-normalize sparse matrix"""rowsum = np.array(mx.sum(1))r_inv = np.power(rowsum, -1).flatten()r_inv[np.isinf(r_inv)] = 0.r_mat_inv = sp.diags(r_inv)mx = r_mat_inv.dot(mx)return mxdef sparse_mx_to_torch_sparse_tensor(self,sparse_mx):"""Convert a scipy sparse matrix to a torch sparse tensor."""sparse_mx = sparse_mx.tocoo().astype(np.float32)indices = torch.from_numpy(np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))values = torch.from_numpy(sparse_mx.data)shape = torch.Size(sparse_mx.shape)return torch.sparse.FloatTensor(indices, values, shape)#精度评价指标
def accuracy(output, labels):preds = output.max(1)[1].type_as(labels)correct = preds.eq(labels).double()correct = correct.sum()return correct / len(labels)#模型
#01-GNN
class GNNLayer(nn.Module):def __init__(self, in_features, output_features):super(GNNLayer, self).__init__()self.linear = nn.Linear(in_features, output_features)def forward(self, adj_matrix, features):hidden_features = torch.matmul(adj_matrix, features)  # GNN公式:H' = A * Hhidden_features = self.linear(hidden_features)  # 使用线性变换hidden_features = F.relu(hidden_features)  # 使用ReLU作为激活函数return hidden_features
class GNN(nn.Module):def __init__(self, in_features, hidden_features, output_features, num_layers=2):super(GNN, self).__init__()#输入维度in_features、隐藏层维度hidden_features、输出维度output_features、GNN的层数num_layersself.layers = nn.ModuleList([GNNLayer(in_features, hidden_features) if i == 0 else GNNLayer(hidden_features, hidden_features) for i inrange(num_layers)])self.output_layer = nn.Linear(hidden_features, output_features)def forward(self, adj_matrix, features):hidden_features = featuresfor layer in self.layers:hidden_features = layer(adj_matrix, hidden_features)output = self.output_layer(hidden_features)return F.log_softmax(output,dim=1)#02-GCN
class GraphConvolution(Module):"""Simple GCN layer, similar to https://arxiv.org/abs/1609.02907"""def __init__(self, in_features, out_features, bias=True):super(GraphConvolution, self).__init__()self.in_features = in_featuresself.out_features = out_featuresself.weight = Parameter(torch.FloatTensor(in_features, out_features))if bias:self.bias = Parameter(torch.FloatTensor(out_features))else:self.register_parameter('bias', None)self.reset_parameters()def reset_parameters(self):stdv = 1. / math.sqrt(self.weight.size(1))self.weight.data.uniform_(-stdv, stdv)if self.bias is not None:self.bias.data.uniform_(-stdv, stdv)def forward(self, input, adj):support = torch.mm(input, self.weight)output = torch.spmm(adj, support)if self.bias is not None:return output + self.biaselse:return outputdef __repr__(self):return self.__class__.__name__ + ' (' \+ str(self.in_features) + ' -> ' \+ str(self.out_features) + ')'class GCN(nn.Module):def __init__(self, in_features, hidden_features, output_features, dropout=cfg.dropout):super(GCN, self).__init__()self.gc1 = GraphConvolution(in_features, hidden_features)self.gc2 = GraphConvolution(hidden_features, output_features)self.dropout = dropoutdef forward(self, adj_matrix, features):x = F.relu(self.gc1(features, adj_matrix))x = F.dropout(x, self.dropout, training=self.training)x = self.gc2(x, adj_matrix)return F.log_softmax(x, dim=1)class graph_run():def train(self):t = time.time()#Create Train Processingall_data = Graph_Data_Loader()#创建一个模型model = eval(cfg.model_name)(in_features=cfg.in_features,hidden_features=cfg.hidden_features,output_features=cfg.output_features).to(cfg.device)optimizer = optim.Adam(model.parameters(),lr=cfg.learning_rate, weight_decay=5e-4)#Trainmodel.train()for epoch in range(cfg.epoch):optimizer.zero_grad()output = model(all_data.adj, all_data.features)loss_train = F.nll_loss(output[all_data.idx_train], all_data.labels[all_data.idx_train])acc_train = accuracy(output[all_data.idx_train], all_data.labels[all_data.idx_train])loss_train.backward()optimizer.step()loss_val = F.nll_loss(output[all_data.idx_val], all_data.labels[all_data.idx_val])acc_val = accuracy(output[all_data.idx_val], all_data.labels[all_data.idx_val])print('Epoch: {:04d}'.format(epoch + 1),'loss_train: {:.4f}'.format(loss_train.item()),'acc_train: {:.4f}'.format(acc_train.item()),'loss_val: {:.4f}'.format(loss_val.item()),'acc_val: {:.4f}'.format(acc_val.item()),'time: {:.4f}s'.format(time.time() - t))torch.save(model, os.path.join(cfg.save_model_dir, 'latest.pth'))  # 模型保存def infer(self):#Create Test Processingall_data = Graph_Data_Loader()model_path = os.path.join(cfg.save_model_dir, 'latest.pth')model = torch.load(model_path, map_location=torch.device(cfg.device))model.eval()output = model(all_data.adj,all_data.features)loss_test = F.nll_loss(output[all_data.idx_test], all_data.labels[all_data.idx_test])acc_test = accuracy(output[all_data.idx_test], all_data.labels[all_data.idx_test])print("Test set results:","loss= {:.4f}".format(loss_test.item()),"accuracy= {:.4f}".format(acc_test.item()))if __name__ == '__main__':mygraph = graph_run()if cfg.istrain == True:mygraph.train()if cfg.istest == True:mygraph.infer()

三、结果与讨论

        需要从网上下载cora数据集,数据组织形式如下图。

        测了下Params和GFLOPs,还是比较大的,发现若作为一个Net的Block还是需要优化的哈哈~

ModelParamsGFLOPs
GNN23.352K126.258M
ModelCora(/train/val/test)
GNN1.0000/0.7800/0.7620
GCN0.9714/0.7767/0.8290

四、展望

        未来可以考虑用PyG(PyTorch Geometric),毕竟PyG实现GAT等图网络、图的数据组织、加载会更加方便。Graph Net通常用可以用于属性数据的embedding模式,将属性数据可以作为一种补充特征加入Net去训练,看能不能发挥效能。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/135845.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

无涯教程-JavaScript - MDETERM函数

描述 MDETERM函数返回数组的矩阵行列式。 语法 MDETERM (array)争论 Argument描述Required/OptionalArrayA numeric array with an equal number of rows and columns.Required Notes 数组可以作为单元格范围给出,如A1:C3;作为数组常量,如{1,2,3; 4,5,6; 7,8,9}&#xff1…

【刷题】蓝桥杯

蓝桥杯2023年第十四届省赛真题-平方差 - C语言网 (dotcpp.com) 初步想法,x y2 − z2(yz)(y-z) 即xa*b,ayz,by-z 2yab 即ab是2的倍数就好了。 即x存在两个因数之和为偶数就能满足条件。 但时间是(r-l)*x&am…

服务网格和微服务架构的关系:理解服务网格在微服务架构中的角色和作用

🌷🍁 博主猫头虎 带您 Go to New World.✨🍁 🦄 博客首页——猫头虎的博客🎐 🐳《面试题大全专栏》 文章图文并茂🦕生动形象🦖简单易学!欢迎大家来踩踩~🌺 &a…

郑州大学图书馆许少辉《乡村振兴战略下传统村落文化旅游设计》中文文献——2023学生开学季辉少许

郑州大学图书馆许少辉《乡村振兴战略下传统村落文化旅游设计》中文文献——2023学生开学季辉少许

六、串口通信

六、串口通信 串口接口介绍使用串口向电脑发送数据电脑发送数据控制LED灯 串口接口介绍 SBUF是串口数据缓存器,物理上是两个独立的寄存器,但占用相同的地址。写操作时,写入的是发送寄存器;读操作时,读出的是接收寄存器…

【uniapp】Dcloud的uni手机号一键登录,具体实现及踩过的坑,调用uniCloud.getPhoneNumber(),uni.login()等

一键登录Dcloud官网请戳这里,感兴趣的可以看看官网,有很详细的示例,选择App一键登录,可以看到一些常用的概述 比如: 1、调用uni.login就能弹出一键登录的页面 2、一键登录的流程,可以选择先预登录uni.prelo…

数据库----数据查询

1.6 查询语句 语法:select [选项] 列名 [from 表名] [where 条件] [group by 分组] [order by 排序][having 条件] [limit 限制]1.6.1 字段表达式 mysql> select 锄禾日当午; ------------ | 锄禾日当午 | ------------ | 锄禾日当午 | ---…

C++---多态

多态 前言多态的概念多态的定义及实现多态的构成条件虚函数虚函数的重写虚函数重写的两个例外协变(基类与派生类虚函数返回值类型不同)析构函数的重写 override和final 虚函数的默认参数 抽象基类 前言 在买火车票的时候,如果你是学生,是买半价票&#…

微服务保护-Sentinel

初识Sentinel 雪崩问题及解决方案 雪崩问题 微服务中,服务间调用关系错综复杂,一个微服务往往依赖于多个其它微服务。 如图,如果服务提供者I发生了故障,当前的应用的部分业务因为依赖于服务I,因此也会被阻塞。此时&a…

基于SpringBoot的旅游系统

基于SpringBootVue的旅游系统、前后端分离 开发语言:Java数据库:MySQL技术:SpringBoot、Vue、Mybaits Plus、ELementUI工具:IDEA/Ecilpse、Navicat、Maven 【主要功能】 角色:管理员、用户 用户:浏览旅游…

Rockchip RK3399 - USB触摸屏接口驱动

---------------------------------------------------------------------------------------------------------------------------- 开发板 :NanoPC-T4开发板eMMC :16GBLPDDR3 :4GB 显示屏 :15.6英寸HDMI接口显示屏u-boot &…

三维模型3DTile格式轻量化压缩文件大小的技术方法研究

三维模型3DTile格式轻量化压缩文件大小的技术方法研究 倾斜摄影三维模型,由于数据量大、复杂度高,轻量化压缩成为其在网络传输和实时渲染中必不可少的环节。以下是几种常用的3DTile格式轻量化压缩技术方法: 几何简化:这是一种最…

K8s(Kubernetes)学习(五)——Service:ClusterIP、NodePort、LoadBalancer、 ExternalName

第五章 Service 什么是 Service为什么需要 ServiceService 特性Service 与 Pod 关联Service type 类型如何使用 Service多端口配置 1 什么是 Service 1.1 定义 官网地址: https://kubernetes.io/zh-cn/docs/concepts/services-networking/service/ 将运行在一个或一组 Pod…

uniapp视频播放功能

UniApp提供了多种视频播放组件,包括视频播放器(video)、多媒体组件(media)、WebView(内置Video标签)等。其中,video和media组件是最常用的。 video组件 video组件是基于HTML5 vide…

iOS-砸壳篇(两种砸壳方式)

CrackerXI砸壳呢,当时你要是使用 frida-ios-dump 也是可以的; https://github.com/AloneMonkey/frida-ios-dump frida-ios-dump: 代码中需要更改的:手机中的内网ip 密码 等 最后放到我的砸壳路径里: python dump.py -l查看应用…

libevent数据结构——TAILQ_结构体

TAILQ_结构体 TAILQ_结构体在文件event2/event_struct.h和文件event2/keyvalq_struct.h中都有定义,并且他们的定义都是一样的,定义了TAILQ_ENTRY、TAILQ_HEAD结构体: #ifndef TAILQ_ENTRY #define EVENT_DEFINED_TQENTRY_ #define TAILQ_EN…

matlab根轨迹绘制

绘制根轨迹目的就是改变系统的闭环极点,使得系统由不稳定变为稳定或者使得稳定的系统变得更加稳定。 在使用PID控制器的时候,首先要确定的参数是Kp,画成框图的形式如下: 也就是想要知道Kp对系统性能有哪些影响,此时就…

展会预告 | 图扑邀您共聚 IOTE 国际物联网展·深圳站

参展时间:9 月 20 日- 22 日 图扑展位:9 号馆 9B 35-1 参展地址:深圳国际会展中心(宝安新馆) IOTE 2023 第二十届国际物联网展深圳站,将于 9 月 20 日- 22 日在深圳国际会展中心(宝安&#xff0…

【面试经典150 | 双指针】判断子序列

文章目录 写在前面Tag题目来源题目解题解题思路方法一:双指针方法二:动态规划 写在最后 写在前面 本专栏专注于分析与讲解【面试经典150】算法,两到三天更新一篇文章,欢迎催更…… 专栏内容以分析题目为主,并附带一些对…

语言建模的发展阶段以及大规模语言模型的背景介绍

语言本质上是一个由语法规则控制的复杂、精密的人类表达系统,开发能够理解和掌握语言的AI 算法是一个重大挑战。作为一种主要方法,语言建模在过去两十年中已被广泛研究,从统计语言模型发展到神经语言模型,用于语言理解和生成。从技…