PyTorch Lightning Trainer介绍

PyTorch Lightning 的 Trainer 是框架的核心类,负责自动化训练流程、分布式训练、日志记录、模型保存等复杂操作。通过配置参数即可快速实现高效训练,无需手动编写循环代码。以下是详细介绍和使用示例:

Trainer 的核心功能

  1. 自动化训练循环
    自动处理 training_stepvalidation_steptest_step 的调用,无需手动编写 for epoch in epochs 循环。

  2. 硬件加速支持
    支持 CPU/GPU/TPU、多卡训练(DDP、DeepSpeed)、混合精度训练等。

  3. 训练控制
    控制训练轮数 (max_epochs)、批次大小 (batch_size)、梯度裁剪 (gradient_clip_val) 等。

  4. 日志与监控
    集成 TensorBoard、W&B、MLFlow 等日志工具,监控损失、准确率等指标。

  5. 回调机制
    通过回调函数(如 EarlyStoppingModelCheckpoint)实现早停、模型保存等扩展功能。

Trainer 的常用参数

from pytorch_lightning import Trainertrainer = Trainer(# 基础配置max_epochs=10,            # 最大训练轮数accelerator="auto",       # 自动选择设备 (CPU/GPU/TPU)devices="auto",           # 使用所有可用设备(如多 GPU)precision="16-mixed",     # 混合精度训练(FP16)# 日志与调试logger=True,              # 默认使用 TensorBoardlog_every_n_steps=10,     # 每 10 个批次记录一次日志fast_dev_run=False,       # 快速运行一个批次(调试模式)# 回调函数callbacks=[pl.callbacks.EarlyStopping(monitor="val_loss", patience=3),pl.callbacks.ModelCheckpoint(monitor="val_loss", save_top_k=2)],# 分布式训练strategy="ddp",           # 分布式数据并行策略(多 GPU)num_nodes=1,              # 节点数量(多机器训练)
)

使用示例代码

步骤 1:定义 LightningModule
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as plclass LitModel(pl.LightningModule):def __init__(self):super().__init__()self.layer1 = nn.Linear(28*28, 128)self.layer2 = nn.Linear(128, 10)def forward(self, x):x = x.view(x.size(0), -1)  # 展平输入x = F.relu(self.layer1(x))x = self.layer2(x)return xdef training_step(self, batch, batch_idx):x, y = batchy_hat = self(x)loss = F.cross_entropy(y_hat, y)self.log("train_loss", loss)  # 自动记录日志return lossdef validation_step(self, batch, batch_idx):x, y = batchy_hat = self(x)loss = F.cross_entropy(y_hat, y)self.log("val_loss", loss)     # 自动记录验证损失def configure_optimizers(self):return torch.optim.Adam(self.parameters(), lr=0.001)
步骤 2:定义 DataModule
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensorclass MNISTDataModule(pl.LightningDataModule):def __init__(self, batch_size=32):super().__init__()self.batch_size = batch_sizedef prepare_data(self):MNIST(root="data", download=True)def setup(self, stage=None):full_dataset = MNIST(root="data", train=True, transform=ToTensor())self.train_data, self.val_data = random_split(full_dataset, [55000, 5000])def train_dataloader(self):return DataLoader(self.train_data, batch_size=self.batch_size, shuffle=True)def val_dataloader(self):return DataLoader(self.val_data, batch_size=self.batch_size)dm = MNISTDataModule(batch_size=32)

步骤 3:启动训练

model = LitModel()
trainer = Trainer(max_epochs=10,accelerator="auto",devices="auto",logger=True,callbacks=[pl.callbacks.ModelCheckpoint(monitor="val_loss")]
)# 开始训练与验证
trainer.fit(model, datamodule=dm)# 测试(可选)
trainer.test(model, datamodule=dm)

关键功能演示

1. 多 GPU 训练
# 使用 4 个 GPU 训练
trainer = Trainer(devices=4, strategy="ddp")
2. 混合精度训练

# 使用 FP16 混合精度
trainer = Trainer(precision="16-mixed")
3. 早停与模型保存
callbacks = [pl.callbacks.EarlyStopping(monitor="val_loss", patience=3),pl.callbacks.ModelCheckpoint(dirpath="checkpoints/",filename="best-model-{epoch:02d}-{val_loss:.2f}",save_top_k=2,monitor="val_loss")
]
trainer = Trainer(callbacks=callbacks)
4. 调试模式
# 快速验证代码正确性(仅运行一个批次)
trainer = Trainer(fast_dev_run=True)

