学习笔记:Pytorch利用MNIST数据集训练生成对抗网络(GAN)

2023.8.27

       在进行深度学习的进阶的时候,我发了生成对抗网络是一个很神奇的东西,为什么它可以“将一堆随机噪声经过生成器变成一张图片”,特此记录一下学习心得。

一、生成对抗网络百科

        2014年,还在蒙特利尔读博士的Ian Goodfellow发表了论 文《Generative Adversarial Networks》(网址: https://arxiv.org/abs/1406.2661),将生成对抗网络引入 深度学习领域。2016年,GAN热潮席卷AI领域顶级会议, 从ICLR到NIPS,大量高质量论文被发表和探讨。Yann LeCun曾评价GAN是“20年来机器学习领域最酷的想法”。

机器学习的模型可大体分为两类,生成模型( Generative Model)和判别模型(Discriminative Model)。判别模型需要输入变量 ,通过某种模型来 预测 。生成模型是给定某种隐含信息,来随机产生观 测数据。

GAN百科:

GAN(生成对抗网络)的系统全面介绍(醍醐灌顶)_打灰人的博客-CSDN博客

二、GAN代码

训练代码:

                epoch=1000时的效果就不错啦

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as pltclass Generator(nn.Module):  # 生成器def __init__(self, latent_dim):super(Generator, self).__init__()self.model = nn.Sequential(nn.Linear(latent_dim, 256),nn.LeakyReLU(0.2),nn.Linear(256, 512),nn.LeakyReLU(0.2),nn.Linear(512, 1024),nn.LeakyReLU(0.2),nn.Linear(1024, 784),nn.Tanh())def forward(self, z):img = self.model(z)img = img.view(img.size(0), 1, 28, 28)return imgclass Discriminator(nn.Module):  # 判别器def __init__(self):super(Discriminator, self).__init__()self.model = nn.Sequential(nn.Linear(784, 512),nn.LeakyReLU(0.2),nn.Linear(512, 256),nn.LeakyReLU(0.2),nn.Linear(256, 1),nn.Sigmoid())def forward(self, img):img = img.view(img.size(0), -1)validity = self.model(img)return validitydef gen_img_plot(model, test_input):pred = np.squeeze(model(test_input).detach().cpu().numpy())fig = plt.figure(figsize=(4, 4))for i in range(16):plt.subplot(4, 4, i + 1)plt.imshow((pred[i] + 1) / 2)plt.axis('off')plt.show(block=False)plt.pause(3)  # 停留0.5splt.close()# 调用GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")# 超参数设置
lr = 0.0001
batch_size = 128
latent_dim = 100
epochs = 1000# 数据集载入和数据变换
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))
])# 训练数据
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=False)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)# 测试数据 torch.randn()函数的作用是生成一组均值为0,方差为1(即标准正态分布)的随机数
# test_data = torch.randn(batch_size, latent_dim).to(device)
test_data = torch.FloatTensor(batch_size, latent_dim).to(device)# 实例化生成器和判别器,并定义损失函数和优化器
generator = Generator(latent_dim).to(device)
discriminator = Discriminator().to(device)
adversarial_loss = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=lr)
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr)# 开始训练模型
for epoch in range(epochs):for i, (imgs, _) in enumerate(train_loader):batch_size = imgs.shape[0]real_imgs = imgs.to(device)# 训练判别器z = torch.FloatTensor(batch_size, latent_dim).to(device)z.data.normal_(0, 1)fake_imgs = generator(z)  # 生成器生成假的图片real_labels = torch.full((batch_size, 1), 1.0).to(device)fake_labels = torch.full((batch_size, 1), 0.0).to(device)real_loss = adversarial_loss(discriminator(real_imgs), real_labels)fake_loss = adversarial_loss(discriminator(fake_imgs.detach()), fake_labels)d_loss = (real_loss + fake_loss) / 2optimizer_D.zero_grad()d_loss.backward()optimizer_D.step()# 训练生成器z.data.normal_(0, 1)fake_imgs = generator(z)g_loss = adversarial_loss(discriminator(fake_imgs), real_labels)optimizer_G.zero_grad()g_loss.backward()optimizer_G.step()torch.save(generator.state_dict(), "Generator_mnist.pth")print(f"Epoch [{epoch}/{epochs}] Loss_D: {d_loss.item():.4f} Loss_G: {g_loss.item():.4f}")# gen_img_plot(Generator, test_data)
gen_img_plot(generator, test_data)

