昇思MindSpore 应用学习-GAN图像生成-CSDN

模型简介

生成式对抗网络(Generative Adversarial Networks,GAN)是一种生成式机器学习模型,是近年来复杂分布上无监督学习最具前景的方法之一。
最初,GAN由Ian J. Goodfellow于2014年发明,并在论文Generative Adversarial Nets中首次进行了描述,其主要由两个不同的模型共同组成——生成器(Generative Model)和判别器(Discriminative Model):

  • 生成器的任务是生成看起来像训练图像的“假”图像;
  • 判别器需要判断从生成器输出的图像是真实的训练图像还是虚假的图像。

GAN通过设计生成模型和判别模型这两个模块,使其互相博弈学习产生了相当好的输出。
GAN模型的核心在于提出了通过对抗过程来估计生成模型这一全新框架。在这个框架中,将会同时训练两个模型——捕捉数据分布的生成模型 和估计样本是否来自训练数据的判别模型 。
在训练过程中,生成器会不断尝试通过生成更好的假图像来骗过判别器,而判别器在这过程中也会逐步提升判别能力。这种博弈的平衡点是,当生成器生成的假图像和训练数据图像的分布完全一致时,判别器拥有50%的真假判断置信度。
用 代表图像数据,用 表示判别器网络给出图像判定为真实图像的概率。在判别过程中, 需要处理作为二进制文件的大小为 的图像数据。当 来自训练数据时, 数值应该趋近于 ;而当 来自生成器时, 数值应该趋近于 。因此 也可以被认为是传统的二分类器。
用 代表标准正态分布中提取出的隐码(隐向量),用 :表示将隐码(隐向量) 映射到数据空间的生成器函数。函数 的目标是将服从高斯分布的随机噪声 通过生成网络变换为近似于真实分布 的数据分布,我们希望找到 使得 和 尽可能的接近,其中 代表网络参数。
表示生成器 生成的假图像被判定为真实图像的概率,如Generative Adversarial Nets中所述, 和 在进行一场博弈, 想要最大程度的正确分类真图像与假图像,也就是参数 ;而 试图欺骗 来最小化假图像被识别到的概率,也就是参数 。因此GAN的损失函数为:

从理论上讲,此博弈游戏的平衡点是,此时判别器会随机猜测输入是真图像还是假图像。下面我们简要说明生成器和判别器的博弈过程:

  1. 在训练刚开始的时候,生成器和判别器的质量都比较差,生成器会随机生成一个数据分布。
  2. 判别器通过求取梯度和损失函数对网络进行优化,将靠近真实数据分布的数据判定为1,将靠近生成器生成出来数据分布的数据判定为0。
  3. 生成器通过优化,生成出更加贴近真实数据分布的数据。
  4. 生成器所生成的数据和真实数据达到相同的分布,此时判别器的输出为1/2。


在上图中,蓝色虚线表示判别器,黑色虚线表示真实数据分布,绿色实线表示生成器生成的虚假数据分布, 表示隐码, 表示生成的虚假图像 。该图片来源于Generative Adversarial Nets。详细的训练方法介绍见原论文。

数据集

数据集简介

MNIST手写数字数据集是NIST数据集的子集,共有70000张手写数字图片,包含60000张训练样本和10000张测试样本,数字图片为二进制文件,图片大小为28*28,单通道。图片已经预先进行了尺寸归一化和中心化处理。
本案例将使用MNIST手写数字数据集来训练一个生成式对抗网络,使用该网络模拟生成手写数字图片。

数据集下载

使用download接口下载数据集,并将下载后的数据集自动解压到当前目录下。数据下载之前需要使用pip install download安装download包。
下载解压后的数据集目录结构如下:

./MNIST_Data/
├─ train
│ ├─ train-images-idx3-ubyte
│ └─ train-labels-idx1-ubyte
└─ test├─ t10k-images-idx3-ubyte└─ t10k-labels-idx1-ubyte

数据下载的代码如下:

# 数据下载
from download import download  # 导入下载模块url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/MNIST_Data.zip"  # 定义数据集的下载链接
download(url, ".", kind="zip", replace=True)  # 下载数据集到当前目录,指定文件类型为zip,且如果存在同名文件则进行替换

解析:

  1. from download import download:从download模块导入download函数,用于执行文件下载操作。
  2. url = "...":定义一个字符串变量url,存储需要下载的文件的网络地址。
  3. download(url, ".", kind="zip", replace=True)
    • url:传入待下载的文件链接。
    • ".":指定下载到当前工作目录。
    • kind="zip":指定下载文件的类型为zip压缩包。
    • replace=True:如果目标目录中已存在同名文件,则替换该文件。

数据加载

使用MindSpore自己的MnistDatase接口,读取和解析MNIST数据集的源文件构建数据集。然后对数据进行一些前处理。

import numpy as np  # 导入NumPy库,用于数值计算
import mindspore.dataset as ds  # 导入MindSpore的数据集模块batch_size = 128  # 定义批量大小
latent_size = 100  # 隐码的长度# 加载MNIST训练集和测试集
train_dataset = ds.MnistDataset(dataset_dir='./MNIST_Data/train')  # 加载训练数据集
test_dataset = ds.MnistDataset(dataset_dir='./MNIST_Data/test')  # 加载测试数据集def data_load(dataset):# 创建生成数据集,指定数据项和是否打乱dataset1 = ds.GeneratorDataset(dataset, ["image", "label"], shuffle=True, python_multiprocessing=False)# 数据增强mnist_ds = dataset1.map(operations=lambda x: (x.astype("float32"), np.random.normal(size=latent_size).astype("float32")),output_columns=["image", "latent_code"])  # 对每个样本进行变换,生成浮点型图像和随机隐码mnist_ds = mnist_ds.project(["image", "latent_code"])  # 只保留图像和隐码两列# 批量操作mnist_ds = mnist_ds.batch(batch_size, True)  # 将数据集分批,batch_size为每批样本大小return mnist_ds  # 返回处理后的数据集# 加载训练集数据
mnist_ds = data_load(train_dataset)# 获取数据集的迭代器大小
iter_size = mnist_ds.get_dataset_size()  # 获取数据集的总批次数
print('Iter size: %d' % iter_size)  # 打印迭代器大小

