图神经网络GNN(一)GraphEmbedding

DeepWalk


使用随机游走采样得到每个结点x的上下文信息,记作Context(x)。
SkipGram优化的目标函数:P(Context(x)|x;θ)
θ = argmax P(Context(x)|x;θ)
DeepWalk这种GraphEmbedding方法是一种无监督方法,个人理解有点类似生成模型的Encoder过程,下面的代码中,node_proj是一个简单的线性映射函数,加上elu激活函数,可以看作Encoder的过程。Encoder结束后就得到了Embedding后的隐变量表示。其实GraphEmbedding要的就是这个node_proj,但是由于没有标签,只有训练数据的内部特征,怎么去训练呢?这就需要看我们的训练任务了,个人理解,也就是说,这种无监督的embedding后的结果取决于你的训练任务,也就是Decoder过程。Embedding后的编码对Decoder过程越有利,损失函数也就越小,编码做的也就越好。在word2vec中,有两种训练任务,一种是给定当前词,预测其前两个及后两个词发生的条件概率,采用这种训练任务做出的embedding就是skip-gram;还有一种是给定当前词前两个及后两个词,预测当前词出现的条件概率,采用这种训练任务做出的embedding就是CBOW.DeepWalk作者的论文中采用的是skip-gram。故复现也采用skip-gram进行复现。
针对skip-gram对应的训练任务,代码中的node_proj相当于编码器,h_o_1和h_o_2相当于解码器。Encoder和Decoder可以先联合训练,训练结束后,可以只保留Encoder的部分,舍弃Decoder的部分。当再来一个独热编码的时候,可以直接通过node_proj映射,即完成了独热编码的embedding过程。
(本代码假定在当前结点去往各邻接结点的可能性相同,即不考虑边的权重)

import pandas as pd
import torch
import torch.nn as nn
import numpy as np
import random
import torch.nn.functional as F
import networkx as nx
from torch.nn import CrossEntropyLoss
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.distributions import Categorical
import matplotlib.pyplot as pltclass MyGraph():def __init__(self,device):super(MyGraph, self).__init__()self.G = nx.read_edgelist(path='data/wiki/Wiki_edgelist.txt',create_using=nx.DiGraph(),nodetype=None,data=[('weight',int)])self.adj_matrix = nx.attr_matrix(self.G)self.edges = nx.edges(self.G)self.edges_emb = torch.eye(len(self.G.edges)).to(device)self.nodes_emb = torch.eye(len(self.G.nodes)).to(device)class GraphEmbedding(nn.Module):def __init__(self,nodes_num,edges_num,device,emb_dim = 10):super(GraphEmbedding, self).__init__()self.device = deviceself.nodes_proj = nn.Parameter(torch.randn(nodes_num,emb_dim))self.edges_proj = nn.Parameter(torch.randn(edges_num,emb_dim))self.h_o_1 = nn.Parameter(torch.randn(emb_dim,nodes_num * 2))self.h_o_2 = nn.Parameter(torch.randn(nodes_num * 2,nodes_num))def forward(self,G:MyGraph):self.nodes_proj,self.edges_proj = self.nodes_proj.to(self.device),self.edges_proj.to(device)self.h_o_1,self.h_o_2 = self.h_o_1.to(self.device),self.h_o_2.to(self.device)# Encoderedges_emb,nodes_emb = torch.matmul(G.edges_emb,self.edges_proj),torch.matmul(G.nodes_emb,self.nodes_proj)nodes_emb = F.elu_(nodes_emb)edges_emb,nodes_emb = edges_emb.to(device),nodes_emb.to(device)# Decoderpolicy = self.DeepWalk(G,gamma=5,window=2)outputs = torch.matmul(torch.matmul(nodes_emb[policy[:,0]],self.h_o_1),self.h_o_2)policy,outputs = policy.to(device),outputs.to(device)return policy,outputsdef DeepWalk(self,Graph:MyGraph,gamma:int,window:int,eps=1e-9):# Calculate transpose matrixadj_matrix = torch.tensor(Graph.adj_matrix[0], dtype=torch.float32)for i in range(adj_matrix.shape[0]):adj_matrix[i,:] /= (torch.sum(adj_matrix[i]) + eps)adj_nodes = Graph.adj_matrix[1].copy()random.shuffle(adj_nodes)nodes_idx, route_result = [],[]for node in adj_nodes:node_idx = np.where(np.array(Graph.adj_matrix[1]) == node)[0].item()node_list = self.Random_Walk(adj_matrix,window=window,node_idx=node_idx)route_result.append(node_list)return torch.tensor(route_result)def Random_Walk(self,adj_matrix:torch.Tensor,window:int,node_idx:int):node_list = [node_idx]for i in range(window):pi = self.HMM_process(adj_matrix,node_idx)if torch.sum(pi) == 0:pi += 1 / pi.shape[0]node_idx = Categorical(pi).sample().item()node_list.append(node_idx)return node_listdef HMM_process(self,adj_matrix:torch.Tensor,node_idx:int,eps=1e-9):pi = torch.zeros((1, adj_matrix.shape[0]), dtype=torch.float32)pi[:,node_idx] = 1.0pi = torch.matmul(pi,adj_matrix)pi = pi.squeeze(0) / (torch.sum(pi) + eps)return piif __name__ == "__main__":epochs = 200device = torch.device("cuda:1")cross_entrophy_loss = CrossEntropyLoss().to(device)Graph = MyGraph(device)Embedding = GraphEmbedding(nodes_num=len(Graph.G.nodes), edges_num=len(Graph.G.edges),device=device).to(device)optimizer = torch.optim.Adam(Embedding.parameters(),lr=1e-5)scheduler=CosineAnnealingLR(optimizer,T_max=50,eta_min=0.05)loss_list = []epoch_list = [i for i in range(1,epochs+1)]for epoch in range(epochs):policy,outputs = Embedding(Graph)outputs = outputs.unsqueeze(1).repeat(1,policy.shape[-1]-1,1).reshape(-1,outputs.shape[-1])optimizer.zero_grad()loss = cross_entrophy_loss(outputs, policy[:,1:].reshape(-1))loss.backward()optimizer.step()scheduler.step()loss_list.append(loss.item())print(f"Loss : {loss.item()}")plt.plot(epoch_list,loss_list)plt.xlabel('Epoch')plt.ylabel('CrossEntrophyLoss')plt.title('Loss-Epoch curve')plt.show()

在这里插入图片描述

Node2Vec

在这里插入图片描述
在这里插入图片描述
修改Random_Walk函数如下:

    def Random_Walk(self,adj_matrix:torch.Tensor,window:int,node_idx:int):node_list = [node_idx]for i in range(window):pi = self.HMM_process(adj_matrix,node_idx)if torch.sum(pi) == 0:pi += 1 / pi.shape[0]if i > 0:v,t = node_list[-1],node_list[-2]x_list = torch.nonzero(adj_matrix[v]).squeeze(-1)for x in x_list:if t == x:  # 0pi[x] *= 1/self.pelif adj_matrix[t][x] == 1:  # 1pi[x] *= 1else:   # 2pi[x] *= 1/self.qnode_idx = Categorical(pi).sample().item()node_list.append(node_idx)return node_list

结果如下,这里令p=2,q=3,即1/p=0.5,1/q=0.33,会相对保守周围。结果似乎好了那么一点点。
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/149285.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

图像和视频上传平台Share Me

本文完成于 6 月,所以反代中,域名演示还是使用的 laosu.ml,不过版本并没有什么变化; 什么是 Share Me ? Share Me 是使用 Next.js 和 PocketBase 的自托管图像和视频上传平台,具有丰富的嵌入支持和 API&…

【全3D打印坦克——基于Arduino履带式机器人】

【全3D打印坦克——基于Arduino履带式机器人】 1. 概述2. 设计机器人平台3. 3D 模型和 STL 下载文件3.1 3D打印3.2 组装 3D 打印坦克 – 履带式机器人平台3.3 零件清单 4. 机器人平台电路图4.1 定制电路板设计4.2 完成 3D 打印储罐组件 5. 机器人平台编程6. 测试3D打印机器人 -…

小谈设计模式(20)—组合模式

小谈设计模式(20)—组合模式 专栏介绍专栏地址专栏介绍 组合模式对象类型叶节点组合节点 核心思想应用场景123 结构图结构图分析 Java语言实现首先,我们需要定义一个抽象的组件类 Component,它包含了组合节点和叶节点的公共操作&a…

在2023年使用Unity2021从Built-in升级到Urp可行么

因为最近在做WEbgl平台,所以某些不可抗力原因,需要使用Unity2021开发,又由于不可明说原因,想用Urp,怎么办? 目录 创建RenderAsset 关联Asset 暴力转换(Menu->Edit) 单个文件…

editplus如何批量删除包含某个字符串的行

在EditPlus中批量删除包含某个字符串的行的步骤如下: 打开EditPlus并打开您想要编辑的文件。 按下 Ctrl H 打开查找/替换对话框。 在 “Find what” 框中,输入您想要删除的字符串的正则表达式。例如,如果您想要删除包含 “testtest” 的行…

企业征信牌照9月末盘点:149家机构荣获上榜,西藏等地机构待批

