【WB 深度学习实验管理】使用 PyTorch Lightning 实现高效的图像分类实验跟踪

本文使用到的 Jupyter Notebook 可在GitHub仓库002文件夹找到,别忘了给仓库点个小心心~~~
https://github.com/LFF8888/FF-Studio-Resources
在这里插入图片描述

在机器学习项目中,实验跟踪和结果可视化是至关重要的环节。无论是调整超参数、优化模型架构,还是监控训练过程中的性能变化,清晰的记录和直观的可视化都能显著提升开发效率。然而,许多开发者在实际操作中往往忽视了这一点,导致实验结果难以复现,或者在项目协作中出现混乱。今天,笔者将介绍如何利用 PyTorch Lightning 和 Weights & Biases 这一强大的工具组合,轻松构建和训练一个图像分类模型。通过本文,你将学会如何高效地组织数据管道、定义模型架构,并利用 W&B 实现实验跟踪和结果可视化,让每一次实验都清晰可溯,每一次优化都有据可依。

使用 PyTorch Lightning ⚡️ 进行图像分类

我们将使用 PyTorch Lightning 构建一个图像分类管道。我们将遵循这个 风格指南 来提高代码的可读性和可重复性。这里有一个很酷的解释:使用 PyTorch Lightning 进行图像分类。

设置 PyTorch Lightning 和 W&B

对于本教程,我们需要 PyTorch Lightning(这不是很明显吗!)和 Weights and Biases。

!pip install lightning torchvision -q
# 安装 weights and biases
!pip install wandb -qU

你需要这些导入。

import lightning.pytorch as pl
# 你最喜欢的机器学习跟踪工具
from lightning.pytorch.loggers import WandbLoggerimport torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import random_split, DataLoaderfrom torchmetrics import Accuracyfrom torchvision import transforms
from torchvision.datasets import CIFAR10import wandb

现在你需要登录到你的 wandb 账户。

wandb.login()

🔧 DataModule - 我们应得的数据管道

DataModules 是一种将数据相关的钩子与 LightningModule 解耦的方式,以便你可以开发与数据集无关的模型。
它将数据管道组织成一个可共享和可重用的类。一个 datamodule 封装了 PyTorch 中数据处理的五个步骤:

  • 下载 / 分词 / 处理。
  • 清理并(可能)保存到磁盘。
  • 加载到 Dataset 中。
  • 应用转换(旋转、分词等)。
  • 包装到 DataLoader 中。

了解更多关于 datamodules 的信息 这里。让我们为 Cifar-10 数据集构建一个 datamodule。

class CIFAR10DataModule(pl.LightningDataModule):def __init__(self, batch_size, data_dir: str = './'):super().__init__()self.data_dir = data_dirself.batch_size = batch_sizeself.transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])self.num_classes = 10def prepare_data(self):CIFAR10(self.data_dir, train=True, download=True)CIFAR10(self.data_dir, train=False, download=True)def setup(self, stage=None):# 为 dataloaders 分配训练/验证数据集if stage == 'fit' or stage is None:cifar_full = CIFAR10(self.data_dir, train=True, transform=self.transform)self.cifar_train, self.cifar_val = random_split(cifar_full, [45000, 5000])# 为 dataloader(s) 分配测试数据集if stage == 'test' or stage is None:self.cifar_test = CIFAR10(self.data_dir, train=False, transform=self.transform)def train_dataloader(self):return DataLoader(self.cifar_train, batch_size=self.batch_size, shuffle=True)def val_dataloader(self):return DataLoader(self.cifar_val, batch_size=self.batch_size)def test_dataloader(self):return DataLoader(self.cifar_test, batch_size=self.batch_size)

📱 Callbacks

回调是一个独立的程序,可以在项目之间重用。PyTorch Lightning 提供了一些 内置回调,这些回调经常被使用。
了解更多关于 PyTorch Lightning 中的回调 这里。

内置回调

在本教程中,我们将使用 Early Stopping 和 Model Checkpoint 内置回调。它们可以传递给 Trainer

自定义回调

