昇思25天学习打卡营第9天|生成式

昇思25天学习打卡营第9天


文章目录

  • 昇思25天学习打卡营第9天
    • CycleGAN图像风格迁移互换
    • 模型介绍
      • 模型简介
      • 模型结构
    • 数据集
      • 数据集下载
      • 数据集加载
      • 可视化
    • 构建生成器
    • 构建判别器
    • 优化器和损失函数
    • 前向计算
    • 计算梯度和反向传播
    • 模型训练
    • 模型推理
    • 参考
    • 打卡记录


CycleGAN图像风格迁移互换

本案例运行需要较大内存,建议在Ascend/GPU上运行。

模型介绍

模型简介

CycleGAN(Cycle Generative Adversarial Network) 即循环对抗生成网络,来自论文 Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks 。该模型实现了一种在没有配对示例的情况下学习将图像从源域 X 转换到目标域 Y 的方法。

该模型一个重要应用领域是域迁移(Domain Adaptation),可以通俗地理解为图像风格迁移。其实在 CycleGAN 之前,就已经有了域迁移模型,比如 Pix2Pix ,但是 Pix2Pix 要求训练数据必须是成对的,而现实生活中,要找到两个域(画风)中成对出现的图片是相当困难的,因此 CycleGAN 诞生了,它只需要两种域的数据,而不需要他们有严格对应关系,是一种新的无监督的图像迁移网络。

模型结构

CycleGAN 网络本质上是由两个镜像对称的 GAN 网络组成,其结构如下图所示(图片来源于原论文):
CycleGAN
为了方便理解,这里以苹果和橘子为例介绍。上图中 X X X 可以理解为苹果, Y Y Y 为橘子; G G G 为将苹果生成橘子风格的生成器, F F F 为将橘子生成的苹果风格的生成器, D X D_{X} DX D Y D_{Y} DY 为其相应判别器,具体生成器和判别器的结构可见下文代码。模型最终能够输出两个模型的权重,分别将两种图像的风格进行彼此迁移,生成新的图像。

该模型一个很重要的部分就是损失函数,在所有损失里面循环一致损失(Cycle Consistency Loss)是最重要的。循环损失的计算过程如下图所示(图片来源于原论文):

CycleGAN_1
图中苹果图片 x x x 经过生成器 G G G 得到伪橘子 Y ^ \hat{Y} Y^,然后将伪橘子 Y ^ \hat{Y} Y^ 结果送进生成器 F F F 又产生苹果风格的结果 x ^ \hat{x} x^,最后将生成的苹果风格结果 x ^ \hat{x} x^ 与原苹果图片 x x x 一起计算出循环一致损失,反之亦然。循环损失捕捉了这样的直觉,即如果我们从一个域转换到另一个域,然后再转换回来,我们应该到达我们开始的地方。详细的训练过程见下文代码。

数据集

本案例使用的数据集里面的图片来源于ImageNet,该数据集共有17个数据包,本文只使用了其中的苹果橘子部分。图像被统一缩放为256×256像素大小,其中用于训练的苹果图片996张、橘子图片1020张,用于测试的苹果图片266张、橘子图片248张。

这里对数据进行了随机裁剪、水平随机翻转和归一化的预处理,为了将重点聚焦到模型,此处将数据预处理后的结果转换为 MindRecord 格式的数据,以省略大部分数据预处理的代码。

数据集下载

使用 download 接口下载数据集,并将下载后的数据集自动解压到当前目录下。数据下载之前需要使用 pip install download 安装 download 包。

%%capture captured_output
# 实验环境已经预装了mindspore==2.2.14,如需更换mindspore版本,可更改下面mindspore的版本号
!pip uninstall mindspore -y
!pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14
# 查看当前 mindspore 版本
!pip show mindspore
from download import downloadurl = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/models/application/CycleGAN_apple2orange.zip"download(url, ".", kind="zip", replace=True)

数据集加载

使用 MindSpore 的 MindDataset 接口读取和解析数据集。

from mindspore.dataset import MindDataset# 读取MindRecord格式数据
name_mr = "./CycleGAN_apple2orange/apple2orange_train.mindrecord"
data = MindDataset(dataset_files=name_mr)
print("Datasize: ", data.get_dataset_size())batch_size = 1
dataset = data.batch(batch_size)
datasize = dataset.get_dataset_size()结果输出:
Datasize:  1019

可视化

通过 create_dict_iterator 函数将数据转换成字典迭代器,然后使用 matplotlib 模块可视化部分训练数据。

import numpy as np
import matplotlib.pyplot as pltmean = 0.5 * 255
std = 0.5 * 255plt.figure(figsize=(12, 5), dpi=60)
for i, data in enumerate(dataset.create_dict_iterator()):if i < 5:show_images_a = data["image_A"].asnumpy()show_images_b = data["image_B"].asnumpy()plt.subplot(2, 5, i+1)show_images_a = (show_images_a[0] * std + mean).astype(np.uint8).transpose((1, 2, 0))plt.imshow(show_images_a)plt.axis("off")plt.subplot(2, 5, i+6)show_images_b = (show_images_b[0] * std + mean).astype(np.uint8).transpose((1, 2, 0))plt.imshow(show_images_b)plt.axis("off")else:break
plt.show()

