在使用 timm 库加载预训练模型权重时,可以通过多种方式加载模型并指定权重。以下是几个常见的操作方法:
- 加载具有预训练权重的模型
timm 提供了大量预训练模型,可以直接通过 pretrained=True 参数加载模型及其对应的权重。
import timm
加载预训练的模型
model = timm.create_model('resnet50', pretrained=True)
查看模型结构
print(model)
- 加载无预训练权重的模型
如果你希望只加载模型的结构,而不加载预训练的权重:
加载模型,不加载预训练权重
model = timm.create_model('resnet50', pretrained=False)
- 自定义加载权重
如果你有自己的权重文件,可以通过 load_state_dict 方法加载。以下是具体步骤:
步骤 1: 加载模型结构
model = timm.create_model('resnet50', pretrained=False)
步骤 2: 加载自定义权重
假设权重文件路径为 path_to_weights.pth:
# 加载本地权重文件
state_dict = torch.load('path_to_weights.pth', map_location='cpu')# 将权重加载到模型中
model.load_state_dict(state_dict)# 加载本地权重文件
state_dict = torch.load('path_to_weights.pth', map_location='cpu')# 将权重加载到模型中
model.load_state_dict(state_dict)
注意: 如果权重文件的结构与模型不匹配,可能会出现错误。可以通过 strict=False 参数忽略部分不匹配的权重:
model.load_state_dict(state_dict, strict=False)
- 查看支持的模型和权重
timm 支持多种模型和预训练权重,可以通过以下方法查看可用模型列表及其支持的权重:
列出支持的模型
from timm.models import list_models
列出所有可用模型
all_models = list_models()
print(all_models)
列出具有预训练权重的模型
pretrained_models = list_models(pretrained=True)
print(pretrained_models)查询特定模型的权重选项
查询特定模型支持的预训练权重
weights = timm.models.list_pretrained('resnet50')
print(weights)
- 指定具体权重版本
timm 中部分模型可能提供多个预训练权重,可以通过 pretrained_cfg 指定:
model = timm.create_model('resnet50', pretrained=True, pretrained_cfg='resnet50.imagenet')
- 检查权重加载是否成功
加载权重后,可以检查模型的参数是否正确加载:
for name, param in model.named_parameters():print(name, param.shape)
完整示例
以下是一个加载自定义权重的完整示例:
import timm
import torch# 加载模型结构
model = timm.create_model('resnet50', pretrained=False)# 加载自定义权重
weights_path = 'path_to_weights.pth'
state_dict = torch.load(weights_path, map_location='cpu')
model.load_state_dict(state_dict, strict=False)# 检查模型
print(model)