生成对抗网络(GAN)详解与实例

GAN介绍

理解GAN的直观方法是从博弈论的角度来理解它。GAN由两个参与者组成,即一个生成器和一个判别器,它们都试图击败对方。生成备从分巾中狄取一些随机噪声,并试图从中生成一些类似于输出的分布。生成器总是试图创建与真实分布没有区别的分布。也就是说,伪造的输出看起来应该是真实的图像。 然而,如果没有显式训练或标注,那么生成器将无法判别真实的图像,并且其唯一的来源就是随机浮点数的张量。

之后,GAN将在博弈中引入另一个参与者,即判别器。判别器仅负责通知生成器其生成的输出看起来不像真实图像,以便生成器更改其生成图像的方式以使判别器确信它是真实图像。 但是判别器总是可以告诉生成器其生成的图像不是真实的,因为判别器知道图像是从生成器生成的。为了解决这个事情,GAN将真实的图像引入博弈中,并将判别器与生成器隔离。现在,判别器从一组真实图像中获取一个图像,并从生成器中获取一个伪图像,而它必须找出每个图像的来源。

最初,判别器什么都不知道,而是随机预测结果。 但是,可以将判别器的任务修改为分类任务。判别器可以将输入图像分类为原始图像或生成图像,这是二元分类。同样,我们训练判别器网络以正确地对图像进行分类,最终,通过反向传播,判别器学会了区分真实图像和生成图像。

在这里插入图片描述

代码实例

数据集简介:
本次实验我们选用花卉数据集做图像的生成,本数据集共六类。
在这里插入图片描述

模型训练
训练判别器:
对于真图片,输出尽可能是1
对于假图片,输出尽可能是0
训练生成器:
对于假图片,输出尽可能是1
1、训练生成器时,无须调整判别器的参数;训练判别器时,无须调整生成器的参数。
2、在训练判别器时,需要对生成器生成的图片用detach操作进行计算图截断,避免反向传播将梯度传到生成器中。因为在训练判别器时我们不需要训练生成器,也就不需要生成器的梯度。
3、在训练判别器时,需要反向传播两次,一次是希望把真图片判为1,一次是希望把假图片判为0。也可以将这两者的数据放到一个batch中,进行一次前向传播和一次反向传播即可。
4、对于假图片,在训练判别器时,我们希望它输出0;而在训练生成器时,我们希望它输出1.因此可以看到一对看似矛盾的代码 error_d_fake = criterion(output, fake_labels)和error_g = criterion(output, true_labels)。判别器希望能够把假图片判别为fake_label,而生成器则希望能把他判别为true_label,判别器和生成器互相对抗提升。

import os
import torch
from torch.utils.data import Dataset, DataLoader
from dataloader import MyDataset
from model import Generator, Discriminator
import torchvision
import numpy as np
import matplotlib.pyplot as plt
if __name__ == '__main__':LR = 0.0002EPOCH = 1000  # 50BATCH_SIZE = 40N_IDEAS = 100EPS = 1e-10TRAINED = False#path = r'./data/image'train_data = MyDataset(path=path, resize=96, Len=10000, img_type='jpg')train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)torch.cuda.empty_cache()if TRAINED:G = torch.load('G.pkl').cuda()D = torch.load('D.pkl').cuda()else:G = Generator(N_IDEAS).cuda()D = Discriminator(3).cuda()optimizerG = torch.optim.Adam(G.parameters(), lr=LR)optimizerD = torch.optim.Adam(D.parameters(), lr=LR)for epoch in range(EPOCH):tmpD, tmpG = 0, 0for step, x in enumerate(train_loader):x = x.cuda()rand_noise = torch.randn((x.shape[0], N_IDEAS, 1, 1)).cuda()G_imgs = G(rand_noise)D_fake_probs = D(G_imgs)D_real_probs = D(x)p_d_fake = torch.squeeze(D_fake_probs)p_d_real = torch.squeeze(D_real_probs)D_loss = -torch.mean(torch.log(p_d_real + EPS) + torch.log(1. - p_d_fake + EPS))G_loss = -torch.mean(torch.log(p_d_fake + EPS))# D_loss = -torch.mean(torch.log(D_real_probs) + torch.log(1. - D_fake_probs))# G_loss = torch.mean(torch.log(1. - D_fake_probs))optimizerD.zero_grad()D_loss.backward(retain_graph=True)optimizerD.step()optimizerG.zero_grad()G_loss.backward(retain_graph=True)optimizerG.step()tmpD_ = D_loss.cpu().detach().datatmpG_ = G_loss.cpu().detach().datatmpD += tmpD_tmpG += tmpG_tmpD /= (step + 1)tmpG /= (step + 1)print('epoch %d avg of loss: D: %.6f, G: %.6f' % (epoch, tmpD, tmpG))# if (epoch+1) % 5 == 0:select_epoch = [1, 5, 10, 20, 50, 80, 100, 150, 200, 400, 500, 800, 999, 1500, 2000, 3000, 4000, 5000, 6000, 8000, 9999]if epoch in select_epoch:
plt.imshow(np.squeeze(G_imgs[0].cpu().detach().numpy().transpose((1, 2, 0))) * 0.5 + 0.5)plt.savefig('./result1/_%d.png' % epoch)torch.save(G, 'G.pkl')torch.save(D, 'D.pkl')

