文章目录
- 1. 重要类
- 2. 保存模型
- 3. 代码测试
1. 重要类
- container.py
- nn.sequential
- nn.modulelist
- save_state_dict
2. 保存模型
pytorch官网教程
3. 代码测试
比较急,后续完善
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# @FileName :ToTest01.py
# @Time :2024/11/24 10:37
# @Author :Jason Zhang
import torch
from torch import nn
from torch.nn import Moduleclass MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()self.linear1 = nn.Linear(2, 3)self.linear2 = nn.Linear(3, 4)self.batch_norm4 = nn.BatchNorm2d(4)def forward(self, x):x = self.linear1(x)x = self.linear2(x)return xif __name__ == "__main__":run_code = 0input_x = torch.randn((1, 2))test_model = MyModel()y = test_model(input_x)model_modules = test_model._modulesprint(f"*"*50)print(f"model_modules=\n{model_modules}")print(f"*"*50)linear1 = model_modules['linear1']print(f"*"*50)print(f"linear1={linear1}")print(f"*"*50)print(f"linear1.weight=\n{linear1.weight}")print(f"*"*50)print(f"linear1.weight.dtype={linear1.weight.dtype}")print(f"*"*50)test_model.to(torch.double)print(f"linear1.weight.dtype={linear1.weight.dtype}")print(f"*"*50)test_model.to(torch.float32)print(f"linear1.weight.dtype={linear1.weight.dtype}")print(f"*"*50)model_parameters = test_model._parametersprint(f"model_parameters={model_parameters}")print(f"*"*50)model_buffers = test_model._buffersprint(f"model_buffer={model_buffers}")print(f"*"*50)model_state_dict = test_model.state_dict()print(f"model_state_dict=\n{model_state_dict}")print(f"*"*50)model_state_dict_linear2 = test_model.state_dict()['linear2.weight']print(f"model_state_dict_linear2=\n{model_state_dict_linear2}")print(f"*"*50)model_named_para =list(test_model.named_parameters())print(f"model_named_para=\n{model_named_para}")print(f"*"*50)model_named_modules =list(test_model.named_modules())print(f"model_named_modules=\n{model_named_modules}")print(f"*"*50)model_named_buffers =list(test_model.named_buffers())print(f"model_named_buffers=\n{model_named_buffers}")print(f"*"*50)model_named_children =list(test_model.named_children())print(f"model_named_children=\n{model_named_children}")
- 结果:
**************************************************
model_modules=
OrderedDict([('linear1', Linear(in_features=2, out_features=3, bias=True)), ('linear2', Linear(in_features=3, out_features=4, bias=True)), ('batch_norm4', BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))])
**************************************************
**************************************************
linear1=Linear(in_features=2, out_features=3, bias=True)
**************************************************
linear1.weight=
Parameter containing:
tensor([[-0.5518, 0.0687],[-0.7013, 0.4869],[-0.1157, -0.1287]], requires_grad=True)
**************************************************
linear1.weight.dtype=torch.float32
**************************************************
linear1.weight.dtype=torch.float64
**************************************************
linear1.weight.dtype=torch.float32
**************************************************
model_parameters=OrderedDict()
**************************************************
model_buffer=OrderedDict()
**************************************************
model_state_dict=
OrderedDict([('linear1.weight', tensor([[-0.5518, 0.0687],[-0.7013, 0.4869],[-0.1157, -0.1287]])), ('linear1.bias', tensor([-0.2915, -0.4807, 0.0071])), ('linear2.weight', tensor([[ 0.4185, 0.1556, 0.1371],[ 0.4751, 0.2029, -0.0679],[ 0.1264, -0.0288, -0.3661],[ 0.4423, -0.5370, 0.3930]])), ('linear2.bias', tensor([ 0.2746, -0.1798, 0.0218, 0.5465])), ('batch_norm4.weight', tensor([1., 1., 1., 1.])), ('batch_norm4.bias', tensor([0., 0., 0., 0.])), ('batch_norm4.running_mean', tensor([0., 0., 0., 0.])), ('batch_norm4.running_var', tensor([1., 1., 1., 1.])), ('batch_norm4.num_batches_tracked', tensor(0))])
**************************************************
model_state_dict_linear2=
tensor([[ 0.4185, 0.1556, 0.1371],[ 0.4751, 0.2029, -0.0679],[ 0.1264, -0.0288, -0.3661],[ 0.4423, -0.5370, 0.3930]])
**************************************************
model_named_para=
[('linear1.weight', Parameter containing:
tensor([[-0.5518, 0.0687],[-0.7013, 0.4869],[-0.1157, -0.1287]], requires_grad=True)), ('linear1.bias', Parameter containing:
tensor([-0.2915, -0.4807, 0.0071], requires_grad=True)), ('linear2.weight', Parameter containing:
tensor([[ 0.4185, 0.1556, 0.1371],[ 0.4751, 0.2029, -0.0679],[ 0.1264, -0.0288, -0.3661],[ 0.4423, -0.5370, 0.3930]], requires_grad=True)), ('linear2.bias', Parameter containing:
tensor([ 0.2746, -0.1798, 0.0218, 0.5465], requires_grad=True)), ('batch_norm4.weight', Parameter containing:
tensor([1., 1., 1., 1.], requires_grad=True)), ('batch_norm4.bias', Parameter containing:
tensor([0., 0., 0., 0.], requires_grad=True))]
**************************************************
model_named_modules=
[('', MyModel((linear1): Linear(in_features=2, out_features=3, bias=True)(linear2): Linear(in_features=3, out_features=4, bias=True)(batch_norm4): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)), ('linear1', Linear(in_features=2, out_features=3, bias=True)), ('linear2', Linear(in_features=3, out_features=4, bias=True)), ('batch_norm4', BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))]
**************************************************
model_named_buffers=
[('batch_norm4.running_mean', tensor([0., 0., 0., 0.])), ('batch_norm4.running_var', tensor([1., 1., 1., 1.])), ('batch_norm4.num_batches_tracked', tensor(0))]
**************************************************
model_named_children=
[('linear1', Linear(in_features=2, out_features=3, bias=True)), ('linear2', Linear(in_features=3, out_features=4, bias=True)), ('batch_norm4', BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))]