>- **🍨 本文为[🔗365天深度学习训练营]中的学习记录博客**
>- **🍖 原作者:[K同学啊]**
本人往期文章可查阅: 深度学习总结
基础任务
- 1.了解什么是生成对抗网络
- 2.生成对抗网络结构是怎么样的
- 3.学习本文代码,并跑通代码
- 调用训练好的模型生成新图像
🏡 我的环境:
- 语言环境:Python3.11
- 编译器:PyCharm
- 深度学习环境:Pytorch
-
- torch==2.0.0+cu118
-
- torchvision==0.18.1+cu118
- 显卡:NVIDIA GeForce GTX 1660
一、理论基础
生成对抗网络(Generative Adversarial Networks, GAN)是近年来深度学习领域的一个热点方向。GAN并不指代某一个具体的神经网络,而是指一类基于博弈思想而设计的神经网络。GAN由两个分支分别被称为生成器(Generator)和判别器(Discriminator)的神经网络组成。其中,生成器从某种噪声分布中随机采样作为输入,输出与训练集中真实样本非常相似的人工样本;判别器的输入则为真实样本或人工样本,其目的是将人工样本与真实样本尽可能地区分出来。生成器和判别器交替运行,相互博弈,各自的能力都得到提升。理想情况下,经过足够次数的博弈之后,判别器无法判断给定样本的真实性,即对于所有样本都输出50%真,50%假的判断。此时,生成器输出的人工样本已经逼真到使判别器无法分辨真伪,停止博弈。这种就可以得到一个具有“伪造”真实样本能力的生成器。
1.生成器
GANs中,生成器G选取随机噪声z作为输入,通过生成器的不断拟合,最终输出一个和真实样本尺寸相同,分布相似的伪造样本 G(z)。生成器的本质是一个使用生成方式的方法的模型,它对数据的分布假设和分布参数进行学习,然后根据学习到的模型重新采样出新的样本。
从数学上来说,生成式方法对于给定的真实数据,首先需要对数据的显式变量或隐含变量做分布假设;然后再将真实数据输入到模型中对参数、参数进行训练;最后得到一个学习后的近似分布,这个分布可以用来生成新的数据。从机器学习的角度来说,模型不会去拟合分布假设,而是通过不断地学习真实数据,对模型进行修正,最后可以得到一个学习后的模型来完成生成任务。这种方法不同于数学方法,学习的过程对人类理解较不直观。
2. 判别器
GANs中,判别器D对于输入的样本x,输出一个[0,1]之间的概率数值 D(x)。x可能是来自于原始数据集中的真实样本x,也可能是来自于生成器G的人工样本 G(z)。通常约定,概率值 D(x) 越接近1就代表此样本为真实样本的可能性更大;反之概率值越小则此样本为伪造样本的可能性越大。也就是说,这里的判别器是一个二分类的神经网络分类器,目的是不是判断输入数据的原始类别,而是区分输入样本的真伪。可以注意到,不管在生成器还是判别器中,样本的类别信息都没有用到,也表明 GAN 是一个无监督的学习过程。
3. 基本原理
GAN是博弈论和机器学习相结合的产物,于2014年由Ian Goodfellow的论文中问世,一经问世即火爆足以看出人们对于这种算法的认可和狂热的研究热忱。想要更详细的了解 GAN,就要知道它是怎么来的,以及这种算法出现的意义是什么。研究者最初想要通过计算机完成自动生成数据的功能,例如通过训练某种算法模型,让某模型学习到一些苹果的图片后能自动生成苹果的图片,具体些的猜法即认为具有生成功能。但是 GAN 不是一个生成算法,而是以往的生成算法在衡量生成图片和真实图片的差距时采用均方误差作为损失函数,但是研究者发现有时均方误差一样的两张生成图片效果却截然不同,鉴于此不足Ian Goodfellow提出了 GAN。
图1:GAN模型结构示意图
那么GAN是如何完成生成图片这项功能的呢,GAN是由两个模型组成的:生成模型G和判别模型D。首先第一个生成模型G的输入是随机噪声z,然后生成模型会生成一张初级图片,训练中判别模型D另其进行二分类操作,将生成的图片判别为0,而真实图片判别为1;为了欺骗一个鉴别器,早期一代生成模型开始伪化,然尔它进阶成了一个,当前生成的数据成功欺骗D1时,鉴别模型也会优化更新,进而升级为D2,按照同样的过程也会不断更新出N代的G和D。
二、前期准备工作
🏡 我的环境:
- 语言环境:Python3.8
- 编译器:Jupyter Notebook
- 深度学习环境:Pytorch
-
- torch==2.3.1+cu118
-
- torchvision==0.18.1+cu118
- 显卡(GPU):NVIDIA GeForce GTX 1660
1. 定义超参数
- n_epochs:这个参数决定了模型训练的总轮数。轮数越多,模型有更多机会学习数据中的模式,但也可能导致过拟合。
- batch_size:批次大小影响模型每次更新时使用的数据量。较小的批次可能导致训练过程波动较大,但可能有助于模型逃离局部最小值;较大的批次则可能使训练更稳定,但需要更多的内存空间。
- lr:学习率控制者模型权重更新的步长。学习率过大可能导致模型在最优解附近震荡甚至发散;学习率过小则可能导致模型收敛速度缓慢或陷入局部最小值。
- b1和b2:这两个参数是Adam优化器的一部分,分别控制一阶矩(梯度的指数移动平均)和二阶矩(梯度平方的指数移动平均)的指数衰减率。它们影响模型更新的稳定性和收敛速度。
- n_cpu:这个参数指定了用于数据加载的CPU数量,可以影响数据预处理和加载的速度,进而影响训练的效率。
- laten_dim:随机向量的维度,它影响生成器生成图像的多样性和质量。维度过低可能导致生成图像缺乏多样性,而维度过高可能导致模型难以训练。
- img_size:图像的大小直接影响模型的感受野和所需计算资源。图像尺寸越大,模型可能需要更多的计算资源和更长的训练时间。
- channels:图像的通道数,对于彩色图像通常是3(RGB),对于灰度图像是1。通道数影响模型处理的信息量。
- sample_interval:保存生成图像的间隔,这个参数决定了我们在训练过程中多久保存一次生成的图像,用于监控生成图像的质量。
- cuda:是否使用GPU进行计算,使用GPU可以显著加速模型的训练过程,因为GPU在并行处理大量计算时更为高效。
import argparse,os
import numpy as np
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable
import torch.nn as nn
import torch#创建文件夹
os.makedirs("./images/",exist_ok=True) #记录训练过程的图片效果
os.makedirs("./save/",exist_ok=True) #训练完成时模型保存的位置
os.makedirs("./datasets/mnist",exist_ok=True) #下载数据集存放的位置#超参数设置
n_epochs=50
batch_size=64
lr=0.0002
b1=0.5
b2=0.999
n_cpu=2
latent_dim=100
img_size=28
channels=1
sample_interval=500#图像的尺寸:(1,28,28),和图像的像素面积:(784)
img_shape=(channels,img_size,img_size)
img_area=np.prod(img_shape)#设置cuda:(cuda:0)
cuda=True if torch.cuda.is_available() else False
print(cuda)
运行结果:
True
2. 下载数据
# mnist数据集下载
mnist=datasets.MNIST(root='./datasets/',train=True,download=False,transform=transforms.Compose([transforms.Resize(img_size),transforms.ToTensor(),transforms.Normalize([0.5],[0.5])]))
由于本人网络因素,运行此段代码后一直提示连接超时而失败,故从网上直接下载MNIST文件,放置于固定路径的文件夹内,如图所示:
修改上述代码,修正文件的路径:
# mnist数据集下载
mnist=datasets.MNIST(root=r'E:\DATABASE',train=True,download=False,transform=transforms.Compose([transforms.Resize(img_size),transforms.ToTensor(),transforms.Normalize([0.5],[0.5])]))
注意,其中,root的路径仅仅写到 DATABASE即可,不需指向E:\DATABASE\MNIST\raw ,否则还是会引发错误。
3. 配置数据
# 配置数据到加载器
dataloader=DataLoader(mnist,batch_size=batch_size,shuffle=True
)
三、定义模型
1. 定义鉴别器
这段代码定义了一个名为Discriminator的类,它继承自 nn.Module。这个类是一个判别器模型,用于判断输入图像是否为真实图像。下面是对代码中每一行的详细解释:
1. class Discriminator(nn.Module):定义一个名为Discriminator 的类,它继承自 nn.Module。nn.Module 是PyTorch中的一个基类,用于构建神经网络模型。
2. def __init__(self):定义类的构造函数,用于初始化模型参数和层。
3. super(Discriminator,self).__init__():调用父类 nn.Module 的构造函数,以确保正确地初始化模型。
4. self.model=nn.Sequential :创建一个 nn.Sequential 对象,它是一个容器,用于按顺序堆叠多个神经网络层。
5. nn.Linear(img_area,512):添加一个线性层,输入大小为 img_area(图像区域的像素数),输出大小为512。这个层用于将输入图像展平并映射到一个新的特征空间。
6. nn.LeakyReLU(0.2,inplace=True):添加一个Leaky ReLU激活函数,其负斜率为0.2。inplace=True表示在原始数据上进行操作,以节省内存。
7. nn.Linear(512,256):添加一个线性层,输入大小为512,输出大小为256.这个层用于进一步将特征映射到更小的特征空间。
8. nn.LeakyReLU(0.2,inplace=True) :再次添加一个Leaky ReLU激活函数,与之前的层相同。
9. nn.Linear(256,1):添加一个线性层,输入大小为256,输出大小为1。这个层用于将特征映射到一个标量值,用于表示输入图像的真实性。
10. nn.Sigmoid():添加一个Sigmoid激活函数,将输出值限制在0到1之间。这可以解释为输入图像为真实图像的概率。
11. ):结束 nn.Sequential 对象的创建。
12. def forward(self,img): 定义模型的前向传播函数,用于计算输入图像的输出。
13. img_flat=img.view(img.size(0),-1):将输入图像 img 展平为一个一维向量。img.size(0) 表示批量大小,-1 表示自动计算剩余维度的大小。
14. validity=self.model(img_flat):将展平后的图像传递给之前定义的 nn.Sequential 模型,得到一个表示图像真实性的标量值。
15. return validity:返回计算得到图像真实性值。
# 将图片28*28展开成784,然后通过多层感知器,中间经过斜率设置为0.2的LeakyReLU激活函数,
# 最后接sigmoid激活函数得到一个0到1之间的概率进行二分类
class Discriminator(nn.Module):def __init__(self):super(Discriminator,self).__init__()self.model=nn.Sequential(nn.Linear(img_area,512), # 输入特征数为784,输出为512nn.LeakyReLU(0.2,inplace=True), # 进行非线性映射nn.Linear(512,256), # 输入特征数为512,输出为256nn.LeakyReLU(0.2,inplace=True), # 进行非线性映射nn.Linear(256,1), # 输入特征数为256,输出为1nn.Sigmoid(), # sigmoid是一个激活函数,二分类问题中可将实数映射到[0,1],作为概率值,多分类用softmax函数)def forward(self,img):img_flat=img.view(img.size(0),-1) # 鉴别器输入是一个被view展开的(784)的一维图像:(64,784)validity=self.model(img_flat) # 通过鉴别器网络return validity # 鉴别器返回的是一个[0,1]间的概率
2. 定义生成器
# 输入一个100维的0~1之间的高斯分布,然后通过第一层现行变换将其映射到256维,
# 然后通过LeakyReLU激活函数,接着进行一个线性变换,再经过一个LeakyReLU激活函数,
# 然后经过现行变换将其变成784维,最后经过Tanh激活函数是希望生成的假的图片数据分布,能够在-1~1之间。
class Generator(nn.Module):def __init__(self):super(Generator,self).__init__()# 模型中间块儿def block(in_feat,out_feat,normalize=True): # block(in,out)layers=[nn.Linear(in_feat,out_feat)] # 线性变换将输入映射到out维if normalize:layers.append(nn.BatchNorm1d(out_feat,0.8)) # 正则化layers.append(nn.LeakyReLU(0.2,inplace=True)) # 非线性激活函数return layers# prod():返回给定轴上的数组元素的乘积:1*28*28=784self.model=nn.Sequential(*block(latent_dim,128,normalize=False), # 线性变化将输入映射100 to 128,正则化,LeakyReLU*block(128,256), # 线性变化将输入映射128 to 256,正则化,LeakyReLU*block(256,512), # 线性变化将输入映射256 to 512,正则化,LeakyReLU*block(512,1024), # 线性变化将输入映射512 to 1024,正则化,LeakyReLUnn.Linear(1024,img_area), # 线性变化将输入映射1024 to 784nn.Tanh() # 将(784)的数据每一个都映射到[-1,1]之间)# view():相当于numpy中的reshape,重新定义矩阵的形状:这里是reshape(64,1,28,28)def forward(self,z): # 输入的是(64,100)的噪声数据imgs=self.model(z) # 噪声数据通过生成器模型imgs=imgs.view(imgs.size(0),*img_shape) # reshape成(64,1,28,28return imgs # 输出为64张大小为(1,28,28)的图像
四、训练模型
1. 创建实例
# 创建生成器,判别器对象
generator=Generator()
discriminator=Discriminator()# 首先需要定义loss的度量方式(二分类的交叉熵)
criterion=torch.nn.BCELoss()# 其次定义优化函数,优化函数的学习率为0.0003
# betas:用于计算梯度以及梯度平方的运行平均值的系数
optimizer_G=torch.optim.Adam(generator.parameters(),lr=lr,betas=(b1,b2))
optimizer_D=torch.optim.Adam(discriminator.parameters(),lr=lr,betas=(b1,b2))# 如果有显卡,都在cuda模式中运行
if torch.cuda.is_available():generator=generator.cuda()discriminator=discriminator.cuda()criterion=criterion.cuda()
2. 训练模型
# 进行多个epoch的训练
for epoch in range(n_epochs): # epoch:50for i ,(imgs,_) in enumerate(dataloader): # imgs:(64,1,28,28)##===============训练判别器===============## view():相当于numpy中的reshape,重新定义矩阵的形状,相当于reshape(128,784) 原来是(128,1,28,28)imgs=imgs.view(imgs.size(0),-1) # 将图片展开为28*28=784 imgs:(64,784)real_img=Variable(imgs).cuda() # 将tensor变成Variable放入计算图中,tensor变成variable之后才能进行反向传播求梯度real_label=Variable(torch.ones(imgs.size(0),1)).cuda() # 定义真实的图片label为1fake_label=Variable(torch.zeros(imgs.size(0),1)).cuda() # 定义真实的图片label为0## ------------------------## Train Discriminator## 分为两部分:1、真的图像判别为真;2、假的图像判别为假## ------------------------## 计算真实图片的损失real_out=discriminator(real_img) # 将真实图片放入判别器中loss_real_D=criterion(real_out,real_label) # 得到真实图片的lossreal_scores=real_out # 得到真实图片的判别值,输出的值越接近1越好# 计算假的图片的损失# detach():从当前计算图中分离下来避免梯度传到G,因为G不用更新z=Variable(torch.randn(imgs.size(0),latent_dim)).cuda()fake_img=generator(z).detach()fake_out=discriminator(fake_img)loss_fake_D=criterion(fake_out,fake_label)fake_scores=fake_out# 损失函数和优化loss_D=loss_real_D+loss_fake_D # 损失包括判真损失和判假损失optimizer_D.zero_grad() # 在反向传播之前,先将梯度归0loss_D.backward() # 将误差反向传播optimizer_D.step() # 更新参数# ------------------------# Train Generator# 原理:目的是希望生成的假的图片被判别器判断为真的图片,# 在此过程中,将判别器固定,将假的图片传入判别器的结果与真实的label对应,# 反向传播更新的参数是生成网络里面的参数,# 这样可以通过更新生成网络里面的参数,来训练网络,使得生成的图片让判别器以为是真的,这样就达到了对抗的目的# -------------------------z=Variable(torch.randn(imgs.size(0),latent_dim)).cuda()fake_img=generator(z)output=discriminator(fake_img)# 损失函数和优化loss_G=criterion(output,real_label)optimizer_G.zero_grad()loss_G.backward()optimizer_G.step()# 打印训练过程中的日志# item():取出单元素张量的元素值并返回该值,保持原元素类型不变if (i+1)%300==0:print("[Epoch %d %d] [Batch %d %d] [D loss:%f] [G loss:%f] [D real:%f] [D fake:%f]"% (epoch,n_epochs,i,len(dataloader),loss_D.item(),loss_G.item(),real_scores.data.mean(),fake_scores.data.mean()))# 保存训练过程中的图像batches_done=epoch*len(dataloader)+iif batches_done % sample_interval==0:save_image(fake_img.data[:25],"./images/%d.png" % batches_done,nrow=5,normalize=True)
运行结果:
[Epoch 0 50] [Batch 299 938] [D loss:1.146173] [G loss:0.765958] [D real:0.456922] [D fake:0.281837]
[Epoch 0 50] [Batch 599 938] [D loss:0.969525] [G loss:1.153390] [D real:0.582100] [D fake:0.320095]
[Epoch 0 50] [Batch 899 938] [D loss:1.116867] [G loss:1.869213] [D real:0.705611] [D fake:0.505496]
[Epoch 1 50] [Batch 299 938] [D loss:0.881180] [G loss:1.415046] [D real:0.759470] [D fake:0.425569]
[Epoch 1 50] [Batch 599 938] [D loss:1.231377] [G loss:2.431004] [D real:0.902129] [D fake:0.662844]
[Epoch 1 50] [Batch 899 938] [D loss:0.909482] [G loss:1.374928] [D real:0.593177] [D fake:0.225221]
[Epoch 2 50] [Batch 299 938] [D loss:0.575307] [G loss:2.108607] [D real:0.738606] [D fake:0.190559]
[Epoch 2 50] [Batch 599 938] [D loss:0.811977] [G loss:1.611509] [D real:0.712358] [D fake:0.322535]
[Epoch 2 50] [Batch 899 938] [D loss:0.880121] [G loss:2.131778] [D real:0.841744] [D fake:0.490013]
[Epoch 3 50] [Batch 299 938] [D loss:0.986335] [G loss:1.951669] [D real:0.792694] [D fake:0.472771]
[Epoch 3 50] [Batch 599 938] [D loss:1.341987] [G loss:3.226815] [D real:0.879281] [D fake:0.671599]
[Epoch 3 50] [Batch 899 938] [D loss:0.776661] [G loss:1.787926] [D real:0.713198] [D fake:0.286383]
[Epoch 4 50] [Batch 299 938] [D loss:1.166095] [G loss:2.602376] [D real:0.840969] [D fake:0.589564]
[Epoch 4 50] [Batch 599 938] [D loss:1.046687] [G loss:2.091129] [D real:0.746495] [D fake:0.469416]
[Epoch 4 50] [Batch 899 938] [D loss:0.814004] [G loss:1.554609] [D real:0.671112] [D fake:0.266649]
[Epoch 5 50] [Batch 299 938] [D loss:0.724353] [G loss:1.310882] [D real:0.714876] [D fake:0.262048]
[Epoch 5 50] [Batch 599 938] [D loss:0.839810] [G loss:1.418422] [D real:0.645111] [D fake:0.260502]
[Epoch 5 50] [Batch 899 938] [D loss:0.826553] [G loss:1.109170] [D real:0.605030] [D fake:0.144340]
[Epoch 6 50] [Batch 299 938] [D loss:0.896352] [G loss:1.229198] [D real:0.684517] [D fake:0.319981]
[Epoch 6 50] [Batch 599 938] [D loss:1.040686] [G loss:0.764906] [D real:0.511610] [D fake:0.197549]
[Epoch 6 50] [Batch 899 938] [D loss:0.850218] [G loss:1.343065] [D real:0.742433] [D fake:0.361607]
[Epoch 7 50] [Batch 299 938] [D loss:1.200897] [G loss:2.300584] [D real:0.791116] [D fake:0.560915]
[Epoch 7 50] [Batch 599 938] [D loss:0.834816] [G loss:1.438616] [D real:0.660125] [D fake:0.255994]
[Epoch 7 50] [Batch 899 938] [D loss:1.014503] [G loss:1.111037] [D real:0.592165] [D fake:0.250707]
[Epoch 8 50] [Batch 299 938] [D loss:0.869173] [G loss:1.407378] [D real:0.703362] [D fake:0.347800]
[Epoch 8 50] [Batch 599 938] [D loss:0.858265] [G loss:1.736284] [D real:0.740475] [D fake:0.365312]
[Epoch 8 50] [Batch 899 938] [D loss:1.030783] [G loss:0.993396] [D real:0.551013] [D fake:0.228221]
[Epoch 9 50] [Batch 299 938] [D loss:0.934841] [G loss:1.730107] [D real:0.759375] [D fake:0.423129]
[Epoch 9 50] [Batch 599 938] [D loss:0.960858] [G loss:1.587723] [D real:0.713016] [D fake:0.372299]
[Epoch 9 50] [Batch 899 938] [D loss:0.971297] [G loss:1.219579] [D real:0.573871] [D fake:0.233515]
[Epoch 10 50] [Batch 299 938] [D loss:1.034411] [G loss:0.882193] [D real:0.556639] [D fake:0.225388]
[Epoch 10 50] [Batch 599 938] [D loss:0.928058] [G loss:1.446012] [D real:0.766924] [D fake:0.428943]
[Epoch 10 50] [Batch 899 938] [D loss:1.064512] [G loss:1.296003] [D real:0.640255] [D fake:0.396908]
[Epoch 11 50] [Batch 299 938] [D loss:1.092536] [G loss:1.372938] [D real:0.642308] [D fake:0.394312]
[Epoch 11 50] [Batch 599 938] [D loss:0.991018] [G loss:0.946512] [D real:0.595111] [D fake:0.310966]
[Epoch 11 50] [Batch 899 938] [D loss:0.954925] [G loss:1.597548] [D real:0.664621] [D fake:0.347017]
[Epoch 12 50] [Batch 299 938] [D loss:0.979235] [G loss:1.355778] [D real:0.674148] [D fake:0.367891]
[Epoch 12 50] [Batch 599 938] [D loss:1.052730] [G loss:0.941167] [D real:0.552950] [D fake:0.290553]
[Epoch 12 50] [Batch 899 938] [D loss:1.192770] [G loss:1.793072] [D real:0.769454] [D fake:0.560993]
[Epoch 13 50] [Batch 299 938] [D loss:1.115526] [G loss:1.788715] [D real:0.768149] [D fake:0.526000]
[Epoch 13 50] [Batch 599 938] [D loss:1.123425] [G loss:0.820183] [D real:0.524850] [D fake:0.275359]
[Epoch 13 50] [Batch 899 938] [D loss:1.156975] [G loss:0.734737] [D real:0.468749] [D fake:0.208848]
[Epoch 14 50] [Batch 299 938] [D loss:1.098107] [G loss:0.777794] [D real:0.550806] [D fake:0.301756]
[Epoch 14 50] [Batch 599 938] [D loss:1.037302] [G loss:1.495595] [D real:0.798072] [D fake:0.499724]
[Epoch 14 50] [Batch 899 938] [D loss:1.156028] [G loss:1.317196] [D real:0.697508] [D fake:0.489884]
[Epoch 15 50] [Batch 299 938] [D loss:1.224238] [G loss:1.737315] [D real:0.730509] [D fake:0.544304]
[Epoch 15 50] [Batch 599 938] [D loss:1.179392] [G loss:2.003285] [D real:0.793012] [D fake:0.566236]
[Epoch 15 50] [Batch 899 938] [D loss:1.196175] [G loss:0.796566] [D real:0.496473] [D fake:0.287953]
[Epoch 16 50] [Batch 299 938] [D loss:1.378808] [G loss:0.785516] [D real:0.397582] [D fake:0.202002]
[Epoch 16 50] [Batch 599 938] [D loss:1.157125] [G loss:1.051247] [D real:0.690098] [D fake:0.492802]
[Epoch 16 50] [Batch 899 938] [D loss:1.031536] [G loss:1.349716] [D real:0.687725] [D fake:0.427418]
[Epoch 17 50] [Batch 299 938] [D loss:1.051005] [G loss:0.958619] [D real:0.631580] [D fake:0.356514]
[Epoch 17 50] [Batch 599 938] [D loss:1.128377] [G loss:1.015510] [D real:0.603700] [D fake:0.404514]
[Epoch 17 50] [Batch 899 938] [D loss:0.960494] [G loss:1.295537] [D real:0.667236] [D fake:0.364065]
[Epoch 18 50] [Batch 299 938] [D loss:1.064532] [G loss:1.428140] [D real:0.698197] [D fake:0.438659]
[Epoch 18 50] [Batch 599 938] [D loss:0.999358] [G loss:1.615426] [D real:0.719155] [D fake:0.438274]
[Epoch 18 50] [Batch 899 938] [D loss:0.994069] [G loss:1.085498] [D real:0.729381] [D fake:0.432135]
[Epoch 19 50] [Batch 299 938] [D loss:0.978384] [G loss:1.053926] [D real:0.573148] [D fake:0.251593]
[Epoch 19 50] [Batch 599 938] [D loss:1.287538] [G loss:2.345575] [D real:0.834289] [D fake:0.632470]
[Epoch 19 50] [Batch 899 938] [D loss:1.063443] [G loss:0.980304] [D real:0.591775] [D fake:0.346237]
[Epoch 20 50] [Batch 299 938] [D loss:1.143715] [G loss:1.364538] [D real:0.703250] [D fake:0.489390]
[Epoch 20 50] [Batch 599 938] [D loss:1.054832] [G loss:1.229990] [D real:0.630416] [D fake:0.380441]
[Epoch 20 50] [Batch 899 938] [D loss:1.092473] [G loss:0.980192] [D real:0.640414] [D fake:0.396368]
[Epoch 21 50] [Batch 299 938] [D loss:1.114693] [G loss:0.815711] [D real:0.507922] [D fake:0.257900]
[Epoch 21 50] [Batch 599 938] [D loss:1.156865] [G loss:0.674195] [D real:0.485364] [D fake:0.263269]
[Epoch 21 50] [Batch 899 938] [D loss:1.122650] [G loss:1.010019] [D real:0.555538] [D fake:0.321152]
[Epoch 22 50] [Batch 299 938] [D loss:1.073445] [G loss:1.221281] [D real:0.612212] [D fake:0.366838]
[Epoch 22 50] [Batch 599 938] [D loss:1.313248] [G loss:1.968656] [D real:0.841517] [D fake:0.641602]
[Epoch 22 50] [Batch 899 938] [D loss:1.140289] [G loss:0.959995] [D real:0.545480] [D fake:0.317969]
[Epoch 23 50] [Batch 299 938] [D loss:1.078399] [G loss:1.082814] [D real:0.543149] [D fake:0.250095]
[Epoch 23 50] [Batch 599 938] [D loss:0.988625] [G loss:0.731643] [D real:0.570309] [D fake:0.271124]
[Epoch 23 50] [Batch 899 938] [D loss:0.899557] [G loss:1.041285] [D real:0.663554] [D fake:0.316601]
[Epoch 24 50] [Batch 299 938] [D loss:0.940427] [G loss:1.147060] [D real:0.600235] [D fake:0.254879]
[Epoch 24 50] [Batch 599 938] [D loss:1.129661] [G loss:0.986496] [D real:0.587398] [D fake:0.383421]
[Epoch 24 50] [Batch 899 938] [D loss:1.180622] [G loss:1.420076] [D real:0.663551] [D fake:0.453712]
[Epoch 25 50] [Batch 299 938] [D loss:1.194358] [G loss:1.158545] [D real:0.555807] [D fake:0.351430]
[Epoch 25 50] [Batch 599 938] [D loss:1.072583] [G loss:0.899997] [D real:0.611936] [D fake:0.377545]
[Epoch 25 50] [Batch 899 938] [D loss:1.287829] [G loss:1.398052] [D real:0.746196] [D fake:0.564117]
[Epoch 26 50] [Batch 299 938] [D loss:1.019067] [G loss:1.436803] [D real:0.705079] [D fake:0.421282]
[Epoch 26 50] [Batch 599 938] [D loss:0.999005] [G loss:0.934130] [D real:0.596882] [D fake:0.313623]
[Epoch 26 50] [Batch 899 938] [D loss:1.002017] [G loss:1.360719] [D real:0.691842] [D fake:0.418453]
[Epoch 27 50] [Batch 299 938] [D loss:0.970563] [G loss:1.134730] [D real:0.579426] [D fake:0.247908]
[Epoch 27 50] [Batch 599 938] [D loss:1.117202] [G loss:1.332376] [D real:0.686316] [D fake:0.463324]
[Epoch 27 50] [Batch 899 938] [D loss:1.184001] [G loss:1.458771] [D real:0.697385] [D fake:0.492752]
[Epoch 28 50] [Batch 299 938] [D loss:0.947888] [G loss:1.083961] [D real:0.588386] [D fake:0.223453]
[Epoch 28 50] [Batch 599 938] [D loss:1.066112] [G loss:0.992872] [D real:0.622085] [D fake:0.352469]
[Epoch 28 50] [Batch 899 938] [D loss:1.038555] [G loss:0.842831] [D real:0.508421] [D fake:0.207675]
[Epoch 29 50] [Batch 299 938] [D loss:1.016933] [G loss:1.008511] [D real:0.568845] [D fake:0.286478]
[Epoch 29 50] [Batch 599 938] [D loss:1.021461] [G loss:1.016894] [D real:0.585771] [D fake:0.298825]
[Epoch 29 50] [Batch 899 938] [D loss:1.079430] [G loss:1.287262] [D real:0.702781] [D fake:0.453398]
[Epoch 30 50] [Batch 299 938] [D loss:1.248966] [G loss:0.617484] [D real:0.411599] [D fake:0.147803]
[Epoch 30 50] [Batch 599 938] [D loss:1.064411] [G loss:1.064151] [D real:0.599990] [D fake:0.340874]
[Epoch 30 50] [Batch 899 938] [D loss:1.026167] [G loss:1.246001] [D real:0.730361] [D fake:0.453320]
[Epoch 31 50] [Batch 299 938] [D loss:1.197263] [G loss:1.408032] [D real:0.631419] [D fake:0.418329]
[Epoch 31 50] [Batch 599 938] [D loss:1.089052] [G loss:1.414903] [D real:0.673101] [D fake:0.417530]
[Epoch 31 50] [Batch 899 938] [D loss:1.008322] [G loss:1.245319] [D real:0.645966] [D fake:0.369107]
[Epoch 32 50] [Batch 299 938] [D loss:1.141207] [G loss:1.247429] [D real:0.687800] [D fake:0.448649]
[Epoch 32 50] [Batch 599 938] [D loss:1.035696] [G loss:1.277810] [D real:0.568379] [D fake:0.297915]
[Epoch 32 50] [Batch 899 938] [D loss:1.126694] [G loss:1.294756] [D real:0.638698] [D fake:0.418313]
[Epoch 33 50] [Batch 299 938] [D loss:0.998067] [G loss:1.268087] [D real:0.658714] [D fake:0.376252]
[Epoch 33 50] [Batch 599 938] [D loss:1.044655] [G loss:0.797849] [D real:0.548929] [D fake:0.264930]
[Epoch 33 50] [Batch 899 938] [D loss:0.974028] [G loss:1.008849] [D real:0.674779] [D fake:0.384628]
[Epoch 34 50] [Batch 299 938] [D loss:0.972191] [G loss:0.872447] [D real:0.618687] [D fake:0.303252]
[Epoch 34 50] [Batch 599 938] [D loss:1.152283] [G loss:1.027059] [D real:0.548038] [D fake:0.328294]
[Epoch 34 50] [Batch 899 938] [D loss:1.078370] [G loss:1.160610] [D real:0.579307] [D fake:0.324892]
[Epoch 35 50] [Batch 299 938] [D loss:0.951649] [G loss:1.953001] [D real:0.729165] [D fake:0.408322]
[Epoch 35 50] [Batch 599 938] [D loss:1.063294] [G loss:1.334316] [D real:0.555569] [D fake:0.243107]
[Epoch 35 50] [Batch 899 938] [D loss:1.027119] [G loss:1.213654] [D real:0.678924] [D fake:0.402880]
[Epoch 36 50] [Batch 299 938] [D loss:1.029951] [G loss:1.602960] [D real:0.705595] [D fake:0.412937]
[Epoch 36 50] [Batch 599 938] [D loss:0.971645] [G loss:1.134115] [D real:0.657235] [D fake:0.349733]
[Epoch 36 50] [Batch 899 938] [D loss:1.075806] [G loss:0.966678] [D real:0.599718] [D fake:0.344953]
[Epoch 37 50] [Batch 299 938] [D loss:1.002480] [G loss:0.985621] [D real:0.616168] [D fake:0.303731]
[Epoch 37 50] [Batch 599 938] [D loss:0.956862] [G loss:1.219155] [D real:0.616930] [D fake:0.307200]
[Epoch 37 50] [Batch 899 938] [D loss:1.125469] [G loss:1.351054] [D real:0.663360] [D fake:0.420789]
[Epoch 38 50] [Batch 299 938] [D loss:1.028449] [G loss:1.462967] [D real:0.668299] [D fake:0.382929]
[Epoch 38 50] [Batch 599 938] [D loss:1.207767] [G loss:0.959586] [D real:0.461430] [D fake:0.199629]
[Epoch 38 50] [Batch 899 938] [D loss:1.258663] [G loss:1.502312] [D real:0.617351] [D fake:0.445597]
[Epoch 39 50] [Batch 299 938] [D loss:1.208128] [G loss:0.871164] [D real:0.518784] [D fake:0.297728]
[Epoch 39 50] [Batch 599 938] [D loss:0.842718] [G loss:1.621715] [D real:0.767684] [D fake:0.387491]
[Epoch 39 50] [Batch 899 938] [D loss:1.014976] [G loss:1.573629] [D real:0.605987] [D fake:0.268213]
[Epoch 40 50] [Batch 299 938] [D loss:1.158162] [G loss:1.213044] [D real:0.638538] [D fake:0.432146]
[Epoch 40 50] [Batch 599 938] [D loss:1.073439] [G loss:1.219819] [D real:0.714997] [D fake:0.455074]
[Epoch 40 50] [Batch 899 938] [D loss:1.101056] [G loss:0.777753] [D real:0.534668] [D fake:0.248878]
[Epoch 41 50] [Batch 299 938] [D loss:1.167324] [G loss:1.538780] [D real:0.654459] [D fake:0.400560]
[Epoch 41 50] [Batch 599 938] [D loss:1.106753] [G loss:1.507490] [D real:0.733342] [D fake:0.462992]
[Epoch 41 50] [Batch 899 938] [D loss:1.065012] [G loss:1.261401] [D real:0.611546] [D fake:0.337609]
[Epoch 42 50] [Batch 299 938] [D loss:0.970665] [G loss:1.531132] [D real:0.743050] [D fake:0.412847]
[Epoch 42 50] [Batch 599 938] [D loss:1.135600] [G loss:2.105861] [D real:0.796165] [D fake:0.524620]
[Epoch 42 50] [Batch 899 938] [D loss:1.168230] [G loss:0.992182] [D real:0.521584] [D fake:0.242412]
[Epoch 43 50] [Batch 299 938] [D loss:0.986416] [G loss:1.345769] [D real:0.691150] [D fake:0.393895]
[Epoch 43 50] [Batch 599 938] [D loss:1.088583] [G loss:0.916690] [D real:0.523170] [D fake:0.204282]
[Epoch 43 50] [Batch 899 938] [D loss:1.009906] [G loss:1.337040] [D real:0.638337] [D fake:0.335985]
[Epoch 44 50] [Batch 299 938] [D loss:0.903776] [G loss:1.462499] [D real:0.717814] [D fake:0.354137]
[Epoch 44 50] [Batch 599 938] [D loss:0.909037] [G loss:1.107432] [D real:0.676036] [D fake:0.324931]
[Epoch 44 50] [Batch 899 938] [D loss:1.019271] [G loss:1.257354] [D real:0.719735] [D fake:0.428816]
[Epoch 45 50] [Batch 299 938] [D loss:1.082113] [G loss:0.695156] [D real:0.515746] [D fake:0.235972]
[Epoch 45 50] [Batch 599 938] [D loss:1.047832] [G loss:1.032727] [D real:0.603880] [D fake:0.322159]
[Epoch 45 50] [Batch 899 938] [D loss:0.957596] [G loss:1.237282] [D real:0.655737] [D fake:0.336426]
[Epoch 46 50] [Batch 299 938] [D loss:0.993955] [G loss:0.741267] [D real:0.602579] [D fake:0.290792]
[Epoch 46 50] [Batch 599 938] [D loss:1.105241] [G loss:1.356119] [D real:0.841550] [D fake:0.551605]
[Epoch 46 50] [Batch 899 938] [D loss:0.985514] [G loss:0.965528] [D real:0.586651] [D fake:0.265542]
[Epoch 47 50] [Batch 299 938] [D loss:1.047470] [G loss:1.568277] [D real:0.683706] [D fake:0.370035]
[Epoch 47 50] [Batch 599 938] [D loss:1.027143] [G loss:1.220082] [D real:0.582073] [D fake:0.298755]
[Epoch 47 50] [Batch 899 938] [D loss:0.908342] [G loss:1.165572] [D real:0.690787] [D fake:0.344307]
[Epoch 48 50] [Batch 299 938] [D loss:1.004957] [G loss:0.879145] [D real:0.577572] [D fake:0.245257]
[Epoch 48 50] [Batch 599 938] [D loss:0.883799] [G loss:1.223790] [D real:0.660361] [D fake:0.281309]
[Epoch 48 50] [Batch 899 938] [D loss:1.132943] [G loss:1.266990] [D real:0.634662] [D fake:0.386300]
[Epoch 49 50] [Batch 299 938] [D loss:1.086390] [G loss:1.289404] [D real:0.661946] [D fake:0.414245]
[Epoch 49 50] [Batch 599 938] [D loss:1.133659] [G loss:0.853763] [D real:0.569690] [D fake:0.297656]
[Epoch 49 50] [Batch 899 938] [D loss:1.489584] [G loss:1.584844] [D real:0.825768] [D fake:0.638332]
3. 最终效果图
4. 保存模型
# 保存模型
torch.save(generator.state_dict(),'./save/generator.pth')
torch.save(generator.state_dict(),'./save/discriminator.pth')
五、心得体会
通过本项目联系,大体了解了GAN的工作原理,对其模型结构有了一个初步的认识。最后,通过项目练习生成了手写数字图像。从最终生成图像可以看出,比较接近手写数字图形。