Lnton羚通关于Optimization在【PyTorch】中的基础知识

OPTIMIZING MODEL PARAMETERS (模型参数优化)
现在我们有了模型和数据,是时候通过优化数据上的参数来训练了,验证和测试我们的模型。训练一个模型是一个迭代的过程,在每次迭代中,模型会对输出进行猜测,计算猜测数据与真实数据的误差(损失),收集误差对其参数的导数(正如前一节我们看到的那样),并使用梯度下降优化这些参数。

Prerequisite Code ( 先决代码 )
We load the code from the previous sections on

import torch 
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transformstraining_data = datasets.FashionMNIST(root = "../../data/",train = True,download = True, transform = transforms.ToTensor()
)test_data = datasets.FashionMNIST(root = "../../data/",train = False,download = True, transform = transforms.ToTensor()
)train_dataloader = DataLoader(training_data, batch_size = 32, shuffle = True)
test_dataloader = DataLoader(test_data, batch_size = 32, shuffle = True)class NeuralNetwork(nn.Module):def __init__(self):super(NeuralNetwork, self).__init__()self.flatten = nn.Flatten()self.linear_relu_stack = nn.Sequential(nn.Linear(28 * 28, 512),nn.ReLU(),nn.Linear(512, 512),nn.ReLU(),nn.Linear(512, 10)  )def forward(self, x):out = self.flatten(x)out = self.linear_relu_stack(out)return outmodel = NeuralNetwork()

Hyperparameters ( 超参数 )
超参数是可调节的参数,允许控制模型优化过程,不同的超参数会影响模型的训练和收敛速度。read more

我们定义如下的超参数进行训练:

Number of Epochs: 遍历数据集的次数
Batch Size: 每一次使用的数据集大小,即每一次用于训练的样本数量
Learning Rate: 每个 batch/epoch 更新模型参数的速度,较小的值会导致较慢的学习速度,而较大的值可能会导致训练过程中不可预测的行为,例如训练抖动频繁,有可能会发散等。

learning_rate = 1e-3
batch_size = 32
epochs = 5

Optimization Loop ( 优化循环 )
我们设置完超参数后,就可以利用优化循环训练和优化模型;优化循环的每次迭代称为一个 epoch, 每个 epoch 包含两个主要部分:

The Train Loop: 遍历训练数据集并尝试收敛到最优参数。
The Validation/Test Loop: 验证/测试循环—遍历测试数据集以检查模型性能是否得到改善。
让我们简单地熟悉一下训练循环中使用的一些概念。跳转到前面以查看优化循环的完整实现。

Loss Function ( 损失函数 )
当给出一些训练数据时,我们未经训练的网络可能不会给出正确的答案。 Loss function 衡量的是得到的结果与目标值的不相似程度,是我们在训练过程中想要最小化的 Loss function。为了计算 loss ,我们使用给定数据样本的输入进行预测,并将其与真实的数据标签值进行比较。

常见的损失函数包括nn.MSELoss (均方误差)用于回归任务,nn.NLLLoss(负对数似然)用于分类神经网络。nn.CrossEntropyLoss 结合 nn.LogSoftmax 和 nn.NLLLoss 。

我们将模型的输出 logits 传递给 nn.CrossEntropyLoss ,它将规范化 logits 并计算预测误差。

# Initialize the loss function
loss_fn = nn.CrossEntropyLoss()

Optimizer ( 优化器 )
优化是在每个训练步骤中调整模型参数以减少模型误差的过程。优化算法定义了如何执行这个过程(在这个例子中,我们使用随机梯度下降)。所有优化逻辑都封装在优化器对象中。这里,我们使用 SGD 优化器; 此外,PyTorch 中还有许多不同的优化器,如 ADAM 和 RMSProp ,它们可以更好地用于不同类型的模型和数据。

optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

在训练的循环中,优化分为3个步骤:

调用 optimizer.zero_grad() 重置模型参数的梯度,默认情况下,梯度是累加的。为了防止重复计算,我们在每次迭代中显式将他们归零。
通过调用 loss.backward() 反向传播预测损失, PyTorch 保存每个参数的损失梯度。
一旦我们有了梯度,我们调用 optimizer.step() 在向后传递中收集梯度调整参数。
Full Implementation (完整实现)
我们定义了遍历优化参数代码的 train loop, 以及根据测试数据定义了test loop。

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms## 数据集
training_data = datasets.FashionMNIST(root="../../data/",train=True,download=True,transform=transforms.ToTensor()
)test_data = datasets.FashionMNIST(root="../../data/",train=False,download=True,transform=transforms.ToTensor()
)## dataloader
train_dataloader = DataLoader(training_data, batch_size=32, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=32, shuffle=True)## 定义神经网络
class NeuralNetwork(nn.Module):def __init__(self):super(NeuralNetwork, self).__init__()self.flatten = nn.Flatten()self.linear_relu_stack = nn.Sequential(nn.Linear(28 * 28, 512),nn.ReLU(),nn.Linear(512, 512),nn.ReLU(),nn.Linear(512, 10))def forward(self, x):out = self.flatten(x)out = self.linear_relu_stack(out)return out## 实例化模型
model = NeuralNetwork()## 损失函数
loss_fn = nn.CrossEntropyLoss()## 优化器
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)## 超参数
learning_rate = 1e-3
batch_size = 32
epochs = 5## 训练循环
def train_loop(dataloader, model, loss_fn, optimizer):size = len(dataloader.dataset)for batch, (X, y) in enumerate(dataloader):# 计算预测和损失pred = model(X)loss = loss_fn(pred, y)## 反向传播optimizer.zero_grad()loss.backward()optimizer.step()if batch % 100 == 0:loss, current = loss.item(), batch * len(X)print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")## 测试循环
def test_loop(dataloader, model, loss_fn):size = len(dataloader.dataset)num_batches = len(dataloader)test_loss, correct = 0, 0with torch.no_grad():for X, y in dataloader:pred = model(X)test_loss += loss_fn(pred, y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()test_loss /= num_batchescorrect /= sizeprint(f"Test error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f}\n")## 训练网络
for t in range(epochs):print(f"Epoch {t+1}\n-------------------------------")train_loop(train_dataloader, model, loss_fn, optimizer)test_loop(test_dataloader, model, loss_fn)
print("Done!")

Lnton羚通专注于音视频算法、算力、云平台的高科技人工智能企业。 公司基于视频分析技术、视频智能传输技术、远程监测技术以及智能语音融合技术等, 拥有多款可支持ONVIF、RTSP、GB/T28181等多协议、多路数的音视频智能分析服务器/云平台。

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/95755.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

工程项目管理系统源码+功能清单+项目模块+spring cloud +spring boot em

​ 工程项目管理软件(工程项目管理系统)对建设工程项目管理组织建设、项目策划决策、规划设计、施工建设到竣工交付、总结评估、运维运营,全过程、全方位的对项目进行综合管理 工程项目各模块及其功能点清单 一、系统管理 1、数据字典&#…

衣服材质等整理(时常更新)

参考文章&图片来源 https://zhuanlan.zhihu.com/p/390341736 00. 天然纤维 01. 化学纤维 02. 聚酯纤维(即,涤纶) 一种由有机二元酸和二元醇通过化学缩聚制成的合成纤维。具有出色的抗皱性和保形性,所制衣物在穿着过程中不容…

解决git reset --soft HEAD^撤销commit时报错

今天在使用git回退功能的时候,遇到以下错误: 解决git reset --soft HEAD^撤销commit时报错 问题: 在进行完commit后,想要撤销该commit,于是使用了git reset --soft HEAD^命令,但是出现如下报错&#xff1…

android 12系统加上TTS引擎

系统层修改&#xff1a; 1.frameworks/base/packages/SettingsProvider/res/values/defaults.xml <string name"def_tts"></string> 2.frameworks/base/packages/SettingsProvider/src/com/android/providers/settings/DatabaseHelper.java loadString…

206. 反转链表

给你单链表的头节点 head &#xff0c;请你反转链表&#xff0c;并返回反转后的链表。 示例 1&#xff1a; 输入&#xff1a;head [1,2,3,4,5] 输出&#xff1a;[5,4,3,2,1]示例 2&#xff1a; 输入&#xff1a;head [1,2] 输出&#xff1a;[2,1]示例 3&#xff1a; 输入&a…

学习左耳听风栏目90天——第七天 7/90(学习左耳朵耗子的工匠精神,对技术的热爱)【每个程序员都该知道的事】

每个程序员都该知道的事 每个程序员都应该要读的书每个搞计算机专业的学生应有的知识LinkedIn 高效的代码复查技巧编程语言和代码质量的研究报告 每个程序员都应该要读的书 每个搞计算机专业的学生应有的知识 LinkedIn 高效的代码复查技巧 编程语言和代码质量的研究报告

MySQL中的锁机制

抛砖引玉&#xff1a;多个查询需要在同一时刻进行数据的修改&#xff0c;就会产生并发控制的问题。我们需要如何避免写个问题从而保证我们的数据库数据不会被破坏。 锁的概念 读锁是共享的互相不阻塞的。多个事务在听一时刻可以同时读取同一资源&#xff0c;而相互不干扰。 写…

Spring Clould 注册中心 - Eureka,Nacos

视频地址&#xff1a;微服务&#xff08;SpringCloudRabbitMQDockerRedis搜索分布式&#xff09; Eureka 微服务技术栈导学&#xff08;P1、P2&#xff09; 微服务涉及的的知识 认识微服务-服务架构演变&#xff08;P3、P4&#xff09; 总结&#xff1a; 认识微服务-微服务技…

mysql全文检索使用

数据库数据量10万左右&#xff0c;使用like %test%要耗费30秒左右&#xff0c;放弃该办法 使用mysql的全文检索 第一步:建立索引 首先修改一下设置: my.ini中ngram_token_size 1 可以通过 show variables like %token%;来查看 接下来建立索引:alter table 表名 add f…

【Unity】坐标转换经纬度方法(应用篇)

【Unity】坐标转换经纬度方法&#xff08;应用篇&#xff09; 解决地图中经纬度坐标转换与unity坐标互转的问题。使用线性变换的方法&#xff0c;理论上可以解决小范围内所以坐标转换的问题。 之前有写过[Unity]坐标转换经纬度方法&#xff08;原理篇),在实际使用中&#xff0c…

SD WebUI 扩展:prompt-all-in-one

sd-webui-prompt-all-in-one 是一个基于 Stable Diffusion WebUI 的扩展&#xff0c;旨在提高提示词/反向提示词输入框的使用体验。它拥有更直观、强大的输入界面功能&#xff0c;它提供了自动翻译、历史记录和收藏等功能&#xff0c;它支持多种语言&#xff0c;满足不同用户的…

规则的加载与管理者——KieContainer的获取与其类型的区别(虽然标题是KieContainer,其实说的还是KieServices)

之前梳理了一下有关KieServices的获取&#xff0c;与获取中的代码走向&#xff0c;详情请见&#xff1a; “万恶”之源的KieServices&#xff0c;获取代码就一行&#xff0c;表面代码越少里面东西就越多&#xff0c;本以为就是个简单的工厂方法&#xff0c;没想到里面弯弯绕绕…

AIGC:【LLM(六)】——Dify:一个易用的 LLMOps 平台

文章目录 一.简介1.1 LLMOps1.2 Dify 二.核心能力三.Dify安装3.1 快速启动3.2 配置 四.Dify使用五.调用开源模型六.接通闭源模型七.在 Dify.AI 探索不同模型潜力7.1 快速切换&#xff0c;测验不同模型的表现7.2 降低模型能力对比和选择的成本 一.简介 1.1 LLMOps LLMOps&…

打印技巧——word中A4排版打印成A3双面对折翻页

在进行会议文件打印时&#xff0c;我们常会遇到需要将A4排版的文件&#xff0c;在A3纸张上进行双面对折翻页打印&#xff0c;本文对设置方式进行介绍&#xff1a; 1、在【布局】选项卡中&#xff0c;点击右下角小箭头&#xff0c;打开页面设置选项卡 1.1在【页边距】中将纸张…

使用EasyExcel实现Excel表格的导入导出

使用EasyExcel实现Excel表格的导入导出 文章目录 使用EasyExcel实现Excel表格的导入导出1.集成easyExcel2.简单导出示例实体与excel列的映射导出excel的代码 3.Excel复杂表头导出与实体的映射导出代码 3.Excel导入 Java解析、生成Excel比较有名的框架有Apache poi、jxl。但他们…

nginx代理webSocket链接响应403

一、场景 使用nginx代理webSocket链接&#xff0c;nginx响应403 1、nginx访问日志响应403 [18/Aug/2023:09:56:36 0800] "GET /FS_WEB_ASS/webim_api/socket/message HTTP/1.1" 403 5 "-" "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit…

SpringBoot 微人事 职称管理模块(十三)

职称管理前端页面设计 在职称管理页面添加输入框 export default {name: "JobLevelMarna",data(){return{Jl:{name:""}}}}效果图 添加一个下拉框 v-model的值为当前被选中的el-option的 value 属性值 <el-select v-model"Jl.titlelevel" …

opencv 矩阵运算

1.矩阵乘&#xff08;*&#xff09; Mat mat1 Mat::ones(2,3,CV_32FC1);Mat mat2 Mat::ones(3,2,CV_32FC1);Mat mat3 mat1 * mat2; //矩阵乘 结果 2.元素乘法或者除法&#xff08;mul&#xff09; Mat m Mat::ones(2, 3, CV_32FC1);m.at<float>(0, 1) 3;m.at…

面试题-React(三):什么是JSX?它与常规JavaScript有什么不同?

在React的世界中&#xff0c;JSX是一项引人注目的技术&#xff0c;它允许开发者在JavaScript中嵌套类似HTML的标签&#xff0c;用于描述UI组件的结构。本篇博客将通过丰富的代码示例&#xff0c;深入探索JSX语法&#xff0c;解析其在React中的用法和优势。 一、JSX基础语法 在…

Springboot 实践(8)springboot集成Oauth2.0授权包,对接spring security接口

此文之前&#xff0c;项目已经添加了数据库DAO服务接口、资源访问目录、以及数据访问的html页面&#xff0c;同时项目集成了spring security&#xff0c;并替换了登录授权页面&#xff1b;但是&#xff0c;系统用户存储代码之中&#xff0c;而且只注册了admin和user两个用户。在…