Pytorch从零开始实战17

Pytorch从零开始实战——生成对抗网络入门

本系列来源于365天深度学习训练营

原作者K同学

文章目录

  • Pytorch从零开始实战——生成对抗网络入门
    • 环境准备
    • 模型定义
    • 开始训练
    • 总结

环境准备

本文基于Jupyter notebook,使用Python3.8,Pytorch1.8+cpu,本次实验目的是了解生成对抗网络。

生成对抗网络(GAN)是一种深度学习模型。GAN由两个主要组成部分组成:生成器和判别器。这两个部分通过对抗的方式共同学习,使得生成器能够生成逼真的数据,而判别器能够区分真实数据和生成的数据。

生成器的任务是生成与真实数据相似的样本。它接收一个随机噪声向量,然后通过深度神经网络将这个随机噪声转换为具体的数据样本。在图像生成的场景中,生成器通常将随机噪声映射为图像。生成器的目标是欺骗判别器,使其无法区分生成的样本和真实的样本。生成器的训练目标是最小化生成的样本与真实样本之间的差异。

判别器的任务是对给定的样本进行分类,判断它是来自真实数据集还是由生成器生成的。它接收真实样本和生成样本,然后通过深度神经网络输出一个概率,表示输入样本是真实样本的概率。判别器的目标是准确地分类样本,使其能够正确地区分真实数据和生成的数据。判别器的训练目标是最大化正确分类的概率。

导入相关包。

import torch
import torch.nn as nn
import argparse
import os
import numpy as np
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

创建文件夹,分别保存训练过程中的图像、模型参数和数据集。

os.makedirs("./images/", exist_ok=True) # 训练过程中图片效果
os.makedirs("./save/", exist_ok=True) # 训练完成时模型保存位置
os.makedirs("./datasets/", exist_ok=True) # 数据集位置

设置超参数。
b1、b2为Adam优化算法的参数,其中b1是梯度的一阶矩估计的衰减系数,b2是梯度的二阶矩估计的衰减系数。
latent_dim表示生成器输入的随机噪声向量的维度。这个噪声向量用于生成器产生新样本。
sample_interval表示在训练过程中每隔多少个batch保存一次生成器生成的样本图像,以便观察生成效果。

epochs = 20
batch_size = 64
lr = 0.0002
b1 = 0.5
b2 = 0.999
latent_dim=100
img_size=28
channels=1
sample_interval=500

设定图像尺寸并检查cuda,本次使用的设备没有cuda。

img_shape = (channels, img_size, img_size) # (1, 28, 28)
img_area = np.prod(img_shape) # 784## 设置cuda
cuda = True if torch.cuda.is_available() else False
print(cuda) # False

本次使用GAN来生成手写数字,首先下载mnist数据集。

mnist = datasets.MNIST(root='./datasets/', train=True, download=True,transform=transforms.Compose([transforms.Resize(img_size),transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]))

使用dataloader划分批次与打乱。

dataloader = DataLoader(mnist,batch_size=batch_size,shuffle=True,
)len(dataloader) # 938

模型定义

首先定义鉴别器模型,代码中LeakyReLU是ReLU激活函数的变体,它引入了一个小的负斜率,在负输入值范围内,而不是将它们直接置零。这个斜率通常是一个小的正数,例如0.01。

class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.model = nn.Sequential(nn.Linear(img_area, 512),        nn.LeakyReLU(0.2, inplace=True),  nn.Linear(512, 256),             nn.LeakyReLU(0.2, inplace=True), nn.Linear(256, 1),              nn.Sigmoid(),                    )def forward(self, img):img_flat = img.view(img.size(0), -1) validity = self.model(img_flat)     return validity         # 返回的是一个[0, 1]间的概率

定义生成器模型,用于输出图像。

