【深度学习】如何一步步实现SGD随机梯度下降算法

如何一步步实现SGD随机梯度下降算法

文章目录

  • 如何一步步实现SGD随机梯度下降算法
    • SGD随机梯度下降算法的作用
    • MNIST_SAMPLE数据集
    • SGD算法的七大步骤
      • Step1. 初始化模型参数
      • Step2. 计算预测值predictions
      • Step3. 计算损失loss
      • Step4. 计算梯度gradients
      • Step5. 更新模型参数
      • Step6. 重复Step2-5
      • Step7. 停止
    • 在MNIST_SAMPLE数据集上训练linear_model
      • 把7个步骤的代码封装成类
      • 衡量指标metric
        • 验证精度validation accuracy
        • 验证函数
        • 训练linear_model
    • 使用learner.fit函数训练模型

SGD随机梯度下降算法的作用

它是一种优化算法,自动调整模型参数,提升模型的性能。
我们今天要在MNIST_SAMPLE数据集上实现SGD算法。

MNIST_SAMPLE数据集

MNIST_SAMPLE数据集只有数字3和数字7的图片。

from fastai.vision.all import *
path = untar_data(URLs.MNIST_SAMPLE)
from fastbook import *
import torch
matplotlib.rc('image', cmap='Greys')threes = (path/'train'/'3').ls().sorted()
sevens = (path/'train'/'7').ls().sorted()seven_tensors = [tensor(Image.open(o)) for o in sevens]
three_tensors = [tensor(Image.open(o)) for o in threes]stacked_sevens = torch.stack(seven_tensors).float()/255
stacked_threes = torch.stack(three_tensors).float()/255valid_3_tens = torch.stack([tensor(Image.open(o)) for o in (path/'valid'/'3').ls()])
valid_3_tens = valid_3_tens.float() / 255
valid_7_tens = torch.stack([tensor(Image.open(o)) for o in (path/'valid'/'7').ls()])
valid_7_tens = valid_7_tens.float() / 255train_x = torch.cat([stacked_threes, stacked_sevens]).view(-1, 28*28)
train_y = tensor([1]*len(threes) + [0]*len(sevens)).unsqueeze(1)
dset = list(zip(train_x, train_y))
dl = DataLoader(dset, 128)valid_x = torch.cat([valid_3_tens, valid_7_tens]).view(-1, 28*28)
valid_y = tensor([1]*len(valid_3_tens) + [0]*len(valid_7_tens)).unsqueeze(1)
valid_dset = list(zip(valid_x, valid_y))
valid_dl = DataLoader(valid_dset, 128)

SGD算法的七大步骤

Step1. 初始化模型参数

使用torch.randn()随机初始化参数,然后使用require_grad_方法表示需要追踪模型参数的梯度。

def init_params(size, std=1.0):return (torch.randn(size)*std).requires_grad_()bias = init_params(1)
weights = init_params((28*28,1))

Step2. 计算预测值predictions

先定义一个简单的网络模型,只有一个全连接层。然后使用这个网络模型计算预测值。
为什么需要bias?
y=w*x, 如果x=0是输入,那么预测值始终为0,不利于模型的训练;
y=w*x+b, 加入了bias,将使得模型更加灵活。

我们用训练集的前4个图片数据,作为样例测试这个函数。在pytorch中,@表示矩阵乘法运算符。

def linear1(xb):return xb@weights + bias
batch = train_x[:4]
preds = linear1(batch)
>>> preds

在这里插入图片描述

Step3. 计算损失loss

loss:基于预测值(predictions)和目标值(targets),使用某种损失函数loss_function,计算两者有多相近。
先定义损失函数:把预测值变成0-1之间的值,如果预测的目标是数字3,就计算预测值和1之间的距离;如果预测的目标是数字7,那么预测值本身就是它到0之间的距离。
sigmoid方法的作用:把任何输入值(无论正负)变成0-1之间的值。
在这里插入图片描述

def mnist_loss(predictions, targets):predictions = predictions.sigmoid()return torch.where(targets==1, 1-predictions, predictions).mean()
loss = mnist_loss(preds, train_y[:4])

在这里插入图片描述

Step4. 计算梯度gradients

梯度:指的是我们将怎么改变模型参数,它指明了具体的方向。
我们不需要手动计算梯度,因为深度学习库会自动帮助我们计算,只需要在初始化模型参数的时候,指明require_grad=True,它会自动保存模型参数的梯度。
backward()方法指的是反向传播算法,调用此方法会帮助我们自动计算每一层参数的梯度。相应地,因为设置了requires_grad=True, 新的梯度也会被自动保存。
当我们计算一个神经网络的导数的时候,这被称为向后传播过程。

loss.backward()
weights.grad.shape, weights.grad.mean(), bias.grad

在这里插入图片描述

把代码封装到calc_grad方法中,便于模块化地调用:

def calc_grad(xb, yb, model):preds = model(xb)loss = mnist_loss(preds, yb)loss.backward()
calc_grad(batch, train_y[:4], linear1)
weights.grad.mean(), bias.grad

在这里插入图片描述

为什么两次的输出结果不一样呢,因为pytorch自动将第一次计算的梯度保存了,第二次会在第一次的基础上再计算梯度,所以当然就结果不一样了,因此我们需要再下一次梯度前,将模型参数的梯度置为0.

weights.grad.zero_()
bias.grad.zero_();

Step5. 更新模型参数

学习率learning rate决定了每次更新模型参数的大小程度(也称为步长)。通常都设置得很小。

lr = 1.
weights.grad -= weights.grad * lr
bias.grad -= bias.grad * lr

Step6. 重复Step2-5

在这里我们需要将整个训练数据集分成mini_batches,然后将一个个batch喂入网络,为什么这样子做?

  • 一次性预测整个训练数据集会花费太长时间和太多内存
  • 如果一张张图片训练的话,梯度将变得不稳定和不精确

所以我们迭代整个训练集的子集来完成训练,即将数据集分成mini_batches

现在我们要在整个训练集的基础上更新参数。

将训练集分成很多mini batches,然后训练。

def train_epoch(model, lr, params):for xb,yb in dl:calc_grad(xb, yb, model)for p in params:p.data -= p.grad*lrp.grad.zero_()
for i in range(10):train_epoch(model, lr, params)

Step7. 停止

这是最基本的SGD算法。
在fastai深度学习库中,已经被封装成一个类了,我们只需要在创建Learner的时候指明loss_func=SGD.

在MNIST_SAMPLE数据集上训练linear_model

把7个步骤的代码封装成类

class SGD_Optim:def __init__(self,params,lr): self.params,self.lr = list(params),lrdef step(self, *args, **kwargs):for p in self.params: p.data -= p.grad.data * self.lrdef zero_grad(self, *args, **kwargs):for p in self.params: p.grad = Noneopt = SGD_Optim(linear_model.parameters(), lr)

衡量指标metric

loss主要是便于模型的训练,现在介绍的metric是便于我们在验证集上直观地了解模型的性能。

验证精度validation accuracy

在这里我们使用预测正确的平均值(即精度)来作为衡量指标。
模型验证过程同模型训练过程一样,我们将验证集分成一个个mini_batch,然后让模型去计算预测值preds,大于0.5的表示是数字3,否则表示数字7,然后计算预测正确的平均值,表示验证精度。

def batch_accuracy(xb, yb):preds = xb.sigmoid()correct = (preds>0.5) == ybreturn correct.float().mean()
验证函数
def validate_epoch(model):accs = [batch_accuracy(model(xb), yb)for xb,yb in valid_dl]# combine all the acc in this list into a single 1-dimensional tensorreturn round(torch.stack(accs).mean().item(), 4)
训练linear_model

通常情况下,训练一次模型包括一个训练周期train_epoch和一个验证周期validate_epoch。
在这里我们采用linear_model,训练20次。

linear_model = nn.Linear(28*28,1)
def train_model(model, epochs):for i in range(epochs):train_epoch(model)print(validate_epoch(model), end=' ')
train_model(linear_model, 20)

在这里插入图片描述

使用learner.fit函数训练模型

在fastai深度学习库中,内置函数learner.fit已经实现了train_model函数,为了使用此函数,我们需要先创建一个learner,而使用learner需要传入参数dataloaders,所以我们先创建dataloaders, 然后初始化一个Learner对象,然后调用fit函数。

dls = DataLoaders(dl, valid_dl)
learn = Learner(dls, nn.Linear(28*28,1), opt_func=SGD,loss_func=mnist_loss, metrics=batch_accuracy)
learn.fit(10, lr=lr)                

在fastai深度学习库中我们已经实现了SGD类的代码,所以我们只需要添加参数opt_func=SGD,便可以使用SGD优化算法。
在这里插入图片描述

