CycleGAN模型解读(附源码+论文)

CycleGAN

论文链接:Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks

官方链接:pytorch-CycleGAN-and-pix2pix

老规矩,先看看效果
请添加图片描述

总体流程

先简单过一遍流程,细节在代码里说。CycleGAN有两个生成器和两个判别器。下图可以看到GeneratorA2BGeneratorB2A,假设A数据集里全是马,B数据集里全是斑马,那这两个分别就是将马生成斑马和斑马生成马的生成器。DiscriminatorADiscriminatorB就是用来判别马和斑马的判别器。

请添加图片描述

上面的流程图很清晰的表示了模型的整体过程。一般生成器将马生成斑马就结束了,但是CycleGAN表示nonono,它还将生成的斑马再还原回马。因为 CycleGAN 强调循环一致性(Cycle Consistency),它通过让生成的斑马再还原回马,确保生成的斑马仍然保留了输入马的基本特征和结构。也就是说,CycleGAN 不仅要生成目标域的图像,还要通过反向转换验证生成图像的合理性。
请添加图片描述

同样的流程还会再经历一次,只不过这里的马和斑马的位置对调了一下(输入斑马,生成马)。

生成模型的损失函数有三种,拆分为六部分,分别对应对抗损失循环一致性损失身份损失。容易想到的,对抗损失用于训练生成器 A2B 和B2A 生成能够欺骗判别器的图像,可以理解为将真马生成为假斑马能不能让判别器判断为真的。上面流程不是有一步生成了假图还要还原回去,没错,还原回去的图与真实的图的差异,即循环一致性损失身份损失是为了让生成器保留图像的风格或颜色特征,可以理解成除了马变成了斑马,其余背景部分不做变化。

将一匹马变成一匹斑马,我在没看这篇论文之前想的流程是,有一张马的图片,然后再做一张将马变为斑马的标签图,这样可以将生成的图片与这张标签图做损失计算进行训练。但CycleGAN根本不需要标签图,它只需要马的图片与斑马的图片就行,不需要什么人工处理的标签图,简直是福音,人工标注鬼知道有多麻烦。就像下图一样,素描鞋子作为输入,皮鞋作为标签图,需要一一对应。而现在,照片作为输入,标签图是毫不相干的油画图。
请添加图片描述

数据准备

这里说的详细一点,怎么准备数据。

公开数据集地址:Index of /cyclegan/datasets

请添加图片描述

下载自己想玩的数据集,比如horse2zebra就是马转斑马的。下载好的数据集放在datasets目录下。

预训练模型地址:Index of /cyclegan/pretrained_models

请添加图片描述

下载对应的预训练模型哦,比如horse2zebra.pth就是马转斑马的。下载好的模型放在checkpoints目录下。

别的配置看github上有详细说明。

预警:模型训练需要显卡显存至少8G及以上。

代码

数据处理

运行train.py。先进入unaligned_dataset.py看看数据集怎么处理的。

def __init__(self, opt):BaseDataset.__init__(self, opt)self.dir_A = os.path.join(opt.dataroot, opt.phase + 'A')  # create a path '/path/to/data/trainA'self.dir_B = os.path.join(opt.dataroot, opt.phase + 'B')  # create a path '/path/to/data/trainB'self.A_paths = sorted(make_dataset(self.dir_A, opt.max_dataset_size))  # load images from '/path/to/data/trainA'self.B_paths = sorted(make_dataset(self.dir_B, opt.max_dataset_size))  # load images from '/path/to/data/trainB'self.A_size = len(self.A_paths)  # get the size of dataset Aself.B_size = len(self.B_paths)  # get the size of dataset BbtoA = self.opt.direction == 'BtoA'input_nc = self.opt.output_nc if btoA else self.opt.input_nc  # get the number of channels of input imageoutput_nc = self.opt.input_nc if btoA else self.opt.output_nc  # get the number of channels of output imageself.transform_A = get_transform(self.opt, grayscale=(input_nc == 1))self.transform_B = get_transform(self.opt, grayscale=(output_nc == 1))

