生成式AI系列 —— DCGAN生成手写数字

1、模型构建

1.1 构建生成器

# 导入软件包
import torch
import torch.nn as nnclass Generator(nn.Module):def __init__(self, z_dim=20, image_size=256):super(Generator, self).__init__()self.layer1 = nn.Sequential(nn.ConvTranspose2d(z_dim, image_size * 32,kernel_size=4, stride=1),nn.BatchNorm2d(image_size * 32),nn.ReLU(inplace=True))self.layer2 = nn.Sequential(nn.ConvTranspose2d(image_size * 32, image_size * 16,kernel_size=4, stride=2, padding=1),nn.BatchNorm2d(image_size * 16),nn.ReLU(inplace=True))self.layer3 = nn.Sequential(nn.ConvTranspose2d(image_size * 16, image_size * 8,kernel_size=4, stride=2, padding=1),nn.BatchNorm2d(image_size * 8),nn.ReLU(inplace=True))self.layer4 = nn.Sequential(nn.ConvTranspose2d(image_size * 8, image_size *4,kernel_size=4, stride=2, padding=1),nn.BatchNorm2d(image_size * 4),nn.ReLU(inplace=True))self.layer5 = nn.Sequential(nn.ConvTranspose2d(image_size * 4, image_size * 2,kernel_size=4, stride=2, padding=1),nn.BatchNorm2d(image_size * 2),nn.ReLU(inplace=True))self.layer6 = nn.Sequential(nn.ConvTranspose2d(image_size * 2, image_size,kernel_size=4, stride=2, padding=1),nn.BatchNorm2d(image_size),nn.ReLU(inplace=True))self.last = nn.Sequential(nn.ConvTranspose2d(image_size, 3, kernel_size=4,stride=2, padding=1),nn.Tanh())# 注意:因为是黑白图像,所以只有一个输出通道def forward(self, z):out = self.layer1(z)out = self.layer2(out)out = self.layer3(out)out = self.layer4(out)out = self.layer5(out)out = self.layer6(out)out = self.last(out)return outif __name__ == "__main__":import matplotlib.pyplot as pltG = Generator(z_dim=20, image_size=256)# 输入的随机数input_z = torch.randn(1, 20)# 将张量尺寸变形为(1,20,1,1)input_z = input_z.view(input_z.size(0), input_z.size(1), 1, 1)#输出假图像fake_images = G(input_z)print(fake_images.shape)img_transformed = fake_images[0].detach().numpy().transpose(1, 2, 0)plt.imshow(img_transformed)plt.show()

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-M0oWDbXr-1692468683782)(E:\学习笔记\深度学习笔记\生成模型\GAN\DCGAN.assets\Figure_1.png)]

1.1 构建判别器

class Discriminator(nn.Module):def __init__(self, z_dim=20, image_size=256):super(Discriminator, self).__init__()self.layer1 = nn.Sequential(nn.Conv2d(3, image_size, kernel_size=4,stride=2, padding=1),nn.LeakyReLU(0.1, inplace=True))#注意:因为是黑白图像,所以输入通道只有一个self.layer2 = nn.Sequential(nn.Conv2d(image_size, image_size*2, kernel_size=4,stride=2, padding=1),nn.LeakyReLU(0.1, inplace=True))self.layer3 = nn.Sequential(nn.Conv2d(image_size*2, image_size*4, kernel_size=4,stride=2, padding=1),nn.LeakyReLU(0.1, inplace=True))self.layer4 = nn.Sequential(nn.Conv2d(image_size*4, image_size*8, kernel_size=4,stride=2, padding=1),nn.LeakyReLU(0.1, inplace=True))self.layer5 = nn.Sequential(nn.Conv2d(image_size*8, image_size*16, kernel_size=4,stride=2, padding=1),nn.LeakyReLU(0.1, inplace=True))self.layer6 = nn.Sequential(nn.Conv2d(image_size*16, image_size*32, kernel_size=4,stride=2, padding=1),nn.LeakyReLU(0.1, inplace=True))self.last = nn.Conv2d(image_size*32, 1, kernel_size=4, stride=1)def forward(self, x):out = self.layer1(x)out = self.layer2(out)out = self.layer3(out)out = self.layer4(out)out = self.layer5(out)out = self.layer6(out)out = self.last(out)return outif __name__ == "__main__":#确认程序执行D = Discriminator(z_dim=20, image_size=64)#生成伪造图像input_z = torch.randn(1, 20)input_z = input_z.view(input_z.size(0), input_z.size(1), 1, 1)fake_images = G(input_z)#将伪造的图像输入判别器D中d_out = D(fake_images)#将输出值d_out乘以Sigmoid函数,将其转换成0~1的值print(torch.sigmoid(d_out))

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-SEGvF8xh-1692468683784)(E:\学习笔记\深度学习笔记\生成模型\GAN\DCGAN.assets\image-20230817224333376.png)]

