使用pytorch构建带梯度惩罚的Wasserstein GAN(WGAN-GP)网络模型

本文为此系列的第三篇WGAN-GP,上一篇为DCGAN。文中仍然不会过多详细的讲解之前写过的,只会写WGAN-GP相对于之前版本的改进点,若有不懂的可以重点看第一篇比较详细。

原理

具有梯度惩罚的 Wasserstein GAN (WGAN-GP)可以解决 GAN 的一些稳定性问题。 具体来说,使用W-loss 作为损失函数替代传统的 BCE 等 loss,并使用梯度惩罚来防止 mode collapse。

  • WGAN-GP 使用了 Wasserstein distance(也成为Earth Mover’s distance, EMD)作为训练 GAN 模型的目标函数,Wasserstein distance is a function of amount and distance,体现的是生成的数据的分布移动到真实数据的分布之间所需的距离与量。
    在这里插入图片描述
    随着判别器训练的越来越好,使用 BCE loss 的话会让鉴别器给出接近于 0 或者接近于 1 的极端值,如下为 sigmoid 曲线,极端值的梯度无限接近于 0,这样判别器就没有太多有用的信息反馈给生成器让它学习,导致梯度消失或 model collapse。使用距离的方式可以有效解决,分布距离再远都不再限制。
    在这里插入图片描述
    在这里插入图片描述
  • BCE loss 本质是一个 minimax game, d 即 discriminator 希望尽可能的 minimize,g 即 generator 希望尽可能的 maximize(意味着造出来的假东西对于鉴别器来说看起来很真实),可以进行如下的简化:
    在这里插入图片描述
    基于 Wasserstein distance 的 W-loss 的的式子与其简化版进行对比:
    在这里插入图片描述
    在 Wasserstein GAN 中不再是 discriminator 了,因为输出不再是 0-1 之间来进行分类,既然不分类了就不是 discriminator 了,而是 critic,所以这里使用 c 代表 critic。critic 希望其尽可能的 maximize,因为希望让 real 和 feak 的距离尽可能的大,起到划清界限的目的;generator 希望其尽可能的minimize,减小两者之间的距离,达到以假乱真的目的。
  • mode collapse 即模式崩溃,当生成器学会从单个类生成特征来欺骗鉴别器时,就会发生 mode collapse(陷入一种模式出不来),跟 cnn 的局部最优是一个概念。这会导致输出出现重复,缺乏多样性和细节。

但在使用 W-loss 训练 GAN 时需要对 critic 有一定的条件 —— critic 需要 1-L(1-Lipschitz)连续:
∣ f ( x 1 ) − f ( x 2 ) ∣ ≤ k ∣ x 1 − x 2 ∣ |f(x_1)-f(x_2)|\le k|x_1-x_2\ | f(x1)f(x2)kx1x2 
这里的 k = 1,也就是 critic 的 nn 函数曲线的梯度(斜率)始终在 -1 到 1 之间,即梯度的 L2 范数不超过1:

在这里插入图片描述
如图:
在这里插入图片描述
曲线的每个点的斜率都是在绿色区域内,很显然这个曲线并不符合。像如下这个曲线就是符合的:
在这里插入图片描述
达到 1-L 连续有两种方法:weigh clipping、gradient penalty。

  • weigh clipping 将权重裁剪到固定范围内,从而限制 critic 的学习能力。但是这样有缺点,可能让所有参数走极端,要么取最大值要么取最小值,critic 会非常倾向于学习一个简单的映射函数。
  • gradient penalty 则是添加一个正则项在 loss function 中,相比 weigh clipping 更加柔和对critic参数的限制更加灵活,通常不会导致梯度消失或梯度爆炸问题。
    在这里插入图片描述
    这里的 λ \lambda λ 为超参值,reg 等于 critic 神经网络梯度范数 -1 的平方,即:
    在这里插入图片描述
    当 critic 神经网络梯度范数 >1 时正则化项发挥作用。平方的作用是为了让其偏离越大,惩罚越大。
    这里的 x ^ \hat{x} x^ 为真实数据与生成数据之间随机取样得到的中间数据,随机值 ϵ \epsilon ϵ 作为权重值,假设 ϵ \epsilon ϵ 为0.3,那么 1- ϵ \epsilon ϵ 为0.7。
    在这里插入图片描述

代码

model.py

