Tensorflow预训练模型转PyTorch

深度学习领域是计算机科学中变化最快的领域之一。大约 5 年前,当我开始研究这个主题时,TensorFlow 被认为是主导框架。如今,大多数研究人员已经转向 PyTorch。

NSDT工具推荐: Three.js AI纹理开发包 - YOLO合成数据生成器 - GLTF/GLB在线编辑 - 3D模型格式在线转换 - 可编程3D场景编辑器 - REVIT导出3D模型插件 - 3D模型语义搜索引擎 - Three.js虚拟轴心开发包 - 3D模型在线减面 - STL模型在线切割 

虽然这种快节奏令人兴奋,但也带来了很多挑战。最近,我面临着继续完成 2018 年开展的一个项目的任务。一位同事在大量临床数据集上训练了一个分割模型,并报告了出色的性能。

今天,我们的目标是在称为迁移学习的过程中,将该训练好的模型用于类似的目标。这里的直觉是,与其从头开始,不如至少部分使用预训练权重来实例化新模型的权重,这将提供一个更好的起点。

1、收集 Tensorflow 1.x 权重

这听起来比实际容易。在 tensorflow 1.x 中,模型保存在四个单独的文件中 - 其中没有一个可以直接转换为 pytorch 的 state_dict。为了解决这个问题,我们必须手动创建一个字典并从 tensorflow 后端检索权重。

为了实现这一点,你需要了解 tensorflow 实现的命名方案。每个操作都可以在创建时分配一个名称。这个名称在稍后转换为 pytorch 时很重要。

import tensorflow as tf  # tensorflow 1.x
import pickle'''
<base_folder>
├───checkpoint
├───<model_name>.meta
├───<model_name>.data-00000-of-00001
└───<model_name>.index
'''# First let's load meta graph and restore weights
sess = tf.Session()
saver = tf.train.import_meta_graph(r'<base_folder>\<model_name>.meta')
saver.restore(sess, tf.train.latest_checkpoint(r'<base_folder>'))# get all trainable weights and save them in a dictionary
vars = sess.graph.get_collection('trainable_variables')
weights = {}
for v in vars:weights[v.name] = sess.run(v)  # retrieve the value from the tf backendwith open('weights.pickle', 'wb') as handle:pickle.dump(weights, handle, protocol=pickle.HIGHEST_PROTOCOL)

2、重建模型

遗憾的是,没有直接的方法将 TensorFlow 模型转换为 PyTorch。但是,尽管语法略有不同,但大多数层都存在于这两个框架中。例如,在 tf1 中,卷积层可以包含激活函数,而在 PyTorch 中,该函数需要按顺序添加。

此示例展示了 tf1 和 PyTorch 实现中流行的 UNet 架构的 upconv 块。

# >>> tf1 implementation (without encapsulating class)
import tensorflow as tfdef upconvcat(self, x1, x2, n_filter, name):x1 = tf.keras.layers.UpSampling2D((2, 2))(x1)x1 = tf.layers.conv2d(x1, filters=n_filter, kernel_size=(3, 3), padding='same', name="upsample_{}".format(name))return tf.concat([x1, x2], axis=-1, name="concat_{}".format(name))  # NHWC format# >>> pytorch implementation
import torchclass UpConvCat(nn.Module):def __init__(self, in_channels, out_channels):super().__init__()self.up = torch.nn.Upsample(scale_factor=2)self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)def forward(self, x1, x2):x1 = self.up(x1)return torch.cat([x1, x2], dim=1)  # NCHW format

3、NHWC 与 NCHW

tensorflow 和 pytorch 之间的最后一个重要区别是关于轴的约定。

在旧版 tensorflow 中, data_format 属性可以指定为 channels_last 或 channels_first ,而前者是默认选项。然而,在 pytorch 中,只能使用 channels first。通常,这些格式表示为 NHWC 和 NCHW,分别表示批处理大小 (N)、高度 (H)、宽度 (W) 和通道 (C)。

np.transpose(kernel, (3, 2, 0, 1))

如果使用默认的 channels_lastoption 训练 tensorflow 中的预训练模型,则需要对内核轴进行置换才能与 torch 一起使用。为了弥补这一点,需要像这样调整 2d-conv 层权重。

4、初始化 pytorch 模型

