VAE中的“变分”什么

写在前面

        VAE(Variational Autoencoder),中文译为变分自编码器。其中AE(Autoencoder)很好理解。那“变分”指的是什么呢?—其实是“变分推断”。变分推断主要用在VAE的损失函数中,那变分推断是什么,VAE的损失函数又是什么呢?下面我就来说一说!

       可以先看一下 这篇文章,介绍了VAE的代码实现。

一、通俗理解损失函数

        这篇文章已经整体介绍了VAE,这里我详细介绍一下VAE的损失函数:

\mathbf{LOSS=-E_{q(z|x)}\left [ \textit{log}p(x|z) \right ]+KL(q(z|x)||p(z))}

        每个变量的说明下面会有介绍,现在我们只关注VAE的损失函数有由两部分组成,第一部分是一个交叉熵,我们称之为“重构项”,其作用是确保训练时输入和输出间的相似性;第二部分是KL散度,叫做“KL散度项”,它其实是一个正则项,主要解决了两个AE模型的痛点,这也是VAE成功并流行的主要原因:

        1.潜在空间的结构化:AE的潜在空间往往是无规则的,这意味着编码器学到的表征可能杂乱无章,不便于后续操作。VAE通过添加KL散度项来惩罚潜在变量分布与预设先验分布(就是p(z),是一个标准高斯分布)之间的偏差,从而迫使潜在空间呈现出一定的结构,使潜在变量的分布更加合理和连贯。说人话就是:VAE可以输入标准高斯分布的采样数据,生成精美的图像。

        2.潜在空间的连续性:KL散度项要求潜在变量 z 的分布  q(z|x) 尽可能接近预设的先验分布  p(z) ,这个先验分布通常选择为标准正态分布。通过这种方式,潜在空间被组织成一个连续、平滑的多维空间,其中每一维上的值都能够自由变动而不产生剧烈变化。这种设计确保了在潜在空间中的小步长移动会导致解码结果的轻微变化,从而实现了连续性。说人话就是:VAE可以通过微调输入的采样数据,一定程度上修改生成图像的属性。这也是造成“抽卡”的原因之一。

        损失函数的这两项可以简单的这么理解,但是它其实是推导出来的,这就说来话长。感兴趣的小伙伴继续往下看。

二、边际似然

1.边际似然的定义

        VAE 是一种生成模型,生成模型的核心任务是计算在给定潜在变量 z 的情况下生成观测数据 x 的概率。我们希望模型能够生成与真实数据分布相似的新数据,这一目标可以通过边际似然 p(x) 来实现。

        其中z就是Latent;x是训练用的图像;p(x)是边际似然,也就是VAE的损失函数

        p(x)可以很好的衡量模型的生成能力。p(x)直接衡量了模型在生成数据方面的整体能力,因为它考虑了所有潜在的隐变量 z 对观测数据 x 的影响。高的p(x)意味着模型可以很好地解释数据,并且在生成新数据时表现出较强的能力。

        具体来说,如果模型的边际似然高,说明模型在所有可能的隐变量 z 下生成观测数据的概率累加起来后非常高,这意味着模型学到了数据的真实分布。

        边际似然 p(x) 表示给定模型情况下生成观测数据 x 的概率,定义为:

p(x)=\int p(x|z)p(z)dz  (1)

        其中,条件概率 p(x∣z):给定潜在变量 z 的情况下,生成观测数据 x 的概率。先验分布 p(z):潜在变量 z 的分布,反映了我们对 z 的先验知识。

2.边际似然的推导

        使用全概率公式,边际似然可以用全概率公式来定义,具体为:

p(x)=\int p(x,z)dz (2)

        这里 p(x,z)是 x 和 z 的联合分布。根据条件概率的定义,联合分布可以表示为:

p(x,z)=p(x|z)p(z) (3)

        因此,我们可以将边际似然表示为:

p(x)=\int p(x|z)p(z)dz (4)

        我们要做的就是最大化p(x),这里多说一句,最大化p(x)的目标是使得模型生成的总体概率分布 p(x) 更接近于真实数据分布。这样,模型生成的新样本就会与训练数据的分布一致。

        直观理解:假设我们在训练一个模型生成手写数字图片。如果真实的数据集中 80% 是“1”,20% 是“2”,那么一个好的生成模型应该能够生成 80% 的“1”和 20% 的“2”。而不是让p(x)趋近于1.