常见问题

如何恢复训练?
使用 resume_from_checkpoint 参数:

trainer = Trainer(resume_from_checkpoint="path/to/checkpoint.ckpt")

如何限制训练时间?

trainer = Trainer(max_time="00:02:00")  # 最多训练 2 分钟

如何自定义学习率调度器?
在 自定义的 LightningDataModule继承类的 configure_optimizers 方法中返回优化器和调度器:

def configure_optimizers(self):optimizer = Adam(self.parameters())scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)return [optimizer], [scheduler]

总结

通过 Trainer,PyTorch Lightning 将训练流程的复杂性封装在几行配置中,开发者只需关注模型逻辑和数据加载。其灵活的参数和回调机制能够覆盖从实验到生产的全流程需求。

参考:

https://lightning.ai/docs/pytorch/stable/common/trainer.html

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/16722.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

逻辑回归不能解决非线性问题,而svm可以解决

逻辑回归和支持向量机(SVM)是两种常用的分类算法,它们在处理数据时有一些不同的特点,特别是在面对非线性问题时。 1. 逻辑回归 逻辑回归本质上是一个线性分类模型。它的目的是寻找一个最适合数据的直线(或超平面&…

23页PDF | 国标《GB/T 44109-2024 信息技术 大数据 数据治理实施指南 》发布

一、前言 《信息技术 大数据 数据治理实施指南》是中国国家标准化管理委员会发布的关于大数据环境下数据治理实施的指导性文件,旨在为组织开展数据治理工作提供系统性的方法和框架。报告详细阐述了数据治理的实施过程,包括规划、执行、评价和改进四个阶…

ESM3(1)-介绍:用语言模型模拟5亿年的进化历程

超过30亿年的进化在天然蛋白质空间中编码形成了一幅生物学图景。在此,作者证明在进化数据上进行大规模训练的语言模型,能够生成与已知蛋白质差异巨大的功能性蛋白质,并推出了ESM3,这是一款前沿的多模态生成式语言模型,…

在大型语言模型(LLM)框架内Transformer架构与混合专家(MoE)策略的概念整合

文章目录 传统的神经网络框架存在的问题一. Transformer架构综述1.1 transformer的输入1.1.1 词向量1.1.2 位置编码(Positional Encoding)1.1.3 编码器与解码器结构1.1.4 多头自注意力机制 二.Transformer分步详解2.1 传统词向量存在的问题2.2 详解编解码…

【黑马点评】 使用RabbitMQ实现消息队列——3.批量获取1k个用户token,使用jmeter压力测试

【黑马点评】 使用RabbitMQ实现消息队列——3.批量获取用户token,使用jmeter压力测试 3.1 需求3.2 实现3.2.1 环境配置3.2.2 修改登录接口UserController和实现类3.2.3 测试类 3.3 使用jmeter进行测试3.4 测试结果3.5 将用户登录逻辑修改回去3.6 批量删除生成的用户…

【安全靶场】信息收集靶场

靶场:https://app.hackinghub.io/hubs/prison-hack 信息收集 子域名收集 1.subfinder files.jabprisons.com staging.jabprisons.com cobrowse.jabprisons.com a1.top.jabprisons.com cf1.jabprisons.com va.cobrowse.jabprisons.com vs.jabprisons.com c…

springboot239-springboot在线医疗问答平台(源码+论文+PPT+部署讲解等)

💕💕作者: 爱笑学姐 💕💕个人简介:十年Java,Python美女程序员一枚,精通计算机专业前后端各类框架。 💕💕各类成品Java毕设 。javaweb,ssm&#xf…

(一)获取数据和读取数据

获取公开数据 下载、爬虫、API 一些公开数据集网站: 爬虫: 发送请求获取网页源代码——解析网页源代码内容,提取数据 通过公开API获取: API定义了两个程序之间的服务合约,即双方是如何使用请求和响应来进行通讯的…

在MacBook Air上本地部署大模型deepseek指南

随着大模型技术的兴起,越来越多的人开始关注如何在本地部署这些强大的AI模型。如果你也想体验大模型的魅力,那么这篇文章将指导你如何在你的MacBook Air上本地部署大模型. 工具准备 为了实现本地部署,你需要以下工具: Ollama&a…

Windows中使用Docker安装Anythingllm,基于deepseek构建自己的本地知识库问答大模型,可局域网内多用户访问、离线运行

文章目录 Windows中使用Docker安装Anythingllm,基于deepseek构建自己的知识库问答大模型1. 安装 Docker Desktop2. 使用Docker拉取Anythingllm镜像2. 设置 STORAGE_LOCATION 路径3. 创建存储目录和 .env 文件.env 文件的作用关键配置项 4. 运行 Docker 命令docker r…

git学习【个人记录b站尚硅谷】

git学习 Git基本命令操作设置用户签名初始化本地库添加文件从工作区到暂存区将文件从暂存区添加到本地库修改文件重新提交 Git分支Github操作创建远程库上传到远程库克隆到本地文件夹拉取远程库最新版本到本地 总结 Git基本命令操作 设置用户签名 git config --global user.n…

【R语言】t检验

一、基本介绍 t检验(t-test)是用于比较两个样本均值是否存在显著差异的一种统计方法。 t.test()函数的调用格式: t.test(x, yNULL, alternativec("two.sided", "less", "greater"), mu0, pairFALSE, var.eq…

TDengine 产品由哪些组件构成

目 录 背景产品生态taosdtaosctaosAdaptertaosKeepertaosExplorertaosXtaosX Agent应用程序或第三方工具 背景 了解一个产品,最好从了解产品包括哪些内容开始,我这里整理了一份儿 TDegnine 产品包括有哪些组件,每个组件作用是什么的说明&a…

实现限制同一个账号最多只能在3个客户端(有电脑、手机等)登录(附关键源码)

如上图,我的百度网盘已登录设备列表,有一个手机,2个windows客户端。手机设备有型号、最后登录时间、IP等。windows客户端信息有最后登录时间、操作系统类型、IP地址等。这些具体是如何实现的?下面分别给出android APP中采集手机信…

使用 Docker 安装 Open WebUI 并集成 Ollama 的 DeepSeek 模型

文章目录 使用 Docker 安装 Open WebUI 并集成 Ollama 的 DeepSeek 模型前提条件1. 安装ollama2. 拉取deepseek的模型3. Open-WebUI 说明4. 启动容器文档的方法如下优化命令(可选)1. 增加了健康检查机制(--health-cmd)2. 使 WebUI…

Untiy3d 铰链、弹簧,特殊的物理关节

(一)铰链组件 1.创建一个立方体和角色胶囊 2.给角色胶囊挂在控制脚本和刚体 using System.Collections; using System.Collections.Generic; using UnityEngine;public class plyer : MonoBehaviour {// Start is called once before the first execut…

【NLP 21、实践 ③ 全切分函数切分句子】

当无数个自己离去,我便日益坦然 —— 25.2.9 一、jieba分词器 Jieba 是一款优秀的 Python 中文分词库,它支持多种分词模式,其中全切分方式会将句子中所有可能的词语都扫描出来。 1.原理 全切分方式会找出句子中所有可能的词语组合。对于一…

团结引擎 OpenHarmony 平台全面支持 UAAL,实现引擎能力嵌入原生应用

团结引擎1.4版本已于近日正式发布!在这一版本中,OpenHarmony 平台迎来了一个具有里程碑意义的更新:全面支持 Used as a Library(UAAL)。UAAL 这一技术方案,具有将引擎嵌入原生应用的独特能力,其…

自己部署DeepSeek 助力 Vue 开发:打造丝滑的标签页(Tabs)

前言:哈喽,大家好,今天给大家分享一篇文章!并提供具体代码帮助大家深入理解,彻底掌握!创作不易,如果能帮助到大家或者给大家一些灵感和启发,欢迎收藏关注哦 💕 目录 自己…

DeepSeek渣机部署编程用的模型,边缘设备部署模型

DeepSeek渣机部署编程用的模型,边缘设备部署模型 文章目录 DeepSeek渣机部署编程用的模型,边缘设备部署模型前言一、python代码二、构建一个简单的前端来接入接口2.读入数据 总结 前言 也许大家伙都想完成一些部署DeepSeek的东西,不过部署并…