测试代码:

import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
import randomdevice = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')class Generator(nn.Module):  # 生成器def __init__(self, latent_dim):super(Generator, self).__init__()self.model = nn.Sequential(nn.Linear(latent_dim, 256),nn.LeakyReLU(0.2),nn.Linear(256, 512),nn.LeakyReLU(0.2),nn.Linear(512, 1024),nn.LeakyReLU(0.2),nn.Linear(1024, 784),nn.Tanh())def forward(self, z):img = self.model(z)img = img.view(img.size(0), 1, 28, 28)return img# test_data = torch.FloatTensor(128, 100).to(device)
test_data = torch.randn(128, 100).to(device)  # 随机噪声model = Generator(100).to(device)
model.load_state_dict(torch.load('Generator_mnist.pth'))
model.eval()pred = np.squeeze(model(test_data).detach().cpu().numpy())for i in range(64):plt.subplot(8, 8, i + 1)plt.imshow((pred[i] + 1) / 2)plt.axis('off')
plt.savefig(fname='image.png', figsize=[5, 5])
plt.show()

三、结果

       在超参数设置 epoch=1000,batch_size=128,lr=0.0001,latent_dim = 100 时,gan生成的权重测的结果如图所示

四,GAN的损失函数曲线

                一开始训练时,我的gan的损失函数的曲线是类似这样的,就是知乎这文章里一样,生成器损失函数的曲线一直发散。首先,这个loss的曲线一看就是网络崩了,一般正常的情况,d_loss的值会一直下降然后收敛,而g_loss的曲线会先增大后减少,最后同样也会收敛。其次,网络拿到手以后先不要训练太多次,容易出现过拟合的情况。

生成对抗网络的损失函数图像如下合理吗? - 知乎

这是训练了10轮的生成器和鉴别器的损失函数值变化吧:

效果如图所示: 

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/112090.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

统信OS国产操作系统身份证读卡器社保卡读卡web网页开发使用操作流程

用于DONSEE系列身份证阅读器谷歌Chrome火狐Firefox插件,支持的型号有:EST-100、EST-100GS、EST-100G、EST-100U、EST-200G、EST-J13X等。 本方案无缝支持最新版本谷歌Chrome火狐Firefox等网页浏览器,支持H5、Vue、React、Node.js、Electron、…

区块链金融项目怎么做?

区块链技术的兴起引发了金融领域的变革,为金融行业带来了前所未有的机遇与挑战。在这个快速发展的领域中,如何在区块链金融领域做出卓越的表现?本文将从专业性和思考深度两个方面,探讨区块链金融的发展路径,并为读者提…

低代码/无代码平台:加速应用开发的工具

在数字化时代,软件应用已经成为企业和组织不可或缺的一部分。然而,传统的应用开发过程往往需要大量的时间、资源和专业知识。为了解决这个问题,低代码/无代码平台应运而生,它们为开发者提供了一种更快速、更简便的应用开发方式。本…

lab11 net

background 在开始写代码之前,回顾一下xv6book的第五章会有帮助你将使用E1000去处理网络通信 E1000会和qemu模拟的lan通信在qemu模拟的lan中 xv6的地址是10.0.2.15qemu模拟的计算机的地址是10.0.2.2 qemu会将所有的网络包都记录在packets.pcap中文件kernel/e1000.…

【LeetCode-中等题】148. 排序链表

