PyTorch训练简单的生成对抗网络GAN

文章目录

    • 原理
    • 代码
    • 结果
    • 参考

原理

同时训练两个网络:辨别器Discriminator 和 生成器Generator
Generator是 造假者,用来生成假数据。
Discriminator 是警察,尽可能的分辨出来哪些是造假的,哪些是真实的数据。

目的:使得判别模型尽量犯错,无法判断数据是来自真实数据还是生成出来的数据。

GAN的梯度下降训练过程:

在这里插入图片描述
上图来源:https://arxiv.org/abs/1406.2661

Train 辨别器: m a x max max l o g ( D ( x ) ) + l o g ( 1 − D ( G ( z ) ) ) log(D(x)) + log(1 - D(G(z))) log(D(x))+log(1D(G(z)))

Train 生成器: m i n min min l o g ( 1 − D ( G ( z ) ) ) log(1-D(G(z))) log(1D(G(z)))

我们可以使用BCEloss来计算上述两个损失函数

BCEloss的表达式: m i n − [ y l n x + ( 1 − y ) l n ( 1 − x ) ] min -[ylnx + (1-y)ln(1-x)] min[ylnx+(1y)ln(1x)]
具体过程参加代码中注释

代码

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter  # to print to tensorboardclass Discriminator(nn.Module):def __init__(self, img_dim):super(Discriminator, self).__init__()self.disc = nn.Sequential(nn.Linear(img_dim, 128),nn.LeakyReLU(0.1),nn.Linear(128, 1),nn.Sigmoid(),)def forward(self, x):return self.disc(x)class Generator(nn.Module):def __init__(self, z_dim, img_dim): # z_dim 噪声的维度super(Generator, self).__init__()self.gen = nn.Sequential(nn.Linear(z_dim, 256),nn.LeakyReLU(0.1),nn.Linear(256, img_dim), # 28x28 -> 784nn.Tanh(),)def forward(self, x):return self.gen(x)# Hyperparameters
device = 'cuda' if torch.cuda.is_available() else 'cpu'
lr = 3e-4 # 3e-4是Adam最好的学习率
z_dim = 64 # 噪声维度
img_dim = 784 # 28x28x1
batch_size = 32
num_epochs = 50disc = Discriminator(img_dim).to(device)
gen = Generator(z_dim, img_dim).to(device)fixed_noise = torch.randn((batch_size, z_dim)).to(device)
transforms = transforms.Compose( # MNIST标准化系数:(0.1307,), (0.3081,)[transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081,))] # 不同数据集就有不同的标准化系数
)dataset = datasets.MNIST(root='dataset/', transform=transforms, download=True)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)opt_disc = optim.Adam(disc.parameters(), lr=lr)
opt_gen = optim.Adam(gen.parameters(), lr=lr)
# BCE 损失
criterion = nn.BCELoss()# 打开tensorboard:在该目录下,使用 tensorboard --logdir=runs
writer_fake = SummaryWriter(f"runs/GAN_MNIST/fake")
writer_real = SummaryWriter(f"runs/GAN_MNIST/real")
step = 0for epoch in range(num_epochs):for batch_idx, (real, _) in enumerate(loader):real = real.view(-1, 784).to(device) # view相当于reshapebatch_size = real.shape[0]### Train Discriminator: max log(D(real)) + log(1 - D(G(z)))noise = torch.randn(batch_size, z_dim).to(device)fake = gen(noise) # G(z)disc_real = disc(real).view(-1) # flatten# BCEloss的表达式:min -[ylnx + (1-y)ln(1-x)]# max log(D(real)) 相当于 min -log(D(real))# ones_like:1填充得到y=1, 即可忽略  min -[ylnx + (1-y)ln(1-x)]中的后一项# 得到 min -lnx,这里的x就是我们的real图片lossD_real = criterion(disc_real, torch.ones_like(disc_real))disc_fake = disc(fake).view(-1)# max log(1 - D(G(z))) 相当于 min -log(1 - D(G(z)))# zeros_like用0填充,得到y=0,即可忽略  min -[ylnx + (1-y)ln(1-x)]中的前一项# 得到 min -ln(1-x),这里的x就是我们的fake噪声lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))lossD = (lossD_real + lossD_fake) / 2disc.zero_grad()lossD.backward(retain_graph=True)opt_disc.step()### Train Generator: min log(1-D(G(z))) <--> max log(D(G(z))) <--> min - log(D(G(z)))# 依然可使用BCEloss来做output = disc(fake).view(-1)lossG = criterion(output, torch.ones_like(output))gen.zero_grad()lossG.backward()opt_gen.step()if batch_idx == 0:print(f"Epoch [{epoch}/{num_epochs}] \ "f"Loss D: {lossD:.4f}, Loss G: {lossG:.4f}")with torch.no_grad():fake = gen(fixed_noise).reshape(-1, 1, 28, 28)data = real.reshape(-1, 1, 28, 28)img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)img_grid_real = torchvision.utils.make_grid(data, normalize=True)writer_fake.add_image("Mnist Fake Images", img_grid_fake, global_step=step)writer_real.add_image("Mnist Real Images", img_grid_real, global_step=step)step += 1