解析:

  1. import numpy as np:导入NumPy库,通常用于高效的数值计算和数组操作。
  2. import mindspore.dataset as ds:导入MindSpore的dataset模块,用于处理数据集。
  3. batch_sizelatent_size:定义批量大小和隐码的长度,用于后续的数据处理。
  4. ds.MnistDataset(...):加载MNIST数据集的训练和测试部分。
  5. def data_load(dataset)::定义一个函数data_load,用于对输入的数据集进行处理。
  6. ds.GeneratorDataset(...):创建生成器数据集,指定数据项(图像和标签),设置是否打乱数据。
  7. dataset1.map(...):对每个样本进行映射操作,使用匿名函数(lambda)生成浮点型的图像和随机隐码。
  8. mnist_ds.project(...):选择需要保留的数据列,过滤掉不需要的列。
  9. mnist_ds.batch(...):对数据集进行批量处理,设置每批的数据量为batch_size
  10. mnist_ds.get_dataset_size():获取数据集的总批次数,即迭代器的大小。

数据集可视化

通过create_dict_iterator函数将数据转换成字典迭代器,然后使用matplotlib模块可视化部分训练数据。

import matplotlib.pyplot as plt  # 导入Matplotlib库用于数据可视化# 创建一个数据迭代器,并获取下一个数据批次
data_iter = next(mnist_ds.create_dict_iterator(output_numpy=True))# 设置图形的大小
figure = plt.figure(figsize=(3, 3))cols, rows = 5, 5  # 定义网格的列数和行数
for idx in range(1, cols * rows + 1):  # 循环绘制5x5的图像image = data_iter['image'][idx]  # 从数据迭代器中获取第idx个图像figure.add_subplot(rows, cols, idx)  # 在图形中添加子图plt.axis("off")  # 关闭坐标轴plt.imshow(image.squeeze(), cmap="gray")  # 显示图像,使用灰度色图plt.show()  # 显示图形

解析:

  1. import matplotlib.pyplot as plt:导入Matplotlib库的pyplot模块,用于绘图和数据可视化。
  2. data_iter = next(mnist_ds.create_dict_iterator(output_numpy=True)):创建一个字典形式的数据迭代器,并获取下一个批次的数据,设置output_numpy=True使输出为NumPy数组。
  3. figure = plt.figure(figsize=(3, 3)):创建一个新的图形,指定图形的大小为3x3英寸。
  4. cols, rows = 5, 5:定义绘图网格的列数和行数,这里设置为5列5行,共25个子图。
  5. for idx in range(1, cols * rows + 1)::循环遍历,从1到25,绘制25个图像。
  6. image = data_iter['image'][idx]:从数据迭代器中提取第idx个图像。
  7. figure.add_subplot(rows, cols, idx):在图形中添加一个子图,位置由idx指定。
  8. plt.axis("off"):关闭当前子图的坐标轴显示。
  9. plt.imshow(image.squeeze(), cmap="gray"):显示图像,使用squeeze()去掉单维度,设置色图为灰度(gray)。
  10. plt.show():渲染并显示图形。

image.png

隐码构造

为了跟踪生成器的学习进度,我们在训练的过程中的每轮迭代结束后,将一组固定的遵循高斯分布的隐码test_noise输入到生成器中,通过固定隐码所生成的图像效果来评估生成器的好坏。

import random  # 导入随机库,用于随机数生成和洗牌
import numpy as np  # 导入NumPy库,用于数值计算
from mindspore import Tensor  # 从MindSpore中导入Tensor类
from mindspore.common import dtype  # 导入数据类型模块# 利用随机种子创建一批隐码
np.random.seed(2323)  # 设置NumPy的随机种子,以便于结果可复现
test_noise = Tensor(np.random.normal(size=(25, 100)), dtype.float32)  # 生成一个形状为(25, 100)的正态分布隐码,并转换为MindSpore Tensor类型
random.shuffle(test_noise)  # 打乱生成的隐码顺序

解析:

  1. import random:导入Python的随机库,用于生成随机数和进行随机操作。
  2. import numpy as np:导入NumPy库,用于高效的数值计算和数组操作。
  3. from mindspore import Tensor:从MindSpore库导入Tensor类,用于创建张量对象。
  4. from mindspore.common import dtype:导入MindSpore中的数据类型模块,以便于后续指定张量的数据类型。
  5. np.random.seed(2323):设置NumPy的随机种子为2323,这样可以确保每次运行代码时生成相同的随机数。
  6. test_noise = Tensor(np.random.normal(size=(25, 100)), dtype.float32)
    • np.random.normal(size=(25, 100)):生成一个形状为(25, 100)的正态分布随机数数组。
    • Tensor(...):将生成的NumPy数组转换为MindSpore的Tensor对象,并指定数据类型为float32
  7. random.shuffle(test_noise):打乱test_noise数组中的元素顺序(注意:shuffle()方法通常用于列表,不适用于NumPy数组。如果需要打乱NumPy数组,应使用np.random.shuffle())。

模型构建

本案例实现中所搭建的 GAN 模型结构与原论文中提出的 GAN 结构大致相同,但由于所用数据集 MNIST 为单通道小尺寸图片,可识别参数少,便于训练,我们在判别器和生成器中采用全连接网络架构和 ReLU 激活函数即可达到令人满意的效果,且省略了原论文中用于减少参数的 Dropout 策略和可学习激活函数 Maxout