2、数据集构建

import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), ".")))
import time
from PIL import Image
import torch
import torch.utils.data as data
import torch.nn as nnfrom torchvision import transforms
from model.DCGAN import Generator, Discriminator
from matplotlib import pyplot as pltdef make_datapath_list(root):"""创建用于学习和验证的图像数据及标注数据的文件路径列表。 """train_img_list = list() #保存图像文件的路径for img_idx in range(200):img_path = f"{root}/img_7_{str(img_idx)}.jpg"train_img_list.append(img_path)img_path = f"{root}/img_8_{str(img_idx)}.jpg"train_img_list.append(img_path)return train_img_listclass ImageTransform:"""图像的预处理类"""def __init__(self, mean, std):self.data_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])def __call__(self, img):return self.data_transform(img)class GAN_Img_Dataset(data.Dataset):"""图像的 Dataset 类,继承自 PyTorchd 的 Dataset 类"""def __init__(self, file_list, transform):self.file_list = file_listself.transform = transformdef __len__(self):'''返回图像的张数'''return len(self.file_list)def __getitem__(self, index):'''获取经过预处理后的图像的张量格式的数据'''img_path = self.file_list[index]img = Image.open(img_path)  # [ 高度 ][ 宽度 ] 黑白# 图像的预处理img_transformed = self.transform(img)return img_transformed# 创建DataLoader并确认执行结果# 创建文件列表
root = "./img_78"
train_img_list = make_datapath_list(root)# 创建Dataset
mean = (0.5)
std = (0.5)
train_dataset = GAN_Img_Dataset(file_list=train_img_list, transform=ImageTransform(mean, std)
)# 创建DataLoader
batch_size = 2
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True
)# 确认执行结果
batch_iterator = iter(train_dataloader)  # 转换为迭代器
imges = next(batch_iterator)  # 取出位于第一位的元素
print(imges.size())  # torch.Size([64, 1, 64, 64])

数据请在访问链接获取:

3、train接口实现


