cGAN和pix2pix的基础概念
cGAN
cGAN是条件生成对抗网络(Conditional Generative Adversarial Networks)的简称。
它是一种基于基础GAN(Generative Adversarial Networks)架构的变体,通过给GAN模型引入额外的信息或条件来指导数据生成过程。
这种额外信息可以是类别标签、文本描述、图像特征等,使得生成的数据不仅真实,而且能够满足特定条件。
cGAN与GAN的区别
-
引入条件:最核心的区别在于cGAN为生成过程添加了一个明确的控制变量——条件。这意味着除了随机噪声外,生成器还需要接受一些附加输入作为条件,并根据这个条件去生成相应的输出。例如,在手写数字生成任务中,这个条件可能是想要生成的具体数字(0-9之间的某个数)。这样做的好处是可以更加精确地控制生成内容。
-
结构变化:为了处理这些额外的条件信息,cGAN通常会对原始GAN的架构做一些调整。比如,在生成器输入端加入条件信息的同时,也可能需要对判别器做出相应修改,使其能够利用这些条件更好地评估生成样本的质量。
-
应用场景扩展:由于增加了可控性,cGAN被广泛应用于更多领域,如图像到图像转换(Image-to-Image Translation)、文字转图片(Text-to-Image Synthesis)等复杂场景下,其中不仅要求生成高质量的图像,还希望这些图像是按照给定的文字描述或者样式指引来创建的。
pix2pix
pix2pix 模型是一种基于条件生成对抗网络(Conditional GAN, cGAN)的图像到图像转换框架。
由Phillip Isola等人在2016年的论文《Image-to-Image Translation with Conditional Adversarial Networks》中提出。
这个模型的主要目的是学习从一种类型的图像到另一种类型图像的映射,比如将边缘轮廓图转换成彩色照片、黑白照片上色、卫星图像转为地图等。
pix2pix的工作原理
生成器(Generator, G): 通常采用U-Net架构,这是一种编码-解码结构,具有跳跃连接(skip connections),可以更好地保留图像的空间信息。
判别器(Discriminator, D): 用于判断一对图像是否是真实的输入-输出对。判别器会接收两个输入:一个是来自真实数据集的输入-输出图像对,另一个是由生成器产生的输入-生成输出图像对。判别器的目标是区分这两个对中的哪一个更可能是真实的。
pix2pix模型中的判别器通常采用PatchGAN架构。PatchGAN是一种特殊的判别器设计,它不是试图对整个图像进行全局真伪判断,而是将输入图像分割成多个小块(patches),并对每个小块独立地做出局部真伪判断。
损失函数:除了传统的对抗损失(Adversarial Loss),即让判别器尽可能准确地分辨真假而生成器则尽量欺骗判别器外,pix2pix还使用了L1损失(也称为绝对误差损失)。
L1损失鼓励生成器产生与目标图像非常接近的结果,这有助于提高生成图像的质量和细节的准确性。
U-Net
U-Net的基本架构是一种对称的编码-解码结构,它通过一系列的卷积层和上采样层来处理图像,并在编码器和解码器之间使用跳跃连接(skip connections)来保留空间信息。这种设计特别适合于需要精确像素级输出的任务,如医学图像分割。
编码器:
- 编码器通常由一系列的卷积层组成,每个卷积层后面可能跟着ReLU激活函数。
- 每经过一个卷积块之后,会有一个池化操作(通常是最大池化),用于减小特征图的空间维度,同时增加通道数以提取更高级别的特征。
- 这个过程可以看作是逐步压缩输入图像的过程,从而获得更高层次、更抽象的特征表示。
瓶颈层:
- 在编码器和解码器之间的部分被称为瓶颈层。在这里,网络已经将输入压缩到了最低分辨率,但拥有最多的特征通道。
- 瓶颈层通常包含几个卷积层,用于进一步提炼特征。
解码器:
- 解码器通过一系列的反卷积(转置卷积)或上采样操作逐渐放大特征图,恢复到原始图像的空间尺寸。
- 在每个上采样步骤之后,解码器会与编码器中对应层的特征图进行拼接(concatenate),这些拼接的特征图就是所谓的“跳跃连接”。这样做的目的是为了结合低层次的细节信息和高层次的语义信息,帮助生成更加精细的输出。
输出层:
- 最后一层通常是卷积层,用来将特征图转换成所需的输出格式,比如单通道的概率图(对于二分类问题)或多通道的类别概率图(对于多分类问题)。
- 输出层也可能包括一个激活函数,如Sigmoid(二分类)或Softmax(多分类),以确保输出值落在适当的范围内。
PatchGAN
感受野:在PatchGAN的设计中,每个输出单元对应输入图像中的一个小区域(patch),这个小区域被称为该单元的感受野。例如,如果最终的PatchGAN输出是一个30x30的矩阵,则意味着它将原始图像分成了多个30x30大小的小块,并且对每个小块进行了独立的真伪评估。
全卷积网络:PatchGAN本质上是一个全卷积网络(Fully Convolutional Network, FCN)。这意味着它没有使用任何全连接层来直接决定整个图像的真实性,而是通过一系列的卷积操作来处理图像数据,并输出一个表示各个局部区域真实性的矩阵。
由于PatchGAN关注的是局部细节,它特别擅长于捕捉图像中的高频特征,如纹理和边缘等。这对于生成高质量、高分辨率的图像非常有帮助。
基于MindSpore的pix2pix
下载数据集
# 下载数据集
from download import downloadurl = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/models/application/dataset_pix2pix.tar"download(url, "./dataset", kind="tar", replace=True)
定义U-Net Skip Connection Block
# 定义UNet Skip Connection Block
import mindspore
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore.common import initializer as init
'''
outer_nc 和 inner_nc 分别表示输出通道数和中间层通道数。
in_planes 是输入通道数,如果未指定则默认为 outer_nc。
dropout 控制是否在模型中添加 Dropout 层。
submodule 是嵌套的子模块,通常是一个更深层的 UNetSkipConnectionBlock。
outermost 和 innermost 用于区分当前块是否是最外层或最内层。
alpha 是 LeakyReLU 的负斜率。
norm_mode 用于选择归一化方式,可以是 'batch' 或 'instance'。
'''
class UNetSkipConnectionBlock(nn.Cell):def __init__(self, outer_nc, inner_nc, in_planes=None, dropout=False, submodule=None, outermost=False, innermost=False, alpha=0.2, norm_mode='batch'):super(UNetSkipConnectionBlock, self).__init__()# 初始化归一化层,默认为BatchNorm2ddown_norm = nn.BatchNorm2d(inner_nc)up_norm = nn.BatchNorm2d(outer_nc)# 默认不使用偏置use_bias = False# 判断是否使用实例归一化'''实例归一化是对单个样本的所有特征通道进行归一化的操作。它通常应用于风格迁移任务和生成对抗网络(GAN)中。'''# 如果归一化模式是实例归一化,则使用InstanceNorm2dif norm_mode == 'instance':down_norm = nn.InstanceNorm2d(inner_nc, affine=False) # 注意这里应该是InstanceNorm2d而不是BatchNorm2dup_norm = nn.InstanceNorm2d(outer_nc, affine=False)use_bias = True # 实例归一化时使用偏置# 如果in_planes未指定,则默认为outer_ncif in_planes is None:in_planes = outer_nc# 定义下采样卷积层down_conv = nn.Conv2d(in_planes, inner_nc, kernel_size=4, stride=2, padding=1, has_bias=use_bias, pad_mode='pad')# 定义激活函数down_relu = nn.LeakyReLU(alpha)up_relu = nn.ReLU()# 根据是否是最外层或最内层来定义不同的结构if outermost:# 上采样转置卷积层up_conv = nn.Conv2dTranspose(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, pad_mode='pad')# 下采样部分down = [down_conv]# 上采样部分up = [up_relu, up_conv, nn.Tanh()] # 使用Tanh作为输出激活函数# 组合模型model = down + [submodule] + upelif innermost: # 最内层up_conv = nn.Conv2dTranspose(inner_nc, outer_nc, kernel_size=4, stride=2, padding=1, has_bias=use_bias, pad_mode='pad')down = [down_relu, down_conv]up = [up_relu, up_conv, up_norm]model = down + upelse: # 中间层up_conv = nn.Conv2dTranspose(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, has_bias=use_bias, pad_mode='pad')down = [down_relu, down_conv, down_norm]up = [up_relu, up_conv, up_norm]model = down + [submodule] + upif dropout:# 如果需要dropout,在模型中添加Dropout层model.append(nn.Dropout(p=0.5))# 封装模型self.model = nn.SequentialCell(model)# 跳跃连接只在非最外层时使用self.skip_connections = not outermostdef construct(self, x):# 执行前向传播out = self.model(x)if self.skip_connections:# 如果有跳跃连接,将输入x与输出out拼接起来out = ops.concat((out, x), axis=1)return out
基于Block定义生成器
# 基于 U-Net 结构的生成器网络。
'''
in_planes:输入通道数。
out_planes:输出通道数。
ngf:生成器的基础特征图数量,默认为 64。
n_layers:U-Net 的层数,默认为 8。
norm_mode:归一化模式,可以是 'bn' 或 'instance',默认为 'bn'。
dropout:是否使用 Dropout 层,默认为 False。
'''
class UNetGenerator(nn.Cell):def __init__(self, in_planes, out_planes, ngf=64, n_layers=8, norm_mode='bn', dropout=False):# 定义最内层的UNetSkipConnectionBlockunet_block = UNetSkipConnectionBlock(ngf * 8, ngf * 8, in_planes=None, submodule=None,norm_mode=norm_mode, innermost=True)# 添加中间层的UNetSkipConnectionBlockfor _ in range(n_layers - 5):unet_block = UNetSkipConnectionBlock(ngf * 8, ngf * 8, in_planes=None, submodule=unet_block,norm_mode=norm_mode, dropout=dropout)# 添加更外层的UNetSkipConnectionBlockunet_block = UNetSkipConnectionBlock(ngf * 4, ngf * 8, in_planes=None, submodule=unet_block,norm_mode=norm_mode)unet_block = UNetSkipConnectionBlock(ngf * 2, ngf * 4, in_planes=None, submodule=unet_block,norm_mode=norm_mode)unet_block = UNetSkipConnectionBlock(ngf * 1, ngf * 2, in_planes=None, submodule=unet_block,norm_mode=norm_mode)# 定义最外层的UNetSkipConnectionBlock并完成网络组装self.model = UNetSkipConnectionBlock(out_planes, ngf, in_planes=in_planes, submodule=unet_block,outermost=True, norm_mode=norm_mode)def construct(self, x):# 执行前向传播return self.model(x)
定义PatchGAN的基本模块
# 基于PatchGAN的判别器
# 包含卷积、归一化和激活函数的复合层
'''
in_planes # 输入通道数
out_planes # 输出通道数
kernel_size=4 # 卷积核大小,默认为4
stride=2 # 步长,默认为2
alpha=0.2 # LeakyReLU的负斜率,默认为0.2
norm_mode='batch' # 归一化模式,默认为'batch'
pad_mode='CONSTANT' # 填充模式,默认为'CONSTANT'
use_relu=True # 是否使用激活函数,默认为True
padding=None # 填充大小,默认为None
'''
class ConvNormRelu(nn.Cell):def __init__(self, in_planes, out_planes, kernel_size=4, stride=2, alpha=0.2, norm_mode='batch', pad_mode='CONSTANT', use_relu=True, padding=None):super(ConvNormRelu, self).__init__()# 初始化归一化层norm = nn.BatchNorm2d(out_planes)if norm_mode == 'instance':# 注意这里应该是InstanceNorm2d而不是BatchNorm2dnorm = nn.InstanceNorm2d(out_planes, affine=False)# 如果是实例归一化,则使用偏置has_bias = (norm_mode == 'instance')# 计算默认填充大小if not padding:padding = (kernel_size - 1) // 2# 根据填充模式选择不同的处理方式if pad_mode == 'CONSTANT':# 使用常量填充conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='pad',has_bias=has_bias, padding=padding)layers = [conv, norm]else:# 使用指定模式的填充paddings = ((0, 0), (0, 0), (padding, padding), (padding, padding))pad = nn.Pad(paddings=paddings, mode=pad_mode)conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='pad', has_bias=has_bias)layers = [pad, conv, norm]# 添加激活函数if use_relu:relu = nn.ReLU()if alpha > 0:relu = nn.LeakyReLU(alpha)layers.append(relu)# 组合网络self.features = nn.SequentialCell(layers)def construct(self, X):output = self.features(X)return output
定义判别器
'''
in_planes:输入通道数,默认为3(RGB图像)
ndf:基础特征图数量,默认为64
n_layers:层数,默认为3
alpha:LeakyReLU的负斜率,默认为0.2
norm_mode:归一化模式,默认为'batch'
'''
class Discriminator(nn.Cell):def __init__(self, in_planes, ndf=64, n_layers=3, alpha=0.2, norm_mode='batch'):super(Discriminator, self).__init__()# 初始化参数kernel_size = 4layers = [nn.Conv2d(in_planes, ndf, kernel_size, 2, pad_mode='pad', padding=1),nn.LeakyReLU(alpha)]# 初始化特征图倍增因子# 特征图倍增因子(feature map multiplier)# 是一种用于控制网络层数增加时特征图数量(即通道数)增长速度的设计参数。nf_mult = ndf# 添加中间层for i in range(1, n_layers):nf_mult_prev = nf_multnf_mult = min(2 ** i, 8) * ndflayers.append(ConvNormRelu(nf_mult_prev, nf_mult, kernel_size, 2, alpha, norm_mode, padding=1))# 添加最后一层之前的层nf_mult_prev = nf_multnf_mult = min(2 ** n_layers, 8) * ndflayers.append(ConvNormRelu(nf_mult_prev, nf_mult, kernel_size, 1, alpha, norm_mode, padding=1))# 添加输出层layers.append(nn.Conv2d(nf_mult, 1, kernel_size, 1, pad_mode='pad', padding=1))# 组合模型self.features = nn.SequentialCell(layers)def construct(self, x, y):# 将输入x和y在通道维度上拼接x_y = ops.concat((x, y), axis=1)# 对图像进行判别output = self.features(x_y)return output
参数定义
# 参数定义
g_in_planes = 3
g_out_planes = 3
g_ngf = 64
g_layers = 8
d_in_planes = 6
d_ndf = 64
d_layers = 3
alpha = 0.2
init_gain = 0.02
init_type = 'normal'
创建生成器对象
# 创建生成器对象
net_generator = UNetGenerator(in_planes=g_in_planes, out_planes=g_out_planes,ngf=g_ngf, n_layers=g_layers)
# 初始化一个U-Net生成器网络(UNetGenerator)的权重和偏置。
'''
用了cells_and_names()方法来遍历网络中的所有模块(cell)及其名称。
这个方法返回的是一个迭代器,每个元素是一个包含名称和对应模块的元组。
在这个循环中,我们只关心模块本身,所以使用了_作为名称变量,表示不使用该值。
'''
for _, cell in net_generator.cells_and_names():if isinstance(cell, (nn,Conv2d, nn.Conv2dTranspose)):if init_type == 'normal':cell.weight.set_data(init.initializer(init.Normal(init_gain), cell.weight.shape))elif init_type == 'xavier':cell.weight.set_data(init.initializer(init.XavierUniform(init_gain), cell.weight.shape))elif init_type == 'constant':cell.weight.set_data(init.initializer(0.001, cell.weight.shape))else:raise NotImplementedError('initialization method [%s] is not implemented' % init_type)# 批归一化层初始化部分elif isinstance(cell, nn.BarchNorm2d):cell.gamma.set_data(init.initializer('ones', cell.gamma.shape))cell.beta.set_data(init.initializer('zeros', cell.beta.shape))
创建判别器对象
net_discriminator = Discriminator(in_planes=d_in_planes, ndf=d_ndf, alpha=alpha, n_layers=d_layers)for _, cell in net_discriminator.cells_and_names():if isinstance(cell, (nn.Conv2d, nn.Conv2dTranspose)):if init_type == 'normal':cell.weight.set_data(init.initializer(init.Normal(init_gain), cell.weight.shape))elif init_type == 'xavier':cell.weight.set_data(init.initializer(init.XavierUniform(init_gain), cell.weight.shape))elif init_type == 'constant':cell.weight.set_data(init.initializer(0.001, cell.weight.shape))else:raise NotImplementedError('initialization method [%s] is not implemented' % init_type)elif isinstance(cell, nn.BatchNorm2d):cell.gamma.set_data(init.initializer('ones', cell.gamma.shape))cell.beta.set_data(init.initializer('zeros', cell.beta.shape))
组装pix2pix网络
# 组装pix2pix网络
class Pix2Pix(nn.Cell):def __init__(self, discriminator, generator):super(Pix2Pix, self).__init__(auto_prefix=True)self.net_discriminator = discriminatorself.net_generator = generatordef construct(self, reala):fakeb = self.net_generator(reala)return fakeb
模型训练
# 进行训练
import numpy as np
import os
import datetime
from mindspore import value_and_grad, Tensorepoch_num = 100
ckpt_dir = "results/ckpt"
dataset_size = 400
val_pic_size = 256
lr = 0.0002
n_epochs = 100
n_epochs_decay = 100# 用于生成学习率(learning rate, LR)的时间表。这个时间表是一个列表,其中每个元素对应一个训练批次的学习率。
# 学习率在训练过程中会逐渐衰减,以帮助模型更好地收敛。
def get_lr():# 初始化学习率列表,初始学习率为 lr,持续 dataset_size * n_epochs 个批次lrs = [lr] * dataset_size * n_epochs# 初始化当前学习率lr_epoch = 0# 在前 n_epochs_decay 个 epoch 中,学习率线性衰减for epoch in range(n_epochs_decay):lr_epoch = lr * (n_epochs_decay - epoch) / n_epochs_decaylrs += [lr_epoch] * dataset_size# 如果总 epoch 数大于 n_epochs + n_epochs_decay,则将最后一个 epoch 的学习率保持到最后lrs += [lr_epoch] * dataset_size * (epoch_num - n_epochs_decay - n_epochs)# 返回学习率列表,并转换为 MindSpore 的 Tensor 类型return Tensor(np.array(lrs).astype(np.float32))dataset = ds.MindDataset("./dataset/dataset_pix2pix/train.mindrecord", columns_list=["input_images", "target_images"], shuffle=True, num_parallel_workers=1)
steps_per_epoch = dataset.get_dataset_size()
loss_f = nn.BCEWithLogitsLoss()
l1_loss = nn.L1Loss()def forword_dis(reala, realb):# 设置判别器损失的权重'''在某些情况下,判别器可能比生成器学习得更快或更慢,导致两者之间的不平衡。通过调整判别器的损失权重,可以控制判别器的学习速度,使其与生成器保持同步。'''lambda_dis = 0.5# 通过生成器生成假图像fakeb = net_generator(reala)# 判别器对真实图像和生成的假图像进行预测pred0 = net_discriminator(reala, fakeb)pred1 = net_discriminator(reala, realb)# 计算判别器的损失loss_d = loss_f(pred1, ops.ones_like(pred1)) + loss_f(pred0, ops.zeros_like(pred0))loss_dis = loss_d * lambda_disreturn loss_disdef forword_gan(reala, realb):lambda_gan = 0.5lambda_l1 = 100fakeb = net_generator(reala)pred0 = net_discriminator(reala, fakeb)loss_1 = loss_f(pred0, ops.ones_like(pred0))loss_2 = l1_loss(fakeb, realb)loss_gan = loss_1 * lambda_gan + loss_2 * lambda_l1return loss_gan# 优化器
d_opt = nn.Adam(net_discriminator.trainable_params(), learning_rate=get_lr(),beta1=0.5, beta2=0.999, loss_scale=1)
g_opt = nn.Adam(net_generator.trainable_params(), learning_rate=get_lr(),beta1=0.5, beta2=0.999, loss_scale=1)# 计算梯度
grad_d = value_and_grad(forword_dis, None, net_discriminator.trainable_params())
grad_g = value_and_grad(forword_gan, None, net_generator.trainable_params())# 训练方法
def train_step(reala, realb):loss_dis, d_grads = grad_d(reala, realb)loss_gan, g_grads = grad_g(reala, realb)d_opt(d_grads)g_opt(g_grads)return loss_dis, loss_ganif not os.path.isdir(ckpt_dir):os.makedirs(ckpt_dir)g_losses = []
d_losses = []
data_loader = dataset.create_dict_iterator(output_numpy=True, num_epochs=epoch_num)for epoch in range(epoch_num):for i, data in enumerate(data_loader):start_time = datetime.datetime.now()# 输入图像input_image = Tensor(data["input_images"])# 目标图像target_image = Tensor(data["target_images"])# 进行训练dis_loss, gen_loss = train_step(input_image, target_image)end_time = datetime.datetime.now()delta = (end_time - start_time).microsecondsif i % 2 == 0:print("ms per step:{:.2f} epoch:{}/{} step:{}/{} Dloss:{:.4f} Gloss:{:.4f} ".format((delta / 1000), (epoch + 1), (epoch_num), i, steps_per_epoch, float(dis_loss), float(gen_loss)))d_losses.append(dis_loss.asnumpy())g_losses.append(gen_loss.asnumpy())if (epoch + 1) == epoch_num:mindspore.save_checkpoint(net_generator, ckpt_dir + "Generator.ckpt")
模型推理
from mindspore import load_checkpoint, load_param_into_net
# 模型推理
param_g = load_checkpoint(ckpt_dir + "Generator.ckpt")
load_param_into_net(net_generator, param_g)
dataset = ds.MindDataset("./dataset/dataset_pix2pix/train.mindrecord", columns_list=["input_images", "target_images"], shuffle=True)
data_iter = next(dataset.create_dict_iterator())
predict_show = net_generator(data_iter["input_images"])
plt.figure(figsize=(10, 3), dpi=140)
for i in range(10):plt.subplot(2, 10, i + 1)plt.imshow((data_iter["input_images"][i].asnumpy().transpose(1, 2, 0) + 1) / 2)plt.axis("off")plt.subplots_adjust(wspace=0.05, hspace=0.02)plt.subplot(2, 10, i + 11)plt.imshow((predict_show[i].asnumpy().transpose(1, 2, 0) + 1) / 2)plt.axis("off")plt.subplots_adjust(wspace=0.05, hspace=0.02)
plt.show()
训练结果如下: