生成对抗网络

目录

0. Abstract

1. Introduction

2. Relatedwork

3.Experiments

4.Advantages and disadvantages

5.Conclusions and future work(idea)

6. 网络训练源代码


0. Abstract

我们提出了一个新的框架,通过一个对抗的过程来估计生成模型,在此过程中我们同时训练两个模型:一个生成模型G捕获数据分布,和一种判别模型D,它估计样本来自训练数据而不是G的概率。G的训练程序是最大化D犯错的概率,这个框架对应于一个极小极大的双人游戏。在任意函数G和D的空间中,存在唯一解,G可以重现训练数据分布,D处处等于1/2。在G和D由多层感知器定义的情况下,整个系统可以通过反向传播进行训练。在训练或生成样本的过程中,不需要任何马尔科夫链或展开的近似推理网络。通过对生成的样本进行定性和定量评估,实验证明了该框架的潜力。

1. Introduction

深度学习的前景是发现丰富的分层模型,它代表人工智能应用中遇到的各种数据的概率分布,如自然图像、包含语音的音频波形和自然语言语料库中的符号。到目前为止,在深度学习中最显著的成功涉及到判别模型,通常是那些将高维、丰富的感官输入映射到类标签的模型。这些惊人的成功主要是基于反向传播和dropout算法,使用分段线性单元,具有特别良好的梯度。由于在极大似然估计和相关策略中出现的许多难以处理的概率计算的近似性,以及由于难以在生成环境中利用分段线性单元的优点,深度生成模型的影响较小。我们提出了一种新的生成模型估计方法来克服这些困难。

在提出的对抗网框架中,生成模型与对手进行了比较:一个学习确定样本是来自模型分布还是来自数据分布的判别模型。生成模型可以被认为类似于一组伪造者,他们试图制造假币并在不被发现的情况下使用它,而判别模型则类似于警察,试图发现假币,这个游戏的竞争促使两队改进他们的方法,直到仿冒品无法从真品中辨别出来。

该框架可以生成针对多种模型的特定训练算法和优化算法,在这篇文章中,我们探讨了生成模型通过一个多层感知器传递随机噪声来生成样本的特殊情况,而判别模型也是一个多层感知器,我们把这种特殊情况称为对抗网络。在这种情况下,我们可以只使用非常成功的反向传播和dropout算法来训练这两个模型,并且只使用正向传播来训练生成模型的样本,不需要近似推论或马尔科夫链。

2. Relatedwork

有潜在变量的有向图形模型的另一种选择是有潜在变量的无向图形模型,如限制玻尔兹曼机(RBMs),深玻尔兹曼机(DBMs)及其众多变体。这些模型中的相互作用被表示为未归一化势函数的乘积,由随机变量所有状态的全局求和/积分进行归一化。这个数量(配分函数)和它的梯度是棘手的,但最琐碎的情况下,虽然他们可以由马尔可夫链蒙特卡罗(MCMC)方法估计。对于依赖于MCMC的学习算法来说,混合是一个很重要的问题。

深度置信网络(DBNs)[16]是包含一个无向层和多个有向层的混合模型。虽然存在一种快速的分层近似训练准则,但DBNs存在与无向和有向模型相关的计算困难。

也有人提出了不近似或不限制对数似然的替代标准,如分数匹配和噪声对比估计(NCE),这两种方法都要求所学习的概率密度被解析指定为一个归一化常数。请注意,在许多具有多层潜在变量(如DBNs和DBMs)的有趣生成模型中,甚至不可能导出可处理的非规范化概率密度,一些模型,如去噪自动编码器[30]和收缩自动编码器的学习规则非常类似于分数匹配应用于RBMs。在NCE中,与本文一样,使用了判别训练准则来拟合生成模型。然而,生成模型本身用于从固定噪声分布的样本中区分生成的数据,而不是拟合一个单独的判别模型。由于NCE使用一个固定的噪声分布,当模型学习到即使是在观察变量的一个小子集上的一个近似正确的分布之后,学习速度也会显著减慢。

