WandB使用笔记

最近看代码,发现代码中有wandb有关的内容,搜索了一下发现是一个模型训练工具,然后学习了一下,这里记录一下使用过程,方便以后查阅。

WandB使用笔记

  • 登录WandB 并 创建团队
  • 安装 WandB 并 登录
  • 模型训练过程跟踪
  • 模型版本管理
  • 自动调参
  • 不同的模型训练工具对比
  • 参考资料

作者自注:之前训练模型一直使用的是Visdom,感觉非常好用,然后现在学习了一下WandB,发现先各有优劣。Visdom的曲线实时跟踪效果好,但是功能简单。WandB曲线实时跟踪效果差(可能是我的网的问题),但是功能强大,可以保存每次模型调优的参数,这样就不用手动再记录了;可以实现模型的版本管理,这样就可以随便改代码,不用担心改坏了;可以进行参数分析,这样就可以有目的的进行参数调优;可以进行自动调参,这样可在完成粗调制后进行局部的参数寻优。感觉以后两个可以同时使用,提高模型调优的效率

登录WandB 并 创建团队

点击下面的网站进入WandB:https://wandb.ai/site,然后点击界面中的 LOGIN 进行登录。

在这里插入图片描述

如下需要选择登录的方式,这里我选择的是 GitHub 。

在这里插入图片描述

完成登陆后进入如下初始界面,点击图片中红框中的内容,创建一个新的 team

在这里插入图片描述

之后进入如下界面,输入团队名称,并点击 Create team ,完成团队的创建。

在这里插入图片描述

团队创建成功后出现如下界面,选择是否把自己的 runs 更新到 team ,这里选择 Update

在这里插入图片描述

如此就完成了登录和团建创建过程!

如果想要删除创建的团队,则在主界面点击创建的团队,如下图所示:

在这里插入图片描述

进入团队后,点击 Team settings ,如下图所示:

在这里插入图片描述

接着滑动到最下面,点击 Delete team

在这里插入图片描述

接着需要你输入 团队的名称 进行删除,这里的逻辑跟GitHub删除项目一样。

在这里插入图片描述

安装 WandB 并 登录

使用 pip 安装 WandB:

pip install wandb

在这里插入图片描述

验证安装是否成功:

wandb --version

在这里插入图片描述

首次使用 WandB 时,需要登录账户:

wandb login

在这里插入图片描述

登录后,WandB 会提示输入 API 密钥。可以从 WandB 的 API 密钥页面 获取密钥,点击图片中的红框部分,复制密钥,然后粘贴到上图的 3 标识的地方,并点击回车,如此就完成了登录过程。

在这里插入图片描述

如果你之前已经登陆过了,则会出现如下的内容:
在这里插入图片描述

然后在终端输入如下的命令即可重新登录:

wandb login --relogin

在这里插入图片描述

模型训练过程跟踪

将如下代码复制到PyCharm中,进行实验。

import wandb
import torch
from torch import nn
import torchvision
from torchvision import transforms
import datetime
from argparse import Namespacedevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
config = Namespace(project_name='wandb_demo',batch_size=512,hidden_layer_width=64,dropout_p=0.1,lr=1e-4,optim_type='Adam',epochs=150,ckpt_path='checkpoint.pt'
)def create_dataloaders(config):transform = transforms.Compose([transforms.ToTensor()])ds_train = torchvision.datasets.MNIST(root="./mnist/", train=True, download=True, transform=transform)ds_val = torchvision.datasets.MNIST(root="./mnist/", train=False, download=True, transform=transform)ds_train_sub = torch.utils.data.Subset(ds_train, indices=range(0, len(ds_train), 5))dl_train = torch.utils.data.DataLoader(ds_train_sub, batch_size=config.batch_size, shuffle=True, drop_last=True)dl_val = torch.utils.data.DataLoader(ds_val, batch_size=config.batch_size, shuffle=False, drop_last=True)return dl_train, dl_valdef create_net(config):net = nn.Sequential()net.add_module("conv1", nn.Conv2d(in_channels=1, out_channels=config.hidden_layer_width, kernel_size=3))net.add_module("pool1", nn.MaxPool2d(kernel_size=2, stride=2))net.add_module("conv2", nn.Conv2d(in_channels=config.hidden_layer_width,out_channels=config.hidden_layer_width, kernel_size=5))net.add_module("pool2", nn.MaxPool2d(kernel_size=2, stride=2))net.add_module("dropout", nn.Dropout2d(p=config.dropout_p))net.add_module("adaptive_pool", nn.AdaptiveMaxPool2d((1, 1)))net.add_module("flatten", nn.Flatten())net.add_module("linear1", nn.Linear(config.hidden_layer_width, config.hidden_layer_width))net.add_module("relu", nn.ReLU())net.add_module("linear2", nn.Linear(config.hidden_layer_width, 10))net.to(device)return netdef train_epoch(model, dl_train, optimizer):model.train()for step, batch in enumerate(dl_train):features, labels = batchfeatures, labels = features.to(device), labels.to(device)preds = model(features)loss = nn.CrossEntropyLoss()(preds, labels)loss.backward()optimizer.step()optimizer.zero_grad()return modeldef eval_epoch(model, dl_val):model.eval()accurate = 0num_elems = 0for batch in dl_val:features, labels = batchfeatures, labels = features.to(device), labels.to(device)with torch.no_grad():preds = model(features)predictions = preds.argmax(dim=-1)accurate_preds = (predictions == labels)num_elems += accurate_preds.shape[0]accurate += accurate_preds.long().sum()val_acc = accurate.item() / num_elemsreturn val_acc
def train(config=config):dl_train, dl_val = create_dataloaders(config)model = create_net(config);optimizer = torch.optim.__dict__[config.optim_type](params=model.parameters(), lr=config.lr)# ======================================================================nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')wandb.init(project=config.project_name, config=config.__dict__, name=nowtime, save_code=True)model.run_id = wandb.run.id# ======================================================================model.best_metric = -1.0for epoch in range(1, config.epochs + 1):model = train_epoch(model, dl_train, optimizer)val_acc = eval_epoch(model, dl_val)if val_acc > model.best_metric:model.best_metric = val_acctorch.save(model.state_dict(), config.ckpt_path)nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')print(f"epoch【{epoch}】@{nowtime} --> val_acc= {100 * val_acc:.2f}%")# ======================================================================wandb.log({'epoch': epoch, 'val_acc': val_acc, 'best_val_acc': model.best_metric})# ======================================================================# ======================================================================wandb.finish()# ======================================================================return model

上述代码最关键的就是如下三个部分:

  1. 初始化部分:
wandb.init(project=config.project_name, config=config.__dict__, name=nowtime, save_code=True)
  1. 模型训练参数上传
wandb.log({'epoch': epoch, 'val_acc': val_acc, 'best_val_acc': model.best_metric})
  1. 模型训练完成关闭wandb:
wandb.finish()

最后在PyCharm中输入如下代码,即可运行上述代码:

model = train(config)

代码运行成功,即可出现如下的界面,点击下图中红框中的部分,即可跳转到曲线监视界面。

在这里插入图片描述

模型训练过程监视界面如下图所示:

在这里插入图片描述

点击下图中的红框部分,更改曲线的横坐标值。

在这里插入图片描述

如下图所示,将横坐标值更改为 epoch。

在这里插入图片描述

然后我们还可以增加一个 section

在这里插入图片描述

在新的 section 中添加新的显示模块,如下图所示:

在这里插入图片描述

此处我们添加了验证集的准确率,实现实时的监控。

在这里插入图片描述

模型训练结束,我们可以点击 runs 查看历史记录。

在这里插入图片描述

如下图可以看到,我们刚才监视的曲线,如图中的长方形红框所示。然后点击小红框中的 runs ,查看每一次训练过程的模型参数。

在这里插入图片描述

每一次模型训练的参数如下图所示,可以选择图中红框中的内容,选择需要的参数进行显示。

在这里插入图片描述

可选择的指标如下图所示:

在这里插入图片描述

对于某些我们比较关注的指标,我们可以将其固定显示:

在这里插入图片描述

固定后,我们回到 Workspace 界面,即可看到固定的参数。

在这里插入图片描述

模型版本管理

除了可以记录实验日志传递到 wandb 网站的云端服务器 并进行可视化分析。wandb还能够将实验关联的数据集,代码和模型 保存到 wandb 服务器。我们可以通过 wandb.log_artifact的方法来保存任务的关联的重要成果。例如 dataset, code,和 model,并进行版本管理。