生成器

生成器 Generator 的功能是将隐码映射到数据空间。由于数据是图像,这一过程也会创建与真实图像大小相同的灰度图像(或 RGB 彩色图像)。在本案例演示中,该功能通过五层 Dense 全连接层来完成的,每层都与 BatchNorm1d 批归一化层和 ReLU 激活层配对,输出数据会经过 Tanh 函数,使其返回 [-1,1] 的数据范围内。注意实例化生成器之后需要修改参数的名称,不然静态图模式下会报错。

from mindspore import nn  # 导入MindSpore的神经网络模块
import mindspore.ops as ops  # 导入MindSpore的操作模块img_size = 28  # 训练图像的长(宽)class Generator(nn.Cell):  # 定义生成器类,继承自nn.Celldef __init__(self, latent_size, auto_prefix=True):  # 初始化方法super(Generator, self).__init__(auto_prefix=auto_prefix)  # 调用父类构造函数self.model = nn.SequentialCell()  # 创建一个顺序模型容器# [N, 100] -> [N, 128]# 输入100维的高斯分布,通过线性层映射到128维self.model.append(nn.Dense(latent_size, 128))  # 添加第一层全连接层self.model.append(nn.ReLU())  # 添加ReLU激活函数# [N, 128] -> [N, 256]self.model.append(nn.Dense(128, 256))  # 添加第二层全连接层self.model.append(nn.BatchNorm1d(256))  # 添加批量归一化层self.model.append(nn.ReLU())  # 添加ReLU激活函数# [N, 256] -> [N, 512]self.model.append(nn.Dense(256, 512))  # 添加第三层全连接层self.model.append(nn.BatchNorm1d(512))  # 添加批量归一化层self.model.append(nn.ReLU())  # 添加ReLU激活函数# [N, 512] -> [N, 1024]self.model.append(nn.Dense(512, 1024))  # 添加第四层全连接层self.model.append(nn.BatchNorm1d(1024))  # 添加批量归一化层self.model.append(nn.ReLU())  # 添加ReLU激活函数# [N, 1024] -> [N, 784]# 线性变换将输出维度变为784self.model.append(nn.Dense(1024, img_size * img_size))  # 添加第五层全连接层# 经过Tanh激活函数将输出范围压缩到[-1, 1]self.model.append(nn.Tanh())  # 添加Tanh激活函数def construct(self, x):  # 定义前向传播方法img = self.model(x)  # 通过模型获取生成的图像return ops.reshape(img, (-1, 1, 28, 28))  # 将输出重塑为形状 (-1, 1, 28, 28)# 创建生成器实例,latent_size为隐码的维度
net_g = Generator(latent_size)  
net_g.update_parameters_name('generator')  # 更新网络参数名称为'generator'

解析:

  1. from mindspore import nn:导入MindSpore的神经网络模块(nn),用于构建深度学习模型。
  2. import mindspore.ops as ops:导入MindSpore的操作模块(ops),用于执行各种操作。
  3. img_size = 28:定义训练图像的长和宽为28。
  4. class Generator(nn.Cell)::定义一个生成器类Generator,继承自nn.Cell,表示一个可训练的模型。
  5. def __init__(self, latent_size, auto_prefix=True)::初始化生成器,接受隐码的维度latent_size作为参数。
  6. super(Generator, self).__init__(auto_prefix=auto_prefix):调用父类的构造方法,初始化模型。
  7. self.model = nn.SequentialCell():创建一个顺序模型容器,用于按顺序添加层。
  8. 添加多个全连接层(nn.Dense)和ReLU激活函数,并使用批量归一化(nn.BatchNorm1d)对中间层进行归一化处理。
  9. 最后一个全连接层将输出维度变为784(28*28),然后通过nn.Tanh()激活函数将输出值限制在[-1, 1]之间,使其适合生成图像数据。
  10. def construct(self, x)::定义前向传播方法,接收输入x并通过模型生成图像。
  11. ops.reshape(img, (-1, 1, 28, 28)):将生成的图像重塑为形状为(-1, 1, 28, 28)的张量。
  12. net_g = Generator(latent_size):创建生成器实例net_g,并传入隐码的维度。
  13. net_g.update_parameters_name('generator'):将生成器的参数名称更新为’generator’。

判别器

如前所述,判别器 Discriminator 是一个二分类网络模型,输出判定该图像为真实图的概率。主要通过一系列的 Dense 层和 LeakyReLU 层对其进行处理,最后通过 Sigmoid 激活函数,使其返回 [0, 1] 的数据范围内,得到最终概率。注意实例化判别器之后需要修改参数的名称,不然静态图模式下会报错。

# 判别器
class Discriminator(nn.Cell):  # 定义判别器类,继承自nn.Celldef __init__(self, auto_prefix=True):  # 初始化方法super().__init__(auto_prefix=auto_prefix)  # 调用父类构造函数self.model = nn.SequentialCell()  # 创建一个顺序模型容器# [N, 784] -> [N, 512]self.model.append(nn.Dense(img_size * img_size, 512))  # 输入特征数为784,输出为512self.model.append(nn.LeakyReLU())  # 添加LeakyReLU激活函数,默认斜率为0.2# [N, 512] -> [N, 256]self.model.append(nn.Dense(512, 256))  # 进行线性映射,将512维映射到256维self.model.append(nn.LeakyReLU())  # 添加LeakyReLU激活函数# [N, 256] -> [N, 1]self.model.append(nn.Dense(256, 1))  # 进行线性映射,将256维映射到1维self.model.append(nn.Sigmoid())  # 添加Sigmoid激活函数,将输出映射到[0, 1]def construct(self, x):  # 定义前向传播方法x_flat = ops.reshape(x, (-1, img_size * img_size))  # 将输入重塑为(-1, 784)的形状return self.model(x_flat)  # 通过模型计算并返回结果# 创建判别器实例
net_d = Discriminator()  
net_d.update_parameters_name('discriminator')  # 更新网络参数名称为'discriminator'

