生成对抗网络(GAN)入门与编程实现

生成对抗网络(Generative Adversarial Networks, 简称 GAN)自 2014 年由 Ian Goodfellow 等人提出以来,迅速成为机器学习和深度学习领域的重要工具之一。GAN 以其在图像生成、风格转换、数据增强等领域的出色表现,吸引了广泛的研究兴趣和应用探索。本文将介绍 GAN 的基本概念、工作原理以及如何通过代码实现一个简单的 GAN 模型。

什么是生成对抗网络(GAN)?

GAN 是一种生成模型,旨在通过学习数据的潜在分布,生成与真实数据相似的样本。它由两个核心部分组成:

  • 生成器(Generator):输入一个随机噪声向量,通过一系列的变换生成假数据,目标是让生成的假数据尽可能接近真实数据。
  • 判别器(Discriminator):输入真实数据和生成器生成的假数据,输出判断其真实性的概率,目标是尽可能准确地区分真实数据和生成数据。
    二者在训练过程中相互对抗,形成一个博弈过程。

在这里插入图片描述

GAN 的工作原理

GAN 的训练过程可以看作是生成器和判别器之间的"零和博弈":

  1. 生成器:
  • 输入随机噪声向量 z z z(通常服从正态分布)。
  • 输出生成的样本 G ( z ) G(z) G(z)
  • 目标是让判别器无法区分 G ( z ) G(z) G(z) 和真实数据。
  1. 判别器:
  • 输入真实样本 x x x 和生成器生成的假样本 G ( z ) G(z) G(z)
  • 输出区分真假样本的概率。
  • 目标是最大化对真实样本和生成样本的区分能力。

通过对模型进行训练,生成器逐渐生成更接近真实分布的样本,而判别器也不断提高其判别能力,直到达到平衡。
在这里插入图片描述
完整的训练过程如下:
在这里插入图片描述

GAN 的代码实现

接下来,我们通过 PyTorch 实现一个简单的 GAN 模型,生成 MNIST 手写数字图片。

  1. 数据加载与预处理
    MNIST 是一个常用的手写数字数据集,每张图片的大小为 28x28,灰度范围为 0-1。
# data_loader
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=(0.5), std=(0.5))
])
train_loader = torch.utils.data.DataLoader(datasets.MNIST('data', train=True, download=True, transform=transform),batch_size=batch_size, shuffle=True)

使用 torchvision 的 datasets.MNIST 下载MNIST数据集。之后,将图片转换为Tensor格式,并对像素值进行归一化(均值0.5,标准差0.5)。

  1. 构建生成器与判别器
    生成器和判别器都是多层全连接神经网络。
# G(z)
class generator(nn.Module):# initializersdef __init__(self, input_size=32, n_class = 10):super(generator, self).__init__()self.fc1 = nn.Linear(input_size, 256)self.fc2 = nn.Linear(self.fc1.out_features, 512)self.fc3 = nn.Linear(self.fc2.out_features, 1024)self.fc4 = nn.Linear(self.fc3.out_features, n_class)# forward methoddef forward(self, input):x = F.leaky_relu(self.fc1(input), 0.2)x = F.leaky_relu(self.fc2(x), 0.2)x = F.leaky_relu(self.fc3(x), 0.2)x = F.tanh(self.fc4(x))x = x.squeeze(-1)return xclass discriminator(nn.Module):# initializersdef __init__(self, input_size=32, n_class=10):super(discriminator, self).__init__()self.fc1 = nn.Linear(input_size, 1024)self.fc2 = nn.Linear(self.fc1.out_features, 512)self.fc3 = nn.Linear(self.fc2.out_features, 256)self.fc4 = nn.Linear(self.fc3.out_features, n_class)# forward methoddef forward(self, input):x = F.leaky_relu(self.fc1(input), 0.2)x = F.dropout(x, 0.3)x = F.leaky_relu(self.fc2(x), 0.2)x = F.dropout(x, 0.3)x = F.leaky_relu(self.fc3(x), 0.2)x = F.dropout(x, 0.3)x = F.sigmoid(self.fc4(x))x = x.squeeze(-1)return x# network
G = generator(input_size=100, n_class=28*28)
D = discriminator(input_size=28*28, n_class=1)
  • 生成器 (generator):

    • 输入:一个大小为100的噪声向量。
    • 结构:包含4个全连接层(fc1到fc4),每层后面跟随一个激活函数:
      • 前三层使用 LeakyReLU 激活函数,最后一层使用 tanh。
      • 输出大小为 28×28(MNIST图片的尺寸)。
    • 功能:将随机噪声映射为类似于手写数字的图片。
  • 判别器 (discriminator):

    • 输入:展平的MNIST图片(大小为 28×28)。
    • 结构:包含4个全连接层(fc1到fc4),每层后面跟随:
      • LeakyReLU 激活函数和 Dropout(用于防止过拟合)。
      • 最后一层使用 sigmoid 激活函数。
    • 输出:一个介于0和1之间的值,表示输入是“真实图片”的概率。
  1. 定义训练参数以及损失函数和优化器