将权重转换为正确的格式后,我们可以将它们加载到 pytorch 模型中。为此,我们随机实例化一个模型并遍历命名参数列表。然后我们使用来自 tensorflow 的权重就地修改参数。

    # set new weights from loaded tf valueswith torch.no_grad():for (name, param), (tf_name, tf_param) in zip(m.named_parameters(), tf_weights.items()):# convert NHWC to NCHW format and copy to change memory layouttf_param = np.transpose(tf_param, (3, 2, 0, 1)).copy() if len(tf_param.shape) == 4 else tf_paramassert tf_param.shape == param.detach().numpy().shape, name# https://discuss.pytorch.org/t/how-to-assign-an-arbitrary-tensor-to-models-parameter/44082/3param.copy_(torch.tensor(tf_param, requires_grad=True, dtype=param.dtype))

5、结束语

按照这些步骤,可以提取在 tensorflow 1.x 中训练的模型并将其转换为 pytorch 模型。我希望这对与我处境相似的人有所帮助。


原文链接:TF预训练模型转PyTorch - BimAnt

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/393601.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

无人机无线电监测设备技术分析

随着无人机技术的飞速发展&#xff0c;其在民用、军事、科研及娱乐等领域的广泛应用&#xff0c;对无线电频谱资源的有效管理和监测提出了更高要求。无人机无线电监测设备作为保障空域安全、维护无线电秩序的重要工具&#xff0c;集成了高精度定位、频谱扫描、信号分析、数据处…

《Token-Label Alignment for Vision Transformers》ICCV2023

摘要 这篇论文探讨了数据混合策略&#xff08;例如CutMix&#xff09;在提高卷积神经网络&#xff08;CNNs&#xff09;性能方面的有效性&#xff0c;并指出这些策略在视觉Transformer&#xff08;ViTs&#xff09;上同样有效。然而&#xff0c;发现了一个“token fluctuation…

Axure RP界面设计初探:基础操作与实用技巧

Axure RP是目前流行的设计精美的用户界面和交互软件。Axure RP提供了一组丰富的RP。 UI 控件&#xff0c;这些控件根据它们的应用领域进行分类。作为Axure的国产替代品&#xff0c;它可以在线协同工作&#xff0c;浏览器可以在不下载客户端的情况下立即打开和使用。如果以前用A…

OpenCV专栏介绍

在当今人工智能和计算机视觉领域&#xff0c;OpenCV作为一个功能强大的开源库&#xff0c;已经成为实现各种视觉算法的基石。本“OpenCV”专栏致力于帮助读者深入理解并掌握OpenCV的使用&#xff0c;从而在计算机视觉项目中发挥关键作用。 专栏导读 随着技术的不断进步&#…

e6.利用 docker 快速部署自动化运维平台

利用 docker 快速部署自动化运维平台 1. 安装docker2. 拉取镜像3. 启动容器4. 初始化5. 访问测试 Spug 面向中小型企业设计的轻量级无 Agent 的自动化运维平台&#xff0c;整合了主机管理、主机批量执行、主 机在线终端、文件在线上传下载、应用发布部署、在线任务计划、配置中…

Java程序员接单分享

作为一名Java程序员&#xff0c;这阵子通过承接些小型项目&#xff0c;我顺利跨过了月薪破万的门槛。这些项目虽小&#xff0c;却如同磨刀石般&#xff0c;让我在实战中发现了自身技术栈的棱角与不足&#xff0c;尤其是意识到了在Java这一浩瀚技术海洋中的诸多未知领域。我深知…

pytorch和deep learning技巧和bug解决方法短篇收集

有一些几句话就可以说明白的观点或者解决的的问题&#xff0c;小虎单独收集到这里。 torch.hub.load how does it work 下载预训练模型再载入&#xff0c;用程序下载链接可能失效。 model torch.hub.load(ultralytics/yolov5, yolov5s)model torch.hub.load(ultralytics/y…

WPF学习(10)-Label标签+TextBlock文字块+TextBox文本框+RichTextBox富文本框

Label标签 Label控件继承于ContentControl控件&#xff0c;它是一个文本标签&#xff0c;如果您想修改它的标签内容&#xff0c;请设置Content属性。我们曾提过ContentControl的Content属性是object类型&#xff0c;意味着Label的Content也是可以设置为任意的引用类型的。 案…

我的cesium for UE踩坑之旅(蓝图、UI创建)

我的小小历程 过程创建对应目录&#xff0c;并将要用到的图片、资源放入对应目录下内容浏览器 窗口中右键&#xff0c;创建一个控件蓝图&#xff0c;用来编辑界面UI绘制画布面板&#xff08;canvas&#xff09;调整整体布局加入对应的控件将UI加入到关卡中 备注搜索不到 Add To…

