现代C++中的从头开始深度学习:【4/8】梯度下降

一、说明

        在本系列中,我们将学习如何仅使用普通和现代C++编写必须知道的深度学习算法,例如卷积、反向传播、激活函数、优化器、深度神经网络等。

        在这个故事中,我们将通过引入梯度下降算法来介绍数据中 2D 卷积核的拟合。我们将使用卷积和上一个故事中引入的成本函数概念,将所有内容编码为现代C++和特征。

这个故事是:C++的梯度下降,查看其他故事:

0 — 现代C++深度学习编程基础

1 — 在C++中编码 2D 卷积

2 — 使用 Lambda 的成本函数

4 — 激活函数

...更多内容即将推出。

二、函数逼近作为优化问题

        如果你读过我们之前的演讲,你已经知道,在机器学习中,我们大部分时间都在关注使用数据来寻找函数近似值。

        通常,我们通过找到最小化成本值的系数来获得函数近似。因此,我们的近似问题被转换为优化问题,我们试图最小化成本函数的值。

三、成本函数和梯度下降

        成本函数计算使用函数 H(X) 近似目标函数 F(X) 的开销。例如,如果 H(X) 是输入 X 和核 k 之间的卷积,则 MSE 成本函数由下式给出:

        我们通常做 Yn = F(Xn),结果是:

MSE是均方误差,是上一个故事中介绍的成本函数

因此,我们的目标是找到最小化MSE(k)的内核值km。找到 km 的最基本(但最强大)的算法是梯度下降。

梯度下降使用成本函数梯度来查找最小成本。为了理解什么是梯度,让我们谈谈成本表面。

四、绘制成本曲面

        为了更容易理解,让我们暂时假设内核仅由两个系数组成。如果我们为每个可能的组合绘制 MSE(k) 的值,我们最终会得到这样的表面:k[k00, k01][k00, k01]

在每个点上,曲面与0k₀₀轴有一个倾角,与0k₀₁轴有另一个倾角:(k00, k01, MSE(k00, k01))

偏导数

这两个斜率分别是 MSE 曲线相对于轴 O k₀₀ 和 Ok₀₁ 的偏导数。在微积分中,我们非常使用符号∂来表示偏导数:

这两个偏导数共同构成了MSE相对于O k₀₀和Ok₀₁的梯度。此梯度用于驱动梯度下降算法的执行,如下所示:

梯度下降的实际应用

在成本表面上执行此“导航”的算法称为梯度下降。

五、梯度下降

梯度下降伪代码描述如下:

gradient_descent:initialize k, learning_rate, epoch = 1repeatk = k - learning_rate x ∇Cost(k)until epoch <= max_epochreturn k

        learning_rate x ∇Cost(k) 的值通常称为权重更新。我们可以通过以下方式恢复梯度下降的行为:

for each iteration:calculate the weight updatesubtract it from the parameter k

顾名思义,Cost(k) 是配置 k 的成本函数。梯度下降的目的是找到成本(k)最小的k值。

learning_rate通常是像 0.1、0.01、0.001 左右这样的标量。此值控制优化过程中的步长。

该算法循环 max_epoch 次。有时,我们会更早地停止算法,即,即使纪元< max_epoch,在 Cost(k) 太小的情况下。

我们通常用超参数的名称来指代learning_ratemax_epoch参数

要实现梯度下降,我们需要知道的最后一件事是如何计算 C(k) 的梯度。幸运的是,在成本函数为 MSE 的情况下,如前所述,查找 ∇Cost(k) 非常简单。

六、查找 MSE 梯度

到目前为止,我们已经看到梯度的分量是每个轴 0kij 的成本面的斜率。我们还看到,MSEk) 相对于每个 i 个、核 k 的系数 j-的梯度由下式给出:

让我们记住,MSE(k) 由下式给出:

其中n是每对的索引(Yn,Tn),r&c是输出矩阵系数的索引:

输出布局

