文章目录
- 昇思MindSpore应用实践
- 基于MindSpore的生成对抗网络图像生成
- 1、生成对抗网络简介
- 零和博弈 vs 极大极小博弈
- GAN的生成对抗损失:
- 2、基于MindSpore的 Vanilla GAN
- 3、基于MindSpore的手写数字图像生成
- 导入数据
- 数据可视化
- 模型训练
- Reference
昇思MindSpore应用实践
本系列文章主要用于记录昇思25天学习打卡营的学习心得。
基于MindSpore的生成对抗网络图像生成
1、生成对抗网络简介
零和博弈 vs 极大极小博弈
生成对抗网络Generative adversarial networks (GANs)主要包括生成器网络(Generator)和判别器网络(Discriminator)
这两个网络在GAN的训练过程中相互竞争,形成了一种博弈论中的极大极小博弈(MinMax game)
零和博弈(Zero-sum game)是博弈论中的一个重要概念,指的是参与者的利益完全相反,即一方的利益的增加意味着另一方的利益的减少,总利益为零。在零和博弈中,参与者之间的利益是完全对立的,因此一个参与者的利益的增加必然导致其他参与者的利益减少。在非合作博弈中,纳什均衡是一种重要的解,纳什均衡代表每个玩家选择的策略都是其在对方策略给定的情况下的最优策略。在零和博弈中,寻找纳什均衡通常涉及找到使每个玩家的预期收益最大化的策略组合。
极大极小博弈(MinMax game)是一种博弈论中的解决方法,用于确定参与者的最佳决策策略,此外为人所熟知用于决策的方法还有强化学习。在极大极小博弈中,每个参与者都试图最大化自己的最小收益。也就是说,每个参与者都采取行动,以确保在对手选择其最优策略时自己的收益最大化。
假设GAN网络训练达到了纳什平衡状态,那么判别器无法准确地判断出输入样本是真样本还是假样本,此时判别器失效,生成器达到了巅峰状态,我们就无需使用判别器并终止训练了,得到的生成器就是我们用来生成数据的预训练模型。
从理论上讲,此博弈游戏的平衡点是 p G ( x ; θ ) = p d a t a ( x ) p_{G}(x;\theta) = p_{data}(x) pG(x;θ)=pdata(x),此时判别器会随机猜测输入是真图像还是假图像。下面我们简要说明生成器和判别器的博弈过程:
- 在训练刚开始的时候,生成器和判别器的质量都比较差,生成器会随机生成一个数据分布。
- 判别器通过求取梯度和损失函数对网络进行优化,将靠近真实数据分布的数据判定为1,将靠近生成器生成出来数据分布的数据判定为0。
- 生成器通过优化,生成出更加贴近真实数据分布的数据。
- 生成器所生成的数据和真实数据达到相同的分布,此时判别器的输出为1/2,如上图中的(d)所示。
GAN的生成对抗损失:
min G max D V ( G , D ) = E x ∼ p data ( x ) [ log D ( x ) ] + E z ∼ p z ( z ) [ log ( 1 − D ( G ( z ) ) ) ] \underset{G}{\min} \underset{D}{\max}V(G, D) = \mathbb{E}_{x \sim p{\text{data}}(x)}[\log D(x)] + \mathbb{E}_{z \sim p_z(z)}[\log(1 - D(G(z)))] GminDmaxV(G,D)=Ex∼pdata(x)[logD(x)]+Ez∼pz(z)[log(1−D(G(z)))]
GAN网络本身就是在训练一个能达到平衡状态的损失函数,生成对抗损失是GANs中最基本的损失函数。
近十年来著名的GAN网络结构:
2、基于MindSpore的 Vanilla GAN
生成器部分:生成器 Generator
的功能是将隐码映射到数据空间。通过五层 Dense
全连接层来完成的,每层都与 BatchNorm1d
批归一化层和 ReLU
激活层配对,输出数据会经过 Tanh
函数,使其返回 [-1,1] 的数据范围内,并返回一张28x28的图像作为生成结果。
from mindspore import nn
import mindspore.ops as opsimg_size = 28 # 训练图像长(宽)28x28class Generator(nn.Cell):def __init__(self, latent_size, auto_prefix=True):super(Generator, self).__init__(auto_prefix=auto_prefix)self.model = nn.SequentialCell()# [N, 100] -> [N, 128]# 输入一个100维的0~1之间的高斯分布,通过第一层线性变换将其映射到128维self.model.append(nn.Dense(latent_size, 128))self.model.append(nn.ReLU())# 通过第二层线性变换将其映射到256维# [N, 128] -> [N, 256]self.model.append(nn.Dense(128, 256))self.model.append(nn.BatchNorm1d(256))self.model.append(nn.ReLU())# [N, 256] -> [N, 512]self.model.append(nn.Dense(256, 512))self.model.append(nn.BatchNorm1d(512))self.model.append(nn.ReLU())# [N, 512] -> [N, 1024]self.model.append(nn.Dense(512, 1024))self.model.append(nn.BatchNorm1d(1024))self.model.append(nn.ReLU())# [N, 1024] -> [N, 784]# 经过线性变换将其变成784维self.model.append(nn.Dense(1024, img_size * img_size))# 经过Tanh激活函数是希望生成的假的图片数据分布能够在-1~1之间self.model.append(nn.Tanh())def construct(self, x):img = self.model(x)return ops.reshape(img, (-1, 1, 28, 28))net_g = Generator(latent_size)
net_g.update_parameters_name('generator')
判别器部分:判别器 Discriminator
是一个二分类网络模型,在训练时,判别器接收生成器的生成图像与对应的真实数据相对比,输出判定该图像为真实图的概率。主要通过一系列的 Dense
层和 LeakyReLU
层对其进行处理,最后通过 Sigmoid
激活函数,使其返回 [0, 1] 的数据范围内,得到最终概率。
# 判别器
class Discriminator(nn.Cell):def __init__(self, auto_prefix=True):super().__init__(auto_prefix=auto_prefix)self.model = nn.SequentialCell()# [N, 784] -> [N, 512]self.model.append(nn.Dense(img_size * img_size, 512)) # 输入特征数为784,输出为512self.model.append(nn.LeakyReLU()) # 默认斜率为0.2的非线性映射激活函数# [N, 512] -> [N, 256]self.model.append(nn.Dense(512, 256)) # 进行一个线性映射self.model.append(nn.LeakyReLU())# [N, 256] -> [N, 1]self.model.append(nn.Dense(256, 1))self.model.append(nn.Sigmoid()) # 二分类激活函数,将实数映射到[0,1]def construct(self, x):x_flat = ops.reshape(x, (-1, img_size * img_size))return self.model(x_flat)net_d = Discriminator()
net_d.update_parameters_name('discriminator')
3、基于MindSpore的手写数字图像生成
导入数据
import numpy as np
import mindspore.dataset as dsbatch_size = 128
latent_size = 100 # 潜在编码的长度train_dataset = ds.MnistDataset(dataset_dir='./MNIST_Data/train')
test_dataset = ds.MnistDataset(dataset_dir='./MNIST_Data/test')def data_load(dataset):dataset1 = ds.GeneratorDataset(dataset, ["image", "label"], shuffle=True, python_multiprocessing=False)# 数据增强mnist_ds = dataset1.map( # 通过map方法给每张图像映射一个潜在编码# 将图像数据转换为 float32 类型# 生成一个长度为 latent_size 的服从正态分布的随机向量,并将其转换为 float32 类型operations=lambda x: (x.astype("float32"), np.random.normal(size=latent_size).astype("float32")),output_columns=["image", "latent_code"])mnist_ds = mnist_ds.project(["image", "latent_code"])# 批量操作mnist_ds = mnist_ds.batch(batch_size, True)return mnist_dsmnist_ds = data_load(train_dataset)iter_size = mnist_ds.get_dataset_size()
print('Iter size: %d' % iter_size)
数据可视化
import matplotlib.pyplot as pltdata_iter = next(mnist_ds.create_dict_iterator(output_numpy=True))
figure = plt.figure(figsize=(3, 3))
cols, rows = 5, 5
for idx in range(1, cols * rows + 1):image = data_iter['image'][idx]figure.add_subplot(rows, cols, idx)plt.axis("off")plt.imshow(image.squeeze(), cmap="gray")
plt.show()
潜在编码(latent code)的构造:
为了跟踪生成器的学习进度,我们在训练的过程中的每轮迭代结束后,将一组固定的遵循高斯分布的隐码test_noise
输入到生成器中,通过这组固定的潜在编码(也叫隐码)所生成的图像效果来评估生成器的生成质量。
import random
import numpy as np
from mindspore import Tensor
from mindspore.common import dtype# 利用随机种子创建一批隐码
np.random.seed(2323)
test_noise = Tensor(np.random.normal(size=(25, 100)), dtype.float32)
random.shuffle(test_noise)
模型训练
定义损失函数和优化器:
lr = 0.0002 # 学习率# 损失函数
adversarial_loss = nn.BCELoss(reduction='mean')# 优化器
optimizer_d = nn.Adam(net_d.trainable_params(), learning_rate=lr, beta1=0.5, beta2=0.999)
optimizer_g = nn.Adam(net_g.trainable_params(), learning_rate=lr, beta1=0.5, beta2=0.999)
optimizer_g.update_parameters_name('optim_g')
optimizer_d.update_parameters_name('optim_d')
训练分为两个主要部分,也就是要训练两个网络:生成与对抗网络。
第一部分是训练判别器。训练判别器的目的是最大程度地提高判别图像真伪的概率。按照原论文的方法,通过提高其随机梯度来更新判别器,最大化 l o g D ( x ) + l o g ( 1 − D ( G ( z ) ) log D(x) + log(1 - D(G(z)) logD(x)+log(1−D(G(z)) 的值。
第二部分是训练生成器。如论文所述,最小化 l o g ( 1 − D ( G ( z ) ) ) log(1 - D(G(z))) log(1−D(G(z))) 来训练生成器,以产生更好的虚假图像。
在这两个部分中,分别获取训练过程中的损失,并在每轮迭代结束时进行测试,将固定隐码批量推送到生成器中,以直观地跟踪生成器 Generator
的训练效果。
import os
import time
import matplotlib.pyplot as plt
import mindspore as ms
from mindspore import Tensor, save_checkpointtotal_epoch = 24 # 训练周期数
batch_size = 64 # 用于训练的训练集批量大小# 加载预训练模型的参数
pred_trained = False
pred_trained_g = './result/checkpoints/Generator99.ckpt'
pred_trained_d = './result/checkpoints/Discriminator99.ckpt'checkpoints_path = "./result/checkpoints" # 结果保存路径
image_path = "./result/images" # 测试结果保存路径# 生成器计算损失过程
def generator_forward(test_noises):fake_data = net_g(test_noises)fake_out = net_d(fake_data)loss_g = adversarial_loss(fake_out, ops.ones_like(fake_out))return loss_g# 判别器计算损失过程
def discriminator_forward(real_data, test_noises):fake_data = net_g(test_noises)fake_out = net_d(fake_data)real_out = net_d(real_data)real_loss = adversarial_loss(real_out, ops.ones_like(real_out))fake_loss = adversarial_loss(fake_out, ops.zeros_like(fake_out))loss_d = real_loss + fake_lossreturn loss_d# 梯度方法
grad_g = ms.value_and_grad(generator_forward, None, net_g.trainable_params())
grad_d = ms.value_and_grad(discriminator_forward, None, net_d.trainable_params())def train_step(real_data, latent_code):# 计算判别器损失和梯度loss_d, grads_d = grad_d(real_data, latent_code)optimizer_d(grads_d)loss_g, grads_g = grad_g(latent_code)optimizer_g(grads_g)return loss_d, loss_g# 保存生成的test图像
def save_imgs(gen_imgs1, idx):for i3 in range(gen_imgs1.shape[0]):plt.subplot(5, 5, i3 + 1)plt.imshow(gen_imgs1[i3, 0, :, :] / 2 + 0.5, cmap="gray")plt.axis("off")plt.savefig(image_path + "/test_{}.png".format(idx))# 设置参数保存路径
os.makedirs(checkpoints_path, exist_ok=True)
# 设置中间过程生成图片保存路径
os.makedirs(image_path, exist_ok=True)net_g.set_train()
net_d.set_train()# 储存生成器和判别器loss
losses_g, losses_d = [], []for epoch in range(total_epoch):start = time.time()for (iter, data) in enumerate(mnist_ds):start1 = time.time()image, latent_code = dataimage = (image - 127.5) / 127.5 # [0, 255] -> [-1, 1]image = image.reshape(image.shape[0], 1, image.shape[1], image.shape[2])d_loss, g_loss = train_step(image, latent_code)end1 = time.time()if iter % 10 == 0:print(f"Epoch:[{int(epoch):>3d}/{int(total_epoch):>3d}], "f"step:[{int(iter):>4d}/{int(iter_size):>4d}], "f"loss_d:{d_loss.asnumpy():>4f} , "f"loss_g:{g_loss.asnumpy():>4f} , "f"time:{(end1 - start1):>3f}s, "f"lr:{lr:>6f}")end = time.time()print("time of epoch {} is {:.2f}s".format(epoch + 1, end - start))losses_d.append(d_loss.asnumpy())losses_g.append(g_loss.asnumpy())# 每个epoch结束后,使用生成器生成一组图片gen_imgs = net_g(test_noise)save_imgs(gen_imgs.asnumpy(), epoch)# 根据epoch保存模型权重文件if epoch % 1 == 0:save_checkpoint(net_g, checkpoints_path + "/Generator%d.ckpt" % (epoch))save_checkpoint(net_d, checkpoints_path + "/Discriminator%d.ckpt" % (epoch))import time
print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),'Wayn_Fan-sail')
使用cpu进行12个epoch的生成效果如下:
Reference
昇思官方文档-GAN图像生成
昇思大模型平台