G1 GAN生成MNIST手写数字图像

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

G1 GAN生成MNIST手写数字图像

1. 生成对抗网络 (GAN) 简介

生成对抗网络 (GAN) 是一种通过“对抗性”学习生成数据的深度学习模型,通常用于生成图像、视频等数据。GAN 由两个网络组成:

  • 生成器 (Generator):用于生成假的数据样本,试图让判别器无法分辨其为假的。
  • 判别器 (Discriminator):用于区分输入的数据是真实的还是生成器生成的。

GAN 的核心思想是,生成器和判别器通过相互对抗学习,生成器逐渐提高生成逼真数据的能力,而判别器逐渐提高区分真假数据的能力。最后,生成器生成的样本与真实样本之间的差异会越来越小。

GAN 的基本流程

  1. 判别器输入真实数据,判别器输出一个接近1的值,表示为真;
  2. 生成器生成假的数据,并试图欺骗判别器;
  3. 判别器输出接近0的值,表示为假;
  4. 生成器通过更新自身的参数,试图让判别器认为生成的数据是真实的。

GAN 的目标是使得生成器生成的假数据,能骗过判别器。

GAN 的损失函数

GAN 的训练目标是让生成器和判别器进行对抗训练,其损失函数分为两个部分:生成器损失和判别器损失。生成器的目标是最大化判别器判断生成数据为真的概率,判别器的目标是最大化正确判断真实数据和生成数据的概率。

判别器的损失函数定义为:

L D = − [ E x ∼ p data [ log ⁡ D ( x ) ] + E z ∼ p z [ log ⁡ ( 1 − D ( G ( z ) ) ) ] ] \mathcal{L}_D = - \left[ \mathbb{E}_{x \sim p_{\text{data}}} \left[ \log D(x) \right] + \mathbb{E}_{z \sim p_z} \left[ \log (1 - D(G(z))) \right] \right] LD=[Expdata[logD(x)]+Ezpz[log(1D(G(z)))]]

生成器的损失函数定义为:

L G = − E z ∼ p z [ log ⁡ D ( G ( z ) ) ] \mathcal{L}_G = - \mathbb{E}_{z \sim p_z} \left[ \log D(G(z)) \right] LG=Ezpz[logD(G(z))]

其中:

  • ( D(x) ) 表示判别器对真实数据 ( x ) 判别为真的概率;
  • ( G(z) ) 是生成器通过噪声 ( z ) 生成的假数据;
  • ( D(G(z)) ) 表示判别器对生成器生成数据的输出(希望趋向于1)。

2. PyTorch 实现

下面使用 PyTorch 实现 GAN 生成 MNIST 手写数字图像。

2.1 导入库与超参数设置

import os
import numpy as np
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image# 创建文件夹
os.makedirs('./output/images/', exist_ok=True)# 超参数设置
n_epochs = 50
batch_size = 64
lr = 0.0002
latent_dim = 100
img_size = 28
channels = 1
img_shape = (channels, img_size, img_size)
img_area = np.prod(img_shape)cuda = True if torch.cuda.is_available() else False

2.2 数据预处理

使用 torchvision.datasets.MNIST 下载并处理 MNIST 数据集。数据会被标准化到 [-1, 1] 区间,并通过 DataLoader 转化为可迭代数据集。

# 下载MNIST数据集并进行预处理
mnist = datasets.MNIST(root='./data', train=True, download=True,transform=transforms.Compose([transforms.Resize(img_size),transforms.ToTensor(),transforms.Normalize([0.5], [0.5])]))dataloader = DataLoader(mnist, batch_size=batch_size, shuffle=True)

2.3 定义生成器模型

生成器接受一个随机噪声向量 ( z ),通过多层线性变换和激活函数逐步生成一个 28x28 的图像。

class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()def block(in_feat, out_feat, normalize=True):layers = [nn.Linear(in_feat, out_feat)]if normalize:layers.append(nn.BatchNorm1d(out_feat, 0.8))layers.append(nn.LeakyReLU(0.2, inplace=True))return layersself.model = nn.Sequential(*block(latent_dim, 128, normalize=False),*block(128, 256),*block(256, 512),*block(512, 1024),nn.Linear(1024, img_area),nn.Tanh())def forward(self, z):img = self.model(z)return img.view(img.size(0), *img_shape)

2.4 定义判别器模型

判别器是一个二分类网络,输入一个 28x28 的图像,输出一个表示真假概率的值。

