GAN对抗生成网络(二)——算法及Python实现

1 算法步骤

上一篇提到的GAN的最优化问题是G^{*}=\arg\min\limits_{G}\max\limits_{D}V(G,D),本文记录如何求解这一问题。

首先为了表示方便,记\max\limits_{D}V(G,D)=L(G),这里让V(G,D)最大的D=D^{*}可视作常量。

第一步,给定初始的G_{0},使用梯度上升找到 D_0^{*},最大化L(G_0)。关于梯度下降,可以参考笔者另一篇文章《BP神经网络原理-CSDN博客》误差反向传播的部分。

第二步,使用梯度下降法,找到G最佳的参数\theta_{G}.其中\eta为学习率。

\theta_{G}\leftarrow \theta_{G}-\eta\frac{\partial{L(G)}}{\theta_{G}}得到G_{1}

 之后这两步交替进行。

这里的L(G)是有max运算的,可以被微分吗?答案是可以的

引用李宏毅老师的例子,f(x)是有max运算的,相当于分段函数,在求微分的时候,根据当前x落在哪个区域决定微分的形式如何。

2 算法与JS散度的关系

上述算法第一步训练D时本质是增大JS散度,第二步训练G时看起来是减小JS散度,但实际上不完全等同。

如下图所示,左侧表示算法第一步根据G_{0}得到了最优的D_0^{*}。当进行到算法第二步,需要根据D_0^{*}找到一个更小的JS散度,如右图所示,G选择了G_{1}从而使得V(G_1, D)<V(G_0, D)。虽然此时JS散度更小,但是由于G_{0}更换成了G_{1}D_0^{*}将更新参数变成D_1^{*},此时JS散度更大了。只能说不能让G更新得太多,否则不能达到减小JS散度的目标。回到文物造假的例子,造假者收到鉴宝者的反馈后,应该微调技艺,而不是彻底更换技艺,否则只能从头来过。

从快速收敛的角度来说,G应该不能更新太过,但是如果太小也忽略了G更好的形式,可能陷入局部最优。

3 实际训练过程

实际训练时,我们是无法计算出真实数据或生成数据实际的期望的,只能通过抽样近似得到期望。因此实际的做法如下:

3.1 第一步,初始化

初始化生成器G和判别器D

3.2 第二步,固定G,训练D

从分布函数(如高斯分布)中随机抽样出m个样本\left \{ z^1,z^2,...z^m \right \}输入给G,输出m个样本\left \{ \tilde{x}^1,\tilde{x}^2,..\tilde{x}^m \right \}G本质上概率分布转化器——将高斯分布的噪声转变成样本的分布。从真实数据中随机抽样出m个样本\left \{ x^1,x^2,...x^m \right \},将二者输入给D。训练D的参数使其接收x_{1}时打出0分,接收x_{2}时打出1分,即最大化\tilde{V}=\frac{1}{m}\sum_{i=1}^m{log{D(x^i)}}+\frac{1}{m}\sum_{i=1}^m{log{(1-D(\tilde{x}^i))}}

建模成分类或回归问题均可。

使用梯度上升法,\theta_d\leftarrow \theta_d+\eta\nabla\tilde{V}(\theta_d)

实际中需要更新多次,使得V值最大。这一步实际上只找到了一个\max\limits_{D}V(G,D)的下限(lower bound),原因是:(1)训练次数不会非常大,没法训练到收敛;(2)即使能收敛,也可能只是一个局部最优解;(3)推导时假设了D可以是任意的函数,即针对不同的x都给出最高的值,但实际中这个假设不成立。

3.3 第三步,固定D,训练G

从分布函数(如高斯分布)中随机抽样出另外m个样本\left \{ z^1,z^2,...z^m \right \}

更新G的参数\theta_G使得下式最小:

\tilde{V}=\frac{1}{m}\sum_{i=1}^m{log{D(x^i)}}+\frac{1}{m}\sum_{i=1}^m{log{(1-D(G(z^i)))}}

其中第一项与G无关,因此只需要看第二项。

根据上文的讨论,这里一般只训练一次,避免G改变过多,无法收敛。

实践中是将GD合在一起作为一个大的神经网络,前几层是G,后几层是D,中间有一个隐含层是G的输出,就是GAN希望得到的输出。第二步和第三步可分别固定神经网络中的某几层参数不动,训练其它层参数。

4 Python实现