先获得训练数据集A和B的路径,然后看看路径下有多少图片,数据集A和B的图片数量可以不一样哦。因为就像我上面说的,不需要一一对应的标签图,所以取出一张A可以从数据集B中随机抽一张作为标签图。然后会通过transform做一些图片处理,比如resize、裁剪、翻转、标准化之类的。

model

进入cycle_gan_model.py看看模型怎么初始化和训练的。

self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm,not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.netG, opt.norm,not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)if self.isTrain:  # define discriminatorsself.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.netD,opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.netD,opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)

在初始化时,最重要的是生成器和判别器的定义,我们去networks.py里看看define_G和define_D的定义。

define_G

def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[]):net = Nonenorm_layer = get_norm_layer(norm_type=norm)if netG == 'resnet_9blocks':net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9)elif netG == 'resnet_6blocks':net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6)elif netG == 'unet_128':net = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout)elif netG == 'unet_256':net = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout)else:raise NotImplementedError('Generator model name [%s] is not recognized' % netG)return init_net(net, init_type, init_gain, gpu_ids)

图片卷积三件套Conv->BN->relu。不过这里的BN换了一下,上面代码用了norm_layer代替,即nn.BatchNorm2d转为nn.InstanceNorm2d,不同点在于InstanceNorm2d对每个样本的特征通道分别进行归一化。构建网络时用的是resnet_9blocks,我们进ResnetGenerator看一看。

ResnetGenerator

def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6,padding_type='reflect'):assert (n_blocks >= 0)super(ResnetGenerator, self).__init__()if type(norm_layer) == functools.partial:use_bias = norm_layer.func == nn.InstanceNorm2delse:use_bias = norm_layer == nn.InstanceNorm2dmodel = [nn.ReflectionPad2d(3),nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),norm_layer(ngf),nn.ReLU(True)]n_downsampling = 2for i in range(n_downsampling):  # add downsampling layersmult = 2 ** imodel += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),norm_layer(ngf * mult * 2),nn.ReLU(True)]mult = 2 ** n_downsamplingfor i in range(n_blocks):  # add ResNet blocksmodel += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout,use_bias=use_bias)]for i in range(n_downsampling):  # add upsampling layersmult = 2 ** (n_downsampling - i)model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),kernel_size=3, stride=2,padding=1, output_padding=1,bias=use_bias),norm_layer(int(ngf * mult / 2)),nn.ReLU(True)]model += [nn.ReflectionPad2d(3)]model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]model += [nn.Tanh()]self.model = nn.Sequential(*model)

在model初始构建时,我们发现卷积三件套上多了一件nn.ReflectionPad2d(3),这是padding的一种方式,不过不同于简单的填充,周围都填充为0,而是通过在输入特征图边界处 镜像反射 填充,如下图所示。
请添加图片描述

模型构建下面就比较简单,就是先通过卷积给图像size越卷越小,不过特征图越卷越多(3->64->128->256)。然后来点残差块ResnetBlock。生成器最后生成的图得跟原图一样大小,所以通过反卷积nn.ConvTranspose2d,给图卷回原来大小,就类似上采样。同理特征图数量也得越卷越少,卷回原来的3通道(256->128->64->3)。

define_D

我们再看看判别器怎么定义的。

def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal', init_gain=0.02, gpu_ids=[]):net = Nonenorm_layer = get_norm_layer(norm_type=norm)if netD == 'basic':  # default PatchGAN classifiernet = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer)elif netD == 'n_layers':  # more optionsnet = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer)elif netD == 'pixel':     # classify if each pixel is real or fakenet = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer)else:raise NotImplementedError('Discriminator model name [%s] is not recognized' % netD)return init_net(net, init_type, init_gain, gpu_ids)

跟生成器一样吼,先定义一个norm_layer,然后进入NLayerDiscriminator看看。

NLayerDiscriminator

def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):super(NLayerDiscriminator, self).__init__()if type(norm_layer) == functools.partial:  # no need to use bias as BatchNorm2d has affine parametersuse_bias = norm_layer.func == nn.InstanceNorm2delse:use_bias = norm_layer == nn.InstanceNorm2dkw = 4padw = 1sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]nf_mult = 1nf_mult_prev = 1for n in range(1, n_layers):  # gradually increase the number of filtersnf_mult_prev = nf_multnf_mult = min(2 ** n, 8)sequence += [nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),norm_layer(ndf * nf_mult),nn.LeakyReLU(0.2, True)]nf_mult_prev = nf_multnf_mult = min(2 ** n_layers, 8)sequence += [nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),norm_layer(ndf * nf_mult),nn.LeakyReLU(0.2, True)]sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]  # output 1 channel prediction mapself.model = nn.Sequential(*sequence)

