昇思25天学习打卡营第08天 | 模型训练

昇思25天学习打卡营第08天 | 模型训练

文章目录

  • 昇思25天学习打卡营第08天 | 模型训练
    • 超参数
    • 损失函数
    • 优化器
      • 优化过程
    • 训练与评估
    • 总结
    • 打卡

模型训练一般遵循四个步骤:

  1. 构建数据集
  2. 定义神经网络模型
  3. 定义超参数、损失函数和优化器
  4. 输入数据集进行训练和评估

构建数据集和网络模型在之前的内容在已经涉及,不再赘述。

超参数

超参数(Hyperparameters)是可以调整的参数,可以控制模型训练的过程。

深度学习模型多采用随机梯度下降算法SGD进行优化:
w t + 1 = w t − η 1 n ∑ x ∈ B ∇ l ( x , w t ) w_{t+1}=w_t- \eta\frac1n\sum_{x\in B}\nabla l(x,w_t) wt+1=wtηn1xBl(x,wt)
其中, η \eta η是学习率, n n n是batch大小,都是超参数,这两个参数是直接影响模型性能收敛的重要参数。
一般会定义三个超参数:

  • epoch:遍历数据集的次数
  • batch size:每个批次数据的大小。size 过小导致花费时间多,梯度震荡严重,不利于收敛;size 过大容易陷入局部极小值。
  • learning rate:学习率国小会导致收敛速度变慢;过大则可能会导致训练不收敛。

损失函数

损失函数用于评估模型预测值和目标值之间的误差。
常见的损失函数包括:

  • nn.MSELoss:均方误差,用于回归
  • nn.NLLLoss:负对数似然,用于分类
  • nn.CrossEntropyLoss:结合了nn.LogSoftmaxnn.NLLLoss,可以对logits进行归一化并计算预测误差
loss_fn = nn.CrossEntropyLoss()

优化器

优化器内部定义了模型参数的优化过程,所有的优化逻辑都封装在优化器对象中。

optimizer = nn.SGD(model.trainable_params(), learning_rate=learning_rate)

优化过程

通过自动微分获得的微分函数,计算参数对应的梯度,并传入优化器中,即可实现参数优化。

grads = grad_fn(inputs)
optimizer(grads)

训练与评估

遍历一次数据集被称为一轮(epoch),每轮执行训练时包含两个步骤:

  1. 训练:迭代训练数据集,并尝试收敛到最佳参数。
  2. 验证/测试:迭代测试数据集,检查模型性能是否提升。