如果你熟悉自定义 Keras 回调,那么在 PyTorch 管道中实现相同功能的能力只是锦上添花。
由于我们正在进行图像分类,能够可视化模型对一些样本图像的预测可能很有帮助。这种形式的回调可以帮助在早期阶段调试模型。

class ImagePredictionLogger(pl.callbacks.Callback):def __init__(self, val_samples, num_samples=32):super().__init__()self.num_samples = num_samplesself.val_imgs, self.val_labels = val_samplesdef on_validation_epoch_end(self, trainer, pl_module):# 将张量带到 CPUval_imgs = self.val_imgs.to(device=pl_module.device)val_labels = self.val_labels.to(device=pl_module.device)# 获取模型预测logits = pl_module(val_imgs)preds = torch.argmax(logits, -1)# 将图像记录为 wandb Imagetrainer.logger.experiment.log({"examples":[wandb.Image(x, caption=f"Pred:{pred}, Label:{y}")for x, pred, y in zip(val_imgs[:self.num_samples],preds[:self.num_samples],val_labels[:self.num_samples])]})

🎺 LightningModule - 定义系统

LightningModule 定义了一个系统,而不是一个模型。在这里,系统将所有研究代码分组到一个类中,使其自包含。LightningModule 将你的 PyTorch 代码组织成 5 个部分:

  • 计算 (__init__)。
  • 训练循环 (training_step)
  • 验证循环 (validation_step)
  • 测试循环 (test_step)
  • 优化器 (configure_optimizers)

因此,可以构建一个与数据集无关的模型,并且可以轻松共享。让我们为 Cifar-10 分类构建一个系统。

class LitModel(pl.LightningModule):def __init__(self, input_shape, num_classes, learning_rate=2e-4):super().__init__()# 记录超参数self.save_hyperparameters()self.learning_rate = learning_rateself.conv1 = nn.Conv2d(3, 32, 3, 1)self.conv2 = nn.Conv2d(32, 32, 3, 1)self.conv3 = nn.Conv2d(32, 64, 3, 1)self.conv4 = nn.Conv2d(64, 64, 3, 1)self.pool1 = torch.nn.MaxPool2d(2)self.pool2 = torch.nn.MaxPool2d(2)n_sizes = self._get_conv_output(input_shape)self.fc1 = nn.Linear(n_sizes, 512)self.fc2 = nn.Linear(512, 128)self.fc3 = nn.Linear(128, num_classes)self.accuracy = Accuracy(task="multiclass", num_classes=num_classes)# 返回从卷积块进入线性层的输出张量的大小。def _get_conv_output(self, shape):batch_size = 1input = torch.autograd.Variable(torch.rand(batch_size, *shape))output_feat = self._forward_features(input)n_size = output_feat.data.view(batch_size, -1).size(1)return n_size# 返回卷积块的特征张量def _forward_features(self, x):x = F.relu(self.conv1(x))x = self.pool1(F.relu(self.conv2(x)))x = F.relu(self.conv3(x))x = self.pool2(F.relu(self.conv4(x)))return x# 将在推理期间使用def forward(self, x):x = self._forward_features(x)x = x.view(x.size(0), -1)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = F.log_softmax(self.fc3(x), dim=1)return xdef training_step(self, batch, batch_idx):x, y = batchlogits = self(x)loss = F.nll_loss(logits, y)# 训练指标preds = torch.argmax(logits, dim=1)acc = self.accuracy(preds, y)self.log('train_loss', loss, on_step=True, on_epoch=True, logger=True)self.log('train_acc', acc, on_step=True, on_epoch=True, logger=True)return lossdef validation_step(self, batch, batch_idx):x, y = batchlogits = self(x)loss = F.nll_loss(logits, y)# 验证指标preds = torch.argmax(logits, dim=1)acc = self.accuracy(preds, y)self.log('val_loss', loss, prog_bar=True)self.log('val_acc', acc, prog_bar=True)return lossdef test_step(self, batch, batch_idx):x, y = batchlogits = self(x)loss = F.nll_loss(logits, y)# 验证指标preds = torch.argmax(logits, dim=1)acc = self.accuracy(preds, y)self.log('test_loss', loss, prog_bar=True)self.log('test_acc', acc, prog_bar=True)return lossdef configure_optimizers(self):optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)return optimizer