class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()def block(in_feat, out_feat, normalize=True):      layers = [nn.Linear(in_feat, out_feat)]          if normalize:layers.append(nn.BatchNorm1d(out_feat, 0.8)) layers.append(nn.LeakyReLU(0.2, inplace=True))   return layersself.model = nn.Sequential(*block(latent_dim, 128, normalize=False), *block(128, 256),                         *block(256, 512),                         *block(512, 1024),                       nn.Linear(1024, img_area),                nn.Tanh()                                )def forward(self, z):                           imgs = self.model(z)                       imgs = imgs.view(imgs.size(0), *img_shape)  # reshape成(64, 1, 28, 28)return imgs                                 # 输出为64张大小为(1, 28, 28)的图像

开始训练

创建生成器、判别器对象。

generator = Generator()
discriminator = Discriminator()

定义损失函数。这个其实就是二分类的交叉熵损失。

criterion = torch.nn.BCELoss()

定义优化函数。

optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))

开始训练,实现GAN训练过程,其中生成器和判别器交替训练,通过对抗过程使得生成器生成逼真的图像,而判别器不断提高对真实和生成图像的判别能力。

for epoch in range(epochs):                   # epoch:50for i, (imgs, _) in enumerate(dataloader):  # imgs:(64, 1, 28, 28)     _:label(64)imgs = imgs.view(imgs.size(0), -1)    # 将图片展开为28*28=784  imgs:(64, 784)real_img = Variable(imgs)     # 将tensor变成Variable放入计算图中,tensor变成variable之后才能进行反向传播求梯度real_label = Variable(torch.ones(imgs.size(0), 1))    ## 定义真实的图片label为1fake_label = Variable(torch.zeros(imgs.size(0), 1))    ## 定义假的图片的label为0real_out = discriminator(real_img)            # 将真实图片放入判别器中loss_real_D = criterion(real_out, real_label) # 得到真实图片的lossreal_scores = real_out                        # 得到真实图片的判别值,输出的值越接近1越好## 计算假的图片的损失## detach(): 从当前计算图中分离下来避免梯度传到G,因为G不用更新z = Variable(torch.randn(imgs.size(0), latent_dim))     ## 随机生成一些噪声, 大小为(128, 100)fake_img    = generator(z).detach()                                    ## 随机噪声放入生成网络中,生成一张假的图片。fake_out    = discriminator(fake_img)                                  ## 判别器判断假的图片loss_fake_D = criterion(fake_out, fake_label)                       ## 得到假的图片的lossfake_scores = fake_out## 损失函数和优化loss_D = loss_real_D + loss_fake_D  # 损失包括判真损失和判假损失optimizer_D.zero_grad()             # 在反向传播之前,先将梯度归0loss_D.backward()                   # 将误差反向传播optimizer_D.step()                  # 更新参数z = Variable(torch.randn(imgs.size(0), latent_dim))     ## 得到随机噪声fake_img = generator(z)                                             ## 随机噪声输入到生成器中,得到一副假的图片output = discriminator(fake_img)                                    ## 经过判别器得到的结果## 损失函数和优化loss_G = criterion(output, real_label)                              ## 得到的假的图片与真实的图片的label的lossoptimizer_G.zero_grad()                                             ## 梯度归0loss_G.backward()                                                   ## 进行反向传播optimizer_G.step()                                                  ## step()一般用在反向传播后面,用于更新生成网络的参数## 打印训练过程中的日志## item():取出单元素张量的元素值并返回该值,保持原元素类型不变if (i + 1) % 100 == 0:print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] [D real: %f] [D fake: %f]"% (epoch, epochs, i, len(dataloader), loss_D.item(), loss_G.item(), real_scores.data.mean(), fake_scores.data.mean()))## 保存训练过程中的图像batches_done = epoch * len(dataloader) + iif batches_done % sample_interval == 0:save_image(fake_img.data[:25], "./images/%d.png" % batches_done, nrow=5, normalize=True)

在这里插入图片描述
保存模型。

torch.save(generator.state_dict(), './save/generator.pth')
torch.save(discriminator.state_dict(), './save/discriminator.pth')

查看最初的噪声图像。
在这里插入图片描述
查看后面生成的图像。
在这里插入图片描述

