基于PyTorch搭建你的生成对抗性网络

cd0651e77ef4cd23df4cfc0475b1bb07.jpeg

前言

你听说过GANs吗?还是你才刚刚开始学?GANs是2014年由蒙特利尔大学的学生 Ian Goodfellow 博士首次提出的。GANs最常见的例子是生成图像。有一个网站包含了不存在的人的面孔,便是一个常见的GANs应用示例。也是我们将要在本文中进行分享的。

生成对抗网络由两个神经网络组成,生成器和判别器相互竞争。我将在后面详细解释每个步骤。希望在本文结束时,你将能够从零开始训练和建立自己的生财之道对抗性网络。所以闲话少说,让我们开始吧。

目录

步骤0: 导入数据集

步骤1: 加载及预处理图像

步骤2: 定义判别器算法

步骤3: 定义生成器算法

步骤4: 编写训练算法

步骤5: 训练模型

步骤6: 测试模型

步骤0: 导入数据集

第一步是下载并将数据加载到内存中。我们将使用 CelebFaces Attributes Dataset (CelebA)来训练你的对抗性网络。主要分以下三个步骤:

1. 下载数据集:

https://s3.amazonaws.com/video.udacity-data.com/topher/2018/November/5be7eb6f_processed-celeba-small/processed-celeba-small.zip;

2. 解压缩数据集;

3. Clone 如下 GitHub地址:

https://github.com/Ahmad-shaikh575/Face-Generation-using-GANS

这样做之后,你可以在 colab 环境中打开它,或者你可以使用你自己的 pc 来训练模型。

导入必要的库

#import the neccessary libraries
import pickle as pkl
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch
from torchvision import datasets
from torchvision import transforms
import torch
import torch.optim as optim

步骤1: 加载及预处理图像

在这一步中,我们将预处理在前一节中下载的图像数据。

将采取以下步骤:

  1. 调整图片大小

  2. 转换成张量

  3. 加载到 PyTorch 数据集中

  4. 加载到 PyTorch DataLoader 中

# Define hyperparameters
batch_size = 32
img_size = 32
data_dir='processed_celeba_small/'# Apply the transformations
transform = transforms.Compose([transforms.Resize(image_size),transforms.ToTensor()])
# Load the dataset
imagenet_data = datasets.ImageFolder(data_dir,transform= transform)# Load the image data into dataloader
celeba_train_loader = torch.utils.data.DataLoader(imagenet_data,batch_size,shuffle=True)

图像的大小应该足够小,这将有助于更快地训练模型。Tensors 基本上是 NumPy 数组,我们只是将图像转换为在 PyTorch 中所必需的 NumPy 数组。

然后我们加载这个转换成的 PyTorch 数据集。在那之后,我们将把我们的数据分成小批量。这个数据加载器将在每次迭代时向我们的模型训练过程提供图像数据。

随着数据的加载完成。现在,我们可以预处理图像。

图像的预处理

我们将在训练过程中使用 tanh 激活函数。该生成器的输出范围在 -1到1之间。我们还需要对这个范围内的图像进行缩放。代码如下所示:

def scale(img, feature_range=(-1, 1)):'''Scales the input image into given feature_range'''min,max = feature_rangeimg = img * (max-min) + minreturn img

这个函数将对所有输入图像缩放,我们将在后面的训练中使用这个函数。

现在我们已经完成了无聊的预处理步骤。

接下来是令人兴奋的部分,现在我们需要为我们的生成器和判别器神经网络编写代码。

步骤2: 定义判别器算法

97127e11857e4e2d08f27fe2e2848c52.png

判别器是一个可以区分真假图像的神经网络。真实的图像和由生成器生成的图像都将提供给它。

我们将首先定义一个辅助函数,这个辅助函数在创建卷积网络层时非常方便。

# helper conv function
def conv(in_channels, out_channels, kernel_size, stride=2, padding=1, batch_norm=True):layers = []conv_layer = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False)#Appending the layerlayers.append(conv_layer)#Applying the batch normalization if it's given trueif batch_norm:layers.append(nn.BatchNorm2d(out_channels))# returning the sequential containerreturn nn.Sequential(*layers)

