【昇思初学入门】第七天打卡-模型训练

训练模型

学习心得

  1. 构建数据集。这通常包括训练集、验证集(可选)和测试集。训练集用于训练模型,验证集用于调整超参数和监控过拟合,测试集用于评估模型的泛化能力。
    (mindspore提供数据集https://www.mindspore.cn/docs/zh-CN/r2.3.0rc2/api_python/mindspore.dataset.html)
  2. 定义神经网络模型。这通常涉及到选择适当的网络架构(如卷积神经网络CNN、循环神经网络RNN、全连接网络等)和激活函数。
    创建模型类:使用mindspore.nn.Cell作为基类,创建一个自定义的神经网络模型类。
    义网络层:定义所需的网络,如卷积层、全连接层、激活函数和池化层等
    实现construct方法:在construct方法中,使用定义好的网络层构建前向网络
  3. 定义超参、损失函数和优化器。
    设置超参数:设置超参数,如学习率、批次大小、训练轮数等。
    定义损失函数:选择适当的损失函数,如均方误差(MSE)用于回归问题,交叉熵损失(Cross-Entropy Loss)用于分类问题等。
    设置优化器:选择合适的优化器,如随机梯度下降(SGD)、Adam等,用于根据损失函数的梯度更新模型参数。
  4. 训练和评估。
    循环输入数据来训练模型。一次数据集的完整迭代循环称为一轮(epoch)。每轮执行训练时包括两个步骤:
    训练:迭代训练数据集,并尝试收敛到最佳参数。
    验证/测试:迭代测试数据集,以检查模型性能是否提升。

笔记

import mindspore
from mindspore import nn
from mindspore.dataset import vision, transforms
from mindspore.dataset import MnistDataset# Download data from open datasets
from download import downloadurl = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/" \"notebook/datasets/MNIST_Data.zip"
path = download(url, "./", kind="zip", replace=True)def datapipe(path, batch_size):image_transforms = [vision.Rescale(1.0 / 255.0, 0),vision.Normalize(mean=(0.1307,), std=(0.3081,)),vision.HWC2CHW()]label_transform = transforms.TypeCast(mindspore.int32)dataset = MnistDataset(path)dataset = dataset.map(image_transforms, 'image')dataset = dataset.map(label_transform, 'label')dataset = dataset.batch(batch_size)return datasettrain_dataset = datapipe('MNIST_Data/train', batch_size=64)
test_dataset = datapipe('MNIST_Data/test', batch_size=64)class Network(nn.Cell):def __init__(self):super().__init__()self.flatten = nn.Flatten()self.dense_relu_sequential = nn.SequentialCell(nn.Dense(28*28, 512),nn.ReLU(),nn.Dense(512, 512),nn.ReLU(),nn.Dense(512, 10))def construct(self, x):x = self.flatten(x)logits = self.dense_relu_sequential(x)return logitsmodel = Network()epochs = 3
batch_size = 64
learning_rate = 1e-2loss_fn = nn.CrossEntropyLoss()
optimizer = nn.SGD(model.trainable_params(), learning_rate=learning_rate)# Define forward function
def forward_fn(data, label):logits = model(data)loss = loss_fn(logits, label)return loss, logits# Get gradient function
grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)# Define function of one-step training
def train_step(data, label):(loss, _), grads = grad_fn(data, label)optimizer(grads)return lossdef train_loop(model, dataset):size = dataset.get_dataset_size()model.set_train()for batch, (data, label) in enumerate(dataset.create_tuple_iterator()):loss = train_step(data, label)if batch % 100 == 0:loss, current = loss.asnumpy(), batchprint(f"loss: {loss:>7f}  [{current:>3d}/{size:>3d}]")def test_loop(model, dataset, loss_fn):num_batches = dataset.get_dataset_size()model.set_train(False)total, test_loss, correct = 0, 0, 0for data, label in dataset.create_tuple_iterator():pred = model(data)total += len(data)test_loss += loss_fn(pred, label).asnumpy()correct += (pred.argmax(1) == label).asnumpy().sum()test_loss /= num_batchescorrect /= totalprint(f"Test: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")loss_fn = nn.CrossEntropyLoss()
optimizer = nn.SGD(model.trainable_params(), learning_rate=learning_rate)for t in range(epochs):print(f"Epoch {t+1}\n-------------------------------")train_loop(model, train_dataset)test_loop(model, test_dataset, loss_fn)
print("Done!")

结果
训练结果

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/360610.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

深度学习:从理论到应用的全面解析

引言 深度学习作为人工智能(AI)的核心技术之一,在过去的十年中取得了显著的进展,并在许多领域中展示了其强大的应用潜力。本文将从理论基础出发,探讨深度学习的最新进展及其在各领域的应用,旨在为读者提供全…

web自动化(一)selenium安装环境搭建、DrissionPage安装

selenium 简介 selenium是企业广泛应用的web自动化框架 selenium 三大组件 selenium IDE 浏览器插件 实现脚本录制 webDriver 实现对浏览器进行各种操作 Grid 分布式执行 用例同时在多个浏览器执行,提高测试效率 问题:环境搭建复杂,浏览器版…

PPT的精细化优化与提升策略

👏👏👏欢迎来到我的博客 ! 亲爱的朋友们,欢迎您们莅临我的博客!这是一个分享知识、交流想法、记录生活的温馨角落。在这里,您可以找到我对世界独特视角的诠释,也可以与我一起探讨各种话题&#…

STL——常用算法(二)

一、常用拷贝和替换算法 1.copy #include <iostream> #include <vector> #include <algorithm> using namespace std; void printVector(int val) {cout << val << " "; } void test01() {vector<int>v1;for (int i 0; i <…

【大数据】—谁是世界上最富的人?

引言 在2024年&#xff0c;全球财富的分布再次成为公众和经济学家关注的焦点。随着经济的波动和新兴市场的崛起&#xff0c;亿万富翁的名单也在不断变化。本文将深入探讨这一现象&#xff0c;通过最新的数据可视化分析&#xff0c;揭示世界上最富有的人在2024年的财富状况和趋…

【Linux】进程信号_1

文章目录 八、进程信号1.信号 未完待续 八、进程信号 1.信号 信号和信号量之间没有任何关系。信号是Linux系统提供的让用户/进程给其他进程发送异步信息的一种方式。 常见信号&#xff1a; 当信号产生时&#xff0c;可选的处理方式有三种&#xff1a;①忽略此信号。②执行该…

(七)React:useEffect的理解和使用

1. useEffect的概念理解 useEffect是一个React Hook函数&#xff0c;用于React组件中创建不是由事件引起而是由渲染本身引起的操作&#xff0c;比如发送AJAX请求&#xff0c;更改DOM等等 说明&#xff1a;上面的组件中没有发生任何的用户事件&#xff0c;组件渲染完毕之后就需…

Ollama模型部署工具在Linux平台的部署

1.新建普通用户dmx&#xff08;可选&#xff09; [rootnode3 ~]$ useradd dmx2.切换普通用户dmx环境(可选) [dmxnode3 ~]$ su - dmx3.下载ollama-linux-amd64服务 下载ollama-linux-amd64到 ~/server目录&#xff0c;并将ollama-linux-amd64服务重命名为ollamaEED curl -L …

圈复杂度.

圈复杂度是衡量代码的重要标准 配置&#xff1a; eslint里面&#xff1a;rules&#xff1a;complexity&#xff1a;[error,10]

Linux-笔记 全志T113移植正点4.3寸RGB屏幕笔记

目录 前言 线序整理 软件 显示调试 触摸调试 背光调试 前言 由于手头有一块4.3寸的RGB屏幕(触摸IC为GT1151)&#xff0c;正好开发板上也有40Pin的RGB接口&#xff0c;就想着给移植一下&#xff0c;前期准备工作主要是整理好线序&#xff0c;然后用转接板与杜邦线连接验证好…

大模型该如何和医疗方面结合创造出更大的价值?

前言 在数字化与智能化浪潮的推动下&#xff0c;大模型技术正以其强大的数据处理和学习能力&#xff0c;成为引领新一轮科技革命的重要力量。而医疗领域&#xff0c;作为与人类健康息息相关的重要领域&#xff0c;与大模型的结合无疑将释放出巨大的价值&#xff0c;为人类的健…

Java三层框架的解析

引言&#xff1a;欢迎各位点击收看本篇博客&#xff0c;在历经很多的艰辛&#xff0c;我也是成功由小白浅浅进入了入门行列&#xff0c;也是收货到很多的知识&#xff0c;每次看黑马的JavaWeb课程视频&#xff0c;才使一个小菜鸡见识到了Java前后端是如何进行交互访问的&#x…

游戏服务器研究二:大世界的 scale 问题

这是一个非常陈旧的话题了&#xff0c;没什么新鲜的&#xff0c;但本人对 scale 比较感兴趣&#xff0c;所以研究得比较多。 本文不会探讨 MMO 类的网游提升单服承载人数有没有意义&#xff0c;只单纯讨论技术上如何实现。 像 moba、fps、棋牌、体育竞技等 “开房间类型的游戏…

如何挑选洗地机?盘点口碑最好的四大洗地机

在购买洗地机这种智能家电时&#xff0c;大家都应该格外谨慎。毕竟&#xff0c;洗地机价格不菲&#xff0c;精打细算&#xff0c;确保物尽其用才是最重要的。谁都不想花了高价买回来却让它闲置在墙角落灰尘。买之前我们还是需要对自己的需求做一个清晰的判断&#xff0c;实用性…

gitee添加别人的仓库后,在该仓库里添加文件夹/文件

一、在指定分支里添加文件夹&#xff08;如果库主没有创建分支&#xff0c;自己还要先创建分支&#xff09; eg:以在一个项目里添加视图文件为例&#xff0c;用Echarts分支在usr/views目录下添加Echarts文件夹&#xff0c;usr/views/Echarts目录下添加index.vue 1.切换为本地仓…

基于PHP的奶茶商城系统

有需要请加文章底部Q哦 可远程调试 基于PHP的奶茶商城系统 一 介绍 此奶茶商城系统基于原生PHP开发&#xff0c;数据库mysql&#xff0c;ajax实现数据交换。系统角色分为用户和管理员。系统在原有基础上添加了糖度的选择。 技术栈 phpmysqlajaxphpstudyvscode 二 功能 用户…

[20] Opencv_CUDA应用之 关键点检测器和描述符

Opencv_CUDA应用之 关键点检测器和描述符 本节中会介绍找到局部特征的各种方法&#xff0c;也被称为关键点检测器关键点(key-point)是表征图像的特征点&#xff0c;可用于准确定义对象 1. 加速段测试特征功能检测器 FAST算法用于检测角点作为图像的关键点&#xff0c;通过对…

2-16 基于matlab的动载荷简支梁模态分析程序

基于matlab的动载荷简支梁模态分析程序&#xff0c;可调节简支梁参数&#xff0c;包括截面宽、截面高、梁长度、截面惯性矩、弹性模量、密度。输出前四阶固有频率&#xff0c;任意时刻、位置的响应结果。程序已调通&#xff0c;可直接运行。 2-16 matlab 动载荷简支梁模态分析 …

什么是营销翻译?为什么要使用它?

营销翻译是将营销活动和宣传品翻译成不同语言的过程。它可能涉及翻译您的&#xff1a; 网站营销文案&#xff0c;社交媒体帖子&#xff0c;演示文稿&#xff0c;新闻稿&#xff0c;产品包装&#xff0c;产品说明&#xff0c;海报&#xff0c;宣传册&#xff0c;以及 虽然企业…