nnUNet V2代码——nnUNetv2_train命令

完成数据预处理命令后,开始训练

本文目录

  • 训练代码入口
    • nnUNetv2_train命令行参数
    • run_training函数
  • 训练代码
  • 训练结束

训练代码入口

nnU-Net V2 的训练命令是nnUNetv2_train

nnUNetv2_train命令行参数

参数名称是否必填默认值描述
dataset_name_or_id用于训练的数据集名称或 ID
configuration需要训练的配置
fold5 折交叉验证的折数。应为 0 到 4 之间的整数
-trnnUNetTrainer指定自定义训练器。默认值为 nnUNetTrainer
-pnnUNetPlans指定自定义计划标识符。默认值为 nnUNetPlans
-pretrained_weightsNone用于预训练模型的 nnU-Net checkpoint文件路径。仅在实际训练时使用。测试版,请谨慎使用
-num_gpus1指定训练时使用的 GPU 数量
–use_compressedFalse如果设置此标志,训练数据(预处理后生成的压缩版数据集)将不会被解压缩。读取压缩数据会消耗更多 CPU 和(可能)内存,仅在您知道自己在做什么时使用
–npzFalse将最终验证的 softmax 预测(数值为概率值,不是类别)保存为 npz 文件(除了预测的分割结果)。这对于找到最佳集成是必需的
–cFalse从上次训练结束处开始(已经完成训练则不用)
–valFalse设置此标志以仅运行验证。需要训练已完成
–val_bestFalse如果设置,验证将使用 checkpoint_best 而不是 checkpoint_final。与 --disable_checkpointing 不兼容!警告:这将使用与常规验证相同的“validation”文件夹,无法区分两者!(val数据集没法区分checkpoint_best 还是checkpoint_final)
–disable_checkpointingFalse设置此标志以禁用checkpoint保存。适合测试时使用,避免硬盘被checkpoint文件填满
-devicecuda设置训练运行的设备。可用选项为 ‘cuda’(GPU)、‘cpu’(CPU)和 ‘mps’(Apple M1/M2)。不要用此参数设置 GPU ID!请使用 CUDA_VISIBLE_DEVICES=X nnUNetv2_train […] 代替!

确定执行该命令后,首先调用run_training_entry函数,该函数会收集用户在命令行输入的参数,调用同文件下的run_training函数,并将收集的命令行参数传递给它。

run_training_entry函数和run_training函数代码均在nnUNet \ nnunetv2 \ run \ run_training.py文件中。

run_training函数

run_training函数在检查必要参数后,判断GPU数量,多GPU需要配置环境,单GPU不需要。

nnUNetv2_train命令有多处关于多GPU训练的代码,之后会集中一篇文章阅读。🏃🏃🏃

无论哪种情况,run_training函数都有如下操作:

1️⃣首先调用get_trainer_from_args函数,获取用于训练的nnunet_trainer变量,默认是实例化后的nnUNetTrainer类。该函数依次完成查询类、配置文件、实例化,代码结构清晰,不做粘贴:

################# run_training函数部分代码
# 实例化的nnUNetTrainer类
nnunet_trainer = get_trainer_from_args(dataset_name_or_id, configuration, fold, tr, p,use_compressed)

2️⃣之后完成一些训练前的设置:

################# run_training函数部分代码
# 是否保存网络训练后的权重
if disable_checkpointing:nnunet_trainer.disable_checkpointing = disable_checkpointingassert not (continue_training and only_run_validation), f'Cannot set --c and --val flag at the same time. Dummy.'# 加载预训练权重
maybe_load_checkpoint(nnunet_trainer, continue_training, only_run_validation, pretrained_weights)if torch.cuda.is_available():cudnn.deterministic = False	# 允许 cuDNN 选择最快的卷积算法,从而加速训练过程cudnn.benchmark = True	# 启用 cuDNN 的自动调优功能,找到最适合当前输入大小和硬件的算法,从而加速训练

3️⃣之后运行nnunet_trainer . run_training函数(重名,注意区分 ❗️❗️❗️)和nnunet_trainer . perform_actual_validation函数,完成 train 和 validate:

################# run_training函数部分代码
# 开启训练
if not only_run_validation:nnunet_trainer.run_training()# 是否使用best权重
# nnU-Net V2在训练过程中会生成三个checkpoint.pth
# 分别是checkpoint_best.pth、checkpoint_final.pth、checkpoint_latest.pth
# 由名称可以看出,分别是最佳、最终、最新训练权重
# checkpoint_final.pth会在训练结束时生成,读者如果需要在训练过程中predict,
# 可以在同文件夹下复制checkpoint_best.pth或checkpoint_latest.pth,更改名称后predict
# 上文的参数–val_best也涉及这一点。
if val_with_best:nnunet_trainer.load_checkpoint(join(nnunet_trainer.output_folder, 'checkpoint_best.pth'))
# 开始测试val数据集
nnunet_trainer.perform_actual_validation(npz)

训练代码

与预处理命令涉及的类不同,nnUNetTrainer类的各个函数代码较长且嵌套较深,因此,在阅读nnUNetTrainer类时,我们不再采用单个函数一个接一个的方式阅读代码,而是按照训练过程依次阅读,以便更好地理解其整体流程和设计思路。

根据上文,训练过程主要由nnunet_trainer变量的run_training函数(重名,注意区分 ❗️❗️❗️)完成,该变量默认是实例化的nnUNetTrainer类(nnU-Net V2文档介绍到,读者可以自定义该类)。

以下将用run_training函数指称nnunet_trainer . run_training函数,不再和上文的run_training函数区分

run_training函数代码如下:

######################## run_training函数代码
def run_training(self):### 训练开始self.on_train_start()for epoch in range(self.current_epoch, self.num_epochs):### epoch开始self.on_epoch_start()### epoch train 开始self.on_train_epoch_start()train_outputs = []### 一个epoch会train 250次(默认值,在nnUNetTrainer类的__init__函数中会讲到)for batch_id in range(self.num_iterations_per_epoch):### 250 次的一次,one steptrain_outputs.append(self.train_step(next(self.dataloader_train)))### epoch train 结束self.on_train_epoch_end(train_outputs)with torch.no_grad():### epoch val 开始self.on_validation_epoch_start()val_outputs = []### 一个epoch会val 50次(默认值,在nnUNetTrainer类的__init__函数中会讲到)for batch_id in range(self.num_val_iterations_per_epoch):### 50 次的一次,one stepval_outputs.append(self.validation_step(next(self.dataloader_val)))### epoch val 结束self.on_validation_epoch_end(val_outputs)### epoch结束self.on_epoch_end()### 训练结束self.on_train_end()

流程如下:

训练开始
self.on_train_start函数
epoch开始
self.on_epoch_start
epoch train 开始
self.on_train_epoch_start
train一次 step
是否完成 250 次训练?
epoch train 结束
self.on_train_epoch_end
epoch val 开始
self.on_validation_epoch_start
val一次 step
是否完成 50 次val?
epoch val 结束
self.on_validation_epoch_end
epoch结束
self.on_epoch_end
是否完成所有 epoch?
训练结束
self.on_train_end

整合其中部分步骤后,阅读顺序如下:

  1. 训练开始(包含dataloader):暂留坑
  2. epoch开始:暂留坑
  3. epoch train开始:暂留坑
  4. 一次train:暂留坑
  5. epoch train结束:暂留坑
  6. epoch val开始:暂留坑
  7. 一次val:暂留坑
  8. epoch val结束:暂留坑
  9. epoch结束:暂留坑
  10. 训练结束:暂留坑

训练结束

训练结束后,nnU-Net会用checkpoint_final.pth(除非用户指定使用best版,否则是final版,上文参数有说明)对val数据集测试,得出本折指标

暂留坑

至此训练结束

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/24067.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

使用 Open3D 批量渲染并导出固定视角点云截图

一、前言 在三维点云处理与可视化中,固定视角批量生成点云渲染截图是一个常见的需求。例如,想要将同一系列的点云(PCD 文件)在同样的视角下生成序列图片,以便后续合成为视频或进行其他可视化演示。本文将介绍如何使用…

c++的继承