这个辅助函数接收创建任何卷积层所需的参数,并返回一个序列化的容器。现在我们将使用这个辅助函数来创建我们自己的判别器网络。

class Discriminator(nn.Module):def __init__(self, conv_dim):super(Discriminator, self).__init__()self.conv_dim = conv_dim#32 x 32self.cv1 = conv(3, self.conv_dim, 4, batch_norm=False)#16 x 16self.cv2 = conv(self.conv_dim, self.conv_dim*2, 4, batch_norm=True)#4 x 4self.cv3 = conv(self.conv_dim*2, self.conv_dim*4, 4, batch_norm=True)#2 x 2self.cv4 = conv(self.conv_dim*4, self.conv_dim*8, 4, batch_norm=True)#Fully connected Layerself.fc1 = nn.Linear(self.conv_dim*8*2*2,1)def forward(self, x):# After passing through each layer# Applying leaky relu activation functionx = F.leaky_relu(self.cv1(x),0.2)x = F.leaky_relu(self.cv2(x),0.2)x = F.leaky_relu(self.cv3(x),0.2)x = F.leaky_relu(self.cv4(x),0.2)# To pass throught he fully connected layer# We need to flatten the image firstx = x.view(-1,self.conv_dim*8*2*2)# Now passing through fully-connected layerx = self.fc1(x)return x

步骤3: 定义生成器算法

d30da4cfe3a34a4e28baacfc833e75f8.png

正如你们从图中看到的,我们给网络一个高斯矢量或者噪声矢量,它输出 s 中的值。图上的“ z”表示噪声,右边的 G (z)表示生成的样本。

与判别器一样,我们首先创建一个辅助函数来构建生成器网络,如下所示:

def deconv(in_channels, out_channels, kernel_size, stride=2, padding=1, batch_norm=True):layers = []convt_layer = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=False)# Appending the above conv layerlayers.append(convt_layer)if batch_norm:# Applying the batch normalization if Truelayers.append(nn.BatchNorm2d(out_channels))# Returning the sequential containerreturn nn.Sequential(*layers)

现在,是时候构建生成器网络了! !

class Generator(nn.Module):def __init__(self, z_size, conv_dim):super(Generator, self).__init__()self.z_size = z_sizeself.conv_dim = conv_dim#fully-connected-layerself.fc = nn.Linear(z_size, self.conv_dim*8*2*2)#2x2self.dcv1 = deconv(self.conv_dim*8, self.conv_dim*4, 4, batch_norm=True)#4x4self.dcv2 = deconv(self.conv_dim*4, self.conv_dim*2, 4, batch_norm=True)#8x8self.dcv3 = deconv(self.conv_dim*2, self.conv_dim, 4, batch_norm=True)#16x16self.dcv4 = deconv(self.conv_dim, 3, 4, batch_norm=False)#32 x 32def forward(self, x):# Passing through fully connected layerx = self.fc(x)# Changing the dimensionx = x.view(-1,self.conv_dim*8,2,2)# Passing through deconv layers# Applying the ReLu activation functionx = F.relu(self.dcv1(x))x= F.relu(self.dcv2(x))x= F.relu(self.dcv3(x))x= F.tanh(self.dcv4(x))#returning the modified imagereturn x

为了使模型更快地收敛,我们将初始化线性和卷积层的权重。根据相关研究论文中的描述:所有的权重都是从0中心的正态分布初始化的,标准差为0.02

我们将为此目的定义一个功能如下:

def weights_init_normal(m):classname = m.__class__.__name__# For the linear layersif 'Linear' in classname:torch.nn.init.normal_(m.weight,0.0,0.02)m.bias.data.fill_(0.01)# For the convolutional layersif 'Conv' in classname or 'BatchNorm2d' in classname:torch.nn.init.normal_(m.weight,0.0,0.02)

现在我们将超参数和两个网络初始化如下:

# Defining the model hyperparamameters
d_conv_dim = 32
g_conv_dim = 32
z_size = 100   #Size of noise vectorD = Discriminator(d_conv_dim)
G = Generator(z_size=z_size, conv_dim=g_conv_dim)
# Applying the weight initialization
D.apply(weights_init_normal)
G.apply(weights_init_normal)print(D)
print()
print(G)

输出结果大致如下:

b0dbbadcdafb3a9f8bf28da76d19bd8f.png

