昇思25天学习打卡营第7天|深度学习流程全解析:从模型训练到评估

目录

构建数据集

定义神经网络模型

定义超参、损失函数和优化器

超参

损失函数

优化器

训练与评估


构建数据集


        首先从数据集 Dataset加载代码,构建数据集。

        代码如下:

#引入了必要的库和模块,像 mindspore 以及相关的数据处理模块等等。  
import mindspore  
from mindspore import nn  
from mindspore.dataset import vision, transforms  
from mindspore.dataset import MnistDataset  
# Download data from open datasets  
#定义了一个下载函数,用于从特定的 url 下载 MNIST 数据集的压缩文件,并明确了保存路径。  
from download import download  
url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/" \  "notebook/datasets/MNIST_Data.zip"  
path = download(url, "./", kind="zip", replace=True)  
#定义了一个叫做 datapipe 的函数,这个函数是用来处理数据集的。  
def datapipe(path, batch_size):  #定义了一个名为 image_transforms 的列表  image_transforms = [  #将图像的像素值缩放到 0 到 1 的范围  vision.Rescale(1.0 / 255.0, 0),  #对图像进行标准化处理,使用给定的均值和标准差  vision.Normalize(mean=(0.1307,), std=(0.3081,)),  #改变图像的数据布局  vision.HWC2CHW()  ]  #定义了一个名为 label_transform 的操作,用于将标签转换为 mindspore.int32 类型。  label_transform = transforms.TypeCast(mindspore.int32)  #通过 MnistDataset 类读取指定路径的数据集。  dataset = MnistDataset(path)  #使用 map 方法对数据集中的图像应用 image_transforms 中的变换操作,对标签应用 label_transform 操作。  dataset = dataset.map(image_transforms, 'image')  dataset = dataset.map(label_transform, 'label')  #使用 batch 方法将数据集按照指定的 batch_size 进行分批处理。  dataset = dataset.batch(batch_size)  #函数返回处理后的数据集。  return dataset  

        运行结果:

        使用 datapipe 函数分别对训练集和测试集进行处理。为训练集和测试集指定了不同的路径,然而批大小均为 64 。处理结束后,将所得结果分别存放在 train_dataset 和 test_dataset 这两个变量当中,以便后续用于模型的训练与测试。

定义神经网络模型


        从网络构建中加载代码,构建一个神经网络模型。

        代码如下:

class Network(nn.Cell):  def __init__(self):  super().__init__()  self.flatten = nn.Flatten()  self.dense_relu_sequential = nn.SequentialCell(  nn.Dense(28*28, 512),  nn.ReLU(),  nn.Dense(512, 512),  nn.ReLU(),  nn.Dense(512, 10)  )  def construct(self, x):  x = self.flatten(x)  logits = self.dense_relu_sequential(x)  return logits  model = Network()  

        分析:这段代码定义了一个名为 Network 的类,它继承自 nn.Cell 。

        在 __init__ 方法(构造方法)中:

        调用了父类的构造方法。

        定义了一个 nn.Flatten 层用于展平输入数据。

        定义了一个名为 dense_relu_sequential 的序列层,其中包含了三个全连接层(nn.Dense)和两个 ReLU 激活函数层。

        在 construct 方法(前向传播方法)中,首先使用 flatten 层对输入 x 进行展平操作,然后将展平后的结果传入 dense_relu_sequential 序列层得到预测结果 logits 并返回。

        最后,创建了一个 Network 类的实例并将其赋值给 model 变量。

定义超参、损失函数和优化器


超参

        超参是可调节的参数,能掌控模型训练优化的进程,不同值可能影响模型训练与收敛速度。现今,深度学习模型多采用批量随机梯度下降算法优化。

        就优化来说,超参是影响模型性能收敛的关键。常见的训练超参有:

        训练轮次(epoch):指训练中遍历数据集的次数。

        批次大小(batch size):数据集分批训练,其每个批次数据的大小就是 batch size 。过小则耗时且梯度震荡,不利收敛;过大则梯度方向不变,易陷局部极小值。所以要选合适的 batch size ,以提升精度和实现全局收敛。

        学习率(learning rate):偏小会使收敛变慢,偏大可能导致训练不收敛等问题。梯度下降法常用于模型误差的参数优化,通过多次迭代和最小化损失函数预估参数,学习率控制着迭代中的学习进程。

        代码如下:

#训练轮次设置为 3 次。  
epochs = 3  
#批次大小设定为 64 。  
batch_size = 64  
#学习率设置为 0.01 (1e-2 表示 10 的 -2 次方,即 0.01 )  
learning_rate = 1e-2  

损失函数

        损失函数(loss function)用于衡量模型的预测值(logits)与目标值(targets)之间的偏差。在训练模型之初,随机初始化的神经网络模型往往会给出错误的预测结果。损失函数会评判预测结果和目标值的差异程度,模型训练的目的就是减小损失函数所计算出的误差。

        常见的损失函数有用于回归任务的 nn.MSELoss(均方误差)和用于分类的 nn.NLLLoss(负对数似然)等。nn.CrossEntropyLoss 融合了 nn.LogSoftmax 和 nn.NLLLoss,能够对 logits 进行标准化并计算预测误差。

        代码如下:

loss_fn = nn.CrossEntropyLoss()  

        分析:定义了一个损失函数变量 loss_fn ,并将其赋值为 nn.CrossEntropyLoss() ,即使用了 PyTorch 库中用于计算交叉熵损失的函数。在后续的模型训练中,会使用这个定义好的损失函数来计算模型预测结果与真实标签之间的误差。

优化器

        模型优化(Optimization)是于每个训练步骤中调整模型参数以降低模型误差的过程。MindSpore 提供多种优化算法的实现,称为优化器(Optimizer)。优化器内部界定了模型的参数优化流程(即梯度如何更新至模型参数),所有优化逻辑皆封装于优化器对象内。在此,我们运用 SGD(Stochastic Gradient Descent)优化器。

        我们借助 model.trainable_params()方法获取模型的可训练参数,并输入学习率超参来初始化优化器。

        代码如下:

optimizer = nn.SGD(model.trainable_params(), learning_rate=learning_rate)  

        分析:定义了一个优化器变量 optimizer ,使用了 PyTorch 中的随机梯度下降(Stochastic Gradient Descent,简称 SGD)优化器。它通过 model.trainable_params() 方法获取模型中可训练的参数,并将学习率设置为 learning_rate 这个变量所指定的值来初始化优化器。

训练与评估


        第一步:定义了模型训练的相关函数和训练循环的逻辑。包括前向传播计算损失、获取梯度、单步训练以及整个训练过程的循环,并定期打印损失信息。

        代码如下:

# 定义前向传播函数  
def forward_fn(data, label):  # 模型对输入数据进行预测得到预测值 logits  logits = model(data)  # 根据预测值和真实标签计算损失  loss = loss_fn(logits, label)  # 返回损失和预测值  return loss, logits  
# 获取梯度计算函数  
grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)  
# 定义单步训练的函数  
def train_step(data, label):  # 调用梯度计算函数,得到损失和辅助信息,并计算梯度  (loss, _), grads = grad_fn(data, label)  # 优化器根据梯度更新模型参数  optimizer(grads)  # 返回损失值  return loss  
def train_loop(model, dataset):  # 获取数据集的大小  size = dataset.get_dataset_size()  # 设置模型为训练模式  model.set_train()  # 遍历数据集中的批次  for batch, (data, label) in enumerate(dataset.create_tuple_iterator()):  # 执行单步训练并获取损失值  loss = train_step(data, label)  # 每 100 个批次打印一次损失信息  if batch % 100 == 0:  loss, current = loss.asnumpy(), batch  print(f"loss: {loss:>7f}  [{current:>3d}/{size:>3d}]")  

        第二步:定义了一个测试循环的函数,用于在给定的数据集上对模型进行测试评估。计算了测试数据的平均损失和准确率,并打印出测试结果。

        代码如下:

def test_loop(model, dataset, loss_fn):  # 获取数据集中的批次数  num_batches = dataset.get_dataset_size()  # 设置模型为评估模式(非训练模式)  model.set_train(False)  # 初始化一些统计变量  total, test_loss, correct = 0, 0, 0  # 遍历数据集中的数据和标签  for data, label in dataset.create_tuple_iterator():  # 模型对输入数据进行预测  pred = model(data)  # 累计数据的数量  total += len(data)  # 累计损失值  test_loss += loss_fn(pred, label).asnumpy()  # 计算预测正确的数量  correct += (pred.argmax(1) == label).asnumpy().sum()  # 计算平均损失  test_loss /= num_batches  # 计算准确率  correct /= total  # 打印测试结果  print(f"Test: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")  

        第三步:进行了模型的训练和测试。首先定义了损失函数和优化器,然后按照设定的轮次数进行训练和测试,每一轮都打印轮次信息,最后打印训练完成的提示。

        代码如下:

loss_fn = nn.CrossEntropyLoss()  # 定义交叉熵损失函数  
optimizer = nn.SGD(model.trainable_params(), learning_rate=learning_rate)  # 定义随机梯度下降优化器,并传入模型的可训练参数和学习率  
for t in range(epochs):  # 进行多个训练轮次  print(f"Epoch {t+1}\n-------------------------------")  # 打印当前轮次信息  train_loop(model, train_dataset)  # 执行训练循环  test_loop(model, test_dataset, loss_fn)  # 执行测试循环  
print("Done!")  # 打印训练完成的提示  

        运行结果:

Epoch 1  
-------------------------------  
loss: 0.250805  [  0/938]  
loss: 0.130063  [100/938]  
loss: 0.074891  [200/938]  
loss: 0.330714  [300/938]  
loss: 0.298072  [400/938]  
loss: 0.177415  [500/938]  
loss: 0.469457  [600/938]  
loss: 0.380078  [700/938]  
loss: 0.225529  [800/938]  
loss: 0.200035  [900/938]  
Test:   Accuracy: 93.9%, Avg loss: 0.207253   Epoch 2  
-------------------------------  
loss: 0.289100  [  0/938]  
loss: 0.328313  [100/938]  
loss: 0.138099  [200/938]  
loss: 0.096204  [300/938]  
loss: 0.162835  [400/938]  
loss: 0.335097  [500/938]  
loss: 0.134196  [600/938]  
loss: 0.332896  [700/938]  
loss: 0.261795  [800/938]  
loss: 0.154485  [900/938]  
Test:   Accuracy: 94.6%, Avg loss: 0.181880   Epoch 3  
-------------------------------  
loss: 0.338207  [  0/938]  
loss: 0.171585  [100/938]  
loss: 0.223193  [200/938]  
loss: 0.174970  [300/938]  
loss: 0.246406  [400/938]  
loss: 0.149053  [500/938]  
loss: 0.281349  [600/938]  
loss: 0.109779  [700/938]  
loss: 0.261625  [800/938]  
loss: 0.060637  [900/938]  
Test:   Accuracy: 95.2%, Avg loss: 0.158948   Done!  

      运行截图:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/371499.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

使用WinSCP工具连接Windows电脑与Ubuntu虚拟机实现文件共享传输

一。环境配置 1.首先你的Windows电脑上安装了VMware虚拟机,虚拟机装有Ubuntu系统; 2.在你的windows电脑安装了WinSCP工具; 3.打开WinSCP工具默认是这样 二。设置WinSCP连接 打开WinSCP,点击新标签页,进入到如下图的…

【持续集成_03课_Jenkins生成Allure报告及Sonar静态扫描】

1、 一、构建之后的配置 1、安装allure插件 安装好之后,可以在这里搜到已经安装的 2、配置allure的allure-commandline 正常配置,是要么在工具里配置,要么在系统里配置 allure-commandline是在工具里进行配置 两种方式进行配置 1&#xff…

关闭vue3中脑瘫的ESLine

在创建vue3的时候脑子一抽选了ESLine,然后这傻卵子ESLine老是给我报错 博主用的idea开发前端 ,纯粹是用不惯vscode 关闭idea中的ESLine,这个只是取消红色波浪线, 界面中的显示 第二步,在vue.config.js中添加 lintOnSave: false 到这里就ok了,其他的我试过了一点用没有

STM32-ADC+DMA

本内容基于江协科技STM32视频学习之后整理而得。 文章目录 1. ADC模拟-数字转换器1.1 ADC模拟-数字转换器1.2 逐次逼近型ADC1.3 ADC框图1.4 ADC基本结构1.5 输入通道1.6 规则组的转换模式1.6.1 单次转换,非扫描模式1.6.2 连续转换,非扫描模式1.6.3 单次…

Python28-7.4 独立成分分析ICA分离混合音频

独立成分分析(Independent Component Analysis,ICA)是一种统计与计算技术,主要用于信号分离,即从多种混合信号中提取出独立的信号源。ICA在处理盲源分离(Blind Source Separation,BSS&#xff0…

Spring源码十七:Bean实例化入口探索

上一篇Spring源码十六:Bean名称转化我们讨论doGetBean的第一个方法transformedBeanName方法,了解Spring是如何处理特殊的beanName(带&符号前缀)与Spring的别名机制。今天我们继续往方法下面看: doGetBean 这个方法…

按键控制LED流水灯模式定时器时钟

目录 1.定时器 2. STC89C52定时器资源 3.定时器框图 4. 定时器工作模式 5.中断系统 1)介绍 2)流程图:​编辑 3)STC89C52中断资源 4)定时器和中断系统 5)定时器的相关寄存器 6.按键控制LED流水灯模…

对话大模型Prompt是否需要礼貌点?

大模型相关目录 大模型,包括部署微调prompt/Agent应用开发、知识库增强、数据库增强、知识图谱增强、自然语言处理、多模态等大模型应用开发内容 从0起步,扬帆起航。 基于Dify的QA数据集构建(附代码)Qwen-2-7B和GLM-4-9B&#x…

【机器学习】机器学习与时间序列分析的融合应用与性能优化新探索

