深度学习笔记12

1.神经网络的代价函数

        神经网络可同时用于解决分类问题和回归问题,对于不同的问题会在输出层后,加上不同的变换函数。一般来说,回归问题使用恒等函数f(x)=x,分类问题使用sigmoid或softmax函数。而不同的变换函数,也对应不同的代价函数。

神经网络解决回归问题

        在使用神经网络解决回归问题时,会在输出层后,加上恒等函数,不会对输入值做任何修改。如果神经网络只预测一个回归值,那么使用的代价函数和线性回归中的代价函数一样,都是均方误差函数MSE。如果线性回归需要预测多个回归值,则需要将每个回归目标的均方误差计算出来,然后相加得到总误差。

总的代价函数:

MSE=\frac{1}{mn}\sum_{i=1}^{m}\sum_{j=1}^{n}(y_{ij}-\hat{y}_{ij})^2

m:样本个数

n:目标个数

y_{ij}:真实值

\hat{y}_{ij}:预测值

i:第i个样本

j:第j个目标

\sum_{j=1}^{n}:计算单独每个样本的n个目标的真实值和预测值的平方差

\sum_{i=1}^{m}:将m个样本对应的误差加到一起

神经网络解决分类问题 

        当使用神经网络解决分类问题时,最后一层的每个神经元都会对应一个类型,每个神经元的输出通过变换函数转换为类别对应的概率。如果每个类别之间是互斥的,就将神经元的输出值,输入到softmax函数中,将其转化为这几个互斥的类别对应的概率(p_1+p_2+...=1)。

        另一种情况是多标签分类,每个类别之间互不打扰,相互独立,这时要将神经元的输出值,分别输入到sigmoid函数中,经过sigmoid函数的计算,可以得到这些不同类别的概率,它们是多个无关联的、0-1之间的实数。

代价函数

         在解决分类问题时,一般使用交叉熵损失函数。对于互斥的分类问题,神经网络会使用与softmax回归形式完全相同的交叉熵损失函数。对于多标签分类问题,神经网络使用与逻辑回归形式相似的交叉熵损失函数。

互斥的多分类问题:E=-\frac{1}{m}\sum_{i=1}^{m}\sum_{k=1}^{n}y_k^{(i)}log(p_k^{(i)})

多标签分类问题:E=-\frac{1}{m}\sum_{i=1}^{m}\sum_{k=1}^{n}(y_k^{(i)}log(p_k^{(i)})+(1-y_k^{(i)})log(1-p_k^{(i)}))

2.小批量梯度下降算法        

        梯度下降算法有三种常见的形式:批量梯度下降、随机梯度下降和小批量梯度下降

  • 批量梯度下降:

         每次迭代中,批量梯度下降算法都会基于所有的训练样本,计算损失函数的梯度,因此可以得到一条平滑的收敛曲线。训练数据:100个样本,迭代轮数50,在每一轮迭代中都会一起使用这100个样本,计算整个训练集的梯度,更新模型参数,所以总更新次数:50次

  • 随机梯度下降:

        会在一轮完整的迭代过程中,遍历整个训练集,但是每次更新只基于一个样本计算梯度,这样会得到一条震荡的收敛曲线。训练数据:100个样本,迭代轮数50,每一轮迭代会遍历这100个样本,每次汇集孙某一个样本的梯度,更新模型参数,所以总更新次数:100*50=5000

  • 小批量梯度下降:

        结合批量梯度下降和随机梯度下降的优点,每次迭代会从训练集中,随机选择一个小批量,计算梯度,更新模型。训练数据:100个样本,迭代轮数50,小批量大小20,在每一轮迭代中会有5次小批量的迭代,所以总更新次数:(100/20)*50=250

梯度下降算法比较
优点缺点
批量梯度下降每次迭代会使用整个训练集计算梯度,可以得到准确的梯度方向如果数据集非常大时,就导致每次迭代的速度都非常慢,计算成本就会很高
随机梯度下降每次只用一个样本训练,所以迭代速度会非常快,迭代具有震荡属性,可以跳出局部最优解更新的方向会不稳定,可能永远都不会真正的收敛
小批量梯度下降结合随机梯度下降的高效性和批量梯度下降的稳定性,它比随机梯度下降有更稳定的收敛,同时又比批量梯度下降计算的更快
方法是否稳定迭代速度局部最优解
批量梯度下降稳定可能停留
随机梯度下降不稳定可跳出
小批量梯度下降较稳定较快可跳出

小批量梯度下降算法的实现 

1.小批量数据的准备