下面是训练多次的效果
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
完整代码如下:

# import os
import torch
import torch.nn as nn
import torchvision as tv
from torch.autograd import Variable
import tqdm
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['SimHei']  # 显示中文标签
plt.rcParams['axes.unicode_minus'] = False# dir = '... your path/faces/'
dir = './data/train_data'
# path = []
#
# for fileName in os.listdir(dir):
#     path.append(fileName)       # len(path)=51223noiseSize = 100     # 噪声维度
n_generator_feature = 64        # 生成器feature map数
n_discriminator_feature = 64        # 判别器feature map数
batch_size = 50
d_every = 1     # 每一个batch训练一次discriminator
g_every = 5     # 每五个batch训练一次generatorclass NetGenerator(nn.Module):def __init__(self):super(NetGenerator,self).__init__()self.main = nn.Sequential(      # 神经网络模块将按照在传入构造器的顺序依次被添加到计算图中执行nn.ConvTranspose2d(noiseSize, n_generator_feature * 8, kernel_size=4, stride=1, padding=0, bias=False),nn.BatchNorm2d(n_generator_feature * 8),nn.ReLU(True),       # (n_generator_feature * 8) × 4 × 4        (1-1)*1+1*(4-1)+0+1 = 4nn.ConvTranspose2d(n_generator_feature * 8, n_generator_feature * 4, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(n_generator_feature * 4),nn.ReLU(True),      # (n_generator_feature * 4) × 8 × 8     (4-1)*2-2*1+1*(4-1)+0+1 = 8nn.ConvTranspose2d(n_generator_feature * 4, n_generator_feature * 2, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(n_generator_feature * 2),nn.ReLU(True),  # (n_generator_feature * 2) × 16 × 16nn.ConvTranspose2d(n_generator_feature * 2, n_generator_feature, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(n_generator_feature),nn.ReLU(True),      # (n_generator_feature) × 32 × 32nn.ConvTranspose2d(n_generator_feature, 3, kernel_size=5, stride=3, padding=1, bias=False),nn.Tanh()       # 3 * 96 * 96)def forward(self, input):return self.main(input)class NetDiscriminator(nn.Module):def __init__(self):super(NetDiscriminator,self).__init__()self.main = nn.Sequential(nn.Conv2d(3, n_discriminator_feature, kernel_size=5, stride=3, padding=1, bias=False),nn.LeakyReLU(0.2, inplace=True),        # n_discriminator_feature * 32 * 32nn.Conv2d(n_discriminator_feature, n_discriminator_feature * 2, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(n_discriminator_feature * 2),nn.LeakyReLU(0.2, inplace=True),         # (n_discriminator_feature*2) * 16 * 16nn.Conv2d(n_discriminator_feature * 2, n_discriminator_feature * 4, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(n_discriminator_feature * 4),nn.LeakyReLU(0.2, inplace=True),  # (n_discriminator_feature*4) * 8 * 8nn.Conv2d(n_discriminator_feature * 4, n_discriminator_feature * 8, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(n_discriminator_feature * 8),nn.LeakyReLU(0.2, inplace=True),  # (n_discriminator_feature*8) * 4 * 4nn.Conv2d(n_discriminator_feature * 8, 1, kernel_size=4, stride=1, padding=0, bias=False),nn.Sigmoid()        # 输出一个概率)def forward(self, input):return self.main(input).view(-1)def train():for i, (image,_) in tqdm.tqdm(enumerate(dataloader)):       # type((image,_)) = <class 'list'>, len((image,_)) = 2 * 256 * 3 * 96 * 96real_image = Variable(image)real_image = real_image.cuda()if (i + 1) % d_every == 0:optimizer_d.zero_grad()output = Discriminator(real_image)      # 尽可能把真图片判为Trueerror_d_real = criterion(output, true_labels)error_d_real.backward()noises.data.copy_(torch.randn(batch_size, noiseSize, 1, 1))fake_img = Generator(noises).detach()       # 根据噪声生成假图fake_output = Discriminator(fake_img)       # 尽可能把假图片判为Falseerror_d_fake = criterion(fake_output, fake_labels)error_d_fake.backward()optimizer_d.step()if (i + 1) % g_every == 0:optimizer_g.zero_grad()noises.data.copy_(torch.randn(batch_size, noiseSize, 1, 1))fake_img = Generator(noises)        # 这里没有detachfake_output = Discriminator(fake_img)       # 尽可能让Discriminator把假图片判为Trueerror_g = criterion(fake_output, true_labels)error_g.backward()optimizer_g.step()def show(num):fix_fake_imags = Generator(fix_noises)fix_fake_imags = fix_fake_imags.data.cpu()[:64] * 0.5 + 0.5# x = torch.rand(64, 3, 96, 96)fig = plt.figure(1)i = 1for image in fix_fake_imags:ax = fig.add_subplot(8, 8, eval('%d' % i))# plt.xticks([]), plt.yticks([])  # 去除坐标轴plt.axis('off')plt.imshow(image.permute(1, 2, 0))i += 1plt.subplots_adjust(left=None,  # the left side of the subplots of the figureright=None,  # the right side of the subplots of the figurebottom=None,  # the bottom of the subplots of the figuretop=None,  # the top of the subplots of the figurewspace=0.05,  # the amount of width reserved for blank space between subplotshspace=0.05)  # the amount of height reserved for white space between subplots)plt.suptitle('第%d迭代结果' % num, y=0.91, fontsize=15)plt.savefig("images/%dcgan.png" % num)if __name__ == '__main__':transform = tv.transforms.Compose([tv.transforms.Resize(96),     # 图片尺寸, transforms.Scale transform is deprecatedtv.transforms.CenterCrop(96),tv.transforms.ToTensor(),tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))       # 变成[-1,1]的数])dataset = tv.datasets.ImageFolder(dir, transform=transform)dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4, drop_last=True)   # module 'torch.utils.data' has no attribute 'DataLoder'print('数据加载完毕!')Generator = NetGenerator()Discriminator = NetDiscriminator()optimizer_g = torch.optim.Adam(Generator.parameters(), lr=2e-4, betas=(0.5, 0.999))optimizer_d = torch.optim.Adam(Discriminator.parameters(), lr=2e-4, betas=(0.5, 0.999))criterion = torch.nn.BCELoss()true_labels = Variable(torch.ones(batch_size))     # batch_sizefake_labels = Variable(torch.zeros(batch_size))fix_noises = Variable(torch.randn(batch_size, noiseSize, 1, 1))noises = Variable(torch.randn(batch_size, noiseSize, 1, 1))     # 均值为0,方差为1的正态分布if torch.cuda.is_available() == True:print('Cuda is available!')Generator.cuda()Discriminator.cuda()criterion.cuda()true_labels, fake_labels = true_labels.cuda(), fake_labels.cuda()fix_noises, noises = fix_noises.cuda(), noises.cuda()plot_epoch = [1,5,10,50,100,200,500,800,1000,1500,2000,2500,3000]for i in range(3000):        # 最大迭代次数train()print('迭代次数:{}'.format(i))if i in plot_epoch:show(i)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/60841.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