# training parameters
batch_size = 256
lr = 0.0002
train_epoch = 200
device = torch.cuda.is_available()
if device:print("running on GPU!")# Binary Cross Entropy loss
BCE_loss = nn.BCELoss()#move to cuda
if device:G.cuda()D.cuda()BCE_loss = BCE_loss.cuda()# Adam optimizer
G_optimizer = optim.Adam(G.parameters(), lr=lr)
D_optimizer = optim.Adam(D.parameters(), lr=lr)
4. 训练过程
在训练过程中,我们交替训练判别器和生成器。
train_hist = {}
train_hist['D_losses'] = []
train_hist['G_losses'] = []for epoch in range(train_epoch):D_losses = []G_losses = []#生成任务,不需要标签for x_, _ in train_loader:#训练图像展平x_ = x_.view(-1, 28 * 28)mini_batch = x_.size()[0]y_real_ = torch.ones(mini_batch)y_fake_ = torch.zeros(mini_batch)# train discriminator DD.zero_grad()z_ = torch.randn((mini_batch, 100))if device:x_, y_real_, y_fake_ = x_.cuda(), y_real_.cuda(), y_fake_.cuda()z_ = z_.cuda()#真数据lossD_result = D(x_)D_real_loss = BCE_loss(D_result, y_real_)D_real_score = D_result#假数据lossG_result = G(z_)D_result = D(G_result)D_fake_loss = BCE_loss(D_result, y_fake_)D_fake_score = D_resultD_train_loss = D_real_loss + D_fake_lossD_train_loss.backward()D_optimizer.step()D_losses.append(D_train_loss.item())# train generator GG.zero_grad()# z_ = torch.randn((mini_batch, 100))# if device:#     z_ = z_.cuda()G_result = G(z_)D_result = D(G_result)G_train_loss = BCE_loss(D_result, y_real_)G_train_loss.backward()G_optimizer.step()G_losses.append(G_train_loss.item())print('[%d/%d]: loss_d: %.3f, loss_g: %.3f' % ((epoch + 1), train_epoch, torch.mean(torch.FloatTensor(D_losses)), torch.mean(torch.FloatTensor(G_losses))))if epoch %10 == 0:p = 'MNIST_GAN_results/Random_results/MNIST_GAN_' + str(epoch + 1) + '.png'fixed_p = 'MNIST_GAN_results/Fixed_results/MNIST_GAN_' + str(epoch + 1) + '.png'show_result((epoch+1), save=True, path=p, isFix=False)show_result((epoch+1), save=True, path=fixed_p, isFix=True)train_hist['D_losses'].append(torch.mean(torch.FloatTensor(D_losses)))train_hist['G_losses'].append(torch.mean(torch.FloatTensor(G_losses)))

