使用亚马逊针对 PyTorch 和 MinIO 的 S3 连接器进行模型检查点处理

2023 年 11 月,Amazon 宣布推出适用于 PyTorch 的 S3 连接器。适用于 PyTorch 的 Amazon S3 连接器提供了专为 S3 对象存储构建的 PyTorch 数据集基元(数据集和数据加载器)的实现。它支持用于随机数据访问模式的地图样式数据集和用于流式处理顺序数据访问模式的可迭代样式数据集。适用于 PyTorch 的 S3 连接器还包括一个检查点接口,用于将检查点直接保存和加载到 S3 存储桶,而无需先保存到本地存储。如果您还没有准备好采用正式的 MLOps 工具,而只需要一种简单的方法来保存模型,那么这是一个非常好的选择。这就是我将在这篇文章中介绍的内容。S3 连接器的文档仅展示了如何将其与 Amazon S3 一起使用 - 我将在此处向您展示如何将其用于 MinIO。让我们先执行此作 - 让我们设置 S3 连接器,以便它从 MinIO 写入和读取检查点。

将 S3 连接器连接到 MinIO

将 S3 连接器连接到 MinIO 就像设置环境变量一样简单。之后,一切都会顺利进行。诀窍是以正确的方式设置正确的环境变量。

本文的代码下载使用 .env 文件来设置环境变量,如下所示。此文件还显示了我用于使用 MinIO Python SDK 直接连接到 MinIO 的环境变量。请注意,AWS_ENDPOINT_URL 需要 protocol,而 MinIO 变量不需要。

AWS_ACCESS_KEY_ID=admin
AWS_ENDPOINT_URL=http://172.31.128.1:9000
AWS_REGION=us-east-1
AWS_SECRET_ACCESS_KEY=password
MINIO_ENDPOINT=172.31.128.1:9000
MINIO_ACCESS_KEY=admin
MINIO_SECRET_KEY=password
MINIO_SECURE=false

写入和读取 Checkpoint

我从一个简单的例子开始。下面的代码段创建了一个 S3Checkpointing 对象,并使用其 writer() 方法将模型的状态字典发送到 MinIO。我还使用 Torchvision 创建了一个 ResNet-18(18 层)模型,用于演示目的。

import osfrom dotenv import load_dotenv
from s3torchconnector import S3Checkpoint
import torchvision
import torch# Load the credentials and connection information.
load_dotenv()model = torchvision.models.resnet18()
model_name = 'resnet18.pth'
bucket_name = 'checkpoints'checkpoint_uri = f's3://{bucket_name}/{model_name}'
s3_checkpoint = S3Checkpoint(os.environ['AWS_REGION'])# Save checkpoint to S3
with s3_checkpoint.writer(checkpoint_uri) as writer:torch.save(model.state_dict(), writer)

请注意,该区域有一个强制参数。从技术上讲,访问 MinIO 时没有必要,但如果为此变量选择错误的值,内部检查可能会失败。此外,您的存储桶必须存在,上述代码才能正常工作。如果 writer() 方法不存在,它将引发错误。不幸的是,无论出了什么问题,writer() 方法都会引发相同的错误。例如,如果您的存储桶不存在,您将收到如下所示的错误。如果 writer() 方法不喜欢您指定的区域,您也会收到相同的错误。希望未来的版本将提供更具描述性的错误消息。

S3Exception: Client error: Request canceled

将以前保存的模型读取到内存中的代码类似于写入 MinIO。使用 reader() 方法,而不是 writer() 方法。下面的代码显示了如何执行此作。

import osfrom dotenv import load_dotenv
from s3torchconnector import S3Checkpoint
import torchvision
import torch# Load the credentials and connection information.
load_dotenv()model_name = 'resnet18.pth'
bucket_name = 'checkpoints'checkpoint_uri = f's3://{bucket_name}/{model_name}'
s3_checkpoint = S3Checkpoint(os.environ['AWS_REGION'])# Load checkpoint from S3
with s3_checkpoint.reader(checkpoint_uri) as reader:state_dict = torch.load(reader, weights_only=True)model.load_state_dict(state_dict)

接下来,让我们看看模型训练期间检查点的一些实际注意事项。

在模型训练期间编写检查点

如果您使用大型数据集训练大型模型,请考虑在每个 epoch 后设置检查点。这些训练运行可能需要数小时甚至数天才能完成,因此在发生故障时能够从上次中断的地方继续非常重要。此外,我们假设您必须使用共享存储桶来保存来自多个团队的多个模型的模型检查点。MLOps 约定是按试验组织训练运行。例如,如果您正在研究具有四个隐藏层的架构,那么在寻找各种超参数的最佳值时,您将使用此架构进行多次运行。如果同事使用五层体系结构运行实验,则需要一种方法来防止名称冲突。这可以通过模拟如下所示的层次结构的对象路径来解决。

