生成对抗网络(GAN)简介以及Python实现

本篇博客简单介绍了生成对抗网络(Generative Adversarial Networks,GAN),并基于Keras实现深度卷积生成对抗网络(DCGAN)。

以往的生成模型都是预先假设生成样本服从某一分布族,然后用深度网络学习分布族的参数,最后从学习到的分布中采样生成新的样本。例如变分自编码器就是构建生成样本的密度函数 p ( x ∣ z , θ ) p(x|z,\theta) p(xz,θ),这种模型称为显示密度模型
GAN并不学习密度函数,而是基于随机噪声,通过深度神经网络的层层包装,直接输出服从原始样本分布的新样本,这种模型称为隐式密度模型

那么GAN如何保证生成的样本符合原始样本分布?这就需要用到对抗学习的思想。

GAN网络结构

GAN网络由生成网络判别网络组合而成,通常都由卷积层和全连接层构成。

  • 生成网络负责根据输入的随机噪声 z z z,构建映射函数 G ( z ) G(z) G(z),使得生成样本尽量接近真实样本
  • 判别网络负责比较真实样本 x x x与生成样本 G ( z ) G(z) G(z),判定出 G ( z ) G(z) G(z)不是来自真实样本分布,相当于一个二分类器

可见,生成网络与判别网络之间是对抗交互的。流程如下:

目标函数

判别网络输出样本来自真实分布的概率为 p ( y = 0 ∣ x ) = D ( x , ϕ ) p(y=0|x)=D(x, \phi) p(y=0x)=D(x,ϕ)式中, ϕ \phi ϕ为判别网络参数。注意:实际模型训练时,通常将真实样本标记为0,生成样本标记为1。
判别网络的目标函数用交叉熵表示:
D
生成网络目标与之相反:
G
式中, θ \theta θ为生成网络参数。
注:以上公式图片均截图自 邱锡鹏《神经网络与深度学习》,详细推到过程请见此书

算法流程

生成网络与判别网络的训练是交叉进行的,判别网络训练K次后,生成网络训练1次。

  • Input: 训练样本集D,最大迭代训练步数max_step,每次迭代判别网络训练次数K以及小批量样本数M
  • Output: 生成网络参数
  1. 采用随机梯度下降训练判别网络K次,每次分别从样本集和生成网络采集M个样本
  2. 生成网络生成M个样本,基于最新的判别网络参数,训练生成网络一次
  3. 若达到最大迭代步数,终止,反之转步骤1

代码示例

