GAN的原理分析与实例

为了便于理解,可以先玩一玩这个网站:GAN Lab: Play with Generative Adversarial Networks in Your Browser!

GAN的本质:枯叶蝶和鸟。生成器的目标:让枯叶蝶进化,变得像枯叶,不被鸟准确识别。判别器的目标:准确判别是枯叶还是鸟

伪代码: 

案例:

原始数据:

案例结果: 

 案例完整代码:

# import os
import torch
import torch.nn as nn
import torchvision as tv
from torch.autograd import Variable
import tqdm
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['SimHei']  # 显示中文标签
plt.rcParams['axes.unicode_minus'] = False# dir = '... your path/faces/'
dir = './data/train_data'
# path = []
#
# for fileName in os.listdir(dir):
#     path.append(fileName)       # len(path)=51223noiseSize = 100     # 噪声维度
n_generator_feature = 64        # 生成器feature map数
n_discriminator_feature = 64        # 判别器feature map数
batch_size = 50
d_every = 1     # 每一个batch训练一次discriminator
g_every = 5     # 每五个batch训练一次generatorclass NetGenerator(nn.Module):def __init__(self):super(NetGenerator,self).__init__()self.main = nn.Sequential(      # 神经网络模块将按照在传入构造器的顺序依次被添加到计算图中执行nn.ConvTranspose2d(noiseSize, n_generator_feature * 8, kernel_size=4, stride=1, padding=0, bias=False),#转置卷积层:输入特征映射的尺寸会放大,通道数可能会减小,普通卷积层:输入特征映射的尺寸会缩小,但通道数可能会增加nn.BatchNorm2d(n_generator_feature * 8),nn.ReLU(True),       # (n_generator_feature * 8) × 4 × 4        (1-1)*1+1*(4-1)+0+1 = 4nn.ConvTranspose2d(n_generator_feature * 8, n_generator_feature * 4, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(n_generator_feature * 4),nn.ReLU(True),      # (n_generator_feature * 4) × 8 × 8     (4-1)*2-2*1+1*(4-1)+0+1 = 8nn.ConvTranspose2d(n_generator_feature * 4, n_generator_feature * 2, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(n_generator_feature * 2),nn.ReLU(True),  # (n_generator_feature * 2) × 16 × 16nn.ConvTranspose2d(n_generator_feature * 2, n_generator_feature, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(n_generator_feature),nn.ReLU(True),      # (n_generator_feature) × 32 × 32nn.ConvTranspose2d(n_generator_feature, 3, kernel_size=5, stride=3, padding=1, bias=False),nn.Tanh()       # 3 * 96 * 96)def forward(self, input):return self.main(input)class NetDiscriminator(nn.Module):def __init__(self):super(NetDiscriminator,self).__init__()self.main = nn.Sequential(nn.Conv2d(3, n_discriminator_feature, kernel_size=5, stride=3, padding=1, bias=False),nn.LeakyReLU(0.2, inplace=True),        # n_discriminator_feature * 32 * 32nn.Conv2d(n_discriminator_feature, n_discriminator_feature * 2, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(n_discriminator_feature * 2),nn.LeakyReLU(0.2, inplace=True),         # (n_discriminator_feature*2) * 16 * 16nn.Conv2d(n_discriminator_feature * 2, n_discriminator_feature * 4, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(n_discriminator_feature * 4),nn.LeakyReLU(0.2, inplace=True),  # (n_discriminator_feature*4) * 8 * 8nn.Conv2d(n_discriminator_feature * 4, n_discriminator_feature * 8, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(n_discriminator_feature * 8),nn.LeakyReLU(0.2, inplace=True),  # (n_discriminator_feature*8) * 4 * 4nn.Conv2d(n_discriminator_feature * 8, 1, kernel_size=4, stride=1, padding=0, bias=False),nn.Sigmoid()        # 输出一个概率)def forward(self, input):return self.main(input).view(-1)def train():for i, (image,_) in tqdm.tqdm(enumerate(dataloader)):       # type((image,_)) = <class 'list'>, len((image,_)) = 2 * 256 * 3 * 96 * 96real_image = Variable(image)#real_image = real_image.cuda()if (i + 1) % d_every == 0:  #d_every = 1,每一个batch训练一次discriminatoroptimizer_d.zero_grad()output = Discriminator(real_image)      # 尽可能把真图片判为Trueerror_d_real = criterion(output, true_labels)error_d_real.backward()noises.data.copy_(torch.randn(batch_size, noiseSize, 1, 1))fake_img = Generator(noises).detach()       # 根据噪声生成假图fake_output = Discriminator(fake_img)       # 尽可能把假图片判为Falseerror_d_fake = criterion(fake_output, fake_labels)error_d_fake.backward()optimizer_d.step()if (i + 1) % g_every == 0:optimizer_g.zero_grad()noises.data.copy_(torch.randn(batch_size, noiseSize, 1, 1))fake_img = Generator(noises)        # 这里没有detachfake_output = Discriminator(fake_img)       # 尽可能让Discriminator把假图片判为Trueerror_g = criterion(fake_output, true_labels)error_g.backward()optimizer_g.step()def show(num):fix_fake_imags = Generator(fix_noises)fix_fake_imags = fix_fake_imags.data.cpu()[:64] * 0.5 + 0.5# x = torch.rand(64, 3, 96, 96)fig = plt.figure(1)i = 1for image in fix_fake_imags:ax = fig.add_subplot(8, 8, eval('%d' % i)) #将Figure划分为8行8列的子图网格,并将当前的子图添加到第i个位置。# plt.xticks([]), plt.yticks([])  # 去除坐标轴plt.axis('off')plt.imshow(image.permute(1, 2, 0)) #permute()函数可以对维度进行重排,Matplotlib期望的图像格式是(H, W, C),即高度、宽度、通道i += 1plt.subplots_adjust(left=None,  # the left side of the subplots of the figureright=None,  # the right side of the subplots of the figurebottom=None,  # the bottom of the subplots of the figuretop=None,  # the top of the subplots of the figurewspace=0.05,  # the amount of width reserved for blank space between subplotshspace=0.05)  # the amount of height reserved for white space between subplots)plt.suptitle('第%d迭代结果' % num, y=0.91, fontsize=15)plt.savefig("images/%dcgan.png" % num)if __name__ == '__main__':transform = tv.transforms.Compose([tv.transforms.Resize(96),     # 图片尺寸, transforms.Scale transform is deprecatedtv.transforms.CenterCrop(96),tv.transforms.ToTensor(),tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))       # 变成[-1,1]的数])dataset = tv.datasets.ImageFolder(dir, transform=transform)dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4, drop_last=True)   # module 'torch.utils.data' has no attribute 'DataLoder'print('数据加载完毕!')Generator = NetGenerator()Discriminator = NetDiscriminator()optimizer_g = torch.optim.Adam(Generator.parameters(), lr=2e-4, betas=(0.5, 0.999))optimizer_d = torch.optim.Adam(Discriminator.parameters(), lr=2e-4, betas=(0.5, 0.999))criterion = torch.nn.BCELoss()true_labels = Variable(torch.ones(batch_size))     # batch_sizefake_labels = Variable(torch.zeros(batch_size))fix_noises = Variable(torch.randn(batch_size, noiseSize, 1, 1))noises = Variable(torch.randn(batch_size, noiseSize, 1, 1))     # 均值为0,方差为1的正态分布# if torch.cuda.is_available() == True:#     print('Cuda is available!')#     Generator.cuda()#     Discriminator.cuda()#     criterion.cuda()#     true_labels, fake_labels = true_labels.cuda(), fake_labels.cuda()#     fix_noises, noises = fix_noises.cuda(), noises.cuda()#plot_epoch = [1,5,10,50,100,200,500,800,1000,1500,2000,2500,3000]plot_epoch = [1,5,10,50,100,200,500,800,1000,1200,1500]for i in range(1500):        # 最大迭代次数train()print('迭代次数:{}'.format(i))if i in plot_epoch:show(i)

http://t.csdnimg.cn/FTSriicon-default.png?t=N7T8http://t.csdnimg.cn/FTSri

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/217597.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

vim + ctags 跳转, 查看函数定义

yum install ctags Package ctags-5.8-13.el7.x86_64 already installed and latest version 创建 /home/mzh/pptp-master/tags.sh #!/usr/bin/shWORKDIR/home/mzh/pptp-masterfind ${WORKDIR} -name "*.[c|h]" | xargs ctags -f ${WORKDIR}/tags find /usr/inclu…

排序算法:【冒泡排序】、逻辑运算符not用法、解释if not tag:

注意&#xff1a; 1、排序&#xff1a;将一组无序序列&#xff0c;调整为有序的序列。所谓有序&#xff0c;就是说&#xff0c;要么升序要么降序。 2、列表排序&#xff1a;将无序列表变成有序列表。 3、列表这个类里&#xff0c;内置排序方法&#xff1a;sort( )&#xff0…

大数据讲课笔记1.4 进程管理

文章目录 零、学习目标一、导入新课二、新课讲解&#xff08;一&#xff09;进程概述1、基本概念2、三维度看待进程3、引入多道编程模型&#xff08;1&#xff09;CPU利用率与进程数关系&#xff08;2&#xff09;从三个视角看多进程 4、进程的产生和消亡&#xff08;1&#xf…

平台工程与 DevOps 和 SRE 有何不同?

在现代软件开发和运营的动态领域中 &#xff0c;平台工程、DevOps 和站点可靠性工程 (SRE) 等术语 经常使用&#xff0c;有时可以互换使用&#xff0c;这常常会导致进入或浏览这些领域的专业人员感到困惑。了解这些概念之间的细微差别对于努力构建强大且可扩展的系统的组织至关…

爱智EdgerOS之深入解析安全可靠的开放协议SDDC

一、协议简介 在 EdgerOS 的智慧生态场景中&#xff0c;许多智能设备或传感器的生命周期都与 SDDC 协议息息相关&#xff0c;这些设备可能是使用 libsddc 智能配网技术开发的&#xff0c;也有可能是因为主要功能上是使用其他技术如 MQTT、LoRa 等但是设备的上下线依然是使用上…

构建外卖小程序:技术代码实践

在这个数字化的时代&#xff0c;外卖小程序已经成为餐饮业的一项重要工具。在本文中&#xff0c;我们将通过一些简单而实用的技术代码&#xff0c;向您展示如何构建一个基本的外卖小程序。我们将使用微信小程序平台作为例子&#xff0c;但这些原理同样适用于其他小程序平台。 …

连连看游戏

连通块记忆性递归的综合运用 这里x&#xff0c;y的设置反我平常的习惯&#xff0c;搞得我有点晕 实际上可以一输入就交换x&#xff0c;y的数据的 如果设置y1为全局变量的话会warning&#xff1a; warning: built-in function y1 declared as non-function 所以我改成p和q了…

阿里云人工智能平台PAI多篇论文入选EMNLP 2023

近期&#xff0c;阿里云人工智能平台PAI主导的多篇论文在EMNLP2023上入选。EMNLP是人工智能自然语言处理领域的顶级国际会议&#xff0c;聚焦于自然语言处理技术在各个应用场景的学术研究&#xff0c;尤其重视自然语言处理的实证研究。该会议曾推动了预训练语言模型、文本挖掘、…

Bytebase 2.12.0 - 改进自动补全和布局导航

&#x1f680; 新功能 支持 MySQL 高级自动补全。支持从 UI 上导入分类分级配置。 &#x1f514; 重大变更 作废已有企业版试用证书。之后可以通过提交申请获取新的试用证书。 &#x1f384; 改进 改进整体布局和导航。 支持在 SQL 编辑器里显示以及查询 PostgreSQL 数据…

HCIA-H12-811题目解析(9)

1、【单选题】下面选项中&#xff0c;能使一台IP地址为10.0.0.1的主机访问Interne的必要技术是&#xff1f; 2、【单选题】 FTP协议控制平面使用的端口号为&#xff1f; 3、【单选题】 使用FTP进行文件传输时&#xff0c;会建立多少个TCP连接&#xff1f; 4、【单选题】完成…

【算法Hot100系列】寻找两个正序数组的中位数

&#x1f49d;&#x1f49d;&#x1f49d;欢迎来到我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里可以感受到一份轻松愉快的氛围&#xff0c;不仅可以获得有趣的内容和知识&#xff0c;也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学…

WordPress主题Lolimeow v8.0.1二次元风格支持erphpdown付费下载

WordPress国人原创动漫主题lolimeow免费下载 lolimeow是一款WordPress国人原创主题&#xff0c;风格属于二次元、动漫、可爱萝莉风&#xff0c;带有后台设置&#xff0c;支持会员中心。该主题为免费主题。 1.侧栏/无侧栏切换&#xff01; 2.会员中心&#xff08;配套Erphpdown…

JVM 详解(JVM组成部分、双亲委派机制、垃圾回收算法、回收器、回收类型、了解调优思路)

目录 JVM 详解&#xff08;JVM组成部分、双亲委派机制、垃圾回收算法、回收器、回收类型、了解调优思路&#xff09;1、概念&#xff1a;什么是 JVM ?JVM 的作用&#xff1f; 2、JVM 的主要组成部分&#xff1f;类加载器&#xff08;Class Loader&#xff09;&#xff1a;简单…

Go实现http同步文件操作 - 增删改查

http同步文件操作 - 增删改查 http同步文件操作 - 增删改查1. 前置要求1.1. 构建结构体 文件名 文件内容1.1.1. 页面结构体1.1.2. 为Page结构体绑定方法&#xff1a;Save1.1.3. 对Page结构体支持页面内容查看方法&#xff0c;同时提供页面文件是否存在的方法 1.2. 简单验证上面…

联想笔记本如何安装Vmware ESXi

环境&#xff1a; Vmware ESXi 8.0 Vmware ESXi 6.7 联想E14笔记本 问题描述&#xff1a; 联想笔记本如何安装Vmware ESXi 解决方案&#xff1a; 1.官网下载镜像文件 https://customerconnect.vmware.com/en/downloads/search?queryesxi%208 下载 2.没有账户注册一个 …

vscode报错:建立连接:XHR failed

文章目录 问题解决方案 问题 Windows端ssh远程连接Linux端&#xff0c;Windows端vscode报错&#xff1a;“…XHR failed.” 解决方案 参考&#xff1a;解决 Windows 端 VS Code “无法与 “…“ 建立连接&#xff1a;XHR failed.” 问题 亲测有效。 总结&#xff1a; linux…

【媒体开发】利用FFMPEG进行推拉流

目录 1. 下载并启动媒体服务 2. 使用 FFMPEG 拉流并推送到指定服务地址 3. 客户端拉流 1. 下载并启动媒体服务 MediaMTX&#xff0c;也即之前的rtsp-simple-server&#xff0c;是一个即用型、零依赖的实时媒体服务器和媒体代理&#xff0c;允许发布、读取、代理和记录视频和…

深度学习第5天:GAN生成对抗网络

☁️主页 Nowl &#x1f525;专栏 《深度学习》 &#x1f4d1;君子坐而论道&#xff0c;少年起而行之 ​​ 文章目录 一、GAN1.基本思想2.用途3.模型架构 二、具体任务与代码1.任务介绍2.导入库函数3.生成器与判别器4.预处理5.模型训练6.图片生成7.不同训练轮次的结果对比 一…

51单片机的外部中断的以及相关寄存器的讲解

中断系统 本文主要涉及8051单片机的中断系统的讲解与使用 其中包括中断相关寄存器的介绍与使用以及外部中断初始化的代码分析。 文章目录 中断系统一、 中断的介绍二、 中断结构及相关寄存器2.1 中断源 2.2 中断请求控制器2.2.1 TCON寄存器2.2.2 SCON寄存器2.2.3 中断允许寄存器…

关于“Python”的核心知识点整理大全21

9.3.2 Python 2.7 中的继承 在Python 2.7中&#xff0c;继承语法稍有不同&#xff0c;ElectricCar类的定义类似于下面这样&#xff1a; class Car(object):def __init__(self, make, model, year):--snip-- class ElectricCar(Car):def __init__(self, make, model, year):supe…