PyTorch训练深度卷积生成对抗网络DCGAN

文章目录

    • DCGAN介绍
    • 代码
    • 结果
    • 参考

DCGAN介绍

将CNN和GAN结合起来,把监督学习和无监督学习结合起来。具体解释可以参见 深度卷积对抗生成网络(DCGAN)

DCGAN的生成器结构:
在这里插入图片描述
图片来源:https://arxiv.org/abs/1511.06434

代码

model.py

import torch
import torch.nn as nnclass Discriminator(nn.Module):def __init__(self, channels_img, features_d):super(Discriminator, self).__init__()self.disc = nn.Sequential(# Input: N x channels_img x 64 x 64nn.Conv2d(channels_img, features_d, kernel_size=4, stride=2, padding=1), # 32 x 32nn.LeakyReLU(0.2),self._block(features_d, features_d*2, 4, 2, 1), # 16 x 16self._block(features_d*2, features_d*4, 4, 2, 1), # 8 x 8self._block(features_d*4, features_d*8, 4, 2, 1), # 4 x 4nn.Conv2d(features_d*8, 1, kernel_size=4, stride=2, padding=0), # 1 x 1nn.Sigmoid(),)def _block(self, in_channels, out_channels, kernel_size, stride, padding):return nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),nn.BatchNorm2d(out_channels),nn.LeakyReLU(0.2),)def forward(self, x):return self.disc(x)class Generator(nn.Module):def __init__(self, z_dim, channels_img, features_g):super(Generator, self).__init__()self.gen = nn.Sequential(# Input: N x z_dim x 1 x 1self._block(z_dim, features_g*16, 4, 1, 0), # N x f_g*16 x 4 x 4self._block(features_g*16, features_g*8, 4, 2, 1), # 8x8self._block(features_g*8, features_g*4, 4, 2, 1), # 16x16self._block(features_g*4, features_g*2, 4, 2, 1), # 32x32nn.ConvTranspose2d(features_g*2, channels_img, kernel_size=4, stride=2, padding=1,),nn.Tanh(),)def _block(self, in_channels, out_channels, kernel_size, stride, padding):return nn.Sequential(nn.ConvTranspose2d(in_channels,out_channels,kernel_size,stride,padding,bias=False,),nn.BatchNorm2d(out_channels),nn.ReLU(),)def forward(self, x):return self.gen(x)def initialize_weights(model):for m in model.modules():if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):nn.init.normal_(m.weight.data, 0.0, 0.02)def test():N, in_channels, H, W = 8, 3, 64, 64z_dim = 100x = torch.randn((N, in_channels, H, W))disc = Discriminator(in_channels, 8)initialize_weights(disc)assert disc(x).shape == (N, 1, 1, 1)gen = Generator(z_dim, in_channels, 8)initialize_weights(gen)z = torch.randn((N, z_dim, 1, 1))assert gen(z).shape == (N, in_channels, H, W)print("success")if __name__ == "__main__":test()

训练使用的数据集:CelebA dataset (Images Only) 总共1.3GB的图片,使用方法,将其解压到当前目录

图片如下图所示:
在这里插入图片描述

train.py

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from model import Discriminator, Generator, initialize_weights# Hyperparameters etc.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LEARNING_RATE = 2e-4  # could also use two lrs, one for gen and one for disc
BATCH_SIZE = 128
IMAGE_SIZE = 64
CHANNELS_IMG = 3 # 1 if MNIST dataset; 3 if celeb dataset
NOISE_DIM = 100
NUM_EPOCHS = 5
FEATURES_DISC = 64
FEATURES_GEN = 64transforms = transforms.Compose([transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),transforms.ToTensor(),transforms.Normalize([0.5 for _ in range(CHANNELS_IMG)], [0.5 for _ in range(CHANNELS_IMG)]),]
)# If you train on MNIST, remember to set channels_img to 1
# dataset = datasets.MNIST(
#     root="dataset/", train=True, transform=transforms, download=True
# )# comment mnist above and uncomment below if train on CelebA# If you train on celeb dataset, remember to set channels_img to 3
dataset = datasets.ImageFolder(root="celeb_dataset", transform=transforms)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
gen = Generator(NOISE_DIM, CHANNELS_IMG, FEATURES_GEN).to(device)
disc = Discriminator(CHANNELS_IMG, FEATURES_DISC).to(device)
initialize_weights(gen)
initialize_weights(disc)opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
opt_disc = optim.Adam(disc.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
criterion = nn.BCELoss()fixed_noise = torch.randn(32, NOISE_DIM, 1, 1).to(device)
writer_real = SummaryWriter(f"logs/real")
writer_fake = SummaryWriter(f"logs/fake")
step = 0gen.train()
disc.train()for epoch in range(NUM_EPOCHS):# Target labels not needed! <3 unsupervisedfor batch_idx, (real, _) in enumerate(dataloader):real = real.to(device)noise = torch.randn(BATCH_SIZE, NOISE_DIM, 1, 1).to(device)fake = gen(noise)### Train Discriminator: max log(D(x)) + log(1 - D(G(z)))disc_real = disc(real).reshape(-1)loss_disc_real = criterion(disc_real, torch.ones_like(disc_real))disc_fake = disc(fake.detach()).reshape(-1)loss_disc_fake = criterion(disc_fake, torch.zeros_like(disc_fake))loss_disc = (loss_disc_real + loss_disc_fake) / 2disc.zero_grad()loss_disc.backward()opt_disc.step()### Train Generator: min log(1 - D(G(z))) <-> max log(D(G(z))output = disc(fake).reshape(-1)loss_gen = criterion(output, torch.ones_like(output))gen.zero_grad()loss_gen.backward()opt_gen.step()# Print losses occasionally and print to tensorboardif batch_idx % 100 == 0:print(f"Epoch [{epoch}/{NUM_EPOCHS}] Batch {batch_idx}/{len(dataloader)} \Loss D: {loss_disc:.4f}, loss G: {loss_gen:.4f}")with torch.no_grad():fake = gen(fixed_noise)# take out (up to) 32 examplesimg_grid_real = torchvision.utils.make_grid(real[:32], normalize=True)img_grid_fake = torchvision.utils.make_grid(fake[:32], normalize=True)writer_real.add_image("Real", img_grid_real, global_step=step)writer_fake.add_image("Fake", img_grid_fake, global_step=step)step += 1

结果

训练5个epoch,部分结果如下:

Epoch [3/5] Batch 1500/1583                   Loss D: 0.4996, loss G: 1.1738
Epoch [4/5] Batch 0/1583                   Loss D: 0.4268, loss G: 1.6633
Epoch [4/5] Batch 100/1583                   Loss D: 0.4841, loss G: 1.7475
Epoch [4/5] Batch 200/1583                   Loss D: 0.5094, loss G: 1.2376
Epoch [4/5] Batch 300/1583                   Loss D: 0.4376, loss G: 2.1271
Epoch [4/5] Batch 400/1583                   Loss D: 0.4173, loss G: 1.4380
Epoch [4/5] Batch 500/1583                   Loss D: 0.5213, loss G: 2.1665
Epoch [4/5] Batch 600/1583                   Loss D: 0.5036, loss G: 2.1079
Epoch [4/5] Batch 700/1583                   Loss D: 0.5158, loss G: 1.0579
Epoch [4/5] Batch 800/1583                   Loss D: 0.5426, loss G: 1.9427
Epoch [4/5] Batch 900/1583                   Loss D: 0.4721, loss G: 1.2659
Epoch [4/5] Batch 1000/1583                   Loss D: 0.5662, loss G: 2.4537
Epoch [4/5] Batch 1100/1583                   Loss D: 0.5604, loss G: 0.8978
Epoch [4/5] Batch 1200/1583                   Loss D: 0.4085, loss G: 2.0747
Epoch [4/5] Batch 1300/1583                   Loss D: 1.1894, loss G: 0.1825
Epoch [4/5] Batch 1400/1583                   Loss D: 0.4518, loss G: 2.1509
Epoch [4/5] Batch 1500/1583                   Loss D: 0.3814, loss G: 1.9391

使用

tensorboard --logdir=logs

打开tensorboard

在这里插入图片描述

参考

[1] DCGAN implementation from scratch
[2] https://arxiv.org/abs/1511.06434

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/96209.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

win系统部署Apollo-quick-start-2.1.0

win系统部署Apollo-quick-start-2.1.0 携程Apollo配置中心&#xff0c;官方部署包里提供了2个sql文件&#xff0c;需要刷入数据库。之后修改demo.sh里的数据库配置,最后使用git bash启动demo.sh刷sql脚本 官方部署包里提供了2个sql文件 修改demo.sh文件 使用git bash启动demo…

WinCC V7.5 中的C脚本对话框不可见,将编辑窗口移动到可见区域的具体方法

WinCC V7.5 中的C脚本对话框不可见&#xff0c;将编辑窗口移动到可见区域的具体方法 由于 Windows 系统更新或使用不同的显示器&#xff0c;在配置C动作时&#xff0c;有可能会出现C脚本编辑窗口被移动到不可见区域的现象。 由于该窗口无法被关闭&#xff0c;故无法进行进一步…

安防监控视频云存储EasyCVR平台H.265转码功能更新:新增分辨率配置

安防视频集中存储EasyCVR视频监控综合管理平台可以根据不同的场景需求&#xff0c;让平台在内网、专网、VPN、广域网、互联网等各种环境下进行音视频的采集、接入与多端分发。在视频能力上&#xff0c;视频云存储平台EasyCVR可实现视频实时直播、云端录像、视频云存储、视频存储…

微服务中间件-分布式缓存Redis

分布式缓存 a.Redis持久化1) RDB持久化1.a) RDB持久化-原理 2) AOF持久化3) 两者对比 b.Redis主从1) 搭建主从架构2) 数据同步原理&#xff08;全量同步&#xff09;3) 数据同步原理&#xff08;增量同步&#xff09; c.Redis哨兵1) 哨兵的作用2) 搭建Redis哨兵集群3) RedisTem…