3.边际似然的挑战

        但是计算边际似然通常是一个复杂且困难的任务,原因包括:

        (1)高维积分:在实际的应用中,潜在变量 z 通常是高维的。例如,如果 z 是 100 维的向量,那么积分就需要在 100 维的空间上进行。这种高维积分是非常复杂的,解析解几乎不可能得到。

        (2)分布形式复杂:在生成模型中,条件分布 p(x∣z)和先验分布 p(z) 可能并不是简单的概率分布。例如,p(x∣z) 可能由一个深度神经网络参数化,计算时需要经过非线性激活函数和复杂的网络结构,这会让这个积分无法直接求解。

        (3)数值计算的困难:计算边际似然时,需要对 z 的所有可能值进行积分,也就是计算出在所有潜在表示 z 上,生成数据 x 的所有可能性。现实中,z 的范围非常大,即使是连续的,也可能取值无穷多个,直接求解所有 z 的可能性几乎是不可能的。

        举个例子,假设我们有一个简单的生成模型,其中:p(z) 是标准正态分布N(0,I)。p(x∣z) 是由一个深度神经网络生成的图像。直接计算边际似然意味着我们需要知道所有 z 的取值如何影响 x。如果 z 是 100 维向量,那么在 R^{100} 空间上对 z 进行积分(或采样)需要极大的计算资源。神经网络的非线性使得每个 p(x∣z) 的计算都很复杂,最终让直接计算积分变得不可行。

        为了解决上面的问题,让模型可以正常训练,我们引入变分推断。

三、变分推断

1.变分推断的定义

        变分推断是一种通过引入近似分布来解决无法直接计算复杂积分的问题的方法。在生成模型中,我们的目标是最大化观测数据的边际似然 p(x):