解析:

  1. class Discriminator(nn.Cell)::定义一个判别器类Discriminator,继承自nn.Cell,表示一个可训练的模型。
  2. def __init__(self, auto_prefix=True)::初始化判别器,接受一个可选参数auto_prefix
  3. super().__init__(auto_prefix=auto_prefix):调用父类的构造方法,初始化模型。
  4. self.model = nn.SequentialCell():创建一个顺序模型容器,用于按顺序添加层。
  5. 添加多个全连接层(nn.Dense)和LeakyReLU激活函数:
    • 第一层将输入的784维特征映射到512维。
    • 第二层将512维特征映射到256维。
    • 最后一层将256维特征映射到1维,输出为判别结果。
  6. self.model.append(nn.Sigmoid()):添加Sigmoid激活函数,将输出值限制在[0, 1]之间,表示输入为假的概率。
  7. def construct(self, x)::定义前向传播方法,接收输入x
  8. x_flat = ops.reshape(x, (-1, img_size * img_size)):将输入重塑为形状为(-1, 784)的张量,以适应模型输入。
  9. return self.model(x_flat):通过模型计算并返回判别结果。
  10. net_d = Discriminator():创建判别器实例net_d
  11. net_d.update_parameters_name('discriminator'):将判别器的参数名称更新为’discriminator’。

损失函数和优化器

定义了 GeneratorDiscriminator 后,损失函数使用MindSpore中二进制交叉熵损失函数BCELoss ;这里生成器和判别器都是使用Adam优化器,但是需要构建两个不同名称的优化器,分别用于更新两个模型的参数,详情见下文代码。注意优化器的参数名称也需要修改。

lr = 0.0002  # 学习率设置为0.0002# 损失函数
adversarial_loss = nn.BCELoss(reduction='mean')  # 使用二元交叉熵损失函数,计算生成器和判别器的损失# 优化器
optimizer_d = nn.Adam(net_d.trainable_params(), learning_rate=lr, beta1=0.5, beta2=0.999)  # 创建判别器的Adam优化器
optimizer_g = nn.Adam(net_g.trainable_params(), learning_rate=lr, beta1=0.5, beta2=0.999)  # 创建生成器的Adam优化器
optimizer_g.update_parameters_name('optim_g')  # 更新生成器的优化器参数名称为'optim_g'
optimizer_d.update_parameters_name('optim_d')  # 更新判别器的优化器参数名称为'optim_d'

解析:

  1. lr = 0.0002:定义学习率,设置为0.0002,这个值通常用于训练生成对抗网络(GAN)。
  2. adversarial_loss = nn.BCELoss(reduction='mean'):定义损失函数为二元交叉熵损失(Binary Cross Entropy Loss),用于计算生成器和判别器的损失值,reduction='mean'表示将损失值取平均。
  3. optimizer_d = nn.Adam(net_d.trainable_params(), learning_rate=lr, beta1=0.5, beta2=0.999)
    • 创建判别器的Adam优化器,传入判别器可训练参数net_d.trainable_params()
    • learning_rate=lr设置学习率为0.0002。
    • beta1=0.5beta2=0.999是Adam优化器的超参数,通常用于控制动量和加速收敛。
  4. optimizer_g = nn.Adam(net_g.trainable_params(), learning_rate=lr, beta1=0.5, beta2=0.999)
    • 创建生成器的Adam优化器,传入生成器可训练参数net_g.trainable_params(),其余参数与判别器优化器相同。
  5. optimizer_g.update_parameters_name('optim_g'):将生成器的优化器参数名称更新为’optim_g’,便于管理和调试。
  6. optimizer_d.update_parameters_name('optim_d'):将判别器的优化器参数名称更新为’optim_d’,便于管理和调试。

模型训练

训练分为两个主要部分。
第一部分是训练判别器。训练判别器的目的是最大程度地提高判别图像真伪的概率。按照原论文的方法,通过提高其随机梯度来更新判别器,最大化 的值。
第二部分是训练生成器。如论文所述,最小化 来训练生成器,以产生更好的虚假图像。
在这两个部分中,分别获取训练过程中的损失,并在每轮迭代结束时进行测试,将隐码批量推送到生成器中,以直观地跟踪生成器 Generator 的训练效果。

import os  # 导入操作系统模块,用于文件和目录操作
import time  # 导入时间模块,用于时间相关功能
import matplotlib.pyplot as plt  # 导入matplotlib.pyplot用于绘图
import mindspore as ms  # 导入MindSpore框架
from mindspore import Tensor, save_checkpoint  # 从MindSpore导入Tensor类和保存检查点的函数total_epoch = 200  # 训练周期数,设置为200
batch_size = 128  # 用于训练的批量大小,设置为128# 加载预训练模型的参数
pred_trained = False  # 预训练标志,初始化为False
pred_trained_g = './result/checkpoints/Generator99.ckpt'  # 预训练生成器的检查点路径
pred_trained_d = './result/checkpoints/Discriminator99.ckpt'  # 预训练判别器的检查点路径checkpoints_path = "./result/checkpoints"  # 结果保存路径,用于保存模型检查点
image_path = "./result/images"  # 测试结果保存路径,用于保存生成的图像