第10轮的精度为0.967615,和上面的第10轮的精度差不多。也就是说,fastai只是有一些内置类和函数,让我们少写了一些代码,训练模型的速度便快了一些,精度上并没有太大提升。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/20423.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Flutter 3.29.0 新特性 CupertinoNavigationBar 可配置bottom属性

Flutter 3.29版本优化了开发流程并提升了性能,对 Impeller、Cupertino、DevTools 等进行了更新。 CupertinoNavigationBar和CupertinoSliverNavigationBar现在接受底部小部件,通常是搜索字段或分段控件。 例如本小节内容就是放置了一个输入框&#xff…

Vue 3最新组件解析与实践指南:提升开发效率的利器

目录 引言 一、Vue 3核心组件特性解析 1. Composition API与组件逻辑复用 2. 内置组件与生命周期优化 3. 新一代UI组件库推荐 二、高级组件开发技巧 1. 插件化架构设计 2. 跨层级组件通信 三、性能优化实战 1. 惰性计算与缓存策略 2. 虚拟滚动与列表优化 3. Tree S…

数据结构----哈希表的插入与输出

#include <stdio.h> #include <string.h> #include <stdlib.h> #include <math.h> typedef int datatype;typedef struct Node {struct Node *next;datatype data; }*Linklist;//创建节点 Linklist Create_node() {Linklist p(Linklist)malloc(sizeof(…

QT QLabel加载图片等比全屏自适应屏幕大小显示

最近在工作项目中,遇到一个需求: 1.使用QLabel显示一张图片; 2.当点击这个QLabel时,需要全屏显示;但不能改变原来的尺寸; 3.当点击放大后的QLabel时,恢复原有大小. 于是乎,就有了本篇博客,介绍如何实现这样的功能. 一、演示效果 在一个水平布局中&#xff0c;添加两个Lable用…

eNSP防火墙综合实验

一、实验拓扑 二、ip和安全区域配置 1、防火墙ip和安全区域配置 新建两个安全区域 ip配置 Client1 Client2 电信DNS 百度web-1 联通DNS 百度web-2 R2 R1 三、DNS透明代理相关配置 1、导入运营商地址库 2、新建链路接口 3、配置真实DNS服务器 4、创建虚拟DNS服务器 5、配置D…

ios苹果手机使用AScript应用程序实现UI自动化操作,非常简单的一种方式

现在要想实现ios的ui自动化还是非常简单的&#xff0c;只需要安装AScript这个自动化工具就可以了&#xff0c;而且安卓&#xff0c;iso还有windows都支持&#xff0c;非常好用。 在ios端安装之后&#xff0c;需要使用mac电脑或者windows电脑激活一下 使用Windows电脑激活​ 激…

CommonLang3-使用介绍

摘自&#xff1a;https://www.cnblogs.com/haicheng92/p/18721636 学习要带着目的&#xff0c;参照现实问题 本次目标&#xff1a; 了解 CommonsLang3 API 文档&#xff0c;找对路后以后开发直接查询 API 文档&#xff0c;摈弃盲目的百度掌握基础的字符串、日期、数值等工具…

Qt:多元素控件

目录 多元素控件介绍 QListWidget QTableWidget QTreeWidget 多元素控件介绍 多元素控件表示这个控件中包含了很多的元素&#xff0c;元素可能指的是字符串&#xff0c;也可以指的是更加复杂的数据结构、图片等等 Qt 中提供的多元素控件有: QListWidgetQListViewQTableW…

基于YOLO11深度学习的心脏超声图像间隔壁检测分割与分析系统【python源码+Pyqt5界面+数据集+训练代码】深度学习实战、目标分割、人工智能

《------往期经典推荐------》 一、AI应用软件开发实战专栏【链接】 项目名称项目名称1.【人脸识别与管理系统开发】2.【车牌识别与自动收费管理系统开发】3.【手势识别系统开发】4.【人脸面部活体检测系统开发】5.【图片风格快速迁移软件开发】6.【人脸表表情识别系统】7.【…

二叉树链式结构:数据结构中的灵动之舞

目录 前言 一、 前置说明 二、二叉树的遍历 2.1前序遍历 2.2中序遍历 2.3 后序遍历 2.4层序遍历 三、二叉树的遍历的应用 3.1二叉树节点个数&#xff1a; 3.2二叉树的高度 3.3 二叉树第k层的节点的个数 3.4二叉树的查找 总结 前言 在数据结构的世界里&#xff0c;二叉…