训练数据可视化

构建生成器

本案例生成器的模型结构参考的 ResNet 模型的结构,参考原论文,对于128×128大小的输入图片采用6个残差块相连,图片大小为256×256以上的需要采用9个残差块相连,所以本文网络有9个残差块相连,超参数 n_layers 参数控制残差块数。

生成器的结构如下所示:
CycleGAN_2
具体的模型结构请参照下文代码:

import mindspore.nn as nn
import mindspore.ops as ops
from mindspore.common.initializer import Normalweight_init = Normal(sigma=0.02)class ConvNormReLU(nn.Cell):def __init__(self, input_channel, out_planes, kernel_size=4, stride=2, alpha=0.2, norm_mode='instance',pad_mode='CONSTANT', use_relu=True, padding=None, transpose=False):super(ConvNormReLU, self).__init__()norm = nn.BatchNorm2d(out_planes)if norm_mode == 'instance':norm = nn.BatchNorm2d(out_planes, affine=False)has_bias = (norm_mode == 'instance')if padding is None:padding = (kernel_size - 1) // 2if pad_mode == 'CONSTANT':if transpose:conv = nn.Conv2dTranspose(input_channel, out_planes, kernel_size, stride, pad_mode='same',has_bias=has_bias, weight_init=weight_init)else:conv = nn.Conv2d(input_channel, out_planes, kernel_size, stride, pad_mode='pad',has_bias=has_bias, padding=padding, weight_init=weight_init)layers = [conv, norm]else:paddings = ((0, 0), (0, 0), (padding, padding), (padding, padding))pad = nn.Pad(paddings=paddings, mode=pad_mode)if transpose:conv = nn.Conv2dTranspose(input_channel, out_planes, kernel_size, stride, pad_mode='pad',has_bias=has_bias, weight_init=weight_init)else:conv = nn.Conv2d(input_channel, out_planes, kernel_size, stride, pad_mode='pad',has_bias=has_bias, weight_init=weight_init)layers = [pad, conv, norm]if use_relu:relu = nn.ReLU()if alpha > 0:relu = nn.LeakyReLU(alpha)layers.append(relu)self.features = nn.SequentialCell(layers)def construct(self, x):output = self.features(x)return outputclass ResidualBlock(nn.Cell):def __init__(self, dim, norm_mode='instance', dropout=False, pad_mode="CONSTANT"):super(ResidualBlock, self).__init__()self.conv1 = ConvNormReLU(dim, dim, 3, 1, 0, norm_mode, pad_mode)self.conv2 = ConvNormReLU(dim, dim, 3, 1, 0, norm_mode, pad_mode, use_relu=False)self.dropout = dropoutif dropout:self.dropout = nn.Dropout(p=0.5)def construct(self, x):out = self.conv1(x)if self.dropout:out = self.dropout(out)out = self.conv2(out)return x + outclass ResNetGenerator(nn.Cell):def __init__(self, input_channel=3, output_channel=64, n_layers=9, alpha=0.2, norm_mode='instance', dropout=False,pad_mode="CONSTANT"):super(ResNetGenerator, self).__init__()self.conv_in = ConvNormReLU(input_channel, output_channel, 7, 1, alpha, norm_mode, pad_mode=pad_mode)self.down_1 = ConvNormReLU(output_channel, output_channel * 2, 3, 2, alpha, norm_mode)self.down_2 = ConvNormReLU(output_channel * 2, output_channel * 4, 3, 2, alpha, norm_mode)layers = [ResidualBlock(output_channel * 4, norm_mode, dropout=dropout, pad_mode=pad_mode)] * n_layersself.residuals = nn.SequentialCell(layers)self.up_2 = ConvNormReLU(output_channel * 4, output_channel * 2, 3, 2, alpha, norm_mode, transpose=True)self.up_1 = ConvNormReLU(output_channel * 2, output_channel, 3, 2, alpha, norm_mode, transpose=True)if pad_mode == "CONSTANT":self.conv_out = nn.Conv2d(output_channel, 3, kernel_size=7, stride=1, pad_mode='pad',padding=3, weight_init=weight_init)else:pad = nn.Pad(paddings=((0, 0), (0, 0), (3, 3), (3, 3)), mode=pad_mode)conv = nn.Conv2d(output_channel, 3, kernel_size=7, stride=1, pad_mode='pad', weight_init=weight_init)self.conv_out = nn.SequentialCell([pad, conv])def construct(self, x):x = self.conv_in(x)x = self.down_1(x)x = self.down_2(x)x = self.residuals(x)x = self.up_2(x)x = self.up_1(x)output = self.conv_out(x)return ops.tanh(output)# 实例化生成器
net_rg_a = ResNetGenerator()
net_rg_a.update_parameters_name('net_rg_a.')net_rg_b = ResNetGenerator()
net_rg_b.update_parameters_name('net_rg_b.')