判别器损失:

根据 DCGAN Research Paper 论文中描述:

        判别器总损失 = 真图像损失 + 假图像损失,即:d_loss = d_real_loss + d_fake_loss。

       不过,我们希望鉴别器输出1表示真正的图像和0表示假图像,所以我们需要设置的损失来反映这一点。

我们将定义双损失函数。一个是真正的损失,另一个是假的损失,如下:

def real_loss(D_out,smooth=False):batch_size = D_out.size(0)if smooth:labels = torch.ones(batch_size)*0.9else:labels = torch.ones(batch_size)labels = labels.to(device)criterion = nn.BCEWithLogitsLoss()loss = criterion(D_out.squeeze(), labels)return lossdef fake_loss(D_out):batch_size = D_out.size(0)labels = torch.zeros(batch_size)labels = labels.to(device)criterion = nn.BCEWithLogitsLoss()loss = criterion(D_out.squeeze(), labels)return loss

生成器损失:

根据 DCGAN Research Paper 论文中描述:

        生成器的目标是让判别器认为它生成的图像是真实的。

现在,是时候为我们的网络设置优化器了:

lr = 0.0005
beta1 = 0.3
beta2 = 0.999 # default value
# Optimizers
d_optimizer = optim.Adam(D.parameters(), lr, betas=(beta1, beta2))
g_optimizer = optim.Adam(G.parameters(), lr, betas=(beta1, beta2))

我将为我们的训练使用 Adam 优化器。因为它目前被认为是对GANs最有效的。根据上述介绍论文中的研究成果,确定了超参数的取值范围。他们已经尝试了它,这些被证明是最好的!超参数设置如下:

步骤4: 编写训练算法

我们必须为我们的两个神经网络编写训练算法。首先,我们需要初始化噪声向量,并在整个训练过程中保持一致。

# Initializing arrays to store losses and samples
samples = []
losses = []# We need to initilialize fixed data for sampling
# This would help us to evaluate model's performance
sample_size=16
fixed_z = np.random.uniform(-1, 1, size=(sample_size, z_size))
fixed_z = torch.from_numpy(fixed_z).float()

对于判别器:

我们首先将真实的图像输入判别器网络,然后计算它的实际损失。然后生成伪造图像并输入判别器网络以计算虚假损失。

在计算了真实和虚假损失之后,我们对其进行求和,并采取优化步骤进行训练。

# setting optimizer parameters to zero
# to remove previous training data residue
d_optimizer.zero_grad()# move real images to gpu memory
real_images = real_images.to(device)# Pass through discriminator network
dreal = D(real_images)# Calculate the real loss
dreal_loss = real_loss(dreal)# For fake images# Generating the fake images
z = np.random.uniform(-1, 1, size=(batch_size, z_size))
z = torch.from_numpy(z).float()# move z to the GPU memory
z = z.to(device)# Generating fake images by passing it to generator
fake_images = G(z)# Passing fake images from the disc network        
dfake = D(fake_images)
# Calculating the fake loss
dfake_loss = fake_loss(dfake)#Adding both lossess
d_loss = dreal_loss + dfake_loss
# Taking the backpropogation step
d_loss.backward()
d_optimizer.step()

对于生成器:

对于生成器网络的训练,我们也会这样做。刚才在通过判别器网络输入假图像之后,我们将计算它的真实损失。然后优化我们的生成器网络。

## Training the generator for adversarial loss
#setting gradients to zero
g_optimizer.zero_grad()# Generate fake images
z = np.random.uniform(-1, 1, size=(batch_size, z_size))
z = torch.from_numpy(z).float()
# moving to GPU's memory
z = z.to(device)# Generating Fake images
fake_images = G(z)# Calculating the generator loss on fake images
# Just flipping the labels for our real loss function
D_fake = D(fake_images)
g_loss = real_loss(D_fake, True)# Taking the backpropogation step
g_loss.backward()
g_optimizer.step()

步骤5: 训练模型

现在我们将开始100个epoch的训练: D

经过训练,损失的图表看起来大概是这样的:

297583a673349a29e8b1e16942b38a73.png

我们可以看到,判别器 Loss 是相当平滑的,甚至在100个epoch之后收敛到某个特定值。而生成器的Loss则飙升。

