第G1周:生成对抗网络(GAN)入门

🍨 本文为[🔗365天深度学习训练营]内部限免文章(版权归 *K同学啊* 所有)
🍖 作者:[K同学啊]

一、理论基础
生成对抗网络(Generative Adversarial Networks, GAN)是近年来深度学习领域的一个热点方向。GAN并不指代某一个具体的神经网络,而是指一类基于博弈思想而设计的神经网络。GAN由两个分别被称为生成器(Generator)和判别器(Discriminator)的神经网络组成。其中,生成器从某种噪声分布中随机采样作为输入,输出与训练集中真实样本非常相似的人工样本;判别器的输入则为真实样本或人工样本,其目的是将人工样本与真实样本尽可能地区分出来。生成器和判别器交替运行,相互博弈,各自的能力都得到升。理想情况下,经过足够次数的博弈之后,判别器无法判断给定样本的真实性,即对于所有样本都输出50%真,50%假的判断。此时,生成器输出的人工样本已经逼真到使判别器无法分辨真假,停止博弈。这样就可以得到一个具有“伪造”真实样本能力的生成器。
1. 生成器

GANs中,生成器 G 选取随机噪声 z 作为输入,通过生成器的不断拟合,最终输出一个和真实样本尺寸相同,分布相似的伪造样本G(z)。生成器的本质是一个使用生成式方法的模型,它对数据的分布假设和分布参数进行学习,然后根据学习到的模型重新采样出新的样本。
从数学上来说,生成式方法对于给定的真实数据,首先需要对数据的显式变量或隐含变量做分布假设;然后再将真实数据输入到模型中对变量、参数进行训练;最后得到一个学习后的近似分布,这个分布可以用来生成新的数据。从机器学习的角度来说,模型不会去做分布假设,而是通过不断地学习真实数据,对模型进行修正,最后也可以得到一个学习后的模型来做样本生成任务。这种方法不同于数学方法,学习的过程对人类理解较不直观。

2. 判别器
GANs中,判别器 D 对于输入的样本 x,输出一个[0,1]之间的概率数值D(x)。x 可能是来自于原始数据集中的真实样本 x,也可能是来自于生成器 G 的人工样本G(z)。通常约定,概率值D(x)越接近于1就代表此样本为真实样本的可能性更大;反之概率值越小则此样本为伪造样本的可能性越大。也就是说,这里的判别器是一个二分类的神经网络分类器,目的不是判定输入数据的原始类别,而是区分输入样本的真伪。可以注意到,不管在生成器还是判别器中,样本的类别信息都没有用到,也表明 GAN 是一个无监督的学习过程。

3. 基本原理
GAN是博弈论和机器学习相结合的产物,于2014年Ian Goodfellow的论文中问世,一经问世即火爆足以看出人们对于这种算法的认可和狂热的研究热忱。想要更详细的了解GAN,就要知道它是怎么来的,以及这种算法出现的意义是什么。研究者最初想要通过计算机完成自动生成数据的功能,例如通过训练某种算法模型,让某模型学习过一些苹果的图片后能自动生成苹果的图片,具备些功能的算法即认为具有生成功能。但是GAN不是第一个生成算法,而是以往的生成算法在衡量生成图片和真实图片的差距时采用均方误差作为损失函数,但是研究者发现有时均方误差一样的两张生成图片效果却截然不同,鉴于此不足Ian Goodfellow提出了GAN。

image.png

那么GAN是如何完成生成图片这项功能的呢,如图1所示,GAN是由两个模型组成的:生成模型G和判别模型D。首先第一代生成模型1G的输入是随机噪声z,然后生成模型会生成一张初级照片,训练一代判别模型1D另其进行二分类操作,将生成的图片判别为0,而真实图片判别为1;为了欺瞒一代鉴别器,于是一代生成模型开始优化,然后它进阶成了二代,当它生成的数据成功欺瞒1D时,鉴别模型也会优化更新,进而升级为2D,按照同样的过程也会不断更新出N代的G和D。

二、前期准备工作

1. 定义超参数

