mindspore mindcv图像分类算法;模型保存与加载

参考:
https://www.mindspore.cn/tutorials/en/r1.3/save_load_model.html
https://github.com/mindspore-lab/mindcv/blob/main/docs/zh/tutorials/finetune.md

1、mindspore mindcv图像分类算法

import os
from mindcv.utils.download import DownLoad
import os
import mindspore as msos.environ['DEVICE_ID']='0'
ms.set_context(mode=ms.GRAPH_MODE, device_target="CPU", device_id=0)  ##指定cpu
#ms.set_context(mode=ms.GRAPH_MODE, device_target="Ascend", device_id=0)  ##需要使用才能npu加速dataset_url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/intermediate/Canidae_data.zip"
root_dir = "./"if not os.path.exists(os.path.join(root_dir, 'data/Canidae')):DownLoad().download_and_extract_archive(dataset_url, root_dir)##加载数据from mindcv.data import create_dataset, create_transforms, create_loadernum_workers = 8# 数据集目录路径
data_dir = "./data/Canidae/"# 加载自定义数据集
dataset_train = create_dataset(root=data_dir, split='train', num_parallel_workers=num_workers)
dataset_val = create_dataset(root=data_dir, split='val', num_parallel_workers=num_workers)# 定义和获取数据处理及增强操作
trans_train = create_transforms(dataset_name='ImageNet', is_training=True)
trans_val = create_transforms(dataset_name='ImageNet',is_training=False)loader_train = create_loader(dataset=dataset_train,batch_size=16,is_training=True,num_classes=2,transform=trans_train,num_parallel_workers=num_workers,
)
loader_val = create_loader(dataset=dataset_val,batch_size=5,is_training=True,num_classes=2,transform=trans_val,num_parallel_workers=num_workers,
)#模型微调from mindcv.models import create_modelnetwork = create_model(model_name='densenet121', num_classes=2, pretrained=True)#训练
from mindcv.loss import create_loss
from mindcv.optim import create_optimizer
from mindcv.scheduler import create_scheduler
from mindspore import Model, LossMonitor, TimeMonitor# 定义优化器和损失函数
opt = create_optimizer(network.trainable_params(), opt='adam', lr=1e-4)
loss = create_loss(name='CE')# 实例化模型
model = Model(network, loss_fn=loss, optimizer=opt, metrics={'accuracy'})
model.train(10, loader_train, callbacks=[LossMonitor(5), TimeMonitor(5)], dataset_sink_mode=False)res = model.eval(loader_val)
print(res)import matplotlib.pyplot as plt
import mindspore as ms
import numpy as npdef visualize_model(model, val_dl, num_classes=2):# 加载验证集的数据进行验证images, labels= next(val_dl.create_tuple_iterator())# 预测图像类别output = model.predict(images)pred = np.argmax(output.asnumpy(), axis=1)# 显示图像及图像的预测值images = images.asnumpy()labels = labels.asnumpy()class_name = {0: "dogs", 1: "wolves"}plt.figure(figsize=(15, 7))for i in range(len(labels)):plt.subplot(3, 6, i + 1)# 若预测正确,显示为蓝色;若预测错误,显示为红色color = 'blue' if pred[i] == labels[i] else 'red'plt.title('predict:{}'.format(class_name[pred[i]]), color=color)picture_show = np.transpose(images[i], (1, 2, 0))mean = np.array([0.485, 0.456, 0.406])std = np.array([0.229, 0.224, 0.225])picture_show = std * picture_show + meanpicture_show = np.clip(picture_show, 0, 1)plt.imshow(picture_show)plt.axis('off')plt.show()visualize_model(model, loader_val)

在这里插入图片描述

在这里插入图片描述

2、模型保存与加载

## 保存模型
import mindspore as msfrom mindcv.models import create_modelnetwork = create_model(model_name='densenet121', num_classes=2, pretrained=True)ms.save_checkpoint(network, "model1.ckpt")
## 加载模型
from mindspore import load_checkpoint, load_param_into_net
from mindspore import Modelparam_dict = load_checkpoint("model1.ckpt")
param_not_load = load_param_into_net(network, param_dict)
print(param_not_load)model1 = Model(network, loss, metrics={"accuracy"})

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/185505.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

一文搞定多端开发,做全栈大牛 附三大企业实战项目

一个功能三套代码 一改需求就是加不完的班? 不存在的,告别改改改 拥抱多端开发 一套代码搞定多个平台 高效开发:一套代码,多端通用 根据统计数据,全球移动设备用户数已经超过了50亿。随着智能手机、平板电脑等移动…

AI:68-基于深度学习的身份证号码识别

🚀 本文选自专栏:AI领域专栏 从基础到实践,深入了解算法、案例和最新趋势。无论你是初学者还是经验丰富的数据科学家,通过案例和项目实践,掌握核心概念和实用技能。每篇案例都包含代码实例,详细讲解供大家学习。 📌📌📌在这个漫长的过程,中途遇到了不少问题,但是…

剑指JUC原理-15.ThreadLocal

👏作者简介:大家好,我是爱吃芝士的土豆倪,24届校招生Java选手,很高兴认识大家📕系列专栏:Spring源码、JUC源码🔥如果感觉博主的文章还不错的话,请👍三连支持&…

【Mac开发环境搭建】安装HomeBrew、HomeBrew安装Docker、Docker安装Mysql5.7和8

文章目录 HomeBrew安装相关命令安装包卸载包查询可用的包更新所有包更新指定包查看已经安装的包查看包的信息清理包查看brew的版本更新brew获取brew的帮助信息 Brew安装DockerDocker常用命令镜像相关查看已经拉取的所有镜像删除镜像 容器相关停止运行容器启动容器重启容器删除容…

