PyTorch Lightning:通过分布式训练扩展深度学习工作流

 

一、介绍

        欢迎来到我们关于 PyTorch Lightning 系列的第二篇文章!在上一篇文章中,我们向您介绍了 PyTorch Lightning,并探讨了它在简化深度学习模型开发方面的主要功能和优势。我们了解了 PyTorch Lightning 如何为组织和构建 PyTorch 代码提供高级抽象,使研究人员和从业者能够更多地关注模型设计和实验,而不是样板代码。

        在本文中,我们将深入研究 PyTorch Lightning,并探索它如何通过分布式训练实现深度学习工作流的扩展。分布式训练对于在海量数据集上训练大型模型至关重要,因为它允许我们利用多个 GPU 或机器的强大功能来加速训练过程。然而,分布式训练往往伴随着一系列挑战和复杂性。

二、安装 Pytorch Lightning & Torchvision

pip install torch torchvision pytorch-lightning 

三、实现

        首先,我们需要从 PyTorch 和 PyTorch Lightning 导入必要的模块:

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision import transformsimport pytorch_lightning as pl

        接下来,我们使用 PyTorch 的类定义我们的神经网络架构。在这个例子中,我们使用一个简单的卷积神经网络,其中包含两个卷积层和三个全连接层:nn.Module

class Net(pl.LightningModule):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(3, 6, 5)self.pool = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(6, 16, 5)self.fc1 = nn.Linear(16 * 5 * 5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = self.pool(nn.functional.relu(self.conv1(x)))x = self.pool(nn.functional.relu(self.conv2(x)))x = torch.flatten(x, 1)x = nn.functional.relu(self.fc1(x))x = nn.functional.relu(self.fc2(x))x = self.fc3(x)return x

        然后,我们为 .在该方法中,我们接收一批输入和标签,将它们通过我们的神经网络来获取 logits,计算交叉熵损失,并使用该方法记录训练损失。在该方法中,我们执行与 相同的操作,但不记录损失:LightningModuletraining_stepxyself.logvalidation_steptraining_step

    def training_step(self, batch, batch_idx):x, y = batchlogits = self(x)loss = nn.functional.cross_entropy(logits, y)self.log("train_loss", loss)return lossdef validation_step(self, batch, batch_idx):x, y = batchlogits = self(x)loss = nn.functional.cross_entropy(logits, y)self.log("val_loss", loss)return loss

        我们还在方法中定义了优化器和学习率调度器:configure_optimizers

    def configure_optimizers(self):optimizer = torch.optim.Adam(self.parameters(), lr=0.001)scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1)return [optimizer], [scheduler]

        接下来,我们使用 PyTorch 和 定义数据加载和预处理步骤:DataLoadertransforms

    def prepare_data(self):transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])CIFAR10(root='./data', train=True, download=True, transform=transform)CIFAR10(root='./data', train=False, download=True, transform=transform)def train_dataloader(self):transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])train_dataset = CIFAR10(root='./data', train=True, download=False, transform=transform)return DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=8)def val_dataloader(self):transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])val_dataset = CIFAR10(root='./data', train=False, download=False, transform=transform)return DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=8)
  1. prepare_data(self):此函数负责在训练模型之前准备数据。它首先使用该类定义一系列转换。转换包括将数据转换为张量并对其进行规范化。定义转换后,该函数将下载用于训练和测试拆分的 CIFAR10 数据集。数据集将下载到目录,并将指定的转换应用于数据。transforms.Compose'./data'
  2. train_dataloader(self):此函数为训练数据集创建数据加载器。它首先定义与函数中相同的转换。接下来,它为训练拆分创建 CIFAR10 数据集的实例。从目录中加载数据集,并应用指定的转换。最后,使用训练数据集创建一个对象。数据加载程序配置为 64 的批大小,对数据进行随机排序,并使用 8 个工作线程进行数据加载。它返回数据加载器。prepare_data'./data'DataLoader
  3. val_dataloader(self):此函数为验证数据集创建数据加载器。它遵循与函数类似的结构。它首先使用 定义转换,这些转换与前面的函数相同。然后,为验证拆分创建 CIFAR10 数据集的实例。从目录中加载数据集,并应用指定的转换。最后,使用验证数据集创建一个对象。数据加载器配置为 64 的批大小,无需随机处理数据,并使用 8 个工作线程进行数据加载。它返回数据加载器。train_dataloadertransforms.Compose'./data'DataLoader

        该函数将模型作为输入,并对测试数据集执行评估。它首先对测试数据应用转换,将其转换为张量并规范化。然后,它为测试数据集创建数据加载程序。模型将移动到相应的设备(GPU,如果可用)。评估标准设置为交叉熵损失。evaluate_model