def train_model(G, D, dataloader, num_epochs):# 确认是否能够使用GPU加速device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")print("使用设备:", device)# 设置最优化算法g_lr, d_lr = 0.0001, 0.0004beta1, beta2 = 0.0, 0.9g_optimizer = torch.optim.Adam(G.parameters(), g_lr, [beta1, beta2])d_optimizer = torch.optim.Adam(D.parameters(), d_lr, [beta1, beta2])# 定义误差函数criterion = nn.BCEWithLogitsLoss(reduction='mean')# 使用硬编码的参数z_dim = 20mini_batch_size = 8# 将网络载入GPU中G.to(device)D.to(device)G.train()  # 将模式设置为训练模式D.train()  # 将模式设置为训练模式# 如果网络相对固定,则开启加速torch.backends.cudnn.benchmark = True# 图像张数num_train_imgs = len(dataloader.dataset)batch_size = dataloader.batch_size# 设置迭代计数器iteration = 1logs = []# epoch循环for epoch in range(num_epochs):# 保存开始时间t_epoch_start = time.time()epoch_g_loss = 0.0  # epoch的损失总和epoch_d_loss = 0.0  # epoch的损失总和print('-------------')print('Epoch {}/{}'.format(epoch, num_epochs))print('-------------')print('(train)')# 以minibatch为单位从数据加载器中读取数据的循环for imges in dataloader:# --------------------# 1.判别器D的学习# --------------------# 如果小批次的尺寸设置为1,会导致批次归一化处理产生错误,因此需要避免if imges.size()[0] == 1:continue# 如果能使用GPU,则将数据送入GPU中imges = imges.to(device)# 创建正确答案标签和伪造数据标签# 在epoch最后的迭代中,小批次的数量会减少mini_batch_size = imges.size()[0]label_real = torch.full((mini_batch_size,), 1).to(device)label_fake = torch.full((mini_batch_size,), 0).to(device)# 对真正的图像进行判定d_out_real = D(imges)# 生成伪造图像并进行判定input_z = torch.randn(mini_batch_size, z_dim).to(device)input_z = input_z.view(input_z.size(0), input_z.size(1), 1, 1)fake_images = G(input_z)d_out_fake = D(fake_images)# 计算误差d_loss_real = criterion(d_out_real.view(-1), label_real.to(torch.float))d_loss_fake = criterion(d_out_fake.view(-1), label_fake.to(torch.float))d_loss = d_loss_real + d_loss_fake# 反向传播处理g_optimizer.zero_grad()d_optimizer.zero_grad()d_loss.backward()d_optimizer.step()# --------------------# 2.生成器G的学习# --------------------# 生成伪造图像并进行判定input_z = torch.randn(mini_batch_size, z_dim).to(device)input_z = input_z.view(input_z.size(0), input_z.size(1), 1, 1)fake_images = G(input_z)d_out_fake = D(fake_images)# 计算误差g_loss = criterion(d_out_fake.view(-1), label_real.to(torch.float))# 反向传播处理g_optimizer.zero_grad()d_optimizer.zero_grad()g_loss.backward()g_optimizer.step()# --------------------# 3.记录结果# --------------------epoch_d_loss += d_loss.item()epoch_g_loss += g_loss.item()iteration += 1# epoch的每个phase的loss和准确率t_epoch_finish = time.time()print('-------------')print('epoch {} || Epoch_D_Loss:{:.4f} ||Epoch_G_Loss:{:.4f}'.format(epoch, epoch_d_loss / batch_size, epoch_g_loss / batch_size))print('timer:  {:.4f} sec.'.format(t_epoch_finish - t_epoch_start))t_epoch_start = time.time()return G, D

4、训练


G = Generator(z_dim=20, image_size=64)
D = Discriminator(z_dim=20, image_size=64)
# 定义误差函数
criterion = nn.BCEWithLogitsLoss(reduction='mean')
num_epochs = 200
G_update, D_update = train_model(G, D, dataloader=train_dataloader, num_epochs=num_epochs
)

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-EvqY6h3G-1692468683786)(E:\学习笔记\深度学习笔记\生成模型\GAN\DCGAN.assets\image-20230820020234209.png)]

5、测试

# 将生成的图像和训练数据可视化
# 反复执行本单元中的代码,直到生成感觉良好的图像为止device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")# 生成用于输入的随机数
batch_size = 8
z_dim = 20
fixed_z = torch.randn(batch_size, z_dim)
fixed_z = fixed_z.view(fixed_z.size(0), fixed_z.size(1), 1, 1)# 生成图像
fake_images = G_update(fixed_z.to(device))# 训练数据
imges = next(iter(train_dataloader))  # 取出位于第一位的元素# 输出结果
fig = plt.figure(figsize=(15, 6))
for i in range(0, 5):# 将训练数据放入上层plt.subplot(2, 5, i + 1)plt.imshow(imges[i][0].cpu().detach().numpy(), 'gray')# 将生成数据放入下层plt.subplot(2, 5, 5 + i + 1)plt.imshow(fake_images[i][0].cpu().detach().numpy(), 'gray')

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/101095.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

无涯教程-TensorFlow - Keras

Keras易于学习的高级Python库,可在TensorFlow框架上运行,它的重点是理解深度学习技术,如为神经网络创建层,以维护形状和数学细节的概念。框架的创建可以分为以下两种类型- 顺序API功能API 无涯教程将使用Jupyter Notebook执行和…

5、css学习5(链接、列表)

1、css可以设置链接的四种状态样式。 a:link - 正常,未访问过的链接a:visited - 用户已访问过的链接a:hover - 当用户鼠标放在链接上时a:active - 链接被点击的那一刻 2、 a:hover 必须在 a:link 和 a:visited 之后, a:active 必须在 a:hover 之后&…

两款开箱即用的Live2d

目录 背景第一款:开箱即用的Live2d在vue项目中使用html页面使用在线预览依赖文件地址配置相关参数成员属性源码 模型下载 第二款:换装模型超多的Live2d在线预览代码示例源码 模型下载 背景 从第一次使用服务器建站已经三年多了,记得那是在2…

php+echarts实现数据可视化实例2

效果: 代码 php <?php include(includes/session.inc); include(includes/SQL_CommonFunctions.inc); ?> <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><meta http-equiv"X-UA-Compatible"…

nginx防盗链

防盗链介绍 通过二次访问&#xff0c;请求头中带有referer&#xff0c;的方式不允许访问静态资源。 我们只希望用户通过反向代理服务器才可以拿到我们的静态资源&#xff0c;不希望别的服务器通过二次请求拿到我们的静态资源。 盗链是指在自己的页面上展示一些并不在自己服务…

Android Selector 的使用

什么是 Selector&#xff1f; Selector 和 Shape 相似&#xff0c;是Drawable资源的一种&#xff0c;可以根据不同的状态&#xff0c;设置不同的图片效果&#xff0c;关键节点 < selector > &#xff0c;例如&#xff1a;我们只需要将Button的 background 属性设置为该dr…

数据结构-->栈

&#x1f495;休对故人思故国&#xff0c;且将新火试新茶&#xff0c;诗酒趁年华&#x1f495; 作者&#xff1a;Mylvzi 文章主要内容&#xff1a;详解链表OJ题 前言&#xff1a; 前面已经学习过顺序表&#xff0c;链表。他们都是线性表&#xff0c;今天要学习的栈也是一种线…

MYSQL数据库

数据库概述 1、基本概念 1.1、数据&#xff1a;(DATA) 描述事物的符号记录,包括数字&#xff0c;文字、图形、图像、声音、档案记录等,以“记录”形式按统一的格式进行存储 1.2、表&#xff1a; 将不同的记录组织在一起&#xff0c;用来存储具体数据 1.3、数据库&#xff1…

Nacos

Nacos介绍 Nacos /nɑ:kəʊs/ 是 Dynamic Naming and Configuration Service的⾸字⺟简称&#xff0c;⼀个更易于构 建云原⽣应⽤的动态服务发现、配置管理和服务管理平台。 在这个介绍中&#xff0c;可以看出Nacos⾄少有三个核⼼功能&#xff1a; 1. 动态服务发现 2. 配…

html动态爱心代码【三】(附源码)

目录 前言 特效 内容修改 完整代码 总结 前言 七夕马上就要到了&#xff0c;为了帮助大家高效表白&#xff0c;下面再给大家带来了实用的HTML浪漫表白代码(附源码)背景音乐&#xff0c;可用于520&#xff0c;情人节&#xff0c;生日&#xff0c;表白等场景&#xff0c;可直…

【Java从入门到精通|1】从特点到第一个Hello World程序

写在前面 在计算机编程领域&#xff0c;Java是一门广泛应用的高级编程语言。它以其强大的跨平台性能、丰富的库和生态系统以及易于学习的语法而备受开发者欢迎。本文将引导您逐步了解Java的特点、如何安装和配置开发环境&#xff0c;以及如何编写您的第一个Java程序。 一、Java…

RISC-V公测平台发布· CoreMark测试报告

一. CoreMark简介 CoreMark是一款用于评估CPU性能的基准测试程序&#xff0c;它包含了多种不同的计算任务&#xff0c;包括浮点数、整数、缓存、内存等方面的测试。CoreMark的测试结果通常被用来作为CPU性能的参考&#xff0c;它可以帮助开发人员和系统管理员评估不同处理器和…

【音视频】基于webrtc的聊天室的设计

目录 术语 webrtc建连流程 系统整体架构 信令服务器房间状态管理 用户加入房间流程 用户加入房间并推流&#xff1a; 其他用户订阅此用户流 用户加入房间并订阅房间其他所有用户 用户退出房间流程 平行集群模式​编辑 第一阶段demo 设计 参考文章 术语 sdp: 在webrt…

冷冻冷藏自动化立体库|HEGERLS四向穿梭车助力打造冷链智能仓储新力量

随着中国仓储物流整体规模和低温产品消费需求的稳步增长&#xff0c;冷链市场应用潜力不断释放。而在实际运行中&#xff0c;由于冷库容量不足、基础设施落后、管理机制欠缺等原因&#xff0c;经常出现“断链”现象&#xff0c;严重威胁到产品质量和消费者安全。 河北沃克金属…

美国FDA医疗器械分类目录数据库查询

最近我们在接到FDA医疗器械咨询项目时&#xff0c;经常收到客户关于公司产品在美国FDA医疗器械认证中或是国内所属的产品类别以及如何查询产品分类的疑问。在这里&#xff0c;我将为大家解答这些问题&#xff0c;希望能够提供帮助&#xff01; 美国FDA医疗器械产品目录中包含了…

CSS如何将浏览器文字设置小于12px

CSS如何将浏览器文字设置小于12px 使用transform: scale进行缩放 transform: scale(0.8);<div><p class"first">第一段文字</p><p class"second">第二段文字</p> </div>.first {font-size: 12px; }.second {font-si…

技术分享 | 如何编写同时兼容 Vue2 和 Vue3 的代码?

LigaAI 的评论编辑器、附件展示以及富文本编辑器都支持在 Vue2&#xff08;Web&#xff09;与 Vue3&#xff08;VSCode、lDEA&#xff09;中使用。这样不仅可以在不同 Vue 版本的工程中间共享代码&#xff0c;还能为后续升级 Vue3 减少一定阻碍。 那么&#xff0c;同时兼容 Vue…

Chapter 14: Using Web Services | Python for Everybody 讲义笔记_En

文章目录 Python for Everybody课程简介Python and Web ServicesUsing Web ServiceseXtensible Markup Language - XMLParsing XMLJavaScript Object Notation - JSONParsing JSONApplication Programming InterfacesSecurity and API usageGlossary Python for Everybody Expl…

java 项目运行时,后端控制台出现空指针异常---java.lang.NullPointerException

项目场景&#xff1a; 提示&#xff1a;这里简述项目背景&#xff1a; 场景如下&#xff1a; java 项目运行时&#xff0c;后端控制台出现如下图所示报错信息&#xff1a;— 问题描述 提示&#xff1a;这里描述项目中遇到的问题&#xff1a; java 项目运行时&#xff0c;后…

代码随想录算法训练营第60天|动态规划part17| 647. 回文子串、516.最长回文子序列、动态规划总结篇

代码随想录算法训练营第60天&#xff5c;动态规划part17&#xff5c; 647. 回文子串、516.最长回文子序列、动态规划总结篇 647. 回文子串 647. 回文子串 思路&#xff1a; 暴力解法 两层for循环&#xff0c;遍历区间起始位置和终止位置&#xff0c;然后还需要一层遍历判断…