解析:

  1. import os:导入操作系统模块,提供与操作系统交互的功能,例如文件和目录的操作。
  2. import time:导入时间模块,用于获取当前时间、延时操作等。
  3. import matplotlib.pyplot as plt:导入Matplotlib库的绘图库,用于可视化数据。
  4. import mindspore as ms:导入MindSpore框架,便于构建和训练深度学习模型。
  5. from mindspore import Tensor, save_checkpoint:从MindSpore导入Tensor类(用于表示张量)和save_checkpoint函数(用于保存模型检查点)。
  6. total_epoch = 200:定义训练的总周期数为200。
  7. batch_size = 128:定义训练过程中的批量大小为128,即每次训练使用128个样本。
  8. pred_trained = False:初始化预训练标志为False,表示尚未加载预训练模型。
  9. pred_trained_g = './result/checkpoints/Generator99.ckpt':定义预训练生成器模型的检查点路径。
  10. pred_trained_d = './result/checkpoints/Discriminator99.ckpt':定义预训练判别器模型的检查点路径。
  11. checkpoints_path = "./result/checkpoints":定义保存模型检查点的目录路径。
  12. image_path = "./result/images":定义保存生成测试结果图像的目录路径。
# 生成器计算损失过程
def generator_forward(test_noises):fake_data = net_g(test_noises)  # 使用生成器生成假数据fake_out = net_d(fake_data)  # 判别器对假数据的预测loss_g = adversarial_loss(fake_out, ops.ones_like(fake_out))  # 计算生成器损失return loss_g  # 返回生成器损失# 判别器计算损失过程
def discriminator_forward(real_data, test_noises):fake_data = net_g(test_noises)  # 使用生成器生成假数据fake_out = net_d(fake_data)  # 判别器对假数据的预测real_out = net_d(real_data)  # 判别器对真实数据的预测real_loss = adversarial_loss(real_out, ops.ones_like(real_out))  # 计算真实数据损失fake_loss = adversarial_loss(fake_out, ops.zeros_like(fake_out))  # 计算假数据损失loss_d = real_loss + fake_loss  # 总判别器损失return loss_d  # 返回判别器损失# 梯度方法
grad_g = ms.value_and_grad(generator_forward, None, net_g.trainable_params())  # 生成器的梯度计算
grad_d = ms.value_and_grad(discriminator_forward, None, net_d.trainable_params())  # 判别器的梯度计算def train_step(real_data, latent_code):# 计算判别器损失和梯度loss_d, grads_d = grad_d(real_data, latent_code)  # 计算判别器损失及其梯度optimizer_d(grads_d)  # 更新判别器参数loss_g, grads_g = grad_g(latent_code)  # 计算生成器损失及其梯度optimizer_g(grads_g)  # 更新生成器参数return loss_d, loss_g  # 返回损失# 保存生成的图片
def save_imgs(gen_imgs1, idx):for i3 in range(gen_imgs1.shape[0]):plt.subplot(5, 5, i3 + 1)  # 创建5x5的子图plt.imshow(gen_imgs1[i3, 0, :, :] / 2 + 0.5, cmap="gray")  # 显示生成的图像plt.axis("off")  # 不显示坐标轴plt.savefig(image_path + "/test_{}.png".format(idx))  # 保存生成的图像# 设置参数保存路径
os.makedirs(checkpoints_path, exist_ok=True)  # 创建检查点保存路径
# 设置中间过程生成图片保存路径
os.makedirs(image_path, exist_ok=True)  # 创建生成图片保存路径net_g.set_train()  # 设置生成器为训练模式
net_d.set_train()  # 设置判别器为训练模式# 储存生成器和判别器损失
losses_g, losses_d = [], []  # 初始化损失列表for epoch in range(total_epoch):  # 进行训练周期start = time.time()  # 记录开始时间for (iter, data) in enumerate(mnist_ds):  # 遍历训练数据集start1 = time.time()  # 记录每一步开始时间image, latent_code = data  # 获取真实图像和潜在编码image = (image - 127.5) / 127.5  # 归一化图像到[-1, 1]image = image.reshape(image.shape[0], 1, image.shape[1], image.shape[2])  # 重塑图像形状d_loss, g_loss = train_step(image, latent_code)  # 进行训练步骤end1 = time.time()  # 记录每一步结束时间if iter % 10 == 0:  # 每10步输出一次信息print(f"Epoch:[{int(epoch):>3d}/{int(total_epoch):>3d}], "f"step:[{int(iter):>4d}/{int(iter_size):>4d}], "f"loss_d:{d_loss.asnumpy():>4f} , "f"loss_g:{g_loss.asnumpy():>4f} , "f"time:{(end1 - start1):>3f}s, "f"lr:{lr:>6f}")end = time.time()  # 记录每个epoch结束时间print("time of epoch {} is {:.2f}s".format(epoch + 1, end - start))  # 输出每个epoch耗时losses_d.append(d_loss.asnumpy())  # 保存判别器损失losses_g.append(g_loss.asnumpy())  # 保存生成器损失# 每个epoch结束后,使用生成器生成一组图片gen_imgs = net_g(test_noise)  # 生成测试图像save_imgs(gen_imgs.asnumpy(), epoch)  # 保存生成的图像# 根据epoch保存模型权重文件if epoch % 1 == 0:  # 每个epoch保存一次save_checkpoint(net_g, checkpoints_path + "/Generator%d.ckpt" % (epoch))  # 保存生成器检查点save_checkpoint(net_d, checkpoints_path + "/Discriminator%d.ckpt" % (epoch))  # 保存判别器检查点