当我们跑出一个相对不错的结果时,我们希望把这个结果给保存下来,此时我们就可以使用该功能。

我们先使用run_id 恢复 run任务,以便继续记录。

import wandb
# resume the run
run = wandb.init(project='wandb_demo', id='6h5xkv16', resume='allow')

上述代码中的 id 是用来关联我们训练的 runs 的,参数的值来自下图红框中的内容,想搞关联某一次的训练过程,就把某一次训练的 ID 写入上述代码。

在这里插入图片描述
保存数据集的代码:

# save dataset
arti_dataset = wandb.Artifact(name='mnist', type='dataset')
arti_dataset.add_dir('mnist/')
wandb.log_artifact(arti_dataset)

保存模型文件的代码:

# save code
arti_code = wandb.Artifact(name='py', type='code')
arti_code.add_file('./wandb_test.py')
wandb.log_artifact(arti_code)

保存模型权重的代码:

# save model
arti_model = wandb.Artifact(name='cnn', type='model')
arti_model.add_file(config.ckpt_path)
wandb.log_artifact(arti_model)

最后结束时要使用一下代码:

# finish时会提交保存
wandb.finish()

上传后的效果如图所示:

在这里插入图片描述

自动调参

sweep采用类似master-workers的controller-agents架构,controller在wandb的服务器机器上运行,agents在用户机器上运行,controller和agents之间通过互联网进行通信。同时启动多个agents即可轻松实现分布式超参搜索。

在这里插入图片描述

使用Sweep的3步骤:

  1. 配置 sweep_config
# 配置 Sweep config
sweep_config = {'method': 'random',  # 选择调优算法,超参数搜索方法:随机搜索'metric': {          # 定义调优目标'name': 'val_acc','goal': 'maximize'},'parameters': {     # 定义超参空间'project_name': {'value': 'wandb_demo'},    # 固定不变的超参'epochs': {'value': 10},'ckpt_path': {'value': 'checkpoint.pt'},'optim_type': {                             # 离散型分布超参'values': ['Adam', 'SGD', 'AdamW']},'hidden_layer_width': {'values': [16, 32, 48, 64, 80, 96, 112, 128]},'lr': {                                     # 连续型分布超参'distribution': 'log_uniform_values','min': 1e-6,'max': 0.1},'batch_size': {'distribution': 'q_uniform','q': 8,'min': 32,'max': 256,},'dropout_p': {'distribution': 'uniform','min': 0,'max': 0.6,}},# 'early_terminate': {    # 定义剪枝策略 (可选)#     'type': 'hyperband',    # 使用 HyperBand 作为早停策略#     'min_iter': 3,          # 最小评估迭代次数(第 3 次迭代后开始考虑剪枝)#     'eta': 2,               # 成倍增长的资源分配比例(每次迭代中仅保留约 1/eta 的实验)#     's': 3                  # HyperBand 的最大阶数,影响资源分配的层级# }
}
from pprint import pprint
pprint(sweep_config)

Sweep支持如下3种调优算法:

(1)网格搜索:grid. 遍历所有可能得超参组合,只在超参空间不大的时候使用,否则会非常慢。

(2)随机搜索:random. 每个超参数都选择一个随机值,非常有效,一般情况下建议使用。

(3)贝叶斯搜索:bayes.
创建一个概率模型估计不同超参数组合的效果,采样有更高概率提升优化目标的超参数组合。对连续型的超参数特别有效,但扩展到非常高维度的超参数时效果不好。

  1. 初始化 sweep controller
# 初始化 sweep controller
sweep_id = wandb.sweep(sweep_config, project=config.project_name)
  1. 启动 sweep agents
# 启动 Sweep agent
# 该agent 随机搜索 尝试5次
wandb.agent(sweep_id, train, count=5)

等代码跑完我们就有了一个 sweep,如下图所示:

在这里插入图片描述

进入 sweep 之后就可以添加 Parallel coordinatesParameter importance 进行参数分析。

在这里插入图片描述
在这里插入图片描述

在这里插入图片描述

不同的模型训练工具对比