深度对抗神经网络(DANN)笔记

一 总体介绍 DANN是一种迁移学习方法&#xff0c;是对抗迁移学习方法的代表方法。基本结构由特征提取层f,分类器部分c和对抗部分d组成&#xff0c;其中f和c其实就是一个标准的分类模型&#xff0c;通过GAN&#xff08;生成对抗网络&#xff09;得到迁移对抗模型的灵感。但此时…

机器学习:BP神经网络,CNN卷积神经网络,GAN生成对抗网络

1&#xff0c;基础知识 1.1&#xff0c;概述 机器学习&#xff1a;概念_燕双嘤-CSDN博客1&#xff0c;机器学习概述1.1&#xff0c;机器学习概念机器学习即Machine Learning&#xff0c;涉及概率论、统计学、逼近论、凸分析、算法复杂度理论等多门学科。目的是让计算机模拟或实…

生成对抗网络(GAN)简单梳理

作者&#xff1a;xg123321123 - 时光杂货店 出处&#xff1a;http://blog.csdn.net/xg123321123/article/details/78034859 声明&#xff1a;版权所有&#xff0c;转载请联系作者并注明出处 网上已经贴满了关于GAN的博客&#xff0c;写这篇帖子只是梳理下思路&#xff0c;以便以…

生成对抗网络(GAN)简介以及Python实现

