【Graph Net学习】LINE实现Graph Embedding

一、简介

        LINE (Large-scale Information Network Embedding,2015) 是一种设计用于处理大规模信息网络的算法。它主要的目标是在给定的大规模信息网络中学习高质量的节点嵌入,并尽量保留网络中信息的丰富性。其具体的表现为在一个低 维空间里以向量形式表示网络中的节点,以便后续的机器学习任务可以更好地理解。

LINE算法根据两种相互关联的线性化策略去处理信息图,分别考虑了图节点的一阶邻居和二阶邻居。通过这种方式,LINE既能反映出网络的全局属性又能反映出网络的局部属性。

        调用算法流程如下:

  1. 首先,为图中的每个节点初始化一个随机向量。

  2. 接着,使用一阶邻居的优化原型函数进行训练。在一阶近邻策略中,若两个节点存在直接连接,则他们的向量应该尽可能相近。

  3. 然后,使用二阶邻居的优化原型函数进行训练。在二阶近邻策略中,考虑两节点间的间接联系。例如,若两节点存在共享的邻居,即使他们之间没有直接的联系,他们的向量也应该相近。

  4. 对每个节点,计算其在一阶和二阶优化下的损失函数值,并对其进行优化。

  5. 优化完成后,此时每个节点上的向量就是最终的嵌入表示。

  6. 基于得到的嵌入表示进行后续的分析或机器学习任务。

        接下来就是快乐的代码时间嘿嘿嘿

二、代码