构建判别器

判别器其实是一个二分类网络模型,输出判定该图像为真实图的概率。网络模型使用的是 Patch 大小为 70x70 的 PatchGANs 模型。通过一系列的 Conv2dBatchNorm2dLeakyReLU 层对其进行处理,最后通过 Sigmoid 激活函数得到最终概率。

# 定义判别器
class Discriminator(nn.Cell):def __init__(self, input_channel=3, output_channel=64, n_layers=3, alpha=0.2, norm_mode='instance'):super(Discriminator, self).__init__()kernel_size = 4layers = [nn.Conv2d(input_channel, output_channel, kernel_size, 2, pad_mode='pad', padding=1, weight_init=weight_init),nn.LeakyReLU(alpha)]nf_mult = output_channelfor i in range(1, n_layers):nf_mult_prev = nf_multnf_mult = min(2 ** i, 8) * output_channellayers.append(ConvNormReLU(nf_mult_prev, nf_mult, kernel_size, 2, alpha, norm_mode, padding=1))nf_mult_prev = nf_multnf_mult = min(2 ** n_layers, 8) * output_channellayers.append(ConvNormReLU(nf_mult_prev, nf_mult, kernel_size, 1, alpha, norm_mode, padding=1))layers.append(nn.Conv2d(nf_mult, 1, kernel_size, 1, pad_mode='pad', padding=1, weight_init=weight_init))self.features = nn.SequentialCell(layers)def construct(self, x):output = self.features(x)return output# 判别器初始化
net_d_a = Discriminator()
net_d_a.update_parameters_name('net_d_a.')net_d_b = Discriminator()
net_d_b.update_parameters_name('net_d_b.')

优化器和损失函数

根据不同模型需要单独的设置优化器,这是训练过程决定的。

对生成器 G G G 及其判别器 D Y D_{Y} DY ,目标损失函数定义为:

L G A N ( G , D Y , X , Y ) = E y − p d a t a ( y ) [ l o g D Y ( y ) ] + E x − p d a t a ( x ) [ l o g ( 1 − D Y ( G ( x ) ) ) ] L_{GAN}(G,D_Y,X,Y)=E_{y-p_{data}(y)}[logD_Y(y)]+E_{x-p_{data}(x)}[log(1-D_Y(G(x)))] LGAN(G,DY,X,Y)=Eypdata(y)[logDY(y)]+Expdata(x)[log(1DY(G(x)))]

其中 G G G 试图生成看起来与 Y Y Y 中的图像相似的图像 G ( x ) G(x) G(x) ,而 D Y D_{Y} DY 的目标是区分翻译样本 G ( x ) G(x) G(x) 和真实样本 y y y ,生成器的目标是最小化这个损失函数以此来对抗判别器。即 $ min_{G} max_{D_{Y}}L_{GAN}(G,D_{Y} ,X,Y )$ 。

单独的对抗损失不能保证所学函数可以将单个输入映射到期望的输出,为了进一步减少可能的映射函数的空间,学习到的映射函数应该是周期一致的,例如对于 X X X 的每个图像 x x x ,图像转换周期应能够将 x x x 带回原始图像,可以称之为正向循环一致性,即 x → G ( x ) → F ( G ( x ) ) ≈ x x→G(x)→F(G(x))\approx x xG(x)F(G(x))x 。对于 Y Y Y ,类似的 x → G ( x ) → F ( G ( x ) ) ≈ x x→G(x)→F(G(x))\approx x xG(x)F(G(x))x 。可以理解采用了一个循环一致性损失来激励这种行为。

循环一致损失函数定义如下:

L c y c ( G , F ) = E x − p d a t a ( x ) [ ∥ F ( G ( x ) ) − x ∥ 1 ] + E y − p d a t a ( y ) [ ∥ G ( F ( y ) ) − y ∥ 1 ] L_{cyc}(G,F)=E_{x-p_{data}(x)}[\Vert F(G(x))-x\Vert_{1}]+E_{y-p_{data}(y)}[\Vert G(F(y))-y\Vert_{1}] Lcyc(G,F)=Expdata(x)[F(G(x))x1]+Eypdata(y)[G(F(y))y1]

循环一致损失能够保证重建图像 F ( G ( x ) ) F(G(x)) F(G(x)) 与输入图像 x x x 紧密匹配。