最后,为了确保您在每个 epoch 中获得新版本的模型,请确保在用于保存检查点的存储桶上启用版本控制。下面的训练函数使用上述路径结构在每个 epoch 后对模型进行检查点作。(可以在本文的代码下载中找到此训练函数的更强大版本。

def train_model(model: nn.Module, loader: DataLoader, training_parameters: Dict[str, Any]) -> List[float]:if training_parameters['checkpoint']:checkpoint_uri = f's3://{training_parameters["checkpoint_bucket"]} \/{training_parameters["project_name"]} \/{training_parameters["experiment_name"]} \/{training_parameters["run_id"]} \/{training_parameters["model_name"]}'s3_checkpoint = S3Checkpoint(region=os.environ['AWS_REGION'])loss_func = nn.NLLLoss()optimizer = optim.SGD(model.parameters(), lr=training_parameters['lr'], momentum=training_parameters['momentum'])# Epoch loopcompute_time_by_epoch = []for epoch in range(training_parameters['epochs']):# Batch loopfor images, labels in loader:# Flatten MNIST images into a 784 long vector.# shape = [32, 784]images = images.view(images.shape[0], -1)# Training passoptimizer.zero_grad()output = model(images)loss = loss_func(output, labels)loss.backward()optimizer.step()# Save checkpoint to S3if training_parameters['checkpoint']:with s3_checkpoint.writer(checkpoint_uri) as writer:torch.save(model.state_dict(), writer)

请注意,模型名称不包含指示纪元的子字符串。如前所述,我使用了启用了版本控制的存储桶 - 换句话说,版本号表示纪元。这种方法的优点在于,您无需知道引用最新模型的 epoch 数。在上述训练代码运行了 10 个 epoch 后,我的检查点存储桶如下面的屏幕截图所示。

此培训演示可被视为 DIY MLOps 解决方案的开始。

结论

适用于 PyTorch 的 S3 连接器易于使用,工程师在使用时编写的数据访问代码行数会更少。在本文中,我展示了如何将其配置为使用环境变量连接到 MinIO。配置完成后,工程师可以分别使用 writer() 和 reader() 方法将检查点写入和读取 MinIO。在本文中,我展示了如何配置 S3 Connect 以连接到 MinIO。我还演示了 S3Checkpoint 类及其 reader() 和 writer() 方法的基本用法。最后,我展示了一种在实际训练函数中针对启用了版本的检查点存储桶使用这些检查点功能的方法。在这篇文章中,我没有介绍在分布式训练期间检查点所需的技术和工具,这可能有点棘手。分布式训练期间的检查点设置会有所不同,具体取决于您使用的框架(PyTorch、Ray 或 DeepSpeed 等)和您正在进行的分布式训练类型:数据并行(每个工作程序都有模型的完整副本)或模型并行(每个工作程序只有一个模型分片)。在以后的文章中,我将介绍其中的一些技术。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/16257.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于 SpringBoot 和 Vue 的智能腰带健康监测数据可视化平台开发(文末联系,整套资料提供)

基于 SpringBoot 和 Vue 的智能腰带健康监测数据可视化平台开发 一、系统介绍 随着人们生活水平的提高和健康意识的增强,智能健康监测设备越来越受到关注。智能腰带作为一种新型的健康监测设备,能够实时采集用户的腰部健康数据,如姿势、运动…

【cocos creator】拖拽排序列表

DEMO下载 GameCtrl.ts import ItemCtrl from "./ItemCtrl";const { ccclass, property } cc._decorator;ccclass export default class GameCtrl extends cc.Component {property(cc.Node)content: cc.Node null;property(cc.Node)prefab: cc.Node null;arr []…

Vision Transformer:打破CNN垄断,全局注意力机制重塑计算机视觉范式

目录 引言 一、ViT模型的起源和历史 二、什么是ViT? 图像处理流程 图像切分 展平与线性映射 位置编码 Transformer编码器 分类头(Classification Head) 自注意力机制 注意力图 三、Coovally AI模型训练与应用平台 四、ViT与图像…

国产编辑器EverEdit - 编辑辅助功能介绍

1 编辑辅助功能 1.1 各编辑辅助选项说明 1.1.1 行号 打开该选项时,在编辑器主窗口左侧显示行号,如下图所示: 1.1.2 文档地图 打开该选项时,在编辑器主窗口右侧靠近垂直滚动条的地方显示代码的缩略图,如下图所示&…

Spring AI 介绍

文章来源:AI 概念 (AI Concepts) _ Spring AI1.0.0-SNAPSHOT中文文档(官方文档中文翻译)|Spring 教程 —— CADN开发者文档中心 本节介绍 Spring AI 使用的核心概念。我们建议仔细阅读它,以了解 Spring AI 是如何实现的。 模型 AI 模型是旨在处理和生成…

Spring MVC 拦截器(Interceptor)与过滤器(Filter)的区别?

1、两者概述 拦截器(Interceptor): 只会拦截那些被 Controller 或 RestController 标注的类中的方法处理的请求,也就是那些由 Spring MVC 调度的请求。过滤器(Filter): 会拦截所有类型的 HTTP …

qt QCommandLineOption 详解

1、概述 QCommandLineOption类是Qt框架中用于解析命令行参数的类。它提供了一种方便的方式来定义和解析命令行选项,并且可以与QCommandLineParser类一起使用,以便在应用程序中轻松处理命令行参数。通过QCommandLineOption类,开发者可以更便捷…

Flink KafkaConsumer offset是如何提交的

一、fllink 内部配置 client.id.prefix,指定用于 Kafka Consumer 的客户端 ID 前缀partition.discovery.interval.ms,定义 Kafka Source 检查新分区的时间间隔。 请参阅下面的动态分区检查一节register.consumer.metrics 指定是否在 Flink 中注册 Kafka…

从Word里面用VBA调用NVIDIA的免费DeepSeekR1

看上去能用而已。 选中的文字作为输入,运行对应的宏即可;会先MSGBOX提示一下,然后相关内容追加到word文档中。 需要自己注册生成好用的apikey Option ExplicitSub DeepSeek()Dim selectedText As StringDim apiKey As StringDim response A…

网络工程师 (29)CSMA/CD协议

前言 CSMA/CD协议,即载波监听多路访问/碰撞检测(Carrier Sense Multiple Access with Collision Detection)协议,是一种在计算机网络中,特别是在以太网环境下,用于管理多个设备共享同一物理传输介质的重要…

WPS中如何批量上下居中对齐word表格中的所有文字

大家好,我是小鱼。 在日常制作Word表格时,经常需要对表格中的内容进行排版。经常会把文字设置成左对齐、居中对齐或者是右对齐,这些对齐方式都比较好设置,有时制作的表格需要把文字批量上下居中对齐,轻松几步就可以搞…

GeekPad智慧屏编程控制

前面通过homeassistant和emqx一番折腾,已经可以控制GeekPad智慧屏的开关了。但是这中间用到的软件对环境依赖非常高,想再优化一下,把这两个工具都装到手机上,最后勉强实现了,但是还得借用模拟器和容器,稳定…

【DeepSeek】在本地计算机上部署DeepSeek-R1大模型实战(完整版)

【作者主页】Francek Chen 【专栏介绍】 ⌈ ⌈ ⌈人工智能与大模型应用 ⌋ ⌋ ⌋ 人工智能(AI)通过算法模拟人类智能,利用机器学习、深度学习等技术驱动医疗、金融等领域的智能化。大模型是千亿参数的深度神经网络(如ChatGPT&…

可编程网卡芯片在京东云网络的应用实践【BGW边界网关篇】

目录导览 文章背景 一.网关问题分析 BGW专线网关机器运维变更困难 BGW专线网关故障收敛链路复杂且长 BGW专线网关不具备异构架构下的灾备能力 BGW专线网关硬件资源成本居高不下 二.技术方案设计实现 网络拓扑规划与VIP架构升级 硬件实现与N-Tb流量平滑迁移 三.落地…

接口测试Day12-持续集成、git简介和安装、Gitee远程仓库、jenkins集成

持续集成 概念: 团队成员将自己的工作成果,持续集成到一个公共平台的过程。成员可以每天集成一次,也可以一天集成多 次。 相关工具: 本地代码管理:git远程代码管理:gitee(国内)、github(国外)、gitlib(公司…

前端快速生成接口方法

大家好,我是苏麟,今天聊一下OpenApi。 官网 : umijs/openapi - npm 安装命令 npm i --save-dev umijs/openapi 在根目录(项目目录下)创建文件 openapi.config.js import { generateService } from umijs/openapi// 自…

三角测量——用相机运动估计特征点的空间位置

引入 使用对极约束估计了相机运动后,接下来利用相机运动估计特征点的空间位置,使用的方法就是三角测量。 三角测量 和对极几何中的对极几何约束描述类似: z 2 x 2 R ( z 1 x 1 ) t z_2x_2R(z_1x_1)t z2​x2​R(z1​x1​)t 经过对极约束…

WPS计算机二级•文档的文本样式与编号

听说这是目录哦 标题级别❤️新建文本样式 快速套用格式🩷设置标题样式 自定义设置多级编号🧡使用自动编号💛取消自动编号💚设置 页面边框💙添加水印🩵排版技巧怎么分栏💜添加空白下划线&#x…

【编程实践】vscode+pyside6环境部署

1 PySide6简介 PySide6是Qt for Python的官方版本,支持Qt6,提供Python访问Qt框架的接口。优点包括官方支持、LGPL许可,便于商业应用,与Qt6同步更新,支持最新特性。缺点是相比PyQt5,社区资源较少。未来发展…

soular基础教程-使用指南

soular是TikLab DevOps工具链的统一帐号中心,今天来介绍如何使用 soular 配置你的组织、工作台,快速入门上手。  1. 账号管理 可以对账号信息进行多方面管理,包括分配不同的部门、用户组等,从而确保账号权限和职责…