p(x)=\int p(x|z)p(z)dz (5)

        如前所述,这个积分通常很难直接计算,因此我们引入一个 近似后验分布(也叫变分分布,就是训练时模型的输出 q(z∣x),来代替无法直接求解的真实后验 p(z∣x)。变分推断的目标是让 q(z∣x) 尽可能地接近真实的 p(z∣x)。

\mathbf{p(x)=\int p(x|z)p(z)dz=\int p(z|x)\frac{p(x|z)p(z)}{q(z|x)}dz} (6)

        通过这种重写,我们引入了 q(z∣x) 作为一个权重,这样我们可以在期望的形式下进行优化。我们现在有一个可以计算的表达式:

\mathbf{\mathit{log}p(x)=\mathit{log}\int p(z|x)\frac{p(x|z)p(z)}{q(z|x)}dz} (7)

        尽管重写了表达式,计算 p(x)依然困难,因为积分本身依然难解。因此,我们应用 Jensen 不等式(log是凸函数),将对数操作从积分外移到期望内部(这里的期望是由积分转化来的):

\mathbf{\mathit{log}p(x)=\mathit{log}\int p(z|x)\frac{p(x|z)p(z)}{q(z|x)}dz\geq E_{q(z|x)} \left [ log\frac{p(x|z)p(z)}{q(z|x)} \right ] } (8)

        其中,Eq(z∣x)[⋅]表示在 q(z∣x) 分布下对 z 取期望。这一不等式说明,我们得到了一个对数边际似然的下界,即变分下界 (ELBO)。

2.变分下界ELBO

        式子(8)右边的表达式即为变分下界(Evidence Lower Bound,),通常记作 ELBO,至此我们的目标也变成了最大化ELBO,从而间接地最大化边际似然 p(x)。式子(8)可以写成:

\mathbf{ELBO=E_{q(z|x)} \left [ log\frac{p(x|z)p(z)}{q(z|x)} \right ] } (9)

        式子(9)右边可以展开成:

\mathbf{ELBO=E_{q(z|x)}\left [ \textit{log}p(x|z) \right ]+E_{q(z|x)}\left [ \textit{log}p(z) \right ]-E_{q(z|x)}\left [ \textit{log}q(z|x) \right ]} (10)

        因为KL散度公式:

\mathbf{KL(q(z|x)||p(z))=E_{q(z|x)}[log\frac{q(z|x)}{p(z)}]=E_{q(z|x)}[\textit{log}q(z|x)]-E_{q(z|x)}[\textit{log}p(z)]}(11)

        可以看到,式子(10)右边的第二项和第三项可以用KL散度代替:

\mathbf{-KL(q(z|x)||p(z))=E_{q(z|x)}[\textit{log}p(z)]-E_{q(z|x)}[\textit{log}q(z|x)]}(12)

        最终,ELBO 可以写成如下式子,这也是VAE需要优化的损失函数:

\mathbf{ELBO=E_{q(z|x)}\left [ \textit{log}p(x|z) \right ]-KL(q(z|x)||p(z))} (13)

        ELBO 公式展示了两个部分:

        重构项:表示模型生成数据的能力。

        KL 散度项:作为正则化项,控制 q(z∣x) 和 p(z) 之间的差异。最小化这个项有助于使近似后验 q(z∣x) 尽量接近先验 p(z),从而促进模型的泛化能力。p(z)一般被设置成标准高斯分布。

最大化 ELBO 的意义:

        优化目标:最大化 ELBO 实际上是希望在重构能力和潜在分布的正则化之间取得平衡。通过调整这两个部分,可以确保模型既能够良好地重构输入数据,又能够学习到有意义的潜在空间。

        间接最大化边际似然:由于 ELBO 是边际似然的下界,最大化 ELBO 也会使得边际似然 p(x) 的值增加。

        ELBO 在 VAE 中扮演着至关重要的角色,它将生成模型的目标与优化过程结合起来,使得模型能够在重构能力和潜在空间的正则化之间找到最佳平衡。通过最大化 ELBO,VAE 能够学习到有效的潜在表示,从而生成新样本。

四、代码实现中的公式

        这篇文章介绍了VAE的代码实现,其中的损失函数是ELBO的具体实现,我们来看一下,具体是怎么实现的。

        我们的目标是最大化ELBO,相当于最小化其负值,因此 VAE 的损失函数可以表示为:

\mathbf{LOSS=-E_{q(z|x)}\left [ \textit{log}p(x|z) \right ]+KL(q(z|x)||p(z))}   (14)

1.重构项

        交叉熵的定义为:

H(p,q)=-E_p[log\textbf{q}]   (15)

        如果我们将 p(x∣z) 视为模型生成 x 的概率分布(对应代码中的recon_x,即模型的输出),而将真实数据的分布视为 q(x)(对应代码中的x,即GT),则ELBO的第一项可以写成:

\mathbf{E_{q(z|x)}[\textit{log}p(x|z)]=-H(x,q(z|x))}  (16)

        最大化 ELBO 的第一项(重构项)实际上是最小化交叉熵损失,代码如下:

BCE = nn.functional.binary_cross_entropy(recon_x, x, reduction='sum')

2.KL 散度项

        对于高斯分布 q(z|x)=N(\mu ,\sigma ^2)和标准正态分布p(x)=N(0,1),我们可以将 KL散度计算分解为以下几个步骤:

        (1)KL散度 的公式为:

\mathbf{KL\left [ q(z|x)||p(z) \right ]=\int q(z|x)log(\frac{q(z|x)}{p(z)})dz}(17)

        解释一下变量的意义:

        q(z∣x):这是给定输入 x时隐变量 z 的后验分布,通常由编码器生成。

        p(z):这是隐变量 z 的先验分布,通常是标准高斯分布 N(0,1)。

        比率 \mathbf{\frac{q(z|x)}{p(z)}}:这个比率表示后验分布与先验分布的相对关系,反映了后验分布相较于先验分布的“信息量”。

        对数项\mathbf{log(\frac{q(z|x)}{p(z)})}:量化了 q(z∣x) 相较于 p(z) 的信息增益。正值表示后验分布相对于先验分布的增加的信息,而负值则表示信息的损失。

        积分:通过对所有可能的 z进行积分,KL散度 计算了整个后验分布与先验分布之间的差异。

        (2)将q(z|x)=N(\mu ,\sigma ^2)p(z)=N(0,1)带入(17

KL\left [ q(z|x)||p(z) \right ]=\int N(\mu ,\sigma ^2)log(\frac{N(\mu ,\sigma ^2)}{N(0,1)})dz(18)

        (3)高斯分布的公式: 高斯分布的概率密度函数为:

\mathbf{N(z;\mu ,\sigma ^2)=\frac{1}{\sqrt{2\pi \sigma ^2}}exp[-\frac{(z-\mu )^2}{2\sigma ^2}]} (19)

        而标准正态分布为:

\mathbf{N(z;0,1)=\frac{1}{\sqrt{2\pi }}exp(-\frac{z^2}{2})}    (20)

        (4)计算 KL散度: 将这些代入 K散度的公式中,最终可以简化得到:

KL(q(z|x)||p(z))=-\frac{1}{2}(1+log(\sigma ^2)-\mu ^2-\sigma ^2)  (21)

        (5)简化: 进一步简化后,得到:

KL(q(z|x)||p(z))=-0.5(log(\sigma ^2)+1-\mu ^2-\sigma ^2)  (22)

        (6)用对数方差表示: 在实现中,通常使用对数方差 log(\sigma ^2) 来计算,这样可以避免数值稳定性问题,最终得到的 KL散度公式是:

KL(q(z|x)||p(z))=-0.5(1+log(\sigma ^2)-\mu ^2-\sigma ^2)(23)

        KL散度代码实现:在代码实现的时候编码器的输出其实是均值mu和对数方差log_var

KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())

        其中log_var 是对数方差,使用对数方差的形式可以保证数值稳定性、避免负值以及计算便利性,这种做法在许多深度学习模型中都得到了广泛应用,尤其是在处理概率分布时。;mu 是均值;\sigma ^2=exp(log\sigma ^2)

五、总结

        1.VAE中的“变分”指的是“变分推断”;

        2.VAE的损失函数值最大化边际似然;

        3.最大化边际似然几乎做不到,所以使用变分推断来简化计算;

        4.使用变分推断后,训练通过最大化ELBO实现;

        5.ELBO有两项:重构项和KL散度项。重构项的作用是确保训练时输入和输出间的相似性,就是传统的损失函数常用的东西;KL散度项是一个正则项,能确保潜在空间的结构化和连续性。

        VAE就介绍到这,关注不迷路(*^__^*) 

  

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/458398.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

C++ | Leetcode C++题解之第514题自由之路

题目&#xff1a; 题解&#xff1a; class Solution { public:int findRotateSteps(string ring, string key) {int n ring.size(), m key.size();vector<int> pos[26];for (int i 0; i < n; i) {pos[ring[i] - a].push_back(i);}vector<vector<int>>…

linux指令笔记

bash命令行讲解 lyt &#xff1a;是用户名 iZbp1i65rwtrfbmjetete2b2Z :这个是主机名 ~ &#xff1a;这个是当前目录 $ &#xff1a;这个是命令行提示符 每个指令都有不同的功能&#xff0c;大部分指令都可以带上选项来实现不同的效果。 一般指令和选项的格式&#xff1a;…

Linux 重启命令全解析:深入理解与应用指南

Linux 重启命令全解析&#xff1a;深入理解与应用指南 在 Linux 系统中&#xff0c;掌握正确的重启命令是确保系统稳定运行和进行必要维护的关键技能。本文将深入解析 Linux 中常见的重启命令&#xff0c;包括功能、用法、适用场景及注意事项。 一、reboot 命令 功能简介 re…

洛谷 P3130 [USACO15DEC] Counting Haybale P

原题链接 题目本质&#xff1a;线段树 感觉我对线段树稍有敏感&#xff0c;线段树一眼就看出来了&#xff0c;思路出来得也快&#xff0c;这道题也并不是很难。 解题思路&#xff1a; 这道题能看出来是线段树就基本成功一半了&#xff0c;区间修改区间查询&#xff0c;就基…

深入探索:深度学习在时间序列预测中的强大应用与实现

引言&#xff1a; 时间序列分析是数据科学和机器学习中一个重要的研究领域&#xff0c;广泛应用于金融市场、天气预报、能源管理、交通预测、健康监控等多个领域。时间序列数据具有顺序相关性&#xff0c;通常展示出时间上较强的依赖性&#xff0c;因此简单的传统回归模型往往…

使用微信免费的内容安全识别接口,UGC场景开发检测违规内容功能

大家好&#xff0c;我是小悟。 内容安全识别主要针对的是有UGC即用户生成内容的功能场景&#xff0c;通过结合内容安全的审核能力&#xff0c;应对文本、图片、音频内容类型下的敏感内容识别、涉黄内容识别、暴恐内容识别、辱骂内容识别等违规问题&#xff0c;可以提高审核效率…

【Docker大揭秘】

Docker 调试一天的血与泪的教训&#xff1a;设备条件&#xff1a;对应的build preparation相应的报错以及修改 作为记录 构建FASTLIO2启动docker获取镜像列出镜像运行containerdocker中实现宿主机与container中的文件互传 调试一天的血与泪的教训&#xff1a; 在DOCKER中跑通F…

ubuntu-开机黑屏问题快速解决方法

开机黑屏一般是由于显卡驱动出现问题导致。 快速解决方法&#xff1a; 通过ubuntu高级选项->recovery模式->resume->按esc即可进入recovery模式&#xff0c;进去后重装显卡驱动&#xff0c;重启即可解决。附加问题&#xff1a;ubuntu的默认显示管理器是gdm3,如果重…

海洋生物图像分割系统:算法改进策略

海洋生物图像分割系统源码&#xff06;数据集分享 [yolov8-seg-C2f-DiverseBranchBlock&#xff06;yolov8-seg-C2f-Faster-EMA等50全套改进创新点发刊_一键训练教程_Web前端展示] 1.研究背景与意义 项目参考ILSVRC ImageNet Large Scale Visual Recognition Challenge 项目…

PHP-FPM 性能配置优化

4 核 8 G 服务器大约可以开启 500 个 PHP-FPM&#xff0c;极限吞吐量在 580 qps &#xff08;Query Per Second 每秒查询数&#xff09;左右。 Nginx php-fpm 是怎么工作的&#xff1f; php-fpm 全称是 PHP FastCGI Process Manager 的简称&#xff0c;从名字可得知&#xff…

第十七周:机器学习

目录 摘要 Abstract 一、MCMC 1、马尔科夫链采样 step1 状态设定 step2 转移矩阵 step3 马尔科夫链的生成 step4 概率分布的估计 2、蒙特卡洛方法 step1 由一个分布产生随机变量 step2 用这些随机变量做实验 3、MCMC算法 4、参考文章 二、flow-based GAN 1、引…

【Linux网络】Linux网络基础入门:初识网络,理解网络协议

&#x1f4dd;个人主页&#x1f339;&#xff1a;Eternity._ ⏩收录专栏⏪&#xff1a;Linux “ 登神长阶 ” &#x1f339;&#x1f339;期待您的关注 &#x1f339;&#x1f339; ❀Linux网络 &#x1f4d2;1. 计算机网络背景发展历程"协议" &#x1f4dc;2. 网络协…

UML外卖系统报告(包含具体需求分析)

1、系统背景 随着互联网技术的快速发展&#xff0c;外卖订餐服务逐渐成为人们生活中的一部分。传统的电话订餐方式面临诸多不便和限制&#xff0c;而基于互联网的外卖订餐系统则提供了更加便捷、快速和高效的订餐服务。这种系统通过将餐厅、顾客和配送人员连接起来&#xff0c…

Sentinel详解

参考博客&#xff1a; SpringCloud Sentinel集成到微服务项目中&#xff08;保姆级教程&#xff09; 什么是Sentinel Sentinel 是面向分布式服务架构的轻量级流量控制产品&#xff0c;主要以流量为切入点&#xff0c;从流量控制、熔断降级、系统负载保护等多个维度来保护服务…

Vue学习记录之二十五 Vue3中Web Componets的使用

一、webcomponets介绍 在Vue 3中使用Web Components可以通过多种方式实现。Web Components是一组允许你创建可重用、封装良好的自定义元素的标准技术。它们包括Custom Elements、Shadow DOM、HTML Templates等。 Vue3 支持原生模式&#xff0c;可以让单个文件的js,css,html以h…

移植rv1106SDK的ipcweb到ubuntu

移植minilogger 在sdk中找到minilogger&#xff0c;复制到任意的文件夹&#xff0c;执行 cmake ./ make make install把minilogger 安装到系统 修改Makefile 在上次那个基础上&#xff0c;修改Makefile #* 这里原来要包含../Makefile.param&#xff0c;但含有sdk的很多参数…

w003基于Springboot的图书个性化推荐系统的设计与实现

&#x1f64a;作者简介&#xff1a;拥有多年开发工作经验&#xff0c;分享技术代码帮助学生学习&#xff0c;独立完成自己的项目或者毕业设计。 代码可以私聊博主获取。&#x1f339;赠送计算机毕业设计600个选题excel文件&#xff0c;帮助大学选题。赠送开题报告模板&#xff…

Mysql(十) --- 用户权限和管理

文章目录 前言1. 应用场景2.用户2.1. 查看用户2.2. 创建用户2.2.1 语法2.2.2. 注意事项 2.2.3.示例2.3. 修改密码2.3.1. 语法2.3.2. 示例 2.4.删除用户2.4.1.语法2.4.2.示例 3. 权限和授权MySQL内置支持的权限列表3.1.给用户授权3.1.1.语法3.1.2. 示例 3.2.回收权限3.2.1.语法3…

Golang Agent 可观测性的全面升级与新特性介绍

作者&#xff1a;张海彬&#xff08;古琦&#xff09; 背景 自 2024 年 6 月 26 日&#xff0c;ARMS 发布了针对 Golang 应用的可观测性监控功能以来&#xff0c;阿里云 ARMS 团队与程序语言与编译器团队一直致力于不断优化和提升该系统的各项功能&#xff0c;旨在为开发者提…

基于SpringBoot的中药材进存销管理系统设计与实现

摘要 中药材进存销管理系统是为了满足中药材生产和销售企业的高效管理需求&#xff0c;涵盖了药材采购、库存管理和销售跟踪等主要功能。本系统采用Spring Boot框架进行开发&#xff0c;结合了前端和数据库设计&#xff0c;构建了一个实用的中药材管理平台&#xff0c;为企业提…