# 实现简单的DCGAN(深度卷积生成对抗网络)
from keras.layers import Conv2D, Dense, Flatten, LeakyReLU, Reshape, Conv2DTranspose, \BatchNormalization, Input, Dropout
from keras.models import Sequential, Model, load_model
from keras.optimizers import adam
from keras.utils import plot_model
import numpy as npdef uniform_sampling(n_sample, dim):# 均匀分布采样return np.random.uniform(0, 1, size=(n_sample, dim))def normal_sampling(n_sample, dim):# 均匀分布采样return np.random.randn(n_sample, dim)# 构建判别网络
d_model = Sequential()
d_model.add(BatchNormalization())
d_model.add(Dropout(0.3))
d_model.add(Conv2D(64, (3, 3), padding='same', input_shape=(28, 28, 1)))
d_model.add(LeakyReLU(0.2))
d_model.add(Dropout(0.3))
d_model.add(Conv2D(128, (3, 3), strides=2, padding='same'))  # 用带步长卷积层替代池化层
d_model.add(LeakyReLU(0.2))
d_model.add(Dropout(0.3))
d_model.add(Conv2D(256, (3, 3), padding='same'))
d_model.add(LeakyReLU(0.2))
d_model.add(Dropout(0.3))
d_model.add(Conv2D(512, (3, 3), strides=2, padding='same'))
d_model.add(LeakyReLU(0.2))
d_model.add(Flatten())
d_model.add(Dropout(0.3))
d_model.add(Dense(1, activation='sigmoid'))  # 输出样本标记为1,即假样本的概率# 构建生成网络
g_model = Sequential()
g_model.add(BatchNormalization())
g_model.add(Dense(7 * 7 * 256, activation='relu', input_dim=100))
g_model.add(Reshape((7, 7, 256)))
g_model.add(BatchNormalization())
g_model.add(Conv2DTranspose(128, 3, strides=2, padding='same', activation='relu'))  # 反卷积
g_model.add(BatchNormalization())
g_model.add(Conv2DTranspose(64, 3, strides=2, padding='same', activation='relu'))
g_model.add(BatchNormalization())
g_model.add(Conv2DTranspose(32, 3, strides=1, padding='same', activation='relu'))
g_model.add(BatchNormalization())
g_model.add(Conv2DTranspose(1, 3, strides=1, padding='same', activation='tanh'))class DCGAN:def __init__(self, d_model, g_model,input_dim=784, g_dim=100,max_step=100, sample_size=256, d_iter=3, kind='normal'):self.input_dim = input_dim  # 图像的展开维度,即判别网络的输入维度self.g_dim = g_dim  # 随机噪声维度,即生成网络的输入维度self.max_step = max_step  # 整个模型的迭代次数self.sample_size = sample_size  # 训练过程中小批量采样的个数的一半self.d_iter = d_iter  # 每次迭代,判别网络训练的次数self.kind = kind  # 随机噪声分布类型self.d_model = d_model  # 判别模型self.g_model = g_model  # 生成模型self.m_model = self.merge_model()  # 合并模型self.optimizer = adam(lr=0.0002, beta_1=0.5)self.d_model.compile(optimizer=self.optimizer, loss='binary_crossentropy')def merge_model(self):# 合并生成网络与判别网络noise = Input(shape=(self.g_dim,))gen_sample = self.g_model(noise)self.d_model.trainable = False  # 固定判别网络,训练合并网络等同与训练生成网络d_output = self.d_model(gen_sample)m_model = Model(noise, d_output)  # 模型输出生成样本的预测结果,越接近0越好m_model.compile(optimizer='adam', loss='binary_crossentropy')return m_modeldef gen_noise(self, num_sample):# 生成随机噪声数据if self.kind == 'normal':f = normal_samplingelif self.kind == 'uniform':f = uniform_samplingelse:raise ValueError('暂不支持分布{}'.format(self.kind))return f(num_sample, self.g_dim)def gen_real_data(self, train_data):# 真实样本采样n_samples = train_data.shape[0]inds = np.random.randint(0, n_samples, size=self.sample_size)real_data = train_data[inds]real_label = np.random.uniform(0, 0.3,size=(self.sample_size,))  # 用0-0.3随机数代替标记0return real_data, real_labeldef gen_fake_data(self):# 生成样本noise = self.gen_noise(self.sample_size)fake_data = g_model.predict(noise)  # 生成网络生成M个样本,标记为0fake_label = np.random.uniform(0.7, 1.2,size=(self.sample_size,))  # 用0.7-1.2随机数代替标记1return fake_data, fake_labeldef fit(self, train_data):# 轮流训练判别网络和生成网络for i in range(self.max_step):for _ in range(self.d_iter):  # 训练判别网络real_data, real_label = self.gen_real_data(train_data)d_model.train_on_batch(real_data, real_label)fake_data, fake_label = self.gen_fake_data()d_model.train_on_batch(fake_data, fake_label)# 训练生成网络noise = self.gen_noise(self.sample_size)expected_label = np.random.uniform(0, 0.3, size=(self.sample_size,))  # 期望输出0self.d_model.trainable = Falseself.m_model.compile(optimizer=self.optimizer, loss='binary_crossentropy')gan_loss = self.m_model.train_on_batch(noise, expected_label)print('第{0}次迭代训练损失值:{1:.3f}'.format(i + 1, gan_loss))returndef gen_samples(self, num):# 生成网络生成数据z = self.gen_noise(num)imgs = g_model.predict(z)return imgsdef save_model(self):# 保存训练后的模型self.d_model.save('d_model.hdf5')self.g_model.save('g_model.hdf5')returnif __name__ == '__main__':# d_model = load_model('D:\Machine_Learning\deep_learning_algorithm\gan\d_model.hdf5')# g_model = load_model('D:\Machine_Learning\deep_learning_algorithm\gan\g_model.hdf5')# plot_model(d_model, 'd_model.png')# plot_model(g_model, 'g_model.png')model = DCGAN(d_model, g_model, max_step=10, sample_size=1000, d_iter=2)# 导入数据input_dim = 28 * 28 * 1  # 单通道28像素的图像f = np.load(r'D:\Machine_Learning\deep_learning_algorithm\data\mnist.npz')x_train, y_train = f['x_train'], f['y_train']f.close()x_train = np.reshape(x_train, [-1, input_dim])x_train = (x_train.astype('float32') - 127.5) / 127.5  # 规范化到(-1,1)x_train = x_train.reshape((x_train.shape[0], 28, 28, 1))  # 转换成卷积网络层标准的数据格式# 训练model.fit(x_train)model.save_model()# 生成样本并可视化imgs = model.gen_samples(10)def plot_img(gen_imgs):# 对比重构前后的图像import matplotlib.pyplot as pltn = 10plt.figure(figsize=(20, 4))for i in range(n):ax = plt.subplot(1, n, i + 1)plt.imshow(gen_imgs[i].reshape(28, 28))plt.gray()ax.get_xaxis().set_visible(False)ax.get_yaxis().set_visible(False)plt.show()returnplot_img(imgs)

