【AI基础】pytorch lightning 基础学习

传统pytorch工作流是首先定义模型框架,然后写训练和验证,测试循环代码。训练,验证,测试代码写起来比较繁琐。这里介绍使用pytorch lightning 部署模型,加速模型训练和验证,记录。

准备工作

1 安装pytorch lightning 检查版本

$ conda create -n lightning python=3.9 -y
$ conda activate lightning
import lightning as L
import torchprint("Lightning version:", L.__version__)
print("Torch version:", torch.__version__)
print("CUDA is available:", torch.cuda.is_available())

2 加载基本库函数

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import lightning as L
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers.tensorboard import TensorBoardLogger
from lightning.pytorch.callbacks.early_stopping import EarlyStopping

3 设置随机种子(可复现性)

L.seed_everything(1121218)

4 数据集下载和增强变换

这里以CIFAR10数据集为例子,该数据集包含 10 个类的 6 万张 32x32 彩色图像,每个类 6000 张图像。

from torchvision import datasets, transforms# Load CIFAR-10 dataset
train_dataset = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform_train
)
val_dataset = datasets.CIFAR10(root="./data", train=False, download=True, transform=transform_test
)
# Data augmentation and normalization for training
transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),],
)
transform_test = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)

上面的增强变换包括以下四种基本变换: 

  • 裁剪(需要指定图像大小,在本例中为 32x32)。
  • 水平翻转。
  • 转换为张量数据类型,这是 PyTorch 所必需的。
  • 对图像的每个颜色通道进行归一化处理。

传统pytorch模型训练流

定义一个CNN模型

class CIFAR10CNN(nn.Module):def __init__(self):super(CIFAR10CNN, self).__init__()self.conv1 = nn.Conv2d(3, 32, 3, padding=1)self.conv2 = nn.Conv2d(32, 64, 3, padding=1)self.conv3 = nn.Conv2d(64, 64, 3, padding=1)self.pool = nn.MaxPool2d(2, 2)self.fc1 = nn.Linear(64 * 4 * 4, 512)self.fc2 = nn.Linear(512, 10)def forward(self, x):x = self.pool(torch.relu(self.conv1(x)))x = self.pool(torch.relu(self.conv2(x)))x = self.pool(torch.relu(self.conv3(x)))x = x.view(-1, 64 * 4 * 4)x = torch.relu(self.fc1(x))x = self.fc2(x)return x

编写训练、验证循环代码

  • 需要初始化模型,损失函数和优化器
  • 管理模型和数据在机器上的运行(CPU 与 GPU)
  • 训练步骤:前向传播、损失计算、反向传播和优化
  • 验证步骤:计算准确性和损失
  • tensorboard日志记录,训练损失,准确率,其他相关指标记录等
  • 模型保存
  • # Initialize the model, loss function, and optimizer
    model = CIFAR10CNN().to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5)# TensorBoard setup
    writer = SummaryWriter('runs/cifar10_cnn_experiment')# Training loop
    total_step = len(train_loader)
    for epoch in range(num_epochs):model.train()train_loss = 0.0for i, (images, labels) in enumerate(train_loader):images = images.to(device)labels = labels.to(device)# Forward passoutputs = model(images)loss = criterion(outputs, labels)# Backward and optimizeoptimizer.zero_grad()loss.backward()optimizer.step()train_loss += loss.item()if (i+1) % 100 == 0:print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{total_step}], Loss: {loss.item():.4f}')# Calculate average training loss for the epochavg_train_loss = train_loss / len(train_loader)writer.add_scalar('training loss', avg_train_loss, epoch)# Validationmodel.eval()with torch.no_grad():correct = 0total = 0val_loss = 0.0for images, labels in test_loader:images = images.to(device)labels = labels.to(device)outputs = model(images)loss = criterion(outputs, labels)val_loss += loss.item()_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()accuracy = 100 * correct / totalavg_val_loss = val_loss / len(test_loader)print(f'Validation Accuracy: {accuracy:.2f}%')writer.add_scalar('validation loss', avg_val_loss, epoch)writer.add_scalar('validation accuracy', accuracy, epoch)# Learning rate schedulingscheduler.step(avg_val_loss)# Final test
    model.eval()
    with torch.no_grad():correct = 0total = 0for images, labels in test_loader:images = images.to(device)labels = labels.to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print(f'Test Accuracy: {100 * correct / total:.2f}%')writer.close()# Save the model
    torch.save(model.state_dict(), 'cifar10_cnn.pth')

     在上面的代码示例,有一些需要特别注意繁琐的细节:

    训练和验证模式之间可以手动切换。
    有梯度计算的手动规范。
    使用较差的 SummaryWriter 类进行日志记录。
    有一个学习率调度程序。

Pytorch lightning 工作流

1 使用LightningModule 类定义模型结构

