[PyTorch][chapter 57][WGAN-GP 代码实现]

前言:

 下图为WGAN 的效果图:

  绿色为真实数据的分布: 8个高斯分布

  红色: 为随机产生的数据分布,跟真实分布基本一致

WGAN-GP:

1 判别器D: 最后一层去掉sigmoid
2 生成器G 和判别器D: loss不取log
3 损失函数 增加了penalty,使用Adam

 Wasserstein GAN
1 判别器D: 最后一层去掉sigmoid
2 生成器G 和判别器D: loss不取log
3 每次更新判别器的参数之后把它们的绝对值截断到不超过一个固定常数c
4 不要用基于动量的优化算法(包括momentum和Adam),推荐RMSProp,SGD也行
 


一  简介

    1.1 模型结构

 1.2 伪代码

      

从Wasserstein距离、对偶理论到WGAN - 科学空间|Scientific Spaces


二  wgan.py

 主要变化:

    Generator 中 去掉了之前的logit 函数

# -*- coding: utf-8 -*-
"""
Created on Thu Sep 28 11:10:19 2023@author: chengxf2
"""import torch
from   torch import nn#生成器模型
h_dim = 400
class Generator(nn.Module):def __init__(self):super(Generator,self).__init__()# z: [batch,input_features]self.net = nn.Sequential(nn.Linear(2, h_dim),nn.ReLU(True),nn.Linear( h_dim, h_dim),nn.ReLU(True),nn.Linear(h_dim, h_dim),nn.ReLU(True),nn.Linear(h_dim, 2))def forward(self, z):output = self.net(z)return output#鉴别器模型
class Discriminator(nn.Module):def __init__(self):super(Discriminator,self).__init__()hDim=400# x: [batch,input_features]self.net = nn.Sequential(nn.Linear(2, hDim),nn.ReLU(True),nn.Linear(hDim, hDim),nn.ReLU(True),nn.Linear(hDim, hDim),nn.ReLU(True),nn.Linear(hDim, 1),)def forward(self, x):#x:[batch,1]output = self.net(x)out = output.view(-1)return out

三 main.py

  主要变化:

    损失函数中增加了gradient_penalty