构建过程也没啥要说的,就是不停卷积给特征图越卷越多,最后输出一个值,判定是真是假(3->64->128->256->512->1)。

初始化好模型后进行训练,主要来看一下model.optimize_parameters()这里的内容。

optimize_parameters

到cycle_gan_model.py看一下optimize_parameters。

def optimize_parameters(self):# forwardself.forward()  # compute fake images and reconstruction images.# G_A and G_B  训练生成器 判别器不工作self.set_requires_grad([self.netD_A, self.netD_B], False)  # Ds require no gradients when optimizing Gsself.optimizer_G.zero_grad()  # set G_A and G_B's gradients to zeroself.backward_G()  # calculate gradients for G_A and G_Bself.optimizer_G.step()  # update G_A and G_B's weights# D_A and D_B  训练判别器 生成器不工作self.set_requires_grad([self.netD_A, self.netD_B], True)self.optimizer_D.zero_grad()  # set D_A and D_B's gradients to zeroself.backward_D_A()  # calculate gradients for D_Aself.backward_D_B()  # calculate graidents for D_Bself.optimizer_D.step()  # update D_A and D_B's weights

简简单单包含了前向传播和反向传播,先看看前向传播forward。

forward

def forward(self):self.fake_B = self.netG_A(self.real_A)  # G_A(A)self.rec_A = self.netG_B(self.fake_B)  # G_B(G_A(A))self.fake_A = self.netG_B(self.real_B)  # G_B(B)self.rec_B = self.netG_A(self.fake_A)  # G_A(G_B(B))

四行代码很好理解。将真实图A输入生成器A2B,输出假B;将假B输入生成器B2A,还原回A;同理,将真实图B输入生成器B2A,输出假A;将假A输入生成器A2B,还原回B。

backward_G

看看生成器的反向传播,怎么做损失计算的。

def backward_G(self):"""Calculate the loss for generators G_A and G_B"""lambda_idt = self.opt.lambda_identity  # 一些缩放因子lambda_A = self.opt.lambda_Alambda_B = self.opt.lambda_B# Identity lossif lambda_idt > 0:# G_A should be identity if real_B is fed: ||G_A(B) - B||self.idt_A = self.netG_A(self.real_B)  # 将真实B传入A->B网络 也能生成假Bself.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt# G_B should be identity if real_A is fed: ||G_B(A) - A||self.idt_B = self.netG_B(self.real_A)  # 将真实A传入B->A网络 也能生成假Aself.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idtelse:self.loss_idt_A = 0self.loss_idt_B = 0# GAN loss D_A(G_A(A))self.loss_G_A = self.criterionGAN(self.netD_A(self.fake_B), True)  # 希望判别器将假B判为真# GAN loss D_B(G_B(B))self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True)# Forward cycle loss || G_B(G_A(A)) - A||self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A  # 还原A与真实A的差异损失# Backward cycle loss || G_A(G_B(B)) - B||self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B# combined loss and calculate gradientsself.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_Bself.loss_G.backward()

这里的self.idt_Aself.idt_B就是我上面说的身份损失,公式如下。

请添加图片描述

式中,G和F是生成器A2B和B2A。

lambda_idt就是个加权参数​,缩放损失用的。生成器A2B的作用不是输入A生成B嘛,不过这里做损失是通过输入B,输出B,这两B之间的差异做损失。论文中说引入这个损失有助于促进映射以保持输入和输出之间的颜色组成。比如莫奈的画作与Flickr照片之间的映射时,生成器经常将 白天 的画作映射到 日落 时拍摄的照片,如下图所示。
请添加图片描述

self.loss_G_Aself.loss_G_B是我上面说的对抗损失,公式如下。
请添加图片描述