最后,一些技术不涉及明确定义概率分布,而是训练生成机器从期望的分布中抽取样本,这种方法的优点是可以通过反向传播来训练这些机器。近期主要的工作包括生成随机网络(GSN)框架:它扩展了广义去噪自动编码器:两者都可以看作是定义一个参数化的马尔科夫链,即一个人学习机器的参数,执行一个步骤的生成马尔科夫链。与GSNs相比,对抗网的采样不需要马尔科夫链,由于反求网络在生成过程中不需要反馈环,所以它们能够更好地利用分段线性单元,这提高了反向传播的性能,但在使用反馈环时存在无限制激活的问题。通过反向传播训练生成机器的最新例子包括自动编码变分贝叶斯和随机反向传播。

当模型都是多层感知器时,对抗性建模框架最容易应用。为了学习生成器在数据x上的分布pg,我们定义了一个输入噪声变量pz (z), G (z;θg)表示将噪声变量映射到数据空间, G是一个可微函数,表示为一个参数为θg的多层感知器。我们还定义一个多层感知器D (x;θd)输出一个标量,D(x)表示x来自数据集而不是pg的概率。我们训练D最大限度地将正确的标签分配给训练样本和来自G的样本的概率,我们同时训练G,使得 log(1 - D(G(z))) 最小化。

换句话说,D和G玩了一个具有值函数V (G,D)的二人极大极小博弈:

在下一节中,我们将对对抗网进行理论分析,主要说明当G和D具有足够的容量时,训练准则允许恢复数据生成分布,例如在非参数极限下。请参见图1,其中对该方法进行了不太正式的、更具教育性的解释。在实践中,我们必须使用迭代的数值方法来实现游戏。优化完成内环的训练在计算上是禁止的,对于有限的数据集会导致过度拟合。相反,我们在优化D的k个步骤和优化G的一个步骤之间交替进行,只要G变化足够慢,D就会保持在其最优解附近,这种策略类似于SML/PCD:训练从一个学习步骤到下一个学习步骤保持来自马尔可夫链的样本,该过程在算法1中正式给出。