from torch import nnclass Generator(nn.Module):def __init__(self, z_dim=10, im_chan=1, hidden_dim=64):super(Generator, self).__init__()self.z_dim = z_dim# Build the neural networkself.gen = nn.Sequential(self.make_gen_block(z_dim, hidden_dim * 4),self.make_gen_block(hidden_dim * 4, hidden_dim * 2, kernel_size=4, stride=1),self.make_gen_block(hidden_dim * 2, hidden_dim),self.make_gen_block(hidden_dim, im_chan, kernel_size=4, final_layer=True),)def make_gen_block(self, input_channels, output_channels, kernel_size=3, stride=2, final_layer=False):if not final_layer:return nn.Sequential(nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),nn.BatchNorm2d(output_channels),nn.ReLU(inplace=True),)else:return nn.Sequential(nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),nn.Tanh(),)def forward(self, noise):x = noise.view(len(noise), self.z_dim, 1, 1)return self.gen(x)class Critic(nn.Module):def __init__(self, im_chan=1, hidden_dim=64):super(Critic, self).__init__()self.crit = nn.Sequential(self.make_crit_block(im_chan, hidden_dim),self.make_crit_block(hidden_dim, hidden_dim * 2),self.make_crit_block(hidden_dim * 2, 1, final_layer=True),)def make_crit_block(self, input_channels, output_channels, kernel_size=4, stride=2, final_layer=False):if not final_layer:return nn.Sequential(nn.Conv2d(input_channels, output_channels, kernel_size, stride),nn.BatchNorm2d(output_channels),nn.LeakyReLU(0.2, inplace=True),)else:return nn.Sequential(nn.Conv2d(input_channels, output_channels, kernel_size, stride),)def forward(self, image):crit_pred = self.crit(image)return crit_pred.view(len(crit_pred), -1)

train.py

import torch
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from model import *
torch.manual_seed(0) # Set for testing purposes, please do not change!def show_tensor_images(image_tensor, num_images=25, size=(1, 28, 28)):image_tensor = (image_tensor + 1) / 2image_unflat = image_tensor.detach().cpu()image_grid = make_grid(image_unflat[:num_images], nrow=5)plt.imshow(image_grid.permute(1, 2, 0).squeeze())plt.show()def get_noise(n_samples, z_dim, device='cpu'):return torch.randn(n_samples, z_dim, device=device)n_epochs = 100
z_dim = 64
display_step = 50
batch_size = 128
lr = 0.0002
beta_1 = 0.5
beta_2 = 0.999
c_lambda = 10
crit_repeats = 5
device = 'cuda'transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,)),
])dataloader = DataLoader(MNIST('.', download=False, transform=transform),batch_size=batch_size,shuffle=True)gen = Generator(z_dim).to(device)
gen_opt = torch.optim.Adam(gen.parameters(), lr=lr, betas=(beta_1, beta_2))
crit = Critic().to(device)
crit_opt = torch.optim.Adam(crit.parameters(), lr=lr, betas=(beta_1, beta_2))def weights_init(m):if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):torch.nn.init.normal_(m.weight, 0.0, 0.02)if isinstance(m, nn.BatchNorm2d):torch.nn.init.normal_(m.weight, 0.0, 0.02)torch.nn.init.constant_(m.bias, 0)
gen = gen.apply(weights_init)
crit = crit.apply(weights_init)def get_gradient(crit, real, fake, epsilon):# Mix the images togethermixed_images = real * epsilon + fake * (1 - epsilon)# Calculate the critic's scores on the mixed imagesmixed_scores = crit(mixed_images)# Take the gradient of the scores with respect to the imagesgradient = torch.autograd.grad(inputs=mixed_images,outputs=mixed_scores,# These other parameters have to do with the pytorch autograd engine worksgrad_outputs=torch.ones_like(mixed_scores),create_graph=True,retain_graph=True,)[0]return gradientdef gradient_penalty(gradient):# Flatten the gradients so that each row captures one imagegradient = gradient.view(len(gradient), -1)# Calculate the magnitude of every rowgradient_norm = gradient.norm(2, dim=1)# Penalize the mean squared distance of the gradient norms from 1penalty = torch.mean((gradient_norm - 1) ** 2)return penaltydef get_gen_loss(crit_fake_pred):gen_loss = -1. * torch.mean(crit_fake_pred)return gen_lossdef get_crit_loss(crit_fake_pred, crit_real_pred, gp, c_lambda):crit_loss = torch.mean(crit_fake_pred) - torch.mean(crit_real_pred) + c_lambda * gpreturn crit_losscur_step = 0
generator_losses = []
critic_losses = []
for epoch in range(n_epochs):# Dataloader returns the batchesfor real, _ in tqdm(dataloader):cur_batch_size = len(real)real = real.to(device)mean_iteration_critic_loss = 0for _ in range(crit_repeats):### Update critic ###crit_opt.zero_grad()fake_noise = get_noise(cur_batch_size, z_dim, device=device)fake = gen(fake_noise)crit_fake_pred = crit(fake.detach())crit_real_pred = crit(real)epsilon = torch.rand(len(real), 1, 1, 1, device=device, requires_grad=True)gradient = get_gradient(crit, real, fake.detach(), epsilon)gp = gradient_penalty(gradient)crit_loss = get_crit_loss(crit_fake_pred, crit_real_pred, gp, c_lambda)# Keep track of the average critic loss in this batchmean_iteration_critic_loss += crit_loss.item() / crit_repeats# Update gradientscrit_loss.backward(retain_graph=True)# Update optimizercrit_opt.step()critic_losses += [mean_iteration_critic_loss]### Update generator ###gen_opt.zero_grad()fake_noise_2 = get_noise(cur_batch_size, z_dim, device=device)fake_2 = gen(fake_noise_2)crit_fake_pred = crit(fake_2)gen_loss = get_gen_loss(crit_fake_pred)gen_loss.backward()# Update the weightsgen_opt.step()# Keep track of the average generator lossgenerator_losses += [gen_loss.item()]### Visualization code ###if cur_step % display_step == 0 and cur_step > 0:gen_mean = sum(generator_losses[-display_step:]) / display_stepcrit_mean = sum(critic_losses[-display_step:]) / display_stepprint(f"Step {cur_step}: Generator loss: {gen_mean}, critic loss: {crit_mean}")show_tensor_images(fake)show_tensor_images(real)step_bins = 20num_examples = (len(generator_losses) // step_bins) * step_binsplt.plot(range(num_examples // step_bins),torch.Tensor(generator_losses[:num_examples]).view(-1, step_bins).mean(1),label="Generator Loss")plt.plot(range(num_examples // step_bins),torch.Tensor(critic_losses[:num_examples]).view(-1, step_bins).mean(1),label="Critic Loss")plt.legend()plt.show()cur_step += 1

