生成对抗网络DCGAN学习

在AI内容生成领域,有三种常见的AI模型技术:GAN、VAE、Diffusion。其中,Diffusion是较新的技术,相关资料较为稀缺。VAE通常更多用于压缩任务,而GAN由于其问世较早,相关的开源项目和科普文章也更加全面,适合入门学习。

博主从入门和学习角度用Tensorflow跑通了DCGAN,本文对其进行记录以及分享。

1.简介

GAN(Generative Adversarial Network)是一种用于生成模型的机器学习框架。其原理基于两个主要组件:生成器(Generator)和判别器(Discriminator),二者通过对抗学习的方式相互竞争和提升。

从2014年左右发展至今,GAN目前有很多分支:

  • GAN 朴素GAN,最原始版本
  • DCGAN 卷积神经网络GAN
  • CGAN 条件GAN,训练时传入额外条件,例如通过不同的mask区域生成不同内容,可控制的生成
  • SeqGAN 使用GAN生成某些风格的句子,但不能进行对答
  • Cycle GAN 可实现图像风格迁移,其实现略复杂
  • 省略

2.原理介绍

先来看图

梯度
判别
G
LeakyReLU
tanh
InputNoise
FullConnectLayer123
OutputImage
D
LeakyReLU
Sigmoid
InputImage
FullConnectLayer12
OutputOneValue

生成器(Generator)和判别器(Discriminator)是GAN的两个主要模型,生成器在上图中用缩写G表示,判别器用缩写D表示。
生成器G输入[N]的一维噪声,即InputNoise。输出[W * H * RGB](大致类似)的张量
判别器D输入一张图像,输出[1]的张量,即一个浮点数,通过0-1的值得到图像是真还是假

判别器需要尽可能的认出造假图片,生成器需要尽可能的骗过判别器,两者会在这2个目标上不断的通过反向传播进行学习,从而达到生成器和判别器的纳什均衡,最终输出质量很高的生成图像。

2.2 重点1

在训练中,判别器返回一个0-1区间的浮点数(如[0]=0.63,[0]=0.21)作为判断结果,值越高也越认为是真实图片。由于判别器也是一个神经网络模型,因此可以将输出层的梯度一直传递回输入层,然后将输入层的梯度作为生成器的梯度继续反向传播,从而完成一次训练。

然而,很多文章并没有提到这一点。如果没有接触过这种多模型梯度传递训练方法,可能会认为使用一个数学方法或者计算机视觉方法来构建判别器也可以让整个模型正常运行。但事实上,这种方法是不可行的(通常情况下)。

2.3 重点2

使用更多的层可以增强模型的推理能力。例如,在训练过程中,如果模型生成出眉毛 A 的特征,则有鼻子 B、C 和 D 相关的备选项;而如果生成出眉毛 E 的特征,则有鼻子 F 和 G 相关的备选项。

这也是为什么生成器需要使用三个隐层的原因(博主的观点)。通过增加隐层的数量,模型可以捕捉到更多的特征和抽象概念,从而提高生成器的表现能力和推理能力。更深层次的网络结构能够帮助模型学习更复杂的模式和关联,使其在生成结果时更加准确和多样化。

上图生成器部分的激活函数用的是LeakyReLU,实际上就单隐层神经网络来说,ReLU要比Sigmoid能多解决很多类型问题,Sigmoid更适合分类问题,遇到一些奇怪的问题不容易收敛,而LeakyReLU激活函数即和ReLU逻辑一样也可以返回负数信息,这是博主觉得采用这个激活函数的原因。
而至于tanH和Sigmoid的比较,它们在某种程度上相似。一般来说,网上普遍认为tanH比Sigmoid更好,主要原因是它具有较窄的数值边界范围。

2.4 重点3

对于2套样本比较损失这类问题,一般使用二分类交叉熵,这不同于分类问题。
而二分类交叉熵又是在只有2种结果(r和1-r),的情况下对公式进行的简化:
https://blog.csdn.net/grayrail/article/details/131619144

2.5 模式崩溃

训练时还会出现一种情况,即生成器始终卡在一个生成结果上,比如生成0-9数字,结果训练几轮后始终在生成数字3。
这种情况称为模式崩溃,一般增加训练样本数量并调节参数,没有比较好的办法。

3.实践准备

python库下载使用国内镜像源:
https://zhuanlan.zhihu.com/p/477179822

使用方式:

pip install -i https://pypi.tuna.tsinghua.edu.cn/simple pyspider

github库下载耽误时间,可以缓存到gitee:
在这里插入图片描述

而gitee也有自己缓存好的镜像库,可以先去这里查:
https://gitcode.net/mirrors

python库查找:
https://pypi.org/

在pip中查找python库:
先 pip install pip-search 再使用命令 pip_search 搜索

4.实践

全连接神经网络版本的朴素GAN效果相对较差,而DCGAN(Deep Convolutional GAN)是卷积神经网络版本的GAN,下面以DCGAN为例使用Tensorflow进行实现:

import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.keras import layers# 定义生成器模型
def build_generator():model = tf.keras.Sequential()model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))model.add(layers.BatchNormalization())model.add(layers.LeakyReLU())model.add(layers.Reshape((7, 7, 256)))assert model.output_shape == (None, 7, 7, 256)  # 注意:batch size 没有限制model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))assert model.output_shape == (None, 7, 7, 128)model.add(layers.BatchNormalization())model.add(layers.LeakyReLU())model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))assert model.output_shape == (None, 14, 14, 64)model.add(layers.BatchNormalization())model.add(layers.LeakyReLU())model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))assert model.output_shape == (None, 28, 28, 1)return model# 定义判别器模型
def build_discriminator():model = tf.keras.Sequential()model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',input_shape=[28, 28, 1]))model.add(layers.LeakyReLU())model.add(layers.Dropout(0.3))model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))model.add(layers.LeakyReLU())model.add(layers.Dropout(0.3))model.add(layers.Flatten())model.add(layers.Dense(1))return model# 定义生成器和判别器
generator = build_generator()
discriminator = build_discriminator()# 定义损失函数
loss_fn = tf.keras.losses.BinaryCrossentropy(from_logits=True)# 定义生成器和判别器的优化器
generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)def generator_loss(fake_output):return loss_fn(tf.ones_like(fake_output), fake_output)def discriminator_loss(real_output, fake_output):real_loss = loss_fn(tf.ones_like(real_output), real_output)fake_loss = loss_fn(tf.zeros_like(fake_output), fake_output)total_loss = real_loss + fake_lossreturn total_loss# 定义训练循环
@tf.function  #这个是tensorflow的装饰器,标记后可提升性能,不加此标记也可
def train_step(images):# 生成噪声向量noise = tf.random.normal([BATCH_SIZE, 100])with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:# 使用生成器生成假图片generated_images = generator(noise, training=True)# 使用判别器判断真假图片real_output = discriminator(images, training=True)fake_output = discriminator(generated_images, training=True)# 计算损失函数gen_loss = generator_loss(fake_output)disc_loss = discriminator_loss(real_output, fake_output)# 计算梯度并更新生成器和判别器的参数gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))def generate_and_save_images(model, epoch, test_input):predictions = model(test_input, training=False)print("predictions.shape:", predictions.shape)num_images = predictions.shape[0]rows = int(num_images ** 0.5) # 计算行数cols = num_images // rows # 计算列数fig = plt.figure(figsize=(8, 8))for i in range(num_images):plt.subplot(rows, cols, i+1)plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')plt.axis('off')plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))#plt.show()# 加载MNIST数据集
(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()# 标准化数据
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
train_images = (train_images - 127.5) / 127.5# 批量大小与训练次数
BATCH_SIZE = 256
EPOCHS = 50# 数据集切分为批次并进行训练
dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(60000).batch(BATCH_SIZE)for epoch in range(EPOCHS):for i,image_batch in enumerate(dataset):print("sub i",i)train_step(image_batch)print("------------------------------------------------------epoch:", epoch)# 每个 epoch 结束后生成并保存一组图像if (epoch + 1) % 5 == 0:seed = tf.random.normal([BATCH_SIZE, 100])generate_and_save_images(generator, epoch + 1, seed)

跑一阵子MNIST数据集后,结果如下:
在这里插入图片描述


参考:

论文精读: https://www.bilibili.com/video/BV1rb4y187vD

同济子豪兄精读版本: https://www.bilibili.com/video/BV1oi4y1m7np

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/74017.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

海康摄像头开发笔记(一):连接防爆摄像头、配置摄像头网段、设置rtsp码流、播放rtsp流、获取rtsp流、调优rtsp流播放延迟以及录像存储

文为原创文章,转载请注明原文出处 本文章博客地址:https://hpzwl.blog.csdn.net/article/details/131679108 红胖子(红模仿)的博文大全:开发技术集合(包含Qt实用技术、树莓派、三维、OpenCV、OpenGL、ffmpeg、OSG、单片机、软硬结…

C++如何用OpenCV中实现图像的边缘检测和轮廓提取?

最近有个项目需要做细孔定位和孔距测量,需要做边缘检测和轮廓提取,先看初步效果图: 主要实现代码: int MainWindow::Test() {// 2.9 单个像素长度um 5倍double dbUnit 2.9/(1000*5);// 定义显示窗口namedWindow("src"…

【使用机器学习和深度学习对城市声音进行分类】基于两种技术(ML和DL)对音频数据(城市声音)进行分类(Matlab代码实现)

💥💥💞💞欢迎来到本博客❤️❤️💥💥 🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜密,逻辑清晰,为了方便读者。 ⛳️座右铭&a…

【ASP.NET MVC】生成页面(6)

本应该继续数据库访问的问题进行探讨,前文确实比较LOW。但本人认为:初学者需要解决的是快速了解知识路线的问题,所谓“生存问题”,所以,干脆把流程先走完。 一、页面生成 下面这张图在前面已经介绍过: 前…

网工内推 | 云计算工程师专场,CCNP/HCIP认证优先

01 弧聚科技 招聘岗位:网络工程师(云计算方向) 职责描述: 1、作为H3C初级云计算交付工程资源培养对象,需配合完成相关华三产品及服务规范培训。 2、培训赋能后,安排到H3C云项目交付中进行项目交付及驻场支…

【从零开始学习JAVA | 第三十七篇】初识多线程

目录 前言: ​编辑 引入: 多线程: 什么是多线程: 多线程的意义: 多线程的应用场景: 总结: 前言: 本章节我们将开始学习多线程,多线程是一个很重要的知识点&#xff…

【Spring Boot】请求参数传json数组,后端采用(pojo)新增案例(103)

请求参数传json数组,后端采用(pojo)接收的前提条件: 1.pom.xml文件加入坐标依赖:jackson-databind 2.Spring Boot 的启动类加注解:EnableWebMvc 3.Spring Boot 的Controller接受参数采用:Reque…

openCV C++环境配置

文章目录 一、openCV 安装二、新建项目三、配置环境变量四、测试使用 编译器:vs2017 OpenCV:4.5.4 一、openCV 安装 将openCV安装到一个路径下,我安装到了D盘根目录下 二、新建项目 在vs2017新建控制台空项目,打开项目属性 在VC目录 -> 包含目录下…

【Spring Boot】请求参数传json对象,后端采用(map)CRUD案例(101)

请求参数传json对象,后端采用(map)接收的前提条件: 1.Spring Boot 的Controller接受参数采用:RequestBody 2.需要一个Json工具类,将json数据转成Map; 工具类:Json转Map import com…

面试热题(接雨水问题)

给定 n 个非负整数表示每个宽度为 1 的柱子的高度图,计算按此排列的柱子,下雨之后能接多少雨水。 我们看到题的第一步,永远是对入参进行判断 public int trap(int[] height) {if (height null) {return 0;}...} 但是我们想想看,接…

Springboot -- 按照模板生成docx、pdf文件,docx转pdf格式

使用 poi-tl 根据模板生成 word 文件。 使用 xdocreport 将 docx 文件转换为 pdf 文件。 xdocreport 也支持根据模板导出 word ,但是 poi-tl 的功能更齐全,操作更简单,文档清晰。 poi-tl 、xdocreport 内部均依赖了 poi ,要注意两…

伊语IM即时通讯源码/im商城系统/纯源码IM通讯系统安卓+IOS前端纯原生源码

伊语IM即时通讯源码/im商城系统/纯源码IM通讯系统安卓IOS前端纯原生源码, 后端是java源码。

利用鸿鹄快速构建公司IT设备管理方案

需求描述 相信应该有一部分朋友跟我们一样,公司内部有很多各种各样的系统,比如资产管理、CRM、issue管理等等。这篇文章介绍下,鸿鹄是如何让我们的资产系统,按照我们的需求展示数据的。 我们的资产管理系统,是使用开源…

机器学习---线性回归、多元线性回归、代价函数

1. 线性回归 回归属于有监督学习中的一种方法。该方法的核心思想是从连续型统计数据中得到数学模型,然后 将该数学模型用于预测或者分类。该方法处理的数据可以是多维的。 回归是由达尔文的表兄弟Francis Galton发明的。Galton于1877年完成了第一次回归预测&…

51:电机(ULN2003D)

1:介绍 我们51单片机使用的是直流电机 直流电机是一种将电能转换为机械能的装置。一般的直流电机有两个电极,当电极正接时,电机正转,当电极反接时,电机反转 直流电机主要由永磁体(定子)、线圈(转…

SpringCloudAlibaba之Ribbon

Ribbon是nacos自带的负载均衡器,属于客户端的负载均衡 但是在Spring高级版本中让LoadBalancer替代了 本人用的是2.1.0的nacos,ribbon还没有被替换。 使用: 在配置类中:LoadBalanced BeanLoadBalancedpublic RestTemplate restT…

MacOS Monterey VM Install ESXi to 7 U2

一、MacOS Monterey ISO 准备 1.1 下载macOS Monterey 下载🔗链接 一定是 ISO 格式的,其他格式不适用: https://www.mediafire.com/file/4fcx0aeoehmbnmp/macOSMontereybyTechrechard.com.iso/file 1.2 将 Monterey ISO 文件上传到数据…

编辑接口和新增接口的分别调用

在后台管理系统中,有时候会碰到新增接口和编辑接口共用一个弹窗的时候. 一.场景 在点击新增或者编辑的时候都会使用这个窗口,新增直接调用接口进行增加即可,编辑则是打开这个窗口显示当前行的数据,然后调用编辑接口。 二.处理方法 在默认的情况下,这个窗口用来处理…

【从零开始学习JAVA | 第三十八篇】应用多线程

目录 前言: 多线程的实现方式: Thread常见的成员方法: 总结: 前言: 多线程的引入不仅仅是提高计算机处理能力的技术手段,更是适应当前时代对效率和性能要求的必然选择。在本文中,我们将深入…

使用贝叶斯滤波器通过运动模型和嘈杂的墙壁传感器定位机器人研究(Matlab代码实现)

💥💥💞💞欢迎来到本博客❤️❤️💥💥 🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜密,逻辑清晰,为了方便读者。 ⛳️座右铭&a…