onnx模型修改:去掉Dropout层

文章目录

    • 尝试1:强行设置dropout层train mode为False
    • 尝试2:找到onnx模型中的dropout, train mode设置为False
    • 尝试3:直接删除dropout层,连接其输入输出
    • 结语

最近训练模型使用了tinyvit,性能挺强的:
image.png

但是导出onnx时,会提示dropout层的train mode被设置为True了。

UserWarning: ONNX export mode is set to TrainingMode.EVAL, but operator 'dropout' is set to train=True. Exporting with train=True.

这个警告如果只是使用onnxruntime去推理的话,可以不用处理,但是如果使用openvino则会在转换模型时失败。因为导出的onnx中出现了Dropout层,一般的推理框架是不支持推理的时候用dropout的。
image.png

尝试1:强行设置dropout层train mode为False

for m in torch_model.modules():if isinstance(m, torch.nn.Dropout):m.training = False

问题依旧

尝试2:找到onnx模型中的dropout, train mode设置为False

做这个尝试的本意是先设置为False, 再用onnx-simplify去优化一把,理论上会把dropout层去掉。


# 遍历模型的所有Dropout节点, 找到所有的training mode节点名称
training_mode_inputs=[]
for node in model.graph.node:if node.op_type == 'Dropout':# 获取Dropout节点的training_mode输入(假设是最后一个输入)training_mode_input = node.input[-1]# 检查这个输入是否指向之前找到的值为True的常量节点training_mode_inputs.append(training_mode_input)# 遍历所有初始化器
for initializer in model.graph.initializer:# 检查初始化器是否是我们要找的training_mode输入if initializer.name in training_mode_inputs:# 假设这个初始化器是一个布尔值,我们将其修改为False# 注意:ONNX中的布尔值是以int64类型存储的,0表示False,1表示True# initializer.data_type = onnx.TensorProto.INT64initializer.int64_data[:] = [0]  # 修改为False
from onnx import helper
new_initializers = []for initializer in model.graph.initializer:if initializer.name in training_mode_inputs:# 创建一个新的TensorProto对象,值为Falsenew_initializer = helper.make_tensor(name=initializer.name,  # 保持原来的名称data_type=onnx.TensorProto.BOOL,dims=initializer.dims,  # 保持原来的维度vals=[0]  # 设置值为False(在ONNX中用0表示))new_initializers.append(new_initializer)else:new_initializers.append(initializer)# 替换原来的初始化器列表
# Clear existing initializers
model.graph.ClearField('initializer')
# Add the new initializers
model.graph.initializer.extend(new_initializers)

理想很丰满,现实很骨感···并没有发生什么变化

尝试3:直接删除dropout层,连接其输入输出

dropout层在推理的时候也没什么用,直接删除,然后连接上原dropout的输入输出层就好了

import onnx
from onnx import helper# 加载模型
onnx_model = onnx.load(model_path)
graph = onnx_model.graph# 找到 Dropout 层
nodes_to_remove = [node for node in graph.node if node.op_type == 'Dropout']# 删除 Dropout 层并重新连接
for node in nodes_to_remove:input_name = node.input[0]output_name = node.output[0]# 找到所有使用 Dropout 输出作为输入的节点for next_node in graph.node:for i, input_name in enumerate(next_node.input):if input_name == node.output[0]:next_node.input[i] = node.input[0]# 从图中移除 Dropout 节点graph.node.remove(node)# 保存修改后的模型
# check if the model is valid
onnx.checker.check_model(onnx_model)
onnx.save(onnx_model, 'tinyvit_11m_sim_replace.onnx')

成功了,模型的dropout层都被删除了。
image.png

结语

虽然尝试了好几种方式···不过这些具体的代码我基本都是问的copilot,不得不说代码助手减轻了好多工作。
f77d79a3b79d6d9849231e64c8e1cdfa~tplv-dy-resize-origshort-autoq-75_330.jpeg

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/360955.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

6.25作业

1.整理思维导图 2.终端输入两个数,判断两数是否相等,如果不相等,判断大小关系 #!/bin/bash read num1 read num2 if [ $num1 -eq $num2 ] then echo num1num2 elif [ $num1 -gt $num2 ] then echo "num1>num2" else echo &quo…

运行ChatGLM大模型时,遇到的各种报错信息及解决方法

①IMPORTANT: You are using gradio version 3.49.0, however version 4.29.0 is available, please upgrade 原因分析: 因为使用的gradio版本过高,使用较低版本。 pip install gradio3.49.0 会有提示IMPORTANT: You are using gradio version 3.49.…

Linux的shell语法

Linux的shell脚本 1.概述 shell解释器,介于操作系统内核与用户之间,充当了一个“命令解释器”的角色,负责接收用户输入的操作指令(命令)并进行解释,将需要执行的操作传递给内核执行,并输出执行…

计算机网络 VLAN间路由单臂路由

一、理论知识 VLAN是一种将物理网络划分成多个逻辑网络的方法。不同的VLAN属于不同的网段,因此互相通信需要通过路由器进行路由。通常情况下,在同一VLAN内的设备可以直接通信,而不同VLAN之间的设备则需要通过路由器转发数据。本实验利用单臂…

SQL连接与筛选:解析left join on和where的区别及典型案例分析

文章目录 前言数据库在运行时的执行顺序一、left join on和where条件的定义和作用left join on条件where条件 二、left join on和where条件的区别原理不同left join原理:where原理: 应用场景不同执行顺序不同(作用阶段不同)结果集…

【物联网】室内定位技术及定位方式简介

目录 一、概述 二、常用的室内定位技术 2.1 WIFI技术 2.2 UWB超宽带 2.3 蓝牙BLE 2.4 ZigBee技术 2.5 RFID技术 三、常用的室内定位方式 3.1 信号到达时间 3.2 信号到达时间差 3.3 信号到达角 3.4 接收信号强度 一、概述 GPS是目前应用最广泛的定位技术&#xff0…

Vue3 按钮根据屏幕宽度展示折叠按钮

文章目录 一、组件封装二、使用三、最终效果(参考)四、参考 一、组件封装 ButtonFold.vue 1、获取父组件的元素,根据元素创建动态插槽 2、插槽中插入父元素标签。默认效果和初始状态相同。 3、当屏幕宽度缩小时,部分按钮通过 dropdown 的方式展示出来&a…

vue elementui简易侧拉栏的使用

目的&#xff1a; 增加了侧拉栏&#xff0c;目的是可以选择多条数据展示数据 组件&#xff1a; celadon.vue <template><div class"LayoutMain"><el-aside :width"sidebarIsCollapse ? 180px : 0px" class"aside-wrap"><…

Tomcat 下载部署到 idea

一、下载Tomcat Tomcat 是Apache 软件基金会&#xff08;Apache Software Foundation&#xff09;下的一个核心项目&#xff0c;免费开源、并支持Servlet 和JSP 规范。属于轻量级应用服务器&#xff0c;在中小型系统和并发访问用户不是很多的场合下被普遍使用&#xff0c;是开发…

Talking Web

1. curl 1.1 http curl http://127.0.0.1:80 向目标主机端口发送http请求 1.2 httphead curl -H “Host: 18ed3df584cd48328b5839443aa7b42b” http://127.0.0.1:80 1.3 httppath curl http://127.0.0.1:80/853c64cd218f80d0a59665666fb2ab80 1.4 URL编码路径 &#xff0…

「2024中国数据要素产业图谱1.0版」重磅发布,景联文科技凭借高质量数据采集服务入选!

近日&#xff0c;景联文科技入选数据猿和上海大数据联盟发布的《2024中国数据要素产业图谱1.0版》数据采集服务板块。 景联文科技是专业数据服务公司&#xff0c;提供从数据采集、清洗、标注的全流程数据解决方案&#xff0c;协助人工智能企业解决整个AI链条中数据采集和数据标…