本篇博客简单介绍了生成对抗网络(Generative Adversarial Networks,GAN)&#xff0c;并基于Keras实现深度卷积生成对抗网络(DCGAN)。 以往的生成模型都是预先假设生成样本服从某一分布族&#xff0c;然后用深度网络学习分布族的参数&#xff0c;最后从学习到的分布中采样生成新…

生成对抗网络(GAN)教程 - 多图详解

一.生成对抗网络简介 1.生成对抗网络模型主要包括两部分&#xff1a;生成模型和判别模型。 生成模型是指我们可以根据任务、通过模型训练由输入的数据生成文字、图像、视频等数据。 [1]比如RNN部分讲的用于生成奥巴马演讲稿的RNN模型,通过输入开头词就能生成下来。…

对抗神经网络学习和实现(GAN)

一&#xff0c;GAN的原理介绍 \quad GAN的基本原理其实非常简单&#xff0c;这里以生成图片为例进行说明。假设我们有两个网络&#xff0c;G&#xff08;Generator&#xff09;和D&#xff08;Discriminator&#xff09;。正如它的名字所暗示的那样&#xff0c;它们的功能分别是…

生成对抗网络(GAN)

1 GAN基本概念 1.1 如何通俗理解GAN&#xff1f; ​ 生成对抗网络(GAN, Generative adversarial network)自从2014年被Ian Goodfellow提出以来&#xff0c;掀起来了一股研究热潮。GAN由生成器和判别器组成&#xff0c;生成器负责生成样本&#xff0c;判别器负责判断生成器生成…

基于图神经网络的对抗攻击 Nettack: Adversarial Attacks on Neural Networks for Graph Data

研究意义 随着GNN的应用越来越广&#xff0c;在安全非常重要的应用中应用GNN&#xff0c;存在漏洞可能是非常严重的。 比如说金融系统和风险管理&#xff0c;在信用评分系统中&#xff0c;欺诈者可以伪造与几个高信用客户的联系&#xff0c;以逃避欺诈检测模型&#xff1b;或者…

生成对抗网络(Generative Adversial Network,GAN)原理简介

生成对抗网络(GAN)是深度学习中一类比较大的家族&#xff0c;主要功能是实现图像、音乐或文本等生成(或者说是创作)&#xff0c;生成对抗网络的主要思想是&#xff1a;通过生成器(generator)与判别器(discriminator)不断对抗进行训练。最终使得判别器难以分辨生成器生成的数据(…

快讯|莫言用 ChatGPT 写《颁奖辞》;特斯拉人形机器人集体出街!已与FSD算法打通

一分钟速览新闻点 言用 ChatGPT 写《颁奖辞》孙其君研究员团队 Adv. Funct. Mater.&#xff1a;多功能离子凝胶纤维膜用于能量离电皮肤微软CEO反驳马斯克&#xff1a;我们没有控制OpenAI特斯拉人形机器人集体出街&#xff01;已与FSD算法打通微软CEO称小型公司仍可在人工智能领…