import os
import pandas as pd
import numpy as np
import networkx as nx
import time
import scipy.sparse as sp
from torch_geometric.data import Data
from torch_geometric.transforms import ToSparseTensor
import torch_geometric.utils
from sklearn.preprocessing import LabelEncoderimport torch
import torch.nn as nn#配置项
class configs():def __init__(self):# Dataself.data_path = r'./data'self.save_model_dir = r'./'self.num_nodes = 2708self.embedding_dim = 128self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")self.learning_rate = 0.01self.epoch = 30self.criterion = nn.BCEWithLogitsLoss()self.istrain = Trueself.istest = Truecfg = configs()def load_cora_data(data_path = './data/cora'):content_df = pd.read_csv(os.path.join(data_path,"cora.content"), delimiter="\t", header=None)content_df.set_index(0, inplace=True)index = content_df.index.tolist()features = sp.csr_matrix(content_df.values[:,:-1], dtype=np.float32)# 处理标签labels = content_df.values[:,-1]class_encoder = LabelEncoder()labels = class_encoder.fit_transform(labels)# 读取引用关系cites_df = pd.read_csv(os.path.join(data_path,"cora.cites"), delimiter="\t", header=None)cites_df[0] = cites_df[0].astype(str)cites_df[1] = cites_df[1].astype(str)cites = [tuple(x) for x in cites_df.values]edges = [(index.index(int(cite[0])), index.index(int(cite[1]))) for cite in cites]edges = np.array(edges).T# 构造Data对象data = Data(x=torch.from_numpy(np.array(features.todense())),edge_index=torch.LongTensor(edges),y=torch.from_numpy(labels))idx_train = range(140)idx_val = range(200, 500)idx_test = range(500, 1500)# 读取Cora数据集 return geometric Data格式def index_to_mask(index, size):mask = np.zeros(size, dtype=bool)mask[index] = Truereturn maskdata.train_mask = index_to_mask(idx_train, size=labels.shape[0])data.val_mask = index_to_mask(idx_val, size=labels.shape[0])data.test_mask = index_to_mask(idx_test, size=labels.shape[0])def to_networkx(data):edge_index = data.edge_index.to(torch.device('cpu')).numpy()G = nx.DiGraph()for src, tar in edge_index.T:G.add_edge(src, tar)return Gnetworkx_data = to_networkx(data)return data,networkx_data
#获取数据:pyg_data:torch_geometric格式;networkx_data:networkx格式def generate_pairs(adj_matrix):# 根据邻接矩阵生成正例和负例pos_pairs = torch.nonzero(adj_matrix, as_tuple=True)pos_u = pos_pairs[0]pos_v = pos_pairs[1]mask = torch.ones_like(adj_matrix)for i in range(len(pos_u)):mask[pos_u[i]][pos_v[i]] = 0mask[pos_v[i]][pos_u[i]] = 0tmp = torch.nonzero(mask, as_tuple=True)#TODO 随机选取负例idx = torch.randperm(tmp[0].size(0))neg_u = tmp[0][idx][:pos_u.size(0)]neg_v = tmp[1][idx][:pos_v.size(0)]return pos_u, pos_v, neg_u, neg_v# 构建LINE网络
class LINE(nn.Module):def __init__(self, num_nodes, embed_dim):super(LINE, self).__init__()#num_nodes为Node个数 , embed_dim为描述Node的Embedding维度self.embed_dim = embed_dimself.num_nodes = num_nodesself.embeddings = nn.Embedding(self.num_nodes, self.embed_dim)self.reset_parameters()def reset_parameters(self):self.embeddings.weight.data.normal_(std=1 / self.embed_dim)def forward(self, pos_u, pos_v, neg_v):emb_pos_u = self.embeddings(pos_u)emb_pos_v = self.embeddings(pos_v)emb_neg_v = self.embeddings(neg_v)pos_scores = torch.sum(torch.mul(emb_pos_u, emb_pos_v), dim=1)neg_scores = torch.sum(torch.mul(emb_pos_u, emb_neg_v), dim=1)return pos_scores, neg_scoresclass LINE_run():def train(self):t = time.time()# 创建一个模型_, networkx_data = load_cora_data()adj_matrix = torch.tensor(nx.adjacency_matrix(networkx_data).toarray(), dtype=torch.float32)model = LINE(num_nodes=cfg.num_nodes, embed_dim=cfg.embedding_dim).to(cfg.device)optimizer = torch.optim.Adam(model.parameters(), lr=cfg.learning_rate)#Trainmodel.train()for epoch in range(cfg.epoch):optimizer.zero_grad()pos_u, pos_v, neg_u, neg_v = generate_pairs(adj_matrix)pos_u = pos_u.to(cfg.device)pos_v = pos_v.to(cfg.device)neg_v = neg_v.to(cfg.device)pos_scores, neg_scores = model(pos_u, pos_v, neg_v)pos_losses = cfg.criterion(pos_scores, torch.ones(len(pos_scores)).to(cfg.device))neg_losses = cfg.criterion(neg_scores, torch.zeros(len(neg_scores)).to(cfg.device))loss = pos_losses + neg_lossesloss.backward()optimizer.step()print('Epoch: {:04d}'.format(epoch + 1),'loss_train: {:.4f}'.format(loss.item()),'time: {:.4f}s'.format(time.time() - t))torch.save(model, os.path.join(cfg.save_model_dir, 'latest.pth'))  # 模型保存print('Embedding dim : ({},{})'.format(model.embeddings.weight.shape[0],model.embeddings.weight.shape[1]))def infer(self):# Create Test Processing_, networkx_data = load_cora_data()adj_matrix = torch.tensor(nx.adjacency_matrix(networkx_data).toarray(), dtype=torch.float32)model_path = os.path.join(cfg.save_model_dir, 'latest.pth')model = torch.load(model_path, map_location=torch.device(cfg.device))model.eval()_, networkx_data = load_cora_data()pos_u, pos_v, neg_u, neg_v = generate_pairs(adj_matrix)pos_u = pos_u.to(cfg.device)pos_v = pos_v.to(cfg.device)neg_v = neg_v.to(cfg.device)pos_scores, neg_scores = model(pos_u, pos_v, neg_v)pos_losses = cfg.criterion(pos_scores, torch.ones(len(pos_scores)).to(cfg.device))neg_losses = cfg.criterion(neg_scores, torch.zeros(len(neg_scores)).to(cfg.device))loss = pos_losses + neg_lossesprint("Test set results:","loss= {:.4f}".format(loss.item()),'Embedding dim : ({},{})'.format(model.embeddings.weight.shape[0], model.embeddings.weight.shape[1]))if __name__ == '__main__':mygraph = LINE_run()if cfg.istrain == True:mygraph.train()if cfg.istest == True:mygraph.infer()

