2、StarGAN V2

2、StarGAN V2

StarGAN 论文链接:StarGAN

StarGAN V2 论文链接:StarGAN V2

在介绍StarGAN V2之前,我们先对StarGAN有一定的了解,StarGAN V2只是在StarGAN的基础上做出了改进,基本的架构是没有变的,只是将风格编码做成了向量的形式,使得风格编码也是可以学习的。

StarGAN
StarGAN的出发点

StarGAN(Star Generative Adversarial Network)是一种生成对抗网络(GAN)的变体,用于图像到图像的多域转换任务。StarGAN 的核心特点是,它可以在单一模型中实现多域图像转换,而不需要为每个领域的转换训练不同的模型。其实就是来解决在CycleGAN中转化一种风格就需要训练一个模型的问题,设计一种编码来实现一个生成器和一个判别器能够生成多种风格,解决了CycleGAN的弊端。

在这里插入图片描述

StarGAN架构图
  1. StarGAN为了解决CycleGAN每一个风格需要训练一个模型,并且需要多个生成器和判别器的问题,StarGAN采用了风格编码来实现只需要一个生成器和一个判别器,但是总体思想仍然采用CycleGAN的思想来设计损失函数。

    • 生成器

      • 在StarGAN中,生成器的输入不仅是图像,还包含目标域的域标签(即风格编码)。生成器会根据该标签生成属于目标域的图像。
      • 生成器同时使用了循环一致性损失(cycle-consistency loss),这是借鉴了CycleGAN的思想。通过将生成的图像转换回原始域,以确保生成图像保留了输入图像的关键信息。
      • 目标是通过风格编码使得生成器能够将一张图片从一个域(如人脸图片)转换为多个目标域(如不同表情、发型或年龄),并在多个域之间进行切换。

      判别器

      • StarGAN的判别器不仅需要判断图像的真假(真实图像 vs. 生成图像),还需要判别该图像属于哪个域(风格编码)。
      • 判别器会输出多个域的分类信息,并在真假分类的同时,判断生成的图像是否符合指定的域标签。

      损失函数

      • 对抗性损失:用于保证生成器生成的图像能够欺骗判别器。
      • 域分类损失:用于确保生成的图像与目标域标签匹配。
      • 循环一致性损失:用于确保生成图像能够还原回原始域,以保持输入的主要特征。

在这里插入图片描述

StarGAN V2
StarGAN V2出发点

StarGAN V2的出发点来自于StarGAN中使用的编码是一些固定的01编码,是不可学习,而StarGAN V2则在风格编码做出来改进,将风格编码初始化成向量,同时也可以通过原始输入图像来生成风格编码,而生成风格编码的网络是可学习的,使的风格更加的差异化,并且生成的图像风格更加准确。模型设计主要流程上并没有做出改动,主要在于损失函数的改动。理解损失函数也是掌握对抗生成网络的关键。

模型架构图

1. 生成器(Generator)

StarGAN V2 的生成器与传统的 GAN 不同,它融合了风格编码和图像转换的思想。生成器的主要目标是将输入图像转换为不同风格的图像。

生成器的核心组成部分:

  • 输入:生成器的输入不仅包括要转换的图像,还包括目标风格编码(可以是从风格编码器得到的风格向量,或者是随机采样的向量)。
  • 风格编码器:StarGAN V2 引入了一个风格编码器,它可以从目标图像中提取出风格信息,将其表示为风格向量。这样,生成器可以利用不同的风格编码生成对应风格的图像。
  • 结构设计:生成器采用了基于卷积的网络架构,但通过风格向量来调控生成过程中的特征图。这使得生成器可以生成具有不同风格特征的图像。
  • 多样性建模:生成器能够通过不同的风格编码生成多个同一源图像的多样化风格变化。这依赖于生成器对风格编码的处理,使得输出图像既能够保持输入的语义信息,又能够呈现目标风格。

2. 判别器(Discriminator)

StarGAN V2 的判别器不仅要判断图像的真假,还要判断生成图像是否符合目标风格。它负责区分生成器生成的图像和真实图像,并检测生成图像的风格是否与目标域匹配。

判别器的核心组成部分:

  • 输入:判别器接收图像输入,同时附带目标风格标签。它的任务是判断输入的图像是否来自真实的目标域,并判断生成器生成的图像是否匹配目标风格。
  • 多域分类:判别器输出的是多分类结果。除了判断图像是真实还是生成的,它还需要对图像的风格域进行分类,确保生成的图像符合目标风格。
  • PatchGAN 设计:判别器通常采用 PatchGAN(局部感知)的设计,它对图像的每个局部区域进行真假和风格分类。这种设计有助于判别器更好地捕捉图像的局部特征,尤其是风格特征,从而在视觉上确保生成的图像看起来自然

损失函数的改进

  • 对抗损失依然是生成对抗网络的核心,用于确保生成图像能欺骗判别器。

  • 风格一致性损失:StarGAN V2通过风格一致性损失来确保生成的图像能保持输入图像的关键信息,并且使风格变化是自然且符合目标域的。

  • 循环一致性损失:与StarGAN类似,StarGAN V2依然采用了循环一致性损失来保证生成图像在转换回原始域时能保持输入图像的主要特征。

  • 多样性损失: StarGAN V2还通过引入多样性损失,确保生成的图像在同一目标域内保持足够的多样性,而不仅仅是简单的风格映射。通过学习不同的风格编码,生成器可以在同一个目标域中生成多个不同风格的图像。

在这里插入图片描述

生成器损失

包含下面四种:对抗损失风格一致性损失多样性损失循环一致性损失

对抗损失

生成对抗网络的核心,用于确保生成图像能欺骗判别器。

公式:在这里插入图片描述

风格一致性损失

风格一致性损失,就是保证模型生成的图片的风格和需要生成的风格越接近越好。首先使用x和风格s生成一张图片,然后再用Style encoder进行编码,获得生成后图片的风格编码,计算它和需要生成风格编码之间的差距作为风格一致性损失。

公式:在这里插入图片描述

多样性损失

多样性损失是确保生成的图像在同一目标域内保持足够的多样性,而不仅仅是简单的风格映射。通过学习不同的风格编码,生成器可以在同一个目标域中生成多个不同风格的图像。简单来说,就是两者的标签一样的,同时采用同样的Mapping network进行编码,但是要使编码出来风格编码差异性越大越好,这样采用生成多种不同风格的图像,学习的是Mapping network。

公式:在这里插入图片描述

循环一致性损失

循环一致性损失和CycleGAN的思想是一样的,要求我们生成出来的图片必须经过还原后还是能够与原来的图像越接近越好。从公式中可以看出,先对x和某中风格编码s生成图像,在使用x经过style encoder生成s1,然后将s1和生成的图像输入生成器,得到图片与原来的图片做比较,这样就得到原始图像和还原后图像之间的差异作为循环一致性损失。

公式:在这里插入图片描述

最终Loss值公式:

Ladv 是对抗损失,Lds前面的负号,说明他们之间的差异越大越好。

在这里插入图片描述

生气器损失计算源码
def compute_g_loss(nets, args, x_real, y_org, y_trg, z_trgs=None, x_refs=None, masks=None):# 确保 z_trgs 和 x_refs 其中一个不为空assert (z_trgs is None) != (x_refs is None)# 当 z_trgs 不为空时,解包 z_trg 和 z_trg2if z_trgs is not None:z_trg, z_trg2 = z_trgs# 当 x_refs 不为空时,解包 x_ref 和 x_ref2if x_refs is not None:x_ref, x_ref2 = x_refs# 对抗损失(adversarial loss)if z_trgs is not None:s_trg = nets.mapping_network(z_trg, y_trg)  # 通过映射网络生成目标风格编码else:s_trg = nets.style_encoder(x_ref, y_trg)  # 通过风格编码器生成目标风格编码x_fake = nets.generator(x_real, s_trg, masks=masks)  # 使用生成器生成假图像out = nets.discriminator(x_fake, y_trg)  # 判别器判断生成的假图像loss_adv = adv_loss(out, 1)  # 对抗损失,目标是真# 风格重构损失(style reconstruction loss)s_pred = nets.style_encoder(x_fake, y_trg)  # 从生成的假图像中提取风格编码loss_sty = torch.mean(torch.abs(s_pred - s_trg))  # 风格重构损失,比较生成和目标风格编码的差异# 多样性敏感损失(diversity sensitive loss)if z_trgs is not None:s_trg2 = nets.mapping_network(z_trg2, y_trg)  # 生成第二个风格编码else:s_trg2 = nets.style_encoder(x_ref2, y_trg)  # 从参考图像中提取第二个风格编码x_fake2 = nets.generator(x_real, s_trg2, masks=masks)  # 生成第二个假图像x_fake2 = x_fake2.detach()  # 停止梯度计算loss_ds = torch.mean(torch.abs(x_fake - x_fake2))  # 计算两个假图像之间的差异,鼓励多样性# 循环一致性损失(cycle-consistency loss)masks = nets.fan.get_heatmap(x_fake) if args.w_hpf > 0 else None  # 使用 FAN 模型获取热图(如果 w_hpf > 0)s_org = nets.style_encoder(x_real, y_org)  # 提取输入图像的原始风格编码x_rec = nets.generator(x_fake, s_org, masks=masks)  # 将假图像转换回原始域loss_cyc = torch.mean(torch.abs(x_rec - x_real))  # 循环一致性损失,确保恢复的图像与原图像相似# 总损失,由对抗损失、风格重构损失、多样性损失和循环一致性损失组成loss = loss_adv + args.lambda_sty * loss_sty \- args.lambda_ds * loss_ds + args.lambda_cyc * loss_cyc# 返回总损失以及每部分的损失值return loss, Munch(adv=loss_adv.item(),sty=loss_sty.item(),ds=loss_ds.item(),cyc=loss_cyc.item())
判别器损失