我们可以从下面步骤6中的结果看出,60个时代之后生成的图像是扭曲的。由此可以得出结论,60个epoch是一个最佳的训练节点。

步骤6: 测试模型

10个epoch之后:

96d12f9703a12ac0e77a0c2f87c3f146.png

20个epoch之后:

dc196f60f97d770c6d0c79c8f4d85ae7.png

30个epoch之后:

33365e822eda268d5e048d296cbea7be.png

40个epoch之后:

0198316b876a53571d5550f9f97556d5.png

50个epoch之后:

bdcf6ad6c3a1977c55cdd7f6fd9a4816.png

60个epoch之后:

4152c319441182ae8907a6b354fb6a44.png

70个epoch之后:

b48254e7675b521a66d059da034db13b.png

80个epoch之后:

e567311618525ee9d2e67e0292fe067d.png

90个epoch之后:

c933353f759c001a2d1cd8d9ec81f3ba.png

100个epoch之后:

0f83fe27e164d062d3afad75e7960e75.png

总结

我们可以看到,训练一个生成对抗性网络并不意味着它一定会产生好的图像。

从结果中我们可以看出,训练40-60个 epoch 的生成器生成的图像相对比其他更好。

您可以尝试更改优化器、学习速率和其他超参数,以使其生成更好的图像!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/193325.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

深入理解 pytest Fixture 方法及其应用!

当涉及到编写自动化测试时,测试框架和工具的选择对于测试用例的设计和执行非常重要。在Python 中,pytest是一种广泛使用的测试框架,它提供了丰富的功能和灵活的扩展性。其中一个很有用的功 能是fixture方法,它允许我们初始化测试环…

Leadshop开源商城小程序源码 – 支持公众号H5

Leadshop是一款出色的开源电商系统,具备轻量级、高性能的特点,并提供持续更新和迭代服务。该系统采用前后端分离架构(uniappyii2.0),以实现最佳用户体验为目标。 前端部分采用了uni-app、ES6、Vue、Vuex、Vue Router、…

系列三、双亲委派机制

一、概述 当一个类收到了类加载的请求,它首先不会尝试自己去加载这个类,而是把这个请求委派给父类去完成,每一层的类加载器都是如此,因此所有的请求都应该传送到启动类加载器中,只有当父类加载器反馈自己无法完成这个…

EtherCAT从站EEPROM组成信息详解(2):字8-15产品标识区

0 工具准备 1.EtherCAT从站EEPROM数据(本文使用DE3E-556步进电机驱动器)1 字8-字15产品标识区 1.1 产品标识区组成规范 对于不同厂家和型号的从站,主站是如何区分它们的呢?这就要提起SII的字8-字15区域存储的产品标识&#xff…

Solidity案例详解(四)投票智能合约

该合约为原创合约,功能要求如下 在⼀定时间能进⾏投票超过时间投票截⽌,并投赞同票超过50%则为通过。 使⽤safeMath库,使⽤Owner 第三⽅库拥有参与投票权的⽤户在创建合约时确定Voter 结构 要有时间戳、投票是否同意等;struct 结构…

VSCode 使用CMakePreset找不到cl.exe编译器的问题

在用vscode开发c项目的时候,使用预先配置的CMakePresets.json可以把一些特定的cmake选项固定下来,在配置时直接使用 "cmake --config --preset presetname"就可以进行配置,免去在命令行输入过多的配置参数。 但是在vscode中&#…

新版本!飞凌嵌入式RK3568系列开发板全面支持Debian 11系统

飞凌嵌入式OK3568-C/OK3568J-C开发板现已全面支持Debian 11系统,新系统的加持能为用户提供主控新选择,并为开发者带来更多开发便利! Debian系统作为一种广受欢迎和信赖的开源操作系统,以其稳定性、可靠性和开放性而闻名&#xff0…

posix定时器的使用

POSIX定时器是基于POSIX标准定义的一组函数,用于实现在Linux系统中创建和管理定时器。POSIX定时器提供了一种相对较高的精度,可用于实现毫秒级别的定时功能。 POSIX定时器的主要函数包括: timer_create():用于创建一个定时器对象…

Chrome 浏览器经常卡死问题解决

