PyTorch中加载模型权重 A匹配B|A不匹配B

在做深度学习项目时,从头训练一个模型是需要大量时间和算力的,我们通常采用加载预训练权重的方法,而我们往往面临以下几种情况:
在这里插入图片描述

未修改网络,A与B一致

很简单,直接.load_state_dict()

net = ANet(num_classses = 5,init_weights=True)
net.to(device)
net.load_state_dict(torch.load('weight/B_weight.pth'))

修改了网络,A与B不一致

[pytorch官方文档](Search — PyTorch master documentation):

load_state_dict(state_dict, strict=True)

将 state_dict 中的参数和缓冲区复制到此模块及其后代中。如果 strict 为 True,则 state_dict 的键必须与该模块的 state_dict() 函数返回的键完全匹配。

state_dict是包含参数和持久缓冲区的字典,可以看出 strict默认为True,所以默认状态下是严格要求state_dict中的key与torch.nn.Module.state_dict返回的key完全一致的

load_state_dict()函数有两个返回值:

missing_keys 是包含缺失键的 str 列表
unexpected_keys 是包含意外键的 str 列表

方法一:

将strict改为false,加载键值相同的部分。

model = NET2()
state_dict = model.state_dict()
weights = torch.load(weights_path)['model_state_dict']	#读取预训练模型权重
model.load_state_dict(weights, strict=False)	#strict

但是此时还存在一种情况:键值相同但shape不同,故应进行if…in…的判断:

ANet = torch.load('ANet.pt')  # 加载预训练权重模型(.pt文件)参数
#现成的模型的话,如resnet50 = models.resnet50(pretrained=True)
#采用:pretrained_dict = resnet50().state_dict()  
model = Model() # 创建模型
model_dict = model.state_dict() # 得到模型的参数字典# 判断预训练模型中网络的模块是否修改后的网络中也存在,并且shape相同,如果相同则取出
pretrained_dict = {k: v for k, v in ANet.items() if k in model_dict and (v.shape == model_dict[k].shape)}# 更新修改之后的 model_dict
model_dict.update(pretrained_dict)# 加载我们真正需要的 state_dict
model.load_state_dict(model_dict, strict=False)

方法二:

1.将权重导入原模型,之后在加载后的原模型基础上进行修改。
2.修改权重文件参数,再进行导入
适用于改动不大的模型

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/76723.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

小鱼深度产品测评之:阿里云容器服务器ASK,一款不需购买节点,即可直接部署容器应用。

容器服务器ASK测评 1、引言2、帮助文档3、集群3.1集群列表3.1.1 详情3.1.1.1概览 tab3.1.1.2基本信息 tab3.1.1.4集群资源 tab3.1.1.5 集群日志 tab3.1.1.6 集群任务 tab 3.1.2 应用管理3.1.2.1 详情3.1.2.2 详情3.1.2.3 伸缩3.1.2.4 监控 3.1.3 查看日志3.1.3.1 集群日志3.1.3…

嵌入式开发学习(STC51-3-点亮led)

内容 点亮第一个led; led闪烁; led流水灯; led简介 led即发光二极管,它具有单向导电性,通过5mA左右电流即可发光,电流越大,其亮度越强,但若电流过大,会烧毁二极管&…

【云原生】k8s组件架构介绍与K8s最新版部署

个人主页:征服bug-CSDN博客 kubernetes专栏:kubernetes_征服bug的博客-CSDN博客 目录 1 集群组件 1.1 控制平面组件(Control Plane Components) 1.2 Node 组件 1.3 插件 (Addons) 2 集群架构详细 3 集群搭建[重点] 3.1 mi…

快速压缩PDF文件的方法:这两种方法一定要学会!

随着PDF文件的增加,文件大小也会逐渐增大,给共享和存储带来了一定的挑战。为了解决这个问题,本文将介绍几个简单而有效的方法,即压缩PDF文件,以减小文件大小,提高共享和存储的效率。 使用在线压缩工具 在…

Dockerfile定制Tomcat镜像

Dockerfile中的打包命令 FROM : 以某个基础镜像作为此镜像的基础 RUN : RUN后面跟着linux常用命令,如RUN echo xxx >> xxx,注意,RUN 不能用于执行命令,因为每个RUN都是独立运行的,RUN 的cd对镜像中的…

我的128创作纪念日

机缘 写CSDN博客的时候,应该纯属一个巧合,还记得当初是和一个班上的同学一起记录学习笔记,最初是在博客园的平台上记录笔记,可以在以后复习时使用,后来我的同学开始推荐使用CSDN平台,于是我们两就转战CSDN…

第十二届金博奖启动,促进大小企业优势互联

8月3日,由博士科技联合深圳证券交易所、广东博士创新发展促进会共同举办的“走进深交所之国家级专精特新‘小巨人’企业融资路演暨第二十届金博奖启动仪式”成功举办。 华南理工大学教授,俄罗斯工程院外籍院士,广东博士创新发展促进会会长、国…

