LightningDataModule
是 PyTorch Lightning 提供的数据模块,用于统一管理数据加载流程(包括数据准备、预处理、拆分、批量加载等)。它的核心作用是将数据处理逻辑与模型解耦,提高代码的可复用性和可读性。
1. LightningDataModule
的作用
✅ 封装数据预处理:数据下载、清理、转换等步骤都可以在 LightningDataModule
中完成。
✅ 统一数据加载流程:确保训练、验证、测试和推理数据集使用相同的数据预处理逻辑。
✅ 简化 Trainer
代码:LightningDataModule
使 Trainer.fit()
更加简洁和模块化。
✅ 支持多 GPU、TPU 训练:可以轻松适配不同计算设备的 Dataloader 设定。
2. LightningDataModule
的基本结构
LightningDataModule
主要包含以下关键方法:
方法 | 作用 |
---|---|
prepare_data() | 仅在主进程中运行一次,用于下载数据、处理静态数据(如数据去重) |
setup(stage) | 在每个 GPU/TPU 设备上运行,用于数据拆分( |