结果

训练50轮的的损失

Epoch [0/50] \ Loss D: 0.7366, Loss G: 0.7051
Epoch [1/50] \ Loss D: 0.2483, Loss G: 1.6877
Epoch [2/50] \ Loss D: 0.1049, Loss G: 2.4980
Epoch [3/50] \ Loss D: 0.1159, Loss G: 3.4923
Epoch [4/50] \ Loss D: 0.0400, Loss G: 3.8776
Epoch [5/50] \ Loss D: 0.0450, Loss G: 4.1703
...
Epoch [43/50] \ Loss D: 0.0022, Loss G: 7.7446
Epoch [44/50] \ Loss D: 0.0007, Loss G: 9.1281
Epoch [45/50] \ Loss D: 0.0138, Loss G: 6.2177
Epoch [46/50] \ Loss D: 0.0008, Loss G: 9.1188
Epoch [47/50] \ Loss D: 0.0025, Loss G: 8.9419
Epoch [48/50] \ Loss D: 0.0010, Loss G: 8.3315
Epoch [49/50] \ Loss D: 0.0007, Loss G: 7.8302

使用

tensorboard --logdir=runs

打开tensorboard:

在这里插入图片描述
可以看到效果并不好,这是由于我们只是采用了简单的线性网络来做辨别器和生成器。后面的博文我们会使用更复杂的网络来训练GAN。

参考

[1] Building our first simple GAN
[2] https://arxiv.org/abs/1406.2661

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/103892.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

小区新冠疫情管理系统的设计与实现/基于springboot的小区疫情管理系统

摘要 采用更加便于维护和使用的Java语言&#xff0c;其可拓展性高且更富于表现力&#xff0c;基于mysql数据库、Springboot框架开发的小区新冠疫情管理系统&#xff0c;方便用户查看物资信息、疫苗信息。通过Eclipse来进行网页编程&#xff0c;其方便易用、移植适用性广、更加安…

日志搞不定?手把手教你如何使用Log4j2

系列文章目录 从零开始&#xff0c;手把手教你搭建Spring Boot后台工程并说明 Spring框架与SpringBoot的关联与区别 SpringBean生成流程详解 —— 由浅入深(附超精细流程图) Spring监听器用法与原理详解 Spring事务畅谈 —— 由浅入深彻底弄懂 Transactional注解 面试热点详解…

装备制造企业如何执行精益管理?

导 读 ( 文/ 2358 ) 精益管理是一种以提高效率、降低成本和优化流程为目标的管理方法。装备制造行业具备人工参与度高&#xff0c;产成品价值高&#xff0c;质量要求高的特点。 在装备制造企业中实施精益管理可以帮助企业提高竞争力、提升生产效率并提供高质量的产品。本文将…

边缘计算节点BEC典型实践:如何快速上手PC-Farm服务器?

百度智能云边缘计算节点BEC&#xff08;Baidu Edge Computing&#xff09;基于运营商边缘节点和网络构建&#xff0c;一站式提供靠近终端用户的弹性计算资源。边缘计算节点在海外覆盖五大洲&#xff0c;在国内覆盖全国七大区、三大运营商。BEC通过就近计算和处理&#xff0c;大…

【算法日志】贪心算法刷题:单调递增数列,贪心算法总结(day32)

代码随想录刷题60Day 目录 前言 单调递增数列 贪心算法总结 前言 今天是贪心算法刷题的最后一天&#xff0c;今天本来是打算刷两道题&#xff0c;其中的一道hard题做了好久都没有做出来(主要思路错了)。然后再总结一下。 单调递增数列 int monotoneIncreasingDigits(int n…

Wlan——STA上线流程与802.11MAC帧讲解以及报文转发路径

目录 802.11MAC帧基本概念 802.11帧结构 802.11MAC帧的分类 管理帧 控制帧 数据帧 STA接入无线网络流程 信号扫描—管理帧 链路认证—管理帧 用户关联—管理帧 用户上线 不同802.11帧的转发路径 802.11MAC帧基本概念 802.11协议在802家族中的角色位置 其中802.3标…

Linux内核学习(八)—— 内存管理(基于Linux 2.6内核)

目录 一、页&#xff08;page&#xff09; 二、区&#xff08;zone&#xff09; 三、页操作 四、kmalloc() 五、vmalloc() 六、slab 分配器 七、在栈上的静态分配 一、页&#xff08;page&#xff09; 内核把物理页作为内存管理的基本单位。尽管处理器的最小可寻 …

网工内推 | 锐捷招云工程师,HCIE、CCIE、RHCE优先,25k*13薪

01 锐捷网络 招聘岗位&#xff1a;云方案工程师 职责描述&#xff1a; 1、负责云数据中心方案项目方案设计撰写、项目实施交付、故障处理、业务割接、客户培训、现场保障、网络优化、网络巡检等技术相关业务 2、负责云数据中心方案新技术文档沉淀、体系建设、工具开发等标准化…