小程序的数据绑定和事件绑定

小程序的数据绑定 1.需要渲染的数据放在index.js中的data里 Page({data: {info:HELLO WORLD,imgSrc:/images/1.jpg,randomNum:Math.random()*10,randomNum1:Math.random().toFixed(2)}, }) 2.在WXML中通过{{}}获取数据 <view>{{info}}</view><image src"{{…

设计模式之责任链模式

简介 责任链模式是一种行为设计模式&#xff0c; 允许你将请求沿着处理者链&#xff08;单向链表&#xff09;进行发送。 收到请求后&#xff0c; 每个处理者根据自身条件对请求进行处理&#xff0c; &#xff0c;如果处理不了则将其传递给链上的下个处理者&#xff0c;以此类…

【LVS】3、LVS+Keepalived群集

为什么用它&#xff0c;为了做高可用 服务功能 1.故障自动切换 2.健康检查 3.节点服务器高可用-HA Keepalived的三个模块&#xff1a; core&#xff1a;Keepalived的核心&#xff0c;负责主进程的启动、维护&#xff1b;调用全局配置文件进行加载和解析 vrrp&#xff1a;实…

5个高清视频素材网站

推荐5个高清视频素材网站&#xff0c;免费、付费、商用的都有&#xff0c;可根据自己需求去选择&#xff0c;赶紧收藏吧&#xff01; 菜鸟图库 https://www.sucai999.com/video.html?vNTYxMjky ​ 菜鸟图库网素材非常丰富&#xff0c;网站主要还是以设计类素材为主&#xff…