class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.model = nn.Sequential(nn.Linear(img_area, 512),nn.LeakyReLU(0.2, inplace=True),nn.Linear(512, 256),nn.LeakyReLU(0.2, inplace=True),nn.Linear(256, 1),nn.Sigmoid())def forward(self, img):img_flat = img.view(img.size(0), -1)validity = self.model(img_flat)return validity

2.5 定义优化器与损失函数

generator = Generator()
discriminator = Discriminator()# 定义损失函数
criterion = nn.BCELoss()# 定义生成器和判别器的优化器
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))if cuda:generator.cuda()discriminator.cuda()criterion.cuda()

2.6 训练过程

2.6.1 训练判别器

判别器需要区分真实图像和生成的假图像,通过两个损失值相加,更新判别器的参数。

real_img = Variable(imgs.type(torch.cuda.FloatTensor))
real_label = Variable(torch.ones(imgs.size(0), 1).cuda())
fake_label = Variable(torch.zeros(imgs.size(0), 1).cuda())real_out = discriminator(real_img)
loss_real = criterion(real_out, real_label)z = Variable(torch.randn(imgs.size(0), latent_dim).cuda())
fake_img = generator(z).detach()
fake_out = discriminator(fake_img)
loss_fake = criterion(fake_out, fake_label)loss_D = loss_real + loss_fake
optimizer_D.zero_grad()
loss_D.backward()
optimizer_D.step()
2.6.2 训练生成器

生成器的目标是让判别器认为生成的数据是真实的,因此生成器的损失是判别器对假图像的输出。

z = Variable(torch.randn(imgs.size(0), latent_dim).cuda())
fake_img = generator(z)
output = discriminator(fake_img)loss_G = criterion(output, real_label)
optimizer_G.zero_grad()
loss_G.backward()
optimizer_G.step()

在这里插入图片描述

2.7 保存与可视化生成图像

if batches_done % sample_interval == 0:save_image(fake_img.data[:25], "./output/images/%d.png" % batches_done, nrow=5, normalize=True)

在这里插入图片描述

4. 总结

这周学习了如何使用 PyTorch 实现生成对抗网络 (GAN) 来生成 MNIST 手写数字图像。GAN 通过生成器与判别器之间的对抗学习,不断提升生成图像的质量,是一种非常强大的生成模型。可以在论文中将其作为数据增强的一种方式。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/451941.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

SQL Injection | SQL 注入概述

关注这个漏洞的其他相关笔记:SQL 注入漏洞 - 学习手册-CSDN博客 0x01:SQL 注入漏洞介绍 SQL 注入就是指 Web 应用程序对用户输入数据的合法性没有判断,前端传入后端的参数是可控的,并且参数会带入到数据库中执行,导致…

CCS字体、字号更改+CCS下载官方链接

Step1、 按照图示箭头操作 step2 Step3 点击确定,点击Apply(应用),点击Apply and close(应用和关闭) 4、历代版本下载链接 CCS下载:官方链接https://www.ti.com/tool/CCSTUDIO The last but not least 如果成功的解决了你的问题&#x…

MEMC功能详解