import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset
#设置一个固定的随机种子,确保每次运行得到相同的数据
np.random.seed(0)
#随机生成100个横坐标x,范围在0-2
x=2*np.random.rand(100,1)
#生成带有噪音的纵坐标y,数据基本分布在y=2x+3附近
y=3+2*x+np.random.randn(100,1)*0.5
plt.scatter(x, y,marker='x',color='red')#将训练数据x、y转为张量
x=torch.from_numpy(x).float()
y=torch.from_numpy(y).float()#使用TensorDataset,将x和y组成训练集
dataset=TensorDataset(x,y)
#使用DataLoader,构造随机的小批量数据
dataloader=DataLoader(dataset,#使用一个小批量的数据规模为20batch_size=20,#随机打乱数据的顺序shuffle=True)
print("dataloader len=%d"%(len(dataloader)))
for index,(data,label) in enumerate(dataloader):print("index = %d num=%d"%(index,len(data )))

x=2*np.random.rand(100,1) 

  • np.random.rand(100,1):使用 NumPyrand 函数生成一个形状为 (100,1)(100,1)(100,1) 的随机数组,里面的数值均匀分布在 000 到 111 之间。

  • 2 * np.random.rand(100,1):将生成的数组乘以 2,使得数值范围扩展到 000 到 222 之间。

y=3+2*x+np.random.randn(100,1)*0.5 

np.random.randn(100,1) * 0.5:生成形状为 (100,1)(100,1)(100,1) 的随机噪声项,使 y 中的值带有一些波动。np.random.randn(100,1) 使用标准正态分布(均值为 0,标准差为 1)生成随机数,将其乘以 0.5 缩小标准差,使得噪声更小、波动更平滑。

np.random.rand 生成的随机数均匀分布在 [0,1),而 np.random.randn 生成的随机数服从标准正态分布,中心在 0。 

DataLoader(dataset, batch_size=20, shuffle=True)

  • DataLoaderPyTorch 中的一个类,用于将数据集分成小批量并进行迭代读取。这样可以有效地处理大量数据而不必一次性全部加载到内存中,特别是在训练深度学习模型时非常重要。
  • dataset

    • 这是要加载的数据集,即前面用 TensorDataset(x, y) 定义的 datasetDataLoader 将使用该数据集来加载数据。
  • batch_size=20

    • 设定每个小批量的大小为 20,意味着每次加载器会返回 20 个样本。
    • 在训练过程中,小批量(batch)数据能够加速模型的梯度计算,同时可以让模型在批量数据的平均梯度上进行更新,有助于稳定训练。
  • shuffle=True

    • 设置 shuffle=True 会在每个 epoch(轮次)开始前打乱数据集的顺序,以避免模型训练受样本顺序的影响,有助于提升模型的泛化能力。
    • 在训练时,数据的随机性能够帮助模型更好地学习到数据的真实分布,防止过拟合。

w = torch.randn(1, requires_grad=True)

        初始化了模型的权重参数 w,并启用了自动求导功能(requires_grad=True),表示 w 将在训练过程中计算其梯度。 尽管我们在定义时启用了 requires_grad=True,只是告诉 PyTorch 这个张量(比如 w)需要计算和记录梯度,但这并不会主动计算出梯度。实际上,梯度计算只在调用 loss.backward() 时才会触发

dataloader len=5
index = 0 num=20
index = 1 num=20
index = 2 num=20
index = 3 num=20
index = 4 num=20 

2.小批量梯度下降算法迭代 

#带迭代的参数w和b
w=torch.randn(1,requires_grad=True)
b=torch.randn(1,requires_grad=True)
#进入模型的迭代循环
for epoch in range(1,51):#迭代轮数#在一个迭代轮次中,以小批量的方式,使用dataloader对数据进行遍历#batch_idx表示当前遍历的批次#data和label表示这个批次的训练数据和标记for batch_idx,(data,label) in enumerate(dataloader):h=x*w+b#计算当前直线的预测值,保存到h#计算预测值h和真实值y之间的均方误差,保存到loss中loss=torch.mean((h-y)**2)loss.backward()#计算代价loss关于参数w和b的偏导数#进行梯度下降,沿着梯度下降的反方向,更新w和b的值w.data-=0.01*w.grad.datab.data-=0.01*b.grad.data#清空张量w和b中的梯度信息,为下一次迭代做准备w.grad.zero_()b.grad.zero_()#每次迭代,都打印当前迭代的轮数epoch#数据的批次batch_idx和loss损失值print("epoch (%d) batch (%d) loss=%.3lf"%(epoch,batch_idx,loss.item()))

for batch_idx, (data, label) in enumerate(dataloader):

  • enumerate(dataloader) 是一个迭代器,它会遍历 dataloader 并在每次迭代时返回当前批次的索引 batch_idx 和该批次的数据 (data, label)
  • dataloader 是由 DataLoader 创建的对象,负责将数据集 dataset 分割为小批量,以便模型可以逐批次读取数据进行训练

 

3.图像绘制 

#打印w和b的值,并绘制直线
print('w=%.3lf,b=%.3lf'%(w.item(),b.item()))
w=w.item()
b=b.item()
x=np.linspace(0,2,100)
h=w*x+b
plt.plot(x,h)
plt.show()

 

4.结果分析

        每次运行代码时返回的 wb 不一样,主要是由于以下几个原因:

  • 随机初始化:

w = torch.randn(1, requires_grad=True) b = torch.randn(1, requires_grad=True)

        这里的 torch.randn 会从标准正态分布(均值 0,标准差 1)中随机生成 wb 的初始值。每次运行代码时,wb 的初始值通常是不同的,这会导致训练过程中的更新路径不同,从而影响最终的值。

  • 数据中的随机噪声

y = 3 + 2 * x + np.random.randn(100, 1) * 0.5

        虽然设置了 np.random.seed(0),使得每次运行生成的 xy 都相同,但在模型初始化时 wb 是随机的,每次不同的初始 wb 会影响模型在不同批次上的更新路径,进而影响训练结果。

  • 小批量数据的随机顺序:

dataloader = DataLoader(dataset, batch_size=20, shuffle=True)

  DataLoadershuffle=True 参数会在每个 epoch 开始时随机打乱数据的顺序。每次运行代码时,小批量数据的顺序会不同,因此参数更新的路径也会不同。

 

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/468839.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

RabbitMQ队列详细属性(重要)

RabbitMQ队列详细属性 1、队列的属性介绍1.1、Type:队列类型1.2、Name:队列名称1.3、Durability:声明队列是否持久化1.4、Auto delete: 是否自动删除1.5、Exclusive:1.6、Arguments:队列的其他属性&#xf…

json即json5新特性,idea使用json5,fastjson、gson、jackson对json5支持

文章目录 1.新特性1.1.JSON&JSON5官网2.示例2.1. IntelliJ IDEA2.1.1.支持.json5文件2.1.2.md支持json5代码块 2.9. 示例源码 1.新特性 【通用】 注释尾随逗号key无需引号(或单引号) 【字符串】 字符串可以用单引号引起来。字符串可以通过转…

【NOIP普及组】摆花

【NOIP普及组】摆花 C语言代码C 代码Java代码Python代码 💐The Begin💐点点关注,收藏不迷路💐 小明的花店新开张,为了吸引顾客,他想在花店的门口摆上一排花,共 m 盆。通过调 查顾客的喜好&am…

pdf转excel;pdf中表格提取

一、问题描述 在工作中或多或少会遇到:需要将某份pdf中的表格数据提取出来,以便能够“修改使用”数据 可将pdf中的表格提取出来,解决办法还有点复杂 尤其涉及“pdf中表格不是标准的单元格”的时候,提取数据到excel不太容易 比…

Qt中 QWidget 和 QMainWindow 区别

QWidget 用来构建简单窗口 QMainWindow 用来构建更复杂的窗口,QMainWindow 继承自QWidget,在QWidget 的基础上提供了菜单栏、工具栏、状态栏等功能 菜单栏(QMenuBar)工具栏(QToolBar)状态栏(Q…

《深入浅出Apache Spark》系列③:Spark SQL解析层优化策略与案例解析

导读:本系列是Spark系列分享的第三期。第一期分享了Spark Core的一些基本原理和一些基本概念,包括一些核心组件。Spark的所有组件都围绕Spark Core来运转,其中最活跃的一个上层组件是Spark SQL。第二期分享则专门介绍了Spark SQL的基本架构和…

安全的时钟启动

Note:文章内容以 Xilinx 系列 FPGA 进行讲解 1、什么是安全启动时钟 通常情况下,在MMCM/PLL的LOCKED信号抬高之后(由0变为1),MMCM/PLL就处于锁定状态,输出时钟已保持稳定。但在此之前,输出时钟会…

【mongodb】数据库的安装及连接初始化简明手册

NoSQL(NoSQL Not Only SQL ),意即"不仅仅是SQL"。 在现代的计算系统上每天网络上都会产生庞大的数据量。这些数据有很大一部分是由关系数据库管理系统(RDBMS)来处理。 通过应用实践证明,关系模型是非常适合于客户服务器…

丹韵红墙成红毯至美背景!冠珠华脉「雍华京韵」于M essential大秀绽放京韵时尚

东方美学代表品牌M essential近日于上海科学会堂举办十周年大秀,并发布品牌全新2024/25冬春系列。冠珠瓷砖作为国风新韵合作品牌,以高定岩板华脉「雍华京韵」系列的宫墙丹韵打造红毯背景墙,中国高定岩板与中国高级时装作品碰撞着“中国美”的…

工程认证与Spring Boot:计算机课程管理的新探索

摘要 随着信息技术在管理上越来越深入而广泛的应用,管理信息系统的实施在技术上已逐步成熟。本文介绍了基于工程教育认证的计算机课程管理平台的开发全过程。通过分析基于工程教育认证的计算机课程管理平台管理的不足,创建了一个计算机管理基于工程教育认…

excel功能

统计excel中每个名字出现的次数 在Excel中统计每个名字出现的次数,您可以使用COUNTIF函数或数据透视表。以下是两种方法的详细步骤: 方法一:使用COUNTIF函数 准备数据:确保您的姓名列表位于一个连续的单元格区域,例如…

【flask开启进程,前端内容图片化并转pdf-会议签到补充】

flask开启进程,前端内容图片化并转pdf-会议签到补充 flask及flask-socketio开启threading页面内容转图片转pdf流程前端主js代码内容转图片-browser端browser端的同步编程flask的主要功能route,def 总结 用到了pdf,来回数据转发和合成,担心flask卡顿,响应差,于是刚好看到threadi…

聊一聊Spring中的自定义监听器

前言 通过一个简单的自定义的监听器,从源码的角度分一下Spring中监听的整个过程,分析监听的作用。 一、自定义监听案例 1.1定义事件 package com.lazy.snail;import lombok.Getter; import org.springframework.context.ApplicationEvent;/*** Class…

VMWareTools安装及文件无法拖拽解决方案

文章目录 1 安装VMWare Tools2 安装vmware tools之后还是无法拖拽文件解决方案2.1 确认vmware tools安装2.2 客户机隔离2.3 修改自定义配置文件2.4 安装open-vm-tools-desktop软件 1 安装VMWare Tools 打开虚拟机VMware Workstation,启动Ubuntu系统,菜单…

ADC前端控制与处理模块--AD7606_Module

总体框架 AD7606_Module主要由3个模块组成组成,AD7606_Data_Pkt和AD7606_Drive以及AD7606_ctrl。 1.AD7606_Data_Pkt主要作用是把AD芯片数据组好数据包,然后发送给上位机; 2.AD7606_Drive主要负责和芯片的交互部分 3.AD7606_ctrl控制模块的作…

Unity 插件 - Project窗口资源大小显示

Unity 插件 - Project窗口资源大小显示 🍔功能🌭安装 🍔功能 💡.显示Project Assets 和Packages下所有文件的大小(右侧显示) 💡.统计选中文件夹及其子文件夹下所有文件的大小并显示&#xff08…

HTB:Photobomb[WriteUP]

目录 连接至HTB服务器并启动靶机 使用nmap对靶机进行端口开放扫描 再次使用nmap对靶机开放端口进行脚本、服务扫描 使用ffuf进行简单的子域名扫描 使用浏览器直接访问该域名 选取一个照片进行下载,使用Yakit进行抓包 USER_FLAG:a9afd9220ae2b5731…

ssm教室信息管理系统+vue

系统包含:源码论文 所用技术:SpringBootVueSSMMybatisMysql 免费提供给大家参考或者学习,获取源码看文章最下面 需要定制看文章最下面 目 录 目 录 III 1 绪论 1 1.1 研究背景 1 1.2目的和意义 1 1.3 论文结构安排 2 2 相关技术 3 …

详解Java之Spring MVC篇二

目录 获取Cookie/Session 理解Cookie 理解Session Cookie和Session的区别 获取Cookie 获取Session 获取Header 获取User-Agent 获取Cookie/Session 理解Cookie HTTP协议自身是“无状态”协议,但是在实际开发中,我们很多时候是需要知道请求之间的…

量子计算及其在密码学中的应用

💓 博客主页:瑕疵的CSDN主页 📝 Gitee主页:瑕疵的gitee主页 ⏩ 文章专栏:《热点资讯》 量子计算及其在密码学中的应用 量子计算及其在密码学中的应用 量子计算及其在密码学中的应用 引言 量子计算概述 定义与原理 发展…