关于GAN的代码,参考了https://github.com/junqiangchen/GAN。项目可以产生数字图片和人脸图片,其中人脸图片的生成使用了GAN的变种——WGAN,之后会专门讨论,本文讨论最原始的GAN模型。

4.1 使用新版tensorflow需要修改的地方

原始的代码直接运行是不通的,需要做一些调整;原始代码采用的是旧版Tensorflow(V1),如果安装了新版TensorFlow(V2)也需要做调整;有些包如果安装的新版同样不支持部分API,需要替换。具体如下表所示

问题调整方法备注
部分文件路径不对

调整路径,例如

from GAN.face_model import WGAN_GPModel

调整为

from GAN.genface.face_model import WGAN_GPModel
其他几处不再赘述
imresize报错

 例如

from scipy.misc import imresize

调整为

from skimage.transform import resize
最新版本scipy不支持此函数,将
imresize(test_image, (init_width * scale_factor, init_height * scale_factor))

替换为

resize(
Image.fromarray(test_image).resize(init_width * scale_factor, init_height * scale_factor))
imsave
报错

例如

scipy.misc.imsave(path, merge_img)

调整为

import cv2cv2.imwrite(path, merge_img * 255)

最新版本scipy不支持此函数,替换成cv2。个人认为最后应该乘255,因为原始数据是0~1的数据,直接存会存成几乎黑白的图片,需要还原
使用新版tensorflow的问题
import tensorflow as tf

替换为

import tensorflow.compat.v1 as tftf.compat.v1.disable_eager_execution()
新版tensorflow提供了向下兼容的compat.v1的使用方式,统一替换即可。同时要取消eager_execution模式,新版默认是“即时计算”模式,如果兼容旧版则应取消该模式。

4.2 GAN的代码解析

代码位置:gan/GAN/genmnist/mnist_model.py, class名为GANModel

4.2.1 Generator

定义在_GAN_generator函数中,总结为以下要点:

(1)含有五层网络,除最后一层,其他层在进入下一层之前都用batch_normalization归一化+relu激活函数

g4 = tf.contrib.layers.batch_norm(g4, epsilon=1e-5, is_training=self.phase, scope='bn4')
g4 = tf.nn.relu(g4)

(2)每一层都定义w和b,使用truncated_normal,即截断异常值的正态分布

tf.truncated_normal_initializer

(3)第1~2层使用全连接层,即使用w乘输入,并加上b偏置

tf.matmul(g1, g_w2) + g_b2

(4)第3~4层使用反卷积运算。是卷积运算的逆过程,关于反卷积的介绍笔者正在整理

tf.nn.conv2d_transpose(x, W, output_shape, strides=[1, strides, strides, 1], padding='SAME')

(5)第5层使用卷积运算,并使用tanh激活函数

g5 = convolution_2d(g4, g_w5)

4.2.2 Discriminator

与Generator类似,简述如下:

(1)共4层,其中1、2层使用卷积,3、4层使用全连接

(2)卷积后使用平均池化

d1 = average_pool_2x2(d1)

(3)最后一层使用sigmoid将输出控制在0~1之间

out = tf.nn.sigmoid(out_logit)

4.2.3 损失函数

Generator的损失函数为

-tf.reduce_mean(tf.log(self.D_fake))

对应前文提到的\frac{1}{m}\sum_{i=1}^m{log{(1-D(G(z^i)))}}。注意这里是用-\frac{1}{m}\sum_{i=1}^m{log{D(G(z^i))}},方向是一样的,之后笔者会讨论他们的区别。

Discriminator的损失函数为

-tf.reduce_mean(tf.log(self.D_real) + tf.log(1 - self.D_fake))

对应前文提到的\tilde{V}=\frac{1}{m}\sum_{i=1}^m{log{D(x^i)}}+\frac{1}{m}\sum_{i=1}^m{log{(1-D(\tilde{x}^i))}}

4.2.5 训练

定义D和G的训练函数,使得各自损失函数最小化。

trainD_op = tf.train.AdamOptimizer(learning_rate, beta1).minimize(self.d_loss, var_list=D_vars)
trainG_op = tf.train.AdamOptimizer(learning_rate, beta1).minimize(self.g_loss, var_list=G_vars)

先让D预训练30次,然后D和G交替训练。为什么先让D预训练30次?笔者认为D本质上就是个图片分类器,可以不依赖于G,比较好训练,预训练可以加快收敛速度。

训练时使用feed“喂数据”

feed_dict={self.X: batch_xs, self.Z: z_batch, self.phase: 1}