import argparse
import os
import numpy as np
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable
import torch.nn as nn
import torch## 创建文件夹
os.makedirs("./images/", exist_ok=True)         ## 记录训练过程的图片效果
os.makedirs("./save/", exist_ok=True)           ## 训练完成时模型保存的位置
os.makedirs("./datasets/mnist", exist_ok=True)      ## 下载数据集存放的位置## 超参数配置
n_epochs=50
batch_size=512
lr=0.0002
b1=0.5
b2=0.999
n_cpu=2
latent_dim=100
img_size=28
channels=1
sample_interval=500## 图像的尺寸:(1, 28, 28),  和图像的像素面积:(784)
img_shape = (channels, img_size, img_size)
img_area = np.prod(img_shape)## 设置cuda:(cuda:0)
cuda = True if torch.cuda.is_available() else False
print(cuda)
## mnist数据集下载
mnist = datasets.MNIST(root='./datasets/', train=True, download=True, transform=transforms.Compose([transforms.Resize(img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]),
)
## 配置数据到加载器
dataloader = DataLoader(mnist,batch_size=batch_size,shuffle=True,
)
class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.model = nn.Sequential(nn.Linear(img_area, 512),         # 输入特征数为784,输出为512nn.LeakyReLU(0.2, inplace=True),  # 进行非线性映射nn.Linear(512, 256),              # 输入特征数为512,输出为256nn.LeakyReLU(0.2, inplace=True),  # 进行非线性映射nn.Linear(256, 1),                # 输入特征数为256,输出为1nn.Sigmoid(),                     # sigmoid是一个激活函数,二分类问题中可将实数映射到[0, 1],作为概率值, 多分类用softmax函数)def forward(self, img):img_flat = img.view(img.size(0), -1) # 鉴别器输入是一个被view展开的(784)的一维图像:(64, 784)validity = self.model(img_flat)      # 通过鉴别器网络return validity                      # 鉴别器返回的是一个[0, 1]间的概率
class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()## 模型中间块儿def block(in_feat, out_feat, normalize=True):        # block(in, out )layers = [nn.Linear(in_feat, out_feat)]          # 线性变换将输入映射到out维if normalize:layers.append(nn.BatchNorm1d(out_feat, 0.8)) # 正则化layers.append(nn.LeakyReLU(0.2, inplace=True))   # 非线性激活函数return layers## prod():返回给定轴上的数组元素的乘积:1*28*28=784self.model = nn.Sequential(*block(latent_dim, 128, normalize=False), # 线性变化将输入映射 100 to 128, 正则化, LeakyReLU*block(128, 256),                         # 线性变化将输入映射 128 to 256, 正则化, LeakyReLU*block(256, 512),                         # 线性变化将输入映射 256 to 512, 正则化, LeakyReLU*block(512, 1024),                        # 线性变化将输入映射 512 to 1024, 正则化, LeakyReLUnn.Linear(1024, img_area),                # 线性变化将输入映射 1024 to 784nn.Tanh()                                 # 将(784)的数据每一个都映射到[-1, 1]之间)## view():相当于numpy中的reshape,重新定义矩阵的形状:这里是reshape(64, 1, 28, 28)def forward(self, z):                           # 输入的是(64, 100)的噪声数据imgs = self.model(z)                        # 噪声数据通过生成器模型imgs = imgs.view(imgs.size(0), *img_shape)  # reshape成(64, 1, 28, 28)return imgs                                 # 输出为64张大小为(1, 28, 28)的图像
## 创建生成器,判别器对象
generator = Generator()
discriminator = Discriminator()## 首先需要定义loss的度量方式  (二分类的交叉熵)
criterion = torch.nn.BCELoss()## 其次定义 优化函数,优化函数的学习率为0.0003
## betas:用于计算梯度以及梯度平方的运行平均值的系数
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))## 如果有显卡,都在cuda模式中运行
if torch.cuda.is_available():generator     = generator.cuda()discriminator = discriminator.cuda()criterion     = criterion.cuda()
for epoch in range(n_epochs):                   # epoch:50for i, (imgs, _) in enumerate(dataloader):  # imgs:(64, 1, 28, 28)     _:label(64)imgs = imgs.view(imgs.size(0), -1)    # 将图片展开为28*28=784  imgs:(64, 784)real_img = Variable(imgs).cuda()      # 将tensor变成Variable放入计算图中,tensor变成variable之后才能进行反向传播求梯度real_label = Variable(torch.ones(imgs.size(0), 1)).cuda()      ## 定义真实的图片label为1fake_label = Variable(torch.zeros(imgs.size(0), 1)).cuda()     ## 定义假的图片的label为0real_out = discriminator(real_img)            # 将真实图片放入判别器中loss_real_D = criterion(real_out, real_label) # 得到真实图片的lossreal_scores = real_out                        # 得到真实图片的判别值,输出的值越接近1越好## 计算假的图片的损失## detach(): 从当前计算图中分离下来避免梯度传到G,因为G不用更新z = Variable(torch.randn(imgs.size(0), latent_dim)).cuda()      ## 随机生成一些噪声, 大小为(128, 100)fake_img    = generator(z).detach()                                    ## 随机噪声放入生成网络中,生成一张假的图片。fake_out    = discriminator(fake_img)                                  ## 判别器判断假的图片loss_fake_D = criterion(fake_out, fake_label)                       ## 得到假的图片的lossfake_scores = fake_out## 损失函数和优化loss_D = loss_real_D + loss_fake_D  # 损失包括判真损失和判假损失optimizer_D.zero_grad()             # 在反向传播之前,先将梯度归0loss_D.backward()                   # 将误差反向传播optimizer_D.step()                  # 更新参数z = Variable(torch.randn(imgs.size(0), latent_dim)).cuda()      ## 得到随机噪声fake_img = generator(z)                                             ## 随机噪声输入到生成器中,得到一副假的图片output = discriminator(fake_img)                                    ## 经过判别器得到的结果## 损失函数和优化loss_G = criterion(output, real_label)                              ## 得到的假的图片与真实的图片的label的lossoptimizer_G.zero_grad()                                             ## 梯度归0loss_G.backward()                                                   ## 进行反向传播optimizer_G.step()                                                  ## step()一般用在反向传播后面,用于更新生成网络的参数## 打印训练过程中的日志## item():取出单元素张量的元素值并返回该值,保持原元素类型不变if ( i + 1 ) % 100 == 0:print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] [D real: %f] [D fake: %f]"% (epoch, n_epochs, i, len(dataloader), loss_D.item(), loss_G.item(), real_scores.data.mean(), fake_scores.data.mean()))## 保存训练过程中的图像batches_done = epoch * len(dataloader) + iif batches_done % sample_interval == 0:save_image(fake_img.data[:25], "./images/%d.png" % batches_done, nrow=5, normalize=True)
torch.save(generator.state_dict(), './generator.pth')
torch.save(discriminator.state_dict(), './discriminator.pth')