一天跌20%,近500只下跌,低价可转债为何不香了?

6月以来&#xff0c;Wind可转债低价指数累计下跌7.3%&#xff0c;大幅跑输中价、高价转债。分析认为&#xff0c;市场调整的底层逻辑在于投资者对风险的重新评估和流动性的紧缩&#xff0c;宏观经济的波动和政策环境的不确定性、市场结构性的变化均对低价可转债市场产生了冲击。…

如何从零开始搭建成功的谷歌外贸网站?

先选择一个适合外贸网站的建站平台&#xff0c;如WordPress或Shopify。这些平台提供丰富的主题和插件&#xff0c;可以帮助你快速搭建和定制网站。设计网站时&#xff0c;注重用户体验&#xff0c;确保导航清晰、页面加载快速、移动端友好。确保网站的SEO优化。从关键词研究开始…

DAY14-力扣刷题

1.删除链表中的重复元素2 82. 删除排序链表中的重复元素 II - 力扣&#xff08;LeetCode&#xff09; 给定一个已排序的链表的头 head &#xff0c; 删除原始链表中所有重复数字的节点&#xff0c;只留下不同的数字 。返回 已排序的链表 。 class Solution {public ListNode …

MySQL报错Duplicate entry ‘0‘ for key ‘PRIMARY‘

报错现场 现象解释 因为你在插入时没有给 Customer.Id 赋值&#xff0c;MySQL 会倾向于赋值为 NULL。但是主键不能为 NULL&#xff0c;所以 MySQL 帮了你一个忙&#xff0c;将值转换为 0。这样&#xff0c;在第二次插入时就会出现冲突&#xff08;如果已经有一条记录为 0&…

RK3568平台开发系列讲解(I2C篇)利用逻辑分析仪进行I2C总线的全面分析

🚀返回专栏总目录 文章目录 1. 基础协议1.1. 协议简介1.2. 物理信号1.3. 总线连接沉淀、分享、成长,让自己和他人都能有所收获!😄 1. 基础协议 1.1. 协议简介 IIC-BUS(Inter-IntegratedCircuit Bus)最早是由PHilip半导体(现在被NXP收购)于1982年开发。 主要是用来方…

Netty中Reactor线程的运行逻辑

Netty中的Reactor线程主要干三件事情&#xff1a; 轮询注册在Reactor上的所有Channel感兴趣的IO就绪事件。 处理Channel上的IO就绪事件。 执行Netty中的异步任务。 正是这三个部分组成了Reactor的运行框架&#xff0c;那么我们现在来看下这个运行框架具体是怎么运转的~~ 这…

鸿蒙开发系统基础能力:【@ohos.inputMethod (输入法框架)】

输入法框架 说明&#xff1a; 本模块首批接口从API version 6开始支持。后续版本的新增接口&#xff0c;采用上角标单独标记接口的起始版本。 导入模块 import inputMethod from ohos.inputMethod;inputMethod8 常量值。 系统能力&#xff1a;以下各项对应的系统能力均为Sy…

lumbda常用操作

文章目录 lumbda的常用操作将List<String>转List<Integer>filter 过滤max 和min将List<Object>转为Map将List<Object>转为Map&#xff08;重复key&#xff09;将List<Object>转为Map&#xff08;指定Map类型&#xff09; lumbda的常用操作 将Li…

【机器学习】大模型驱动下的医疗诊断应用

摘要&#xff1a; 随着科技的不断发展&#xff0c;机器学习在医疗领域的应用日益广泛。特别是在大模型的驱动下&#xff0c;机器学习为医疗诊断带来了革命性的变化。本文详细探讨了机器学习在医疗诊断中的应用&#xff0c;包括疾病预测、图像识别、基因分析等方面&#xff0c;并…