DenseDataLoader
和 DataLoader
在处理数据的方式上有所不同。DenseDataLoader
是专门用于处理稠密图数据的,而 DataLoader
通常用于处理稀疏图数据。在你的案例中,如果所有图的节点数和边数是固定的,可以使用 DenseDataLoader
进行更高效的批处理。
目前发现不同之处在于,Dense输入是一个邻接矩阵,打印出来向量是展开格式(batch_size, num_nodes, num_node_features),直接使用dataloader是(batch_size*num_nodes, num_node_features),这两种方式返回一个包含图数据的 Data 对象,包含特征矩阵和邻接矩 return Data(x=x, adj=self.adj_matrix, y=y)和return Data(x=x, edge_index=edge_index, y=y)
使用 DenseDataLoader
以下是如何使用 DenseDataLoader
来加载数据的完整代码示例:
1. 导入必要的库
import torch
from torch_geometric.data import Data
from torch_geometric.loader import DenseDataLoader
2. 定义 MyDataset
类
class MyDataset(torch.utils.data.Dataset):def __init__(self, num_samples, num_nodes, num_node_features):super(MyDataset, self).__init__()self.num_samples = num_samplesself.num_nodes = num_nodesself.num_node_features = num_node_features# 创建固定的邻接矩阵,这里简单使用环形图的邻接矩阵self.adj_matrix = torch.zeros((num_nodes, num_nodes), dtype=torch.float)for i in range(num_nodes):self.adj_matrix[i, (i + 1) % num_nodes] = 1self.adj_matrix[(i + 1) % num_nodes, i] = 1def __len__(self):return self.num_samplesdef __getitem__(self, idx):# 创建随机特征和标签,这里仅作示例x = torch.randn((self.num_nodes, self.num_node_features))y = torch.randn((self.num_nodes, 1)) # 每个节点一个标签# 返回一个包含图数据的 Data 对象,包含特征矩阵和邻接矩阵return Data(x=x, adj=self.adj_matrix, y=y)
3. 创建数据集和封装数据
# 参数设置
num_samples = 100 # 样本数
num_nodes = 10 # 每个图中的节点数
num_node_features = 8 # 每个节点的特征数# 创建数据集
dataset = MyDataset(num_samples, num_nodes, num_node_features)# 封装数据
data_list = [dataset[i] for i in range(num_samples)]
4. 使用 DenseDataLoader
使用 DenseDataLoader
加载数据集,并检查批次数据的形状。
# 创建 DenseDataLoader
loader = DenseDataLoader(data_list, batch_size=32, shuffle=True)# 从 DenseDataLoader 中获取一个批次的数据并查看其形状
for batch in loader:x = batch.x # 形状为 (batch_size, num_nodes, num_node_features)adj = batch.adj # 形状为 (batch_size, num_nodes, num_nodes)y = batch.y # 形状为 (batch_size, num_nodes, 1)print("Batch node features shape:", x.shape) # 期望输出形状为 (32, 10, 8)print("Batch adjacency matrix shape:", adj.shape) # 期望输出形状为 (32, 10, 10)print("Batch labels shape:", y.shape) # 期望输出形状为 (32, 10, 1)break # 仅查看第一个批次的形状
解释
-
导入库:
- 导入
torch
和torch_geometric.data
中的Data
。 - 导入
DenseDataLoader
。
- 导入
-
定义
MyDataset
类:__init__
方法初始化数据集参数,创建固定的邻接矩阵。__len__
方法返回数据集的样本数量。__getitem__
方法生成每个样本的随机节点特征和标签,并返回一个Data
对象。
-
创建数据集和封装数据:
- 使用
MyDataset
类创建一个包含 100 个样本的数据集,每个样本包含 10 个节点,每个节点有 8 个特征。 - 使用
data_list = [dataset[i] for i in range(num_samples)]
封装数据。
- 使用
-
使用
DenseDataLoader
:- 使用
DenseDataLoader
加载data_list
,设置批次大小为 32,并进行随机打乱。 - 在获取一个批次的数据时,检查
batch.x
、batch.adj
和batch.y
的形状,以确保其符合期望的三维形状。
- 使用
通过这个完整的示例代码,你可以生成、封装和加载稠密图数据,并确保每个批次的数据形状保持正确。
- 使用
data_list = [dataset[i] for i in range(num_samples)]
封装数据这一步其实不需要
完整代码
import torch
from torch_geometric.data import Data
from torch_geometric.loader import DenseDataLoaderclass MyDataset(torch.utils.data.Dataset):def __init__(self, num_samples, num_nodes, num_node_features):super(MyDataset, self).__init__()self.num_samples = num_samplesself.num_nodes = num_nodesself.num_node_features = num_node_features# 创建固定的邻接矩阵,这里简单使用环形图的邻接矩阵self.adj_matrix = torch.zeros((num_nodes, num_nodes), dtype=torch.float)for i in range(num_nodes):self.adj_matrix[i, (i + 1) % num_nodes] = 1self.adj_matrix[(i + 1) % num_nodes, i] = 1def __len__(self):return self.num_samplesdef __getitem__(self, idx):# 创建随机特征和标签,这里仅作示例x = torch.randn((self.num_nodes, self.num_node_features))y = torch.randn((self.num_nodes, 1)) # 每个节点一个标签# 返回一个包含图数据的 Data 对象,包含特征矩阵和邻接矩阵return Data(x=x, adj=self.adj_matrix, y=y)# 参数设置
num_samples = 100 # 样本数
num_nodes = 10 # 每个图中的节点数
num_node_features = 8 # 每个节点的特征数# 创建数据集
dataset = MyDataset(num_samples, num_nodes, num_node_features)# 创建 DenseDataLoader
loader = DenseDataLoader(dataset, batch_size=32, shuffle=True)# 从 DenseDataLoader 中获取一个批次的数据并查看其形状
for batch in loader:x = batch.x # 形状为 (batch_size, num_nodes, num_node_features)adj = batch.adj # 形状为 (batch_size, num_nodes, num_nodes)y = batch.y # 形状为 (batch_size, num_nodes, 1)print("Batch node features shape:", x.shape) # 期望输出形状为 (32, 10, 8)print("Batch adjacency matrix shape:", adj.shape) # 期望输出形状为 (32, 10, 10)print("Batch labels shape:", y.shape) # 期望输出形状为 (32, 10, 1)break # 仅查看第一个批次的形状