🚋 训练和评估

现在我们已经使用 DataModule 组织了数据管道,并使用 LightningModule 组织了模型架构和训练循环,PyTorch Lightning Trainer 为我们自动化了其他所有内容。

Trainer 自动化了以下内容:

  • Epoch 和 batch 迭代
  • 调用 optimizer.step()backwardzero_grad()
  • 调用 .eval(),启用/禁用梯度
  • 保存和加载权重
  • Weights and Biases 日志记录
  • 多 GPU 训练支持
  • TPU 支持
  • 16 位训练支持
dm = CIFAR10DataModule(batch_size=32)
# 要访问 x_dataloader,我们需要调用 prepare_data 和 setup。
dm.prepare_data()
dm.setup()# 自定义 ImagePredictionLogger 回调所需的样本,用于记录图像预测。
val_samples = next(iter(dm.val_dataloader()))
val_imgs, val_labels = val_samples[0], val_samples[1]
val_imgs.shape, val_labels.shape
model = LitModel((3, 32, 32), dm.num_classes)# 初始化 wandb logger
wandb_logger = WandbLogger(project='wandb-lightning', job_type='train')# 初始化 Callbacks
early_stop_callback = pl.callbacks.EarlyStopping(monitor="val_loss")
checkpoint_callback = pl.callbacks.ModelCheckpoint()# 初始化一个 trainer
trainer = pl.Trainer(max_epochs=2,logger=wandb_logger,callbacks=[early_stop_callback,ImagePredictionLogger(val_samples),checkpoint_callback],)# 训练模型 ⚡🚅⚡
trainer.fit(model, dm)# 在保留的测试集上评估模型 ⚡⚡
trainer.test(dataloaders=dm.test_dataloader())# 关闭 wandb run
wandb.finish()

最终想法

我来自 TensorFlow/Keras 生态系统,发现 PyTorch 虽然是一个优雅的框架,但有点让人不知所措。这只是我的个人经验。在探索 PyTorch Lightning 时,我意识到几乎所有让我远离 PyTorch 的原因都得到了解决。以下是我兴奋的快速总结:

  • 过去:传统的 PyTorch 模型定义通常分散在各个地方。模型在某个 model.py 脚本中,训练循环在 train.py 文件中。需要来回查看才能理解管道。
  • 现在:LightningModule 作为一个系统,模型定义与 training_stepvalidation_step 等一起定义。现在它是模块化的且可共享的。
  • 过去:TensorFlow/Keras 最棒的部分是输入数据管道。他们的数据集目录丰富且不断增长。PyTorch 的数据管道曾经是最大的痛点。在普通的 PyTorch 代码中,数据下载/清理/准备通常分散在许多文件中。
  • 现在:DataModule 将数据管道组织成一个可共享和可重用的类。它只是 train_dataloaderval_dataloader(s)、test_dataloader(s) 以及匹配的转换和数据处理/下载步骤的集合。
  • 过去:使用 Keras,可以调用 model.fit 来训练模型,调用 model.predict 来运行推理。model.evaluate 提供了一个简单而有效的测试数据评估。这在 PyTorch 中不是这样。通常会找到单独的 train.pytest.py 文件。
  • 现在:有了 LightningModuleTrainer 自动化了一切。只需调用 trainer.fittrainer.test 来训练和评估模型。
  • 过去:TensorFlow 喜欢 TPU,PyTorch…嗯!
  • 现在:使用 PyTorch Lightning,可以轻松地在多个 GPU 上训练相同的模型,甚至在 TPU 上。哇!
  • 过去:我是回调的忠实粉丝,更喜欢编写自定义回调。像 Early Stopping 这样简单的事情曾经是传统 PyTorch 的讨论点。
  • 现在:使用 PyTorch Lightning,使用 Early Stopping 和 Model Checkpointing 是小菜一碟。我甚至可以编写自定义回调。

🎨 结论和资源

我希望你觉得这份报告有帮助。我鼓励你玩一下代码,并使用你选择的数据集训练一个图像分类器。

