RootNeighboursDataset(helpers.dataset_classes文件中的root_neighbours_dataset.py)

任务类型:回归
用途:在 `RootNeighboursDataset` 中,任务是给定一棵根树,预测根节点度数为6的邻居的特征平均值。因此,模型需要基于根节点的结构,找到度为6的邻居,并计算其特征的平均值。这属于回归问题,因为目标是预测连续值(特征的平均值)

from helpers.dataset_classes.root_neighbours_dataset import RootNeighboursDataset

import torch
from torch_geometric.data import Data, Batch
from typing import Dict, Tuple, List
from torch import Tensorclass RootNeighboursDataset(object):def __init__(self, seed: int, print_flag: bool = False):super().__init__()self.seed = seedself.plot_flag = print_flagself.generator = torch.Generator().manual_seed(seed)self.constants_dict = self.initialize_constants()self._data = self.create_data()def get(self) -> Data:return self._datadef create_data(self) -> Data:# train, val, testdata_list = []for num in range(self.constants_dict['NUM_COMPONENTS']):data_list.append(self.generate_component())return Batch.from_data_list(data_list)def mask_task(self, num_nodes_per_fold: List[int]) -> Tuple[Tensor, Tensor, Tensor]:num_nodes = sum(num_nodes_per_fold)train_mask = torch.zeros(size=(num_nodes,), dtype=torch.bool)val_mask = torch.zeros(size=(num_nodes,), dtype=torch.bool)test_mask = torch.zeros(size=(num_nodes,), dtype=torch.bool)train_mask[0] = Trueval_mask[num_nodes_per_fold[0]] = Truetest_mask[num_nodes_per_fold[0] + num_nodes_per_fold[1]] = Truereturn train_mask, val_mask, test_maskdef generate_component(self) -> Data:data_per_fold, num_nodes_per_fold = [], []for fold_idx in range(3):data = self.generate_fold(eval=(fold_idx != 0))num_nodes_per_fold.append(data.x.shape[0])data_per_fold.append(data)train_mask, val_mask, test_mask = self.mask_task(num_nodes_per_fold=num_nodes_per_fold)batch = Batch.from_data_list(data_per_fold)return Data(x=batch.x, edge_index=batch.edge_index, y=batch.y, train_mask=train_mask, val_mask=val_mask,test_mask=test_mask)def initialize_constants(self) -> Dict[str, int]:return {'NUM_COMPONENTS': 1000, 'MAX_HUBS': 3, 'MAX_1HOP_NEIGHBORS': 10, 'ADD_HUBS': 2, 'HUB_NEIGHBORS': 5,'MAX_2HOP_NEIGHBORS': 3, 'NUM_FEATURES': 5}def generate_fold(self, eval: bool) -> Data:constant_dict = self.initialize_constants()MAX_HUBS, MAX_1HOP_NEIGHBORS, ADD_HUBS, HUB_NEIGHBORS, MAX_2HOP_NEIGHBORS, NUM_FEATURES =\[constant_dict[key] for key in ['MAX_HUBS', 'MAX_1HOP_NEIGHBORS', 'ADD_HUBS', 'HUB_NEIGHBORS','MAX_2HOP_NEIGHBORS', 'NUM_FEATURES']]assert MAX_HUBS + ADD_HUBS <= MAX_1HOP_NEIGHBORSadd_hubs = ADD_HUBS if eval else 0num_hubs = torch.randint(1, MAX_HUBS + 1, size=(1,), generator=self.generator).item() + add_hubsnum_1hop_neighbors = torch.randint(MAX_HUBS + add_hubs, MAX_1HOP_NEIGHBORS + 1, size=(1,),generator=self.generator).item()assert num_hubs <= num_1hop_neighborslist_num_2hop_neighbors = torch.randint(1, MAX_2HOP_NEIGHBORS, size=(num_1hop_neighbors - num_hubs,),generator=self.generator).tolist()list_num_2hop_neighbors = [HUB_NEIGHBORS] * num_hubs + list_num_2hop_neighbors# 2 hop edge indexnum_nodes = 1  # root node is 0idx_1hop_neighbors = []list_edge_index = []for num_2hop_neighbors in list_num_2hop_neighbors:idx_1hop_neighbors.append(num_nodes)if num_2hop_neighbors > 0:clique_edge_index = torch.tensor([[0] * num_2hop_neighbors, list(range(1, num_2hop_neighbors + 1))])# clique_edge_index = torch.combinations(torch.arange(num_2hop_neighbors), r=2).Tlist_edge_index.append(clique_edge_index + num_nodes)num_nodes += num_2hop_neighbors + 1# 1 hop edge indexidx_0hop = torch.tensor([0] * num_1hop_neighbors)idx_1hop_neighbors = torch.tensor(idx_1hop_neighbors)hubs = idx_1hop_neighbors[:num_hubs]list_edge_index.append(torch.stack((idx_0hop, idx_1hop_neighbors), dim=0))edge_index = torch.cat(list_edge_index, dim=1)# undirectedge_index_other_direction = torch.stack((edge_index[1], edge_index[0]), dim=0)edge_index = torch.cat((edge_index_other_direction, edge_index), dim=1)# featuresx = 4 * torch.rand(size=(num_nodes, NUM_FEATURES), generator=self.generator) - 2# labelsy = torch.zeros_like(x)y[0] = torch.mean(x[hubs], dim=0)return Data(x=x, edge_index=edge_index, y=y)if __name__ == '__main__':data = RootNeighboursDataset(seed=0, print_flag=True)