其中self.X表示真实的图片,self.Z表示噪声,self.phase表示batchnorm训练阶段还是测试阶段。

4.2.6 预测

生成随机噪声Z之后,喂给G,即可生成图片

outimage = self.Gen.eval(feed_dict={self.Z: z_batch, self.phase: 1}, session=sess)

不过笔者对这里的phase有些疑问,是否应该设置为0?恕笔者对Tensorflow不熟,代码解析有些走马观花,没有深究细节以及为什么这么写,等功力提高再回过头来优化。

至此,原始GAN的算法以及Python实现已介绍完毕,下一篇笔者将拓展讨论一些细节并介绍GAN的变种。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/500733.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

多线程访问FFmpegFrameGrabber.start方法阻塞问题

一、背景 项目集成网络摄像头实现直播功能需要用到ffmpeg处理rtmp视频流进行web端播放 通过网上资源找到大神的springboot项目实现了rtmp视频流转为http请求进行视频中转功能&#xff0c;其底层利用javacv的FFmpegFrameGrabber进行拉流、推流&#xff0c;进而实现了视频中转。 …

X86、X64、64位、32位归纳总结

梳理一下位数区别和命名规范。 操作系统的位数决定了内存寻址空间的大小 X86_32的简称是X86。32位&#xff0c;最多只能识别4GB的内存。 X86_64的简称是X64。64位&#xff0c;最多能识别数十TB内存。 由于历史发展原因&#xff0c;形成了这样的简称&#xff0c;简称很关键。…

redis的学习(二)

4 哈希表 哈希类型中的映射关系通常称为field-value&#xff0c;⽤于区分Redis整体的键值对&#xff08;key-value&#xff09;&#xff0c; 注意这⾥的value是指field对应的值&#xff0c;不是键&#xff08;key&#xff09;对应的值&#xff0c; 4.1 操作命令 hset&#xff…

前端编码技巧与规范

当我们完成项目的构建&#xff0c;进入开发阶段的时候&#xff0c;除了你需要了解框架本身的知识点外&#xff0c;我们还需要提前掌握一些项目的编码技巧与规范&#xff0c;在根源上解决之后因编码缺陷而导致的项目维护困难、性能下降等常见问题&#xff0c;为项目多人开发提供…

谷歌广告关键词出价根据什么来判断?

投放广告的目的是为了盈利&#xff0c;而关键字的出价直接关系到广告费用的支出。因此&#xff0c;设定出价上限时&#xff0c;不仅要参考关键字规划师的建议&#xff0c;还需结合广告的盈利表现来合理判断。 可以在谷歌广告账户的后台查看“平均每次点击费用”和“每次点击的…

《我在技术交流群算命》(二):QGraphicsItem怎么写自定义信号啊(QObject多继承顺序问题)

某位群友突然无征兆的抛出以下问题&#xff1a; QGraphicsItem怎么写自定义信号啊 看到这个问题的时候我是比较疑惑的&#xff0c;按鄙人对 Qt 的了解&#xff0c;自定义信号只需: 继承QObject类中加入Q_OBJECT宏声明一个信号并使用 但该群友毕竟也不是一个Qt新手&#xff0…

filebeat采集应用程序日志和多行匹配

1 filebeat采集nginx json日志 01 修改nginx的日志为json格式 elk93节点安装nginx&#xff0c;注释掉默认的nginx日志格式&#xff1a;# access_log /var/log/nginx/access.log;&#xff0c;在下方增加以下配置。然后重启nginx log_format wzy_nginx_json {"timestamp&…

大语言模型提示技巧(二)-给模型时间思考

在与大语言模型交互的时候&#xff0c;如果模型给出了错误的结论&#xff0c;不要着急否定大模型的能力&#xff0c;我们应当尝试重新构建查询&#xff0c;请求模型在提供它的最终答案之前进行一系列相关的推理。也就是说&#xff0c;如果给模型一个在短时间或用少量文字无法完…

深入剖析MySQL数据库架构:核心组件、存储引擎与优化策略(一)

sql语句分为两大类&#xff1a;查询&#xff08;select&#xff09;、增删改----修改&#xff08;update&#xff09; select语句的执行流程 执行sql语句的流程&#xff1a;连接数据库、缓存查询、解析器、优化器、执行器、存储引擎操作数据 客户端&#xff1a;图形界面工具…

【AimRT】现代机器人通信中间件 AimRT