Tomcat下载,安装,配置终极版(2024)

Tomcat下载&#xff0c;安装&#xff0c;配置终极版&#xff08;2024&#xff09; 1. Tomcat下载和安装 进入Apache Tomcat官网&#xff0c;我们可以看到这样一个界面。 现在官网目前最新版是Tomcat11&#xff0c;我用的是Java17&#xff0c;在这里我们选择Tomcat10即可。Tom…

Android Studio - Android Studio 查看项目的 Android SDK 版本(4 种方式)

一、通过项目级 build.gradle 文件 1、基本介绍 在项目级 build.gradle 文件中&#xff0c;查看 compileSdk、minSdk、targetSdk 字段 或者是 compileSdkVersion、minSdkVersion、targetSdkVersion 字段 // 看到的可能是android {compileSdk 32defaultConfig {minSdk 21tar…

linux云服务器部署deepseek,并通过网页访问

参考视频&#xff1a;https://www.douyin.com/root/search/linux%E5%AE%89%E8%A3%85%20deepseek?aid3aa2527c-e4f2-4059-b724-ab81a140fa8b&modal_id7468518885570940214&typegeneral 修改ollama配置文件 vim /etc/systemd/system/ollama.service 我的电脑硬盘只有4…

【AI】mac 本地部署 Dify 实现智能体

下载 Ollama 访问 Ollama 下载页&#xff0c;下载对应系统 Ollama 客户端。或者参考文章【实战AI】macbook M1 本地ollama运行deepseek_m1 max可以跑deepseek吗-CSDN博客 dify 开源的 LLM 应用开发平台。提供从 Agent 构建到 AI workflow 编排、RAG 检索、模型管理等能力&am…

Jenkins介绍

什么是Jenkins Jenkins 是一个开源的自动化服务器&#xff0c;主要用于持续集成和持续交付&#xff08;CI/CD&#xff09;。它帮助开发团队自动化构建、测试和部署软件&#xff0c;从而提高开发效率和软件质量。 如果一个系统是前后端分离的开发模式&#xff0c;在集成阶段会需…

如何使用 vxe-table grid 全配置式给单元格字段格式化内容,格式化下拉选项内容

如何使用 vxe-table grid 全配置式给单元格字段格式化内容&#xff0c;格式化下拉选项内容 公司的业务需求是自定义配置好的数据源&#xff0c;通过在列中配置好数据&#xff0c;全 json 方式直接返回给前端渲染&#xff0c;不需要写任何格式化方法。 官网&#xff1a;https:/…

【弹性计算】IaaS 和 PaaS 类计算产品

《弹性计算产品》系列&#xff0c;共包含以下文章&#xff1a; 云服务器&#xff1a;实例、存储、网络、镜像、快照容器、裸金属云上运维IaaS 和 PaaS 类计算产品 &#x1f60a; 如果您觉得这篇文章有用 ✔️ 的话&#xff0c;请给博主一个一键三连 &#x1f680;&#x1f680…

【Spring详解二】容器的基本实现

二、容器的基本实现 2.1 容器的基本用法 package com.xxx; public class Hello {public void sayHello() {System.out.println("Hello, spring");} } public static void main(String[] args) {//XmlBeanFactory 在 Spring3.1 以后废弃BeanFactory beanFactory ne…

计算机毕业设计Python考研院校推荐系统 考研分数线预测 考研推荐系统 考研可视化(代码+LW文档+PPT+讲解视频)

温馨提示&#xff1a;文末有 CSDN 平台官方提供的学长联系方式的名片&#xff01; 温馨提示&#xff1a;文末有 CSDN 平台官方提供的学长联系方式的名片&#xff01; 温馨提示&#xff1a;文末有 CSDN 平台官方提供的学长联系方式的名片&#xff01; 作者简介&#xff1a;Java领…

Ubuntu 系统 LVM 逻辑卷扩容教程

Ubuntu 系统 LVM 逻辑卷扩容教程 前言 在 Linux 系统中&#xff0c;LVM&#xff08;Logical Volume Manager&#xff09;是一种逻辑卷管理工具&#xff0c;允许管理员动态调整磁盘空间&#xff0c;而无需重启系统。 本文将详细介绍如何使用 LVM 扩容逻辑卷&#xff0c;以实现…