【笔记】打卡01 | 初学入门

初学入门:01-02

  • 01 基本介绍
  • 02 快速入门
    • 处理数据集
    • ==网络构建==
    • 模型训练
    • 保存模型
    • 加载模型
    • 打卡-时间

01 基本介绍

MindSpore Data(数据处理层)
ModelZoo(模型库)
MindSpore Science(科学计算),包含了业界领先的数据集、基础模型、预置高精度模型和前后处理工具
MindSpore Insight(可视化调试调优工具),能够可视化地查看训练过程、优化模型性能、调试精度问题、解释推理结果

02 快速入门

import mindspore
from mindspore import nn
from mindspore.dataset import vision, transforms
from mindspore.dataset import MnistDataset

处理数据集

下载Mnist数据集

# Download data from open datasets
from download import downloadurl = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/" \"notebook/datasets/MNIST_Data.zip"
path = download(url, "./", kind="zip", replace=True)

在这里插入图片描述

训练集、测试集

train_dataset = MnistDataset('MNIST_Data/train')
test_dataset = MnistDataset('MNIST_Data/test')

在这里插入图片描述
列名:图片 和 对应标签(分类)

数据处理流水线(Data Processing Pipeline)

参数:数据集、batch_size

def datapipe(dataset, batch_size):image_transforms = [                    vision.Rescale(1.0 / 255.0, 0),vision.Normalize(mean=(0.1307,), std=(0.3081,)),vision.HWC2CHW()]label_transform = transforms.TypeCast(mindspore.int32)dataset = dataset.map(image_transforms, 'image')dataset = dataset.map(label_transform, 'label')dataset = dataset.batch(batch_size)return dataset

首先,数据变换(Transforms):1、对输入数据(即图片)2、对输出(即标签);
然后,map对图像数据及标签进行变换处理;
最后,将处理好的数据集打包为大小为64的batch

train_dataset = datapipe(train_dataset, 64)
test_dataset = datapipe(test_dataset, 64)

对数据集进行迭代访问

for data in test_dataset.create_dict_iterator():print(f"Shape of image [N, C, H, W]: {data['image'].shape} {data['image'].dtype}")print(f"Shape of label: {data['label'].shape} {data['label'].dtype}")break

网络构建

class Network(nn.Cell):def __init__(self):super().__init__()self.flatten = nn.Flatten()self.dense_relu_sequential = nn.SequentialCell(nn.Dense(28*28, 512),nn.ReLU(),nn.Dense(512, 512),nn.ReLU(),nn.Dense(512, 10))def construct(self, x):x = self.flatten(x)logits = self.dense_relu_sequential(x)return logitsmodel = Network()
print(model)

mindspore.nn类是构建所有网络的基类,也是网络的基本单元。

  • 自定义网络时,可以继承nn.Cell
  • __init__包含所有网络层的定义
  • construct(类似前向传播??)包含数据(Tensor)的变换过程。

模型训练

定义损失函数、优化器

loss_fn = nn.CrossEntropyLoss()
optimizer = nn.SGD(model.trainable_params(), 1e-2)

一个完整的训练过程(step)需要实现以下三步:

1. 正向计算:模型预测结果(logits),并与正确标签(label)求预测损失(loss)。
2. 反向传播:利用自动微分机制,自动求模型参数(parameters)对于loss的梯度(gradients)。
3. 参数优化:将梯度更新到参数上。

定义正向计算函数。

def forward_fn(data, label):logits = model(data)loss = loss_fn(logits, label)return loss, logits

使用value_and_grad通过函数变换获得梯度计算函数。

grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)

one-step training

def train_step(data, label):(loss, _), grads = grad_fn(data, label)optimizer(grads)return loss

定义训练函数,使用set_train设置为训练模式,执行正向计算、反向传播和参数优化。

def train(model, dataset):size = dataset.get_dataset_size()model.set_train()for batch, (data, label) in enumerate(dataset.create_tuple_iterator()):loss = train_step(data, label)if batch % 100 == 0:loss, current = loss.asnumpy(), batchprint(f"loss: {loss:>7f}  [{current:>3d}/{size:>3d}]")

定义测试函数:用来评估模型的性能。