孟凡富 笔者根据7年帮助20多家企业征信机构备案的经验,以及对于征信政策和知识的深入了解,整理了这篇文章。 2013年1月21日,国务院颁布了《征信业管理条例》(国务院令第631号),自2013年3月15日起开始实施。…

【C语言】字符函数和字符串函数(1)

#国庆发生的那些事儿# 大家好,我是苏貝,本篇博客带大家了解字符函数和字符串函数,如果你觉得我写的还不错的话,可以给我一个赞👍吗,感谢❤️ 目录 1.本章重点2. strlen2.1函数介绍2.2 模拟实现 3. strcpy3…

节日灯饰灯串灯出口欧洲CE认证办理

灯串(灯带),这个产品的形状就象一根带子一样,再加上产品的主要原件就是LED,因此叫做灯串或者灯带。2022年,我国灯具及相关配件产品出口总额超过460亿美元。其中北美是最大的出口市场。其次是欧洲市场&#…

Firefly-LLaMA2-Chinese - 开源中文LLaMA2大模型

文章目录 关于模型列表 & 数据列表训练细节增量预训练 & 指令微调数据格式 & 数据处理逻辑增量预训练指令微调模型推理权重合并模型推理部署关于 github : https://github.com/yangjianxin1/Firefly-LLaMA2-Chinese本项目与Firefly一脉相承,专注于低资源增量预训练…

RDP协议流程详解(二)Basic Settings Exchange 阶段

RDP连接建立过程,在Connection Initiation后,RDP客户端和服务端将进行双方基础配置信息交换,也就是basic settings exchange阶段。在此阶段,将包含两条消息Client MCS Connect Initial PDU和Server MCS Connect Response PDU&…

mysql-sql执行流程

sql执行流程 MYSQL 中的执行流程 MYSQL 中的执行流程 sql 执行流程如下图

网络爬虫中的代理技术:socks5代理和HTTP代理

网络爬虫是一种非常重要的数据采集工具,但是在进行网络爬虫时,我们经常会遇到一些限制,比如IP封锁、反爬虫机制等,这些限制会影响我们的数据采集效果。为了解决这些问题,我们可以使用代理服务器,其中socks5…

travel总结:

目录 1、前期准备: 2、项目期间: (1)注册功能的实现: 1、前端: 1、表单数据的校验:(js) 2、使用ajax完成表单提交 3、注册成功跳转页面 2、web: 1、获取表单数据、封装数据 2、调…

字符串函数的模拟实现

引言:对于字符串来说,我们通常想要对其完成各种各样的目的,不管是排序还是查找都是最普遍的功能,而我们的C语言中也包含着一系列函数是为了实现对字符串的一些功能,今天我们就来介绍他们。 strlen函数: 求字…

正则表达式 Regular Expression学习

该文章内容为以下视频的学习笔记: 10分钟快速掌握正则表达式_哔哩哔哩_bilibili正则表达式在线测试工具:https://regex101.com/, 视频播放量 441829、弹幕量 1076、点赞数 19330、投硬币枚数 13662、收藏人数 26242、转发人数 2768, 视频作者 奇乐编程学…

【visual studio 小技巧】项目属性->生成->事件

需求 我们有时会用到一些dll,需要把这些dll和我们生成的exe放到一起,一般我们是手动自己copy, 这样发布的时候,有时会忘记拷贝这个dll,导致程序运行出错。学会这个小技巧,就能实现自动copy,非…

AWS Lambda Golang HelloWorld 快速入门

操作步骤 以下测试基于 WSL2 Ubuntu 22.04 环境 # 下载最新 golang wget https://golang.google.cn/dl/go1.21.1.linux-amd64.tar.gz# 解压 tar -C ~/.local/ -xzf go1.21.1.linux-amd64.tar.gz# 配置环境变量 PATH echo export PATH$PATH:~/.local/go/bin >> ~/.bashrc …

Ubuntu20配置Mysql常用操作

文章目录 版权声明ubuntu更换软件源Ubuntu设置静态ipUbuntu防火墙ubuntu安装ssh服务Ubuntu安装vmtoolsUbuntu安装mysql5.7Ubuntu安装mysql8.0Ubuntu卸载mysql 版权声明 本博客的内容基于我个人学习黑马程序员课程的学习笔记整理而成。我特此声明,所有版权属于黑马程…

STM32复习笔记(四):看门狗

目录 (一)简介 (二)IWDG IWDG的CUBEMX工程配置 IWDG相关函数(非常少,所以直接贴上来): (三)WWDG (一)简介 看门狗分为独立看门…

【图像处理GIU】图像分割(Matlab代码实现)

💥💥💞💞欢迎来到本博客❤️❤️💥💥 🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜密,逻辑清晰,为了方便读者。 ⛳️座右铭&a…