工具实验管理数据版本控制模型部署团队协作离线支持特点
TensorBoard轻量级工具,适合快速原型开发
WandB功能全面,支持超参数调优和实时协作
Comet简单易用,支持离线模式
MLflow实验管理与模型部署一体化
Neptune强大的可视化功能
Sacred极简实验管理工具
Polyaxon分布式训练与大规模实验管理支持
DVC专注于数据和模型版本控制
ClearML全面的 MLOps 功能

参考资料

30分钟吃掉wandb模型训练可视化

wandb我最爱的炼丹伴侣操作指南

30分钟吃掉wandb可视化自动调参

wandb可视化调参完全指南

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/504093.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

中国科技统计年鉴EXCEL版(2021-2023年)-社科数据

中国科技统计年鉴EXCEL版(2021-2023年)-社科数据https://download.csdn.net/download/paofuluolijiang/90028724 https://download.csdn.net/download/paofuluolijiang/90028724 中国科技统计年鉴提供了从2021至2023年的详尽数据,覆盖了科技…

Cursor无限续杯——解决Too many free trials.

前情提要 我们都知道Cursor对新用户是有14天且500条免费限制的。 一般情况下,当14天过期,是可以注销账户再重新注册,这样就可以继续拥有14天的体验时长。 但是!!如果使用超过500次,Cusor就会把你的电脑I…

深入学习RabbitMQ的Direct Exchange(直连交换机)

RabbitMQ作为一种高性能的消息中间件,在分布式系统中扮演着重要角色。它提供了多种消息传递模式,其中Direct Exchange(直连交换机)是最基础且常用的一种。本文将深入介绍Direct Exchange的原理、应用场景、配置方法以及实践案例&a…

Mysql--基础篇--事务(ACID特征及实现原理,事务管理模式,隔离级别,并发问题,锁机制,行级锁,表级锁,意向锁,共享锁,排他锁,死锁,MVCC)

在MySQL中,事务(Transaction)是一组SQL语句的集合,这些语句一起被视为一个单一的工作单元。事务具有ACID特性,确保数据的一致性和完整性。通过事务,可以保证多个操作要么全部成功执行,要么全部不…

Linux下文件重定向

文章目录 一 重定向的基本使用1. 标准输出重定向2. 标准错误输出重定向3. 同时重定向标准输出和标准错误输出4. 输入重定向&#xff08;<&#xff09; 二 重定向基本原理1. 文件描述符概念2.什么是文件描述符3. 文件描述符的分配规则初始分配与默认对应关系动态分配规则 4. …

Android车载音频系统目录

目录 第一章 1.1 Android Automotive&#xff08;一&#xff09; 1.2 Android Automotive&#xff08;二&#xff09; 1.3 Android Automotive&#xff08;三&#xff09; 第二章 2.1 Android车载音频系统概览 2.2 车载音频焦点 2.3 车载音频配置 2.4 Audio control HAL…

【Linux】深入理解文件系统(超详细)

目录 一.磁盘 1-1 磁盘、服务器、机柜、机房 &#x1f4cc;补充&#xff1a; &#x1f4cc;通常网络中用高低电平&#xff0c;磁盘中用磁化方向来表示。以下是具体说明&#xff1a; &#x1f4cc;如果有一块磁盘要进行销毁该怎么办&#xff1f; 1-2 磁盘存储结构 ​编辑…

Spring项目创建流程及配置文件bean标签参数简介

1. 项目搭建流程 1. pom.xml中引入依赖Spring-webMVC <!-- https://mvnrepository.com/artifact/org.springframework/spring-webmvc --><dependency><groupId>org.springframework</groupId><artifactId>spring-webmvc</artifactId><…

PHP进阶-在Ubuntu上搭建LAMP环境教程

本文将为您提供一个在Ubuntu服务器上搭建LAMP&#xff08;Linux, Apache, MySQL, PHP&#xff09;环境的完整指南。通过本文&#xff0c;您将学习如何安装和配置Apache、MySQL、PHP&#xff0c;并将您的PHP项目部署到服务器上。本文适用于Ubuntu 20.04及更高版本。 一、系统更新…

Web应用安全-漏洞扫描器设计与实现

摘 要 随着Web2.0、社交网络、微博等一系列新型的互联网产品的诞生&#xff0c;基于Web环境的互联网应用越来越广泛&#xff0c;企业信息化的过程中各种应用都架设在Web平台上。Web应用的迅速发展也引起黑客们的强烈关注&#xff0c;接踵而至的就是Web安全威胁的凸显&#xff…