在实际应用中,公式1可能无法为G提供足够的梯度来学习。在学习的早期,当G较差时,D可以很有信心地拒绝样本,因为它们与训练数据明显不同。在这种情况下,log(1 - D(G(z)))饱和,与其训练G去最小化log(1 - D(G(z))不如训练G去最大化logD(G(z))这一目标函数的结果与动态函数相同,但在学习中提供了更强的学习效果。

注:图中的黑色虚线表示真实的样本的分布情况,蓝色虚线表示判别器判别概率的分布情况,绿色实线表示生成样本的分布。Z表示噪声,Z到x表示通过生成器之后的分布的映射情况。
我们的目标是使用生成样本分布(绿色实线)去拟合真实的样本分布(黑色虚线),来达到生成以假乱真样本的目的。
可以看到在(a)状态处于最初始的状态的时候,生成器生成的分布和真实分布区别较大,并且判别器判别出样本的概率不是很稳定,因此会先训练判别器来更好地分辨样本。
通过多次训练判别器来达到(b)样本状态,此时判别样本区分得非常显著和良好。然后再对生成器进行训练。
训练生成器之后达到(c)样本状态,此时生成器分布相比之前,逼近了真实样本分布。
经过多次反复训练迭代之后,最终希望能够达到(d)状态,生成样本分布拟合于真实样本分布,并且判别器分辨不出样本是生成的还是真实的(判别概率均为0.5)。也就是说我们这个时候就可以生成出非常真实的样本啦,目的达到。[2]

3.Experiments

包括MNIST, theTorontoFace Database (TFD),和CIFAR-10一系列数据集上训练了对抗网络。生成网络使用rectifier linear and sigmoid两种激活函数,而判别器使用maxout激活。应用dropout训练判别器网络。虽然我们的理论框架允许在生成器的中间层使用dropout和其他噪声,但我们只使用噪声作为生成器网络最底层的输入。

4.Advantages and disadvantages

与以前的建模框架相比,这个新框架有优点也有缺点。缺点主要是没有显式表示的pg (x),在训练时D必须与G同步。它的优点是不需要使用马尔科夫链,只使用backprop来获得梯度,在学习过程中不需要推理,可以将多种函数合并到模型中。

5.Conclusions and future work(idea)

  1. 将c作为G和D的输入,可以得到条件生成模型p(x | c)。
  2. 学习近似推理:可以利用一个辅助网络在给定x时来预测z。这与wake-sleep算法训练的推理网络类似,但具有在生成器网络完成训练后,可以对固定生成器网络进行推理网络训练的优点。
  3. 通过训练一系列共享参数的条件模型,可以近似地对所有条件p(xS | x)进行建模,其中s是x指标的子集。本质上,我们可以使用对抗网来实现确定性MP-DBM[11]的随机扩展。
  4. 半监督学习:当有限的标记数据可用时,鉴别器或推理器的特性可能会降低分类器的性能。
  5. 效率改进:在培训过程中,通过划分更好的方法来协调G和D,或者确定更好的z分布,可以大大加快训练的速度。

6. 网络训练源代码

import torch.nn as nn
from torchvision import transforms
import torch
import torch.optim as op
from torchvision import datasets
from torch.utils.data import DataLoaderbatch_size = 64
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])dataset = datasets.MNIST(root='../dataset/mnist/', train=True, download=True, transform=transform)
data_loader = DataLoader(dataset, shuffle=True, batch_size=batch_size)
# test_dataset = datasets.MNIST(root='../dataset/mnist/', train=False, download=True, transform=transform)
# test_loader = DataLoader(test_dataset, shuffle=False, batch_size=batch_size)
"生成器的输入是一组噪声"
class Generator(nn.Module):def __init__(self, in_features=64, out_features=784):""":param in_features: 生成器的in_features,一般输入z的维度z_dim,该值可自定义:param out_features: 生成器的out_features,需要与真实数据的维度一致"""super().__init__()"nn.Tanh() #用于归一化数据"self.gen = nn.Sequential(nn.Linear(in_features, 256),nn.LeakyReLU(0.1),nn.Linear(256, out_features),nn.Tanh())def forward(self, z):gz = self.gen(z)return gz"判别器"
class Discriminator(nn.Module):def __init__(self, in_features=784):""":param in_features: 真实数据的维度、同时也是生成的假数据的"""super().__init__()"使用非饱和激活函数nn.LeakyReLU(0.1),防止梯度下降""nn.Tanh() 是双曲正切函数,通常用于确保生成的输出处于特定的值范围内,例如在 -1 到 1 之间"self.disc = nn.Sequential(nn.Linear(in_features, 128),nn.LeakyReLU(0.1),nn.Linear(128, 1),nn.Sigmoid())def forward(self, data):""":param data: 输入的data可以是真实数据时,Disc输出dx。输入的data是gz时,Disc输出dgz:return:"""return self.disc(data)    # 输出结果为置信度z_dim = 64
real_data_dim = 784
lr = 0.1
"判断是否有GPU存在"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
"实例化判别器与生成器"
gen = Generator(in_features=z_dim, out_features=real_data_dim)
gen.to(device)
disc = Discriminator(in_features=real_data_dim).to(device)
disc.to(device)
"定义判别器与生成器所使用的优化算法"
op_disc = op.Adam(disc.parameters(), lr=lr, betas=(0.9, 0.999))
op_gen = op.Adam(gen.parameters(), lr=lr, betas=(0.9, 0.999))
"定义损失函数"
criterion = nn.BCELoss(reduction="mean")
if __name__ == "__main__":for epoch in range(10):for batch_idx, (x, _) in enumerate(data_loader):x = x.view(-1, 784).to(device)batch_size = x.shape[0]# 判别器反向传播=========================================================================="------------------------判别器对真实数据的预测概率------------------------"dx = disc(x).view(-1)"所有真实数据的损失均值"loss_real = criterion(dx, torch.ones_like(dx))loss_real.backward()"判别器对真实数据的预测概率 dx 的平均值,然后使用 .item() 方法将其转换为标量值,并将结果存储在 D_x 变量中"D_x = dx.mean().item()"------------------------判别器对生成数据的预测概率------------------------"noise = torch.randn((batch_size, z_dim)).to(device)"将随机噪声 noise 通过生成器模型 gen 生成假数据 gz,这些假数据模拟真实数据的特征"gz = gen(noise)"使用 gz.detach() 是为了阻止生成数据进入判别器的计算图,以确保在这里只计算判别器对生成数据的预测概率"dgz1 = disc(gz.detach())"所有生成数据的损失均值,在训练生成对抗网络(GAN)的判别器时,对于生成的数据,我们希望判别器的输出接近零,表示生成数据被正确分类为假数据。因此,我们将目标设置为与生成数据对应的标签,通常是零"loss_fake = criterion(dgz1, torch.zeros_like(dgz1))loss_fake.backward()"判别器对生成数据的预测概率 dx 的平均值,然后使用 .item() 方法将其转换为标量值,并将结果存储在 D_G_Z1 变量中"D_G_z1 = dgz1.mean().item()"判别器对真实数据的损失和对生成数据的损失之和。这个总损失通常用于衡量判别器的性能"errorD = loss_real + loss_fake"errorD.backward() #直接对errorD反向传播,也可分别对loss_real,loss_fake执行反向传播""更新判别器上的权重"op_disc.step()"清零判别器迭代后的梯度"disc.zero_grad()# 生成器反向传播*=========================================================================="注意,由于在此时判别器上的权重已经被更新过了,所以dgz的值会变化,需要重新生成""得到判别器对生成数据的输出 dgz2"dgz2 = disc(gz)"计算了生成器的损失。与判别器的损失不同,这里我们希望生成器生成的数据被判别器识别为真实数据,所以我们使用目标值为1的损失函数来计算生成器的损失"Gloss = criterion(dgz2, torch.ones_like(dgz2))"反向传播"Gloss.backward()"更新生成器上的权重"op_gen.step()"清零生成器更新后梯度"gen.zero_grad()D_G_z2 = dgz2.mean().item()# print(f"第{ epoch+1 }次训练")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/168242.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