解析:

  1. def generator_forward(test_noises)::定义生成器前向传播函数,计算生成器损失。
    • fake_data = net_g(test_noises):生成器生成假数据。
    • fake_out = net_d(fake_data):判别器对假数据进行分类。
    • loss_g = adversarial_loss(fake_out, ops.ones_like(fake_out)):计算生成器损失。
    • return loss_g:返回损失值。
  2. def discriminator_forward(real_data, test_noises)::定义判别器前向传播函数,计算判别器损失。
    • fake_data = net_g(test_noises):生成器生成假数据。
    • fake_out = net_d(fake_data):判别器对假数据进行分类。
    • real_out = net_d(real_data):判别器对真实数据进行分类。
    • real_loss = adversarial_loss(real_out, ops.ones_like(real_out)):计算真实数据损失。
    • fake_loss = adversarial_loss(fake_out, ops.zeros_like(fake_out)):计算假数据损失。
    • loss_d = real_loss + fake_loss:总判别器损失。
    • return loss_d:返回损失值。
  3. grad_g = ms.value_and_grad(generator_forward, None, net_g.trainable_params()):计算生成器的损失和梯度。
  4. grad_d = ms.value_and_grad(discriminator_forward, None, net_d.trainable_params()):计算判别器的损失和梯度。
  5. def train_step(real_data, latent_code)::定义训练步骤,计算损失并更新模型。
    • loss_d, grads_d = grad_d(real_data, latent_code):计算判别器损失和梯度。
    • optimizer_d(grads_d):更新判别器参数。
    • loss_g, grads_g = grad_g(latent_code):计算生成器损失和梯度。
    • optimizer_g(grads_g):更新生成器参数。
    • return loss_d, loss_g:返回损失。
  6. def save_imgs(gen_imgs1, idx)::定义保存生成图像的函数。
    • 使用plt.imshow显示图像,并使用plt.savefig保存图像。
  7. os.makedirs(checkpoints_path, exist_ok=True)os.makedirs(image_path, exist_ok=True):创建保存检查点和图像的目录。
  8. net_g.set_train()net_d.set_train():设置生成器和判别器为训练模式。
  9. 训练循环:进行训练,计算损失,并每隔一定步数输出信息。
    • losses_d.append(d_loss.asnumpy())losses_g.append(g_loss.asnumpy()):记录损失值。
    • 每个epoch结束后,生成并保存一组图像。
    • 根据epoch保存模型权重。

效果展示

运行下面代码,描绘DG损失与训练迭代的关系图:

plt.figure(figsize=(6, 4))  # 创建一个新的图形,设置大小为6x4英寸
plt.title("Generator and Discriminator Loss During Training")  # 设置图形标题
plt.plot(losses_g, label="G", color='blue')  # 绘制生成器损失曲线,标签为"G",颜色为蓝色
plt.plot(losses_d, label="D", color='orange')  # 绘制判别器损失曲线,标签为"D",颜色为橙色
plt.xlim(-20, 220)  # 设置x轴的范围
plt.ylim(0, 3.5)  # 设置y轴的范围
plt.xlabel("iterations")  # 设置x轴标签
plt.ylabel("Loss")  # 设置y轴标签
plt.legend()  # 显示图例
plt.show()  # 展示图形

解析:

  1. plt.figure(figsize=(6, 4)):创建一个新的图形,并设置其大小为6英寸宽、4英寸高。
  2. plt.title("Generator and Discriminator Loss During Training"):设置图形的标题为“训练过程中生成器和判别器的损失”。
  3. plt.plot(losses_g, label="G", color='blue'):绘制生成器损失的曲线,设置标签为"G"并使用蓝色。
  4. plt.plot(losses_d, label="D", color='orange'):绘制判别器损失的曲线,设置标签为"D"并使用橙色。
  5. plt.xlim(-20, 220):设置x轴的显示范围为-20到220。
  6. plt.ylim(0, 3.5):设置y轴的显示范围为0到3.5。
  7. plt.xlabel("iterations"):设置x轴的标签为“迭代次数”。
  8. plt.ylabel("Loss"):设置y轴的标签为“损失”。
  9. plt.legend():显示图例,标识不同的曲线。
  10. plt.show():展示绘制的图形。

image.png
可视化训练过程中通过隐向量生成的图像。

import cv2  # 导入OpenCV库用于图像处理
import matplotlib.animation as animation  # 导入Matplotlib的动画模块# 将训练过程中生成的测试图转为动态图
image_list = []  # 初始化图像列表
for i in range(total_epoch):  # 遍历每个epochimage_list.append(cv2.imread(image_path + "/test_{}.png".format(i), cv2.IMREAD_GRAYSCALE))  # 读取图像并添加到列表中,使用灰度模式show_list = []  # 初始化显示列表
fig = plt.figure(dpi=70)  # 创建图形,设置分辨率为70 DPI
for epoch in range(0, len(image_list), 5):  # 每5个epoch选择一个图像plt.axis("off")  # 关闭坐标轴show_list.append([plt.imshow(image_list[epoch], cmap='gray')])  # 将图像添加到显示列表中,使用灰度色图# 创建动画
ani = animation.ArtistAnimation(fig, show_list, interval=1000, repeat_delay=1000, blit=True)  # 设置动画参数
ani.save('train_test.gif', writer='pillow', fps=1)  # 保存动画为GIF文件,帧率为1 FPS