这个 RootNeighboursDataset通过随机生成的树状图数据来模拟一种节点关系,并基于图结构生成特征和标签。代码使用了 PyTorchPyTorch Geometric 的功能来处理图数据。下面逐块详细解释该代码实现:

1. RootNeighboursDataset 类构造器

import torch
from torch_geometric.data import Data, Batch
from typing import Dict, Tuple, List
from torch import Tensorclass RootNeighboursDataset(object):def __init__(self, seed: int, print_flag: bool = False):super().__init__()self.seed &#

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/455124.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

百度搜索推广和信息流推广的区别,分别适用于什么场景!

信息流推广和搜索广告&#xff0c;不仅仅是百度&#xff0c;是很多平台的两个核心推广方式。 1、搜索广告&#xff1a; 就是基于用户的搜索习惯&#xff0c;更多是用户有疑问、还有用户当下就要做出行动的广告。 比如上门服务、线上咨询服务、招商加盟、了解产品各种型号和信…

STM32G4系列MCU的低功耗模式介绍

目录 概述 1 认识低功耗模式 1.1 低功耗模式的应用 1.2 低功耗模式介绍 2 低功耗模式的状态关系 2.1 低功耗模式可能的转换状态图 2.2 低功耗模式总结 3 运行模式 3.1 减慢系统时钟 3.2 外围时钟门控 3.3 低功耗运行模式&#xff08;LP运行&#xff09; 概述 本文主…

react 基础学习笔记

1.react 语法 ①数据渲染 函数组件将HTML结构直接写在函数的返回值中 JSX只能有一个根元素 JSX插值写法 插值可以使用的位置 1.标签内容&#xff1b; 2.标签属性 JSX 条件渲染&#xff1a;三目运算符&#xff1b; JSX根据数据进行列表渲染&#xff1a;map()方法&#x…

QT 机器视觉 1.相机类型

本专栏从实际需求场景出发详细还原、分别介绍大型工业化场景、专业实验室场景、自动化生产线场景、各种视觉检测物体场景介绍本专栏应用场景 更适合涉及到视觉相关工作者、包括但不限于一线操作人员、现场实施人员、项目相关维护人员&#xff0c;希望了解2D、3D相机视觉相关操作…

微服务与多租户详解:架构设计与实现

引言 在现代软件开发领域&#xff0c;微服务架构和多租户架构是两个重要的概念。微服务架构通过将应用程序拆分为多个独立的服务&#xff0c;提升了系统的灵活性和可维护性。而多租户架构则通过共享资源来服务多个客户&#xff0c;提高了资源利用率和系统的经济性。 一、微服务…

OpenCV的常用与形状形状描述相关函数及用法示例

OpenCV提供了提供了多种用于形状描述和分析的函数。这些函数能够帮助你提取图像中的形状特征&#xff0c;进行形状匹配、识别和分析。下面介绍一些常用的形状描述函数&#xff1a; 轮廓检测函数findContours() findContours()函数用于在二值图像中查找轮廓。有两个原型函数&…

【zlm】 webrtc源码讲解(二)