采用交叉熵损失函数(BCE)计算Loss,即
在这里插入图片描述
其中判别器的loss计算如下:
在这里插入图片描述
生成器的loss计算如下:
在这里插入图片描述

  1. 保存模型及数据
    将生成器和判别器的模型参数进行保存,保存训练过程的loss数据。
print("Training finish!... save training results")
torch.save(G.state_dict(), "MNIST_GAN_results/generator_param.pkl")
torch.save(D.state_dict(), "MNIST_GAN_results/discriminator_param.pkl")
with open('MNIST_GAN_results/train_hist.pkl', 'wb') as f:pickle.dump(train_hist, f)
  1. 数据可视化
show_train_hist(train_hist, save=True, path='MNIST_GAN_results/MNIST_GAN_train_hist.png')images = []
for e in range(train_epoch):img_name = 'MNIST_GAN_results/Fixed_results/MNIST_GAN_' + str(e + 1) + '.png'images.append(imageio.imread(img_name))
imageio.mimsave('MNIST_GAN_results/generation_animation.gif', images, fps=5)
  1. 完整代码
import os
import matplotlib.pyplot as plt
import itertools
import pickle
import imageio
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
# from torch.autograd import Variable# G(z)
class generator(nn.Module):# initializersdef __init__(self, input_size=32, n_class = 10):super(generator, self).__init__()self.fc1 = nn.Linear(input_size, 256)self.fc2 = nn.Linear(self.fc1.out_features, 512)self.fc3 = nn.Linear(self.fc2.out_features, 1024)self.fc4 = nn.Linear(self.fc3.out_features, n_class)# forward methoddef forward(self, input):x = F.leaky_relu(self.fc1(input), 0.2)x = F.leaky_relu(self.fc2(x), 0.2)x = F.leaky_relu(self.fc3(x), 0.2)x = F.tanh(self.fc4(x))x = x.squeeze(-1)return xclass discriminator(nn.Module):# initializersdef __init__(self, input_size=32, n_class=10):super(discriminator, self).__init__()self.fc1 = nn.Linear(input_size, 1024)self.fc2 = nn.Linear(self.fc1.out_features, 512)self.fc3 = nn.Linear(self.fc2.out_features, 256)self.fc4 = nn.Linear(self.fc3.out_features, n_class)# forward methoddef forward(self, input):x = F.leaky_relu(self.fc1(input), 0.2)x = F.dropout(x, 0.3)x = F.leaky_relu(self.fc2(x), 0.2)x = F.dropout(x, 0.3)x = F.leaky_relu(self.fc3(x), 0.2)x = F.dropout(x, 0.3)x = F.sigmoid(self.fc4(x))x = x.squeeze(-1)return xfixed_z_ = torch.randn((5 * 5, 100))    # fixed noise
with torch.no_grad():fixed_z_ = fixed_z_.cuda()def show_result(num_epoch, show = False, save = False, path = 'result.png', isFix=False):z_ = torch.randn((5*5, 100))with torch.no_grad():z_ = z_.cuda()# z_ = Variable(z_.cuda(), volatile=True)G.eval()if isFix:test_images = G(fixed_z_)else:test_images = G(z_)G.train()size_figure_grid = 5fig, ax = plt.subplots(size_figure_grid, size_figure_grid, figsize=(5, 5))for i, j in itertools.product(range(size_figure_grid), range(size_figure_grid)):ax[i, j].get_xaxis().set_visible(False)ax[i, j].get_yaxis().set_visible(False)for k in range(5*5):i = k // 5j = k % 5ax[i, j].cla()ax[i, j].imshow(test_images[k, :].cpu().data.view(28, 28).numpy(), cmap='gray')label = 'Epoch {0}'.format(num_epoch)fig.text(0.5, 0.04, label, ha='center')plt.savefig(path)if show:plt.show()else:plt.close()def show_train_hist(hist, show = False, save = False, path = 'Train_hist.png'):x = range(len(hist['D_losses']))y1 = hist['D_losses']y2 = hist['G_losses']plt.plot(x, y1, label='D_loss')plt.plot(x, y2, label='G_loss')plt.xlabel('Epoch')plt.ylabel('Loss')plt.legend(loc=4)plt.grid(True)plt.tight_layout()if save:plt.savefig(path)if show:plt.show()else:plt.close()# training parameters
batch_size = 256
lr = 0.0002
train_epoch = 200
device = torch.cuda.is_available()
if device:print("running on GPU!")# data_loader
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=(0.5), std=(0.5))
])
train_loader = torch.utils.data.DataLoader(datasets.MNIST('data', train=True, download=True, transform=transform),batch_size=batch_size, shuffle=True)# network
G = generator(input_size=100, n_class=28*28)
D = discriminator(input_size=28*28, n_class=1)# Binary Cross Entropy loss
BCE_loss = nn.BCELoss()#move to cuda
if device:G.cuda()D.cuda()BCE_loss = BCE_loss.cuda()# Adam optimizer
G_optimizer = optim.Adam(G.parameters(), lr=lr)
D_optimizer = optim.Adam(D.parameters(), lr=lr)# results save folder
if not os.path.isdir('MNIST_GAN_results'):os.mkdir('MNIST_GAN_results')
if not os.path.isdir('MNIST_GAN_results/Random_results'):os.mkdir('MNIST_GAN_results/Random_results')
if not os.path.isdir('MNIST_GAN_results/Fixed_results'):os.mkdir('MNIST_GAN_results/Fixed_results')train_hist = {}
train_hist['D_losses'] = []
train_hist['G_losses'] = []for epoch in range(train_epoch):D_losses = []G_losses = []#生成任务,不需要标签for x_, _ in train_loader:#训练图像展平x_ = x_.view(-1, 28 * 28)mini_batch = x_.size()[0]y_real_ = torch.ones(mini_batch)y_fake_ = torch.zeros(mini_batch)# train discriminator DD.zero_grad()z_ = torch.randn((mini_batch, 100))if device:x_, y_real_, y_fake_ = x_.cuda(), y_real_.cuda(), y_fake_.cuda()z_ = z_.cuda()#真数据lossD_result = D(x_)D_real_loss = BCE_loss(D_result, y_real_)D_real_score = D_result#假数据lossG_result = G(z_)D_result = D(G_result)D_fake_loss = BCE_loss(D_result, y_fake_)D_fake_score = D_resultD_train_loss = D_real_loss + D_fake_lossD_train_loss.backward()D_optimizer.step()D_losses.append(D_train_loss.item())# train generator GG.zero_grad()# z_ = torch.randn((mini_batch, 100))# if device:#     z_ = z_.cuda()G_result = G(z_)D_result = D(G_result)G_train_loss = BCE_loss(D_result, y_real_)G_train_loss.backward()G_optimizer.step()G_losses.append(G_train_loss.item())print('[%d/%d]: loss_d: %.3f, loss_g: %.3f' % ((epoch + 1), train_epoch, torch.mean(torch.FloatTensor(D_losses)), torch.mean(torch.FloatTensor(G_losses))))if epoch %10 == 0:p = 'MNIST_GAN_results/Random_results/MNIST_GAN_' + str(epoch + 1) + '.png'fixed_p = 'MNIST_GAN_results/Fixed_results/MNIST_GAN_' + str(epoch + 1) + '.png'show_result((epoch+1), save=True, path=p, isFix=False)show_result((epoch+1), save=True, path=fixed_p, isFix=True)train_hist['D_losses'].append(torch.mean(torch.FloatTensor(D_losses)))train_hist['G_losses'].append(torch.mean(torch.FloatTensor(G_losses)))print("Training finish!... save training results")
torch.save(G.state_dict(), "MNIST_GAN_results/generator_param.pkl")
torch.save(D.state_dict(), "MNIST_GAN_results/discriminator_param.pkl")
with open('MNIST_GAN_results/train_hist.pkl', 'wb') as f:pickle.dump(train_hist, f)show_train_hist(train_hist, save=True, path='MNIST_GAN_results/MNIST_GAN_train_hist.png')images = []
for e in range(train_epoch):img_name = 'MNIST_GAN_results/Fixed_results/MNIST_GAN_' + str(e + 1) + '.png'images.append(imageio.imread(img_name))
imageio.mimsave('MNIST_GAN_results/generation_animation.gif', images, fps=5)

