深度学习(36)—— 图神经网络GNN(1)

深度学习(36)—— 图神经网络GNN(1)

这个系列的所有代码我都会放在git上,欢迎造访

文章目录

  • 深度学习(36)—— 图神经网络GNN(1)
    • 1. 基础知识
    • 2.使用场景
    • 3. 图卷积神经网络GCN
      • (1)基本思想
    • 4. GNN基本框架——pytorch_geometric
      • (1)数据
      • (2)可视化
      • (3)网络定义
      • (4)训练模型(semi-supervised)

1. 基础知识

  • GNN考虑的事当前的点和周围点之间的关系

  • 邻接矩阵是对称的稀疏矩阵,表示图中各个点之间的关系

  • 图神经网络的输入是每个节点的特征和邻接矩阵

  • 文本数据可以用图的形式表示吗?文本数据也可以表示图的形式,邻接矩阵表示连接关系

  • 邻接矩阵中并不是一个N* N的矩阵,而是一个source,target的2* N的矩阵
    在这里插入图片描述

  • 信息传递神经网络:每个点的特征如何更新??——考虑他们的邻居,更新的方式可以自己设置:最大,最小,平均,求和等

  • GNN可以有多层,图的结构不发生改变,即当前点所连接的点不发生改变(邻接矩阵不发生变化)【卷积中存在感受野的概念,在GNN中同样存在,GNN的感受野也随着层数的增大变大】

  • GNN输出的特征可以干什么?

    • 各个节点的特征组合,对图分类【graph级别任务】
    • 对各个节点分类【node级别任务】
    • 对边分类【edge级别任务】
    • 利用图结构得到特征,最终做什么自定义!

2.使用场景

  • 为什么CV和NLP中不用GNN?
    因为图像和文本的数据格式很固定,传统神经网络格式是固定的,输入的东西格式是固定的
  • 化学、医疗
  • 分子、原子结构
  • 药物靶点
  • 道路交通,动态流量预测
  • 社交网络——研究人
    GNN输入格式比较随意,是不规则的数据结构, 主要用于输入数据不规则的时候

3. 图卷积神经网络GCN

  • 图卷积和卷积完全不同
  • GCN不是单纯的有监督学习,多数是半监督,有的点是没有标签的,在计算损失的时候只考虑有标签的点。针对数据量少的情况也可以训练

(1)基本思想

  • 网络层次:第一层对于每个点都要做更新,最后输出每个点对应的特征向量【一般不会做特别深层的】
  • 图中的基本组成:G(原图)A(邻接)D(度)F(特征)
  • 度矩阵的倒数* 邻接矩阵 *度矩阵的倒数——>得到新的邻接矩阵【左乘对行做归一化,右乘对列做归一化】
  • 两到三层即可,太多效果不佳

4. GNN基本框架——pytorch_geometric

它实现了各种GNN的方法
注意:安装过程中不要pip install,会失败!根据自己的device和python版本去下载scatter,pattern等四个依赖,先安装他们然后再pip install torch_geometric==2.0
这里记得是2.0版本否则会出现 TypeError: Expected ‘Iterator‘ as the return annotation for __iter__ of SMILESParser, but found ty
献上github地址:这里

下面是一个demo

(1)数据

这里使用的是和这个package提供的数据,具体参考:club
在这里插入图片描述

from torch_geometric.datasets import KarateClubdataset = KarateClub()
print(f'Dataset: {dataset}:')
print('======================')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')data = dataset[0]  # Get the first graph object.

在torch_geometric中图用Data的格式,Data的对象:可以在文档中详细了解在这里插入图片描述
其中的属性

  • edge_index:表示图的连接关系(start,end两个序列)
  • node features:每个点的特征
  • node labels:每个点的标签
  • train_mask:有的节点没有标签(用来表示哪些节点要计算损失)

(2)可视化

from torch_geometric.utils import to_networkxG = to_networkx(data, to_undirected=True)
visualize_graph(G, color=data.y)

在这里插入图片描述

(3)网络定义

GCN layer的定义:在这里插入图片描述
可以在官网的文档做详细了解

在这里插入图片描述
卷积层就有很多了:
在这里插入图片描述