目录 webrtc播放 MultiMediaSourceMuxer里的_ring webrtc播放 > MediaServer.exe!mediakit::WebRtcPlayer::onStartWebRTC() 行 60 CMediaServer.exe!mediakit::WebRtcTransport::OnDtlsTransportConnected(const RTC::DtlsTransport * dtlsTransport, RTC::SrtpSession::…

tomcat部署war包部署运行,IDEA一键运行启动tomacat服务,maven打包为war包并部署到tomecat

tomcat部署war包前端访问 在Java Web开发中&#xff0c;Tomcat是一个非常流行的开源Web服务器和Servlet容器。它实现了Java Servlet和JavaServer Pages (JSP) 技术&#xff0c;提供了一个纯Java的Web应用环境。本文将介绍如何在Tomcat中部署运行WAR包&#xff0c;让你的应用快…

vue2 使用环境变量

一. 在根目录下创建.env.xxx文件 .env 基础系统变量&#xff0c;无论何种环境&#xff0c;都可使用其中配置的值&#xff0c;其他环境中的变量会覆盖.env中的同名变量。 .env.development 开发环境 .env.production 生产环境 .env.staging 测试环境 二. 内容格式 vue2 使用是以…

GRU神经网络理解

全文参考以下B站视频及《神经网络与深度学习》邱锡鹏&#xff0c;侧重对GPU模型的理解&#xff0c;初学者入门自用记录&#xff0c;有问题请指正【重温经典】GRU循环神经网络 —— LSTM的轻量级版本&#xff0c;大白话讲解_哔哩哔哩_bilibili 更新门、重置门、学习与输出 注&a…

STM32(二十一):看门狗

WDG&#xff08;Watchdog&#xff09;看门狗&#xff0c;手动重装寄存器的操作就是喂狗。 看门狗可以监控程序的运行状态&#xff0c;当程序因为设计漏洞、硬件故障、电磁干扰等原因&#xff0c;出现卡死或跑飞现象时&#xff0c;看门狗能及时复位程序&#xff0c;避免程序陷入…

数学建模微分方程模型——传染病模型

病毒也疯狂&#xff1a;细说传染病微分方程模型的那些事儿 “数学是打开科学大门的钥匙&#xff0c;而微分方程则是理解世界变化的密码。” 大家好&#xff01;今天我们要聊一聊一个既严肃又有趣的话题——传染病微分方程模型。别急&#xff0c;听起来高大上&#xff0c;其实一…

亚信安全DeepSecurity中标知名寿险机构云主机安全项目

近日&#xff0c;亚信安全DeepSecurity成功中标国内知名寿险机构的云主机安全项目。亚信安全凭借在云主机安全防护领域的突出技术优势&#xff0c;结合安全运营的能力&#xff0c;以“实战化”为指导&#xff0c;为用户提供无惧威胁攻击、无忧安全运营的一站式云安全体系&#…

【论文翻译】ICLR 2018 | DCRNN:扩散卷积递归神经网络:数据驱动的交通预测

论文题目Diffusion Convolutional Recurrent Neural Network: Data-Driven Traffic Forecasting论文链接https://arxiv.org/abs/1707.01926源码地址https://github.com/liyaguang/DCRNN发表年份-会议/期刊2018 ICLR关键词交通预测&#xff0c;扩散卷积&#xff0c;递归神经网络…

数字+文旅:AI虚拟数字人如何焕发传统文旅景区新活力?

​​引言&#xff1a; 据《2024年中国数字文旅行业市场研究报告》显示&#xff0c;截至2022年&#xff0c;中国数字文旅市场规模已达到约9698.1亿元人民币&#xff0c;相较于2017年的7870.5亿元&#xff0c;实现了57.89%的显著增长。这一行业涵盖了数字化的文化遗产旅游、虚拟…

JVM、字节码文件介绍

目录 初识JVM 什么是JVM JVM的三大核心功能 JVM的组成 字节码文件的组成 基础信息 Magic魔数 主副版本号 其它基础信息 常量池 字段 方法 属性 字节码常用工具 javap jclasslib插件 阿里Arthas 初识JVM 什么是JVM JVM的三大核心功能 1. 解释和运行虚拟机指…

【性能优化】安卓性能优化之CPU优化

【性能优化】安卓性能优化之CPU优化 CPU优化及常用工具原理与文章参考常用ADB常用原理、监控手段原理监控手段多线程并发解决耗时UI相关 常见场景排查CPU占用过高常用系统/开源分析工具AndroidStudio ProfilerSystraceBtracePerfettoTraceView和 Profile ANR相关ANR原理及常见场…

使用 VSCode 通过 Remote-SSH 连接远程服务器详细教程

使用 VSCode 通过 Remote-SSH 连接远程服务器详细教程 在日常开发中&#xff0c;许多开发者需要远程连接服务器进行代码编辑和调试。Visual Studio Code&#xff08;VSCode&#xff09;提供了一个非常强大的扩展——Remote-SSH&#xff0c;它允许我们通过 SSH 协议直接连接远程…

YOLO V3 网络构架解析

YOLO V3&#xff08;You Only Look Once version 3&#xff09;是由Joseph Redmon等人于2018年提出的一种基于深度学习的目标检测算法。它在速度和精度上相较于之前的版本有了显著提升&#xff0c;成为计算机视觉领域的一个重要里程碑。本文将详细解析YOLO V3的网络架构&#x…

【信息论基础第六讲】离散无记忆信源等长编码包括典型序列和等长信源编码定理

一、信源编码的数学模型 我们知道信源的输出是消息序列&#xff0c;对于信源进行编码就是用码字集来表示消息集&#xff0c;也就是要进行从消息集到码字集的映射。 根据码字的特征我们又将其分为D元码&#xff0c;等长码&#xff0c;不等长码&#xff0c;唯一可译码。 我们通过…