神经网络 torch.nn---优化器的使用

torch.optim - PyTorch中文文档 (pytorch-cn.readthedocs.io)

torch.optim — PyTorch 2.3 documentation

反向传播可以求出神经网路中每个需要调节参数的梯度(grad),优化器可以根据梯度进行调整,达到降低整体误差的作用。下面我们对优化器进行介绍。

构造优化器

optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
optimizer = optim.Adam([var1, var2], lr=0.0001)
  • 选择优化器的算法optim.SGD
  • 之后在优化器中放入模型参数model.parameters(),这一步是必备
  • 还可在函数中设置一些参数,如学习速率lr=0.01(这是每个优化器中几乎都会有的参数)

调用优化器中的step方法

step()方法就是利用我们之前获得的梯度,对神经网络中的参数进行更新。

for input, target in dataset:optimizer.zero_grad()output = model(input)loss = loss_fn(output, target)loss.backward()optimizer.step()
  • 步骤optimizer.zero_grad()是必选的。对每个参数的梯度进行清零

  • 我们的输入经过了模型,并得到了输出output

  • 之后计算输出和target之间的误差loss

  • 调用误差的反向传播loss.backwrd更新每个参数对应的梯度

  • 调用optimizer.step()对卷积核中的参数进行优化调整

  • 之后继续进入for循环,使用函数optimizer.zero_grad()对每个参数的梯度进行清零,防止上一轮循环中计算出来的梯度影响下一轮循环。

优化器的使用

优化器中算法共有的参数(其他参数因优化器的算法而异):

  • params: 传入优化器模型中的参数

  • lr: learning rate,即学习速率

关于学习速率

  • 一般来说,学习速率设置得太大,模型运行起来会不稳定

  • 学习速率设置得太小,模型训练起来会过

  • 建议在最开始训练模型的时候,选择设置一个较大的学习速率;训练到后面的时候,再选择一个较小的学习速率

代码实战:

选择CIFAR10

import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Linear, Sequential, Flatten
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriterdataset = torchvision.datasets.CIFAR10(root='./dataset', train=False, transform=torchvision.transforms.ToTensor(),download=True)
dataloader = DataLoader(dataset, batch_size=1)class Tudui(nn.Module):def __init__(self):super(Tudui, self).__init__()self.model = Sequential(Conv2d(3, 32, kernel_size=5, padding=2),MaxPool2d(kernel_size=2),Conv2d(32, 32, kernel_size=5, padding=2),MaxPool2d(kernel_size=2),Conv2d(32, 64, kernel_size=5, padding=2),MaxPool2d(kernel_size=2),Flatten(),Linear(64 * 4 * 4, 64),Linear(64, 10))def forward(self, x):x = self.model(x)return xloss = nn.CrossEntropyLoss()
tudui = Tudui()
optim = torch.optim.SGD(tudui.parameters(), lr=0.01)
for epoch in range(20):running_loss = 0.0for data in dataloader:imgs, targets = dataoutputs = tudui(imgs)# print(outputs)# print(targets)result_loss = loss(outputs, targets)optim.zero_grad()result_loss.backward()optim.step()# print(result_loss)running_loss = running_loss + result_lossprint(running_loss)

输出:

总结使用优化器训练的训练套路):

  • 设置损失函数loss function

  • 定义优化器optim

  • 从使用循环dataloader中的数据:for data in dataloder

    • 取出图片imgs,标签targets:imgs,targets=data

    • 将图片放入神经网络,并得到一个输出:output=model(imgs)

    • 计算误差:loss_result=loss(output,targets)

    • 使用优化器,初始化参数的梯度为0:optim.zero_grad()

    • 使用反向传播求出梯度:loss_result.backward()

    • 根据梯度,对每一个参数进行更新:optim.step()

  • 进入下一个循环,直到完成训练所需的循环次数

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/342693.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

uniapp内置的button组件的问题

问题描述 由于想要使用uniapp内置button组件的开放能力,所以就直接使用了button,但是他本身带着边框,而且使用 border:none;是没有效果的。 问题图片 解决方案 button::after {border: none;} 正确样式 此时的分享…

6.更复杂的光照

一、Unity的渲染路径 渲染路径决定了光照是如何应用到Unity Shader中的。我们需要为每个Pass指定它使用的渲染路径 如何设置渲染路径? Edit>Project Settings>Player>Other Settinigs>Rendering 如何使用多个渲染路径?如:摄像…

kafka-集群搭建(在docker中搭建)

文章目录 1、kafka集群搭建1.1、下载镜像文件1.2、创建zookeeper容器并运行1.3、创建3个kafka容器并运行1.3.1、9095端口1.3.2、9096端口1.3.3、9097端口 1.4、重启kafka-eagle1.5、查看 efak1.5.1、查看 brokers1.5.2、查看 zookeeper 1、kafka集群搭建 1.1、下载镜像文件 d…

makefile2

makefile的条件判断 运行make。 替换 make -c make-f …… 还可以 man make来查看其他的make命令。

【ARM Cache 及 MMU 系列文章 6.2 -- ARMv8/v9 Cache 内部数据读取方法详细介绍】