三、输出结果

        跑的是Cora数据,共2708个Node,设置的Embedding维度是128维。上面代码运行完就是长下面这个样子。

Epoch: 0001 loss_train: 1.3863 time: 3.0867s
Epoch: 0002 loss_train: 1.3832 time: 3.7739s
Epoch: 0003 loss_train: 1.3768 time: 4.4471s
...
Epoch: 0028 loss_train: 0.7739 time: 21.3568s
Epoch: 0029 loss_train: 0.7694 time: 22.0310s
Epoch: 0030 loss_train: 0.7663 time: 22.7042s
Embedding dim : (2708,128)
Test set results: loss= 0.7609 Embedding dim : (2708,128)

        效果未知,没有用下游聚类测一下,反正看起来BCE loss是降了哈哈,这期就到这里。
        

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/138319.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

华为云云耀云服务器L实例评测|华为云上安装监控服务Prometheus三件套安装

文章目录 华为云云耀云服务器L实例评测|华为云上试用监控服务Prometheus一、监控服务Prometheus三件套介绍二、华为云主机准备三、Prometheus安装四、Grafana安装五、alertmanager安装六、三个服务的启停管理1. Prometheus、Alertmanager 和 Grafana 启动顺序2. 使用…

mysql workbench常用操作

1、No database selected Select the default DB to be used by double-clicking its name in the SCHEMAS list in the sidebar 方法一:双击你要使用的库 方法二:USE 数据库名 2、复制表名,字段名 3、保存链接

折线图geom_line()参数选项

往期折线图教程 图形复现| 使用R语言绘制折线图折线图指定位置标记折线图形状更改 | 绘制动态折线图跟着NC学作图 | 使用python绘制折线图 前言 我们折线的专栏推出一段时间,但是由于个人的原因,一直未进行更新。那么今天,我们也参考《R语…

Kafka-UI

有多款kafka管理应用,目前选择的是github上star最多的UI for Apache Kafka。 关于 To run UI for Apache Kafka, you can use either a pre-built Docker image or build it (or a jar file) yourself. UI for Apache Kafka is a versatile, fast, and lightweight…

关于ABB机器人的IO创建和设置

首先要链接网线,请求写权限 关于ABB机器人的默认地址位10有的是63.31看你硬接线 ABB机器人分配好信号机器人控制柜要重启 可以把机器人分配的信号导出为EIO 类似与发那科机器人IO

Spring底层原理之 BeanFactory 与 ApplicationContext

🐌个人主页: 🐌 叶落闲庭 💨我的专栏:💨 c语言 数据结构 javaEE 操作系统 Redis 石可破也,而不可夺坚;丹可磨也,而不可夺赤。 Spring底层原理 一、 BeanFactory 与 Appli…

ChatGPT追祖寻宗:GPT-3技术报告要点解读

论文地址:https://arxiv.org/abs/2005.14165 往期相关文章: ChatGPT追祖寻宗:GPT-1论文要点解读_五点钟科技的博客-CSDN博客ChatGPT追祖寻宗:GPT-2论文要点解读_五点钟科技的博客-CSDN博客 本文的标题之所以取名技术报告而不是论文…

无法删除目录(设备或资源忙)

文章目录 无法删除目录(设备或资源忙)问题原因解决方案步骤一:首先找到挂载的位置步骤二:取消挂载步骤三:查看挂在情况 无法删除目录(设备或资源忙) 问题 原因 网络共享挂载导致无法删除 解决…

C++实现观察者模式(包含源码)

文章目录 观察者模式一、基本概念二、实现方式三、角色四、过程五、结构图六、构建思路七、完整代码 观察者模式 一、基本概念 观察者模式(又被称为模型(Model)-视图(View)模式)是软件设计模式的一种。在…

Twitter图片数据优化的细节

Twitter个人数据优化:吸引更多关注和互动 头像照片在Twitter上,头像照片是最快识别一个账号的方法之一。因此,请务必使用公司的标志或与品牌相关的图片。建议尺寸为400x400像素。 为了建立强大的品牌形象和一致性,强烈建议在所有…

Spring MVC常见面试题

Spring MVC简介 Spring MVC框架是以请求为驱动,围绕Servlet设计,将请求发给控制器,然后通过模型对象,分派器来展示请求结果视图。简单来说,Spring MVC整合了前端请求的处理及响应。 Servlet 是运行在 Web 服务器或应用…

Xcode 15 运行<iOS 14, 启动崩溃问题

如题. Xcode 15 启动 < iOS 14(没具体验证过, 我的问题设备是iOS 13.7)真机设备 出现启动崩溃 解决方案: Build Settings -> Other Linker Flags -> Add -> -ld64

第八天:gec6818arm开发板和Ubuntu中安装并且编译移植mysql驱动连接QT执行程序

一、Ubuntu18.04中安装并且编译移植mysql驱动程序连接qt执行程序 1 、安装Mysql sudo apt-get install mysql-serverapt-get isntall mysql-clientsudo apt-get install libmysqlclient-d2、查看是否安装成功&#xff0c;即查看MySQL版本 mysql --version 3、MySQL启动…

STP生成树协议基本配置示例---STP逻辑树产生和修改

STP是用来避免数据链路层出现逻辑环路的协议&#xff0c;运行STP协议的设备通过交互信息发现环路&#xff0c;并通过阻塞特定端口&#xff0c;最终将网络结构修剪成无环路的树形结构。在网络出现故障的时候&#xff0c;STP能快速发现链路故障&#xff0c;并尽快找出另外一条路径…

Selenium-介绍下其他骚操作

Chrome DevTools 简介 Chrome DevTools 是一组直接内置在基于 Chromium 的浏览器&#xff08;如 Chrome、Opera 和 Microsoft Edge&#xff09;中的工具&#xff0c;用于帮助开发人员调试和研究网站。 借助 Chrome DevTools&#xff0c;开发人员可以更深入地访问网站&#xf…

SpringMVC初级

文章目录 一、SpringMVC 概述二、springMVC步骤1、新建maven的web项目2、导入maven依赖3、创建controller4、创建spring-mvc.xml配置文件&#xff08;本质就是spring的配置件&#xff09;5、web.xml中配置前端控制器6、新建a.jsp文件7、配置tomcat8、启动测试 三、工作流程分析…

【Vue】如何搭建SPA项目--详细教程

&#x1f3ac; 艳艳耶✌️&#xff1a;个人主页 &#x1f525; 个人专栏 &#xff1a;《Spring与Mybatis集成整合》《springMvc使用》 ⛺️ 生活的理想&#xff0c;为了不断更新自己 ! 目录 1.什么是vue-cli 2.安装 2.1.创建SPA项目 2.3.一问一答模式答案 3.运行SPA项目 3…

UML活动图

在UML中&#xff0c;活动图本质上就是流程图&#xff0c;它描述系统的活动、判定点和分支等&#xff0c;因此它对开发人员来说是一种重要工具。 活动图 活动是某件事情正在进行的状态&#xff0c;既可以是现实生活中正在进行的某一项工作&#xff0c;也可以是软件系统中某个类…

【element-ui】form表单动态修改rules校验项

在项目开发过程中&#xff0c;该页面有暂存和提交两个按钮&#xff0c;其中暂存和提交必填项校验不一样&#xff0c;此时需要动态增减必填项校验 &#xff0c;解决方法如下&#xff1a; 增加rules校验项 this.$set(this.formRules,name,[{required:true,message:请输入名称,t…

润和软件HopeStage与华宇信息TAS应用中间件完成产品兼容性互认证

近日&#xff0c;江苏润和软件股份有限公司&#xff08;以下简称“润和软件”&#xff09;HopeStage 操作系统与北京华宇信息技术有限公司&#xff08;以下简称“华宇信息”&#xff09;TAS应用中间件软件完成产品兼容性测试。 测试结果表明&#xff0c;企业级通用操作系统Hope…