参考:
https://www.mindspore.cn/tutorials/en/r1.3/save_load_model.html
https://github.com/mindspore-lab/mindcv/blob/main/docs/zh/tutorials/finetune.md
1、mindspore mindcv图像分类算法
import os
from mindcv.utils.download import DownLoad
import os
import mindspore as msos.environ['DEVICE_ID']='0'
ms.set_context(mode=ms.GRAPH_MODE, device_target="CPU", device_id=0) ##指定cpu
#ms.set_context(mode=ms.GRAPH_MODE, device_target="Ascend", device_id=0) ##需要使用才能npu加速dataset_url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/intermediate/Canidae_data.zip"
root_dir = "./"if not os.path.exists(os.path.join(root_dir, 'data/Canidae')):DownLoad().download_and_extract_archive(dataset_url, root_dir)##加载数据from mindcv.data import create_dataset, create_transforms, create_loadernum_workers = 8# 数据集目录路径
data_dir = "./data/Canidae/"# 加载自定义数据集
dataset_train = create_dataset(root=data_dir, split='train', num_parallel_workers=num_workers)
dataset_val = create_dataset(root=data_dir, split='val', num_parallel_workers=num_workers)# 定义和获取数据处理及增强操作
trans_train = create_transforms(dataset_name='ImageNet', is_training=True)
trans_val = create_transforms(dataset_name='ImageNet',is_training=False)loader_train = create_loader(dataset=dataset_train,batch_size=16,is_training=True,num_classes=2,transform=trans_train,num_parallel_workers=num_workers,
)
loader_val = create_loader(dataset=dataset_val,batch_size=5,is_training=True,num_classes=2,transform=trans_val,num_parallel_workers=num_workers,
)#模型微调from mindcv.models import create_modelnetwork = create_model(model_name='densenet121', num_classes=2, pretrained=True)#训练
from mindcv.loss import create_loss
from mindcv.optim import create_optimizer
from mindcv.scheduler import create_scheduler
from mindspore import Model, LossMonitor, TimeMonitor# 定义优化器和损失函数
opt = create_optimizer(network.trainable_params(), opt='adam', lr=1e-4)
loss = create_loss(name='CE')# 实例化模型
model = Model(network, loss_fn=loss, optimizer=opt, metrics={'accuracy'})
model.train(10, loader_train, callbacks=[LossMonitor(5), TimeMonitor(5)], dataset_sink_mode=False)res = model.eval(loader_val)
print(res)import matplotlib.pyplot as plt
import mindspore as ms
import numpy as npdef visualize_model(model, val_dl, num_classes=2):# 加载验证集的数据进行验证images, labels= next(val_dl.create_tuple_iterator())# 预测图像类别output = model.predict(images)pred = np.argmax(output.asnumpy(), axis=1)# 显示图像及图像的预测值images = images.asnumpy()labels = labels.asnumpy()class_name = {0: "dogs", 1: "wolves"}plt.figure(figsize=(15, 7))for i in range(len(labels)):plt.subplot(3, 6, i + 1)# 若预测正确,显示为蓝色;若预测错误,显示为红色color = 'blue' if pred[i] == labels[i] else 'red'plt.title('predict:{}'.format(class_name[pred[i]]), color=color)picture_show = np.transpose(images[i], (1, 2, 0))mean = np.array([0.485, 0.456, 0.406])std = np.array([0.229, 0.224, 0.225])picture_show = std * picture_show + meanpicture_show = np.clip(picture_show, 0, 1)plt.imshow(picture_show)plt.axis('off')plt.show()visualize_model(model, loader_val)
2、模型保存与加载
## 保存模型
import mindspore as msfrom mindcv.models import create_modelnetwork = create_model(model_name='densenet121', num_classes=2, pretrained=True)ms.save_checkpoint(network, "model1.ckpt")
## 加载模型
from mindspore import load_checkpoint, load_param_into_net
from mindspore import Modelparam_dict = load_checkpoint("model1.ckpt")
param_not_load = load_param_into_net(network, param_dict)
print(param_not_load)model1 = Model(network, loss, metrics={"accuracy"})