它对真实图像和生成的假图像分别进行判别,并计算对应的对抗损失。对真实图像,函数计算其对抗损失(希望判别器将其判别为真)和 R1 正则化损失,以提高训练稳定性。对生成的假图像,生成器根据目标域的风格编码生成假图像,判别器再判断该假图像并计算对抗损失(希望判别器将其判别为假)。最后,将真实损失、假图像损失和正则化损失加和,作为判别器的总损失。

判别器损失计算源码
def compute_d_loss(nets, args, x_real, y_org, y_trg, z_trg=None, x_ref=None, masks=None):# 确保 z_trg 和 x_ref 中只有一个不为空assert (z_trg is None) != (x_ref is None)# 对真实图像进行操作x_real.requires_grad_()  # 允许对 x_real 进行梯度计算out = nets.discriminator(x_real, y_org)  # 使用判别器判断真实图像loss_real = adv_loss(out, 1)  # 真实图像的对抗损失,目标是 1loss_reg = r1_reg(out, x_real)  # R1 正则化损失,用于提高训练稳定性# 对生成的假图像进行操作with torch.no_grad():  # 假图像的生成不需要计算梯度if z_trg is not None:s_trg = nets.mapping_network(z_trg, y_trg)  # 通过映射网络生成目标风格编码else:  # x_ref 不为空时,通过风格编码器生成风格编码s_trg = nets.style_encoder(x_ref, y_trg)x_fake = nets.generator(x_real, s_trg, masks=masks)  # 生成假图像out = nets.discriminator(x_fake, y_trg)  # 判别器判断生成的假图像loss_fake = adv_loss(out, 0)  # 假图像的对抗损失,目标是 0# 总损失,由真实损失、假图像损失和正则化损失组成loss = loss_real + loss_fake + args.lambda_reg * loss_reg# 返回总损失以及每部分的损失值return loss, Munch(real=loss_real.item(),fake=loss_fake.item(),reg=loss_reg.item())

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/429252.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

(11)(2.1.2) DShot ESCs(二)

文章目录 前言 3 配置伺服功能 4 检查RC横幅 5 参数说明 前言 DShot 是一种数字 ESC 协议,它允许快速、高分辨率的数字通信,可以改善飞行器控制,这在多旋翼和 quadplane 应用中特别有用。 3 配置伺服功能 如上所述,如果使用…

《粮油与饲料科技》是什么级别的期刊?是正规期刊吗?能评职称吗?

问题解答 问:《粮油与饲料科技》是不是核心期刊? 答:不是,是知网收录的第一批认定 学术期刊。 问:《粮油与饲料科技》级别? 答:省级。主管单位:中文天地出版传媒集团股份有限公司…

Apache ZooKeeper 及 Curator 使用总结

1. 下载 官网地址:Apache ZooKeeper 点击下载按钮 选择对应的版本进行下载 2. 使用 1、解压 tar -zxf apache-zookeeper-3.9.2-bin.tar.gz2、复制配置文件,有一个示例配置文件 conf/zoo_sample.cfg,此文件不能生效,需要名称为…

C#和数据库高级:继承与多态

文章目录 一、继承的基本使用继承的概念:继承的特点:为什么使用继承? 二、继承的关键字1、this关键字2、base关键字3、Protected关键字4、子类调用父类的构造函数的总结: 三、继承的特性继承的传递性:继承的单根性&…

【服务器入门】Linux系统基础知识

【服务器入门】Linux系统基础知识 远程登录与文件传输基础命令与文本编辑vi/vim使用shell脚本基本命令1、目录操作2、文件创建与删改3、文件连接与查看 参考 目前超算使用的系统以Linux系统为主,肯定需要了解一些相关知识。本博客就以本人运行WRF模型所需&#xff0…

Remix在SPA模式下,出现ErrorBoundary错误页加载Ant Design组件报错,不能加载样式的问题

Remix是一个既能做服务端渲染,又能做单页应用的框架,如果想做单页应用,又想学服务端渲染,使用Remix可以降低学习成本。最近,在学习Remix的过程中,遇到了在SPA模式下与Ant Design整合的问题。 我用Remix官网…

Godot游戏如何提升触感体验