使用element-plus组件,默认显示英文 转换为中文

最近在边写项目边学习vue3 所以这几天没有更新 找机会把vue3的知识也统计一下吧 先说今天遇到的问题 最近做项目的时候使用element-plus分页组件时发现&#xff0c;显示的不是中文的了&#xff0c;是英文的 解决方法 在app.vue里面配置 <template><el-config-provi…

Pyqt5打开电脑摄像头进行拍照

目录 1、设计UI界面 2、设计逻辑代码&#xff0c;建立连接显示窗口 3、结果 1、设计UI界面 将ui界面转为py文件后获得的逻辑代码为&#xff1a;&#xff08;文件名为 Camera.py&#xff09; # -*- coding: utf-8 -*-# Form implementation generated from reading ui file …

【算法专题突破】双指针 - 移动零(1)

目录 写在前面 1. 题目解析 2. 算法原理 3. 代码编写 写在最后&#xff1a; 写在前面 在进行了剑指Offer和LeetCode hot100的毒打之后&#xff0c; 我决心系统地学习一些经典算法&#xff0c;增强我的综合算法能力。 1. 题目解析 题目链接&#xff1a;283. 移动零 - 力…

centos7.9和redhat6.9 离线升级OpenSSH和openssl (2023年的版本)

升级注意事项&#xff01; 1、多开几个连接窗口&#xff08;xshell&#xff09;&#xff0c;避免升级openssh失败无法再次连接终端&#xff0c;否则要跑机房了。 2、可开启telnet服务、vnc服务、打快照。多几个“保命”的路数。一、centos7.9的信息 [rootnode2 ~]# openssl v…

学会Mybatis框架:一文掌握MyBatis与GitHub插件分页的完美结合【三.分页】

&#x1f973;&#x1f973;Welcome Huihuis Code World ! !&#x1f973;&#x1f973; 接下来看看由辉辉所写的关于Mybatis的相关操作吧 目录 &#x1f973;&#x1f973;Welcome Huihuis Code World ! !&#x1f973;&#x1f973; 一.Mybatis分页 1. Mybatis自带分页 2…

AIGC如何借AI Agent落地?TARS-RPA-Agent破解RPA与LLM融合难题

文/王吉伟 大语言模型&#xff08;LLM&#xff0c;Large Language Model&#xff09;的持续爆发&#xff0c;让AIGC一直处于这股AI风暴最中央&#xff0c;不停席卷各个领域。 在国内&#xff0c;仍在雨后春笋般上新的大语言模型&#xff0c;在持续累加“千模大战”大模型数量的…

Linux之基础IO文件系统讲解

基础IO文件系统讲解 回顾C语言读写文件读文件操作写文件操作输出信息到显示器的方法stdin & stdout & stderr总结 系统文件IOIO接口介绍文件描述符fd文件描述符的分配规则C标准库文件操作函数简易模拟实现重定向dup2 系统调用在minishell中添加重定向功能 FILE文件系统…

什么是Pytorch?

当谈及深度学习框架时&#xff0c;PyTorch 是当今备受欢迎的选择之一。作为一个开源的机器学习库&#xff0c;PyTorch 为研究人员和开发者们提供了一个强大的工具来构建、训练以及部署各种深度学习模型。你可能会问&#xff0c;PyTorch 是什么&#xff0c;它有什么特点&#xf…

为什么需要单元测试?

为什么需要单元测试&#xff1f; 从产品角度而言&#xff0c;常规的功能测试、系统测试都是站在产品局部或全局功能进行测试&#xff0c;能够很好地与用户的需要相结合&#xff0c;但是缺乏了对产品研发细节&#xff08;特别是代码细节的理解&#xff09;。 从测试人员角度而言…

Mysql简短又易懂

MySql 连接池:的两个参数 最大连接数&#xff1a;可以同时发起的最大连接数 单次最大数据报文&#xff1a;接受数据报文的最大长度 数据库如何存储数据 存储引擎&#xff1a; InnoDB:通过执行器对内存和磁盘的数据进行写入和读出 优化SQL语句innoDB会把需要写入或者更新的数…

Java如何调用接口API并返回数据(两种方法)

Java如何调用接口API并返回数据&#xff08;两种方法&#xff09; java处理请求接口后返回的json数据-直接处理json字符串 处理思路&#xff1a; 将返回的数据接收到一个String对象中&#xff08;有时候需要自己选择性的取舍接收&#xff09; 再将string转换为JSONObject对象 …

Go 语言在 Windows 上的安装及配置

1. Go语言的下载 Golang官网&#xff1a;All releases - The Go Programming Language Golang中文网&#xff1a;Go下载 - Go语言中文网 - Golang中文社区 两个网站打开的内容只有语言不同而已&#xff0c;网站上清晰的标注了不同操作系统需要对应安装哪个版本&#xff0c;其中…