以下是一些学习更多关于 PyTorch Lightning 的资源:

  • 逐步演练 - 这是官方教程之一。他们的文档写得非常好,我强烈推荐它作为学习资源。
  • 使用 PyTorch Lightning 与 Weights & Biases - 这是一个快速 colab,你可以通过它学习如何使用 W&B 与 PyTorch Lightning。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/14734.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

人工智能入门 数学基础 线性代数 笔记

必备的数学知识是理解人工智能不可或缺的要素,今天的种种人工智能技术归根到底都建立在数学模型之上,而这些数学模型又都离不开线性代数(linear algebra)的理论框架。 线性代数的核心意义:世间万事万物都可以被抽象成某…

5 计算机网络

5 计算机网络 5.1 OSI/RM七层模型 5.2 TCP/IP协议簇 5.2.1:常见协议基础 一、 TCP是可靠的,效率低的; 1.HTTP协议端口默认80,HTTPSSL之后成为HTTPS协议默认端口443。 2.对于0~1023一般是默认的公共端口不需要注册,1024以后的则需…

unity碰撞的监测和监听

1.创建一个地面 2.去资源商店下载一个火焰素材 3.把procedural fire导入到自己的项目包管理器中 4.给magic fire 0 挂在碰撞组件Rigidbody , Sphere Collider 5.创建脚本test 并挂在magic fire 0 脚本代码 using System.Collections; using System.Collections.Generic; usi…

使用云效解决docker官方镜像拉取不到的问题

目录 前言原文地址测试jenkins构建结果:后续使用说明 前言 最近经常出现docker镜像进行拉取不了,流水线挂掉的问题,看到一个解决方案: 《借助阿里个人版镜像仓库云效实现全免费同步docker官方镜像到国内》 原文地址 https://developer.aliyun.com/artic…

element-plus+vue3前端如何根据name进行搜索查到符合条件的数据

界面如图&#xff0c;下面的区域是接口给的所有的&#xff0c;希望前端根据输入的内容自己去匹配。 我是使用的element-plusvue3ts的写法。 <el-input v-model"filters.region" placeholder"输入区域搜索" keyup"filterRegion(filters.region)&q…

电路研究9.3——合宙Air780EP中的AT开发指南(含TCP 示例)

根据合宙的AT研发推荐&#xff0c; AT指令基本上也简单看完了&#xff0c;这里开始转到AT的开发了。 AT 命令采用标准串口进行数据收发&#xff0c;将以前复杂的设备通讯方式转换成简单的串口编程&#xff0c; 大大简化了产品的硬件设计和软件开发成本&#xff0c;这使得几乎所…

cursor指令工具

Cursor 工具使用指南与实例 工具概览 Cursor 提供了一系列强大的工具来帮助开发者提高工作效率。本指南将通过具体实例来展示这些工具的使用方法。 1. 目录文件操作 1.1 查看目录内容 (list_dir) 使用 list_dir 命令可以查看指定目录下的文件结构: 示例: list_dir log…

AI安全最佳实践:AI应用开发安全评估矩阵(上)

生成式AI开发安全范围矩阵简介 生成式AI目前可以说是当下最热门的技术&#xff0c;吸引各大全球企业的关注&#xff0c;并在全球各行各业中带来浪潮般的编个。随时AI能力的飞跃&#xff0c;大语言模型LLM参数达到千亿级别&#xff0c;它和Transformer神经网络共同驱动了我们工…

Java继承简介

继承的本质&#xff1a;是代码的复用&#xff0c;重复使用已经定义好的方法和域&#xff08;即全局变量&#xff09; 要掌握继承首先要了解Java方法的重载和重写 方法的重载和重写 方法的重载 当前方法名相同&#xff0c;但是参数类型不同&#xff0c;发生重载 类比数学函…

【redis】缓存设计规范

本文是 Redis 键值设计的 14 个核心规范与最佳实践&#xff0c;按重要程度分层说明&#xff1a; 一、通用数据类型选择 这里我们先给出常规的选择路径图。 以下是对每个步骤的分析&#xff1a; 是否需要排序&#xff1f;&#xff1a; zset&#xff08;有序集合&#xff09;用…