def evaluate_model(model):transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])test_dataset = CIFAR10(root='./data', train=False, download=True, transform=transform)test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=8)device = torch.device("cuda" if torch.cuda.is_available() else "cpu")model = model.to(device)criterion = nn.CrossEntropyLoss()model.eval()test_loss = 0.0correct = 0total = 0with torch.no_grad():for data in test_loader:inputs, labels = datainputs = inputs.to(device)labels = labels.to(device)outputs = model(inputs)loss = criterion(outputs, labels)test_loss += loss.item()_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()accuracy = 100.0 * correct / totalaverage_loss = test_loss / len(test_loader)print(f"Test Loss: {average_loss:.4f}")print(f"Test Accuracy: {accuracy:.2f}%")

        将模型置于评估模式,并初始化测试损失、正确预测和总数据点的变量。在无梯度上下文中,该函数遍历测试数据加载器,通过模型转发成批的输入,计算损失并累积测试损失。它还计算正确预测的数量和数据点的总数。最后,它计算并打印平均测试损失和测试精度。

        最后,我们实例化我们的模型和来自 PyTorch Lightning,指定用于分布式训练的所需数量的 GPU 或机器:NetTrainer

net = Net()trainer = pl.Trainer(num_nodes=1,  # Change to the number of machines in your distributed setupaccelerator="auto",  # Distributed Data Parallel, Available names are: auto, cpu, cuda, hpu, ipu, mps, tpu.max_epochs=5, devices=1 # Change to the desired number of GPUs or use `None` for CPU training
)trainer.fit(net)evaluate_model(net)
  • num_nodes:它指定分布式设置中的计算机数量。在这种情况下,它设置为 ,表示单台计算机设置。1
  • accelerator:它确定训练的加速器类型。该值允许 PyTorch Lightning 根据硬件和软件环境自动选择适当的加速器。其他可能的值包括 、 和 ,它们对应于特定的硬件加速器。"auto""cpu""cuda""hpu""ipu""mps""tpu"
  • max_epochs:它设置用于训练模型的最大周期数(通过训练数据集的完整遍历)。在本例中,它设置为 。5
  • devices:它指定用于训练的 GPU 数量。将其设置为 表示使用单个 GPU 进行训练。如果要在 CPU 上进行训练,可以将其设置为 。1None

        这些选项允许您控制训练过程的各个方面,例如分布式训练、加速器选择以及用于训练的周期数和设备数。

        设置好所有内容后,我们只需调用对象的方法,传入我们的模型、训练数据加载器和验证数据加载器。fitTrainerNet

四、输出

 

五、结论

        PyTorch Lightning 通过分布式训练简化了扩展深度学习工作流的过程。通过抽象化分布式训练的复杂性,PyTorch Lightning 使我们能够专注于设计和实现我们的深度学习模型,而不必担心低级细节。在本文中,我们演练了一个使用 PyTorch Lightning 进行分布式训练的示例代码实现。通过利用多个GPU或机器的强大功能,我们可以显著减少大型深度学习模型的训练时间。

六、引用

  • PyTorch Lightning: Welcome to ⚡ PyTorch Lightning — PyTorch Lightning 2.1.0.rc0 documentation
  • PyTorch: PyTorch
  • torchvision.datasets.CIFAR10: Datasets — Torchvision 0.15 documentation
  • torch.utils.data.DataLoader: torch.utils.data — PyTorch 2.0 documentation
  • 火炬亚当:Adam — PyTorch 2.0 documentation
  • torch.optim.lr_scheduler。步长:StepLR — PyTorch 2.0 documentation
  • Torch.nn.CrossEntropyLoss: CrossEntropyLoss — PyTorch 2.0 documentation
  • torch.cuda.is_available:torch.cuda — PyTorch 2.0 documentation

阿奈·东格雷

皮托奇

分布式系统

深度学习
皮托奇闪电
计算机视觉

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/98355.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

数据结构 - 算法设计的基本要求

1、算法的描述: 自然语言:英语、中文流程图:传统流程图、NS流程图伪代码:类语言 - 类C语言程序代码:Java、C语言 2、算法的特性: 一个算法必须具备以下五个特性: 3、算法设计的要求 正确性可…

分代收集 + 垃圾回收算法

分代假说 1. 弱分代假说(Weak Generational Hypothesis):绝大多数对象都是朝生夕灭的 2. 强分代假说(Strong Generational Hypothesis):熬过越多次垃圾收集过程的对象就越难以消亡 3. 跨代引用假说&…

leetcode 125.验证回文串

⭐️ 题目描述 🌟 leetcode链接:https://leetcode.cn/problems/valid-palindrome/ 思路: 这道题只判断字符串中的字母与数字是否是回文。虽然小写大写字母可以互相转换,但是里面是含有数字字符的,所以先统一&#xff…

爬虫逆向实战(十八)--某得科技登录

一、数据接口分析 主页地址:某得科技 1、抓包 通过抓包可以发现数据接口是AjaxLogin 2、判断是否有加密参数 请求参数是否加密? 查看“载荷”模块可以发现有一个password加密参数和一个__RequestVerificationToken 请求头是否加密? 无…

JavaScript 快速入门手册

本篇文章学习: 菜鸟教程、尚硅谷。 JavaScript 快速入门手册 💯 前言: 本人目前算是一个Java程序员,但是目前环境… ε(ο`*))) 一言难尽啊,blog也好久好久没有更新了,一部分工作原因吧(外包真…

window安裝python2.7.0

官网下载安装 https://www.python.org/downloads/release/python-270/ 选中所有用户,然后点击next 切换安装位置,最好不要选择c盘 点击next 等待安装 安装完成 配置环境变量 将python安装路径添加到系统环境变量 cmd窗口输入python,会打开应用商…

Idea中隐藏指定文件或指定类型文件

Setting ->Editor ->Code Style->File Types → Ignored Files and Folders输入要隐藏的文件名,支持*号通配符回车确认添加

基于51单片机直流电机转速数码管显示控制系统

一、系统方案 本文主要研究了利用MCS-51系列单片机控制PWM信号从而实现对直流电机转速进行控制的方法。本文中采用了三极管组成了PWM信号的驱动系统,并且对PWM信号的原理、产生方法以及如何通过软件编程对PWM信号占空比进行调节,从而控制其输入信号波形等…

LNMP及论坛搭建

目录 安装Nginx服务 1.安装依赖包 2.创建运行用户、组(Nginx 服务程序默认以 nobody 身份运行,建议为其创建专门的用户账号,以便更准确地控制其访问权限) 3.编译安装Nginx 4.优化路径 5.添加 Nginx 系统服务 安装 MySQL 服…

正则表达式在PHP8中的应用案例-PHP8知识详解

正则表达式在php8中有许多应用案例。以下是一些常见的应用场景:如数据验证、数据提取、数据替换、url路由、文本搜索和过滤等。 1、数据验证 使用正则表达式可以对用户输入的数据进行验证,例如验证邮箱地址、手机号码、密码强度等。 下面是一个用正则表…

BDA初级分析——SQL多表连接应用

一、用SQL拼接数据 三个初始数据 问题1:在所有的数据里,销售额最高的产品品类名是什么? 问题2:是否有什么产品是在所观测的时间里没有被购买过的? 拼接数据:JOIN join,加入 作用&#xff1…

uniapp 上传比较大的视频文件就超时

uni.uploadFile,上传超过10兆左右的文件就报错err:uploadFile:fail timeout,超时 解决: 在manifest.json文件中做超时配置 uni.uploadFile({url: this.action,method: "POST",header: {Authorization: uni.getStorage…

【es6】中的Generator

Generator 一、Generator 是什么?1.1 与普通函数写法不一样,有两个不同 二、Generator 使用2.1 书写方法 三、yield语句3.1 yield和return3.2 注意事项3.3 yield*语句3.4 yield*应用 四、next方法4.1参数 总结 一、Generator 是什么? Genera…

[JavaWeb]【一】入门JavaWeb开发总概及HTML、CSS、JavaScript

目录 一 特色 二 收获​编辑 三 什么是web? 四 网站的工作流程 五 web网站的开发模式​编辑 六 web开发课程学习安排 七、初始web前端 八 HTML、CSS 8.1 什么是HTNL\CSS(w3cschool) 8.2 HTML快速入门 8.3 VS Code开发工具 8.3.1 插件 8.3.2 主题(改变颜色&…

【C++】做一个飞机空战小游戏(八)——生成敌方炮弹(rand()和srand()函数应用)

[导读]本系列博文内容链接如下: 【C】做一个飞机空战小游戏(一)——使用getch()函数获得键盘码值 【C】做一个飞机空战小游戏(二)——利用getch()函数实现键盘控制单个字符移动【C】做一个飞机空战小游戏(三)——getch()函数控制任意造型飞机图标移动 【C】做一个飞…

solidwords(5)

我们打算从上面画出总体,再从上面、侧面切除 最后成品

blender 发射体粒子

发射体粒子的基础设置 选择需要添加粒子的物体,点击右侧粒子属性,在属性面板中,点击加号,物体表面会出现很多小点点,点击空格键,粒子会自动运动,像下雨一样; bender 粒子系统分为两…

抖音火山引擎推出免费域名DNS和公共DNS服务

抖音旗下的云计算服务火山引擎最近推出了"TrafficRoute DNS 套件"服务,其中包括两款产品,对软希网来说非常有用。 1.域名DNS: 这是一个用于网站域名的DNS服务,可以加速域名解析速度,从而提升网站的速度。如…

【3Ds Max】挤出命令的简单使用(实现二维变三维)

简介 在3ds Max中,"挤出"(Extrude)是一种常用的建模操作,用于在平面或曲面上创建立体几何形状。以下是使用3ds Max中的挤出命令的基本步骤: 创建基本几何形状: 在3ds Max中创建一个基本的几何形…

自动驾驶——车辆动力学模型

/*lat_controller.cpp*/ namespace apollo { namespace control {using apollo::common::ErrorCode;//故障码 using apollo::common::Status;//状态码 using apollo::common::TrajectoryPoint;//轨迹点 using apollo::common::VehicleStateProvider;//车辆状态信息 using Matri…