[GAN] 使用GAN网络进行图片生成的“调参人”入门指南——生成向日葵图片

[GAN] 使用GAN网络进行图片生成的“炼丹人”日志——生成向日葵图片

文章目录

  • [GAN] 使用GAN网络进行图片生成的“炼丹人”日志——生成向日葵图片
    • 1. 写在前面:
      • 1.1 应用场景:
      • 1.2 数据集情况:
      • 1.3 实验原理讲解和分析(简化版,到时候可以出一期深入的PaperReading)
      • 1.4 一些必要的介绍
    • 2. 重要实验代码:
      • 2.1 一些相关的数据预处理
      • 2.2 生成器和判别器
      • 2.3 损失函数计算
      • 2.4 训练和反向传播
    • 3. 实验结果分析:
      • 3.0 baseline
        • 3.0.1 损失函数:
        • 3.0.2 last picture:
        • 3.0.3 gif picture:
      • 3.1 epoch不变的情况下提高学习率:
        • 3.1.1 损失函数:
        • 3.1.2 last picture:
        • 3.1.3 gif picture:
      • 3.2 试试增加epoch?:
        • 3.2.1 损失函数:
        • 3.2.2 last picture:
        • 3.2.3 gif picture:
    • 4. 目前比较不错的效果展示
    • 5. 一些其它问题和小小的总结
    • 参考资料

1. 写在前面:

1.1 应用场景:

为了支撑人工智能落地,为人们的生活带来更多的便利,充足的数据尤为重要。而在实际的应用中常常会面临专业数据匮乏,数据不均衡的问题,所以利用神经网络根据已有的数据生成新的数据,进行数据扩充,成为了助力人工智能落地的新思路。

1.2 数据集情况:

我所使用的数据集是总量为256张的彩色的向日葵的图片。

在这里插入图片描述

1.3 实验原理讲解和分析(简化版,到时候可以出一期深入的PaperReading)

在这里插入图片描述

  • GAN网络俗称生成式对抗网络,该网络训练了两个模型(即生成器G和判别器D)来进行相互博弈,而博弈的目的是为了得到一个性能较好的可以用于生成我们想要的图片的生成器G。
  • 其中生成器网络G是为了生成可以用来迷惑判别器网络D的"假"图像。按数学语言来理解就是要最大化判别器D犯错的概率。
  • 而判别器网络D则是为了判别一个样本是不是来自于真实数据。按数学语言来理解就是它用于估计出一个样本是来源于真实的数据而非来自于G的概率。
  • 因此,不难得出这个模型的训练的过程大抵就是一个生成器G和判别器D之间的左右互博的过程。
  • 不过,值得注意的是这里对G和D的模型的构建使用的是多层感知机MLP(Multilayer perceptrons),也就是在网络上主要是使用全连接层
    在这里插入图片描述
  • 从这里我们可以看到GAN网络的损失函数为:
    在这里插入图片描述
  • 这个估值函数中由两个部分的数学期望所组成,第一部分是当输入是来自真实样本数据的期望,而第二部分则是当输入是来自生成器生成的样本时的期望。
  • 判别器输出的值是一个概率值,这个概率表示输出值是来自真实数据而非来自生成器的程度。
  • 这个值越接近1就越表明当前的输入来自真实数据,而越接近0就表示这个输入来自生成器。
  • 这样们就可以理解D(x)的目的是为了更好地区分二者,这样能是的D函数输出的值是合理的(更接近1或0)。
  • 而G的目的是为了让G(z)更像数据样本,这样可以使得第二个期望中的D(G(z))能被误判为1,这样就可以达到让第二个期望的值尽可能小的效果。
  • 再反过来看D的训练,D能更好判别真假,就更加使得第二个期望中的D(G(z))能被正确判为0,这样就可以达到让第二个期望的值尽可能大的效果。
  • 所以综合地来看,判别器D就是为了让整个损失(价值)函数尽量大,而生成器则反之,它想让损失函数足够小。这样也就符合我们训练一个网络的指标是让损失值减小,而我们也就可以沿着想办法让损失减小的方向去优化我们的模型从而达到训练出一个较好的生成器。