import torch
from torch.nn import Linear
from torch_geometric.nn import GCNConvclass GCN(torch.nn.Module):def __init__(self):super().__init__()torch.manual_seed(1234)self.conv1 = GCNConv(dataset.num_features, 4) # 只需定义好输入特征和输出特征即可self.conv2 = GCNConv(4, 4)self.conv3 = GCNConv(4, 2)self.classifier = Linear(2, dataset.num_classes)def forward(self, x, edge_index):h = self.conv1(x, edge_index) # 输入特征与邻接矩阵(注意格式,上面那种)h = h.tanh()h = self.conv2(h, edge_index)h = h.tanh()h = self.conv3(h, edge_index)h = h.tanh()  # 分类层out = self.classifier(h)return out, hmodel = GCN()
print(model)_, h = model(data.x, data.edge_index)
print(f'Embedding shape: {list(h.shape)}')# 输出最后分类前的中间特征shapevisualize_embedding(h, color=data.y)

这时很分散
在这里插入图片描述

(4)训练模型(semi-supervised)

import timemodel = GCN()
criterion = torch.nn.CrossEntropyLoss()  # Define loss criterion.
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)  # Define optimizer.def train(data):optimizer.zero_grad()  out, h = model(data.x, data.edge_index) #h是两维向量,主要是为了画图方便 loss = criterion(out[data.train_mask], data.y[data.train_mask])  # semi-supervisedloss.backward()  optimizer.step()  return loss, hfor epoch in range(401):loss, h = train(data)if epoch % 10 == 0:visualize_embedding(h, color=data.y, epoch=epoch, loss=loss)time.sleep(0.3)

然后就可以看到一系列图,看点的变化情况了

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/88497.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

UnityWebGL移动端兼容性说明

测试时间2023.8.10 官方文档说明 依据Unity官方最新版本文档(2021.3LTS),关于WebGL的兼容性说明为"Unity WebGL不支持移动设备。它可能适用于高端设备,但当前的设备通常不够强大,并且没有足够的内存来支持Unity …

【c语言】字符函数与字符串函数(上)

大家好呀,今天给大家分享一下字符函数和字符串函数,说起字符函数和字符串函数大家会想到哪些呢??我想到的只有求字符串长度的strlen,拷贝字符串的strcpy,字符串比较相同的strcmp,今天,我要分享给大家的是我们一些其他的…

SQL-每日一题【1517. 查找拥有有效邮箱的用户】

题目 表: Users 编写一个解决方案,以查找具有有效电子邮件的用户。 一个有效的电子邮件具有前缀名称和域,其中: 前缀 名称是一个字符串,可以包含字母(大写或小写),数字,下划线 _ &…

详细讲解如何在github上编辑个人主页?

在 GitHub 上编辑个人主页可以让您展示您的项目、技能和个人信息,以及与其他开发者互动。以下是详细的步骤来在 GitHub 上编辑个人主页: 创建 GitHub 账户 如果您还没有 GitHub 账户,首先需要注册一个。 登录到 GitHub 使用您的用户名和密…

【TypeScript】进阶之路语法细节,类型和函数

进阶之路 类型别名(type)的使用接口(interface)的声明的使用二者区别: 联合类型和交叉类型联合类型交叉类型 类型断言获取DOM元素 非空类型断言字面量类型的使用类型缩小(类型收窄)TypeScript 函数类型函数类型表达式内部规则检测函数的调用签…

置信域策略优化Trust Region Policy Optimization (TRPO)

1. 置信域方法(Trust Region Methods) [1]将置信域方法用到强化学习中,并取到了非常好的结果. 1.1 优化问题 1.2 置信域 1.3 置信域方法的过程 References [1] Schulman J, Levine S, Abbeel P, et al. Trust region policy optimization[C]//International conf…

【K8S系列】深入解析k8s网络插件—Weave Net

序言 做一件事并不难,难的是在于坚持。坚持一下也不难,难的是坚持到底。 文章标记颜色说明: 黄色:重要标题红色:用来标记结论绿色:用来标记论点蓝色:用来标记论点 Kubernetes (k8s) 是一个容器编…

构建Docker容器监控系统(cadvisor+influxDB+grafana)

