文章目录
- 背景
- 原因
- 解决方案
背景
很简单,我网上下载了一个模型文件,现在想读取这个模型,然后将这个模型用在我的数据集上。
import torch
model=torch.load("model.pyt")#这步直接报错了。
output=model(mydata)
报错了。
ModuleNotFoundError: No module named ‘models‘
原因
我现在项目的目录结构和model.pyt
这个模型文件当初在保存torch.save
的时候的项目目录不一致,导致导入load
模型的时候有一些关键东西缺失。
啥意思呢?假设当初模型保存torch.save
的文件长成这样。
import torch
from A import Model
model=Model()
torch.save(model,"model.pyt")
保存在model.pyt
中的东西,大家都知道,有模型权重,模型结构等。但是大家想过这样一个问题没有,如果模型里面的一个函数引用了另外一个用户自定义的函数,在torch.save
之后,这个自定义函数会被保存吗?答案是不会被保存。也就是说,对于上面的代码,
import torch
from A import Model
torch
这个库不会被保存,A
这个文件也不会被保存。那么自然,等我们torch.load
的时候,A
就会找不到,torch
可以找到,因为我们本地肯定会导入torch
。
为啥Pytorch设计的时候不保存这些呢?很简单,就怕模型里面的一个函数引用了另外一个用户自定义的函数,然后这个自定义函数又引用另外一个,然后没玩没了。更怕的是,自定义函数里面还导入了一些非常大的数据,如果全部保存起来,model.pyt
得多么大呀!
解决方案
一种方法当然就是把他的原项目下载下来,这包括了他的代码文件,而不能像我一样只下载模型文件。
其实,在load的时候Pytorch已经提示我们了,虽然只是一个warning。
FutureWarning: You are using
torch.load
withweights_only=False
(the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value forweights_only
will be flipped toTrue
. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user viatorch.serialization.add_safe_globals
. We recommend you start settingweights_only=True
for any use case where you don’t have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
官方不建议torch.save(model,"model.pyt")
这种保存方式,比较推荐torch.save(model.state_dict(),"model.pyt")
,也就是只保存模型权重,其他的一律不保存。这样,强制了你去下载代码文件,不然下面第一行代码Model()
会报错。
model = Model()
state_dict = torch.load('model.pyt')
model.load_state_dict(state_dict)#载入训练好的权重。