Spring-Bean的生命周期

目录 生命周期汇总 细分生命周期 1.实例化 2.属性赋值&#xff08;依赖注入&#xff09; 3.Aware接口 4.BeanPostProcessor接口 5.初始化 6.销毁 测试验证 类结构 业务类 测试类 生命周期汇总 Spring Bean 的生命周期见下图 &#xff08;一定记忆好下图&#x…

二进制数的左移和右移位运算numpy.left_shift()numpy.right_shift()

【小白从小学Python、C、Java】 【计算机等考500强证书考研】 【Python-数据分析】 二进制数的左移和右移位运算 numpy.left_shift() numpy.right_shift() [太阳]选择题 下列代码最后一次输出的结果是&#xff1f; import numpy as np a 8 print("【显示】a ", a)…

【编织时空四:探究顺序表与链表的数据之旅】

本章重点 链表的分类 带头双向循环链表接口实现 顺序表和链表的区别 缓存利用率参考存储体系结构 以及 局部原理性。 一、链表的分类 实际中链表的结构非常多样&#xff0c;以下情况组合起来就有8种链表结构&#xff1a; 1. 单向或者双向 2. 带头或者不带头 3. 循环或者非…

CANoe软件Tools中无法找到LDF Explorer

关联文章&#xff1a; LDF概述和LDF Explorer工具介绍 问题描述 使用CANoe软件的菜单栏Tools中无法找到LDF Explorer 原因分析/解决方案&#xff1a; ①查看CANoe硬件是否带LIN license&#xff0c;并且license在正常激活时间内。 ②查看CANoe是否配置了LIN通道&#xff0c;…