# Define forward function
def forward_fn(data, label):logits = model(data)loss = loss_fn(logits, label)return loss, logits# Get gradient function
grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)# Define function of one-step training
def train_step(data, label):(loss, _), grads = grad_fn(data, label)optimizer(grads)return lossdef train_loop(model, dataset):size = dataset.get_dataset_size()model.set_train()for batch, (data, label) in enumerate(dataset.create_tuple_iterator()):loss = train_step(data, label)if batch % 100 == 0:loss, current = loss.asnumpy(), batchprint(f"loss: {loss:>7f}  [{current:>3d}/{size:>3d}]")def test_loop(model, dataset, loss_fn):num_batches = dataset.get_dataset_size()model.set_train(False)total, test_loss, correct = 0, 0, 0for data, label in dataset.create_tuple_iterator():pred = model(data)total += len(data)test_loss += loss_fn(pred, label).asnumpy()correct += (pred.argmax(1) == label).asnumpy().sum()test_loss /= num_batchescorrect /= totalprint(f"Test: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

训练过程一般为:

loss_fn = nn.CrossEntropyLoss()
optimizer = nn.SGD(model.trainable_params(), learning_rate=learning_rate)for t in range(epochs):print(f"Epoch {t+1}\n-------------------------------")train_loop(model, train_dataset)test_loop(model, test_dataset, loss_fn)

总结

这一节的内容对深度学习模型训练的一般过程进行了详细的介绍,从数据集构建到模型定义,接着定义超参数并选择合适的值,创建损失函数和优化器对象完成训练前的准备。通过封装一个模型调用和loss计算的前向计算函数并自动微分,在每个epoch中计算loss并优化参数,从而完成模型的训练。

打卡

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/370228.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Excel、RStudio计算T检测的具体操作步骤】

目录 一、基础知识1.1 显著性检验1.2 等方差T检验、异方差T检验1.3 单尾p、双尾p1.3.1 检验目的不同1.3.2 用法不同1.3.3 如何选择 二、Excel2.1 统计分析工具2.1.1 添加统计分析工具2.1.2 数据分析 2.2 公式 -> 插入函数 -> T.TEST 三、RStudio 一、基础知识 参考: 1.…

无人机运营合格证及无人机驾驶员合格证(AOPA)技术详解

无人机运营合格证及无人机驾驶员合格证(AOPA)技术详解如下: 一、无人机运营合格证 无人机运营合格证是无人机运营企业或个人必须获得的证书,以确保无人机在运营过程中符合相关法规和标准。对于无人机运营合格证的具体要求和申请…

【React】React18 Hooks之useState

目录 useState案例1(直接修改状态)案例2(函数式更新)案例3(受控表单绑定)注意事项1:set函数不会改变正在运行的代码的状态注意事项2:set函数自动批量处理注意事项3:在下次…

【反悔堆 优先队列 临项交换 决策包容性】630. 课程表 III

本文涉及知识点 贪心 反悔堆 优先队列 临项交换 Leetcode630. 课程表 III 这里有 n 门不同的在线课程,按从 1 到 n 编号。给你一个数组 courses ,其中 courses[i] [durationi, lastDayi] 表示第 i 门课将会 持续 上 durationi 天课,并且必…

E4.【C语言】练习:while和getchar的理解

#include <stdio.h> int main() {int ch 0;while ((ch getchar()) ! EOF){if (ch < 0 || ch>9)continue;putchar(ch);}return 0; } 理解上述代码 0-->48 9-->57 if行判断是否为数字&#xff0c;打印数字&#xff0c;不打印非数字

Selenium 切换 frame/iframe

环境&#xff1a; Python 3.8 selenium3.141.0 urllib31.26.19说明&#xff1a; driver.switch_to.frame() # 将当前定位的主体切换为frame/iframe表单的内嵌页面中 driver.switch_to.default_content() # 跳回最外层的页面# 判断元素是否在 frame/ifame 中 # 126 邮箱为例 # …

k8s公网集群安装(1.23.0)

网上搜到的公网搭建k8s都不太一致, 要么说的太复杂, 要么镜像无法下载, 所以写了一个简洁版,小白也能一次搭建成功 使用的都是centos7,k8s版本为1.23.0 使用二台机器搭建的, 三台也是一样的思路1.所有节点分别设置对应主机名 hostnamectl set-hostname master hostnamectl set…

javaSwing图书管理系统

一、 引言 图书管理系统是一个用于图书馆或书店管理图书信息、借阅记录和读者信息的应用程序。本系统使用Java Swing框架进行开发&#xff0c;提供直观的用户界面&#xff0c;方便图书馆管理员或书店工作人员对图书信息进行管理。以下是系统的设计、功能和实现的详细报告。 二…

Matplotlib Artist 1 概览

Matplotlib API中有三层 matplotlib.backend_bases.FigureCanvas&#xff1a;绘制区域matplotlib.backend_bases.Renderer&#xff1a;控制如何在FigureCanvas上绘制matplotlib.artist.Artist&#xff1a;控制render如何进行绘制 开发者95%的时间都是在使用Artist。Artist有两…

【C语言】自定义类型:联合和枚举

前言 前面我们学习了一种自定义类型&#xff0c;结构体&#xff0c;现在我们学习另外两种自定义类型&#xff0c;联合 和 枚举。 目录 一、联合体 1. 联合体类型的声明 2. 联合体的特点 3. 相同成员联合体和结构体对比 4. 联合体大小的计算 5. 用联合体判断当前机…

常见算法和Lambda

常见算法和Lambda 文章目录 常见算法和Lambda常见算法查找算法基本查找&#xff08;顺序查找&#xff09;二分查找/折半查找插值查找斐波那契查找分块查找扩展的分块查找&#xff08;无规律的数据&#xff09; 常见排序算法冒泡排序选择排序插入排序快速排序递归快速排序 Array…

hdu物联网硬件实验2 GPIO亮灯

学院 班级 学号 姓名 日期 成绩 实验题目 GPIO亮灯 实验目的 点亮三个灯闪烁频率为一秒 硬件原理 无 关键代码及注释 const int ledPin1 GREEN_LED; // the number of the LED pin const int ledPin2 YELLOW_LED; const int ledPin3 RED…

githup开了代理push不上去

你们好&#xff0c;我是金金金。 场景 git push出错 解决 cmd查看 git config --global http.proxy git config --global https.proxy 如果什么都没有&#xff0c;代表没设置全局代理&#xff0c;此时如果你开了代理&#xff0c;则执行如下&#xff0c;设置代理 git con…

【深度学习】图形模型基础(5):线性回归模型第四部分:预测与贝叶斯推断

1.引言 贝叶斯推断超越了传统估计方法&#xff0c;它包含三个关键步骤&#xff1a;结合数据和模型形成后验分布&#xff0c;通过模拟传播不确定性&#xff0c;以及利用先验分布整合额外信息。本文将通过实际案例阐释这些步骤&#xff0c;展示它们在预测和推断中的挑战和应用。…

CSS中 实现四角边框效果

效果图 关键代码 border-radius:10rpx ;background: linear-gradient(#fff, #fff) left top,linear-gradient(#fff, #fff) left top,linear-gradient(#fff, #fff) right top,linear-gradient(#fff, #fff) right top,linear-gradient(#fff, #fff) left bottom,linear-gradient(…

12 Dockerfile详解

目录 1. Dockerfile 2. Dockerfile构建过程 2.1. Dockerfile编写规则&#xff1a; 2.2. Docker执行Dockerfile的大致流程 2.3. 总结 3. Dockerfile指令 3.1. FROM 3.2. MAINTAINER 3.3. RUN 3.4. EXPOSE 3.5. WORKDIR 3.6. USER 3.7. ENV 3.8. VOLUME 3.9. ADD …

电商利器——淘宝商品月销量API接口解析

在电商时代&#xff0c;数据就是金钱。对于淘宝商家而言&#xff0c;掌握商品的销量数据无异于掌握了市场的脉搏。如今&#xff0c;淘宝商品月销量API接口的出现&#xff0c;联讯数据让商家如虎添翼&#xff0c;能够更加精准地把握市场动态&#xff0c;优化商品策略。 淘宝商…

利用redis数据库管理代理库爬取cosplay网站-cnblog

爬取cos猎人 数据库管理主要分为4个模块&#xff0c;代理获取模块&#xff0c;代理储存模块&#xff0c;代理测试模块&#xff0c;爬取模块 cos猎人已经倒闭&#xff0c;所以放出爬虫源码 api.py 为爬虫评分提供接口支持 import requests import concurrent.futures import …

鸿蒙开发设备管理:【@ohos.vibrator (振动)】

振动 说明&#xff1a; 开发前请熟悉鸿蒙开发指导文档&#xff1a;gitee.com/li-shizhen-skin/harmony-os/blob/master/README.md点击或者复制转到。 本模块首批接口从API version 8开始支持。后续版本的新增接口&#xff0c;采用上角标单独标记接口的起始版本。 导入模块 imp…

查询某个县区数据,没有的数据用0补充。

加油&#xff0c;新时代打工人&#xff01; 思路&#xff1a; 先查出有数据的县区&#xff0c;用县区编码判断&#xff0c;不存在县区里的数据。然后&#xff0c;用union all进行两个SQL拼接起来。 SELECTt.regionCode,t.regionName,t.testNum,t.sampleNum,t.squareNum,t.crop…