1.4 一些必要的介绍

  • 在我个人的实践中,我所使用的深度学习框架为华为昇腾AI系列的mindspore-1.9深度学习框架。
  • 所使用的笔记本的操作系统为Windows10
  • 我使用的是AMD的CPU来进行训练,因为本身该demo的数据量并不是很大。

2. 重要实验代码:

2.1 一些相关的数据预处理

import numpy as np
import matplotlib.pyplot as plt
from PIL import Image  # 一个读取图片和对图片做基础操作的类
# 数据转换
image_size = 64
input_images = np.asarray([np.asarray  # 将Python的数组转化成npArray(Image.open(input_data_dir + "/" + file).resize((image_size, image_size))  # 将图片的尺寸转化为 64* 64.convert("L"))  # 将图片转化为灰度图,这样就简化了运算,只需要考虑一个颜色通道了。(可拓展点对RGB三个颜色的通道都进行处理。)for file in filename])
# 数据预处理
input_images = input_images.reshape(256, 4096)  # 将256张图片展平为一维向量
# input_images = input_images.astype('float32')/255 # 把图片的值放缩到(0,1)之间
input_images = (input_images.astype('float32') - 127.5) / 127.5  # 把图片的值放缩到(-1,1)之间
# input_images = (input_images.astype('float32')-mean)/std # 把数据样本转化为均值为0,方差为1的标准化数据(未完成)

2.2 生成器和判别器

# 构建生成器
img_size = 64  # 训练图像长(宽)class Generator(nn.Cell):def __init__(self, latent_size, auto_prefix=True):super(Generator, self).__init__(auto_prefix=auto_prefix)self.model = nn.SequentialCell()# [N, 100] -> [N, 128]# 输入一个100维的0~1之间的高斯分布,然后通过第一层线性变换将其映射到256维self.model.append(nn.Dense(latent_size, 128))self.model.append(nn.ReLU())# [N, 128] -> [N, 256]self.model.append(nn.Dense(128, 256))self.model.append(nn.BatchNorm1d(256))self.model.append(nn.ReLU())# [N, 256] -> [N, 512]self.model.append(nn.Dense(256, 512))self.model.append(nn.BatchNorm1d(512))self.model.append(nn.ReLU())# [N, 512] -> [N, 1024]self.model.append(nn.Dense(512, 1024))self.model.append(nn.BatchNorm1d(1024))self.model.append(nn.ReLU())# [N, 1024] -> [N, 4096]# 经过线性变换将其变成4096维self.model.append(nn.Dense(1024, img_size * img_size))# 经过Tanh激活函数是希望生成的假的图片数据分布能够在-1~1之间self.model.append(nn.Tanh())def construct(self, x):img = self.model(x)return ops.reshape(img, (-1, 1, 64, 64))latent_size = 100  # 隐码的长度
net_g = Generator(latent_size)
net_g.update_parameters_name('generator')
# 构建判别器class Discriminator(nn.Cell):def __init__(self, auto_prefix=True):super().__init__(auto_prefix=auto_prefix)self.model = nn.SequentialCell()# [N, 4096] -> [N, 1024]self.model.append(nn.Dense(img_size * img_size, 1024))  # 输入特征数为4096,输出为1024self.model.append(nn.LeakyReLU())  # 默认斜率为0.2的非线性映射激活函数# [N, 1024] -> [N, 256]self.model.append(nn.Dense(1024, 256))  # 进行一个线性映射self.model.append(nn.LeakyReLU())# [N, 256] -> [N, 1]self.model.append(nn.Dense(256, 1))self.model.append(nn.Sigmoid())  # 二分类激活函数,将实数映射到[0,1]def construct(self, x):x_flat = ops.reshape(x, (-1, img_size * img_size))return self.model(x_flat)net_d = Discriminator()
net_d.update_parameters_name('discriminator')

2.3 损失函数计算

# 损失函数
adversarial_loss = nn.BCELoss(reduction='mean')# 损失及梯度计算函数
# 生成器计算损失过程
def generator_forward(test_noises):fake_data = net_g(test_noises)fake_out = net_d(fake_data)loss_g = adversarial_loss(fake_out, ops.ones_like(fake_out))return loss_g# 判别器计算损失过程
def discriminator_forward(real_data, test_noises):fake_data = net_g(test_noises)fake_out = net_d(fake_data)real_out = net_d(real_data)real_loss = adversarial_loss(real_out, ops.ones_like(real_out))fake_loss = adversarial_loss(fake_out, ops.zeros_like(fake_out))loss_d = real_loss + fake_lossreturn loss_d

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

2.4 训练和反向传播

def train_step(real_data, latent_code):# 计算判别器损失和梯度# 前向计算 => 得到损失函数和梯度参数# 反向传播 => 使用梯度参数进行权重参数更新loss_d, grads_d = grad_d(real_data, latent_code)optimizer_d(grads_d)loss_g, grads_g = grad_g(latent_code)optimizer_g(grads_g)return loss_d, loss_g

3. 实验结果分析:

  • 写在前面——在正式进行实验前还有一些随机性的探索。

其中值得一提的是,比起直接把256张照片一整个当成一个批次epoch来训练的话,在一个epoch内将整个数据集分成几个batch效果会好得多,下面的所有的实验都是在这种情况下进行的训练。

在这里插入图片描述
在这里插入图片描述

3.0 baseline

  • 以下是使用SGD优化器学习率lr=0.01并且训练100个epoch后的结果。

3.0.1 损失函数:

在这里插入图片描述

3.0.2 last picture:

在这里插入图片描述

3.0.3 gif picture:

在这里插入图片描述

  • 学习率是我们进行超参数调节中非常经常用来调节的一个参数,而lr=0.01是一个很常用的经验值,所以这次我们就i用这个值来作为一个实验的起始的参考值。
  • 从上面的损失函数的趋势可以看出,在一个数值比较小的lr下,损失函数的曲线是相对很平滑的。
  • 从上面的损失函数的曲线我们也可以看到一个健康的GAN网络训练的过程生成器G的损失和判别器D的损失一般是呈现为在某个区间内相互对峙波动发展的过程。
  • 而从上面的结果图来看,现在当前的模型是尚未收敛的状态,需要 “ 去做更多的学习来让自己收敛。
  • 那么怎么往下去学得更多呢?
  • 我们知道学习的过程是一个反向传播的过程,而控制这个过程的一个重要的参数是学习率,也就是说,我们可以考虑让学习率高一些,这样就可以学得更快一些。
  • 从另外一个角度来说我们也可以考虑“学得久一些”,比如增大epoch看看效果会怎么样?
  • 而这就是我们本文所研究的两条调参路线

3.1 epoch不变的情况下提高学习率:

3.1.1 损失函数:

  • SGD优化器100个epoch学习率lr=0.05
    请添加图片描述
  • SGD优化器100个epoch学习率lr=0.10
    请添加图片描述
  • SGD优化器100个epoch学习率lr=0.20
    请添加图片描述

3.1.2 last picture:

  • SGD优化器100个epoch学习率lr=0.05
    请添加图片描述
  • SGD优化器100个epoch学习率lr=0.10
    请添加图片描述
  • SGD优化器100个epoch学习率lr=0.20
    请添加图片描述

3.1.3 gif picture:

  • SGD优化器100个epoch学习率lr=0.05
    请添加图片描述
  • SGD优化器100个epoch学习率lr=0.10
    请添加图片描述
  • SGD优化器100个epoch学习率lr=0.20
    在这里插入图片描述
  • 从上面的部分结果来看的话,在只变动学习率的情况下,对于当前的例子,使用更大的学习率确实能够加速模型的收敛,让生成器最后的效果呈现出一种比较不错的效果,至少整个图片看起来已经是很像一张向日葵的图片。这个是一个不错的进步。
  • 但是依然产生了一些新的问题,比如因为学习率变大,虽然收敛的速度变快了,但是损失函数不是很平滑,充满了各种爆炸的毛刺的气息,这让我想到了过拟合不稳定

3.2 试试增加epoch?:

3.2.1 损失函数:

  • SGD优化器200个epoch学习率lr=0.05
    在这里插入图片描述
  • SGD优化器200个epoch学习率lr=0.10
    在这里插入图片描述
  • SGD优化器200个epoch学习率lr=0.20
    在这里插入图片描述

3.2.2 last picture:

  • SGD优化器200个epoch学习率lr=0.05
    在这里插入图片描述
  • SGD优化器200个epoch学习率lr=0.10
    在这里插入图片描述
  • SGD优化器200个epoch学习率lr=0.20
    在这里插入图片描述

3.2.3 gif picture:

  • SGD优化器200个epoch学习率lr=0.05
    在这里插入图片描述
  • SGD优化器200个epoch学习率lr=0.10
    在这里插入图片描述
  • SGD优化器200个epoch学习率lr=0.20
    在这里插入图片描述
  • 从最后的效果来看,把epoch增多,最后生成的照片的细腻程度远比仅有100个epoch的最后的成片的效果好了很多。由此可见,在学习率合理的情况下,去增大训练的epoch量也确实是能比较不错地提升GAN网络最后生成的图片的效果。
  • 不过也产生了许多新的问题,从上面的这些损失函数可以找到一个共性,那就是在初期的epoch中,生成器G的损失值是在判别器的损失值的之下的,而随着训练的epoch的量足够大之后,在中后期,会出现判别器D的损失值不断下降,而生成器的损失值则开始上升的情况。这其实直接说明了在这些阶段中继续增大epoch可能并不能很好地朝着我们想要的训练出一个效果更好的生成器的方向演变了。
  • 从部分实验结果中我们可以发现:当判别器D的能力相比生成器G更强的时候,G为了能够继续优化,往往就会向模式崩塌的方向走去,它会开始投机取巧,使得最后生成出来的图片会普遍有某种类似,在个性上就不够有好效果了。我们称其为泛化能力不够。
  • 这里我以我训练了500个epoch的一些过程性的截图来展示:
  • SGD优化器1个epoch学习率lr=0.25
    在这里插入图片描述
  • SGD优化器50个epoch学习率lr=0.25
    在这里插入图片描述
  • SGD优化器100个epoch学习率lr=0.25
    在这里插入图片描述
  • SGD优化器150个epoch学习率lr=0.25
    在这里插入图片描述
  • SGD优化器200个epoch学习率lr=0.25
    在这里插入图片描述
  • SGD优化器250个epoch学习率lr=0.25
    在这里插入图片描述
  • SGD优化器300个epoch学习率lr=0.25
    在这里插入图片描述
  • SGD优化器350个epoch学习率lr=0.25
    在这里插入图片描述
  • SGD优化器400个epoch学习率lr=0.25
    在这里插入图片描述
  • SGD优化器450个epoch学习率lr=0.25
    在这里插入图片描述
  • SGD优化器500个epoch学习率lr=0.25
    在这里插入图片描述
  • 特别指出这个例子的原因是我发现epoch增大越到后期,生成出来的向日葵就基本都是怼脸向日葵居多,而前面还能看到的苗条向日葵,则其实基本偏少了,更不用说其他更有特性的向日葵了。
  • 当我返回去看这256张向日葵的数据集的时候,我发现其实原始的相册中,其实居多的也主要是怼脸向日葵,其次是苗条向日葵,最后是一些零散的各类较有个性的向日葵。
  • 尤次可见,最后的最后,我们导向的结果依然是最后影响一个模型的质量的,还是回到了训练这个模型的数据集的质量。高质量的数据处理对模型的训练是非常非常非常重要的!
  • 数据集照片情况概览:
    在这里插入图片描述
    在这里插入图片描述

4. 目前比较不错的效果展示

  • 以下是使用SGD优化器,学习率为0.25,训练了500个epoch的一个演变效果。
    在这里插入图片描述

5. 一些其它问题和小小的总结

  • 总得来说经过本次实验的探究,其实我所在对抗的主要是两个问题
    • "生成的图片不像我的目的图像"的问题。(欠拟合,未收敛)
    • ”生成的图片大多长得类似,或者甚至一模一样!“(过拟合,模式崩塌)
  • 结合做了以上那么多的实验来看,我现在对GAN网络的两个模型的损失函数的理解是正常的情况G和D应该是两条有波动,但整体上是对峙者推进的一上一下的趋势,其中最好是G在下,而D在上。这样的状态持续得越多个epoch,最终我们得到的生成器的综合效果就会越佳,而一旦打破了这个平衡,生成器的质量就会往某一个方向偏移,一般是模式崩塌即判别器不断在进化,使得判别器太强,而生成器只能通过投机取巧的方式来精学某一类来保持它能继续保持能骗过生成器。所以如何达到平衡是一个值得深入研究的方向。

参考资料

  • [1] GOODFELLOW I, POUGET-ABADIE J, MIRZA M, et al. Generative Adversarial Nets[J/OL]. Journal of Japan Society for Fuzzy Theory and Intelligent Informatics, 2017: 177-177. http://dx.doi.org/10.3156/jsoft.29.5_177_2. DOI:10.3156/jsoft.29.5_177_2.
  • GAN图像生成-mindspore

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/92221.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Flink CDC系列之:TiDB CDC 导入 Elasticsearch

Flink CDC系列之:TiDB CDC 导入 Elasticsearch 一、通过docker 来启动 TiDB 集群二、下载 Flink 和所需要的依赖包三、在TiDB数据库中创建表和准备数据四、启动Flink 集群,再启动 SQL CLI五、在 Flink SQL CLI 中使用 Flink DDL 创建表六、Kibana查看Ela…

H3C QoS打标签和限速配置案例

EF:快速转发 AF:确保转发 CS:给各种协议用的 BE:默认标记(尽力而为) VSR-88-2 出口路由配置: [H3C]dis current-configuration version 7.1.075, ESS 8305 vlan 1 traffic classifier vlan10 operator and if-match a…

关于consul的下载方法

linux下 sudo yum install -y yum-utils sudo yum-config-manager --add-repo https://rpm.releases.hashicorp.com/RHEL/hashicorp.repo sudo yum -y install consulwindow下 https://developer.hashicorp.com/consul/downloads 然后把里面的exe文件放在gopath下就行了 验证…

苍穹外卖day11笔记

今日首先介绍前端技术Apache ECharts,说明后端需要准备的数据,然后讲解具体统计功能的实现,包括营业额统计、用户统计、订单统计、销量排名。 一、ECharts 是什么 ECharts是一款基于 Javascript 的数据可视化图表库。我们用它来展示图表数…

说一下什么是tcp的2MSL,为什么客户端在 TIME-WAIT 状态必须等待 2MSL 的时间?

1.TCP之2MSL 1.1 MSL MSL:Maximum Segment Lifetime报文段最大生存时间,它是任何报文段被丢弃前在网络内的最长时间 1.2为什么存在MSL TCP报文段以IP数据报在网络内传输,而IP数据报则有限制其生存时间的TTL字段,并且TTL的限制是基于跳数 1.3…

ceph相关概念和部署

Ceph 可用于向云提供 Ceph 对象存储 平台和 Ceph 可用于提供 Ceph 块设备服务 到云平台。Ceph 可用于部署 Ceph 文件 系统。所有 Ceph 存储集群部署都从设置 每个 Ceph 节点,然后设置网络。 Ceph 存储集群需要满足以下条件:至少一个 Ceph 监控器&#x…

继承和多态C++

这里写目录标题 继承public、protected、private 修饰类的成员public、protected、private 指定继承方式改变访问权限 C继承时的名字遮蔽问题基类成员函数和派生类成员函数不构成重载C基类和派生类的构造函数构造函数的调用顺序基类构造函数调用规则 C基类和派生类的析构函数C多…

Android app专项测试之耗电量测试

前言 耗电量指标 待机时间成关注目标 提升用户体验 通过不同的测试场景,找出app高耗电的场景并解决 01、需要的环境准备 1、python2.7(必须是2.7,3.X版本是不支持的) 2、golang语言的开发环境 3、Android SDK 此三个的环境搭建这里就不详细说了&am…

无涯教程-Perl - send函数

描述 此函数在SOCKET上发送消息(与recv相反)。如果Socket未连接,则必须提供一个目标以与TO参数进行通信。在这种情况下,将使用sendto系统功能代替系统发送功能。 FLAGS参数由按位或0以及MSG_OOB和MSG_DONTROUTEoptions中的一个或多个形成。 MSG_OOB允许您在支持此概念的Socke…

炬芯科技发布全新第二代智能手表芯片,引领腕上新趋势!

2023年7月,炬芯科技宣布全新第二代智能手表芯片正式发布。自2021年底炬芯科技推出第一代的智能手表芯片开始便快速获得了市场广泛认可和品牌客户的普遍好评。随着技术的不断创新和突破,为了更加精准地满足市场多元化的变幻和用户日益增长的体验需求&…

C语言入门 Day_5 四则运算

目录 前言 1.四则运算 2.其他运算 3.易错点 4.思维导图 前言 图为世界上第一台通用计算机ENIAC,于1946年2月14日在美国宾夕法尼亚大学诞生。发明人是美国人莫克利(JohnW.Mauchly)和艾克特(J.PresperEckert)。 计算机的最开始…

轻量级 Spring Task 任务调度可视化管理

Spring Task/Spring Scheduler 傻傻分不清 首先做一下“名词解释”,分清楚这两者的区别: Spring Task Spring Task 是 Spring 框架自带的一个任务调度模块,提供了基本的任务调度功能。它是通过 Java 的 Timer 和 TimerTask 类来实现的&…

ReactDOM模块react-dom/client没有默认导出报错解决办法

import ReactDOM 模块“"E:/Dpandata/Shbank/rt-pro/node_modules/.pnpm/registry.npmmirror.comtypesreact-dom18.2.7/node_modules/types/react-dom/client"”没有默认导出。 解决办法 只需要在tsconfig.json里面添加配置 "esModuleInterop": true 即…

数据结构介绍

1、什么是数据结构呢? 计算机底层存储、组织数据的方式。是指数据相互之间是以什么方式排列在一起的。数据结构是为了更方便的管理和使用数据,需要结合具体的业务来进行选择。一般情况下,精心选择的数据结构可以带来更高的运行或者存储效率。…

6.利用matlab完成 符号矩阵的秩和 符号方阵的逆矩阵和行列式 (matlab程序)

1.简述 利用M文件建立矩阵 对于比较大且比较复杂的矩阵,可以为它专门建立一个M文件。下面通过一个简单例子来说明如何利用M文件创建矩阵。 例2-2 利用M文件建立MYMAT矩阵。(1) 启动有关编辑程序或MATLAB文本编辑器,并输入待建矩阵:(2) 把…

【2023年11月第四版教材】《第5章-信息系统工程之软件工程(第一部分)》

《第5章-信息系统工程(第一部分)》 章节说明1 软件工程1.1 架构设计1.2 需求分析 章节说明 65%为新增内容, 预计选择题考5分,案例和论文不考;本章与第三版教材一样的内容以楷体字进行标注! 1 软件工程 1.1 架构设计 1、软件架…

Linux文件权限一共10位长度,分成四段

Linux文件权限一共10位长度,分成四段 Linux文件权限 1、 文件aaa的访问权限为rw-r--r--,现要增加所有用户的执行权限和同组用户的写权限,下列哪些命令是正确的? a) chmod ax gw aaa √ b) chmod 764 aaa c) chmod 775 aaa √ d)…

vue3+vite+pinia

目录 一、项目准备 1.1、Vite搭建项目 1.2、vue_cli创建项目 二、组合式API(基于setup) 2.1、ref 2.2、reactive 2.3、toRefs 2.4、watch和watchEffect 2.5、computed 2.6、生命周期钩子函数 2.7、setup(子组件)的第一个参数-props 2.8、setup(子组件)的第二个参数…

贴吧照片和酷狗音乐简单爬取

爬取的基本步骤 很简单,主要是两大步 向url发起请求 这里注意找准对应资源的url,如果对应资源不让程序代码访问,这里可以伪装成浏览器发起请求。 解析上一步返回的源代码,从中提取想要的资源 这里解析看具体情况,一…

什么是前端框架?怎么学习? - 易智编译EaseEditing

前端框架是一种用于开发Web应用程序界面的工具集合,它提供了一系列预定义的代码和结构,以简化开发过程并提高效率。 前端框架通常包括HTML、CSS和JavaScript的库和工具,用于构建交互式、动态和响应式的用户界面。 学习前端框架可以让您更高效…