从供应链到价值链:人形机器人产业链深入研究

原创 | 文 BFT机器人 01 人形机器人产业进展&#xff1a;AI赋能&#xff0c;人形机器人迭代有望加速 目前人形机器人产业所处从“0”到“1”的萌芽期&#xff0c;从现在到未来的时间里&#xff0c;人形机器人以其仿人外形、身体构成及其智能大脑&#xff0c;能极大解放生产力、…

人形机器人火出圈!OpenAI领投挪威人形机器人公司“1X”

文&#xff5c;牛逼的AI 编&#xff5c;猫猫咪子 源&#xff5c;AI源起 目前已经实现了对接GPT4技术的机器人Ameca&#xff0c;它拥有逼真的外观和丰富的表情。 随着人形机器人技术的飞速发展&#xff0c;未来可能不再适用阿西莫夫所提出的“机器人三定律”&#xff0c;因为超级…

chatgpt赋能python:Python登录界面制作指南

Python登录界面制作指南 介绍 登录界面是许多应用程序的关键组成部分之一。Python作为一种优秀的编程语言&#xff0c;拥有着强大的界面开发框架&#xff0c;能够帮助开发人员更轻松地创作出完美的登录界面。 在本文中&#xff0c;我们将向您介绍使用Python如何制作一个简单…

The Journal of Neuroscience: 珠心算训练有助于提高儿童的视觉空间工作记忆

《本文同步发布于“脑之说”微信公众号&#xff0c;欢迎搜索关注~~》 珠心算是指个体在熟练进行珠算操作后&#xff0c;可摆脱实际算盘&#xff0c;借助大脑中虚拟算盘进行数字计算的方式&#xff08;图1&#xff09;。早期行为学研究表明&#xff0c;珠心算个体的数字计算能力…

php珠心算源码,深度解析珠心算的“开智”功能

编者按&#xff1a;本文来自李绵军校长在廊坊智慧特训营演讲。李绵军校长通过十几年来对珠心算的钻研练习&#xff0c;详细解读了珠心算的开智功能&#xff0c;以及“一门深入”的作用。 珠心算的开智价值是在哪里&#xff1f;大家都说开发智力&#xff0c;我在这讲开发智力不是…

php珠心算源码,NOIP201401珠心算测验

珠心算测验 问题描述】 珠心算是一种通过在脑中模拟算盘变化来完成快速运算的一种计算技术。珠心算训练&#xff0c;既能够开发智力&#xff0c;又能够为日常生活带来很多便利&#xff0c;因而在很多学校得到普及。 某学校的珠心算老师采用一种快速考察珠心算加法能力的测验方…

雨课堂提交作业步骤 10步帮你弄好

1 2 3 4 5 6 7 中间计数记错了… 8 9 10 弹出对话框&#xff0c;点击确认即可 提交成功的截图&#xff1a;

2022李宏毅作业hw4 - 挫败感十足的一次作业。

系列文章&#xff1a; 2022李宏毅作业hw1—新冠阳性人员数量预测。_亮子李的博客-CSDN博客_李宏毅hw1 hw-2 李宏毅2022年作业2 phoneme识别 单strong-hmm详细解释。_亮子李的博客-CSDN博客_李宏毅hw2 2021李宏毅作业hw3 --食物分类。对比出来的80准确率。_亮子李的博客-CSDN博客…

php老师的一个作业展示

1.在ScanCode.php中 在judgeTrayCodeEnableIntoWarehouse方法中: 2在CommitTrayCodes.php中 访问时&#xff0c;直接访问CommitTrayCodes.php&#xff0c;这个CommitTrayCodes.php是要建在controller下&#xff0c;建议导一下&#xff0c;老师的数据库&#xff0c;以防止命名…

HCIA网络课程第七周作业

&#xff08;1&#xff09;请用自己的语言描述基本ACL和高级ACL的区别 &#xff08;2&#xff09;AAA支持的认证、授权和计费方式分别有哪几种&#xff1f; AAA支持的认证方式有不认证 本地认证 远端认证AAA支持的授权方式为不授权 本地授权 远端授权AAA支持计费方式为不计费…