文章目录 题目方法一:集合排序(核心是内部的排序)方法二: 优先队列(核心也是内部的排序)方法三:归并排序(带递归) 从上往下方法四:归并排序(省去递…

java八股文面试[多线程]——什么是守护线程

知识来源: 【2023年面试】什么是守护线程_哔哩哔哩_bilibili

新亮点!安防视频监控/视频集中存储/云存储平台EasyCVR平台六分屏功能展示

安防视频监控/视频集中存储/云存储/磁盘阵列EasyCVR平台可拓展性强、视频能力灵活、部署轻快,可支持的主流标准协议有国标GB28181、RTSP/Onvif、RTMP等,以及支持厂家私有协议与SDK接入,包括海康Ehome、海大宇等设备的SDK等。平台既具备传统安…

WebRTC之FEC前向纠错协议

FEC前向纠错用于丢包恢复,对媒体包进行异或或其他算法生成冗余包进行发送。如果接收端出现丢包,可以通过冗余包恢复出原始的媒体包。FEC的代价是增加码率带宽,所以一般会根据网络状况、丢包率来动态调整FEC冗余系数,也会结合NACK/…

一文带你全面理解向量数据库

近些年来,向量数据库引起业界的广泛关注,一个相关事实是许多向量数据库初创公司在短期内就筹集到数百万美元的资金。 你很可能已经听说过向量数据库,但也许直到现在才真正关心向量数据库——至少,我想这就是你现在阅读本文的原因…

cvat 安装部署

官网地址: https://github.com/opencv/cvat/tree/masterhttps://github.com/opencv/cvat/tree/master 1.从官网上下载源码地址。 2.配置环境变量 vim /etc/profile source /etc/profile 或者执行: export CVAT_HOSTyour-ip-address 3.执行命令 …

基于Spring Boot 的 Ext JS 应用框架之coworkee

Ext JS 官方提供了一个人员管理的完整应用框架 - coworkee。该框架的显示如下: 该框架的布局特点如下: 布局方式: 左右布局, 左侧导航栏默认收合特点:左侧导航区占用空间小, 工作区较大, 适合没有二级导航栏,工作区需要显示的内容较多的系统。如果导航栏是横向底部,就…

ssm毕业生就业状况管理系统源码和论文

ssm毕业生就业状况管理系统源码和论文093 开发工具:idea 数据库mysql5.7 数据库链接工具:navcat,小海豚等 技术:ssm 摘 要 现代经济快节奏发展以及不断完善升级的信息化技术,让传统数据信息的管理升级为软件存储&#xff…

无涯教程-Android - Grid View函数

Android GridView在二维滚动网格(行和列)中显示项目,并且网格项目不一定是预定的,但它们会使用ListAdapter自动插入到布局中 Grid View - Grid view ListView 和 GridView 是 AdapterView 的子类,可以通过将它们绑定到 Adapter 来填充&#x…

MongoDB实验——在MongoDB集合中查找文档

在MongoDB集合中查找文档 一、实验目的二、实验原理三、实验步骤1.启动MongoDB数据库、启动MongoDB Shell客户端2.数据准备-->person.json3.指定返回的键4 .包含或不包含 i n 或 in 或 in或nin、$elemMatch(匹配数组)5.OR 查询 $or6.Null、$exists7.…

简易虚拟培训系统-UI控件的应用1

目录 前言 UI结构总体介绍 建立初步的系统UI结构 Image控件 前言 前面的文章介绍了关于Oculus设备与UI控件的关联,从本文开始采用小示例的方式介绍基本的UI控件在系统中的基本作用(仅介绍“基本作用”,详细的API教程可参考官方文档&#x…

用wireshark流量分析的四个案例

目录 第一题 1 2 3 4 第二题 1 2 3. 第三题 1 2 第四题 1 2 3 第一题 题目: 1.黑客攻击的第一个受害主机的网卡IP地址 2.黑客对URL的哪一个参数实施了SQL注入 3.第一个受害主机网站数据库的表前缀(加上下划线例如abc) 4.…

2024毕业设计选题指南【附选题大全】

title: 毕业设计选题指南 - 如何选择合适的毕业设计题目 date: 2023-08-29 categories: 毕业设计 tags: 选题指南, 毕业设计, 毕业论文, 毕业项目 - 如何选择合适的毕业设计题目 当我们站在大学生活的十字路口,毕业设计便成了我们面临的一项重要使命。这不仅是对我们…

文本编辑器Vim常用操作和技巧

文章目录 1. Vim常用操作1.1 Vim简介1.2 Vim工作模式1.3 插入命令1.4 定位命令1.5 删除命令1.6 复制和剪切命令1.7 替换和取消命令1.8 搜索和搜索替换命令1.9 保存和退出命令 2. Vim使用技巧 1. Vim常用操作 1.1 Vim简介 Vim是一个功能强大的全屏幕文本编辑器,是L…

计算机竞赛 基于机器视觉的手势检测和识别算法

0 前言 🔥 优质竞赛项目系列,今天要分享的是 基于深度学习的手势检测与识别算法 该项目较为新颖,适合作为竞赛课题方向,学长非常推荐! 🧿 更多资料, 项目分享: https://gitee.com/dancheng…

今天学了二叉树的前序,中序和后序遍历oier

同时了解一些趣图笑死我了 所以想入门先入坟,这是最好的礼物。 废话说多了,谈谈正事,我们了解到二叉树有节点和边权;分为有向和无向图;这里如果我们需要搜索一下每一个节点的情况,所以就需要这个遍历了 1…