使用pytorch构建一个初级的无监督的GAN网络模型

在这个系列中将系统的构建GAN及其相关的一些变种模型,来了解GAN的基本原理。本片为此系列的第一篇,实现起来很简单,所以不要期待有很好的效果出来。

第一篇我们搭建一个无监督的可以生成数字 (0-9) 手写图像的 GAN,使用MINIST数据集,包含0-9的60000张手写数字图像,如图:
在这里插入图片描述

原理

首先简单讲一下GAN的工作原理,如下为前向传播的过程:
在这里插入图片描述
GAN网络有两个模型,分别是生成器generator和判别器discriminator。generator的作用是生成图片的,也就是我们想要的结果,通过输入随机噪声来生成图片;而discriminator是判断输入的图片是真实数据还是生成的假数据,输入生成的假数据或真实数据,输出真与假的概率值。

而反向传播过程其实是分开的,即generator和discriminator是分别进行梯度更新的。且交替进行训练的,一个模型训练,另一个模型就要保持不变,保持两个模型的能力要相当才能一起进步,否则如果判别器的性能要比生成器要好的话就很容易陷入模式崩溃mdoel collapse或梯度消失等。
下图为discriminator的反向传播的过程:
在这里插入图片描述
discriminator的工作是为了将生成的假数据判别为0,将真实的数据判别为1,即公正判别非黑即白,所以loss的计算为:
在这里插入图片描述

下图为generator的反向传播的过程:
在这里插入图片描述
而generator的工作是为了将生成的假数据让discriminator判别为1,即骗过discriminator颠倒黑白,所以loss的计算为:
在这里插入图片描述

代码

下面开始直接上代码,我在网上学习别人代码的习惯是先把所有代码跑起来再来仔细看每个代码模块,我在这也就先放上所有代码再分析各个模块。
model.py:

from torch import nn
import torchdef get_generator_block(input_dim, output_dim):return nn.Sequential(nn.Linear(input_dim, output_dim),nn.BatchNorm1d(output_dim),nn.ReLU(inplace=True),)class Generator(nn.Module):def __init__(self, z_dim=10, im_dim=784, hidden_dim=128):super(Generator, self).__init__()self.gen = nn.Sequential(get_generator_block(z_dim, hidden_dim),get_generator_block(hidden_dim, hidden_dim * 2),get_generator_block(hidden_dim * 2, hidden_dim * 4),get_generator_block(hidden_dim * 4, hidden_dim * 8),nn.Linear(hidden_dim * 8, im_dim),nn.Sigmoid())def forward(self, noise):return self.gen(noise)def get_gen(self):return self.gendef get_discriminator_block(input_dim, output_dim):return nn.Sequential(nn.Linear(input_dim, output_dim), #Layer 1nn.LeakyReLU(0.2, inplace=True))class Discriminator(nn.Module):def __init__(self, im_dim=784, hidden_dim=128):super(Discriminator, self).__init__()self.disc = nn.Sequential(get_discriminator_block(im_dim, hidden_dim * 4),get_discriminator_block(hidden_dim * 4, hidden_dim * 2),get_discriminator_block(hidden_dim * 2, hidden_dim),nn.Linear(hidden_dim, 1))def forward(self, image):return self.disc(image)def get_disc(self):return self.disc

train.py:

import torch
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.datasets import MNIST # Training dataset
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from model import Discriminator, Generator
torch.manual_seed(0) # Set for testing purposes, please do not change!def show_tensor_images(image_tensor, num_images=25, size=(1, 28, 28)):image_unflat = image_tensor.detach().cpu().view(-1, *size)image_grid = make_grid(image_unflat[:num_images], nrow=5)plt.imshow(image_grid.permute(1, 2, 0).squeeze())plt.show()def get_noise(n_samples, z_dim, device='cpu'):return torch.randn(n_samples,z_dim,device=device)criterion = nn.BCEWithLogitsLoss()
n_epochs = 200
z_dim = 64
display_step = 500
batch_size = 128
lr = 0.00001
device = 'cuda'dataloader = DataLoader(MNIST('./', download=True, transform=transforms.ToTensor()),  # 已经下载过的可以改为False跳过下载batch_size=batch_size,shuffle=True)gen = Generator(z_dim).to(device)
gen_opt = torch.optim.Adam(gen.parameters(), lr=lr)
disc = Discriminator().to(device)
disc_opt = torch.optim.Adam(disc.parameters(), lr=lr)def get_disc_loss(gen, disc, criterion, real, num_images, z_dim, device):fake_noise = get_noise(num_images, z_dim, device=device)fake = gen(fake_noise)disc_fake_pred = disc(fake.detach())disc_fake_loss = criterion(disc_fake_pred, torch.zeros_like(disc_fake_pred))disc_real_pred = disc(real)disc_real_loss = criterion(disc_real_pred, torch.ones_like(disc_real_pred))disc_loss = (disc_fake_loss + disc_real_loss) / 2return disc_lossdef get_gen_loss(gen, disc, criterion, num_images, z_dim, device):fake_noise = get_noise(num_images, z_dim, device=device)fake = gen(fake_noise)disc_fake_pred = disc(fake)gen_loss = criterion(disc_fake_pred, torch.ones_like(disc_fake_pred))return gen_losscur_step = 0
mean_generator_loss = 0
mean_discriminator_loss = 0
gen_loss = False
error = False
for epoch in range(n_epochs):# Dataloader returns the batchesfor real, _ in tqdm(dataloader):cur_batch_size = len(real)# Flatten the batch of real images from the datasetreal = real.view(cur_batch_size, -1).to(device)### Update discriminator #### Zero out the gradients before backpropagationdisc_opt.zero_grad()# Calculate discriminator lossdisc_loss = get_disc_loss(gen, disc, criterion, real, cur_batch_size, z_dim, device)# Update gradientsdisc_loss.backward(retain_graph=True)# Update optimizerdisc_opt.step()### Update generator ###gen_opt.zero_grad()gen_loss = get_gen_loss(gen, disc, criterion, cur_batch_size, z_dim, device)gen_loss.backward()gen_opt.step()# Keep track of the average discriminator lossmean_discriminator_loss += disc_loss.item() / display_step# Keep track of the average generator lossmean_generator_loss += gen_loss.item() / display_step### Visualization code ###if cur_step % display_step == 0 and cur_step > 0:print(f"Step {cur_step}: Generator loss: {mean_generator_loss}, discriminator loss: {mean_discriminator_loss}")fake_noise = get_noise(cur_batch_size, z_dim, device=device)fake = gen(fake_noise)show_tensor_images(fake)show_tensor_images(real)mean_generator_loss = 0mean_discriminator_loss = 0cur_step += 1

运行结果

运行后每隔500个epoch画出fake和real,刚开始的fake和real是这样的:
在这里插入图片描述
在这里插入图片描述
到后面的fake逐渐变成这样:
在这里插入图片描述

代码解释

网络模型

model.py里面存放了generator和discriminator的网络模型,神经元使用的是简单的全连接层,后面的文章再使用卷积。
在这里插入图片描述
生成器输出为784 = 28 * 28,因为使用的是MINIST手写字体数据集,每张图的shape是28 * 28 * 1(黑白图单通道),所以输出的假数据要与真实数据的shape一致,这样输入鉴别器才不会出错。
在这里插入图片描述
生成的图片(或真实数据)直接输入鉴别器,所以鉴别器的输入也是28*28,而输出为1,即输出判别结果为真或假。
在这里插入图片描述
每个优化器仅优化一个模型的参数,所以一个模型构建一个优化器。

图像显示

在这里插入图片描述
首先将图像的tensor转到cpu上,因为PyTorch中的大部分图像处理和显示函数都是在CPU上执行的,包括我们使用的imshow。
detach() 方法将张量从计算图中分离出来,但是仍指向原变量的存放位置,不同之处只是requirse_grad为false,得到的这个tensor永远不需要计算器梯度,不具有grad,这样做的目的是避免梯度计算的影响,因为在展示图像时通常不需要计算梯度。
Pytorch的计算图由节点和边组成,节点表示张量或者Function,边表示张量和Function之间的依赖关系,类似这样:
在这里插入图片描述
一个网络模型就是一个计算图,在网络backward时候,需要用链式求导法则求出网络最后输出的梯度,然后再对网络进行优化,求导过程就如上图这样。
make_grid 函数用于将多个图像组成一个网格,方便显示。
在这里插入图片描述

在这里插入图片描述
然后每500个batch显示一次当前模型性能所能生成的图片以及当前batch的真实图片(虽然一个batch设置了128张,但是我们只展示25张),以及print出生成器和鉴别器的loss。

损失函数

在这里插入图片描述
损失函数的原理在上面的“原理”中有讲解,这里不再赘述。
在计算鉴别器的loss里,disc_fake_pred = disc(fake.detach())是对生成图片的判别,这里也使用 .detach() 的目的是将生成器产生的假数据与生成器的参数分离,使得在计算 disc_fake_pred 时不会对生成器的梯度进行传播。这是因为在训练鉴别器的阶段,我们只希望更新鉴别器的参数,而不希望更新生成器的参数(就如上面说的生成器的训练和鉴别器的应该要隔开分别训练、交替训练)。

反向传播

在这里插入图片描述
retain_graph=True参数是用来指示 PyTorch 在反向传播时保留计算图。这个参数的作用是为了在一次反向传播之后保留计算图的状态,以便后续再次调用 backward() 函数时能够继续使用这个计算图进行梯度计算。
Pytoch构建的计算图是动态图,为了节约内存,所以每次一轮迭代完之后计算图就被在内存释放,所以当你想要多次backward时候就会报如下错:

RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed.

而在GAN中一次的迭代需要先更新鉴别器的参数,然后再更新生成器的参数;在更新生成器的参数时,我们仍然需要使用鉴别器来鉴别real or fake,只要使用到鉴别器就需要他的计算图。因此,我们需要在调用 disc_loss.backward() 后保留计算图,以便后续调用 gen_loss.backward() 时能够继续使用相同的计算图进行梯度计算。而对于生成器的梯度更新 gen_loss.backward(),不需要显式指定 retain_graph=True。
所以,在同一个计算图上多次调用 backward() 函数时才需要使用它。

下一篇构建DCGAN。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/296818.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

BugKu:Simple SSTI

1.进入此题 2.查看源代码 可以知道要传入一个名为flag的参数,又说我们经常设置一个secret_key 3.flask模版注入 /?flag{{config.SECRET_KEY}} 4.学有所思 4.1 什么是flask? flask是用python编写的一个轻量web开发框架 4.2 SSTI成因(SST…

RIP协议(路由信息协议)

一、RIP协议概述 RIP协议(Routing Information Protocol,路由信息协议)是一种基于距离矢量的内部网关协议,即根据跳数来度量路由开销,进行路由选择。 相比于其它路由协议(如OSPF、ISIS等)&#…

Spring Cloud微服务入门(二)

微服务的技术栈 服务治理: 服务注册、发现、调用。 负载均衡: 高可用、集群部署。 容错: 避免雪崩、削峰、服务降级。 消息总线: 消息队列、异步通信,数据一致性。 网关: 校验路径、请求转发、服务集成…

使用Python获取红某书笔记详情并批量无水印下载

根据红某手最新版 请求接口必须要携带x-s x-s-c x-t,而调用官方接口又必须携带cookie,缺一不可,获取笔记详情可以通过爬取网页的形式获取,虽然也是无水印,但是一些详情信息只能获取大概,并不是详细的数值,因此既不想自己破解x-s x…

UE5 C++ LevelSequence

前言 最近在用UE C做一些功能,用到了Level Sequence功能,但是看了下UE官方论坛包括一些文章基本没有关于C 处理Level Sequence 这块内容,有的也是一些修改或者源码原理的一些内容分析,接下来我就把我新建Sequence包括一些库的调用…

# 达梦数据库知识点

达梦数据库知识点 测试数据 -- SYSDBA.TABLE_CLASS_TEST definitionCREATE TABLE SYSDBA.TABLE_CLASS_TEST (ID VARCHAR(100) NOT NULL,NAME VARCHAR(100) NULL,CODE VARCHAR(100) NULL,TITLE VARCHAR(100) NULL,CREATETIME TIMESTAMP NULL,COLUMN1 VARCHAR(100) NULL,COLUMN…

云容器引擎CCE弹性伸缩

CCE弹性伸缩介绍 CCE的弹性伸缩能力分为如下两个维度: 工作负载弹性伸缩:即调度层弹性,主要是负责修改负载的调度容量变化。例如,HPA是典型的调度层弹性组件,通过HPA可以调整应用的副本数,调整的副本数会…

ShardingJdbc+Mybatis实现多数据源

Mybatis多数据源 这个是对shardingjdbc应用的一个升级,如果对于shardingjdbc的整合还没看过之前的文章的,可以先看看文章https://blog.csdn.net/Think_and_work/article/details/137174049?spm1001.2014.3001.5501 整合步骤 1、依赖 和全新项目的单…

翻译: 硅谷软件工程师面试:准备所需的一切

没有人有时间去做成百上千道LeetCode题目,好消息是你实际上并不需要做那么多题目就能够在FAANG公司找到工作! 我曾经在Grab工作,这是东南亚的一家共享出行公司,但我对工作感到沮丧,想要进入FAANG公司,但我…

Linux------一篇博客了解Linux最常用的指令

🎈个人主页:靓仔很忙i 💻B 站主页:👉B站👈 🎉欢迎 👍点赞✍评论⭐收藏 🤗收录专栏:Linux 🤝希望本文对您有所裨益,如有不足之处&#…

LeetCode-199. 二叉树的右视图【树 深度优先搜索 广度优先搜索 二叉树】

LeetCode-199. 二叉树的右视图【树 深度优先搜索 广度优先搜索 二叉树】 题目描述:解题思路一:广度优先搜索解题思路二:深度优先搜索解题思路三:0 题目描述: 给定一个二叉树的 根节点 root,想象自己站在它…

基于顺序表的学生成绩管理系统

🌈 个人主页:白子寰 🔥 分类专栏:python从入门到精通,魔法指针,进阶C,C语言,C语言题集,C语言实现游戏👈 希望得到您的订阅和支持~ 💡 坚持创作博文…

Android 自定义View 测量控件宽高、自定义viewgroup测量

1、View生命周期以及View层级 1.1、View生命周期 View的主要生命周期如下所示, 包括创建、测量(onMeasure)、布局(onLayout)、绘制(onDraw)以及销毁等流程。 自定义View主要涉及到onMeasure、…

设置asp.net core WebApi函数请求参数可空的两种方式

以下面定义的asp.net core WebApi函数为例,客户端发送申请时,默认三个参数均为必填项,不填会报错,如下图所示: [HttpGet] public string GetSpecifyValue(string param1,string param2,string param3) {return $"…

LeetCode-108. 将有序数组转换为二叉搜索树【树 二叉搜索树 数组 分治 二叉树】

LeetCode-108. 将有序数组转换为二叉搜索树【树 二叉搜索树 数组 分治 二叉树】 题目描述:解题思路一:中序遍历,总是选择中间位置左边的数字作为根节点解题思路二:0解题思路三:0 题目描述: 给你一个整数数…

【软件工程】概要设计

1. 导言 1.1 目的 该文档的目的是描述学生成绩管理系统的概要设计,其主要内容包括: 系统功能简介 系统结构简介 系统接口设计 数据设计 模块设计 界面设计 本文的预期读者是: 项目开发人员 项目管理人员 项目评测人员(…

利用开源AI引擎:打造安全生产作业人员穿戴检测应用平台

在电力行业中,作业人员的安全是至关重要的。为了确保工作人员在进行电力设施操作时的个人安全,需要对作业人员的安全穿戴情况进行严格监控。随着计算视觉技术的发展,特别是图像处理和目标检测技术的进步,我们可以通过自动化的方式…

扫雷(蓝桥杯)

题目描述 小明最近迷上了一款名为《扫雷》的游戏。其中有一个关卡的任务如下, 在一个二维平面上放置着 n 个炸雷,第 i 个炸雷 (xi , yi ,ri) 表示在坐标 (xi , yi) 处存在一个炸雷,它的爆炸范围是以半径为 ri 的一个圆。 为了顺利通过这片土…

链路追踪原理

分布式系统为什么需要链路追踪? 随着互联网业务快速扩展,软件架构也日益变得复杂,为了适应海量用户高并发请求,系统中越来越多的组件开始走向分布式化,如单体架构拆分为微服务、服务内缓存变为分布式缓存、服务组件通…

网络原理 - HTTP / HTTPS(3)——http响应

目录 一、认识 “状态码”(status code) 常见的状态码 (1)200 OK (2)404 Not Found (3)403 ForBidden (4)405 Method Not Allowed (5&…