【漏洞工具】小米路由器任意文件读取漏洞python图形化框架利用工具(poc|exp)

there is no tomorrow 工具利用 漏洞扫描 漏洞利用 poc 本文工具、源码获取 因本工具涉及到源码分享&#xff0c;如有需求&#xff0c;请私聊圈主 文笔生疏&#xff0c;措辞浅薄&#xff0c;望各位大佬不吝赐教&#xff0c;万分感谢。 免责声明&#xff1a;由于传播或利用…

【Logstash03】企业级日志分析系统ELK之Logstash 过滤 Filter 插件

Logstash 过滤 Filter 插件 数据从源传输到存储库的过程中&#xff0c;Logstash 过滤器能够解析各个事件&#xff0c;识别已命名的字段以构建结构&#xff0c; 并将它们转换成通用格式&#xff0c;以便进行更强大的分析和实现商业价值。 Logstash 能够动态地转换和解析数据&a…

游戏关卡设计的常用模式

游戏关卡分为很多种&#xff0c;但常用的有固定套路&#xff0c;分为若干种类型。 关卡是主角与怪物、敌方战斗的场所&#xff0c;包括装饰物、通道。 单人游戏的关卡较小&#xff0c;偏线性&#xff1b; 联机/MMO的关卡较大&#xff0c;通道多&#xff0c;自由度高&#xf…

用OpenCV实现UVC视频分屏

分屏 OpencvUVC代码验证后话 用OpenCV实现UVC摄像头的视频分屏。 Opencv opencv里有很多视频图像的处理功能。 UVC Usb 视频类&#xff0c;免驱动的。视频流格式有MJPG和YUY2。MJPG是RGB三色通道的。要对三通道进行分屏显示。 代码 import cv2 import numpy as np video …

用户界面软件02

基于表单的用户界面 在“基于表单的用户界面”里面&#xff0c;用户开始时选中某个业务处理&#xff08;模块&#xff09;&#xff0c;然后应用程序就使用一系列的表单来引导用户完成整个处理过程。大型机系统上的大部分用户界面都是这样子的。[Cok97]中有更为详细的讨论。 面…

使用Registry explore实现法医检查练习

Windows Forensics 1&#xff08;windows 取证&#xff09; 第一题&#xff1a; 关于用户的基本都在sam注册表中&#xff0c;所以使用Registry explore&#xff0c;添加一个sam进来检查&#xff0c;通常sam注册表都是在C:\Windows\System32\config中 接着就可以开始我们的检验…

Linux服务器网络不通问题排查及常用命令使用

在PVE主机上创建虚拟机&#xff0c;并配置静态ip和dns后&#xff0c;主机可以正常访问网络&#xff0c;但是在宿主机或者其他机器上都无法访问该虚拟机。 检查ip是否联通且端口是否开启 如果ip无法ping通&#xff0c;可能是静态ip配置、网卡或桥接设置问题。 [k8slocalhost …

道品科技智慧农业与云平台:未来农业的变革之路

随着全球人口的不断增长&#xff0c;农业面临着前所未有的挑战。如何在有限的土地和资源上提高农业生产效率&#xff0c;成为了各国政府和农业从业者亟待解决的问题。智慧农业的兴起&#xff0c;结合云平台的应用&#xff0c;为农业的可持续发展提供了新的解决方案。 ## 一、智…

《C++11》右值引用深度解析:性能优化的秘密武器

C11引入了一个新的概念——右值引用&#xff0c;这是一个相当深奥且重要的概念。为了理解右值引用&#xff0c;我们需要先理解左值和右值的概念&#xff0c;然后再理解左值引用和右值引用。本文将详细解析这些概念&#xff0c;并通过实例进行说明&#xff0c;以揭示右值引用如何…

【OJ刷题】同向双指针问题

这里是阿川的博客&#xff0c;祝您变得更强 ✨ 个人主页&#xff1a;在线OJ的阿川 &#x1f496;文章专栏&#xff1a;OJ刷题入门到进阶 &#x1f30f;代码仓库&#xff1a; 写在开头 现在您看到的是我的结论或想法&#xff0c;但在这背后凝结了大量的思考、经验和讨论 目录 1…