文章目录 MEMC的工作原理:优点:缺点:适用场景:1. Deblur(去模糊)2. Dejudder(去抖动)总结两者区别: MEMC(Motion Estimation and Motion Compensation&#x…

【开源免费】基于SpringBoot+Vue.JS房屋租赁系统(JAVA毕业设计)

本文项目编号 T 020 ,文末自助获取源码 \color{red}{T020,文末自助获取源码} T020,文末自助获取源码 目录 一、系统介绍二、演示录屏三、启动教程四、功能截图五、文案资料5.1 选题背景5.2 国内外研究现状5.3 可行性分析 六、核心代码6.1 查…

编码方式知识整理【ASCII、Unicode和UTF-8】

编码方式 一、ASCII编码二、Unicode 编码三、UTF-8编码四、GB2312编码五、GBK编码 计算机中对数据的存储为二进制形式,但采用什么样的编码方式存储,效率更高。主要编码方式有 ASCII、Unicode、UTF-8等。 英文一般为1个字节,汉字一般为3个字节…

代码复现(五):GCPANet

文章目录 net.py1.class Bottleneck:残差块2.class ResNet:特征提取3.class SRM:SR模块4.class FAM:FIA模块5.class CA:GCF模块6.class SA:HA模块7.class GCPANet:网络架构 train.pytest.py 论文…

【数学二】一元函数积分学-定积分的应用-平面图形面积、旋转体体积、函数的平均值、平面曲线的弧长、旋转曲面面积

考试要求 1、理解原函数的概念,理解不定积分和定积分的概念. 2、掌握不定积分的基本公式,掌握不定积分和定积分的性质及定积分中值定理,掌握换元积分法与分部积分法. 3、会求有理函数、三角函数有理式和简单无理函数的积分. 4、理解积分上限…

进程与线程的区别

1.进程的简单了解 进程是计算机中程序在某个数据集合上的一次运行活动,是操作系统进行资源分配和调度的基本单位。 从不同角度来看: ● 资源分配角度:进程拥有独立的内存地址空间、系统资源(如 CPU 时间、文件描述符等&#xf…

【OD】【E卷】【真题】【100分】光伏场地建设规划(PythonJavajavaScriptC++C)

题目描述 祖国西北部有一片大片荒地,其中零星的分布着一些湖泊,保护区,矿区; 整体上常年光照良好,但是也有一些地区光照不太好。 某电力公司希望在这里建设多个光伏电站,生产清洁能源对每平方公里的土地进行了发电评…

关于测试翻译准确率的相关方法

本文提到的翻译准确率测试指标是BLEU,以及使用Python库-fuzzywuzzy来计算相似度 一、基于BLEU值评估 1.只评估一段话,代码如下 from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction# 机器翻译结果 machine_translation "Ho…

【动手学深度学习】8.3 语言模型(个人向笔记)

下面是语言模型的简介 1. 学习语言模型 使用计数来建模 N元语法:这里的元可以理解为我们之前的时间变量。对于 N 元语法,我们可以把所有长度为 N 的子序列存下来。其中 1 元语法用的很少。这里其实就是算概率的时候我们不往前看所有的概率,…

ACL和NAT

一、ACL 1.概述 访问控制列表Access Control List是由一系列permit或deny语句组成的、有序规则的列表是一个匹配工具,对报文进行匹配和区分 2.ACL应用 匹配流量在traffic-filter中被调用在NAT(Natwork Address Translation)中被调用在路由策略中被调用在防火墙的…

Linux系统:本机(物理主机)访问不了虚拟机中的apache服务问题的解决方案

学习目标: 提示:本文主要讲述-本机(物理主机)访问不了虚拟机中的apache服务情况下的解决方案 Linux系统:Ubuntu 23.04; 文中提到的“本机”:代表,宿主机,物理主机; 首先&#xff0c…

OpenCV高级图形用户界面(14)交互式地选择一个或多个感兴趣区域函数selectROIs()的使用

操作系统:ubuntu22.04 OpenCV版本:OpenCV4.9 IDE:Visual Studio Code 编程语言:C11 算法描述 允许用户在给定的图像上选择多个 ROI。 该函数创建一个窗口,并允许用户使用鼠标来选择多个 ROI。控制方式:使用空格键或…

【Unity新闻】Unity 6 正式版发布

Unity CEO Matt Bromberg 在今天自豪地宣布,Unity 6 正式发布!作为迄今为止最强大和稳定的版本,Unity 6 为游戏和应用开发者提供了大量的新功能和工具,帮助他们加速开发并提升性能。 本次正式版是6.0000.0.23f1(LTS&a…

Django学习(三)

Django的设计模式及模板层 传统的MVC(例如java) Django的MTV 模板层: 模板加载: 代码: views.py def test_html(request):#方案一# from django.template import loader# 1. 使用loader加载模板# t loader.get_…

WIFI实现透传+接线图

单片机通过TX接WIFI模块的RX将设置的AT代码写入WIFI模块(连接WIFI调为设备模式(有设备,路由,双模等模式)) WIFI模块将响应信号通过TX通过CH340发给PC的RX 通过STC-ISP或安信可串口调试助手查看响应信息 …

Parallels Desktop20最新版本虚拟机 让双系统无缝切换成为现实!

Parallels Desktop 20最新版本虚拟机:让双系统无缝切换成为现实! 嘿,各位小伙伴们~🎉 如果你是像我一样,既爱 Windows 又放不下 macOS 的纠结星人,那今天这篇分享你可要仔细看啰!&am…

Linux学习笔记9 文件系统的基础

一、查看文件组织结构 Linux中一切都是文件。 Linux和Win的文件系统不是一个结构,Linux存在的根目录是所有目录的起点。 所有的存储空间和设备共享一个根目录,不同的磁盘块和分区挂载在其下,成为某个子目录的子目录,甚至设备也挂…

Windows系统部署redis自启动服务【亲测可用】

文章目录 引言I redis以本地服务运行(Windows service)使用MSI安装包配置文件,配置端口和密码II redis服务以终端命令启动缺点运行redis-server并指定端口和密码III 知识扩展确认redis-server可用性Installing the Service引言 服务器是Windows系统,所以使用Windows不是re…