部分运行截图:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/90949.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

adb对安卓app进行抓包(ip连接设备)

adb对安卓app进行抓包(ip连接设备) 一,首先将安卓设备的开发者模式打开,提示允许adb调试 二,自己的笔记本要和安卓设备在同一个网段下(同连一个WiFi就可以了) 三,在笔记本上根据i…

【Git】安装以及基本操作

目录 一、初识Git二、 在Linux底下安装Git一)centOS二)Ubuntu 三、 Git基本操作一) 创建本地仓库二)配置本地仓库三)认识工作区、暂存区、版本库四)添加文件五)查看.git文件六)修改文…

【分布式技术专题】RocketMQ延迟消息实现原理和源码分析

痛点背景 业务场景 假设有这么一个需求,用户下单后如果30分钟未支付,则该订单需要被关闭。你会怎么做? 之前方案 最简单的做法,可以服务端启动个定时器,隔个几秒扫描数据库中待支付的订单,如果(当前时间-订…

RocketMQ双主双从同步集群部署

🎈 作者:互联网-小啊宇 🎈 简介: CSDN 运维领域创作者、阿里云专家博主。目前从事 Kubernetes运维相关工作,擅长Linux系统运维、开源监控软件维护、Kubernetes容器技术、CI/CD持续集成、自动化运维、开源软件部署维护…

Java接口压力测试—如何应对并优化Java接口的压力测试

导言 在如今的互联网时代,Java接口压力测试是评估系统性能和可靠性的关键一环。一旦接口不能承受高并发量,用户体验将受到严重影响,甚至可能导致系统崩溃。因此,了解如何进行有效的Java接口压力测试以及如何优化接口性能至关重要…

Linux系统USB摄像头测试程序(二)_读取配置

1、收先安装gtk3,我的测试机器是ubutn16.04,只要执行下面的安装命令就可以了 apt-get install libgtk-3-dev 使用下列命令验证是否安装好gtk3: pkg-config --cflags --libs gtk-3.0 2、显示结果类似如下: -pthre…

这是一篇关于SQL 脚本表间连接join的可视化说明

使用SQL合并两个数据集可以通过JOINS来完成。JOIN是查询的FROM子句中的SQL指令,用于标识要查询的表以及它们应该如何组合。 主键和外键 通常,在关系数据库中,数据被组织到由属性(列)和记录(行&#xff09…

MySQL运维

日志 错误日志 show VARIABLES like %log_error%;使用 tail -f 错误文件路径 可以查看具体错误二进制日志 show variables like %log_bin%;在my.ini文件下的mysqlID下添加 log_binmysql-bin binlog-formatROW重启就开启binlog了 show VARIABLES like %binlog_format%;mys…

i18n 配置vue项目中英文语言包(中英文转化)

一、实现效果 二、下载插件创建文件夹 2.1 下载cookie来存储 npm install --save js-cookienpm i vue-i18n -S 2.2 封装组件多页面应用 2.3 创建配置语言包字段 三、示例代码 3.1 main.js 引用 i18n.js import i18n from ./lang// 实现语言切换:i18n处理element&#xff0c…

屏蔽socket 实例化时,握手阶段报错信息WebSocket connection to ‘***‘ failed

事情起因是这样的: 我们网站是需要socket链接实行实时推送服务,有恶意竞争对手通过抓包或者断网,获取到了我们的socket链接地址,那么他就可以通过java写一个脚本无限链接这个socket地址。形成dos攻击。使socket服务器资源耗尽&…

【运维】linkis安装dss保姆级教程与踩坑实践

文章目录 一. 安装准备二. 创建用户三. 准备安装包四. 修改配置1. 修改config.sh2. 修改db.sh 五、安装和使用1. 执行安装脚本2. 启动服务3. 查看验证是否成功 六. 报错处理报错一:The user is not logged in报错二:dss接口报错报错三:执行没…

算法随笔:图论问题之割点割边

割点 定义 割点的定义:如果一个点被删除之后会导致整个图不再是一个连通图,那么这个顶点就是这个图的割点。举例: 上图中的点2就是一个割点,如果它被删除,则整个图被分为两个连通分量,不再是一个连通图。…

vue3多条件搜索功能

搜索功能在后台管理页面中非常常见&#xff0c;本篇就着重讲一下vue3-admin-element框架中如何实现一个顶部多条件搜索功能 一、首先需要在vue页面的<template></template>中写入对应的结构 <!-- 搜索 --><div style"display: flex; justify-content…

突破大模型 | Alluxio助力AI大模型训练-成功案例(一)

更多详细内容可见《Alluxio助力AI大模型训练制胜宝典》 【案例一&#xff1a;知乎】多云缓存在知乎的探索:从UnionStore到Alluxio 作者&#xff1a;胡梦宇-知乎大数据基础架构开发工程师&#xff08;内容转载自InfoQ&#xff09; 一、背景 随着云原生技术的飞速发展&#xff…

ApacheCon - 云原生大数据上的 Apache 项目实践

Apache 软件基金会的官方全球系列大会 CommunityOverCode Asia&#xff08;原 ApacheCon Asia&#xff09;首次中国线下峰会将于 2023 年 8 月 18-20 日在北京丽亭华苑酒店举办&#xff0c;大会含 17 个论坛方向、上百个前沿议题。 字节跳动云原生计算团队在此次 CommunityOve…

Java多线程编程中的线程间通信

Java多线程编程中的线程间通信 基本概念&#xff1a; ​ 线程间通信是多线程编程中的一个重要概念&#xff0c;指的是不同线程之间如何协调和交换信息&#xff0c;以达到共同完成任务的目的。 线程间通信的目的 ​ 是确保多个线程能够按照一定的顺序和规则进行协作&#xff…

使用QT可视化设计对话框详细步骤与代码

一、创建对话框基本步骤 创建并初始化子窗口部件把子窗口部件放到布局中设置tab键顺序建立信号-槽之间的连接实现对话框中的自定义槽 首先前面三步在这里是通过ui文件里面直接进行的&#xff0c;剩下两步则是通过代码来实现 二、项目创建详细步骤 创建新项目 为项目命名 为…

计算机组成原理之地址映射

例1&#xff1a;某计算机主存容量256MB&#xff0c;按字编址&#xff0c;字长1B&#xff0c;块大小32B&#xff0c;Cache容量512KB。对如下的直接映射方式、4-路组相联映射方式、全相联映射方式的内存地址格式&#xff0c;求&#xff1a; &#xff08;1&#xff09;计算A、B、C…

案例研究|大福中国通过JumpServer满足等保合规和资产管理双重需求

“大福中国为了满足安全合规要求引入堡垒机产品&#xff0c;在对比了传统型堡垒机后&#xff0c;发现JumpServer使用部署更加灵活&#xff0c;功能特性丰富&#xff0c;能够较好地满足公司在等保合规和资产管理方面的双重需求。” ——大福&#xff08;中国&#xff09;有限公…

网络编程(8.14)TCP并发服务器模型

作业&#xff1a; 1. 多线程中的newfd&#xff0c;能否修改成全局&#xff0c;不行&#xff0c;为什么&#xff1f; 2. 多线程中分支线程的newfd能否不另存&#xff0c;直接用指针间接访问主线程中的newfd,不行&#xff0c;为什么&#xff1f; 多线程并发服务器模型原代码&…