使用链式规则和线性组合规则,我们可以通过以下方式找到MSE梯度:

由于 NR、CYn 和 T n 的值是已知的,我们需要计算的只是 Tn 中每个系数相对于系数 kij 的偏导数。在带有填充 P 的卷积的情况下,此导数由下式给出:

如果我们展开 r 和 c 的总和,我们可以发现梯度由下式给出:

其中 δn 是矩阵:

以下代码实现此操作:

auto gradient = [](const std::vector<Matrix> &xs, std::vector<Matrix> &ys, std::vector<Matrix> &ts, const int padding)
{const int N = xs.size();const int R = xs[0].rows();const int C = xs[0].cols();const int result_rows = xs[0].rows() - ys[0].rows() + 2 * padding + 1;const int result_cols = xs[0].cols() - ys[0].cols() + 2 * padding + 1;Matrix result = Matrix::Zero(result_rows, result_cols);for (int n = 0; n < N; ++n) {const auto &X = xs[n];const auto &Y = ys[n];const auto &T = ts[n];Matrix delta = T - Y;Matrix update = Convolution2D(X, delta, padding);result = result + update;}result *= 2.0/(R * C);return result;
};

现在我们知道了如何获得梯度,让我们来实现梯度下降算法。

七、编码梯度下降

最后,我们的梯度下降的代码在这里:

auto gradient_descent = [](Matrix &kernel, Dataset &dataset, const double learning_rate, const int MAX_EPOCHS)
{std::vector<double> losses; losses.reserve(MAX_EPOCHS);const int padding = kernel.rows() / 2;const int N = dataset.size();std::vector<Matrix> xs; xs.reserve(N);std::vector<Matrix> ys; ys.reserve(N);std::vector<Matrix> ts; ts.reserve(N);int epoch = 0;while (epoch < MAX_EPOCHS){xs.clear(); ys.clear(); ts.clear();for (auto &instance : dataset) {const auto & X = instance.first;const auto & Y = instance.second;const auto T = Convolution2D(X, kernel, padding);xs.push_back(X);ys.push_back(Y);ts.push_back(T);}losses.push_back(MSE(ys, ts));auto grad = gradient(xs, ys, ts, padding);auto update = grad * learning_rate;kernel -= update;epoch++;}return losses;
};

This is the base code. We can improve it in several ways, for example:

  • using the loss of each instance to update the kernel. This is called Stochastic Gradient Descent (SGD), which is very useful in real-world scenarios;
  • grouping instances in batches and updating the kernel after each batch, which is called Minibatch;
  • 使用学习率时间表来降低各个时期的学习率;
  • 在这一行中,我们可以连接一个优化器,如MomentumRMSPropAdam。 我们将在接下来的故事中讨论优化器;kernel -= update;
  • 引入验证或使用某些交叉验证架构;
  • 通过矢量化替换嵌套循环以获得性能和 CPU 使用率(如上一个故事所述);for(auto &instance: dataset)
  • 添加回调和钩子以更轻松地自定义我们的训练循环。

我们可以暂时忘记这些改进。现在,重点是了解如何使用梯度来更新参数(在我们的例子中是内核)。这是当今机器学习的基本、核心概念,也是推进更高级主题的关键因素。

让我们通过说明性实验将其付诸行动,看看这段代码是如何工作的。

八、实际实验:修复索贝尔边缘探测器

        在上一个故事中,我们了解到我们可以应用 Sobel 滤波器 Gx 来检测垂直边缘:

        现在,问题是:给定原始图像和边缘图像,我们是否设法恢复了 Sobel 滤镜 Gx

换句话说,我们可以在给定输入 X 和预期输出 Y 的情况下拟合内核吗?

答案是肯定的,我们将使用梯度下降来做到这一点。

九、加载和准备数据

        首先,我们使用OpenCV从文件夹中读取一些图像。我们对它们应用 Gx 过滤器,并将它们成对存储在我们的数据集对象中:

auto load_dataset = [](std::string data_folder, const int padding) {Dataset dataset;std::vector<std::string> files;for (const auto & entry : fs::directory_iterator(data_folder)) {Mat image = cv::imread(data_folder + entry.path().c_str(), cv::IMREAD_GRAYSCALE);Mat formatted_image = resize_image(image, 640, 640);Matrix X;cv::cv2eigen(formatted_image, X);X /= 255.;auto Y = Convolution2D(X, Sobel.Gx, padding);auto pair = std::make_pair(X, Y);dataset.push_back(pair);}return dataset;
};auto dataset = load_dataset("../images/");

我们使用辅助实用程序 .resize_image 格式化每个输入图像以适合 640x640 网格

        如上图所示,将每个图像集中到黑色 640x640 网格中,而无需通过简单地调整图像大小来拉伸图像。resize_image

        我们使用 Gx 过滤器为每个图像生成真实输出 Y。现在,我们可以忘记这个过滤器了。我们将使用梯度下降和 2D 卷积从数据中恢复它。

十、运行实验       

通过连接所有部分,我们最终可以看到训练执行情况:

int main() {const int padding = 1;auto dataset = load_dataset("../images/", padding);const int MAX_EPOCHS = 1000;const double learning_rate = 0.1;auto history = gradient_descent(kernel, dataset, learning_rate, MAX_EPOCHS);std::cout << "Original kernel is:\n\n" << std::fixed << std::setprecision(2) << Sobel.Gx << "\n\n";std::cout << "Trained kernel is:\n\n" << std::fixed << std::setprecision(2) << kernel << "\n\n";plot_performance(history);return 0;
}

The following sequence illustrates the fitting process:

一开始,内核充满了随机数。因此,在第一个纪元中,输出图像通常是黑色输出。

然而,在几个纪元之后,梯度下降开始使核拟合到全局最小值。

最后,在最后一个纪元中,输出几乎等于基本事实。此时,损失值渐近移动到最低值。让我们检查一下各时期的损失表现:

训练表现

在机器学习中,这种损失曲线形状非常常见。事实证明,在第一个纪元中,参数基本上是随机值。这会导致初始损失很高:

成本面上的算法搜索表示

在最后一个时期,梯度下降终于完成了它的工作,将核拟合到合适的值,这使得损失收敛到最小值。

现在,我们可以将学习到的内核与原始 Gx Sobel 的过滤器进行比较:

正如我们所料,学习内核和原始内核非常接近。请注意,如果我们在更多的时期训练内核(并使用较小的学习率),这种差异仍然可以更小。

用于训练此内核的代码可以在此存储库中找到。

十一、关于差异化和autodiff

        在这个故事中,我们使用常见的微积分规则来查找MSE偏导数。然而,在某些情况下,为给定的复数成本函数找到代数导数可能具有挑战性。幸运的是,现代机器学习框架提供了一个神奇的功能,称为自动微分或简称。autodiff

   autodiff跟踪每个基本算术运算(如加法或乘法),将链式规则应用于它们以找到偏导数。因此,在使用时,我们不需要计算偏导数的代数公式,甚至不需要直接实现它们。autodiff

        由于这里我们使用的是简单的、众所周知的成本公式,因此不需要手动使用甚至解决复杂的微分。autodiff

更详细地涵盖导数、偏导数和自动微分值得一个新的故事!

十二、结论 

        在这个故事中,我们学习了如何使用梯度来拟合数据中的内核。我们介绍了梯度下降,它简单、强大,是推导出更复杂的算法(如反向传播)的基础。我们还使用梯度下降法进行了一项实际实验,从数据中恢复了Sobel滤波器。

参考书

机器学习,米切尔

Cálculo 3, Geraldo Ávila(巴西葡萄牙语)

神经网络:综合基础,Haykin

模式分类,杜达

计算机视觉:算法和应用,Szeliski。

Python machine learning, Raschka

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/80610.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于Byzer-LLM和ChatGLM-6B快速搭建一款免费的语言大模型助力电商企业

假设有一家电商企业&#xff0c;员工大概20-30人&#xff0c;企业是在淘宝等电商平台买衣服&#xff0c;目前在淘宝上已经上架十万种服饰, 之前淘宝限制服饰的标题描述字数&#xff0c;所以写的特别精简。以该公司售卖的阔腿裤为例&#xff0c;目前标题都是这样的&#xff1a; …

云道资本:2023中国氢能源产业-氢制备深度研究报告(附下载)

关于报告的所有内容&#xff0c;公众【营销人星球】获取下载查看 核心观点 中国可再生能源消纳能力提升远远滞后于发电占比的提升。大规模的可再生能源发电是实现碳中和的关键一步&#xff0c;但风电、光伏发电间歌性、波动性强&#xff0c;电网消纳压力较大&#xff0c;且电…

十三、高光谱图像基础

1、各种图像 1.1 高光谱图像 高光谱成像技术的原理基于物体的光谱吸收和反射特性。当光线通过或反射于物体表面时,被物体吸收或反射的光波将发生变化。高光谱成像系统通过对各个波段的频谱进行连续测量,可以获取到物体在不同波段下的光谱信息。通过分析这些光谱数据,我们…