总结

对于GAN,生成器的任务是从随机噪声生成逼真的数据样本,判别器的任务是对给定的数据样本进行分类,判断其是真实数据还是由生成器生成的。生成器和判别器通过对抗的方式进行训练。在每个训练迭代中,生成器试图生成逼真的样本以欺骗判别器,而判别器努力提高自己的能力,以正确地区分真实和生成的样本。这种对抗训练的动态平衡最终导致生成器生成高质量、逼真的样本。

总之,GAN实现了在无监督情况下学习数据分布的能力,被广泛用于生成逼真图像、视频等数据。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/241807.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

什么是DDOS高防ip?DDOS高防ip是怎么防护攻击的

随着互联网的快速发展,网络安全问题日益突出,DDoS攻击和CC攻击等网络威胁对企业和网站的正常运营造成了巨大的威胁。为了解决这些问题,高防IP作为一种网络安全服务应运而生。高防IP通过实时监测和分析流量,识别和拦截恶意流量&…

AI时代—ChatGPT-4.5的正确打开方式

前言 前些天发现了一个巨牛的人工智能学习网站,通俗易懂,风趣幽默,忍不住分享一下给大家:https://www.captainbed.cn/z ChatGPT体验地址 文章目录 前言4.5key价格泄漏ChatGPT4.0使用地址ChatGPT正确打开方式最新功能语音助手存档…

航空飞行器运维VR模拟互动教学更直观有趣

传统的二手车鉴定评估培训模式存在实践性不强、教学样本不足、与实际脱节等一些固有的不足。有了VR虚拟仿真技术的加持,二手车鉴定评估VR虚拟仿真实训系统逐渐进入实训领域,为院校及企业二手车检测培训提供了全新的解决方案。 高职院校汽车专业虚拟仿真实…

Excel 根据日期按月汇总公式

Excel 根据日期按月汇总公式 数据透视表日期那一列右击,选择“组合”,步长选择“月” 参考 Excel 根据日期按月汇总公式Excel如何按着日期来做每月求和

想要在线使用XD就这么做!简单又高效

XD 文件是由 Adobe XD 这款免费轻量级原型软件制作的静态页面设计或原型交互动态文件。打开它 XD 文件的优点是可以快速设计和制作原型,并向团队或客户展示。目前,Adobe XD 基本上可以满足原型和设计草案的绘制,但与其他专业的交互原型制作软…

KBP206-ASEMI小功率家用电源KBP206