Unity抖音云启动测试:如何用cmd命令行启动exe

相关资料&#xff1a;弹幕云启动&#xff08;原“玩法云启动能力”&#xff09;_直播小玩法_抖音开放平台 1&#xff0c;操作方法 在做云启动的时候&#xff0c;接完发现需要命令行模拟云环境测试启动&#xff0c;所以研究了下。 首先进入cmd命令&#xff0c;CD进入对应包的文件…

Android studio怎么创建assets目录

在Android Studio中创建assets文件夹是一个简单的步骤&#xff0c;通常用于存储不需要编译的资源文件&#xff0c;如文本文件、图片、音频等 main文件夹&#xff0c;邮件new->folder-assets folder

第26场蓝桥入门赛

5.扑克较量【算法赛】 - 蓝桥云课 C&#xff1a; #include <iostream> #include <algorithm> using namespace std;int a[100005];int main() {int n,k;cin>>n>>k;for (int i1; i<n; i)cin>>a[i], a[i] % k;sort(a1, a1n);int mx a[1]k-a…

公司配置内网穿透方法笔记

一、目的 公司内部有局域网&#xff0c;局域网上有ftp服务器&#xff0c;有windows桌面服务器&#xff1b; 在内网环境下&#xff0c;是可以访问ftp服务器以及用远程桌面登录windows桌面服务器的&#xff1b; 现在想居家办公时&#xff0c;也能访问到公司内网的ftp服务器和win…

c++:list

1.list的使用 1.1构造 1.2迭代器遍历 &#xff08;1&#xff09;迭代器是算法和容器链接起来的桥梁 容器就是链表&#xff0c;顺序表等数据结构&#xff0c;他们有各自的特点&#xff0c;所以底层结构是不同的。在不用迭代器的前提下&#xff0c;如果我们的算法要作用在容器上…

《Wiki.js知识库部署实践 + CNB Git数据同步方案解析》

一、wiki.js 知识库简介 基本概述 定义 &#xff1a;Wiki.js 是一个开源、现代、轻量且功能强大的 Wiki 应用程序&#xff0c;基于 Node.js 构建&#xff0c;旨在帮助个人和团队轻松创建、管理和共享知识。开源性质 &#xff1a;它遵循 AGPLv3 许可证&#xff0c;任何人都可以…

ip地址是手机号地址还是手机地址

在数字化生活的浪潮中&#xff0c;IP地址、手机号和手机地址这三个概念如影随形&#xff0c;它们各自承载着网络世界的独特功能&#xff0c;却又因名称和功能的相似性而时常被混淆。尤其是“IP地址”这一术语&#xff0c;经常被错误地与手机号地址或手机地址划上等号。本文旨在…

微服务 day01 注册与发现 Nacos OpenFeign

目录 1.认识微服务&#xff1a; 单体架构&#xff1a; 微服务架构&#xff1a; 2.服务注册和发现 1.注册中心&#xff1a; 2.服务注册&#xff1a; 3.服务发现&#xff1a; 发现并调用服务&#xff1a; 方法1&#xff1a; 方法2&#xff1a; 方法3:OpenFeign OpenFeig…

网络安全:挑战、技术与未来发展

&#x1f4dd;个人主页&#x1f339;&#xff1a;一ge科研小菜鸡-CSDN博客 &#x1f339;&#x1f339;期待您的关注 &#x1f339;&#x1f339; 1. 引言 在数字化时代&#xff0c;网络安全已成为全球关注的焦点。随着互联网的普及和信息技术的高速发展&#xff0c;网络攻击的…

PostgreSql-COALESCE函数、NULLIF函数、NVL函数使用

COALESCE函数 COALESCE函数是返回参数中的第一个非null的值&#xff0c;它要求参数中至少有一个是非null的; select coalesce(1,null,2),coalesce(null,2,1),coalesce(null,null,null); NULLIF(ex1,ex2)函数 如果ex1与ex2相等则返回Null&#xff0c;不相等返回第一个表达式的值…