目录 一、什么是AimRT二、AimRT与ROS22.1 定位与设计2.2 组成与通信方式对比 三、AimRT基本概念3.1 Node、Pkg 和 Module3.2 Protocol、Channel、Rpc 和 Filter3.3 App模式 和 Pkg模式3.4 Executor3.5 Plugin 一、什么是AimRT AimRT 是智元机器人公司自主研发的一款机器人通信…

SSM-Spring-AOP

目录 1 AOP实现步骤&#xff08;以前打印当前系统的时间为例&#xff09; 2 AOP工作流程 3 AOP核心概念 4 AOP配置管理 4-1 AOP切入点表达式 4-1-1 语法格式 4-1-2 通配符 4-2 AOP通知类型 五种通知类型 AOP通知获取数据 获取参数 获取返回值 获取异常 总结 5 …

idea( 2022.3.2)打包报错总结

一 报错 class lombok.javac.apt.LombokProcessor (in unnamed module 0x4fe64d23) cannot access class com.sun.tools.javac.processing.JavacProcessingEnvironment (in module jdk.compiler) because module jdk.compiler does not export com.sun.tools.javac.processing …

攻防靶场(29):目录权限和文件权限 ICMP

目录 1. 侦查 1.1 收集目标网络信息&#xff1a;IP地址 1.2 主动扫描&#xff1a;扫描IP地址段 1.3 搜索目标网站 2. 初始访问 2.1 利用面向公众的应用 3. 权限提升 3.1 有效账户&#xff1a;本地账户 3.2 滥用特权控制机制&#xff1a;Sudo和Sudo缓存 靶场下载地址&#xff1a…

C++ 面向对象编程:多态、虚函数原理

多态的通用描述便是&#xff0c;使用父类指针调用函数&#xff0c;可以根据对象类型来调用对应类型函数&#xff0c;我们分几个步骤来理解&#xff0c;先看下类的占用空间&#xff0c;然后拓展到虚函数对应数组&#xff0c;最后理解多态的原理。 我们先来看下在多态中没有任何…

王老吉药业SRM系统上线 携手隆道共启战略合作新篇章

12月27日&#xff0c;广州王老吉药业股份有限公司&#xff08;简称“王老吉药业”&#xff09;SRM项目上线启动会&#xff0c;在王老吉科普教育基地——“吉园”隆重举行。广药集团纪委主任陈耕、王老吉药业总工程师黄晓丹、隆道公司总裁吴树贵、项目经理赵耀、供应商代表郭伟及…

JavaScript基础 -- 变量、作用域与内存

1 原始值与引用值 原始值就是最简单的数据&#xff0c;引用值则是由多个值构成的对象。在把一个值赋给变量时&#xff0c;JavaScript引擎必须要确定这个值是原始值还是引用值 原始值大小固定&#xff0c;保存在栈内存上&#xff1b;引用值是对象&#xff0c;存储在堆内存上 它…

SQL—替换字符串—replace函数用法详解

SQL—替换字符串—replace函数用法详解 REPLACE() 函数——查找一个字符串中的指定子串&#xff0c;并将其替换为另一个子串。 REPLACE(str, old_substring, new_substring)str&#xff1a;要进行替换操作的原始字符串。old_substring&#xff1a;要被替换的子串。new_substri…

[极客大挑战 2019]Http 1

进入环境&#xff1a; 检查源码发现有一个链接&#xff0c;但是这里没有绑定&#xff0c;需要手动跳转&#xff0c;打开后&#xff0c;发现提示&#xff1a; 这里就是需要我们从https://Sycsecret.buuoj.cn来访问它 因此我们抓包&#xff0c;使用referer&#xff1a;服务器伪造…

吾杯网络安全技能大赛——Misc方向WP

吾杯网络安全技能大赛——Misc方向WP Sign 题目介绍: 浅浅签个到吧 解题过程&#xff1a; 57754375707B64663335376434372D333163622D343261382D616130632D3634333036333464646634617D 直接使用赛博橱子秒了 flag为 WuCup{df357d47-31cb-42a8-aa0c-6430634ddf4a} 原神启动…

如何查看下载到本地的大模型的具体大小?占了多少存储空间:Llama-3.1-8B下载到本地大概15GB

这里介绍一下tree命令&#xff0c;可以方便的查看文件目录结构和文件大小。 命令行tree的具体使用&#xff0c;请参考笔者的另一篇博客&#xff1a;深入了解 Linux tree 命令及其常用选项&#xff1a;Linux如何显示目录结构和文件大小&#xff0c;一言以蔽之&#xff0c;sudo a…