# 构建生成器,判别器优化器
optimizer_rg_a = nn.Adam(net_rg_a.trainable_params(), learning_rate=0.0002, beta1=0.5)
optimizer_rg_b = nn.Adam(net_rg_b.trainable_params(), learning_rate=0.0002, beta1=0.5)optimizer_d_a = nn.Adam(net_d_a.trainable_params(), learning_rate=0.0002, beta1=0.5)
optimizer_d_b = nn.Adam(net_d_b.trainable_params(), learning_rate=0.0002, beta1=0.5)# GAN网络损失函数,这里最后一层不使用sigmoid函数
loss_fn = nn.MSELoss(reduction='mean')
l1_loss = nn.L1Loss("mean")def gan_loss(predict, target):target = ops.ones_like(predict) * targetloss = loss_fn(predict, target)return loss

前向计算

搭建模型前向计算损失的过程,过程如下代码。

为了减少模型振荡[1],这里遵循 Shrivastava 等人的策略[2],使用生成器生成图像的历史数据而不是生成器生成的最新图像数据来更新鉴别器。这里创建 image_pool 函数,保留了一个图像缓冲区,用于存储生成器生成前的50个图像。

import mindspore as ms# 前向计算def generator(img_a, img_b):fake_a = net_rg_b(img_b)fake_b = net_rg_a(img_a)rec_a = net_rg_b(fake_b)rec_b = net_rg_a(fake_a)identity_a = net_rg_b(img_a)identity_b = net_rg_a(img_b)return fake_a, fake_b, rec_a, rec_b, identity_a, identity_blambda_a = 10.0
lambda_b = 10.0
lambda_idt = 0.5def generator_forward(img_a, img_b):true = Tensor(True, dtype.bool_)fake_a, fake_b, rec_a, rec_b, identity_a, identity_b = generator(img_a, img_b)loss_g_a = gan_loss(net_d_b(fake_b), true)loss_g_b = gan_loss(net_d_a(fake_a), true)loss_c_a = l1_loss(rec_a, img_a) * lambda_aloss_c_b = l1_loss(rec_b, img_b) * lambda_bloss_idt_a = l1_loss(identity_a, img_a) * lambda_a * lambda_idtloss_idt_b = l1_loss(identity_b, img_b) * lambda_b * lambda_idtloss_g = loss_g_a + loss_g_b + loss_c_a + loss_c_b + loss_idt_a + loss_idt_breturn fake_a, fake_b, loss_g, loss_g_a, loss_g_b, loss_c_a, loss_c_b, loss_idt_a, loss_idt_bdef generator_forward_grad(img_a, img_b):_, _, loss_g, _, _, _, _, _, _ = generator_forward(img_a, img_b)return loss_gdef discriminator_forward(img_a, img_b, fake_a, fake_b):false = Tensor(False, dtype.bool_)true = Tensor(True, dtype.bool_)d_fake_a = net_d_a(fake_a)d_img_a = net_d_a(img_a)d_fake_b = net_d_b(fake_b)d_img_b = net_d_b(img_b)loss_d_a = gan_loss(d_fake_a, false) + gan_loss(d_img_a, true)loss_d_b = gan_loss(d_fake_b, false) + gan_loss(d_img_b, true)loss_d = (loss_d_a + loss_d_b) * 0.5return loss_ddef discriminator_forward_a(img_a, fake_a):false = Tensor(False, dtype.bool_)true = Tensor(True, dtype.bool_)d_fake_a = net_d_a(fake_a)d_img_a = net_d_a(img_a)loss_d_a = gan_loss(d_fake_a, false) + gan_loss(d_img_a, true)return loss_d_adef discriminator_forward_b(img_b, fake_b):false = Tensor(False, dtype.bool_)true = Tensor(True, dtype.bool_)d_fake_b = net_d_b(fake_b)d_img_b = net_d_b(img_b)loss_d_b = gan_loss(d_fake_b, false) + gan_loss(d_img_b, true)return loss_d_b# 保留了一个图像缓冲区,用来存储之前创建的50个图像
pool_size = 50
def image_pool(images):num_imgs = 0image1 = []if isinstance(images, Tensor):images = images.asnumpy()return_images = []for image in images:if num_imgs < pool_size:num_imgs = num_imgs + 1image1.append(image)return_images.append(image)else:if random.uniform(0, 1) > 0.5:random_id = random.randint(0, pool_size - 1)tmp = image1[random_id].copy()image1[random_id] = imagereturn_images.append(tmp)else:return_images.append(image)output = Tensor(return_images, ms.float32)if output.ndim != 4:raise ValueError("img should be 4d, but get shape {}".format(output.shape))return output

计算梯度和反向传播

其中梯度计算也是分开不同的模型来进行的,详情见如下代码:

from mindspore import value_and_grad# 实例化求梯度的方法
grad_g_a = value_and_grad(generator_forward_grad, None, net_rg_a.trainable_params())
grad_g_b = value_and_grad(generator_forward_grad, None, net_rg_b.trainable_params())grad_d_a = value_and_grad(discriminator_forward_a, None, net_d_a.trainable_params())
grad_d_b = value_and_grad(discriminator_forward_b, None, net_d_b.trainable_params())# 计算生成器的梯度,反向传播更新参数
def train_step_g(img_a, img_b):net_d_a.set_grad(False)net_d_b.set_grad(False)fake_a, fake_b, lg, lga, lgb, lca, lcb, lia, lib = generator_forward(img_a, img_b)_, grads_g_a = grad_g_a(img_a, img_b)_, grads_g_b = grad_g_b(img_a, img_b)optimizer_rg_a(grads_g_a)optimizer_rg_b(grads_g_b)return fake_a, fake_b, lg, lga, lgb, lca, lcb, lia, lib# 计算判别器的梯度,反向传播更新参数
def train_step_d(img_a, img_b, fake_a, fake_b):net_d_a.set_grad(True)net_d_b.set_grad(True)loss_d_a, grads_d_a = grad_d_a(img_a, fake_a)loss_d_b, grads_d_b = grad_d_b(img_b, fake_b)loss_d = (loss_d_a + loss_d_b) * 0.5optimizer_d_a(grads_d_a)optimizer_d_b(grads_d_b)return loss_d

模型训练