目录 一、部署 1、安装docker-cd 2、阿里云镜像加速 3、下载组件镜像 4、创建自定义网络 5、创建influxdb容器 6、创建Cadvisor 容器 7、创建granafa容器 一、部署 1、安装docker-cd [rootlocalhost ~]# iptables -F [rootlocalhost ~]# setenforce 0 setenforce: SELi…

BGP的工作过程及报文

IGP核心:路由的计算。OSPF,ISIS等 BGP核心:路由的传递,不产生路由,只是路由的搬运工,一般用于规模特别大的网络中,只要TCP可达就可以建立邻居。 大型企业分支间采用BGP进行路由传递,不同的分支属于不同的BGP的AS,它们通过BGP进行路由交互。企业与运营商之间可使用BGP进行…

解决nvm安装后,node生效但npm无效

问题描述 nvm安装后,node生效但npm无效 清除缓存 C:\Users\cc\AppData\Roaming cc是我的用户名改成你自己的就行删除 npm和npm-cache

Rx.NET in Action 中文介绍 前言及序言

Rx 处理器目录 (Catalog of Rx operators) 目标可选方式Rx 处理器(Operator)创建 Observable Creating Observables直接创建 By explicit logicCreate Defer根据范围创建 By specificationRangeRepeatGenerateTimerInterval Return使用预设 Predefined primitivesThrow …

软件测试(功能、接口、性能、自动化)详解

一、软件测试功能测试 测试用例编写是软件测试的基本技能;也有很多人认为测试用例是软件测试的核心;软件测试中最重要的是设计和生成有效的测试用例;测试用例是测试工作的指导,是软件测试的必须遵守的准则。 黑盒测试常见测试用…

Gartner发布2023年的存储技术成熟曲线

技术路线说明 Gartner自1995年起开始采用技术成熟度曲线,它描述创新的典型发展过程,即从过热期发展到幻灭低谷期,再到人们最终理解创新在市场或领域内的意义和角色。 一项技术 (或相关创新)在发展到最终成熟期的过程中经历多个阶段&#xff1…

二十二、策略模式

目录 1、项目需求2、传统方案解决鸭子问题的分析和代码实现3、传统方式实现存在的问题分析和解决方案4、策略模式基本介绍5、使用策略模式解决鸭子问题6、策略模式的注意事项和细节7、策略模式的使用场景 以具体项目来演示为什么需要策略模式,策略模式的优点&#x…

微信小程序--原生

1:数据绑定 1:数据绑定的基本原则 2:在data中定义页面的数据 3:Mustache语法 4:Mustache的应用场景 1:常见的几种场景 2:动态绑定内容 3:动态绑定属性 4:三元运算 4&am…

python_day19_正则表达式

正则表达式re模块 导包 import res "python java c c python2 python python3"match 从头匹配 res re.match("python", s) res_2 re.match("python2", s) print("res:", res) print(res.span()) print(res.group()) print("…

Python-OpenCV中的图像处理-傅里叶变换

Python-OpenCV中的图像处理-傅里叶变换 傅里叶变换Numpy中的傅里叶变换Numpy中的傅里叶逆变换OpenCV中的傅里叶变换OpenCV中的傅里叶逆变换 DFT的性能优化不同滤波算子傅里叶变换对比 傅里叶变换 傅里叶变换经常被用来分析不同滤波器的频率特性。我们可以使用 2D 离散傅里叶变…

【分布式系统】聊聊高性能设计

每个程序员都应该知道的数字 高性能 对于以上的数字,其实每个程序员都应该了解,因为只有了解这些基本的数字,才能知道对于CPU、内存、磁盘、网络之间数据读写的时间。1000ms 1S。毫秒->微秒->纳秒-秒->分钟 为什么高性能如此重要的…

单体版ruoyi代码生成增删改查

目录 拉取代码 打开代码,新建一个模块,模块放我们的项目后台数据库的curd代码。 我们的新模块引入ruoyi的通用模块 ruoyi的adm引入我们的项目依赖,引用我们的模型、service、mapper 将我们的模块注入父项目 打开ruoyi-adm配置MyBatis&…

Spannable配合AnimationDrawable实现TextView中展示Gif图片

辣的原理解释,反正大家也不爱看,所以直接上代码了 长这样,下面两个图是gif,会动的。 package com.example.myapplication;import android.content.Context; import android.graphics.Bitmap; import android.graphics.drawable…