训练结果

在这里插入图片描述

以上是训练190个epoch后得到的结果,可以看到其中某些图片已经有了数字的模样。这里仅仅是使用了全连接层来搭建模型,如果使用卷积神经网络,效果会有更好的提升,大家可以尝试一下。

遇到的问题

可以适当地提高batch size来提高训练速度,也可以切换更简单的loss函数来提高训练速度。
建议batch size从底到高慢慢调节,若batch size过高,可能导致模型训练出现问题。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/5847.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【若依】添加数据字典

接下来,在生成代码的页面将“学科”字段改为下拉框,然后选择数据字典 然后,将生成的代码中的index文件复制到vue3的index中,替换掉之前的index文件 修改数据库中的subject的值,这样就可以通过数据字典来查询 以上操作成…

ngrok同时配置多个内网穿透方法

一、概要 ngrok可以用来配置免费的内网穿透,启动后就可以用外网ip:端口访问到自己计算机的某个端口了。 可以用来从外网访问自己的测试页面(80、8080)、ftp文件传输(21)、远程桌面(3389)等。 …

MySQL可直接使用的查询表的列信息

文章目录 背景实现方案模板SQL如何查询列如何转大写如何获取字符位置如何拼接字段 SQL适用场景 背景 最近产品找来,想让帮忙出下表的信息,字段驼峰展示,每张表信息show create table全部展示,再逐个粘贴,有点太耗费时…

