Pytorch从零开始实战——CycleGAN实战
本系列来源于365天深度学习训练营
原作者K同学
内容介绍
CycleGAN是一种无监督图像到图像转换模型,它的一个重要应用领域是域迁移,比如可以把一张普通的风景照变化成梵高化作,或者将游戏画面变化成真实世界画面,将一匹正常肤色的马转为斑马等等。
CycleGAN 主要解决的问题是将一个域中的图像转换到另一个域中的图像,而无需成对的训练数据。这种转换是双向的,即可以从一个域转换到另一个域,也可以反过来转换。
生成器: CycleGAN 包含两个生成器,分别用于将两个不同域的图像进行转换。例如,在从马到斑马的转换中,一个生成器负责将马的图像转换为斑马的图像,另一个生成器负责将斑马的图像转换为马的图像。生成器学习将输入图像从一个域映射到另一个域的转换函数。
判别器: CycleGAN 包含两个判别器,用于区分生成的图像和真实的图像。一个判别器用于区分生成的源图像和真实的源图像,另一个判别器用于区分生成的生成图像和真实的生成图像。判别器帮助生成器生成更逼真的图像。
损失函数:CycleGAN的Loss由三部分组成,分别为LossGAN(保证生成器和判别器相互进化,进而保证生成器能产生更真实的图片)、LossCycle(保证生成器的输出图片与输入图片只是风格不同,而内容相同)和LossIdentity(是映射损失, 即用真实的 A 当做输入, 查看生成器是否会原封不动的输出)。
数据集类
自定义的 PyTorch 数据集类 ,用于加载图像数据集并进行预处理。
import glob
import random
import osfrom torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as transformsdef to_rgb(image):rgb_image = Image.new("RGB", image.size)rgb_image.paste(image)return rgb_imageclass ImageDataset(Dataset):def __init__(self, root, transforms_=None, unaligned=False, mode="train"):self.transform = transforms.Compose(transforms_)self.unaligned = unalignedself.files_A = sorted(glob.glob(os.path.join(root, "%sA" % mode) + "/*.*"))self.files_B = sorted(glob.glob(os.path.join(root, "%sB" % mode) + "/*.*"))def __getitem__(self, index):image_A = Image.open(self.files_A[index % len(self.files_A)])if self.unaligned:image_B = Image.open(self.files_B[random.randint(0, len(self.files_B) - 1)])else:image_B = Image.open(self.files_B[index % len(self.files_B)])# Convert grayscale images to rgbif image_A.mode != "RGB":image_A = to_rgb(image_A)if image_B.mode != "RGB":image_B = to_rgb(image_B)item_A = self.transform(image_A)item_B = self.transform(image_B)return {"A": item_A, "B": item_B}def __len__(self):return max(len(self.files_A), len(self.files_B))
模型实现
遍历模型中的每一层,初始化神经网络模型中的权重。
import torch.nn as nn
import torch.nn.functional as F
import torchdef weights_init_normal(m):classname = m.__class__.__name__if classname.find("Conv") != -1:torch.nn.init.normal_(m.weight.data, 0.0, 0.02)if hasattr(m, "bias") and m.bias is not None:torch.nn.init.constant_(m.bias.data, 0.0)elif classname.find("BatchNorm2d") != -1:torch.nn.init.normal_(m.weight.data, 1.0, 0.02)torch.nn.init.constant_(m.bias.data, 0.0)
定义了一个残差块。每个残差块包含两个卷积层,使用反射填充)进行填充,然后进行卷积、实例归一化和 ReLU 激活操作。最后通过残差连接将输入和残差块的输出相加得到最终的输出。
class ResidualBlock(nn.Module):def __init__(self, in_features):super(ResidualBlock, self).__init__()self.block = nn.Sequential(nn.ReflectionPad2d(1),nn.Conv2d(in_features, in_features, 3),nn.InstanceNorm2d(in_features),nn.ReLU(inplace=True),nn.ReflectionPad2d(1),nn.Conv2d(in_features, in_features, 3),nn.InstanceNorm2d(in_features),)def forward(self, x):return x + self.block(x)
定义了基于 ResNet 结构的生成器。它通过堆叠多个残差块、卷积层和上采样层来生成图像。首先是一个初始的卷积块,然后进行下采样、残差块、上采样,最后输出目标图像。
class GeneratorResNet(nn.Module):def __init__(self, input_shape, num_residual_blocks):super(GeneratorResNet, self).__init__()channels = input_shape[0]# Initial convolution blockout_features = 64model = [nn.ReflectionPad2d(channels),nn.Conv2d(channels, out_features, 7),nn.InstanceNorm2d(out_features),nn.ReLU(inplace=True),]in_features = out_features# Downsamplingfor _ in range(2):out_features *= 2model += [nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),nn.InstanceNorm2d(out_features),nn.ReLU(inplace=True),]in_features = out_features# Residual blocksfor _ in range(num_residual_blocks):model += [ResidualBlock(out_features)]# Upsamplingfor _ in range(2):out_features //= 2model += [nn.Upsample(scale_factor=2),nn.Conv2d(in_features, out_features, 3, stride=1, padding=1),nn.InstanceNorm2d(out_features),nn.ReLU(inplace=True),]in_features = out_features# Output layermodel += [nn.ReflectionPad2d(channels), nn.Conv2d(out_features, channels, 7), nn.Tanh()]self.model = nn.Sequential(*model)def forward(self, x):return self.model(x)
定义了判别器,这个判别器由多个卷积层组成,逐渐减小特征图的大小,最后输出一个单通道的结果,表示输入图像是真实图像的概率。
class Discriminator(nn.Module):def __init__(self, input_shape):super(Discriminator, self).__init__()channels, height, width = input_shape# Calculate output shape of image discriminator (PatchGAN)self.output_shape = (1, height // 2 ** 4, width // 2 ** 4)def discriminator_block(in_filters, out_filters, normalize=True):"""Returns downsampling layers of each discriminator block"""layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]if normalize:layers.append(nn.InstanceNorm2d(out_filters))layers.append(nn.LeakyReLU(0.2, inplace=True))return layersself.model = nn.Sequential(*discriminator_block(channels, 64, normalize=False),*discriminator_block(64, 128),*discriminator_block(128, 256),*discriminator_block(256, 512),nn.ZeroPad2d((1, 0, 1, 0)),nn.Conv2d(512, 1, 4, padding=1))def forward(self, img):return self.model(img)
开始训练
Util工具类,ReplayBuffer 用于创建一个缓冲区,用于存储历史数据,并在训练过程中可能会用到。LambdaLR 则用于在训练过程中根据指定的规则调整学习率。
import random
import time
import datetime
import sysfrom torch.autograd import Variable
import torch
import numpy as npfrom torchvision.utils import save_imageclass ReplayBuffer:def __init__(self, max_size=50):assert max_size > 0, "Empty buffer or trying to create a black hole. Be careful."self.max_size = max_sizeself.data = []def push_and_pop(self, data):to_return = []for element in data.data:element = torch.unsqueeze(element, 0)if len(self.data) < self.max_size:self.data.append(element)to_return.append(element)else:if random.uniform(0, 1) > 0.5:i = random.randint(0, self.max_size - 1)to_return.append(self.data[i].clone())self.data[i] = elementelse:to_return.append(element)return Variable(torch.cat(to_return))class LambdaLR:def __init__(self, n_epochs, offset, decay_start_epoch):assert (n_epochs - decay_start_epoch) > 0, "Decay must start before the training session ends!"self.n_epochs = n_epochsself.offset = offsetself.decay_start_epoch = decay_start_epochdef step(self, epoch):return 1.0 - max(0, epoch + self.offset - self.decay_start_epoch) / (self.n_epochs - self.decay_start_epoch)
设置训练参数,包括 epoch 数、数据集名称、批量大小、学习率。接着定义模型和优化器,包括生成器、判别器、损失函数和优化器。加载数据集并进行数据预处理,设置训练和测试数据加载器。
import argparse
import itertools
from torchvision.utils import save_image, make_grid
from torch.utils.data import DataLoader
from models import *
from datasets import *
from utils import *
import torchparser = argparse.ArgumentParser()
parser.add_argument("--epoch", type=int, default=0, help="epoch to start training from")
parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training")
parser.add_argument("--dataset_name", type=str, default="monet2photo", help="name of the dataset")
parser.add_argument("--batch_size", type=int, default=1, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--decay_epoch", type=int, default=100, help="epoch from which to start lr decay")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--img_height", type=int, default=256, help="size of image height")
parser.add_argument("--img_width", type=int, default=256, help="size of image width")
parser.add_argument("--channels", type=int, default=3, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=100, help="interval between saving generator outputs")
parser.add_argument("--checkpoint_interval", type=int, default=1, help="interval between saving model checkpoints")
parser.add_argument("--n_residual_blocks", type=int, default=9, help="number of residual blocks in generator")
parser.add_argument("--lambda_cyc", type=float, default=10.0, help="cycle loss weight")
parser.add_argument("--lambda_id", type=float, default=5.0, help="identity loss weight")
opt = parser.parse_args()
print(opt)# Create sample and checkpoint directories
os.makedirs("images/%s" % opt.dataset_name, exist_ok=True)
os.makedirs("saved_models/%s" % opt.dataset_name, exist_ok=True)# Losses
criterion_GAN = torch.nn.MSELoss()
criterion_cycle = torch.nn.L1Loss()
criterion_identity = torch.nn.L1Loss()cuda = torch.cuda.is_available()input_shape = (opt.channels, opt.img_height, opt.img_width)# 初始化生成器鉴别器
G_AB = GeneratorResNet(input_shape, opt.n_residual_blocks)
G_BA = GeneratorResNet(input_shape, opt.n_residual_blocks)
D_A = Discriminator(input_shape)
D_B = Discriminator(input_shape)if cuda:G_AB = G_AB.cuda()G_BA = G_BA.cuda()D_A = D_A.cuda()D_B = D_B.cuda()criterion_GAN.cuda()criterion_cycle.cuda()criterion_identity.cuda()if opt.epoch != 0:# 加载预训练模型G_AB.load_state_dict(torch.load("saved_models/%s/G_AB_%d.pth" % (opt.dataset_name, opt.epoch)))G_BA.load_state_dict(torch.load("saved_models/%s/G_BA_%d.pth" % (opt.dataset_name, opt.epoch)))D_A.load_state_dict(torch.load("saved_models/%s/D_A_%d.pth" % (opt.dataset_name, opt.epoch)))D_B.load_state_dict(torch.load("saved_models/%s/D_B_%d.pth" % (opt.dataset_name, opt.epoch)))
else:# 初始化权重G_AB.apply(weights_init_normal)G_BA.apply(weights_init_normal)D_A.apply(weights_init_normal)D_B.apply(weights_init_normal)# Optimizers
optimizer_G = torch.optim.Adam(itertools.chain(G_AB.parameters(), G_BA.parameters()), lr=opt.lr, betas=(opt.b1, opt.b2)
)
optimizer_D_A = torch.optim.Adam(D_A.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D_B = torch.optim.Adam(D_B.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))# Learning rate update schedulers
lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(optimizer_G, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step
)
lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(optimizer_D_A, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step
)
lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(optimizer_D_B, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step
)Tensor = torch.cuda.FloatTensor if cuda else torch.Tensor# Buffers of previously generated samples
fake_A_buffer = ReplayBuffer()
fake_B_buffer = ReplayBuffer()# Image transformations
transforms_ = [transforms.Resize(int(opt.img_height * 1.12), Image.BICUBIC),transforms.RandomCrop((opt.img_height, opt.img_width)),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]# Training data loader
dataloader = DataLoader(ImageDataset("./data/%s/" % opt.dataset_name, transforms_=transforms_, unaligned=True),batch_size=opt.batch_size,shuffle=True,num_workers=opt.n_cpu,
)
# Test data loader
val_dataloader = DataLoader(ImageDataset("./data/%s/" % opt.dataset_name, transforms_=transforms_, unaligned=True, mode="test"),batch_size=5,shuffle=True,num_workers=1,
)def sample_images(batches_done):"""Saves a generated sample from the test set"""imgs = next(iter(val_dataloader))G_AB.eval()G_BA.eval()real_A = Variable(imgs["A"].type(Tensor))fake_B = G_AB(real_A)real_B = Variable(imgs["B"].type(Tensor))fake_A = G_BA(real_B)# Arange images along x-axisreal_A = make_grid(real_A, nrow=5, normalize=True)real_B = make_grid(real_B, nrow=5, normalize=True)fake_A = make_grid(fake_A, nrow=5, normalize=True)fake_B = make_grid(fake_B, nrow=5, normalize=True)# Arange images along y-axisimage_grid = torch.cat((real_A, fake_B, real_B, fake_A), 1)save_image(image_grid, "images/%s/%s.png" % (opt.dataset_name, batches_done), normalize=False)# ----------
# Training
# ----------if __name__ == '__main__':prev_time = time.time()for epoch in range(opt.epoch, opt.n_epochs):for i, batch in enumerate(dataloader):# Set model inputreal_A = Variable(batch["A"].type(Tensor))real_B = Variable(batch["B"].type(Tensor))# Adversarial ground truthsvalid = Variable(Tensor(np.ones((real_A.size(0), *D_A.output_shape))), requires_grad=False)fake = Variable(Tensor(np.zeros((real_A.size(0), *D_A.output_shape))), requires_grad=False)# ------------------# Train Generators# ------------------G_AB.train()G_BA.train()optimizer_G.zero_grad()# Identity lossloss_id_A = criterion_identity(G_BA(real_A), real_A)loss_id_B = criterion_identity(G_AB(real_B), real_B)loss_identity = (loss_id_A + loss_id_B) / 2# GAN lossfake_B = G_AB(real_A)loss_GAN_AB = criterion_GAN(D_B(fake_B), valid)fake_A = G_BA(real_B)loss_GAN_BA = criterion_GAN(D_A(fake_A), valid)loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2# Cycle lossrecov_A = G_BA(fake_B)loss_cycle_A = criterion_cycle(recov_A, real_A)recov_B = G_AB(fake_A)loss_cycle_B = criterion_cycle(recov_B, real_B)loss_cycle = (loss_cycle_A + loss_cycle_B) / 2# Total lossloss_G = loss_GAN + opt.lambda_cyc * loss_cycle + opt.lambda_id * loss_identityloss_G.backward()optimizer_G.step()# -----------------------# Train Discriminator A# -----------------------optimizer_D_A.zero_grad()# Real lossloss_real = criterion_GAN(D_A(real_A), valid)# Fake loss (on batch of previously generated samples)fake_A_ = fake_A_buffer.push_and_pop(fake_A)loss_fake = criterion_GAN(D_A(fake_A_.detach()), fake)# Total lossloss_D_A = (loss_real + loss_fake) / 2loss_D_A.backward()optimizer_D_A.step()# -----------------------# Train Discriminator B# -----------------------optimizer_D_B.zero_grad()# Real lossloss_real = criterion_GAN(D_B(real_B), valid)# Fake loss (on batch of previously generated samples)fake_B_ = fake_B_buffer.push_and_pop(fake_B)loss_fake = criterion_GAN(D_B(fake_B_.detach()), fake)# Total lossloss_D_B = (loss_real + loss_fake) / 2loss_D_B.backward()optimizer_D_B.step()loss_D = (loss_D_A + loss_D_B) / 2# --------------# Log Progress# --------------# Determine approximate time leftbatches_done = epoch * len(dataloader) + ibatches_left = opt.n_epochs * len(dataloader) - batches_donetime_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))prev_time = time.time()# Print logsys.stdout.write("\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f, adv: %f, cycle: %f, identity: %f] ETA: %s"% (epoch,opt.n_epochs,i,len(dataloader),loss_D.item(),loss_G.item(),loss_GAN.item(),loss_cycle.item(),loss_identity.item(),time_left,))# If at sample interval save imageif batches_done % opt.sample_interval == 0:sample_images(batches_done)# Update learning rateslr_scheduler_G.step()lr_scheduler_D_A.step()lr_scheduler_D_B.step()if opt.checkpoint_interval != -1 and epoch % opt.checkpoint_interval == 0:# Save model checkpointstorch.save(G_AB.state_dict(), "saved_models2/%s/G_AB_%d.pth" % (opt.dataset_name, epoch))torch.save(G_BA.state_dict(), "saved_models2/%s/G_BA_%d.pth" % (opt.dataset_name, epoch))torch.save(D_A.state_dict(), "saved_models2/%s/D_A_%d.pth" % (opt.dataset_name, epoch))torch.save(D_B.state_dict(), "saved_models2/%s/D_B_%d.pth" % (opt.dataset_name, epoch))
本次实验设备较差,算力不够。请读者在GPU机器上自行运行。
总结
CycleGAN 可以用于学习两个不同图像域之间的映射关系,使得在两个域之间进行图像转换成为可能。通过训练,模型可以学习到如何将一个图像从一个域转换到另一个域,而无需配对的训练数据,降低了数据收集和标注的成本。其提出的不同角度的损失函数,也是值得我们去学习。