解析:

  1. import cv2:导入OpenCV库,用于处理图像文件。
  2. import matplotlib.animation as animation:导入Matplotlib的动画模块,用于制作动态图像。
  3. image_list = []:初始化一个空列表,用于存储读取的图像。
  4. for i in range(total_epoch)::遍历每个训练周期。
    • image_list.append(cv2.imread(image_path + "/test_{}.png".format(i), cv2.IMREAD_GRAYSCALE)):读取指定路径下的图像文件并以灰度模式存储到image_list中。
  5. show_list = []:初始化一个空列表,用于存储将要显示的图像帧。
  6. fig = plt.figure(dpi=70):创建一个新的图形,设置图形的分辨率为70 DPI。
  7. for epoch in range(0, len(image_list), 5)::每隔5个epoch选择一个图像进行显示。
    • plt.axis("off"):关闭图形的坐标轴。
    • show_list.append([plt.imshow(image_list[epoch], cmap='gray')]):将当前选择的图像以灰度色图形式添加到show_list中。
  8. ani = animation.ArtistAnimation(fig, show_list, interval=1000, repeat_delay=1000, blit=True):创建动画对象,设置每帧间隔为1000毫秒,重复延迟为1000毫秒,启用blit优化。
  9. ani.save('train_test.gif', writer='pillow', fps=1):将创建的动画保存为名为train_test.gif的GIF文件,帧率设置为1帧每秒。


从上面的图像可以看出,随着训练次数的增多,图像质量也越来越好。如果增大训练周期数,当 epoch 达到100以上时,生成的手写数字图片与数据集中的较为相似。下面我们通过加载生成器网络模型参数文件来生成图像,代码如下:

模型推理

下面我们通过加载生成器网络模型参数文件来生成图像,代码如下:

import mindspore as ms  # 导入MindSpore库# 指定要加载的检查点路径
test_ckpt = './result/checkpoints/Generator199.ckpt'# 加载模型参数
parameter = ms.load_checkpoint(test_ckpt)  # 从指定路径加载检查点
ms.load_param_into_net(net_g, parameter)  # 将参数加载到生成器网络net_g中# 模型生成结果
test_data = Tensor(np.random.normal(0, 1, (25, 100)).astype(np.float32))  # 生成随机噪声作为输入
images = net_g(test_data).transpose(0, 2, 3, 1).asnumpy()  # 使用生成器生成图像并转换为NumPy数组# 结果展示
fig = plt.figure(figsize=(3, 3), dpi=120)  # 创建一个新的图形,设置大小和分辨率
for i in range(25):  # 遍历生成的25张图像fig.add_subplot(5, 5, i + 1)  # 创建5x5的子图布局plt.axis("off")  # 关闭坐标轴plt.imshow(images[i].squeeze(), cmap="gray")  # 显示图像,使用灰度色图
plt.show()  # 展示图形

解析:

  1. import mindspore as ms:导入MindSpore库,用于深度学习框架的操作。
  2. test_ckpt = './result/checkpoints/Generator199.ckpt':指定要加载的生成器模型的检查点文件路径。
  3. parameter = ms.load_checkpoint(test_ckpt):使用ms.load_checkpoint函数从指定路径加载模型参数。
  4. ms.load_param_into_net(net_g, parameter):将加载的参数导入到生成器网络net_g中,使其可以使用这些参数生成图像。
  5. test_data = Tensor(np.random.normal(0, 1, (25, 100)).astype(np.float32)):生成一个形状为(25, 100)的随机噪声张量,符合标准正态分布,并转换为浮点数类型,以作为生成器的输入。
  6. images = net_g(test_data).transpose(0, 2, 3, 1).asnumpy():将随机噪声传入生成器net_g,生成图像,然后转换图像的维度并转为NumPy数组格式。
  7. fig = plt.figure(figsize=(3, 3), dpi=120):创建一个新的图形,设置大小为3x3英寸,分辨率为120 DPI。
  8. for i in range(25)::遍历生成的25张图像。
    • fig.add_subplot(5, 5, i + 1):在5x5的子图布局中添加子图。
    • plt.axis("off"):关闭子图的坐标轴。
    • plt.imshow(images[i].squeeze(), cmap="gray"):显示当前图像,使用灰度色图,并去掉多余的维度。
  9. plt.show():展示生成的图像。

image.png

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/385622.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

超逼真AI生成电影来了!《泰坦尼克号》AI重生!浙大阿里发布MovieDreamer,纯AI生成电影引爆热议!

视频生成领域的最新进展主要利用了短时内容的扩散模型。然而,这些方法往往无法对复杂的叙事进行建模,也无法在较长时间内保持角色的一致性,而这对于电影等长篇视频制作至关重要。 对此,浙大&阿里发布了一种新颖的分层框架Mov…

图解分布式事务中的2PC与Seata方案

文章目录 文章导图什么是2PC解决传统2PC方案XA方案DTP模型举例:新用户注册送积分总结: Seata方案设计思想执行流程举例:新用户注册送积分 Seata实现2PC事务(AT模式)前提整体机制写隔离读隔离实际案例理解要点说明核心代…

uniapp小程序中富文本内容渲染图片不展示的问题

文章目录 1.从后端请求的数据中图片是这样的2.前端我是用Uview中的u-parse组件3.这样修改去掉富文本中的所有反斜杠4.完美解决 1.从后端请求的数据中图片是这样的 <p><img src\\\"https://zhangsanfengcode.cn:8084/images/2024-06-28a257befe.jpg\\\" alt…

如何使用 SQLite ?

SQLite 是一个轻量级、嵌入式的关系型数据库管理系统&#xff08;RDBMS&#xff09;。它是一种 C 库&#xff0c;实现了自给自足、无服务器、零配置、事务性 SQL 数据库引擎。SQLite 的源代码是开放的&#xff0c;完全在公共领域。它被广泛用于各种应用程序&#xff0c;包括浏览…

关于 OSPF 序列号范围 0x80000001-0x7FFFFFFF 正本清源

注&#xff1a;机翻&#xff0c;未校对。 正本&#xff1a;RFC 2328 OSPF Version 2 中相关解释 April 1998 12.1.6. LS sequence number 12.1.6. 序列号 The sequence number field is a signed 32-bit integer. It is used to detect old and duplicate LSAs. The space …

【OSS对象存储】Springboot集成阿里云OSS + 私有化部署Minio

【OSS对象存储】Springboot集成阿里云OSS 私有化部署Minio 一、摘要二、POM依赖三、配置文件四、表结构设计五、代码实现5.1 代码包结构5.2 API封装5.3 增删改查 六、扩展6.1 Minio配置https访问 一、摘要 掌握阿里云OSS、私有化部署Minio两种对象存储的使用方式运用工厂策略…

【C++指南】内存管理(上)

&#x1f493; 博客主页&#xff1a;倔强的石头的CSDN主页 &#x1f4dd;Gitee主页&#xff1a;倔强的石头的gitee主页 ⏩ 文章专栏&#xff1a;《数据结构与算法》 期待您的关注

vue上传Excel文件并直接点击文件列表进行预览

本文主要内容&#xff1a;用elementui的Upload 组件上传Excel文件&#xff0c;上传后的列表采用xlsx插件实现点击预览表格内容效果。 在项目中可能会有这样的需求&#xff0c;有很多种方法实现。但是不想要跳转外部地址&#xff0c;所以用了xlsx插件来解析表格&#xff0c;并展…

总结一些vue3小知识3

总结一些vue3小知识1&#xff1a;http://t.csdnimg.cn/C5vER 总结一些vue3小知识2&#xff1a;http://t.csdnimg.cn/sscid 1.限制时间选择器只能选择后面的日期 说明&#xff1a;disabled-date属性是一个用来判断该日期是否被禁用的函数&#xff0c;接受一个 Date 对象作为参…

科普文:分布式架构中的三高:高并发、高性能、高可用

关于高并发 高并发场景 互联网应用以及云计算的普及&#xff0c;使得架构设计和软件技术的关注点从如何实现复杂的业务逻 辑&#xff0c;转变为如何满足大量用户的高并发访问请求。 一个简单的计算处理过程&#xff0c;如果一旦面对大量的用户访问&#xff0c;整个技术挑战就…

DP 整数拆分不同的二叉搜索树 DAY21

整数拆分&#xff1f; 给定一个正整数 n &#xff0c;将其拆分为 k 个 正整数 的和&#xff08; k > 2 &#xff09;&#xff0c;并使这些整数的乘积最大化。 返回 你可以获得的最大乘积。 示例 1: 输入: n 2 输出: 1 解释: 2 1 1, 1 1 1。示例 2: 输入: n 10 输…

实验2-2-1 温度转换

#include<stdio.h> #include <math.h> int main(){int c,f150;c5*(f-32)/9;printf("fahr 150, celsius %d",c); }

sqlilabs解题方法

Lass1 查询id为1的用户名和密码 查询id为2的用户名和密码 没有回显&#xff0c;不含id-1的行 判断字段数&#xff0c;字段数为3 查询数据库用户名&#xff0c;和数据库名 查询时id必须超出数据库以外&#xff0c;一般用-1 用户名&#xff1a;user() 数据库名&#xff1a;databa…

redis:清除缓存的最简单命令示例

清除redis缓存命令(执行命令列表见截图) 1.打开cmd窗口&#xff0c;并cd进入redis所在目录 2.登录redis redis-cli 3.查询指定队列当前的记录数 llen 队列名称 4.清除指定队列所有记录 ltrim 队列名称 1 0 5.再次查询&#xff0c;确认队列的记录数是否已清除

opencascade AIS_Manipulator源码学习

前言 AIS_Manipulator 是 OpenCASCADE 库中的一个类&#xff0c;用于在3D空间中对其他交互对象或一组对象进行局部变换。该类提供了直观的操控方式&#xff0c;使用户可以通过鼠标进行平移、缩放和旋转等操作。 详细功能 交互对象类&#xff0c;通过鼠标操控另一个交互对象…

【Django】 读取excel文件并在前端以网页形式显示-安装使用Pandas

文章目录 安装pandas写views写urls安装openpyxl重新调试 安装pandas Pandas是一个基于NumPy的Python数据分析库&#xff0c;可以从各种文件格式如CSV、JSON、SQL、Excel等导入数据&#xff0c;并支持多种数据运算操作&#xff0c;如归并、再成形、选择等。 更换pip源 pip co…

在 Postman 中设置全局 token

目录 问题描述解决方案 问题描述 在使用 Postman 进行接口测试时&#xff0c;经常会遇到在 Header 中添加 token 的情况。当接口数量较多时&#xff0c;需要为每个接口进行设置&#xff0c;而且当 token 失效时需要重新获取并设置&#xff0c;这样一来效率较低。 解决方案 下…

MSPM0G3507学习笔记1:开发环境_引脚认识与点灯

今日速通一款Ti的单片机用于电赛&#xff1a;MSPM0G3507 这里默认已经安装好了Keil5_MDK 首先声明一下: 因为是速成&#xff0c;所以需要一定单片机学习基础&#xff0c;然后我写的也不会详细&#xff0c;这个专栏的笔记也就是自己能看懂就行的目标~~~ 文章提供测试代码解…

Matlab进阶绘图第65期—带分组折线段的柱状图

带分组折线段的柱状图是在原始柱状图的基础上&#xff0c;在每组柱状图位置处分别添加折线段&#xff0c;以进行对比或添加额外信息。 由于Matlab中未收录带分组折线段的柱状图的绘制函数&#xff0c;因此需要大家自行设法解决。 本文使用自制的BarwithGroupedLine小工具进行…

「Ant Design」Antd 中卡片如何完全不展示内容区域、按需展示内容区域、不展示标题

前言 下面是默认的 Antd 卡片&#xff0c;由以下区域组成 处理 Antd 的 Card 展示形式大致有下面三种 卡片完全不展示内容区域 const App () > (<Card title"Default size card" extra{<a href"#">More</a>} style{{ width: 300 }}b…