白玉微瑕:闲谈 SwiftUI 过渡(Transition)动画的“口是心非”(下)

概述 秃头小码农们都知道,SwiftUI 不仅仅是一个静态 UI 构建框架那么简单,辅以海量默认或自定义的动画和过渡(Transition)特效,我们可以将 App 界面的绚丽升华到极致。 不过,目前 SwiftUI 中的过渡&#x…

cookie 与 session -- 会话管理

目录 前言 -- HTTP的无状态 cookie 概念 工作原理 Cookie 分类 会话 Cookie -- 内存级存储 持久 Cookie -- 文件级存储 代码验证 cookie 用户 username 过期时间 expires 指定路径 path cookie 的不足 session 概念 工作原理 代码验证session THE END 前言 -…

微信小程序使用上拉加载onReachBottom。页面拖不动。一直无法触发上拉的事件。

1,可能是原因是你使用了scroll-view的标签,用onReachBottom触发加载事件。这两个是有冲突的。没办法一起使用。如果页面的样式是滚动的是无法去触发页面的onReachBottom的函数的。因此,你使用overflow:auto.来使用页面的某些元素滚动&#xf…

计算机网络——网络层

重点内容: (1) 虚拟互连网络的概念。 (2) IP 地址与物理地址的关系。 (3) 传统的分类的 IP 地址(包括子网掩码)和无分类域间路由选择 CIDR 。 (4) 路由选择协议的工作原理。 目录 重点内容: 一.网络层提供的两种服务 二…

【动态规划】落花人独立,微雨燕双飞 - 8. 01背包问题

本篇博客给大家带来的是01背包问题之动态规划解法技巧. 🐎文章专栏: 动态规划 🚀若有问题 评论区见 ❤ 欢迎大家点赞 评论 收藏 分享 如果你不知道分享给谁,那就分享给薯条. 你们的支持是我不断创作的动力 . 王子,公主请阅🚀 要开心要快乐顺便…

yolov11 pose 推理代码

