Pytorch深度学习实践笔记3

🎬个人简介:一个全栈工程师的升级之路!
📋个人专栏:pytorch深度学习
🎀CSDN主页 发狂的小花
🌄人生秘诀:学习的本质就是极致重复!

视频来自【b站刘二大人】

目录

1 梯度下降(Gradient Descent)

2 随机梯度下降(SGD Stochastic Gradient Descent)

3 批量梯度下降(BGD Batch Gradient Descent)

4 小批量梯度下降(mini-Batch GD,mini-batch Gradient Descent)

5 代码


1 梯度下降(Gradient Descent)

  • 梯度:

方向导数在该点的最大值,建立cost与w的关系,优化使得w能够快速收敛

  • 引入:

可以发现上一章节我们寻找权重ω的时候,使用的是遍历 ω 的方法,显然在工程上这是不可行的,于是引入了梯度下降(Gradient Descent)算法。

  • 方案:

我们的目标是找出 ω∗ 最小化 cost(ω)函数 。梯度下降使用公式ω=ω−α∗∂cost∂ω,其中α是人为设定的学习率,显然 ω 总是往cost局部最小化的方向趋近(可以注意并不总是往全局最优的方向)

  • 局部最优:

我们经常担心模型在训练的过程中陷入局部最优的困境中,但实际上由于Mini-batch的存在在实际工程中模型陷入鞍点(局部最优)的概率是很小的

  • epoch:

轮次,一个epoch指的是所有的训练样本在模型中都进行了一次正向传播和一次反向传播


2 随机梯度下降(SGD Stochastic Gradient Descent)


随机梯度下降算法在梯度下降算法的基础上进行了一定的优化,其对于每一个实例都进行更新,也就是不是用MSE,而是

对 ω进行更新
优势:SGD更不容易陷入鞍点之中,同时其拥有更好的性能
优点:
由于不是在全部训练数据上的损失函数,而是在每轮迭代中,随机优化某一条训练数据上的损失函数,这样每一轮参数的更新速度大大加快。
缺点:
准确度下降。由于即使在目标函数为强凸函数的情况下,SGD仍旧无法做到线性收敛。
可能会收敛到局部最优,由于单个样本并不能代表全体样本的趋势.
不易于并行实现。

for i in range(number of epochs):np.random.shuffle(data)for each in data:weights_grad = evaluate_gradient(loss_function, each, weights)weights = weights - learning_rate * weights_grad


3 批量梯度下降(BGD Batch Gradient Descent)


BGD通常是取所有训练样本损失函数的平均作为损失函数,每次计算所有样本的梯度,进行求均值,计算量比较大,会陷入鞍点
优点:
一次迭代是对所有样本进行计算,此时利用矩阵进行操作,实现了并行。
由全数据集确定的方向能够更好地代表样本总体,从而更准确地朝向极值所在的方向。当目标函数为凸函数时,BGD一定能够得到全局最优。
缺点:
当样本数目 m 很大时,每迭代一步都需要对所有样本计算,训练过程会很慢。(有些样本被重复计算,浪费资源)

for i in range(number of epochs):np.random.shuffle(data)for each in data:weights_grad = evaluate_gradient(loss_function, each, weights)weights = weights - learning_rate * weights_grad


4 小批量梯度下降(mini-Batch GD,mini-batch Gradient Descent)


mini-batch GD采取了一个折中的方法,每次选取一定数目(mini-batch)的样本组成一个小批量样本,然后用这个小批量来更新梯度,这样不仅可以减少计算成本,还可以提高算法稳定性。

for i in range(number of epochs):np.random.shuffle(data)for batch in get_batches(data, batch_size = batch_size):weights_grad = evaluate_gradient(loss_function, batch, weights)weights = weights - learning_rate * weights_grad


优点:融合了BGD和SGD优点

  • 通过矩阵运算,每次在一个batch上优化神经网络参数并不会比单个数据慢太多。
  • 每次使用一个batch可以大大减小收敛所需要的迭代次数,同时可以使收敛到的结果更加接近梯度下降的效果。
  • 可实现并行化。

梯度下降:BGD、SGD、mini-batch GD介绍及其优缺点​

blog.csdn.net/qq_41375609/article/details/112913848​编辑


5 代码

  • BGD
import matplotlib.pyplot as plt
import numpy as np# BGD 批量梯度下降x_data = np.arange(1.0,200.0,1.0)
y_data = np.arange(2.0,400.0,2.0)def forward(x,w):return x*wdef cost(x,y,w):cost = 0for x_val,y_true in zip(x,y):y_pred = forward(x_val,w)loss_val = (y_true - y_pred)**2cost = cost + loss_valreturn cost/len(x)def gradient(x,y,w):gradient = 0for x_val,y_true in zip(x,y):gradient_temp = 2 * x_val *(x_val * w - y_true)gradient = gradient + gradient_tempreturn gradient/len(x)w = 1.0
lr = 0.00001epoch_list = []
cost_list = []print("Before train 4: ",forward(400,w))
for epoch in range(100):cost_val = cost(x_data,y_data,w)gradient_val = gradient(x_data,y_data,w)w = w - lr * gradient_valprint("epoch: ",epoch," loss: ",cost_val," w: ",w)epoch_list.append(epoch)cost_list.append(cost_val)if (cost_val < 1e-5):breakprint("After train 4: ",forward(400,w))plt.plot(epoch_list,cost_list)
plt.xlabel("epoch")
plt.ylabel("loss")
plt.savefig("./data/pytorch2.png")

  • SGD
import matplotlib.pyplot as pltimport numpy as np# SGD随机梯度下降x_data = np.arange(1.0,200.0,1.0)
y_data = np.arange(2.0,400.0,2.0)def forward(x,w):return x * wdef loss(x,y_true,w):y_pred = forward(x,w)return (y_pred-y_true)**2def gradient(x,y,w):return 2 *x *(x *w-y)w = 1.0
lr = 0.00001epoch_list = []
loss_list = []print("Before train 4: ",forward(400,w))
for epoch in range(1000):seed = np.random.choice(range(len(x_data)))loss_val = loss(x_data[seed],y_data[seed],w)gradient_val = gradient(x_data[seed],y_data[seed],w)w -= lr*gradient_valprint("epoch: ",epoch," loss: ",loss_val," w: ",w)epoch_list.append(epoch)loss_list.append(loss_val)if (loss_val < 1e-7):break
print("After train 4: ",forward(400,w))plt.plot(epoch_list,loss_list)
plt.xlabel("epoch")
plt.ylabel("loss")
plt.savefig("./data/pytorch2_1.png")

  • mini-Batch GD
import matplotlib.pyplot as pltimport numpy as np
import random# mini-batch GD 小批量随机梯度下降x_data = np.arange(1.0,200.0,1.0)
y_data = np.arange(2.0,400.0,2.0)def forward(x,w):return x*wdef cost(x,y,w):cost = 0for x_val,y_true in zip(x,y):y_pred = forward(x_val,w)loss_val = (y_true - y_pred)**2cost = cost + loss_valreturn cost/len(x)def gradient(x,y,w):gradient = 0for x_val,y_true in zip(x,y):gradient_temp = 2 * x_val *(x_val * w - y_true)gradient = gradient + gradient_tempreturn gradient/len(x)def get_seed_two(nums):# 从数组中随机取两个索引index1 = random.randrange(len(nums))index2 = random.randrange(len(nums))while index2 == index1:index2 = random.randrange(len(nums))return index1, index2w = 1.0
lr = 0.00001
batch_size = 2epoch_list = []
loss_list = []# 存储随机取出的数的索引
seed = []# 存储一个batch的数据
x_data_mini = []
y_data_mini = []print("Before train 4: ",forward(400,w))
for epoch in range(1000):# 设定Batchsize 大小为2,每次随机取所有数据中的两个,作为一个batch,进行训练idx1,idx2= get_seed_two(x_data)seed.append(idx1)seed.append(idx2)for i in range(batch_size):x_data_mini.append(x_data[seed[i]])y_data_mini.append(y_data[seed[i]])loss_val = cost(x_data_mini,y_data_mini,w)gradient_val = gradient(x_data_mini,y_data_mini,w)w -= lr*gradient_valprint("epoch: ",epoch," loss: ",loss_val," w: ",w)epoch_list.append(epoch)loss_list.append(loss_val)if (loss_val < 1e-7):break
print("After train 4: ",forward(400,w))plt.plot(epoch_list,loss_list)
plt.xlabel("epoch")
plt.ylabel("loss")
plt.savefig("./data/pytorch2_2.png")

🌈我的分享也就到此结束啦🌈
如果我的分享也能对你有帮助,那就太好了!
若有不足,还请大家多多指正,我们一起学习交流!
📢未来的富豪们:点赞👍→收藏⭐→关注🔍,如果能评论下就太惊喜了!
感谢大家的观看和支持!最后,☺祝愿大家每天有钱赚!!!欢迎关注、关注!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/331195.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

STM32 学习——1. STM32最小系统

这是一个最小系统的测试&#xff0c;LED灯会进行闪烁。选用PC13口&#xff0c;因为STM32F103C8T6 硬件开发板中&#xff0c;这个端口是一个LED 1. proteus8.15 原理图 2. cubemx 新建工程 3. keil 代码 while (1){HAL_GPIO_TogglePin(LED_GPIO_Port, LED_Pin);HAL_Delay(100);…

linux文件权限常用知识点,基于Linux(openEuler、CentOS8)

目录 知识点常用实例 知识点 真实环境文件显示 解读 常用实例 文件所有者 chown -R nginx:nginx /home/source目录权限(R选填必须大写<遍历子文件夹及文件>) chmod -R 755 /home/sourcechmod -R 777 /home/source

为什么推荐前端用WebStorm软件编程?

一、介绍 WebStorm是由JetBrains公司开发的一款JavaScript开发工具&#xff0c;被广大中国JS开发者誉为“Web前端开发神器”、“最强大的HTML5编辑器”、“最智能的JavaScript IDE”等。它支持JavaScript、ECMAScript 6、TypeScript、CoffeeScript、Dart和Flow等多种语言的代码…

【vue】封装的天气展示卡片,在线获取天气信息

源码 <template><div class"sen_weather_wrapper"><div class"sen_top_box"><div class"sen_left_box"><div class"sen_top"><div class"sen_city">山东</div><qctc-time cl…

unreal engine 5.0.3 创建游戏项目

根据虚幻官网介绍&#xff0c;虚幻引擎5可免费用于创建线性内容、定制项目和内部项目。你可以免费用它开发游戏&#xff0c;只有当你的产品营收超过100万美元时&#xff0c;才收取5%的分成费用。所以目前国内也有许多游戏厂商在使用UE制作游戏。UE5源码也已开源&#xff0c;有U…

多线程(C++11)

多线程&#xff08;C&#xff09; 文章目录 多线程&#xff08;C&#xff09;前言一、std::thread类1.线程的创建1.1构造函数1.2代码演示 2.公共成员函数2.1 get_id()2.2 join()2.3 detach()2.4 joinable()2.5 operator 3.静态函数4.类的成员函数作为子线程的任务函数 二、call…

解释JAVA语言中关于方法的重载

在JAVA语言中&#xff0c;方法的重载指的是在同一个类中可以存在多个同名方法&#xff0c;但它们的参数列表不同。具体来说&#xff0c;重载的方法必须满足以下至少一条条件: 1. 参数个数不同。 2. 参数类型不同。 3. 参数顺序不同。 当调用一个重载方法时&#xff0c;编译器…

MyBatis 学习笔记(一)

MyBatis 封装 JDBC :连接、访问、操作数据库中的数据 MyBatis 是一个持久层框架。 MyBatis 提供的持久层框架包括 SQLMaps 和 Data Access Objects&#xff08;DAO&#xff09; SQLMaps&#xff1a;数据库中的数据和 Java数据的一个映射关系 封装 JDBC 的过程Data Access Ob…

东哥一句兄弟,你还当真了?

关注卢松松&#xff0c;会经常给你分享一些我的经验和观点。 你还真把自己当刘强东兄弟了?谁跟你是兄弟了?你在国外的房子又不给我住&#xff0c;你出去旅游也不带上我!都成人年了&#xff0c;东哥一句客套话&#xff0c;别当真! 今天&#xff0c;东哥在高管会上直言&…

计算机网络套接字知识(非常详细)从零基础入门到精通

本节重点 认识IP地址, 端口号, 网络字节序等网络编程中的基本概念; 学习socket api的基本用法; 一、预备知识 1.理解源IP地址和目的IP地址 ⭐在IP数据包头部中&#xff0c;有两个IP地址&#xff0c;分别叫做源IP地址和目的IP地址。 思考: 我们光有IP地址就可以完成通信了…

深入理解NumPy与Pandas【numpy模块及Pandas模型使用】

二、numpy模块及Pandas模型使用 numpy模块 1.ndarray的创建 import numpy as np anp.array([1,2,3,4]) bnp.array([[1,2,3,4],[5,6,7,8]]) print(a) #[1 2 3 4] print(b) #[[1 2 3 4][5 6 7 8]] 1.1使用array()函数创建 numpy.array(object, dtype None, copy True, ord…

CentOS 7安装/卸载Grafana

说明&#xff1a;本文介绍CentOS 7操作系统如何安装/卸载Grafana&#xff1b; 安装 Step1&#xff1a;下载rpm文件 敲下面的命令&#xff0c;下载grafana的rpm文件 wget https://dl.grafana.com/oss/release/grafana-7.3.7-1.x86_64.rpmStep2&#xff1a;安装grafana 敲下…

Redis常见数据类型(6)-set, zset

目录 Set 命令小结 内部编码 使用场景 用户画像 其它 Zset有序集合 普通指令 zadd zcard zcount zrange zrevrange ​编辑 zrangebyscore zpopmax/zpopmin bzpopmax/bzpopmin zrank/zrevrank zscore zrem zremrangebyrank zremrangebyscore Set 命令小结 …

图像上下文学习|多模态基础模型中的多镜头情境学习

【原文】众所周知&#xff0c;大型语言模型在小样本上下文学习&#xff08;ICL&#xff09;方面非常有效。多模态基础模型的最新进展实现了前所未有的长上下文窗口&#xff0c;为探索其执行 ICL 的能力提供了机会&#xff0c;并提供了更多演示示例。在这项工作中&#xff0c;我…

以太坊(3)——智能合约

智能合约 首先明确一下几个说法&#xff08;说法不严谨&#xff0c;为了介绍清晰才说的&#xff09;&#xff1a; 全节点矿工 节点账户 智能合约是基于Solidity语言编写的 学习Solidity语言可以到WFT学院官网&#xff08;Hello from WTF Academy | WTF Academy&#xff09;…

Go语言的内存泄漏如何检测和避免?

文章目录 Go语言内存泄漏的检测与避免一、内存泄漏的检测1. 使用性能分析工具2. 使用内存泄漏检测工具3. 代码审查与测试 二、内存泄漏的避免1. 使用defer关键字2. 使用垃圾回收机制3. 避免循环引用4. 使用缓冲池 Go语言内存泄漏的检测与避免 在Go语言开发中&#xff0c;内存泄…

Linux基础(五):常用基本命令

从本节开始&#xff0c;我们正式进入Linux的学习&#xff0c;通过前面的了解&#xff0c;我们知道我们要以命令的形式使用操作系统&#xff08;使用操作系统提供的各类命令&#xff0c;以获得字符反馈的形式去使用操作系统。&#xff09;&#xff0c;因此&#xff0c;我们是很有…

win32-鼠标消息、键盘消息、计时器消息、菜单资源

承接前文&#xff1a; win32窗口编程windows 开发基础win32-注册窗口类、创建窗口win32-显示窗口、消息循环、消息队列 本文目录 键盘消息键盘消息的分类WM_CHAR 字符消息 鼠标消息鼠标消息附带信息 定时器消息 WM_TIMER创建销毁定时器 菜单资源资源相关菜单资源使用命令消息的…

634 · 单词矩阵

链接&#xff1a;LintCode 炼码 - ChatGPT&#xff01;更高效的学习体验&#xff01; . - 力扣&#xff08;LeetCode&#xff09; 题解&#xff1a; class Solution { public: struct Trie {Trie() {next.resize(26, nullptr);end false;} std::vector<Trie*> next; b…

Python高级进阶--dict字典

dict字典⭐⭐ 1. 字典简介 dictionary&#xff08;字典&#xff09; 是 除列表以外 Python 之中 最灵活 的数据类型&#xff0c;类型为dict 字典同样可以用来存储多个数据字典使用键值对存储数据 2. 字典的定义 字典用{}定义键值对之间使用,分隔键和值之间使用:分隔 d {中…