Chrome 浏览器经常卡死问题解决 chrome 任务管理器杀进程 mac 后台有很多 google chrome helper 线程并且内存占用较高 一直怀疑是插件的锅 其实并不是-0- 查看是哪个网页,哪个插件占用内存 chrome 更多工具 -> 任务管理器 切换到稳定版本的 chrome&#xff0c…

gin索引 btree索引 gist索引比较

创建例子数据 postgres# create table t_hash as select id,md5(id::text) from generate_series(1,5000000) as id; SELECT 5000000postgres# vacuum ANALYZE t_hash; VACUUMpostgres# \timing Timing is on. postgres# select * from t_hash limit 10;id | …

手机开机入网流程 KPI接通率和掉线率

今天我们来学习手机开机入网流程是怎么样的。以及RRC连接和重建流程(和博主之前讲TCP三次握手,四次挥手原理很相似)是什么样的,还有天线的KPI指标都包括什么,是不是很期待啊~ 目录 手机开机入网流程 ATTACH/RRC连接建立过程 KPI接通率和掉…

ubuntu 18.04安裝QT+PCL+VTK+Opencv

资源 qt5.14.1:qt5.14.1.run opencv4.5.5:opecv4.5.5压缩包 1.国内换中科大源,加快下载速度 cd /etc/apt/ sudo gedit sources.list 替换成如下内容 deb https://mirrors.ustc.edu.cn/ubuntu/ bionic main restricted universe multiverse deb-src https://mirro…

WordPress 媒体库文件夹管理插件 FileBird v5.5.4和谐版下载

FileBird是一款WordPress 按照文件夹管理方式的插件。 拖放界面 拖放功能现已成为现代软件和网站的标配。本机拖动事件(包括仅在刀片中将文件移动到文件夹以及将文件夹移动到文件夹)极大地减少了完成任务所需的点击次数。 一流设计的文件夹树展示 我们…

<MySQL> 查询数据进阶操作 -- 联合查询

目录 一、什么是笛卡尔积? 二、什么是联合查询? 三、内连接 3.1 简介 3.2 语法 3.3 更多的表 3.4 操作演示 四、外连接 4.1 简介 4.2 语法 4.3 操作演示 五、自连接 5.1 简介 5.2 自连接非必要不使用 六、子查询(嵌套查询) 6.1 简介 6.…

Docker Compose详细教程(从入门到放弃)

对于现代应用来说,大多都是通过很多的微服务互相协同组成的一个完整应用。例如, 订单管理、用户管理、品类管理、缓存服务、数据库服务等,它们构成了一个电商平台的应 用。而部署和管理大量的服务容器是一件非常繁琐的事情。而 Docker Compos…

arcgis--填充面域空洞

方法一:使用【编辑器】-【合并工具】进行填充。首选需要在相同图层中构造一个填充空洞的面域,然后利用【合并】工具进行最后填充。 打开一幅含有空洞的矢量数据,如下: 打开【开始编辑】-【构造工具】-【面】进行覆盖空洞的面域的…

RabbitMQ之交换机

文章目录 一、Exchanges1、Exchanges 概念2、Exchanges 的类型3、无名 exchange 二、临时队列三、绑定(bindings)四、Fanout(扇出)1、Fanout 介绍2、Fanout 实战 五、Direct exchange(直连交换机)1、Direct exchange 介绍2、多重绑…

相对强弱指标 RSI

SMA(A,B,1)MA AA ,一天前的收盘价; BB,如果时涨的,把涨幅返回; CC,12天的涨幅占12天全部涨跌幅的多少; 画一条50 的线条。

一道 python 数据分析的题目

python 数据分析的题目。 做题方法:使用 pandas 读取数据,然后分析。 知识点:pandas,正则表达式,py知识。 过程:不断使用 GPT,遇到有问题的地方自己分析,把分析的结果告诉 GPT&am…

vagrant+virtualbox的踩坑记录

vagrant virtualbox 文章目录 vagrant virtualbox一、导入虚拟机ova文件失败二、修改虚拟机的保存位置三、无法使用xshell等软件用密码进行连接四、vagrant up失败 一、导入虚拟机ova文件失败 背景:手动删除了虚拟机文件导致无法重新导入相同名称虚拟机的ova文件…