# -*- coding: utf-8 -*-
"""
Created on Thu Sep 28 11:28:32 2023@author: chengxf2
"""import visdom
from gan  import  Discriminator
from gan  import Generator
import numpy as np
import random
import torch
from   torch import nn, optim
from    matplotlib import pyplot as plt
from torch import autogradh_dim =400
batchsz = 256
viz = visdom.Visdom()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")def weights_init(net):if isinstance(net, nn.Linear):# net.weight.data.normal_(0.0, 0.02)nn.init.kaiming_normal_(net.weight)net.bias.data.fill_(0)def data_generator():"""8- gaussian destributionReturns-------None."""scale = 2a = np.sqrt(2.0)centers =[(1,0),(-1,0),(0,1),(0,-1),(1/a,1/a),(1/a,-1/a),(-1/a, 1/a),(-1/a,-1/a)]centers = [(scale*x, scale*y) for x,y in centers]while True:dataset =[]for i in range(batchsz):point = np.random.randn(2)*0.02center = random.choice(centers)point[0] += center[0]point[1] += center[1]dataset.append(point)dataset = np.array(dataset).astype(np.float32)dataset /=a#生成器函数是一个特殊的函数,可以返回一个迭代器yield datasetdef generate_image(D, G, xr, epoch):      #xr表示真实的sample"""Generates and saves a plot of the true distribution, the generator, and thecritic."""N_POINTS = 128RANGE = 3plt.clf()points = np.zeros((N_POINTS, N_POINTS, 2), dtype='float32')points[:, :, 0] = np.linspace(-RANGE, RANGE, N_POINTS)[:, None]points[:, :, 1] = np.linspace(-RANGE, RANGE, N_POINTS)[None, :]points = points.reshape((-1, 2))             # (16384, 2)x = y = np.linspace(-RANGE, RANGE, N_POINTS)N = len(x)# draw contourwith torch.no_grad():points = torch.Tensor(points)      # [16384, 2]disc_map = D(points).cpu().numpy() # [16384]plt.contour(x, y, disc_map.reshape((N, N)).transpose())#plt.clabel(cs, inline=1, fontsize=10)plt.colorbar()# draw sampleswith torch.no_grad():z = torch.randn(batchsz, 2)                 # [b, 2]samples = G(z).cpu().numpy()                # [b, 2]plt.scatter(xr[:, 0], xr[:, 1], c='green', marker='.')plt.scatter(samples[:, 0], samples[:, 1], c='red', marker='+')viz.matplot(plt, win='contour', opts=dict(title='p(x):%d'%epoch))def gradient_penalty(D, xr,xf):#[b,1]t =  torch.rand(batchsz, 1).to(device)       #[b,1]=>[b,2]  保证每个sample t 相同t =  t.expand_as(xr)#sample penalty interpoation [b,2]mid = t*xr +(1-t)*xfmid.requires_grad_()pred = D(mid) #[256]'''grad_outputs:   如果outputs 是向量,则此参数必须写retain_graph:  True 则保留计算图, False则释放计算图create_graph: 若要计算高阶导数,则必须选为Trueallow_unused: 允许输入变量不进入计算'''grads = autograd.grad(outputs= pred, inputs = mid,grad_outputs= torch.ones_like(pred),create_graph=True,retain_graph=True,only_inputs=True)[0]gp = torch.pow(grads.norm(2, dim=1)-1,2).mean()return gpdef main():lambd = 0.2 #超参数maxIter = 1000torch.manual_seed(10)np.random.seed(10)data_iter  = data_generator()G = Generator().to(device)D = Discriminator().to(device)G.apply(weights_init)D.apply(weights_init)optim_G = optim.Adam(G.parameters(),lr =5e-4, betas=(0.5,0.9))optim_D = optim.Adam(D.parameters(),lr =5e-4, betas=(0.5,0.9))K = 5viz.line([[0,0]], [0], win='loss', opts=dict(title='loss', legend=['D', 'G']))for epoch in range(maxIter):#1: train Discrimator fistlyfor k in range(K):#1.1: train on real dataxr = next(data_iter)xr = torch.from_numpy(xr).to(device)predr = D(xr)#max(predr) == min(-predr)lossr = -predr.mean()#1.2: train on fake dataz = torch.randn(batchsz,2).to(device) #[b,2] 随机产生的噪声xf = G(z).detach() #固定G,不更新G参数 tf.stop_gradient()predf =D(xf)lossf = predf.mean()#1.3 gradient_penaltygp = gradient_penalty(D, xr,xf.detach())#aggregate allloss_D = lossr + lossf +lambd*gpoptim_D.zero_grad()loss_D.backward()optim_D.step()#print("\n Discriminator 训练结束 ",loss_D.item())# 2 train  Generator#2.1 train on fake dataz = torch.randn(batchsz, 2).to(device)xf = G(z)predf =D(xf) #期望最大loss_G= -predf.mean()#optimizeoptim_G.zero_grad()loss_G.backward()optim_G.step()if epoch %100 ==0:viz.line([[loss_D.item(), loss_G.item()]], [epoch], win='loss', update='append')generate_image(D, G, xr, epoch)print("\n epoch: %d"%epoch,"\t lossD: %7.4f"%loss_D.item(),"\t lossG: %7.4f"%loss_G.item())if __name__ == "__main__":main()

参考:

课时130 WGAN-GP实战_哔哩哔哩_bilibili

WGAN基本原理及Pytorch实现WGAN-CSDN博客

CSDN

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/152548.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

讲讲项目里的仪表盘编辑器(四)分页卡和布局容器组件

讲讲两个经典布局组件的实现 ① 布局容器组件 配置面板是给用户配置布局容器背景颜色等属性。这里我们不需要关注 定义文件 规定了组件类的类型、标签、图标、默认布局属性、主文件等等。 // index.js import Container from ./container.vue; class ContainerControl extends…

java:代理模式

概念代理模式 概念: 真实对象:被代理的对象,背景的联想总部代理对象:也就是那个西安联想代理商代理模式:代理对象代理真实对象,达到增强真实对象功能的目的 实现方式: 静态代理:有一…

边缘计算网关

一、项目整体框架图 二、项目整体描述 边缘计算网关项目主要实现了智能家居场景和工业物联网场景下设备的数据采集和控制。 整个项目分为三大层:用户接口层、网关层、设备层。 其中用户层通过QT客户端、WEB界面及阿里云提供数据展示和用户接口。 网关使用虚拟机代替…

ArcGIS Engine:实现Shp/Mxd数据的加载、图层的简单查询

本博客参考&#xff1a;BiliBili UP主 <羊羊旸> &#xff1a; Arcgis Engine学习 目录 01 加载控件以及控件的基本信息等调整 02 编写 <菜单-地图控件> 中各个子工具的代码 2.1 加载Shapefile数据-代码 2.2 加载地图文档数据-代码 2.3 获取图层数量-代码 2.…

如何从零开始系统的学习项目管理?

一、项目的概念 根据项目管理协会&#xff08;PMI&#xff09;的定义&#xff0c;项目是指为了创造独特的产品、服务或成果而进行的临时性工作。这意味着项目需要有明确的目标&#xff0c;且不是日常重复性工作。尽管项目是临时性工作&#xff0c;但它所交付的成果可能会持续存…

汽车冲压车间的RFID技术设计解决方案

一、RFID技术的基本原理 RFID技术是一种利用非接触式自动识别的技术&#xff0c;通过将RFID标签放置在被识别物品上&#xff0c;并使用RFID读写器对标签进行扫描和识别&#xff0c;实现对物品的自动识别和追踪。RFID标签分为被动式和主动式两种。被动式标签无内置电源&#xf…

解决远程git服务器路径改变导致本地无法push的问题

解决远程git服务器路径改变导致本地无法push的问题 &#xff08;1&#xff09;第一步&#xff1a;查看git配置 git config -l&#xff08;2&#xff09;第二步&#xff1a;删除远程git地址 git remote remove origin&#xff08;3&#xff09;第三步&#xff1a;再次查看git配…

Vue3 + Ts实现NPM插件 - 定制loading

目录 你的 Loading&#x1f916; 安装&#x1f6f9; 简介苍白请 您移步文档&#xff1a;✈️ 使用方法&#x1f6e0;️ 配置 loading 类型&#x1f3b2; 定制 loading 色彩 &#x1f4a1; 注意事项 前期回顾 你的 Loading 开箱即可用的 loading&#xff0c; 说明&#xff1a;vu…

Java练习题-用冒泡排序法实现数组排序

✅作者简介&#xff1a;CSDN内容合伙人、阿里云专家博主、51CTO专家博主、新星计划第三季python赛道Top1&#x1f3c6; &#x1f4c3;个人主页&#xff1a;hacker707的csdn博客 &#x1f525;系列专栏&#xff1a;Java练习题 &#x1f4ac;个人格言&#xff1a;不断的翻越一座又…

MySql017——组合查询UNION和UNION ALL

一、UNION作用 可用UNION操作符来组合数条SQL查询。 二、UNION 使用规则 1、UNION的使用很简单。所需做的只是给出每条SELECT语句&#xff0c;在各条语句之间放上关键字UNION。2、UNION必须由两条或两条以上的SELECT语句组成&#xff0c;语句之间用关键字UNION分隔&#xff…

Mac mov转mp4,详细转换步骤

Mac mov转mp4怎么转&#xff1f;视频文件格式为.mov是由Apple公司所开发的特殊格式。因其只能在苹果设备上播放&#xff0c;与他人分享时就会变得困难。为此&#xff0c;我们通常会选择使用MP4这种最受欢迎的视频格式。在日常使用中&#xff0c;MP4成为了大家首选的视频格式。而…

Vulnhub系列靶机-The Planets Earth

文章目录 Vulnhub系列靶机-The Planets: Earth1. 信息收集1.1 主机扫描1.2 端口扫描1.3 目录爆破 2. 漏洞探测2.1 XOR解密2.2 解码 3. 漏洞利用3.1 反弹Shell 4. 权限提升4.1 NC文件传输 Netcat&#xff08;nc&#xff09;文件传输 Vulnhub系列靶机-The Planets: Earth 1. 信息…

【管理运筹学】第 9 章 | 网络计划(1,网络图的组成及绘制)

文章目录 引言一、网络图的组成及绘制1.1 网络图的组成1. 基本要素2. 线路与关键线路3. 网络图的类型 1.2 网络图的绘制1. 画图原则2. 绘图一般步骤 写在最后 引言 大纲里关于网络计划这一章的描述&#xff0c;就两个&#xff0c;一个是基本概念&#xff1a;网络计划、时间参数…

计算机竞赛 题目:基于深度学习的中文对话问答机器人

文章目录 0 简介1 项目架构2 项目的主要过程2.1 数据清洗、预处理2.2 分桶2.3 训练 3 项目的整体结构4 重要的API4.1 LSTM cells部分&#xff1a;4.2 损失函数&#xff1a;4.3 搭建seq2seq框架&#xff1a;4.4 测试部分&#xff1a;4.5 评价NLP测试效果&#xff1a;4.6 梯度截断…

API网关是什么?

API网关是什么&#xff1f; API网关很多人都知道它的实现原理&#xff0c;但是并不清楚它存在的意义和背景是什么&#xff0c;这里我给大家通俗易懂地讲解下&#xff01;举个例子&#xff0c;假设你正在开发一个电商网站&#xff0c;那么这里会涉及到很多后端的微服务&#xf…

排序算法之【归并排序】

&#x1f4d9;作者简介&#xff1a; 清水加冰&#xff0c;目前大二在读&#xff0c;正在学习C/C、Python、操作系统、数据库等。 &#x1f4d8;相关专栏&#xff1a;C语言初阶、C语言进阶、C语言刷题训练营、数据结构刷题训练营、有感兴趣的可以看一看。 欢迎点赞 &#x1f44d…

postman测试文件上传接口教程

postman是一个很好的接口测试软件&#xff0c;有时候接口是Get请求方式的&#xff0c;肯定在浏览器都可以测了&#xff0c;不过对于比较规范的RestFul接口&#xff0c;限定了只能post请求的&#xff0c;那你只能通过工具来测了&#xff0c;浏览器只能支持get请求的接口&#xf…

【计算机网络】poll | epoll

文章目录 1. pollpoll函数参数解析代码解析PollServer代码 poll 特点 2. epoll认识接口epoll_createepoll_ctlepoll_wait 基本原理红黑树就绪队列 1. poll poll函数参数解析 输入 man poll poll的第一个参数是文件描述符 poll的第二个参数为 等待的多个文件描述符(fd)数字层面…

点云分割segmentation

点云分割是根据空间、几何和纹理等特征对点云进行划分&#xff0c;使得同一划分区域内的点云拥有相似的特征 。点云的有效分割往往是许多应用的前提。例如&#xff0c;在逆向工程CAD/CAM 领域&#xff0c;对零件的不同扫描表面进行分割&#xff0c;然后才能更好地进行孔洞修复、…

Go 并发编程

并发编程 1.1 并发与并⾏ 并⾏与并发是两个不同的概念&#xff0c;普通解释&#xff1a; 并发&#xff1a;交替做不同事情的能⼒并⾏&#xff1a;同时做不同事情的能⼒ 如果站在程序员的⻆度去解释是这样的&#xff1a; 并发&#xff1a;不同的代码块交替执⾏并⾏&#xf…