class CIFAR10CNN(L.LightningModule):def __init__(self):super().__init__()self.conv1 = nn.Conv2d(3, 32, 3, padding=1)self.conv2 = nn.Conv2d(32, 64, 3, padding=1)self.conv3 = nn.Conv2d(64, 64, 3, padding=1)self.pool = nn.MaxPool2d(2, 2)self.fc1 = nn.Linear(64 * 4 * 4, 512)self.fc2 = nn.Linear(512, 10)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = self.pool(F.relu(self.conv2(x)))x = self.pool(F.relu(self.conv3(x)))x = x.view(-1, 64 * 4 * 4)x = F.relu(self.fc1(x))x = self.fc2(x)return x

唯一的区别是,我们是从LightningModule类继承,而不是从继承nn.Module。是类LightningModule的扩展nn.Module。它将 PyTorch 工作流的训练、验证、测试、预测和优化步骤组合到一个没有循环的单一界面中。 当你开始使用时LightningModule,它被组织成六个部分:

  • 初始化(__init__和setup()方法)
  • 训练循环(training_step()方法)
  • 验证循环(validation_step()方法)
  • 测试循环(test_step()方法)
  • 预测循环(prediction_step()方法)
  • 优化器和 LR 调度程序(configure_optimizers())

我们已经看到了初始化部分。让我们继续进行训练步骤。

2 编写训练过程代码

在模型类中,复写training_step()方法

# Add the method inside the class
def training_step(self, batch, batch_idx):x, y = batchy_hat = self(x)loss = F.cross_entropy(y_hat, y)self.log('train_loss', loss)return loss

此方法将整个训练循环压缩为几行代码。首先,从数据batch中读取模型输入和模型输出。然后,我们运行前向传递self(x)并计算损失。然后,我们只需使用内置的 Lightning 记录器函数记录训练损失即可self.log()。

还可以在此方法中记录其他指标,例如训练准确性:

def training_step(self, batch, batch_idx):x, y = batchy_hat = self(x)loss = F.cross_entropy(y_hat, y)acc = (y_hat.argmax(1) == y).float().mean()self.log("train_loss", loss)self.log("train_acc", acc)return loss

log()方法可以自动计算每个epoch的模型的各个指标,比如准确性,F1-score等等。该方法里面有一些参数是可以额外设置的,比如记录每个batch和epoch下的模型指标,模型训练和验证时创建进度条,还有将模型的各个指标输出到本地文件中。

# Log the loss at each training step and epoch, create a progress bar
self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)

3 编写验证和测试步骤代码

def validation_step(self, batch, batch_idx):x, y = batchy_hat = self(x)loss = F.cross_entropy(y_hat, y)acc = (y_hat.argmax(1) == y).float().mean()self.log('val_loss', loss)self.log('val_acc', acc)
def test_step(self, batch, batch_idx):x, y = batchy_hat = self(x)loss = F.cross_entropy(y_hat, y)acc = (y_hat.argmax(1) == y).float().mean()self.log('test_loss', loss)self.log('test_acc', acc)

唯一的区别是不需要返回计算出的指标。Lightning模块会自动将正确的数据加载器分配给验证和测试步骤,并在后台创建循环。

尽管validation_step()和test_step()看起来相同,但它们有一个关键的区别:

  • validation_step()在训练期间,直接参与模型验证。
  • test_step()在测试期间,需要调用训练器对象的.test()方法,才能执行此操作。

4 配置优化器和优化器scheduler程序

为了定义优化器和学习率调度器,需要重写configure_optimizers()类的方法。

def configure_optimizers(self):optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.1, patience=5)return {"optimizer": optimizer,"lr_scheduler": {"scheduler": scheduler,"monitor": "val_loss",},}

上面,创建了一个Adam优化器,传入超参数和学习率。还定义了一个ReduceLROnPlateau调度函数,用于在验证损失稳定时降低学习率。返回对象字典是最灵活的选项,因为它允许定义需要额外参数的scheduler。

https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#configure-optimizers

5 定义callbacks和记录器

模型类和附带的训练,验证,优化器,学习率调度器和指标计算都已经完成,模型可以实现前向和反向传播,模型更新,验证,记录模型的各个指标。此时,还需要定义一系列的callbacks和记录器类型。这里定义一个checkpoint callback和记录器。

checkpoint_callback = ModelCheckpoint(dirpath="checkpoints",monitor="val_loss",filename="cifar10-{epoch:02d}-{val_loss:.2f}-{val_acc:.2f}",save_top_k=3,mode="min",
)

ModelCheckpoint是一个强大的回调,用于在监控给定指标的同时定期保存模型。每个模型检查点都记录到dirpath中。

定义一个tensorboardlogger() 记录方法

logger = TensorBoardLogger(save_dir="lightning_logs", name="cifar10_cnn")