封装、继承和多态是c的三大特性,他们的关系甚为紧密 封装的概念简单易懂,其实就是将数据和操作数据的方法结合在一起,形成一个独立的单元(类),通过访问控制符(如private、protected和public&…

3dtiles平移旋转工具制作

3dtiles平移旋转缩放原理及可视化工具实现 背景 平时工作中,通过cesium平台来搭建一个演示场景是很常见的事情。一般来说,演示场景不需要多完善的功能,但是需要一批三维模型搭建,如厂房、电力设备、园区等。在实际搭建过程中&…

我是如何从 0 到 1 找到 Web3 工作的?

作者:Lotus的人生实验 关于我花了一个月的时间,从 0 到 1 学习 Web3 相关的知识和编程知识。然后找到了一个 Web3 创业公司实习的远程工作。 👇👇👇 我的背景: 计算机科班,学历还可以(大厂门槛水平) 毕业工…

进程状态(R|S|D|t|T|X|Z)、僵尸进程及孤儿进程

文章目录 一.进程状态进程排队状态:运行、阻塞、挂起 二.Linux下的进程状态R 运行状态(running)S 睡眠状态(sleeping)D 磁盘休眠状态(Disk sleep)t 停止、暂停状态(tracing stopped)T 停止、暂停状态(stopp…

为什么要将PDF转换为CSV?CSV是Excel吗?

在企业和数据管理的日常工作中,PDF文件和CSV文件承担着各自的任务。PDF通常用于传输和展示静态的文档,而CSV因其简洁、易操作的特性,广泛应用于数据存储和交换。如果需要从PDF中提取、分析或处理数据,转换为CSV格式可能是一个高效…

Starlink卫星动力学系统仿真建模第十讲-基于SMC和四元数的卫星姿态控制示例及Python实现

基于四元数与滑模控制的卫星姿态控制 一、基本原理 1. 四元数姿态表示 四元数运动学方程: 3. 滑模控制设计 二、代码实现(Python) 1. 四元数运算工具 import numpy as npdef quat_mult(q1, q2):"""四元数乘法""…

CSS—引入方式、选择器、复合选择器、文字控制属性、CSS特性

目录 CSS 1.引入方式 2.选择器 3.复合选择器 4.文字控制属性 5.CSS特性 CSS 层叠样式表,是一种样式表语言,用来描述HTML文档的呈现 书写时一般按照顺序:盒子模型属性—>文字样式—>圆角、阴影等修饰属性 1.引入方式 引入方式方…

OpenHarmony-4.基于dayu800 GPIO 实践(2)

基于dayu800 GPIO 进行开发 1.DAYU800开发板硬件接口 LicheePi 4A 板载 2x10pin 插针,其中有 16 个原生 IO,包括 6 个普通 IO,3 对串口,一个 SPI。TH1520 SOC 具有4个GPIO bank,每个bank最大有32个IO:  …

win11 24h2 远程桌面 频繁断开 已失去连接 2025

一、现象 Windows11自升级2025年2月补丁后版本号为系统版本是26100.3194,远程桌面频繁断开连接,尝试连接,尤其在连接旧的server2012 二、临时解决方案 目前经测试,在组策略中,远程桌面连接客户端,关闭客户…

rust学习笔记6-数组练习704. 二分查找

上次说到rust所有权看看它和其他语言比有什么优势,就以python为例 # Python3 def test():a [1, 3, -4, 7, 9]print(a[4])b a # 所有权没有发生转移del b[4]print(a[4]) # 由于b做了删除,导致a再度访问报数组越界if __name__ __main__:test() 运行结…

Windows安装NVIDIA显卡CUDAD调用GPU,适用于部署deepseek r1

显卡、显卡驱动、CUDA之间的关系 显卡:(GPU),主流是NVIDIA的GPU,因为深度学习本身需要大量计算。GPU的并行计算能力,在过去几年里恰当地满足了深度学习的需求。AMD的GPU基本没有什么支持,可以不…

基于无人机遥感的烟株提取和计数研究

一.研究的背景、目的和意义 1.研究背景及意义 烟草作为我国重要的经济作物之一,其种植面积和产量的准确统计对于烟草产业的发展和管理至关重要。传统的人工烟株计数方法存在效率低、误差大、难以覆盖大面积烟田等问题,已无法满足现代烟草种植管理的需求…

《深度学习实战》第3集:循环神经网络(RNN)与序列建模

第3集:循环神经网络(RNN)与序列建模 引言 在深度学习领域,处理序列数据(如文本、语音、时间序列等)是一个重要的研究方向。传统的全连接网络和卷积神经网络(CNN)难以直接捕捉序列中…

【前沿探索篇七】【DeepSeek自动驾驶:端到端决策网络】

第一章 自动驾驶的"感官革命":多模态神经交响乐团 1.1 传感器矩阵的量子纠缠 我们把8路摄像头+4D毫米波雷达+128线激光雷达的融合称为"传感器交响乐",其数据融合公式可以简化为: def sensor_fusion(cam, radar, lidar):# 像素级特征提取 (ResNet-152…

可狱可囚的爬虫系列课程 13:Requests使用代理IP

一、什么是代理 IP 代理 IP(Proxy IP)是一个充当“中间人”的服务器IP地址,用于代替用户设备(如电脑、手机等)直接与目标网站或服务通信。用户通过代理IP访问互联网时,目标网站看到的是代理服务器的IP地址&…

https:原理

目录 1.数据的加密 1.1对称加密 1.2非对称加密 2.数据指纹 2.1数据指纹实际的应用 3.数据加密的方式 3.1只使用对称加密 3.2只使用非对称加密 3.3双方都使用对称加密 3.4非对称加密和对称加密一起使用 4.中间人攻击 5.CA证书 5.1什么是CA证书 CA证书的验证 6.https的原理 1.数据…

Github项目管理之 其余分支同步main分支

文章目录 方法:通过 Pull Request 同步分支1. **创建一个从 main 到目标分支的 Pull Request**2. **合并 Pull Request** 注意事项总结 在 GitHub 网页上,你可以通过 Pull Request 的方式将一个分支(例如 main 分支)的修改同步到…

Aseprite绘画流程案例(5)——花盆

1.最终图片效果 参考素材来源于:手绘像素画第三课:像素画盆花示范(无参考图)_哔哩哔哩_bilibili 2.流程 1.新建画布40X27的画布,打开显示网格,背景色为白色 2.画出梯形的盆 3.给盆进行亮暗对比上色 4.添…

【模板】csdn markdown语法演示

这里写自定义目录标题 欢迎使用Markdown编辑器新的改变功能快捷键合理的创建标题,有助于目录的生成如何改变文本的样式插入链接与图片如何插入一段漂亮的代码片生成一个适合你的列表创建一个表格设定内容居中、居左、居右SmartyPants 创建一个自定义列表如何创建一个…