回归预测 | MATLAB实现基于SVM-RFE-BP支持向量机递归特征消除特征选择算法结合BP神经网络的多输入单输出回归预测

回归预测 | MATLAB实现基于SVM-RFE-BP支持向量机递归特征消除特征选择算法结合BP神经网络的多输入单输出回归预测 目录 回归预测 | MATLAB实现基于SVM-RFE-BP支持向量机递归特征消除特征选择算法结合BP神经网络的多输入单输出回归预测预测效果基本介绍研究内容程序设计参考资料…

Last-Mile Embodied Visual Navigation 论文阅读

论文阅读 题目:Last-Mile Embodied Visual Navigation 作者:JustinWasserman, Karmesh Yadav 来源:CoRL 时间:2023 代码地址:https://jbwasse2.github.io/portfolio/SLING Abstract 现实的长期任务(例如…

9.物联网操作系统之软件定时器

一。软件定时器概念及应用 1.软件定时器定义 就是软件实现定时器。 2.FreeRTOS软件定时器介绍 如上图所示,Times的左边为设置定时器时间,设置方式可以为任务设置或者中断设置;Times的右边为定时器的定时相应,使用CalBack相应。 …

2023年华数杯数学建模C题思路代码分析 - 母亲身心健康对婴儿成长的影响

# 1 赛题 C 题 母亲身心健康对婴儿成长的影响 母亲是婴儿生命中最重要的人之一,她不仅为婴儿提供营养物质和身体保护, 还为婴儿提供情感支持和安全感。母亲心理健康状态的不良状况,如抑郁、焦虑、 压力等,可能会对婴儿的认知、情…

为Stable Diffusion web UI开发自己的插件实战

最近,Stable Diffusion AI绘画受到了广泛的关注和热捧。它的Web UI提供了了一系列强大的功能,其中特别值得一提的是对插件的支持,尤其是Controlnet插件的加持,让它的受欢迎程度不断攀升。那么,如果你有出色的创意&…

数学建模-元胞自动机

clc clear n 300; % 定义表示森林的矩阵大小 Plight 5e-6; Pgrowth 1e-2; % 定义闪电和生长的概率 UL [n,1:n-1]; DR [2:n,1]; % 定义上左,下右邻居 vegzeros(n,n); % 初始化表示森林的矩阵 imh ima…

【代码源每日一题div2 】简单的异或问题

简单的异或问题 - 题目 - Daimayuan Online Judge 题意: 思路: 首先这有一个结论:0~2^m-1的所有数进行XOR运算后,得到的结果是0。我们来证明一下这个结论: 比如m3时,一共是0 1 2 3 4 5 6 7,八…

【iOS RunLoop】

文章目录 前言-什么是RunLoop?默认情况下主线程的RunLoop原理 1. RunLoop对象RunLoop对象的获取 CFRunLoopRef源码部分(引入线程相关) 2. RunLoop和线程3. RunLoop相关的类RunLoop相关类的实现CFRunLoopModeRef五种运行模式CommonModes CFRun…

网站无法访问的常见原因

有多种问题可能会阻止用户访问您的网站。本文将解决无法访问网站,且没有错误消息指示确切问题的情况,希望对您有所帮助。 无法访问网站的常见原因有: (1)DNS 设置不正确。 (2)域名已过期。 (3)空白或没有索引文件。 (4)网络连接问题。 DNS 设…

Qt开发,编译报错:error: C2001: 常量中有换行符

一、问题描述 Qt开发,编译报错:error: C2001: 常量中有换行符 E:\work\xxx.cpp:1: warning: C4819: 该文件包含不能在当前代码页(936)中表示的字符。请将该文件保存为 Unicode 格式以防止数据丢失 E:\work\xxx.cpp:66: error: C2001: 常量中有换行符 E…

中国信通院发布《高质量数字化转型产品及服务全景图(2023)》

2023年7月27日,由中国信息通信研究院主办的2023数字生态发展大会暨中国信通院铸基计划年中会议在北京成功召开。 本次大会发布了中国信通院《高质量数字化转型产品及服务全景图(2023)》,中新赛克海睿思受邀出席本次大会并成功入选…

一文了解JavaScript 与 TypeScript的区别

TypeScript 和 JavaScript 是两种互补的技术,共同推动前端和后端开发。在本文中,我们将带您快速了解JavaScript 与 TypeScript的区别。 一、TypeScript 和 JavaScript 之间的区别 JavaScript 和 TypeScript 看起来非常相似,但有一个重要的区…

14-5_Qt 5.9 C++开发指南_基于HTTP 协议的网络应用程序

文章目录 1. 实现高层网络操作的类2. 基于HTTP协议的网络文件下载3.源码3.1 可是化UI设计3.2 mainwindow.h3.3 mainwindow.cpp 1. 实现高层网络操作的类 Qt 网络模块提供一些类实现 OSI 7 层网络模型中高层的网络协议,如 HTTP、FTP、SNMP等,这些类主要是…