这个损失很好理解,就是我生成器生成的假图,我希望判别器判为真的。这里的判别是对每个像素点都判别是真是假,然后取平均。如果输入图片size是[256,256],经过判别器输出图片size是[30,30],对这[30,30]的每个点判断是真是假。具体代码可以自己去networks.py的GANLoss看一下。

self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device)  # define GAN loss.

剩下最后的self.loss_cycle_Aself.loss_cycle_B就是循环一致性损失了,公式如下。

请添加图片描述

这个更好理解,就是看看还原回去的图与输入的图的差异损失。下图是论文中,输入->输出->还原(重建)的过程。
请添加图片描述

====================================================================================

train需要电脑配置还挺高的,大家可以试试test,配置一下参数就行,例如下面

python test.py --dataroot datasets/horse2zebra/testA --name horse2zebra_pretrained --model test --no_dropout

数据集放在datasets/horse2zebra/testA目录下,模型放在checkpoints/horse2zebra_pretrained目录下,最后的结果会生成一个result目录下。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/8895.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

ue5 GAS制作一个技能,技能冷却,给剑添加碰撞预设,打击敌人

总结: 新建文件夹 ability 取名BP_BaseAbility 新建一个技能GAB_Melee 上面技能GAB_Melee和技能基类BP_BaseAbility 进入技能GAB_Melee,添加打印火云掌 给这个技能添加标签 点这个号 这样命名,小心这个点(.&#xff09…

工作总结:git篇

文章目录 前言基础Gerrit1.克隆2.新建本地分支和checkout3.添加到暂存区新增文件到暂存区修改已经添加到暂存区的文件取消添加到暂存区的文件 4.提交到本地仓库在不重复提交的情况下,修改本次提交 5.提交到远程仓库6.评审其他辅助命令 前言 目前也算是工作一段时间…

ESP32 I2S音频总线学习笔记(二):I2S读取INMP441音频数据

简介 在这个系列的上一篇文章中,我们介绍了ESP32 I2S音频总线的相关知识,简要了解了什么是I2S总线、它的通信格式,以及相关的底层API函数。没有看过上篇文章的可以点击文章进行回顾: ESP32 I2S音频总线学习笔记(一&a…

(学习总结21)C++11 异常与智能指针

C11 异常与智能指针 异常异常的概念异常的抛出和捕获栈展开查找匹配的处理代码异常重新抛出异常安全问题异常规范标准库的异常 智能指针RAII 和智能指针的设计思路智能指针的使用场景分析C标准库智能指针的使用weak_ptr 和 shared_ptr循环引用weak_ptrshared_ptr 循环引用问题 …

智能调度体系与自动驾驶技术优化运输配送效率的研究——兼论开源AI智能名片2+1链动模式S2B2C商城小程序的应用潜力

摘要:随着全球化和数字化进程的加速,消费者需求日益呈现出碎片化和个性化的趋势,这对物流运输行业提出了前所未有的挑战。传统的物流调度体系与调度方式已难以满足当前复杂多变的物流需求,因此,物流企业必须积极引入大…

AndroidCompose Navigation导航精通1-基本页面导航与ViewPager

文章目录 前言基本页面导航库依赖导航核心部件简单NavHost实现ViewPagerPager切换逻辑图阐述Pager导航实战前言 在当今的移动应用开发中,导航是用户与应用交互的核心环节。随着 Android Compose 的兴起,它为开发者提供了一种全新的、声明式的方式来构建用户界面,同时也带来…

noteboolm 使用笔记

今天心血来潮,想要体验下AInotebook,看看最新的软件能够做到什么程度。 于是来到了notebooklm,这是一个google推出的AI笔记本的网站,我想知道我们能在上面做一些怎么样有趣的事情! 网址:https://notebookl…

JAVA 接口、抽象类的关系和用处 详细解析

接口 - Java教程 - 廖雪峰的官方网站 一个 抽象类 如果实现了一个接口,可以只选择实现接口中的 部分方法(所有的方法都要有,可以一部分已经写具体,另一部分继续保留抽象),原因在于: 抽象类本身…

ReactNative react-devtools 夜神模拟器连调

目录 一、安装react-devtools 二、在package.json中配置启动项 三、联动 一、安装react-devtools yarn add react-devtools5.3.1 -D 这里选择5.3.1版本,因为高版本可能与夜神模拟器无法联动,导致部分功能无法正常使用。 二、在package.json中配置启…

关于使用Mybatis-plus的TableNameHandler动态表名处理器实现分表业务的详细介绍

引言 随着互联网应用的快速发展,数据量呈爆炸式增长。传统的单表设计在面对海量数据时显得力不从心,容易出现性能瓶颈、查询效率低下等问题。为了提高数据库的扩展性和响应速度,分表(Sharding)成为了一种常见的解决方案…

【开源免费】基于Vue和SpringBoot的在线文档管理系统(附论文)

本文项目编号 T 038 ,文末自助获取源码 \color{red}{T038,文末自助获取源码} T038,文末自助获取源码 目录 一、系统介绍二、演示录屏三、启动教程四、功能截图五、文案资料5.1 选题背景5.2 国内外研究现状5.3 可行性分析 六、核心代码6.1 查…

智慧园区系统分类及其在提升企业管理效率中的创新应用探讨

内容概要 智慧园区的概念已经逐渐深入人心,成为现代城市发展中不可或缺的一部分。随着信息技术的飞速发展和数字化转型的不断推进,一系列智慧园区管理系统应运而生。这些系统不仅帮助企业提高了管理效率,还在多个方面激发了创新。 首先&…

图片上传实现图片预览的功能

文章目录 图片上传实现图片预览的功能一、引言二、拖拽上传实现预览1、HTML结构与样式2、JavaScript实现拖拽逻辑 三、选择文件上传实现预览1、HTML结构2、JavaScript实现预览逻辑 四、使用示例五、总结 图片上传实现图片预览的功能 一、引言 在现代网页设计中,图片…

电力晶体管(GTR)全控性器件

电力晶体管(Giant Transistor,GTR)是一种全控性器件,以下是关于它的详细介绍:(模电普通晶体管三极管进行对比学习) 基本概念 GTR是一种耐高电压、大电流的双极结型晶体管(BJT&am…

Linux - 进程间通信(2)

目录 2、进程池 1)理解进程池 2)进程池的实现 整体框架: a. 加载任务 b. 先描述,再组织 I. 先描述 II. 再组织 c. 创建信道和子进程 d. 通过channel控制子进程 e. 回收管道和子进程 问题1: 解答1&#xff…

【阅读笔记】New Edge Diected Interpolation,NEDI算法,待续

一、概述 由Li等提出的新的边缘指导插值(New Edge—Di-ected Interpolation,NEDI)算法是一种具有良好边缘保持效果的新算法,它利用低分辨率图像与高分辨率图像的局部协方差问的几何对偶性来对高分辨率图像进行自适应插值。 2001年Xin Li和M.T. Orchard…

Windows安装Miniconda和PySide6以及配置PyCharm

目录 1. 选择Miniconda 2. 下载Miniconda 3. 安装Miniconda 4. 在base环境下创建pyside6环境 5. 安装pyside6环境 6. 配置PyCharm环境 7. 运行第一个程序效果 1. 选择Miniconda 选择Miniconda而没有选择Anaconda,是因为它是一个更小的Anaconda发行版&#x…

Linux之内存管理前世今生(一)

一个程序(如王者荣耀)平常是存储在硬盘上的,运行时才把这个程序载入内存,CPU才能执行。 问题: 这个程序载入内存的哪个位置呢?载入内核所在的空间吗?系统直接挂了。 一、虚拟内存 1.1 内存分…

Java基于SSM框架的互助学习平台小程序【附源码、文档】

博主介绍:✌IT徐师兄、7年大厂程序员经历。全网粉丝15W、csdn博客专家、掘金/华为云//InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ 🍅文末获取源码联系🍅 👇🏻 精彩专栏推荐订阅👇&#x1f3…

【Rust自学】16.3. 共享状态的并发

喜欢的话别忘了点赞、收藏加关注哦(加关注即可阅读全文),对接下来的教程有兴趣的可以关注专栏。谢谢喵!(・ω・) 16.3.1. 使用共享来实现并发 还记得Go语言有一句名言是这么说的:Do not commun…