【图像合成】基于DCGAN典型网络的MNIST字符生成(pytorch)

关于

 

近年来,基于卷积网络(CNN)的监督学习已经 在计算机视觉应用中得到了广泛的采用。相比之下,无监督 使用 CNN 进行学习受到的关注较少。在这项工作中,我们希望能有所帮助 缩小了 CNN 在监督学习和无监督学习方面的成功之间的差距。我们介绍一类称为深度卷积生成的 CNN 对抗性网络(DCGAN),具有一定的架构限制,以及 证明他们是无监督学习的有力候选人。训练 在各种图像数据集上,我们展示了令人信服的证据,表明我们的深度卷积对抗对学习了从对象部分到 生成器和鉴别器中的场景。此外,我们使用学到的 新任务的特征 - 证明它们作为一般图像表示的适用性。(https://arxiv.org/pdf/1511.06434.pdf)

工具

 数据集

方法实现

加载必要的库函数和自定义函数

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as Ffrom torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from torchvision.utils import save_image
def get_sample_image(G, n_noise):"""save sample 100 images"""z = torch.randn(100, n_noise).to(DEVICE)y_hat = G(z).view(100, 28, 28) # (100, 28, 28)result = y_hat.cpu().data.numpy()img = np.zeros([280, 280])for j in range(10):img[j*28:(j+1)*28] = np.concatenate([x for x in result[j*10:(j+1)*10]], axis=-1)return img

定义判别模型

class Discriminator(nn.Module):"""Convolutional Discriminator for MNIST"""def __init__(self, in_channel=1, num_classes=1):super(Discriminator, self).__init__()self.conv = nn.Sequential(# 28 -> 14nn.Conv2d(in_channel, 512, 3, stride=2, padding=1, bias=False),nn.BatchNorm2d(512),nn.LeakyReLU(0.2),# 14 -> 7nn.Conv2d(512, 256, 3, stride=2, padding=1, bias=False),nn.BatchNorm2d(256),nn.LeakyReLU(0.2),# 7 -> 4nn.Conv2d(256, 128, 3, stride=2, padding=1, bias=False),nn.BatchNorm2d(128),nn.LeakyReLU(0.2),nn.AvgPool2d(4),)self.fc = nn.Sequential(# reshape input, 128 -> 1nn.Linear(128, 1),nn.Sigmoid(),)def forward(self, x, y=None):y_ = self.conv(x)y_ = y_.view(y_.size(0), -1)y_ = self.fc(y_)return y_

定义生成模型

class Generator(nn.Module):"""Convolutional Generator for MNIST"""def __init__(self, input_size=100, num_classes=784):super(Generator, self).__init__()self.fc = nn.Sequential(nn.Linear(input_size, 4*4*512),nn.ReLU(),)self.conv = nn.Sequential(# input: 4 by 4, output: 7 by 7nn.ConvTranspose2d(512, 256, 3, stride=2, padding=1, bias=False),nn.BatchNorm2d(256),nn.ReLU(),# input: 7 by 7, output: 14 by 14nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1, bias=False),nn.BatchNorm2d(128),nn.ReLU(),# input: 14 by 14, output: 28 by 28nn.ConvTranspose2d(128, 1, 4, stride=2, padding=1, bias=False),nn.Tanh(),)def forward(self, x, y=None):x = x.view(x.size(0), -1)y_ = self.fc(x)y_ = y_.view(y_.size(0), 512, 4, 4)y_ = self.conv(y_)return y_

 模型超参数定义配置

batch_size = 64criterion = nn.BCELoss()
D_opt = torch.optim.Adam(D.parameters(), lr=0.001, betas=(0.5, 0.999))
G_opt = torch.optim.Adam(G.parameters(), lr=0.001, betas=(0.5, 0.999))max_epoch = 30 # need more than 20 epochs for training generator
step = 0
n_critic = 1 # for training more k steps about Discriminator
n_noise = 100D_labels = torch.ones([batch_size, 1]).to(DEVICE) # Discriminator Label to real
D_fakes = torch.zeros([batch_size, 1]).to(DEVICE) # Discriminator Label to fake

 模型训练

for epoch in range(max_epoch):for idx, (images, labels) in enumerate(data_loader):# Training Discriminatorx = images.to(DEVICE)x_outputs = D(x)D_x_loss = criterion(x_outputs, D_labels)z = torch.randn(batch_size, n_noise).to(DEVICE)z_outputs = D(G(z))D_z_loss = criterion(z_outputs, D_fakes)D_loss = D_x_loss + D_z_lossD.zero_grad()D_loss.backward()D_opt.step()if step % n_critic == 0:# Training Generatorz = torch.randn(batch_size, n_noise).to(DEVICE)z_outputs = D(G(z))G_loss = criterion(z_outputs, D_labels)D.zero_grad()G.zero_grad()G_loss.backward()G_opt.step()if step % 500 == 0:print('Epoch: {}/{}, Step: {}, D Loss: {}, G Loss: {}'.format(epoch, max_epoch, step, D_loss.item(), G_loss.item()))if step % 1000 == 0:G.eval()img = get_sample_image(G, n_noise)imsave('./{}_step{}.jpg'.format(MODEL_NAME, str(step).zfill(3)), img, cmap='gray')G.train()step += 1

测试生成效果

# generation to image
G.eval()
imshow(get_sample_image(G, n_noise), cmap='gray')

 

模型和状态参量保存

def save_checkpoint(state, file_name='checkpoint.pth.tar'):torch.save(state, file_name)# Saving params.
# torch.save(D.state_dict(), 'D_c.pkl')
# torch.save(G.state_dict(), 'G_c.pkl')
save_checkpoint({'epoch': epoch + 1, 'state_dict':D.state_dict(), 'optimizer' : D_opt.state_dict()}, 'D_dc.pth.tar')
save_checkpoint({'epoch': epoch + 1, 'state_dict':G.state_dict(), 'optimizer' : G_opt.state_dict()}, 'G_dc.pth.tar')

应用

DCGAN作为一个成熟的生成模型,在自然图像,医学图像,医学电生理信号数据分析中,都可以用来实现数据的合成,达到数据增强的目的,同时,如何减少增强数据对于后端任务的不利干扰,也是一个需要关注的方面。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/289173.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

FPGA时钟资源详解(2)——Clock-Capable Inputs

FPGA时钟系列文章总览:FPGA原理与结构(14)——时钟资源https://ztzhang.blog.csdn.net/article/details/132307564 目录 一、概述 1.1 为什么使用CC 1.2 如何使用CC 二、Clock-Capable Inputs 2.1 SRCC 2.2 MRCC 2.3 其他用途 2.3.1…

element-plus中的日期时间选择器el-date-picker;日期选择面板中选定起始与结束的日期只能改具体的时刻,日期默认是一个月没法动态修改问题

目前遇到一个问题,在使用element-plus中的日期时间选择器el-date-picker,type为datetimerange时,展示的日期选择面板有两个输入框,开始时间和结束时间,element-plus只提供了default-time 使用datetimerange进行范围选择…

我们是如何测试人工智能的(八)包含大模型的企业级智能客服系统拆解与测试方法 -- 大模型 RAG

大模型的缺陷 -- 幻觉 接触过 GPT 这样的大模型产品的同学应该都知道大模型的强大之处, 很多人都应该调戏过 GPT,跟 GPT 聊很多的天。 作为一个面向大众的对话机器人,GPT 明显是鹤立鸡群,在世界范围内还没有看到有能跟 GPT 扳手腕…

五款会让你爱不离手的编程工具,用了都说好,加班变得少。

作为一名“CV工程师” 勤勤恳恳地复制粘贴 没想到AI来了之后 连搬运都不用了! 融入了AI代码生成能力的工具 真的能代替程序员的位置吗? 看完这5个AI工具 咱们再来说结论吧! aiXcoder 在平时写代码的过程中,经常需要通过上…

flutter3_douyin:基于flutter3+dart3短视频直播实例|Flutter3.x仿抖音

flutter3-dylive 跨平台仿抖音短视频直播app实战项目。 全新原创基于flutter3.19.2dart3.3.0getx等技术开发仿抖音app实战项目。实现了类似抖音整屏丝滑式上下滑动视频、左右滑动切换页面模块,直播间进场/礼物动效,聊天等模块。 运用技术 编辑器&#x…

吴恩达2022机器学习专项课程(一) 4.2 梯度下降实践

问题预览/关键词 本节内容梯度下降更新w的公式梯度下降更新b的公式的含义α的含义为什么要控制梯度下降的幅度?导数项的含义为什么要控制梯度下降的方向?梯度下降何时结束?梯度下降算法收敛的含义正确更新梯度下降的顺序错误更新梯度下降的顺…

网络编程之流式套接字

流式套接字(SOCK_STREAM)是一种网络编程接口,它提供了一种面向连接的、可靠的、无差错和无重复的数据传输服务。这种服务保证了数据按照发送的顺序被接收,使得数据传输具有高度的稳定性和正确性。通常用于那些对数据的顺序和完整性…

【vue3学习笔记(一)】vue3简介;使用vue-cli创建工程;使用vite创建工程;分析工程结构;安装开发者工具

尚硅谷Vue2.0Vue3.0全套教程丨vuejs从入门到精通 对应课程136-140节 课程 P136节 《vue3简介》笔记 课程 P137节 《使用vue-cli创建工程》笔记 官方文档: https://cli.vuejs.org/zh/guide/creating-a-project.html#vue-create官方文档地址 查看vue-cli版本&#x…

不要盲目开抖店,这才是开店的正确流程,2024全新版教程

我是王路飞。 抖音小店和视频号小店,我更建议没有经验的新手去做抖音小店。 虽然现在抖音小店不属于是一个蓝海项目了,但它依旧是我们普通人借助抖音流量变现非常重要的一个渠道,甚至没有之一。 至于视频号小店,可以说是当下最…

【JSON2WEB】11 基于 Amis 角色功能权限设置页面

【JSON2WEB】01 WEB管理信息系统架构设计 【JSON2WEB】02 JSON2WEB初步UI设计 【JSON2WEB】03 go的模板包html/template的使用 【JSON2WEB】04 amis低代码前端框架介绍 【JSON2WEB】05 前端开发三件套 HTML CSS JavaScript 速成 【JSON2WEB】06 JSON2WEB前端框架搭建 【J…

油缸位置传感器871D-DW2NP524-N4

概述 油缸位置传感器是一种使用电感原理来检测物体接近的开关装置。它通过感应物体的电磁场来判断物体的位置,并将信号转化为电信号输出。当物体靠近或远离电感式接近开关时,物体的电磁场会改变,从而使接近开关产生不同的信号输出。电感式接…

Chrome 插件 tabs API 解析

Chrome.tabs API 解析 使用 chrome.tabs API 与浏览器的标签页系统进行交互,可以使用此 API 在浏览器中创建、修改和重新排列标签页 Tabs API 不仅提供操作和管理标签页的功能,还可以检测标签页的语言、截取屏幕截图,以及与标签页的内容脚本…

MySQL面试汇总(一)

MySQL 如何定位慢查询 如何优化慢查询 索引及其底层实现 索引是一个数据结构,可以帮助MySQL高效获取数据。 聚簇索引和非聚簇索引 覆盖索引 索引创建原则 联合索引

6. 学习方法和Java概述

文章目录 1)学习方法2)Java是什么? 1)学习方法 作为一个0基础入门的同学,在刚开始学习的时候,我们不要追求知识点的深度,而是要追求知识点的广度。简单来说,学一个知识点不要想的太…