参考资料

  1. 邱锡鹏《神经网络与深度学习》
  2. https://arxiv.org/pdf/1511.06434.pdf
  3. https://github.com/soumith/ganhacks
  4. https://blog.csdn.net/theonegis/article/details/80115340

注:代码未经严格测试训练,仅作示例,如有不当之处请指正。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/60837.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

生成对抗网络(GAN)教程 - 多图详解

一.生成对抗网络简介 1.生成对抗网络模型主要包括两部分:生成模型和判别模型。 生成模型是指我们可以根据任务、通过模型训练由输入的数据生成文字、图像、视频等数据。 [1]比如RNN部分讲的用于生成奥巴马演讲稿的RNN模型,通过输入开头词就能生成下来。…

对抗神经网络学习和实现(GAN)

一,GAN的原理介绍 \quad GAN的基本原理其实非常简单,这里以生成图片为例进行说明。假设我们有两个网络,G(Generator)和D(Discriminator)。正如它的名字所暗示的那样,它们的功能分别是…

生成对抗网络(GAN)

1 GAN基本概念 1.1 如何通俗理解GAN? ​ 生成对抗网络(GAN, Generative adversarial network)自从2014年被Ian Goodfellow提出以来,掀起来了一股研究热潮。GAN由生成器和判别器组成,生成器负责生成样本,判别器负责判断生成器生成…

基于图神经网络的对抗攻击 Nettack: Adversarial Attacks on Neural Networks for Graph Data

研究意义 随着GNN的应用越来越广,在安全非常重要的应用中应用GNN,存在漏洞可能是非常严重的。 比如说金融系统和风险管理,在信用评分系统中,欺诈者可以伪造与几个高信用客户的联系,以逃避欺诈检测模型;或者…

生成对抗网络(Generative Adversial Network,GAN)原理简介