OSPF下的MGRE实验

一、实验要求 1、R1-R3-R4构建全连的MGRE环境 2、R1-R5-R6建立hub-spoke的MGRE环境,其中R1为中心 3、R1-R3...R6均存在环回网段模拟用户私网,使用OSPF使全网可达 4、其中R2为ISP路由器,仅配置IP地址 二、实验拓扑图 三、实验配置 1、给各路…

自然语言处理中的文本聚类:揭示模式和见解

一、介绍 在自然语言处理(NLP)领域,文本聚类是一种基本且通用的技术,在信息检索、推荐系统、内容组织和情感分析等各种应用中发挥着关键作用。文本聚类是将相似文档或文本片段分组为簇或类别的过程。这项技术使我们能够发现隐藏的…

企业级低代码开发,科技赋能让企业具备“驾驭软件的能力”

科技作为第一生产力,其强大的影响力在各个领域中都有所体现。数字技术,作为科技领域中的一股重要力量,正在对传统的商业模式进行深度的变革,为各行业注入新的生命力。随着数字技术的不断发展和应用,企业数字化转型的趋…

10-27 maven概念

maven maven的概念模型: 项目对象模型(POM: Project object Model),一组标准集合: pom.xml 依赖管理系统(Dependency Management System) 项目生命周期(Project Lifecycle) 项目对象模型: 把项目当成一个对象,描述这个项目,使用p…

计算机网络实验

计算机网络实验 使用软件PT7.0按照上面的拓扑结构建立网络,进行合理配置,使得所有计算机之间能够互相通信。并且修改各交换机的系统名称为:学号_编号,如你的学号为123,交换机Switch0的编号为0,则系统名称为…

Flutter利用GridView创建网格布局实现优美布局

文章目录 简介使用详解导入依赖项创建一个基本的 GridView一些参数说明使用GridView.count来构造 其他控制总结 简介 GridView 是 Flutter 中用于创建网格布局的强大小部件。它允许你在行和列中排列子小部件,非常适合显示大量项目,例如图像、文本、卡片…

【Royalty in Wind 2.0.0】个人体测计算、资料分享小程序

前言 Royalty in Wind 是我个人制作的一个工具类小程序。主要涵盖体测计算器、个人学习资料分享等功能。这个小程序在2022年第一次发布,不过后来因为一些原因暂时搁置。现在准备作为我个人的小程序重新投入使用XD PS:小程序开发部分我是在21年跟随郄培…

C# 继承,抽象,接口,泛型约束,扩展方法

文章目录 前言模拟需求场景模拟重复性高的需求初始类结构继承优化抽象类 需求1:打印CreateTime方法1:使用重载方法2:基类函数方法3:泛型约束方法3.1:普通泛型方法方法3.2:高级泛型约束,扩展方法…

uniapp u-tabs表单如何默认选中

首先先了解该组件;该组件,是一个tabs标签组件,在标签多的时候,可以配置为左右滑动,标签少的时候,可以禁止滑动。 该组件的一个特点是配置为滚动模式时,激活的tab会自动移动到组件的中间位置。 …

npm发布自己的包

npm发布自己的包 1. 首先在npm官网注册一个自己的账户(有账号的可以直接登录) 注册地址 2. 创建一个自己的项目(如果已有自己的项目, 跳过这一步) npm init -y3. 确认自己的npm下载源, 只能使用npm官方的地址 npm config get registry修改地址源 npm config set registr…

Spring Cloud学习(二)【Eureka注册中心】

文章目录 Eureka 注册中心Eureka 的作用 动手实践搭建 EurekaServer服务注册服务发现 Ribbon 负载均衡负载均衡原理IRule 接口(负载均衡策略)饥饿加载 Eureka 注册中心 服务调用出现的问题 不能采用硬编码服务消费者该如何获取服务提供者的地址信息&am…

苹果Mac电脑fcpx视频剪辑:Final Cut Pro中文最新 for mac

Final Cut Pro是苹果公司开发的一款专业视频剪辑软件,它为原生64位软件,基于Cocoa编写,支持多路多核心处理器,支持GPU加速,支持后台渲染。Final Cut Pro在Mac OS平台上运行,适用于进行后期制作。 Final Cu…

小白学安全-KunLun-M静态白盒扫描工具

一、KunLun-M简介 KunLun-M是一个完全开源的静态白盒扫描工具,支持PHP、JavaScript的语义扫描,基础安全、组件安全扫描,Chrome Ext\Solidity的基础扫描。开源地址:https://github.com/LoRexxar/Kunlun-M Cobra是一款源代码安全审计…

箭头函数 跟匿名函数this的指向问题

var id 10; function foo() {// 创建时 this->windowthis.id 20; // 等价于 window.id 20let c () > {console.log("id1:", this.id); // 创建时父级 创建时 this->window};let d function () {console.log("id2:", this.id); // 执行时本…

计算机毕业设计 基于SpringBoot的私人西服定制系统的设计与实现 Java实战项目 附源码+文档+视频讲解

博主介绍:✌从事软件开发10年之余,专注于Java技术领域、Python人工智能及数据挖掘、小程序项目开发和Android项目开发等。CSDN、掘金、华为云、InfoQ、阿里云等平台优质作者✌ 🍅文末获取源码联系🍅 👇🏻 精…

Spring Cloud LoadBalancer 负载均衡策略与缓存机制

目录 1. 什么是 LoadBalancer ? 2. 负载均衡策略的分类 2.1 常见的负载均衡策略 3. 为什么要学习 Spring Cloud Balancer ? 4. Spring Cloud LoadBalancer 内置的两种负载均衡策略 4.1 轮询负载均衡策略(默认的) 4.2 随机负…