在这里插入图片描述

代码讲解

网络模型与上一篇的DCGAN没有变动。
在这里插入图片描述
这个模块进行梯度计算,即上文原理中正则项公式里面的梯度L2范数里的梯度。首先计算真实数据与生成数据之间随机取样的混合数据,然后输入 critic,最后计算出其梯度。
在这里插入图片描述
梯度惩罚模块,即上文原理中的整个正则项公式,梯度范数 -1 的平方。
在这里插入图片描述
critic 的 loss function 公式如下,generator 因为和真实数据无关,且与正则项也无关,所以只有中间一项。
在这里插入图片描述————————————————————————————————————————————

总之,WGAN-GP 不一定要提高 GAN 的整体性能,但会很好的提高稳定性并避免模式崩溃。

下一篇条件生成GAN。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/295428.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

OpenHarmony error: signature verification failed due to not trusted app source

问题:error: signature verification failed due to not trusted app source 今天在做OpenHarmony App开发,之前一直用的设备A在测试开效果,今天换成了设备B,通过DevEco Studio安装应用程序的时候,就出现错误&#xf…

蓝桥杯刷题-四平方和

四平方和 代码: from copy import deepcopy n int(input()) maxn int(5e6) 10 dic dict() for a in range(maxn):if a * a > n:breakfor b in range(a,maxn):if a * a b * b > n:breakif dic.get(a*ab*b) is None:dic[a*ab*b] (a,b) ans [maxn for _ …

基于springboot+vue+Mysql的教学视频点播系统

开发语言:Java框架:springbootJDK版本:JDK1.8服务器:tomcat7数据库:mysql 5.7(一定要5.7版本)数据库工具:Navicat11开发软件:eclipse/myeclipse/ideaMaven包:…

网络安全 | 什么是DDoS攻击?

关注WX:CodingTechWork DDoS-介绍 DoS:Denial of Service,拒绝服务。DDoS是通过大规模的网络流量使得正常流量不能访问受害者目标,是一种压垮性的网络攻击,而不是一种入侵手段。NTP网络时间协议,设备需要…

c++ 插值搜索-迭代与递归(Interpolation Search)

给定一个由 n 个均匀分布值 arr[] 组成的排序数组,编写一个函数来搜索数组中的特定元素 x。 线性搜索需要 O(n) 时间找到元素,跳转搜索需要 O(? n) 时间,二分搜索需要 O(log n) 时间。 插值搜索是对实例二分搜索的改进,…

C#开发中一些常用的工具类分享

一、配置文件读写类 用于在开发时候C#操作配置文件读写信息 1、工具类 ReadIni 代码 using System; using System.Collections.Generic; using System.IO; using System.Linq; using System.Runtime.InteropServices; using System.Text; using System.Threading.Tasks;namesp…

Lua 和 Love 2d 教程 二十一点朴克牌 (上篇lua源码)

GitCode - 开发者的代码家园 Lua版完整原码 规则 庄家和玩家各发两张牌。庄家的第一张牌对玩家是隐藏的。 玩家可以拿牌(即拿另一张牌)或 停牌(即停止拿牌)。 如果玩家手牌的总价值超过 21,那么他们就爆掉了。 面牌…

数据结构——红黑树详解

一、红黑树的定义 红黑树,是一种二叉搜索树,但在每个结点上增加一个存储位表示结点的颜色,可以是Red或Black。 通过对任何一条从根到叶子的路径上各个结点着色方式的限制,红黑树确保没有一条路径会比其他路径长出两倍&#xff0c…

数据库加载驱动问题(java.lang.ClassNotFoundException: com.mysql.cj.jdbc.Driver)

java.lang.ClassNotFoundException: com.mysql.cj.jdbc.Driver 遇到此问题,首先检查IDEA外部库中是否有mysql数据库驱动。如下所示: 如果发现外部库中存有mysql数据库驱动,需要在数据库配置文件中查看是否设置有时区mysql8.0以上版本需要设…

【机器学习】机器学习创建算法第3篇:K-近邻算法,学习目标【附代码文档】

机器学习(算法篇)完整教程(附代码资料)主要内容讲述:机器学习算法课程定位、目标,K-近邻算法定位,目标,学习目标,1 什么是K-近邻算法,1 Scikit-learn工具介绍,2 K-近邻算法API。K-近邻算法,1.4 …

rhce复习2

HTTPS协议 https简介 超文本传输协议HTTP协议被用于在Web浏览器和网站服务器之间传递信息。HTTP协议以明文方式发送内容,不提供任何方式的数据加密,如果攻击者截取了Web浏览器和网站服务器之间的传输报文,就可以直接读懂其中的信息&#xf…

Canal1.1.5整Springboot在MQ模式和TCP模式监听mysql

canal本实验使用的是1.1.5,自行决定版本:[https://github.com/alibaba/canal/releases] canal 涉及的几个角色 canal-admin:canal 后台管理系统,管理 canal 服务canal-deployer:即canal-server(客户端&…

某眼实时票房接口获取

某眼实时票房接口获取 前言解决方案1.找到veri.js2.找到signKey所在位置3.分析它所处的这个函数的内容4.index参数的获取5.signKey参数的获取运行结果关键代码另一种思路票房接口:https://piaofang.maoyan.com/dashboard-ajax https://piaofang.maoyan.com/dashboard 实时票房…

Meta 推出Ego-Exo4D:一个研究视频学习和多模态感知的基础数据集

每周跟踪AI热点新闻动向和震撼发展 想要探索生成式人工智能的前沿进展吗?订阅我们的简报,深入解析最新的技术突破、实际应用案例和未来的趋势。与全球数同行一同,从行业内部的深度分析和实用指南中受益。不要错过这个机会,成为AI领…

安全架构设计理论与实践相关知识总结

一、安全架构概述 常见信息威胁介绍: 1. 信息泄露:信息被泄露或透露给某个非授权实体 2. 破坏信息完整性:数据被非授权地进行增删改查货破坏而受到损失 3. 拒绝服务:对信息会其他资源的合法访问被无条件的组织 4. 非法使用&#x…

IDE/VS2015和VS2017帮助文档MSDN安装和使用

文章目录 概述VS2015MSDN离线安装离线MSDN的下载离线MSDN安装 MSDN使用方法从VS内F1启动直接启动帮助程序跳转到了Qt的帮助网页 VS2017在线安装MSDN有些函数在本地MSDN没有帮助?切换中英文在线帮助文档 概述 本文主要介绍了VS集成开发环境中,帮助文档MS…

hibernate session接口

hibernate session接口 Session接口是hibernate向应用程序提供的操纵数据库的最主要的接口,提供了保存、更新、删除和加载Java对象的方法。 session具有一个缓存,位于缓存中的对象成为持久化对象,和数据库中的相关记录对应。session能够在某些…

【御控物联】JavaScript JSON结构转换(17):数组To对象——键值互换属性重组

文章目录 一、JSON结构转换是什么?二、核心构件之转换映射三、案例之《JSON数组 To JSON对象》四、代码实现五、在线转换工具六、技术资料 一、JSON结构转换是什么? JSON结构转换指的是将一个JSON对象或JSON数组按照一定规则进行重组、筛选、映射或转换…

MES可视化管理,提高制造业工厂管理水平

在制造业中,生产过程的可视化管理对于提高生产效率、降低成本以及提升决策的准确性和及时性至关重要。MES系统作为制造执行过程中的核心管理工具,为实现生产过程的可视化管理提供了有力支持。 生产车间应用MES生产管理系统之后,能通过电子屏幕…

异地文件如何共享访问?

异地文件共享访问是一种让不同地区的用户能够快速、安全地共享文件的解决方案。人们越来越需要在不同地点之间共享文件和数据。由于复杂的网络环境和安全性的问题,实现异地文件共享一直是一个挑战。 为了解决这个问题,许多公司和组织研发了各种异地文件共…