生成对抗网络(GAN)是深度学习中一类比较大的家族,主要功能是实现图像、音乐或文本等生成(或者说是创作),生成对抗网络的主要思想是:通过生成器(generator)与判别器(discriminator)不断对抗进行训练。最终使得判别器难以分辨生成器生成的数据(…

快讯|莫言用 ChatGPT 写《颁奖辞》;特斯拉人形机器人集体出街!已与FSD算法打通

一分钟速览新闻点 言用 ChatGPT 写《颁奖辞》孙其君研究员团队 Adv. Funct. Mater.:多功能离子凝胶纤维膜用于能量离电皮肤微软CEO反驳马斯克:我们没有控制OpenAI特斯拉人形机器人集体出街!已与FSD算法打通微软CEO称小型公司仍可在人工智能领…

从供应链到价值链:人形机器人产业链深入研究

原创 | 文 BFT机器人 01 人形机器人产业进展:AI赋能,人形机器人迭代有望加速 目前人形机器人产业所处从“0”到“1”的萌芽期,从现在到未来的时间里,人形机器人以其仿人外形、身体构成及其智能大脑,能极大解放生产力、…

人形机器人火出圈!OpenAI领投挪威人形机器人公司“1X”

文|牛逼的AI 编|猫猫咪子 源|AI源起 目前已经实现了对接GPT4技术的机器人Ameca,它拥有逼真的外观和丰富的表情。 随着人形机器人技术的飞速发展,未来可能不再适用阿西莫夫所提出的“机器人三定律”,因为超级…

chatgpt赋能python:Python登录界面制作指南

Python登录界面制作指南 介绍 登录界面是许多应用程序的关键组成部分之一。Python作为一种优秀的编程语言,拥有着强大的界面开发框架,能够帮助开发人员更轻松地创作出完美的登录界面。 在本文中,我们将向您介绍使用Python如何制作一个简单…

The Journal of Neuroscience: 珠心算训练有助于提高儿童的视觉空间工作记忆

《本文同步发布于“脑之说”微信公众号,欢迎搜索关注~~》 珠心算是指个体在熟练进行珠算操作后,可摆脱实际算盘,借助大脑中虚拟算盘进行数字计算的方式(图1)。早期行为学研究表明,珠心算个体的数字计算能力…

php珠心算源码,深度解析珠心算的“开智”功能

编者按:本文来自李绵军校长在廊坊智慧特训营演讲。李绵军校长通过十几年来对珠心算的钻研练习,详细解读了珠心算的开智功能,以及“一门深入”的作用。 珠心算的开智价值是在哪里?大家都说开发智力,我在这讲开发智力不是…

php珠心算源码,NOIP201401珠心算测验

珠心算测验 问题描述】 珠心算是一种通过在脑中模拟算盘变化来完成快速运算的一种计算技术。珠心算训练,既能够开发智力,又能够为日常生活带来很多便利,因而在很多学校得到普及。 某学校的珠心算老师采用一种快速考察珠心算加法能力的测验方…

雨课堂提交作业步骤 10步帮你弄好

1 2 3 4 5 6 7 中间计数记错了… 8 9 10 弹出对话框,点击确认即可 提交成功的截图:

2022李宏毅作业hw4 - 挫败感十足的一次作业。

系列文章: 2022李宏毅作业hw1—新冠阳性人员数量预测。_亮子李的博客-CSDN博客_李宏毅hw1 hw-2 李宏毅2022年作业2 phoneme识别 单strong-hmm详细解释。_亮子李的博客-CSDN博客_李宏毅hw2 2021李宏毅作业hw3 --食物分类。对比出来的80准确率。_亮子李的博客-CSDN博客…

php老师的一个作业展示

1.在ScanCode.php中 在judgeTrayCodeEnableIntoWarehouse方法中: 2在CommitTrayCodes.php中 访问时,直接访问CommitTrayCodes.php,这个CommitTrayCodes.php是要建在controller下,建议导一下,老师的数据库,以防止命名…

HCIA网络课程第七周作业

(1)请用自己的语言描述基本ACL和高级ACL的区别 (2)AAA支持的认证、授权和计费方式分别有哪几种? AAA支持的认证方式有不认证 本地认证 远端认证AAA支持的授权方式为不授权 本地授权 远端授权AAA支持计费方式为不计费…

如何获取抖音和快手直播间的直播流地址

如下是通过python代码脚本获取的方法: import requests import re def get_real_url(rid): try: if ‘v.douyin.com‘ in rid: room_id re.findall(r‘(\d{19})‘, requests.get(urlrid).url)[0] else: room_id rid room_url ‘https://webcast-hl.amemv.com/…

新手如何做抖音直播带货?新号如何快速获取直播推荐流量?

如果要做抖音直播带货的话,首先需要开通抖音直播带货权限。也就是我们经常说的解锁抖音直播购物车功能。解锁直播购物车后,我们就能在直播中售卖商品。 而在此之前,你要先开通商品橱窗并解锁视频购物车。开通商品橱窗,在完成新手…

抖音直播如何获取推流地址?不到1000粉也能直播啦。还能加热。2020年12月29日

抖音直播自从出了自己的pc客户端(直播伴侣)后,直播推流地址已不再对外暴漏。 正常情况下,粉丝大于1000,才可使用抖音官方的推流工具(直播伴侣)。但对于粉丝数不够1000,也想用第三方推流工具(如…

短视频、直播平台——电商直播源码第三方SDK接入教程

现在网络视频直播行业非常火爆,所以很多公司也希望开发直播平台,一般直播平台需要用户给主播送礼物来实现盈利,所以刷礼物的功能是必备的,另外为了增加视频的美感与炫酷等特效功能,也需要用到美颜与视频滤镜等功能&…