定义一个early_stopping callback

early_stopping = EarlyStopping(monitor="val_loss", patience=5, mode="min", verbose=False)

6 创建一个trainer类

在将模型LightningModule类和callback, 记录器全部定义完以后,就可以定义一个Trainer 类来实现模型的数据读取,自动训练,验证,模型自动保存,比较简洁。可以定义最大epoch数,使用gpu训练和gpu个数,记录器,callback,训练精度,训练数据比例(默认100%),验证数据比例(默认100%),多少个epoch 模型做一次验证,多少个epoch后记录一次模型指标,记录和模型地址,单gpu训练还是分布式训练。

# Initialize the Trainer
trainer = L.Trainer(max_epochs=50,callbacks=[checkpoint_callback, early_stopping],logger=logger,accelerator="gpu" if torch.cuda.is_available() else "cpu",devices="auto",
)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

7 训练和测试模型

# Train and test the modeltrainer.fit(model, train_loader, test_loader)trainer.test(model, test_loader)

8 pytorch lightning 训练模型的基本流程总结

  •   创建应用转换的训练、验证和测试数据加载器。
  • 将代码组织到一个LightningModule类中:
  • 定义初始化。
  • 定义训练、验证和(可选)测试步骤。
  • 定义优化器和学习率调度器。
  • 定义回调和记录器。
  • 创建一个训练类trainer
  • 初始化模型类。
  • 拟合并测试模型。  

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/433920.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于Hive和Hadoop的保险分析系统

本项目是一个基于大数据技术的保险分析系统,旨在为用户提供全面的汽车保险信息和深入的保险价格分析。系统采用 Hadoop 平台进行大规模数据存储和处理,利用 MapReduce 进行数据分析和处理,通过 Sqoop 实现数据的导入导出,以 Spark…

使用 Git 帮助文档

聊聊如何更好地查阅官方文档。 ‍ git help 学习某个工具&#xff0c;官方文档是少不了的&#xff0c;也是最权威的。我们可以使用 git help 来查看帮助&#xff0c;该命令会列举出常用的命令和介绍&#xff1a; > git help usage: git [--version] [--help] [-C <pa…

如何利用 Kafka,实时挖掘企业数据的价值?

首先&#xff0c;问读者老爷们一个简单的问题&#xff0c;如果你需要为你的数据选择一个同时具备高吞吐 、数据持久化、可扩展的数据传递系统&#xff0c;你会选择什么样的工具或架构呢&#xff1f; 答案非常显而易见&#xff0c;那就是 Kafka&#xff0c;不妨再次套用一个被反…

关于Chrome浏览器F12调试,显示未连接到互联网的问题

情况说明 最近笔者更新下电脑的Chrome浏览器&#xff0c;在调试前端代码的时候&#xff0c;遇到下面一个情况&#xff1a; 发现打开调试面板后&#xff0c;页面上显示未连接到互联网&#xff0c;但实际电脑网络是没有问题的&#xff0c;关闭调试面板后&#xff0c;网页又能正…

基于大数据的亚健康人群数据分析及可视化系统

作者&#xff1a;计算机学姐 开发技术&#xff1a;SpringBoot、SSM、Vue、MySQL、JSP、ElementUI、Python、小程序等&#xff0c;“文末源码”。 专栏推荐&#xff1a;前后端分离项目源码、SpringBoot项目源码、Vue项目源码、SSM项目源码 精品专栏&#xff1a;Java精选实战项目…

asp.net core grpc快速入门

环境 .net 8 vs2022 创建 gRPC 服务器 一定要勾选Https 安装Nuget包 <PackageReference Include"Google.Protobuf" Version"3.28.2" /> <PackageReference Include"Grpc.AspNetCore" Version"2.66.0" /> <PackageR…

通过 OBD Demo 体验 OceanBase 4.3 社区版

本文作者&#xff1a;马顺华 引言 OceanBase 4.3 是一个专为实时分析 AP 业务设计的重大更新版本。它基于LSM-Tree架构&#xff0c;引入了列存引擎&#xff0c;实现了行存与列存数据存储的无缝整合。这一版本不仅显著提升了AP场景的查询性能&#xff0c;同时也确保了TP业务场景…

看Threejs好玩示例,学习创新与技术(React-three-fiber)

什么&#xff0c;竟有人把ThreeJS和React绑定在一起&#xff0c;混着用&#xff1f; 1、VUE劫持问题 暂先把今天的问题先放一边&#xff0c;先简单回顾下vue劫持的情况。vue会把data里面的数据自动转换为属性&#xff0c;方便界面与数据交互。这本身是没有任何问题&#xff0…

人工智能 实验1 Python语法

我发现了有些人喜欢静静看博客不聊天呐&#xff0c; 但是ta会点赞。 这样的人呢帅气低调有内涵&#xff0c; 美丽大方很优雅。 说的就是你&#xff0c; 不用再怀疑哦 实验1 Python语言基础一 【实验目的】掌握Python及其集成开发环境的下载安装及其简单应用 【实验内容…

vue3中< keep-alive >页面实现缓存及遇到的问题

vue3中< keep-alive >页面实现缓存及遇到的问题 实现原理&#xff1a;keep-alive 是 Vue 的内置组件&#xff0c;当它包裹动态组件时&#xff0c;会缓存不活动的组件实例&#xff0c;而不是销毁它们。实现不同路由是否缓存只需要设置对应路由参数keepAlive为true&#xf…

【周末推荐】替换SwitchyOmega的Chrome浏览器插件

SwitchyOmega插件在我们这个圈子里应该无人不知无人不晓了吧&#xff0c;最近有很多朋友反馈自己的SwitchyOmega不工作了&#xff0c;今天我们将聊聊为什么SwitchyOmega不工作了&#xff0c;并推荐2款实用的Chrome浏览器插件解决这个问题。 为什么要替换SwitchyOmega&#xff…

【有啥问啥】深度理解主动学习:机器学习的高效策略

深度理解主动学习&#xff1a;机器学习的高效策略 在大数据时代&#xff0c;数据量的爆炸性增长与有限的标注资源之间的矛盾日益凸显。如何高效地利用标注资源来训练高质量的模型&#xff0c;成为了机器学习领域亟待解决的问题。主动学习&#xff08;Active Learning, AL&…

Vmware VC登录报错:Vmware报错 HTTP状态 500 - 内部服务器错误

问题现象&#xff1a; 登录Vmware VC系统报错&#xff1a;Vmware报错 HTTP状态 500 - 内部服务器错误、 然后登录管理服务&#xff08;访问端口&#xff1a;5480&#xff09;重启一下异常服务&#xff0c;结果提示证书过期。 初步判断VC SSL证书到期 判定方法&#xff1a; 1…

基于微信小程序爱心领养小程序设计与实现(源码+定制+开发)

博主介绍&#xff1a; ✌我是阿龙&#xff0c;一名专注于Java技术领域的程序员&#xff0c;全网拥有10W粉丝。作为CSDN特邀作者、博客专家、新星计划导师&#xff0c;我在计算机毕业设计开发方面积累了丰富的经验。同时&#xff0c;我也是掘金、华为云、阿里云、InfoQ等平台…

C++编程基础:内联函数、auto关键字、基于范围的for循环和nullptr

内联函数 概念 以inline修饰的函数叫做内联函数,编译时C++编译器会在调用内联函数的地方展开,没有函数调用建立栈帧的开销,内联函数提升程序运行的效率。 如果在函数前增加inline关键字将其改成内联函数,在编译期间编译器会用函数体替换函数的调用。 特性 1.我们可以这…

深入浅出MySQL事务处理:从基础概念到ACID特性及并发控制

1、什么是事务 在实际的业务开发中&#xff0c;有些业务操作要多次访问数据库。一个业务要发送多条SQL语句给数据库执行。需要将多次访问数据库的操作视为一个整体来执行&#xff0c;要么所有的SQL语句全部执行成功。如果其中有一条SQL语句失败&#xff0c;就进行事务的回滚&a…

第五部分:5---三张信号表,信号表的系统调用

目录 信号的递达、未决、阻塞&#xff1a; 进程维护的三张信号表&#xff1a; 普通信号与实时信号的记录&#xff1a; 信号结构的系统调用&#xff1a; bolck表的系统调用&#xff1a; 实例&#xff1a;设置屏蔽信号集中的所有信号都频闭 pending表读取&#xff1a; 信号…

html TAB切换按钮变色、自动生成table

<!DOCTYPE html> <head> <meta charset"UTF-8"> <title>Dynamic Tabs with Table Data</title> <style> /* 简单的样式 */ .tab-content { display: none; border: 1px solid #ccc; padding: 1px; marg…

OFDM通信系统发射端需要做ifftshift的原因分析

对频率为15Hz的正弦波信号进行FFT分析&#xff0c;并且直接画图&#xff0c;matlab代码如下&#xff1a; fs 100; % sampling frequency t 0:(1/fs):(10-1/fs); % time vector S cos(2*pi*15*t); n length(S); X fft(S); f (0:n-1)*(fs/n); %frequenc…

Django 数据库配置以及字段设置详解

配置PostGre 要在 Django 中配置连接 PostgreSQL 数据库&#xff0c;并创建一个包含“使用人”和“车牌号”等字段的 Car 表 1. 配置 PostgreSQL 数据库连接 首先&#xff0c;在 Django 项目的 settings.py 中配置 PostgreSQL 连接。 修改 settings.py 文件&#xff1a; …