【Pytorch】生成对抗网络实战

GAN框架基于两个模型的竞争,Generator生成器和Discriminator鉴别器。生成器生成假图像,鉴别器则尝试从假图像中识别真实的图像。作为这种竞争的结果,生成器将生成更好看的假图像,而鉴别器将更好地识别它们。

目录

创建数据集

定义生成器

定义鉴别器

初始化模型权重

定义损失函数

定义优化器

训练模型

部署生成器


创建数据集

使用 PyTorch torchvision 包中提供的 STL-10 数据集,数据集中有 10 个类:飞机、鸟、车、猫、鹿、狗、马、猴、船、卡车。图像为96*96像素的RGB图像。数据集包含 5,000 张训练图像和 8,000 张测试图像。在训练数据集和测试数据集中,每个类分别有 500 和 800 张图像。

 STL-10数据集详细参考http://t.csdnimg.cn/ojBn6中数据加载和处理部分 

from torchvision import datasets
import torchvision.transforms as transforms
import os# 定义数据集路径
path2data="./data"
# 创建数据集路径
os.makedirs(path2data, exist_ok= True)# 定义图像尺寸
h, w = 64, 64
# 定义均值
mean = (0.5, 0.5, 0.5)
# 定义标准差
std = (0.5, 0.5, 0.5)
# 定义数据预处理
transform= transforms.Compose([transforms.Resize((h,w)),  # 调整图像尺寸transforms.CenterCrop((h,w)),  # 中心裁剪transforms.ToTensor(),  # 转换为张量transforms.Normalize(mean, std)])  # 归一化# 加载训练集
train_ds=datasets.STL10(path2data, split='train', download=False,transform=transform)

 展示示例图像张量形状、最小值和最大值

import torch
for x, _ in train_ds:print(x.shape, torch.min(x), torch.max(x))break

 展示示例图像

from torchvision.transforms.functional import to_pil_image
import matplotlib.pylab as plt
%matplotlib inline
plt.imshow(to_pil_image(0.5*x+0.5))

 

创建数据加载器 

import torch
batch_size = 32
train_dl = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=True)

 示例

for x,y in train_dl:print(x.shape, y.shape)break

定义生成器

GAN框架是基于两个模型的竞争,generator生成器和discriminator鉴别器。生成器生成假图像,鉴别器尝试从假图像中识别真实的图像。

作为这种竞争的结果,生成器将生成更好看的假图像,而鉴别器将更好地识别它们。

定义生成器模型 

from torch import nn
import torch.nn.functional as Fclass Generator(nn.Module):def __init__(self, params):super(Generator, self).__init__()# 获取参数nz = params["nz"]ngf = params["ngf"]noc = params["noc"]# 定义反卷积层1self.dconv1 = nn.ConvTranspose2d( nz, ngf * 8, kernel_size=4,stride=1, padding=0, bias=False)# 定义批归一化层1self.bn1 = nn.BatchNorm2d(ngf * 8)# 定义反卷积层2self.dconv2 = nn.ConvTranspose2d(ngf * 8, ngf * 4, kernel_size=4, stride=2, padding=1, bias=False)# 定义批归一化层2self.bn2 = nn.BatchNorm2d(ngf * 4)# 定义反卷积层3self.dconv3 = nn.ConvTranspose2d( ngf * 4, ngf * 2, kernel_size=4, stride=2, padding=1, bias=False)# 定义批归一化层3self.bn3 = nn.BatchNorm2d(ngf * 2)# 定义反卷积层4self.dconv4 = nn.ConvTranspose2d( ngf * 2, ngf, kernel_size=4, stride=2, padding=1, bias=False)# 定义批归一化层4self.bn4 = nn.BatchNorm2d(ngf)# 定义反卷积层5self.dconv5 = nn.ConvTranspose2d( ngf, noc, kernel_size=4, stride=2, padding=1, bias=False)# 前向传播def forward(self, x):# 反卷积层1x = F.relu(self.bn1(self.dconv1(x)))# 反卷积层2x = F.relu(self.bn2(self.dconv2(x)))            # 反卷积层3x = F.relu(self.bn3(self.dconv3(x)))        # 反卷积层4x = F.relu(self.bn4(self.dconv4(x)))    # 反卷积层5out = torch.tanh(self.dconv5(x))return out

设定生成器模型参数、移动模型到cuda设备并打印模型结构 

params_gen = {"nz": 100,"ngf": 64,"noc": 3,}
model_gen = Generator(params_gen)
device = torch.device("cuda:0")
model_gen.to(device)
print(model_gen)

定义鉴别器

定义鉴别器模型, 用于鉴别真实图像

class Discriminator(nn.Module):def __init__(self, params):super(Discriminator, self).__init__()# 获取参数nic= params["nic"]ndf = params["ndf"]# 定义卷积层1self.conv1 = nn.Conv2d(nic, ndf, kernel_size=4, stride=2, padding=1, bias=False)# 定义卷积层2self.conv2 = nn.Conv2d(ndf, ndf * 2, kernel_size=4, stride=2, padding=1, bias=False)# 定义批归一化层2self.bn2 = nn.BatchNorm2d(ndf * 2)            # 定义卷积层3self.conv3 = nn.Conv2d(ndf * 2, ndf * 4, kernel_size=4, stride=2, padding=1, bias=False)# 定义批归一化层3self.bn3 = nn.BatchNorm2d(ndf * 4)# 定义卷积层4self.conv4 = nn.Conv2d(ndf * 4, ndf * 8, kernel_size=4, stride=2, padding=1, bias=False)# 定义批归一化层4self.bn4 = nn.BatchNorm2d(ndf * 8)# 定义卷积层5self.conv5 = nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=1, padding=0, bias=False)def forward(self, x):# 使用leaky_relu激活函数对卷积层1的输出进行激活x = F.leaky_relu(self.conv1(x), 0.2, True)# 使用leaky_relu激活函数对卷积层2的输出进行激活,并使用批归一化层2进行批归一化x = F.leaky_relu(self.bn2(self.conv2(x)), 0.2, inplace = True)# 使用leaky_relu激活函数对卷积层3的输出进行激活,并使用批归一化层3进行批归一化x = F.leaky_relu(self.bn3(self.conv3(x)), 0.2, inplace = True)# 使用leaky_relu激活函数对卷积层4的输出进行激活,并使用批归一化层4进行批归一化x = F.leaky_relu(self.bn4(self.conv4(x)), 0.2, inplace = True)        # 使用sigmoid激活函数对卷积层5的输出进行激活,并返回结果# Sigmoid激活函数是一种常用的非线性激活函数,它将输入值压缩到0和1之间,[ \sigma(x) = \frac{1}{1 + e^{-x}} ]out = torch.sigmoid(self.conv5(x))return out.view(-1)

设置模型参数,移动模型到cuda设备,打印模型结构 


params_dis = {"nic": 3,"ndf": 64}
model_dis = Discriminator(params_dis)
model_dis.to(device)
print(model_dis)

初始化模型权重

定义函数,初始化模型权重 

def initialize_weights(model):# 获取模型类的名称classname = model.__class__.__name__# 如果模型类名称中包含'Conv',则初始化权重为均值为0,标准差为0.02的正态分布if classname.find('Conv') != -1:nn.init.normal_(model.weight.data, 0.0, 0.02)# 如果模型类名称中包含'BatchNorm',则初始化权重为均值为1,标准差为0.02的正态分布,偏置为0elif classname.find('BatchNorm') != -1:nn.init.normal_(model.weight.data, 1.0, 0.02)nn.init.constant_(model.bias.data, 0)

初始化生成器模型和鉴别器模型的权重 

# 对生成器模型应用初始化权重函数
model_gen.apply(initialize_weights);
# 对判别器模型应用初始化权重函数
model_dis.apply(initialize_weights);

定义损失函数

定义二元交叉熵(BCE)损失函数 

loss_func = nn.BCELoss()

定义优化器

定义Adam优化器

from torch import optim
# 学习率
lr = 2e-4 
# Adam优化器的beta1参数
beta1 = 0.5
# 定义鉴别器模型的优化器,学习率为lr,beta1参数为beta1,beta2参数为0.999
opt_dis = optim.Adam(model_dis.parameters(), lr=lr, betas=(beta1, 0.999))
# 定义生成器模型的优化器
opt_gen = optim.Adam(model_gen.parameters(), lr=lr, betas=(beta1, 0.999))

训练模型

 示例训练1000个epochs

# 定义真实标签和虚假标签
real_label = 1
fake_label = 0
# 获取生成器的噪声维度
nz = params_gen["nz"]
# 设置训练轮数
num_epochs = 1000
# 定义损失历史记录
loss_history={"gen": [],"dis": []}
# 定义批次数
batch_count = 0
# 遍历训练轮数
for epoch in range(num_epochs):# 遍历训练数据for xb, yb in train_dl:# 获取批大小ba_si = xb.size(0)# 将判别器梯度置零model_dis.zero_grad()# 将输入数据移动到指定设备xb = xb.to(device)# 将标签数据转换为指定设备yb = torch.full((ba_si,), real_label, device=device)# 判别器输出out_dis = model_dis(xb)# 将输出和标签转换为浮点数out_dis = out_dis.float()yb = yb.float()# 计算真实样本的损失loss_r = loss_func(out_dis, yb)# 反向传播loss_r.backward()# 生成噪声noise = torch.randn(ba_si, nz, 1, 1, device=device)# 生成器输出out_gen = model_gen(noise)# 判别器输出out_dis = model_dis(out_gen.detach())# 将标签数据填充为虚假标签yb.fill_(fake_label)    # 计算虚假样本的损失loss_f = loss_func(out_dis, yb)# 反向传播loss_f.backward()# 计算判别器的总损失loss_dis = loss_r + loss_f  # 更新判别器的参数opt_dis.step()   # 将生成器梯度置零model_gen.zero_grad()# 将标签数据填充为真实标签yb.fill_(real_label)  # 判别器输出out_dis = model_dis(out_gen)# 计算生成器的损失loss_gen = loss_func(out_dis, yb)# 反向传播loss_gen.backward()# 更新生成器的参数opt_gen.step()# 记录生成器和判别器的损失loss_history["gen"].append(loss_gen.item())loss_history["dis"].append(loss_dis.item())# 更新批次数batch_count += 1# 每100个批打印一次损失if batch_count % 100 == 0:print(epoch, loss_gen.item(),loss_dis.item())

 绘制损失图像

plt.figure(figsize=(10,5))
plt.title("Loss Progress")
plt.plot(loss_history["gen"],label="Gen. Loss")
plt.plot(loss_history["dis"],label="Dis. Loss")
plt.xlabel("batch count")
plt.ylabel("Loss")
plt.legend()
plt.show()

存储模型权重 

import os
path2models = "./models/"
os.makedirs(path2models, exist_ok=True)
path2weights_gen = os.path.join(path2models, "weights_gen_128.pt")
path2weights_dis = os.path.join(path2models, "weights_dis_128.pt")
torch.save(model_gen.state_dict(), path2weights_gen)
torch.save(model_dis.state_dict(), path2weights_dis)

部署生成器

通常情况下,训练完成后放弃鉴别器模型而保留生成器模型,部署经过训练的生成器来生成新的图像。为部署生成器模型,将训练好的权重加载到模型中,然后给模型提供随机噪声。

# 加载生成器模型的权重
weights = torch.load(path2weights_gen)
# 将权重加载到生成器模型中
model_gen.load_state_dict(weights)
# 将生成器模型设置为评估模式
model_gen.eval()

 生成图像

import numpy as np
with torch.no_grad():# 生成固定噪声fixed_noise = torch.randn(16, nz, 1, 1, device=device)# 打印噪声形状print(fixed_noise.shape)# 生成假图像img_fake = model_gen(fixed_noise).detach().cpu()    
# 打印假图像形状
print(img_fake.shape)
# 创建画布
plt.figure(figsize=(10,10))
# 遍历假图像
for ii in range(16):# 在画布上绘制图像plt.subplot(4,4,ii+1)# 将图像转换为PIL图像plt.imshow(to_pil_image(0.5*img_fake[ii]+0.5))# 关闭坐标轴plt.axis("off")

其中一些可能看起来扭曲,而另一些看起来相对真实。为改进结果,可以在单个数据类上训练模型,而不是在多个类上一起训练。GAN在使用单个类进行训练时表现更好。此外,可以尝试更长时间地训练模型。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/409302.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

前端宝典十一:前端工程化稳定性方案

一、工程化体系介绍 1、什么是前端工程化 前端工程化 前端 软件工程;前端工程化 将工程方法系统化地应用到前端开发中;前端工程化 系统、严谨、可量化的方法开发、运营和维护前端应用程序;前端工程化 基于业务诉求,梳理出最…

redhawk:STA timing data file解析

我正在「拾陆楼」和朋友们讨论有趣的话题,你⼀起来吧? 拾陆楼知识星球入口 往期文章:

PyTorch深度学习网络(一:MLP)

全连接神经网络,又称多层感知机(MLP),是深度学习最基础的神经网络。全连接神经网络主要由输入层、隐藏层和输出层构成。本文实现了一个通用MLP网络,包括以下功能: 根据输入的特征数、类别数、各隐藏层神经…

以简单的例子从头开始建spring boot web多模块项目(五)-thymeleaf引擎

继续向里面加,这次是引入thymeleaf渲染引擎。 使用这个引擎的很多,主要是以下几个优点: Thymeleaf是适用于Web和独立环境的现代服务器端Java模板引擎。Thymeleaf的主要目标是为您的开发工作流程带来优雅的自然模板 -HTML可以在浏览器中正确显…

Vue3加vite使用Cesium绘制图形

Vue3加vite使用Cesium绘制图形 1、项目开发准备 Node版本:16.20.2 1.1创建一个新的工程:my-cesium-app npm create vitelatest my-cesium-app – --template vue1.2 安装Element Plus npm install element-plus --save // main.js import ElementPl…

【STM32】看门狗

看门狗,还没有别的地方用上,暂时还不清楚在实际应用中最多的场景是什么,我感觉是用来强制重启系统。 大部分图片来源:正点原子HAL库教程 专栏目录:记录自己的嵌入式学习之路-CSDN博客 目录 1 应用场景 1.1 解决…

Langchain Memory组件深度剖析:从对话基础到高级链式应用

文章目录 前言一、Langchain memory 记忆1.Memory 组件基本介绍2.Memory 组件的类型1.ChatMessageHistory2.ConversationBufferMemory3.ConversationBufferWindowMemory4.ConversationEntityMemory5.ConversationKGMemory6.ConversationSummaryMemory 二、长时记忆1.简单介绍2.…

解决ubuntu22.04无法识别CH340/CH341和vscode espidf插件无法选择串口设备节点问题

文章目录 解决ubuntu22.04无法识别CH340/CH341和vscode espidf插件无法选择串口设备节点问题不识别CH340/CH341报错解决办法升级驱动编译安装 卸载brltty程序 vscode espidf插件无法选择串口设备节点问题解决办法编译安装 解决ubuntu22.04无法识别CH340/CH341和vscode espidf插…

C#开发中ImageComboBox控件数据源实时变换

在C#开发中,我们如何将控件的数据源实时变换,当然我们可以在窗口实例化的时候指定固定的数据源,但是这样对于用户来说数据源永远固定,并不利于我们对于用户的数据存储,优化用户的操作,遇到这种问题&#xf…

Flutter ListView滑动

在Flutter中,ScrollController可以精确地控制和管理滚动行为。通过ScrollController,可以监听滚动的位置、速度,甚至可以在用户滚动时触发自定义的动作。此外,ScrollController还提供了对滚动位置的直接控制,可以编程地…

DRF——请求的封装与版本管理

文章目录 django restframework1. 快速上手2. 请求数据的封装3. 版本管理3.1 URL的GET参数传递(*)3.2 URL路径传递(*)3.3 请求头传递3.4 二级域名传递3.5 路由的namespace传递 小结 django restframework 快速上手请求的封装版本…

科大讯飞刘聪:大模型加持,人形机器人将跨越三大瓶颈

2024年,AI大模型成为机器人产业新的加速器。 今年3月,ChatGPT4加持的机器人Figure01向外界展示了大模型赋能人形机器人的巨大潜力。Figure01能理解周围环境,流畅地与人类交谈,理解人类的需求并完成具体行动,包括给人类…

虚幻5|AI视力系统,听力系统,预测系统(2)听力系统

虚幻5|AI视力系统,听力系统,预测系统(1)视力系统-CSDN博客 一,把之前的听力系统,折叠成函数,复制粘贴一份改名为听力系统 1.小个体修改如下,把之前的视力系统改成听力系统 2.整体修…

隐私指纹浏览器产品系列 —— 浏览器指纹 中(三)

1.引言 在上一篇文章中,我们聊到了最老牌的浏览器指纹检测站——BrowserLeaks。BrowserLeaks曾经是浏览器指纹检测的权威,但它似乎更像是一本老旧的工具书,只能呆板告诉你浏览器的指纹值,并对比不同浏览器的指纹差异。 今天&…

C语言 之 浮点数在内存中的存储 详细讲解

文章目录 浮点数浮点数的存储浮点数的存储浮点数的读取例题 浮点数 常见的浮点数:3.14159、1E10(表示1*10^10)等 浮点数家族包括: float、double、long double 类型。 浮点数表示的范围在float.h 中有定义 浮点数的存储 浮点数…

C++研发笔记1——github注册文档

1、第一步:登录网站 GitHub: Let’s build from here GitHub 最新跳转页面如下: 2、选择“sign up”进行注册,并填写设置账户信息 3、创建账户成功之后需要进行再次登录 4、根据实际情况填写个人状态信息 登录完成后页面网站: 5…

手写SpringAOP

一、非注解式简易版AOP 整体流程 1.1 代码 public class Test {public static void main(String[] args){// Aop代理工厂DefaultAopProxyFactory factory new DefaultAopProxyFactory();// 测试对象AOPDemoImpl demo new AOPDemoImpl();// 支撑类:用于存放目标…

配置策略路由实战 附带基础网络知识

背景 作为一个软件开发人员,不可能做到只负责业务开发工作,一旦功能上线或者系统切换就会遇到非常多考验开发人员个人能力的场景,网络调整就是非常重要的一个方面,如果你在系统上线的过程中无法处理一些简单的网络问题或者听不懂…

文件包含漏洞(1)

目录 PHP伪协议 php://input Example 1&#xff1a; 造成任意代码执行 Example 2&#xff1a; 文件内容绕过 php://filer zip:// PHP伪协议 php://input Example 1&#xff1a; 造成任意代码执行 搭建环境 <meta charset"utf8"> <?php error_repo…

Modern C++——不准确“类型声明”引发的非必要性能损耗

大纲 案例代码地址 C是一种强类型语言。我们在编码时就需要明确指出每个变量的类型&#xff0c;进而让编译器可以正确的编译。看似C编译器比其他弱类型语言的编译器要死板&#xff0c;实则它也做了很多“隐藏”的操作。它会在尝试针对一些非预期类型进行相应转换&#xff0c;以…