编辑:ll KBP206-ASEMI小功率家用电源KBP206 型号:KBP206 品牌:ASEMI 正向电流(Id):2A 反向耐压(VRRM):600V 正向浪涌电流:60A 正向电压(V…

【RT-DETR有效改进】华为 | Ghostnetv1一种专为移动端设计的特征提取网络

前言 大家好,这里是RT-DETR有效涨点专栏。 本专栏的内容为根据ultralytics版本的RT-DETR进行改进,内容持续更新,每周更新文章数量3-10篇。 专栏以ResNet18、ResNet50为基础修改版本,同时修改内容也支持ResNet32、ResNet101和PP…

matplotlib从起点出发(12)_Tutorial_12_MultiAxes

在一个Figure中安排多个Axes 通常在一个图像中,需要同时呈现多于一个Axes,并且需要对齐到网格. Matplotlib有多种工具用于处理在本库历史中演变的Axes网格,我们将讨论我们认为用户最常使用的工具,支持Axes组织方式的工具&#xf…

新能源汽车智慧充电桩方案:如何实现充电停车智慧化管理?

一、方案概述 基于新能源汽车充电桩的监管运营等需求,安徽旭帆科技携手合作伙伴触角云共同打造“智能充电设备+云平台+APP小程序”一体化完整的解决方案,为充电桩车位场所提供精细化管理车位的解决办法,解决燃油车恶意…

GO 中如何防止 goroutine 泄露

文章目录 概述如何监控泄露一个简单的例子泄露情况分类chanel 引起的泄露发送不接收接收不发送nil channel真实的场景 传统同步机制MutexWaitGroup 总结参考资料 今天来简单谈谈,Go 如何防止 goroutine 泄露。 概述 Go 的并发模型与其他语言不同,虽说它…

第十二篇【传奇开心果系列】Ant Design Mobile of React开发移动应用:内置组件实现酷炫CSS 动画

Ant Design Mobile of React 开发移动应用示例博文系列 第一篇【传奇开心果系列】Ant Design Mobile of React 开发移动应用:从helloworld开始 第二篇【传奇开心果系列】Ant Design Mobile of React 开发移动应用:天气应用 第三篇【传奇开心果系列】Ant Design Mobile of Reac…

逸学Docker【java工程师基础】3.4Docker安装redis

1.拉取redis docker pull redis 2.选择一个合适的redis 版本的配置文件 Redis configuration | Redis 或者这个 链接:https://pan.baidu.com/s/1RRdtgec4xBAgQghlhm0x1Q 提取码:ycyc 在1044行修改密码 3.提前在服务器建立 /data/redis 文件夹&…

前端下载文件流,设置返回值类型responseType:‘blob‘无效的问题

前言: 本是一个非常简单的请求,即是下载文件。通常的做法如下: 1.前端通过Vue Axios向后端请求,同时在请求中设置响应体为Blob格式。 2.后端相应前端的请求,同时返回Blob格式的文件给到前端(如果没有步骤…

AI短视频制作:创意与技术的完美结合

文章目录 一、充分了解AI技术的应用范围和优势二、创意策划,确定作品主题和风格三、素材收集,丰富作品内容四、特效制作,提升作品视觉效果五、配音处理,增强作品表现力六、作品发布,扩大作品传播范围《AI短视频制作一本…

Laravel7 + easyWeChat 实现微信公众号支付功能

注册服务号,需进行微信认证,此时需缴费 300 元/年,必须是认证成功的服务号才能开通微信支付。 注册微信支付商户号 1、登录 https://pay.weixin.qq.com/index.php/core/home/login?return_urlhttps%3A%2F%2Fpay.weixin.qq.com%2Findex.php%…

Python爬虫学习笔记(一)---Python入门

一、pycharm的安装及使用二、python的基础使用1、字符串连接2、单双引号转义3、换行4、三引号跨行字符串5、命名规则6、注释7、 优先级not>and>or8、列表(list)9、字典(dictionary)10、元组(tuple)11…

PE解释器之PE文件结构(二)

接下来的内容是对IMAGE_OPTIONAL_HEADER32中的最后一个成员DataDirectory,虽然他只是一个结构体数组,每个结构体的大小也不过是个字节,但是它却是PE文件中最重要的成员。PE装载器通过查看它才能准确的找到某个函数或某个资源。 一&#xff1…

qt学习:实战 读取txt文件+定时器点名

目录 目标 步骤 头文件 配置ui界面 在.h里定义槽函数和字符串链表和定时器指针 在构造函数里读取txt文件并初始化定时器 开始按钮点击函数 开始定时器 停止按钮点击函数 关闭定时器 定时器槽函数 目标 两个按钮,一个开始点名,一个停止点名一个…

用Go plan9汇编实现斐波那契数列计算

斐波那契数列是一个满足递推关系的数列,如:1 1 2 3 5 8 ... 其前两项为1,第3项开始,每一项都是其前两项之和。 用Go实现一个简单的斐波那契计算逻辑 func fib(n int) int {if n 1 || n 2 {return 1}return fib(n-1) fib(n-2) …

C# 获取QQ会话聊天信息

目录 利用UIAutomation获取QQ会话聊天信息 效果 代码 目前遇到一个问题 其他解决办法 利用UIAutomation获取QQ会话聊天信息 效果 代码 AutomationElement window AutomationElement.FromHandle(get.WindowHwnd); AutomationElement QQMsgList window.FindFirst(Tr…