【链表】经典链表题LeetCode

文章目录 160. 相交链表 简单&#x1f525;206. 反转链表 简单&#x1f525;876. 链表的中间结点 简单234. 回文链表 简单&#x1f525;141. 环形链表 简单&#x1f525;142. 环形链表 II 中等&#x1f525;21. 合并两个有序链表 简单&#x1f525;2. 两数相加 中等&#x1f52…

【Unity】ShaderGraph应用(模型膨胀流动)

【Unity】ShaderGraph应用&#xff08;模型膨胀流动&#xff09; 实现效果 ShaderGraph是 unity的图形化 Shader 编程工具。本文介绍使用ShaderGraph实现模型的膨胀流动效果。该效果可以由于模拟流体在管线中的流动等相关功能。 一、实现的方法 1.使用节点介绍 关键节点 UV…

【Redis实践篇】使用Redisson 优雅实现项目实践过程中的5种场景

文章目录 1.前言2.使用方式1. 添加Redisson依赖&#xff1a;2. 配置Redis连接信息3. 使用场景3.1. 分布式锁3.2. 限流器&#xff08;Rate Limiter&#xff09;3.3. 可过期的对象&#xff08;Expirable Object&#xff09;3.4. 信号量&#xff08;Semaphore&#xff09;3.5. 分布…

首起针对国内金融企业的开源组件投毒攻击事件

简述 2023年8月9日&#xff0c;墨菲监控到用户名为 snugglejack_org (邮件地址&#xff1a;SnuggleBearrxxhotmail.com&#xff09;的用户发布到 NPM 仓库中的 ws-paso-jssdk 组件包具有发向 https://ql.rustdesk[.]net 的可疑流量&#xff0c;经过确认该组件包携带远控脚本&a…

PHP 从 URL(链接) 字符串中获取参数

PHP 从 URL&#xff08;链接&#xff09; 字符串中获取参数 //URL(链接)字符串 $url https://www.baidu.com/?name小洪帽i&sex男&age999; //parse_url 函数从一个 URL 字符串中获取参数 $urlparse_url($url); //输出获取到的内容 echo "<pre>"; pri…

PyTorch学习笔记(十三)——现有网络模型的使用及修改

以分类模型的VGG为例 vgg16_false torchvision.models.vgg16(weightsFalse) vgg16_true torchvision.models.vgg16(weightsTrue) 设置为 False 的情况&#xff0c;相当于网络模型中的参数都是初始化的、默认的设置为 True 时&#xff0c;网络模型中的参数在数据集上是训练好…

WSL2 ubuntu子系统OpenCV调用本机摄像头的RTSP视频流做开发测试

文章目录 前言一、Ubuntu安装opencv库二、启动 Windows 本机的 RTSP 视频流下载解压 EasyDarwin查看本机摄像头设备开始推流 三、在ubuntu 终端编写代码创建目录及文件创建CMakeLists.txt文件启动 cmake 配置并构建 四、结果展示启动图形界面在图形界面打开终端找到 rtsp_demo运…

汽车租赁管理系统/汽车租赁网站的设计与实现

摘 要 租赁汽车走进社区&#xff0c;走进生活&#xff0c;成为当今生活中不可缺少的一部分。随着汽车租赁业的发展&#xff0c;加强管理和规范管理司促进汽车租赁业健康发展的重要推动力。汽车租赁业为道路运输车辆一种新的融资服务形式、广大人民群众一种新的出行消费方式和…