/lib/x86_64-linux-gnu/libc.so.6: version `GLIBC_2.28‘ not found

某项目中&#xff0c;我要给别人封装一个深度学习算法的SDK接口&#xff0c;运行在RK3588平台上&#xff0c;然后客户给我的交叉编译工具链是 然后我用他们给我的交叉编译工具链报下面的错误&#xff1a; aarch64-buildroot-linux-gnu-gcc --version /data/chw/aarch64/bin/cca…

Verilog求log10和log2近似

Verilog求log10和log2近似 Verilog求10对数近似方法&#xff0c;整数部分用位置index代替&#xff0c;小数部分用查找表实现 参考&#xff1a; Verilog写一个对数计算模块Log2(x) FPGA实现对数log2和10*log10

ArcGIS在洪水灾害普查、风险评估及淹没制图中应用教程

详情点击链接&#xff1a;ArcGIS在洪水灾害普查、风险评估及淹没制图中应用教程 一&#xff1a;洪水普查技术规范 1.1 全国水旱灾害风险普查实施方案 1.2 洪水风险区划及防治区划编制技术要求 1.3 山丘区中小河流洪水淹没图编制技术要求 二&#xff1a;ArcGIS及数据管理 …

python爬虫之scrapy框架介绍

一、Scrapy框架简介 Scrapy 是一个开源的 Python 库和框架&#xff0c;用于从网站上提取数据。它为自从网站爬取数据而设计&#xff0c;也可以用于数据挖掘和信息处理。Scrapy 可以从互联网上自动爬取数据&#xff0c;并将其存储在本地或在 Internet 上进行处理。Scrapy 的目标…

Demystifying Prompts in Language Models via Perplexity Estimation

Demystifying Prompts in Language Models via Perplexity Estimation 原文链接 Gonen H, Iyer S, Blevins T, et al. Demystifying prompts in language models via perplexity estimation[J]. arXiv preprint arXiv:2212.04037, 2022. 简单来说就是作者通过在不同LLM和不同…

HTML5 Canvas和Svg:哪个简单且好用?

HTML5 Canvas 和 SVG 都是基于标准的 HTML5 技术&#xff0c;可用于创建令人惊叹的图形和视觉体验。 首先&#xff0c;让我们花几句话介绍HTML5 Canvas和SVG。 什么是Canvas? Canvas&#xff08;通过 标签使用&#xff09;是一个 HTML 元素&#xff0c;用于在用户计算机屏幕…

基于EIoT能源物联网的工厂智能照明系统应用改造-安科瑞黄安南

【摘要】&#xff1a;随着物联网技术的发展&#xff0c;许多场所针对照明合理应用物联网照明系统&#xff0c;照明作为工厂的重要能耗之一&#xff0c;工厂的照明智能化控制&#xff0c;如何优化控制、提高能源的利用率&#xff0c;达到节约能源的目的。将互联网的技术应用到工…

如何离线安装ModHeader - Modify HTTP headers Chrome插件?

如何离线安装ModHeader - Modify HTTP headers Chrome插件&#xff1f; 1.1 前言1.2 打开Chrome浏览器的开发者模式1.3 下载并解压打包好的插件1.4 解压下载好的压缩包1.5 加载插件1.6 如何使用插件? 1.1 前言 ModHeader 是一个非常好用的Chrome浏览器插件&#xff0c;可以用…

并发——线程与进程的关系,区别及优缺点?

文章目录 1. 图解进程和线程的关系2.程序计数器为什么是私有的?3. 虚拟机栈和本地方法栈为什么是私有的?4. 一句话简单了解堆和方法区5. 说说并发与并行的区别? 从 JVM 角度说进程和线程之间的关系 1. 图解进程和线程的关系 下图是 Java 内存区域&#xff0c;通过下图我们…

Uniapp基于微信小程序以及web端文件、图片下载,带在线文件测试地址

一、效果 传送门 二、UI视图 <scroll-view scroll-x="true" scroll-y="true" :style

python数据分析报告 范文,python数据分析报告+代码

大家好&#xff0c;本文将围绕python数据分析期末大作业报告展开说明&#xff0c;python数据分析期末大作业是一个很多人都想弄明白的事情&#xff0c;想搞清楚python数据分析报告怎么写需要先了解以下几个事情。 背景 虽然用Python开发爬虫脚本&#xff0c;顺利把某房产网站的…

国产水声功率放大器ATA-L50在水下通信领域中的应用

水下通信是指在水下环境中进行信息交流和传递的技术。由于水下环境的特殊性&#xff0c;水下通信面临着诸多挑战&#xff0c;如水压、水体的吸收和散射等。然而&#xff0c;随着科技的发展&#xff0c;水下通信技术已经取得了长足的进步&#xff0c;并广泛应用于海洋资源开发、…

鉴源实验室丨汽车网络安全攻击实例解析(二)

作者 | 田铮 上海控安可信软件创新研究院项目经理 来源 | 鉴源实验室 社群 | 添加微信号“TICPShanghai”加入“上海控安51fusa安全社区” 引言&#xff1a;汽车信息安全事件频发使得汽车行业安全态势愈发紧张。这些汽车网络安全攻击事件&#xff0c;轻则给企业产品发布及产品…

Mageia 9 RC1 正式发布,Mandriva Linux 发行版的社区分支

导读Mageia 9 首个 RC 已发布。公告写道&#xff0c;自 2023 年 5 月发布 beta 2 以来&#xff0c;Mageia 团队一直致力于解决许多顽固问题并提供安全修复和新特性。 新版本的控制中心添加了用于删除旧内核的新功能&#xff0c;该功能在 Mageia 9 中默认自动启用&#xff0c;用…

大麦订单生成器 大麦一键生成订单

后台一键生成链接&#xff0c;独立后台管理 教程&#xff1a;修改数据库config/Conn.php 不会可以看源码里有教程 下载源码程序&#xff1a;https://pan.baidu.com/s/16lN3gvRIZm7pqhvVMYYecQ?pwd6zw3

【小沐学NLP】在线AI绘画网站(百度:文心一格)

文章目录 1、简介2、文心一格2.1 功能简介2.2 操作步骤2.3 使用费用2.4 若干示例2.4.1 女孩2.4.2 昙花2.4.3 山水画2.4.4 夜晚2.4.5 古诗2.4.6 二次元2.4.7 帅哥 结语 1、简介 当下&#xff0c;越来越多AI领域前沿技术争相落地&#xff0c;逐步释放出极大的产业价值&#xff0…