def test(model, dataset, loss_fn):num_batches = dataset.get_dataset_size()model.set_train(False)total, test_loss, correct = 0, 0, 0for data, label in dataset.create_tuple_iterator():pred = model(data)total += len(data)test_loss += loss_fn(pred, label).asnumpy()correct += (pred.argmax(1) == label).asnumpy().sum()test_loss /= num_batchescorrect /= totalprint(f"Test: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

训练过程需多轮(epoch)训练数据集

epochs = 3
for t in range(epochs):print(f"Epoch {t+1}\n-------------------------------")train(model, train_dataset)test(model, test_dataset, loss_fn)
print("Done!")

在这里插入图片描述

保存模型

模型训练完成后,需要保存其参数。

mindspore.save_checkpoint(model, "model.ckpt")
print("Saved Model to model.ckpt")

加载模型

加载保存的权重

# 1、重新实例化模型对象,构造模型
model = Network()
# 加载模型参数,并将其加载至模型上。
param_dict = mindspore.load_checkpoint("model.ckpt")
param_not_load, _ = mindspore.load_param_into_net(model, param_dict)
print(param_not_load)

param_not_load是未被加载的参数列表,为空时代表所有参数均加载成功。

打卡-时间

from datetime import datetime
import pytz
# 设置时区为北京时区
beijing_tz = pytz.timezone('Asia/shanghai')
# 获取当前时间,并转为北京时间
current_beijing_time = datetime.now(beijing_tz)
# 格式化时间输出
formatted_time = current_beijing_time.strftime('%Y-%m-%d %H:%M:%S')
print("当前北京时间:",formatted_time,'your name')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/355437.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

视频融合共享平台LntonCVS视频监控安防系统运用多视频协议建设智慧园区方案

智慧园区,作为现代化城市发展的重要组成部分,不仅推动了产业的升级转型,也成为了智慧城市建设的核心力量。随着产业园区之间的竞争日益激烈,如何打造一个功能完善、智能化程度高的智慧园区,已经成为了业界广泛关注的焦…

MacOS设备远程登录配置结合内网穿透实现异地ssh远程连接

文章目录 前言1. MacOS打开远程登录2. 局域网内测试ssh远程3. 公网ssh远程连接MacOS3.1 MacOS安装配置cpolar3.2 获取ssh隧道公网地址3.3 测试公网ssh远程连接MacOS 4. 配置公网固定TCP地址4.1 保留一个固定TCP端口地址4.2 配置固定TCP端口地址 5. 使用固定TCP端口地址ssh远程 …

在 Visual Studio 2022 (Visual C++ 17) 中使用 Visual Leak Detector

在 Visual C 2022 中使用 Visual Leak Detector 1 问题描述1.1 内存泄漏的困扰和解决之道1.2 内存泄漏检测工具的选择1.3 VLD的现状 2 安装和设置VLD的环境变量2.1 安装VLD文件2.2 VLD安装后的目录和文件说明2.2.1 include子目录说明2.2.2 lib子目录说明2.2.2.1 目录整理 2.2.3…

感恩的力量!美洲杯魔幻提前预告 阿根廷 ——早读(逆天打工人爬取热门微信文章解读)

梅西还能不能提? 引言Python 代码第一篇 洞见 感恩的力量(深度好文)第二篇 视频新闻结尾 引言 早上早起 昨天晚上1点多才睡 这几天都是 明明很早就准备上床睡觉 但是就是忍不住 吃根雪糕 喝个小饮料 看看最近的欧洲杯比赛 卒 真的是拖延症十…

【总结】ui自动化selenium知识点总结

1. 大致原理 首页安装第三方库selenium库, 其次要下载好浏览器驱动文件,比如谷歌的 chromedriver.exe,配置上环境变量。 使用selenium的webdriver类去创建一个浏览器驱动对象赋值叫driver,一个浏览器驱动对象就可以 实现 对浏…

python爬虫之selenium自动化操作

python爬虫之selenium自动化操作 需求:操作淘宝去掉弹窗广告搜索物品后进入百度回退又前进 selenium模块的基本使用 问题:selenium模块和爬虫之间具有怎样的关联? 1、便捷的获取网站中动态加载的数据 2、便捷实现模拟登录 什么是selenium模块&#x…

【python】PyQt5初体验,窗口等组件开发技巧,面向对象方式开发流程实战

✨✨ 欢迎大家来到景天科技苑✨✨ 🎈🎈 养成好习惯,先赞后看哦~🎈🎈 🏆 作者简介:景天科技苑 🏆《头衔》:大厂架构师,华为云开发者社区专家博主,…

文件扫描工具哪个好?便捷的文件扫描工具推荐

对于初入职场的大学毕业生,申请就业补贴是一项不可忽视的福利。 它不仅能够为新生活带来经济上的缓解,也有助于职业生涯的顺利起步。面对申请过程中需提交的文件,如纸质劳动合同,不必烦恼。市面上众多文件扫描软件能助你一臂之力…

轮式机器人Swiss-Mile城市机动性大提升:强化学习引领未来城市物流

喜好儿小斥候消息,苏黎世联邦理工学院的研究团队成功开发了一款革命性的机器人控制系统,该系统采用强化学习技术,使轮式四足机器人在城市环境中的机动性和速度得到了显著提升。 喜好儿网 这款专为轮腿四足动物设计的控制系统,能…

一种基于图卷积创新的电场强度监测模型,原创未发表!!!

声明:文章是从本人公众号中复制而来,因此,想最新最快了解各类算法的家人,可关注我的VX公众号:python算法小当家,不定期会有很多免费代码分享~ 一种基于图卷积创新的电场强度监测模型,原创未发表…

环境配置02:CUDA安装

1. CUDA安装 Nvidia官网下载对应版本CUDA Toolkit CUDA Toolkit 12.1 Downloads | NVIDIA Developer CUDA Toolkit 12.5 Downloads | NVIDIA Developer 安装配置步骤参考:配置显卡cuda与配置pytorch - 知乎 (zhihu.com) 2. 根据CUDA版本,安装cudnn …

内容安全复习 2 - 网络信息内容的获取与表示

文章目录 信息内容的获取网络信息内容的类型网络媒体信息获取方法 信息内容的表示视觉信息视觉特征表达文本特征表达音频特征表达 信息内容的获取 网络信息内容的类型 网络媒体信息 传统意义上的互联网网站公开发布信息,网络用户通常可以基于网络浏览器获得。网络…

mysql8.0找不到my.ini

报错问题解释: MySQL 8.0 在Windows系统中通常不需要 my.ini 文件,因为安装程序会在 %PROGRAMDATA%\MySQL\MySQL Server 8.0\ (通常是 C:\ProgramData\MySQL\MySQL Server 8.0\)创建默认的配置文件。如果你的系统中找不到 my.ini…

ArcGIS查找相同图斑、删除重复图斑

​ 点击下方全系列课程学习 点击学习—>ArcGIS全系列实战视频教程——9个单一课程组合系列直播回放 点击学习——>遥感影像综合处理4大遥感软件ArcGISENVIErdaseCognition 这次是上次 今天分享一下,很重要却被大家忽略的两个工具 这两个工具不仅可以找出属性…

解决电脑关机难题:电脑关不了机的原因以及方法

在使用电脑的日常生活中,有时会遇到一些烦人的问题,其中之一就是电脑关不了机。当您尝试关闭电脑时,它可能会停留在某个界面,或者根本不响应关机指令。这种情况不仅令人困惑,还可能导致数据丢失或系统损坏。 在本文中…

编译xlnt开源库源码, 使用c++读写excel文件

编译xlnt开源库源码,在linux平台使用c读写excel文件 下载xnlt源码 官方网站https://tfussell.gitbooks.io/xlnt/content/ 下载地址https://github.com/tfussell/xlnt 下载libstudxml开源库源码 下载地址https://github.com/kamxgal/libstudxml 下载xnlt源码 官方网站https://…

Walrus:去中心化存储和DA协议,可以基于Sui构建L2和大型存储

Walrus是为区块链应用和自主代理提供的创新去中心化存储网络。Walrus存储系统今天以开发者预览版的形式发布,面向Sui开发者征求反馈意见,并预计很快会向其他Web3社区广泛推广。 通过采用纠删编码创新技术,Walrus能够快速且稳健地将非结构化数…

C++的动态内存分配

使用new/delete操作符在堆中分配/释放内存 //使用new操作符在堆中分配内存int* p1 new int;*p1 2234;qDebug() << "数字是&#xff1a;" << *p1;//使用delete操作符在堆中释放内存delete p1;在分配内存的同时初始化 //在分配内存的时初始化int* p2 n…

海外云手机自动化管理,高效省力解决方案

不论是企业还是个人&#xff0c;对于海外社媒的营销都是需要自动化管理的&#xff0c;因为自动化管理不仅省时省力&#xff0c;而且还节约成本&#xff1b; 海外云手机的自动化管理意味着什么&#xff1f;那就是企业无需再投入大量的人力和时间去逐一操作和监控每一台设备。 通…

k8s学习--OpenKruise详细解释以及原地升级及全链路灰度发布方案

文章目录 OpenKruise简介OpenKruise来源OpenKruise是什么&#xff1f;核心组件有什么&#xff1f;有什么特性和优势&#xff1f;适用于什么场景&#xff1f; 什么是OpenKruise的原地升级原地升级的关键特性使用原地升级的组件原地升级的工作原理 应用环境一、OpenKruise部署1.安…