训练分为两个主要部分:训练判别器和训练生成器,在前文的判别器损失函数中,论文采用了最小二乘损失代替负对数似然目标。

  • 训练判别器:训练判别器的目的是最大程度地提高判别图像真伪的概率。按照论文的方法需要训练判别器来最小化 E y − p d a t a ( y ) [ ( D ( y ) − 1 ) 2 ] E_{y-p_{data}(y)}[(D(y)-1)^2] Eypdata(y)[(D(y)1)2]

  • 训练生成器:如 CycleGAN 论文所述,我们希望通过最小化 E x − p d a t a ( x ) [ ( D ( G ( x ) − 1 ) 2 ] E_{x-p_{data}(x)}[(D(G(x)-1)^2] Expdata(x)[(D(G(x)1)2] 来训练生成器,以产生更好的虚假图像。

下面定义了生成器和判别器的训练过程:

import os
import time
import random
import numpy as np
from PIL import Image
from mindspore import Tensor, save_checkpoint
from mindspore import dtype# 由于时间原因,epochs设置为1,可根据需求进行调整
epochs = 1
save_step_num = 80
save_checkpoint_epochs = 1
save_ckpt_dir = './train_ckpt_outputs/'print('Start training!')for epoch in range(epochs):g_loss = []d_loss = []start_time_e = time.time()for step, data in enumerate(dataset.create_dict_iterator()):start_time_s = time.time()img_a = data["image_A"]img_b = data["image_B"]res_g = train_step_g(img_a, img_b)fake_a = res_g[0]fake_b = res_g[1]res_d = train_step_d(img_a, img_b, image_pool(fake_a), image_pool(fake_b))loss_d = float(res_d.asnumpy())step_time = time.time() - start_time_sres = []for item in res_g[2:]:res.append(float(item.asnumpy()))g_loss.append(res[0])d_loss.append(loss_d)if step % save_step_num == 0:print(f"Epoch:[{int(epoch + 1):>3d}/{int(epochs):>3d}], "f"step:[{int(step):>4d}/{int(datasize):>4d}], "f"time:{step_time:>3f}s,\n"f"loss_g:{res[0]:.2f}, loss_d:{loss_d:.2f}, "f"loss_g_a: {res[1]:.2f}, loss_g_b: {res[2]:.2f}, "f"loss_c_a: {res[3]:.2f}, loss_c_b: {res[4]:.2f}, "f"loss_idt_a: {res[5]:.2f}, loss_idt_b: {res[6]:.2f}")epoch_cost = time.time() - start_time_eper_step_time = epoch_cost / datasizemean_loss_d, mean_loss_g = sum(d_loss) / datasize, sum(g_loss) / datasizeprint(f"Epoch:[{int(epoch + 1):>3d}/{int(epochs):>3d}], "f"epoch time:{epoch_cost:.2f}s, per step time:{per_step_time:.2f}, "f"mean_g_loss:{mean_loss_g:.2f}, mean_d_loss:{mean_loss_d :.2f}")if epoch % save_checkpoint_epochs == 0:os.makedirs(save_ckpt_dir, exist_ok=True)save_checkpoint(net_rg_a, os.path.join(save_ckpt_dir, f"g_a_{epoch}.ckpt"))save_checkpoint(net_rg_b, os.path.join(save_ckpt_dir, f"g_b_{epoch}.ckpt"))save_checkpoint(net_d_a, os.path.join(save_ckpt_dir, f"d_a_{epoch}.ckpt"))save_checkpoint(net_d_b, os.path.join(save_ckpt_dir, f"d_b_{epoch}.ckpt"))print('End of training!')结果输出:
Start training!
Epoch:[  1/  1], step:[   0/1019], time:277.880039s,
loss_g:20.03, loss_d:0.97, loss_g_a: 0.98, loss_g_b: 0.99, loss_c_a: 5.37, loss_c_b: 6.66, loss_idt_a: 2.68, loss_idt_b: 3.35
Epoch:[  1/  1], step:[  80/1019], time:0.460943s,
loss_g:10.59, loss_d:0.30, loss_g_a: 0.69, loss_g_b: 0.63, loss_c_a: 2.70, loss_c_b: 3.69, loss_idt_a: 1.20, loss_idt_b: 1.68
Epoch:[  1/  1], step:[ 160/1019], time:0.452814s,
loss_g:7.94, loss_d:0.30, loss_g_a: 0.65, loss_g_b: 0.54, loss_c_a: 2.66, loss_c_b: 2.00, loss_idt_a: 1.16, loss_idt_b: 0.93
Epoch:[  1/  1], step:[ 240/1019], time:0.649191s,
loss_g:8.74, loss_d:0.50, loss_g_a: 0.39, loss_g_b: 0.25, loss_c_a: 2.90, loss_c_b: 2.83, loss_idt_a: 1.25, loss_idt_b: 1.11
Epoch:[  1/  1], step:[ 320/1019], time:0.452768s,
loss_g:5.46, loss_d:0.51, loss_g_a: 0.41, loss_g_b: 0.18, loss_c_a: 1.47, loss_c_b: 1.80, loss_idt_a: 0.72, loss_idt_b: 0.87
Epoch:[  1/  1], step:[ 400/1019], time:0.449897s,
loss_g:5.78, loss_d:0.78, loss_g_a: 0.08, loss_g_b: 0.24, loss_c_a: 2.14, loss_c_b: 1.74, loss_idt_a: 0.76, loss_idt_b: 0.82
Epoch:[  1/  1], step:[ 480/1019], time:0.457811s,
loss_g:5.69, loss_d:0.28, loss_g_a: 0.57, loss_g_b: 0.52, loss_c_a: 1.40, loss_c_b: 1.98, loss_idt_a: 0.48, loss_idt_b: 0.73
Epoch:[  1/  1], step:[ 560/1019], time:0.457280s,
loss_g:5.23, loss_d:0.57, loss_g_a: 0.23, loss_g_b: 0.33, loss_c_a: 2.01, loss_c_b: 1.26, loss_idt_a: 0.81, loss_idt_b: 0.59
Epoch:[  1/  1], step:[ 640/1019], time:0.467980s,
loss_g:5.19, loss_d:0.51, loss_g_a: 0.43, loss_g_b: 0.28, loss_c_a: 1.42, loss_c_b: 1.74, loss_idt_a: 0.55, loss_idt_b: 0.76
Epoch:[  1/  1], step:[ 720/1019], time:0.452623s,
loss_g:4.36, loss_d:0.57, loss_g_a: 0.42, loss_g_b: 0.31, loss_c_a: 1.90, loss_c_b: 0.66, loss_idt_a: 0.74, loss_idt_b: 0.32
Epoch:[  1/  1], step:[ 800/1019], time:0.451772s,
loss_g:4.84, loss_d:0.55, loss_g_a: 0.27, loss_g_b: 0.60, loss_c_a: 1.23, loss_c_b: 1.69, loss_idt_a: 0.37, loss_idt_b: 0.67
Epoch:[  1/  1], step:[ 880/1019], time:0.464954s,
loss_g:4.82, loss_d:0.47, loss_g_a: 0.27, loss_g_b: 0.42, loss_c_a: 1.57, loss_c_b: 1.27, loss_idt_a: 0.82, loss_idt_b: 0.47
Epoch:[  1/  1], step:[ 960/1019], time:0.464330s,
loss_g:4.65, loss_d:0.42, loss_g_a: 0.37, loss_g_b: 0.32, loss_c_a: 1.54, loss_c_b: 1.16, loss_idt_a: 0.73, loss_idt_b: 0.54
Epoch:[  1/  1], epoch time:761.31s, per step time:0.75, mean_g_loss:6.92, mean_d_loss:0.45
End of training!

模型推理

下面我们通过加载生成器网络模型参数文件来对原图进行风格迁移,结果中第一行为原图,第二行为对应生成的结果图。

%%time
import os
from PIL import Image
import mindspore.dataset as ds
import mindspore.dataset.vision as vision
from mindspore import load_checkpoint, load_param_into_net# 加载权重文件
def load_ckpt(net, ckpt_dir):param_GA = load_checkpoint(ckpt_dir)load_param_into_net(net, param_GA)g_a_ckpt = './CycleGAN_apple2orange/ckpt/g_a.ckpt'
g_b_ckpt = './CycleGAN_apple2orange/ckpt/g_b.ckpt'load_ckpt(net_rg_a, g_a_ckpt)
load_ckpt(net_rg_b, g_b_ckpt)# 图片推理
fig = plt.figure(figsize=(11, 2.5), dpi=100)
def eval_data(dir_path, net, a):def read_img():for dir in os.listdir(dir_path):path = os.path.join(dir_path, dir)img = Image.open(path).convert('RGB')yield img, dirdataset = ds.GeneratorDataset(read_img, column_names=["image", "image_name"])trans = [vision.Resize((256, 256)), vision.Normalize(mean=[0.5 * 255] * 3, std=[0.5 * 255] * 3), vision.HWC2CHW()]dataset = dataset.map(operations=trans, input_columns=["image"])dataset = dataset.batch(1)for i, data in enumerate(dataset.create_dict_iterator()):img = data["image"]fake = net(img)fake = (fake[0] * 0.5 * 255 + 0.5 * 255).astype(np.uint8).transpose((1, 2, 0))img = (img[0] * 0.5 * 255 + 0.5 * 255).astype(np.uint8).transpose((1, 2, 0))fig.add_subplot(2, 8, i+1+a)plt.axis("off")plt.imshow(img.asnumpy())fig.add_subplot(2, 8, i+9+a)plt.axis("off")plt.imshow(fake.asnumpy())eval_data('./CycleGAN_apple2orange/predict/apple', net_rg_a, 0)
eval_data('./CycleGAN_apple2orange/predict/orange', net_rg_b, 4)
plt.show()

风格迁移结果

参考

[1] I. Goodfellow. NIPS 2016 tutorial: Generative ad-versarial networks. arXiv preprint arXiv:1701.00160,2016. 2, 4, 5

[2] A. Shrivastava, T. Pfister, O. Tuzel, J. Susskind, W. Wang, R. Webb. Learning from simulated and unsupervised images through adversarial training. In CVPR, 2017. 3, 5, 6, 7

打卡记录

打卡记录

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/379595.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【PostgreSQL】PostgreSQL 教程

博主介绍&#xff1a;✌全网粉丝20W&#xff0c;CSDN博客专家、Java领域优质创作者&#xff0c;掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域✌ 技术范围&#xff1a;SpringBoot、SpringCloud、Vue、SSM、HTML、Nodejs、Python、MySQL、PostgreSQL、大数据、物…

C++【OpenCV】图片亮度色度归一化

#include <opencv2/highgui.hpp> #include <opencv2/imgproc.hpp> #include <iostream>using namespace cv; using namespace std;int main() {Mat image imread("SrcMF.jpg");// 灰度、Gamma归一化亮度cv::Mat m_gray;cv::cvtColor(image, m_gra…

Richteck立锜科技电源管理芯片简介及器件选择指南

一、电源管理简介 电源管理组件的选择和应用本身的电源输入和输出条件是高度关联的。 输入电源是交流或直流&#xff1f;需求的输出电压比输入电压高或是低&#xff1f;负载电流多大&#xff1f;系统是否对噪讯非常敏感&#xff1f;也许系统需要的是恒流而不是稳压 (例如 LED…

Docker-compose单机容器集群编排

传统的容器管理&#xff1a;Dockerfile文件 -> 手动执行 docker build 一个个镜像的构建 -> 手动执行 docker run 一个个容器的创建和启动 容器编排管理&#xff1a;Dockerfile文件 -> 在docker-compose.yml配置模板文件里定义容器启动参数和依赖关系 -> 执行dock…

vue echarts 柱状图表,点击柱子,路由代参数(X轴坐标)跳转

一 myChart.on(click, (params) > {if (params.componentType series && params.dataIndex ! undefined) {const months this.month_htqd[params.dataIndex]; // 获取点击柱状图的 X 轴坐标值alert(点击了柱状图&#xff0c;值为: ${months});// 根据点击的柱状图…

基于PHP+MYSQL开发制作的趣味测试网站源码

基于PHPMYSQL开发制作的趣味测试网站源码。可在后台提前设置好缘分&#xff0c; 自己手动在数据库里修改数据&#xff0c;数据库里有就会优先查询数据库的信息&#xff0c; 没设置的话第一次查询缘分都是非常好的 95-99&#xff0c;第二次查就比较差 &#xff0c; 所以如果要…

深度解析:如何优雅地删除GitHub仓库中的特定commit历史

&#x1f49d;&#x1f49d;&#x1f49d;欢迎莅临我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里可以感受到一份轻松愉快的氛围&#xff0c;不仅可以获得有趣的内容和知识&#xff0c;也可以畅所欲言、分享您的想法和见解。 推荐:「stormsha的主页」…

第二证券:电影暑期档持续升温 农机自动驾驶驶入快车道

农机自动驾驶打开驶入快车道 得益于农机补贴、土地流通、高标准农田制造等方针引导&#xff0c;叠加技术突围和用户降本增效的内生需求&#xff0c;我国正处于农业2.0向农业3.0的过渡阶段。其间农机自动驾驶系统是结束农业3.0&#xff08;即自动化&#xff09;的要害并迎来快速…

互联网行业的产品方向(二)

数字与策略产品 大数据时代&#xff0c;数据的价值越来越重要。大多数公司开始对内外全部数据进行管理与挖掘&#xff0c;将业务数据化&#xff0c;数据资产化&#xff0c;资产业务化&#xff0c;将数据产品赋能业务&#xff0c;通过数据驱动公司业务发展&#xff0c;支撑公司战…

QT5:简单显示百度页面

目录 前言 一、环境 二、实现过程 1.引入模块 2.环境构建 三、代码示例 总结 参考博客 前言 使用qt5 QT WebEngine 模块实现在Designer 上展示百度页面。 一、环境 qt版本&#xff1a;5.12.7 windows 11 下的 Qt Designer &#xff08;已搭建&#xff09; 编译器&a…

地图项目涉及知识点总结

序&#xff1a;最近做了一个在地图上标记点的项目&#xff0c;用户要求是在地图上显示百万量级的标记点&#xff0c;并且地图仍要可用&#xff08;能拖拽&#xff0c;能缩放&#xff09;。调研了不少方法和方案&#xff0c;最终实现了相对流畅的地图系统&#xff0c;加载耗时用…

STM32全栈嵌入式人脸识别考勤系统:融合OpenCV、Qt和SQLite的解决方案

1. 项目概述 本项目旨在设计并实现一个基于STM32的全栈人脸识别考勤系统。该系统结合了嵌入式开发、计算机视觉和数据库技术&#xff0c;实现了自动人脸检测、识别和考勤记录功能。 主要特点: 使用STM32F4系列微控制器作为主控制器采用OpenCV进行人脸检测和识别Qt开发跨平台…

数据结构 day3

目录 思维导图&#xff1a; 学习内容&#xff1a; 1. 顺序表 1.1 概念 1.2 有关顺序表的操作 1.2.1 创建顺序表 1.2.2 顺序表判空和判断满 1.2.3 向顺序表中添加元素 1.2.4 遍历顺序表 1.2.5 顺序表按位置进行插入元素 1.2.6 顺序表任意位置删除元素 1.2.7 按值进…

Xcode 16 beta3 真机调试找不到 Apple Watch 的尝试解决

很多小伙伴们想用 Xcode 在 Apple Watch 真机上调试运行 App 时却发现&#xff1a;在 Xcode 设备管理器中压根找不到对应的 Apple Watch 设备。 大家是否已将 Apple Watch 和 Mac 都重启一万多遍了&#xff0c;还是束手无策。 Apple Watch not showing in XCodeApple Watch wo…

PHP手边酒店多商户版平台小程序系统源码

&#x1f3e8;【旅行新宠】手边酒店多商户版小程序&#xff0c;一键解锁住宿新体验&#xff01;&#x1f6cc; &#x1f308;【开篇&#xff1a;旅行新伴侣&#xff0c;尽在掌握】&#x1f308; 还在为旅行中的住宿选择而纠结吗&#xff1f;是时候告别繁琐的搜索和比价过程&a…

电脑屏幕录制怎么弄?分享3个简单的电脑录屏方法

在信息爆炸的时代&#xff0c;屏幕上的每一个画面都可能成为我们生活中不可或缺的记忆。作为一名年轻男性&#xff0c;我对于录屏软件的需求可以说是既挑剔又实际。今天&#xff0c;我就为大家分享一下我近期体验的三款录屏软件&#xff1a;福昕录屏大师、转转大师录屏大师和OB…

TikTok账号矩阵运营怎么做?

这几年&#xff0c;聊到出海避不过海外抖音&#xff0c;也就是TikTok&#xff0c;聊到TikTok电商直播就离不开账号矩阵&#xff1b; 在TikTok上&#xff0c;矩阵养号已经成为了出海电商人的流行策略&#xff0c;归根结底还是因为矩阵养号可以用最小的力&#xff0c;获得更大的…

FastAPI 学习之路(五十)WebSockets(六)聊天室完善

我们这次只是对于之前的功能做下优化&#xff0c;顺便利用下之前的操作数据的接口&#xff0c;使用下数据库的练习。 在聊天里会有一个上线的概念。上线要通知大家&#xff0c;下线也要通知大家谁离开了&#xff0c;基于此功能我们完善下代码。 首先&#xff0c;我们的登录用…

初识langchain[1]:Langchain实战教学,利用qwen2.1与GLM-4大模型构建智能解决方案[含Agent、tavily面向AI搜索]

初识langchain[1]&#xff1a;Langchain实战教学&#xff0c;利用qwen2.1与GLM-4大模型构建智能解决方案 1.大模型基础知识 大模型三大重点&#xff1a;算力、数据、算法&#xff0c;ReAct &#xff08;reason推理act行动&#xff09;–思维链 Langchain会把上述流程串起来&a…

[Maven] 打包编译本地Jar包报错的几种解决办法

目录 方式1&#xff1a;通过scope指定 方式2&#xff1a;通过新建lib 方式3&#xff1a;通过build节点打包依赖​​​​​​​ 方式4&#xff1a;安装Jar包到本地 方式5&#xff1a;发布到远程私有仓库 方式6&#xff1a;删除_remote.repositories 方式7&#xff1a;打包…