文章目录 引言第一章:机器学习在时间序列分析中的应用1.1 数据预处理1.1.1 数据清洗1.1.2 数据归一化1.1.3 数据增强 1.2 模型选择1.2.1 自回归模型1.2.2 移动平均模型1.2.3 长短期记忆网络1.2.4 卷积神经网络 1.3 模型训练1.3.1 梯度下降1.3.2 随机梯度下降1.3.3 A…

平台稳定性里程碑 | Android 15 Beta 3 已发布

作者 / 产品管理副总裁、Android 开发者 Matthew McCullough 从近期发布的 Beta 3 开始,Android 15 达成了平台稳定性里程碑版本,这意味着开发者 API 和所有面向应用的行为都已是最终版本,您可以查阅它们并将其集成到您的应用中,并…

Pandas 入门 15 题

Pandas 入门 15 题 1. 相关知识点1.1 修改DataFrame列名1.2 获取行列数1.3 显示前n行1.4 条件数据选取值1.5 创建新列1.6 删去重复的行1.7 删除空值的数据1.9 修改列名1.10 修改数据类型1.11 填充缺失值1.12 数据上下合并1.13 pivot_table透视表的使用1.14 melt透视表的使用1.1…

使用Vue实现前后端分离 spring框架返回json数据中文乱码

java json数据返回值中文乱码 出现&#xff1f;&#xff1f;&#xff1f; - _xkoko - 博客园 (cnblogs.com) 引入js的script标签到底是放在head还是body中_html页面中用<script>标签引入js代码,该标签放在<head>标签中和放在<body>标签-CSDN博客 vue.js 的问…

golang结合neo4j实现权限功能设计

neo4j 是非关系型数据库之图形数据库&#xff0c;这里不再赘述。 传统关系数据库基于rbac实现权限, user ---- role ------permission,加上中间表共5张表。 如果再添上部门的概念&#xff1a;用户属于部门&#xff0c;部门拥有 角色&#xff0c;则又多了一层&#xff1a; user-…

MySQL之备份与恢复(七)

备份与恢复 文件系统快照 规划LVM备份 LVM快照备份也是有开销的。服务器写到原始卷的越多&#xff0c;引发的额外开销也越多。当服务器随机修改许多不同块时&#xff0c;磁头需要需要自写时复制空间来来回回寻址&#xff0c;并且将数据的老版本写到写时复制空间。从快照中读…

网络基础:IS-IS协议

IS-IS&#xff08;Intermediate System to Intermediate System&#xff09;是一种链路状态路由协议&#xff0c;最初由 ISO&#xff08;International Organization for Standardization&#xff09;为 CLNS&#xff08;Connectionless Network Service&#xff09;网络设计。…

Windows电脑下载、安装VS Code的方法

本文介绍Visual Studio Code&#xff08;VS Code&#xff09;软件在Windows操作系统电脑中的下载、安装、运行方法。 Visual Studio Code&#xff08;简称VS Code&#xff09;是一款由微软开发的免费、开源的源代码编辑器&#xff0c;支持跨平台使用&#xff0c;可在Windows、m…

apk反编译修改教程系列-----修改apk 解除软件限制功能 实例操作步骤解析_3【二十二】

在前面的几期博文中有过解析去除apk中功能权限的反编译步骤。另外在以往博文中也列举了修改apk中选项功能权限的操作方法。今天以另外一款apk作为演示修改反编译去除软件功能限制的步骤。兴趣的友友可以参考其中的修改过程。 课程的目的是了解apk中各个文件的具体作用以及简单…

【经验篇】Spring Data JPA开启批量更新时乐观锁失效问题

乐观锁机制 什么是乐观锁&#xff1f; 乐观锁的基本思想是&#xff0c;认为在大多数情况下&#xff0c;数据访问不会导致冲突。因此&#xff0c;乐观锁允许多个事务同时读取和修改相同的数据&#xff0c;而不进行显式的锁定。在提交事务之前&#xff0c;会检查是否有其他事务…

浏览器插件利器-allWebPluginV2.0.0.14-stable版发布

allWebPlugin简介 allWebPlugin中间件是一款为用户提供安全、可靠、便捷的浏览器插件服务的中间件产品&#xff0c;致力于将浏览器插件重新应用到所有浏览器。它将现有ActiveX插件直接嵌入浏览器&#xff0c;实现插件加载、界面显示、接口调用、事件回调等。支持谷歌、火狐等浏…

【音视频 | RTSP】RTSP协议详解 及 抓包例子解析(详细而不赘述)

&#x1f601;博客主页&#x1f601;&#xff1a;&#x1f680;https://blog.csdn.net/wkd_007&#x1f680; &#x1f911;博客内容&#x1f911;&#xff1a;&#x1f36d;嵌入式开发、Linux、C语言、C、数据结构、音视频&#x1f36d; &#x1f923;本文内容&#x1f923;&a…