TCP和UDP分别是什么?TCP和UDP的区别

在计算机网络通信中,TCP(Transmission Control Protocol)和UDP(User Datagram Protocol)是两种重要的传输层协议,它们在数据传输过程中发挥着关键作用。本文将深入探讨TCP和UDP的定义、特点以及它们之间的区…

【Qt】QDialog对话框

目录 一、概念 二、对话框的分类 2.1 模态对话框 2.2 非模态对话框 2.3 混合属性对话框 三、消息对话框QMessageBox 四、颜色对话框QColorDialog 五、文件对话框QFileDialog 六、字体对话框QFontDialog 七、输入对话框QInputDialog 一、概念 对话框是GUI程序中不可或…

MrDoc寻思文档 个人wiki搭建

通过Docker快速搭建个人wiki,开源wiki系统用于知识沉淀,教学管理,技术学习 部署步骤 ## 拉取 MrDoc 代码 ### 开源版: git clone https://gitee.com/zmister/MrDoc.git### 专业版: git clone https://{用户名}:{密码…

「媒体宣传」如何针对不同行业制定媒体邀约方案

传媒如春雨,润物细无声,大家好,我是51媒体网胡老师。 针对不同行业制定媒体邀约方案时,需要考虑行业特点、目标受众、媒体偏好以及市场趋势等因素。 一、懂行业 先弄清楚你的行业是啥样,有啥特别之处。 了解行业的热…

PPT没保存怎么恢复?3个方法(更新版)!

“我刚做完一个PPT,正准备保存的时候电脑没电自动关机了,打开电脑后才发现我的PPT没保存。这可怎么办?还有机会恢复吗?” 在日常办公和学习中,PowerPoint是制作演示文稿的重要工具。我们会在各种场景下使用它。但有时候…

鸿蒙OS开发实例:【工具类封装-页面路由】

import common from ohos.app.ability.common; import router from ohos.router 封装app内的页面之间跳转、app与app之间的跳转工具类 【使用要求】 DevEco Studio 3.1.1 Release api 9 【使用示例】 import MyRouterUtil from ../common/utils/MyRouterUtil MyRouterUtil…