在游戏世界中,触感体验至关重要,既能极大提升玩家沉浸感,让其深度融入游戏,在操作角色或与环境互动时,通过触感反馈获得身临其境的真实感(比如动作游戏中角色攻击或受击时的振动反馈,能使玩家更…

花朵识别系统Python+卷积神经网络算法+人工智能+深度学习+计算机课设项目+TensorFlow+模型训练

一、介绍 花朵识别系统。本系统采用Python作为主要编程语言,基于TensorFlow搭建ResNet50卷积神经网络算法模型,并基于前期收集到的5种常见的花朵数据集(向日葵、玫瑰、蒲公英、郁金香、菊花)进行处理后进行模型训练,最…

解决DockerDesktop启动redis后采用PowerShell终端操作

如图: 在启动redis容器后,会计入以下界面 : 在进入执行界面后如图: 是否会觉得界面过于单调,于是想到使用PowerShell来操作。 步骤如下: 这样就能使用PowerShell愉快地敲命令了(颜值是第一生…

【stm32笔记】使用rtt-studio与stm32CubeMx联合创建项目

使用rtt-studio与stm32CubeMx联合创建项目 创建rt-thread项目 设置项目信息 在项目资源管理器中“右击“,创建RRT studio 项目 双击“RT-Thread 项目“。 选择MCU,设置UART,以及调试方式。添加项目名称,点击“完成“按钮。 …

Redis的主从模式、哨兵模式、集群模式

最近学习了一下这三种架构模式,这里记录一下,仅供参考 目录 一、主从架构 1、搭建方式 2、同步原理 3、优化策略: 4、总结: 二、哨兵架构 1、搭建哨兵集群 2、RedisTemplate如何使用哨兵模式 三、分片集群架构 1&#…

集成学习详细介绍

以下内容整理于: 斯图尔特.罗素, 人工智能.现代方法 第四版(张博雅等译)机器学习_温州大学_中国大学MOOC(慕课)XGBoost原理介绍------个人理解版_xgboost原理介绍 个人理解-CSDN博客 集成学习(ensemble):选择一个由一系列假设h1, h2, …, hn构成的集合…

AI运动小程序开发常见问题集锦一

截止到现在写博文时,我们的AI运动识别小程序插件已经迭代了23个版本,成功应用于健身、体育、体测、AR互动等场景;为了让正在集成或者计划进行功能扩展优化的用户,少走弯路、投入更少的开发资源,我们归集了一部分集中的…

Redis数据结构之set

一.set集合特性 集合类型也是保存多个字符串类型的元素的,但和list列表不一样,集合中的元素是无序的,而且元素不能够重复,不仅支持增删查改,还支持交集并集等操作 二.相关命令 1.sadd sadd key members…… 咱们把…

【机器学习】--- 决策树与随机森林

文章目录 决策树与随机森林的改进:全面解析与深度优化目录1. 决策树的基本原理2. 决策树的缺陷及改进方法2.1 剪枝技术2.2 树的深度控制2.3 特征选择的优化 3. 随机森林的基本原理4. 随机森林的缺陷及改进方法4.1 特征重要性改进4.2 树的集成方法优化4.3 随机森林的…

JavaScript ---案例(统计字符出现次数)

统计字符出现次数 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta http-equiv"X-UA-Compatible" content"IEedge"><meta name"viewport" content"widthdevice-w…

深度学习之微积分预备知识点(2)

极限&#xff08;Limit&#xff09; 定义&#xff1a;表示某一点处函数趋近于某一特定值的过程&#xff0c;一般记为 极限是一种变化状态的描述&#xff0c;核心思想是无限靠近而永远不能到达 公式&#xff1a; 表示 x 趋向 a 时 f(x) 的极限。 知识点口诀解释极限的存在左…

LabVIEW软件维护的内容是什么呢?

LabVIEW软件维护涉及多个方面&#xff0c;确保程序的正常运行和长期稳定性。维护内容包括以下几个方面&#xff1a; 1. Bug修复 在开发和运行过程中&#xff0c;可能会出现各种软件问题或缺陷&#xff08;bugs&#xff09;。维护工作之一就是识别这些问题并通过修复程序中的代…

uniapp child.onFieldChange is not a function

uni-forms // 所有子组件参与校验,使用 for 可以使用 awiatfor (let i in childrens) {const child childrens[i];let name realName(child.name);if (typeof child.onFieldChange function) {const result await child.onFieldChange(tempFormData[name]);if (result) {…

【网络】TCP/IP 五层网络模型:网络层

最核心的就是 IP 协议&#xff0c;是一个相当复杂的协议 TCP 详细展开讲解&#xff0c;是因为 TCP 确实在开发中非常关键&#xff0c;经常用到&#xff0c;IP 则不同&#xff0c;和普通程序猿联系比较浅。和专门开发网络的程序猿联系比较紧密&#xff08;开发路由器&#xff0…