每日一题 2678. 老人的数目(简单)

简单题,不多说 class Solution:def countSeniors(self, details: List[str]) -> int:ans 0for l in details:if int(l[11:13]) > 60:ans 1return ans

数据结构与算法-(10)---列表(List)

🌈write in front🌈 🧸大家好,我是Aileen🧸.希望你看完之后,能对你有所帮助,不足请指正!共同学习交流. 🆔本文由Aileen_0v0🧸 原创 CSDN首发🐒 如…

WebService SOAP1.1 SOAP1.12 HTTP PSOT方式调用

Visual Studio 2022 新建WebService项目 创建之后启动运行 设置默认文档即可 经过上面的创建WebService已经创建完成,添加HelloWorld3方法, [WebMethod] public string HelloWorld3(int a, string b) { //var s a b; return $"Hello World ab{a …

java基础面试题

java后端面试题大全 1.java基础1.1 java中和equals的区别1.2 String、StringBuffer、StringBuilder的区别1.3 intern方法的作用及原理1.4 String不可变的含义1.5 static用法、使用位置、实例1.6 为什么静态方法不能调用非静态方法和变量1.7 异常/Exception1.7 try/catch/finall…

请求转发和响应重定向

请求转发与响应重定向是什么? 请求转发和响应重定向是两种在HTTP协议中常见的操作,用于在服务器和客户端之间传递数据。 请求转发(RequestDispatcher)是服务器收到请求后,从一个资源跳转到另一个资源的操作。这种操作…

QCC 音频输入输出

QCC 音频输入输出 QCC蓝牙芯片(QCC3040 QCC3083 QCC3084 QCC5181 等等)支持DAC、I2S、SPDIF输出,AUX、I2S、SPDIF、A2DP 输入 蓝牙音频输入,模拟输出是最常见的方式。 也可以再此基础上动态切换输入方式。 输入方式切换参考 sta…

SOLIDWORKS 2024新功能 3D CAD三维机械设计10大新功能

SOLIDWORKS 2024新增功能 - 3D CAD三维机械设计 10大新增功能 1. 先前版本的兼容性 •利用您订阅的 SOLIDWORKS,可将您的 SOLIDWORKS 设计作品保存为旧版本,与使用旧版本 SOLIDWORKS 的供应商无缝协作。 •可将零件、装配体和工程图保存为最新版本…

redis 宕机恢复

1.集群现在状态 6个进程 主从分配如下 2. 关闭其中一个主节点 可以看到从节点转换成了主节点,7002主节点处在失败状态: 3.重新启动失败节点 可以看到启动后成为从节点: 另外,如果主节点宕机,从节点转换为主节点…

【1.总纲】

目录 知识框架No.0 总纲安排No.1课程安排一、目标二、内容三、 学到 No.2 深度学习介绍一、AI地图二、图片分类三、物体检测和分割四、样式迁移五、人脸合成六、文字生成图片七、文字生成-GPT八、无人驾驶九、广告点击 No.3 安装No.3 安装 知识框架 No.0 总纲安排 B站网址&…

中文编程开发语言编程实际案例:程序控制灯电路以及桌球台球室用这个程序计时计费

中文编程开发语言编程实际案例:程序控制灯电路以及桌球台球室用这个程序计时计费 上图为:程序控制的硬件设备电路图 上图为:程序控制灯的开关软件截图,适用范围比如:台球厅桌球室的计时计费管理,计时的时候…

Shell动态条进度

代码: #!/bin/bashfunction dongtai(){ i0 bar index0 arr( "|" "/" "-" "\\" )while [ $i -le 100 ] dolet indexindex%4printf "[]准备开始:[%-100s][%d%%][\e[43;46;1m%c\e[0m]\r" "$bar" "$…

深度学习---卷积神经网络

卷积神经网络概述 卷积神经网络是深度学习在计算机视觉领域的突破性成果。在计算机视觉领域。往往输入的图像都很大,使用全连接网络的话,计算的代价较高。另外图像也很难保留原有的特征,导致图像处理的准确率不高。 卷积神经网络&#xff0…

嵌入式linux系统设备树实例分析

前言 我们可以从LED程序中榨取很多知识:基本的驱动框架、驱动的简单分层、驱动的分层分离思想、总线设备驱动模型、设备树等。这大多都是结合韦老师的教程学的。 这篇笔记结合第6个demo(基于设备树)来学习、分析: 框图 下面是L…

JMeter添加插件

一、前言 ​ 在我们的工作中,我们可以利用一些插件来帮助我们更好的进行性能测试。今天我们来介绍下Jmeter怎么添加插件? 二、插件管理器 ​ 首先我们需要下载插件管理器jar包 下载地址:Install :: JMeter-Plugins.org 然后我们将下载下来…

《红蓝攻防对抗实战》一. 隧道穿透技术详解

一.隧道穿透技术详解 从技术层面来讲,隧道是一种通过互联网的基础设施在网络之间传递数据的方式,其中包括数据封装、传输和解包在内的全过程,使用隧道传递的数据(或负载)可以使用不同协议的数据帧或包。 假设我们获取到一台内网主机的权限,…

如何制作.exe免安装绿色单文件程序,将源代码打包成可独立运行的exe文件

环境: rustdesk编译文件和文件夹 文件程序制作工具 问题描述: 如何制作.exe免安装绿色单文件程序,将源代码打包成可独立运行的exe文件,像官网那种呢? 将下面编译好的rustdesk文件夹制作成一个.exe免安装绿色单文件程序,点击exe就可以运行 在github上找了半天也没有…

AIGC笔记--基于DDPM实现图片生成

目录 1--扩散模型 2--训练过程 3--损失函数 4--生成过程 5--参考 1--扩散模型 完整代码:ljf69/DDPM 扩散模型包含两个过程,前向扩散过程和反向生成过程。 前向扩散过程对一张图像逐渐添加高斯噪声,直至图像变为随机噪声。 反向生成过程…

推荐微软的开源课程《AI-For-Beginners》

今天给大家推荐一个对新手非常友好的AI入门课程《AI-For-Beginners》。 该课程由微软推出,为期12周,共24课时,对比Google的AI入门课更通俗易懂一些,强烈推荐刚入门的AI小白们学习!而且是免费!课程资源看文…

SQL UPDATE 语句(更新表中的记录)

SQL UPDATE 语句 UPDATE 语句用于更新表中已存在的记录。 还可以使用AND或OR运算符组合多个条件。 SQL UPDATE 语法 具有WHERE子句的UPDATE查询的基本语法如下所示: UPDATE table_name SET column1 value1, column2 value2, ... WHERE conditi…

【第三天】C++类和对象进阶指南:从堆区空间操作到友元的深度掌握

一、new和delete 堆区空间操作 1、new和delete操作基本类型的空间 new与C语言中malloc、delete和C语言中free 作用基本相同 区别: new 不用强制类型转换 new在申请空间的时候可以 初始化空间内容 2、 new申请基本类型的数组 3、new和delete操作类的空间 4、new申请…