请阅读【ARM Cache 及 MMU/MPU 系列文章专栏导读】 及【嵌入式开发学习必备专栏】 文章目录 Direct access to internal memoryL1 cache encodingsL1 Cache Data 寄存器Cache 数据读取代码实现Direct access to internal memory 在ARMv8架构中,缓存(Cache)是用来加速数据访…

01_初识微服务

文章目录 一、微服务概述1.1 什么是微服务1.2 对比微服务架构与单体架构1.3 微服务设计原则1.4 微服务开发框架1.5 简单理解分布式部署与集群部署 二、微服务的核心概念2.1 服务注册与发现2.2 微服调用(通信)2.3 服务网关2.4 服务容错2.5 链路追踪参考链…

C语言小例程6/100

题目:输入三个整数x,y,z,请把这三个数由小到大输出。 程序分析:我们想办法把最小的数放到x上,先将x与y进行比较,如果x>y则将x与y的值进行交换,然后再用x与z进行比较,如果x>z则将x与z的值…

电商平台的消费增值策略

电商平台通过创新的消费增值策略,为用户提供了全新的激励体验。这种策略通过积分奖励和价值提升机制,鼓励用户持续参与并增强用户对平台的忠诚度。 消费积分的奖励机制 在电商平台,每一笔交易都会根据预设的比例回馈给消费者一部分资金&…

MATLAB format

在MATLAB中,format 是一个函数,用于控制命令窗口中数值的显示格式。这个函数可以设置数值的精度、显示的位数等。以下是一些常用的 format 命令: format long:以默认的长格式显示数值,通常显示15位有效数字。format s…

Mysql学习(三)——SQL通用语法之DML

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 DML添加数据修改数据删除数据 总结 DML DML用来对数据库中表的数据记录进行增删改操作。 添加数据 -- 给指定字段添加数据 insert into 表名(字段1,字…

引用(C++)和内联函数

前言&#xff1a;本文主要讲解C语法中引用如何使用和使用时的一些技巧 基本语法 引用就是取别名 #include <iostream> using namespace std; int main() {int a 10;int& b a;//给a取别名为bcout << a << endl;cout << b << endl;return 0…

【python】OpenCV—Bitplane

学习来自&#xff1a; 位平面分割&#xff08;Bit-Plane Slicing&#xff09;使用OpenCVPython进行图像处理的初学者指南 位平面 位平面&#xff08;bitplane&#xff09;是一个在计算机科学中用于描述图像数据的概念&#xff0c;具体定义如下&#xff1a; 【定义】&#x…

SpringBoot实现发送邮件功能

目录 一、开启邮件服务 二、导入pom依赖 三、配置yml文件 四、发送邮件 4.1、发送文字邮件 4.2、发送html邮件 4.3、发送附件邮件 4.4、发送图片邮件 一、开启邮件服务 这里拿QQ邮箱举例。 翻到下面进行开启,之后获取授权码。 二、导入pom依赖 <dependency><…

kotlin 调用java的get方法Use of getter method instead of property access syntax

调用警告 Person.class public class Person {private String name;Person(String name) {this.name name.trim();}public String getName() {return name;}public void setName(String name) {this.name name;}public String getFullName() {return name " Wang&quo…

【漏洞复现】用友NC downCourseWare 任意文件读取漏洞

0x01 产品简介 用友NC是一款企业级ERP软件。作为一种信息化管理工具&#xff0c;用友NC提供了一系列业务管理模块&#xff0c;包括财务会计、采购管理、销售管理、物料管理、生产计划和人力资源管理等&#xff0c;帮助企业实现数字化转型和高效管理。 0x02 漏洞概述 用友NC …

使用python优雅的将PDF转为Word

使用python优雅的将PDF转为Word 先装这个优雅的库 pip install pdf2docx然后运行下面优雅的代码&#xff0c;将pdf路径和docx路径修改 from pdf2docx import Converter # path pdf_file C:\\Users\\phl\\Desktop\\软件工程期末\\软件工程模拟试题5.pdf docx_file C:\\User…

Python 识别图片形式pdf的尝试(未解决)

想识别出pdf页面右下角某处的编号。pdf是图片形式页面。查了下方法&#xff0c;有源码是先将页面提取成jpg&#xff0c;再用pytesseract提取图片文件中的内容。 直接用图片来识别。纯数字的图片&#xff0c;如条形码&#xff0c;可识别。带中文的不可以&#xff0c;很乱。 识别…

pESC-HIS是什么,怎么看?-实验操作系列-2

01 典型的pESC-HIS质粒遗传图谱 02 介绍 质粒类型&#xff1a;酿酒酵母蛋白表达载体 表达水平&#xff1a;高拷贝 诱导方法&#xff1a;半乳糖 启动子&#xff1a;GAL1和GAL10 克隆方法&#xff1a;多克隆位点&#xff0c;限制性内切酶 载体大小&#xff1a;6706bp 5 测…

国产工业级实时数据库

项目功能描述 Mars数据库的核心功能在于其能够高效地处理来自工业现场的大量传感器数据。它通过简化的可视化配置&#xff0c;允许用户轻松接入各种传感器&#xff0c;并进行数据记录和逻辑处理。Mars数据库在单机模式下支持高达120万个传感器信号的接入&#xff0c;而其分布式…