如何在Zoom中集成自己的app?一个简单的例子

一、注册zoom 账号、以便在zoom app maketplace创建app。 二、安装git、node.js、vscode开发环境&#xff08;略&#xff09;。 三、注册ngrok账号&#xff0c;获得一个免费的https静态域名。 四、配置zoom app(wxl)&#xff0c;设置上一步获得的https静态域名&#xff0c;验证…

进阶学习-----练习线程思维解决实际问题

线程在IT行业的实际应用 1. 多线程编程 在软件开发中&#xff0c;多线程编程是一种常见的技术&#xff0c;它允许程序同时执行多个任务。以下是多线程编程的一些具体应用&#xff1a; 任务分解&#xff1a;将一个大的任务分解为多个小任务&#xff0c;每个小任务由一个线程执…

C#基础——泛型

泛型 C# 中的泛型是一种强大的编程特性&#xff0c;它允许你编写类型安全且灵活的代码。泛型允许你定义类、结构体、接口、方法和委托&#xff0c;而不必在编译时指定具体的数据类型。相反&#xff0c;你可以使用类型参数来定义泛型类型或方法&#xff0c;然后在使用时指定具体…

springboot高校实验室安全管理系统-计算机毕业设计源码73839

目 录 摘要 1 绪论 1.1 研究背景 1.2 选题意义 1.3研究方案 1.4论文章节安排 2相关技术介绍 2.1 B/S结构 2.2 Spring Boot框架 2.3 Java语言 2.4 MySQL数据库 3系统分析 3.1 可行性分析 3.2 系统功能性分析 3.3.非功能性分析 3.4 系统用例分析 3.5系统流程分析…

算法板子:最短路问题——包含朴素Dijkstra算法、堆优化版的Dijkstra算法、SPFA算法、Floyd算法

目录 1. 几种算法的用途2. Dijkstra算法——求源点到其他所有点的最短路径(不能处理负边权)&#xff08;1&#xff09;朴素Dijkstra算法——适用于稠密图&#xff08;2&#xff09;堆优化版的Dijkstra算法——适用于稀疏图 4. SPFA算法——求源点到其他所有点的最短路径、判断是…

WordPress原创插件:disable-gutenberg禁用古腾堡编辑器和小工具

WordPress原创插件&#xff1a;disable-gutenberg禁用古腾堡编辑器和小工具 disable-gutenberg插件下载:https://download.csdn.net/download/huayula/89616495

SpringBoot快速学习

目录 SpringBoot配置文件 多环境配置 SpringBoot整合junit SpringBoot整合mybatis 1.在创建时勾选需要的模块 2.定义实体类 3.定义dao接口 4.编写数据库配置 5.使用Druid数据源 SpringBoot 是对 Spring 开发进行简化的。 那我们先来看看SpringMVC开发中的一些必须流…

翻译: 梯度下降 深度学习神经网络如何学习一

在上一节影片里我讲解了神经网络的结构 首先我们来快速回顾一下 在本节影片里&#xff0c;我们有两个目标 首介绍梯度下降的概念 它不仅是神经网络工作的基础 也是很多其他机器学习方法的基础 然后我们会研究一下这个特别的网络是如何工作的 以及这些隐藏的神经元层究竟在寻找什…

使用Openvino部署C++的Yolov5时类别信息混乱问题记录

使用Openvino部署C的Yolov5时类别信息混乱问题记录 简单记录一下。 一、问题描述 问题描述&#xff1a;在使用Yolov5的onnx格式模型进行C的Openvino进行模型部署时&#xff0c;通过读取classes.txt获得类别信息时&#xff0c;出现模型类别混乱&#xff0c;或者说根本就不给图…

如何将avi格式转换为flv格式呢?

FLV是随着FLASH MX的推出发展而来的一种视频格式&#xff0c;目前被众多新一代视频分享网站所采用&#xff0c;是目前增长较快&#xff0c;也较为广泛的视频传播格式。 FLV格式可以轻松导入FLASH播放器中&#xff0c;另外它还能起到保护版权的作用&#xff0c;非常受欢迎。那么…

在优化微信、支付宝小程序用户体验时有哪些关键指标

在优化小程序用户体验时&#xff0c;有几个关键指标需要特别关注&#xff0c;这些指标不仅能够帮助评估当前的用户体验状况&#xff0c;还能为后续的优化工作提供明确的方向。以下是一些关键指标及其解释&#xff1a; 1. 日活跃用户&#xff08;DAU&#xff09; 是指每天使用…