目录 效果图: yolo_pose.py 效果图: yolo_pose.py import osimport cv2 from PIL import Imagefrom ultralytics import YOLO import json from pathlib import Path import tqdmclass YOLOPose:def __init__(self, detections_file):self.detections_file = detections_f…

MFC程序设计(二)基于对话框编程

从现在开始,我们将以基于对话框的MFC应用程序来讲解MFC应用 向导生成基于对话框MFC应用程序 对话框是一种特殊类型的窗口,绝大多数Windows程序都通过对话框与用户进行交互。在Visual C中,对话框既可以单独组成一个简单的应用程序&#xff0…

ubuntu20.04有亮度调节条但是调节时亮度不变

尝试了修改grub文件,没有作用,下载了brightness-controllor,问题解决了。 sudo add-apt-repository ppa:apandada1/brightness-controller sudo apt update sudo apt install brightness-controller 之后在应用软件中找到brightness-contro…

深入探讨视图更新:提升数据库灵活性的关键技术

title: 深入探讨视图更新:提升数据库灵活性的关键技术 date: 2025/1/21 updated: 2025/1/21 author: cmdragon excerpt: 在现代数据库的管理中,视图作为一种高级的抽象机制,为数据的管理提供了多种便利。它不仅简化了复杂查询的过程,还能用来增强数据的安全性,限制用户…

75,【7】BUUCTF WEB [Weblogic]SSRF(未作出)

看到这个更是降龙十八掌 给的源代码进不去 给的靶场打不开 未完待续

16_动态提示窗口_协程延时

创建动态提示窗口DynamicWnd.cs 编写代码 using UnityEngine; using UnityEngine.UI; //功能 : 动态窗口界面 public class DynamicWnd : WindowsRoot{public Animation tipsAni;public Text txtTips;protected override void InitWnd() {base.InitWnd();//在启动时先隐藏提示…

麒麟监控工具rpm下载

确认系统 cat /etc/.productinfo麒麟v10 sp1 sp2 sp3 rpm包下载链接 sar - sysstat mtr iostat - sysstat netstat - net-tools https://update.cs2c.com.cn/NS/V10/V10SP3-2403/os/adv/lic/base/x86_64/Packages/sysstat-12.2.1-7.p01.ky10.x86_64.rpm https://update.cs…

2024年智慧消防一体化安全管控年度回顾与2025年预测

随着科技的飞速发展,智慧营区一体化安全管控在2024年取得了显著进展,同时也为2025年的发展奠定了坚实基础。 2024年年度回顾 政策支持力度持续加大:国家对消防安全的重视程度不断提高,出台了一系列涵盖技术创新、市场应用、人才培…

C#深度神经网络(TensorFlow.NET)

C#深度神经网络 文章目录 C#深度神经网络前言专业术语讲解模型[Model]向量[Vector]矩阵[Matrix]张量[Tensor]批量大小(Batch Size)迭代次数(Epochs)交叉熵[Cross Entropy] 训练流程数据预处理数据打标签数据转换标准化/归一化选择…

java 根据前端传回的png图片数组,后端加水印加密码生成pdf,返回给前端

前端传回的png图片数组,后端加水印加密码生成pdf,返回给前端 场景:重点:maven依赖controllerservice 场景: 当前需求,前端通过html2canvas将页面报表生成图片下载,可以仍然不满意。 需要java后…

Linux(LAMP)

赛题拓扑: 题目: 安装WEB服务。 服务以用户webuser系统用户运行。 限制WEB服务只能使用系统500M物理内存。 全站点启用TLS访问,使用本机上的“CSK Global Root CA”颁发机构颁发,网站证书信息如下: C CN ST China…

vue3+elementPlus之后台管理系统(从0到1)(day3-管理员管理)

管理员管理 搭建管理员页面 在views中创建一个manager文件夹&#xff0c;并创建ManagerIndexView.vue、MangagerListView.vue、UserList.vue <!-- src/views/manager/ManagerIndexView.vue --> <template><!-- 作为一个占位符&#xff0c;用于渲染与当前 URL…