CGAN|生成手势图像|可控制生成

  •     🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍦 参考文章:TensorFlow入门实战|第3周:天气识别
  • 🍖 原作者:K同学啊|接辅导、项目定制

CGAN(条件生成对抗网络)的原理是在原始GAN的基础上,为生成器和判别器提供 额外的条件信息。

CGAN通过将条件信息(如类别标签或其他辅助信息)加入生成器和判别器的输入中,使得生成器能够根据这些条件信息生成特定类型的数据,而判别器则负责区分真实数据和生成数据是否符合这些条件。这种方式让生成器在生成数据时有了明确的方向,从而提高了生成数据的质量与相关性。

CGAN的特点包括有监督学习、联合隐层表征、可控性、使用卷积结构等,其具体内容为:

有监督学习:CGAN通过额外信息的使用,将原本无监督的GAN转变为一种有监督的学习模式,这使得网络的训练更加目标明确,生成结果更加符合预期。
联合隐层表征:在生成模型中,噪声输入和条件信息共同构成了联合隐层表征,这有助于生成更多样化且具有特定属性的数据。
可控性:CGAN的一个关键特点是提高了生成过程的可控性,即可以通过调整条件信息来指导模型生成特定类型的数据。
使用卷积结构:CGAN可以采用卷积神经网络作为其内部结构,这在图像相关的任务中尤其有效,因为它能够捕捉到局部特征,并提高模型对细节的处理能力。

一、前期工作

import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image, make_grid
from torchsummary import summary
import matplotlib.pyplot as pltdevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')batch_size = 128
train_transform = transforms.Compose([transforms.Resize(128),transforms.ToTensor(),transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])])train_dataset = datasets.ImageFolder(root="H:/G3/rps/rps", transform=train_transform)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True,num_workers=6)
def show_images(dl):for images, _ in dl:fig, ax = plt.subplots(figsize=(10, 10))ax.set_xticks([]); ax.set_yticks([])ax.imshow(make_grid(images.detach(), nrow=16).permute(1, 2, 0))breakshow_images(train_loader)

 

二、构建模型 

latent_dim = 100
n_classes = 3
embedding_dim = 100def weights_init(m):classname = m.__class__.__name__if classname.find('Conv') != -1:torch.nn.init.normal_(m.weight, 0.0, 0.02)elif classname.find('BatchNorm') != -1:torch.nn.init.normal_(m.weight, 1.0, 0.02)torch.nn.init.zeros_(m.bias)class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.label_conditioned_generator = nn.Sequential(nn.Embedding(n_classes, embedding_dim), nn.Linear(embedding_dim, 16)            )self.latent = nn.Sequential(nn.Linear(latent_dim, 4*4*512),  nn.LeakyReLU(0.2, inplace=True)  )self.model = nn.Sequential( nn.ConvTranspose2d(513, 64*8, 4, 2, 1, bias=False),nn.BatchNorm2d(64*8, momentum=0.1, eps=0.8),  nn.ReLU(True),            nn.ConvTranspose2d(64*8, 64*4, 4, 2, 1, bias=False),nn.BatchNorm2d(64*4, momentum=0.1, eps=0.8),nn.ReLU(True),     nn.ConvTranspose2d(64*4, 64*2, 4, 2, 1, bias=False),nn.BatchNorm2d(64*2, momentum=0.1, eps=0.8),nn.ReLU(True),       nn.ConvTranspose2d(64*2, 64*1, 4, 2, 1, bias=False),nn.BatchNorm2d(64*1, momentum=0.1, eps=0.8),nn.ReLU(True),       nn.ConvTranspose2d(64*1, 3, 4, 2, 1, bias=False),nn.Tanh()  )def forward(self, inputs):noise_vector, label = inputs  label_output = self.label_conditioned_generator(label)     label_output = label_output.view(-1, 1, 4, 4)        latent_output = self.latent(noise_vector)     latent_output = latent_output.view(-1, 512, 4, 4) concat = torch.cat((latent_output, label_output), dim=1)image = self.model(concat)return imagegenerator = Generator().to(device)
generator.apply(weights_init)
a = torch.ones(100)
b = torch.ones(1)
b = b.long()
a = a.to(device)
b = b.to(device)import torch
import torch.nn as nnclass Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.label_condition_disc = nn.Sequential(nn.Embedding(n_classes, embedding_dim),     nn.Linear(embedding_dim, 3*128*128)         )self.model = nn.Sequential(nn.Conv2d(6, 64, 4, 2, 1, bias=False),      nn.LeakyReLU(0.2, inplace=True),             nn.Conv2d(64, 64*2, 4, 3, 2, bias=False),    nn.BatchNorm2d(64*2, momentum=0.1, eps=0.8),  nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(64*2, 64*4, 4, 3, 2, bias=False),  nn.BatchNorm2d(64*4, momentum=0.1, eps=0.8),nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(64*4, 64*8, 4, 3, 2, bias=False),  nn.BatchNorm2d(64*8, momentum=0.1, eps=0.8),nn.LeakyReLU(0.2, inplace=True),nn.Flatten(),                               nn.Dropout(0.4),                            nn.Linear(4608, 1),                         nn.Sigmoid()                                )def forward(self, inputs):img, label = inputslabel_output = self.label_condition_disc(label)label_output = label_output.view(-1, 3, 128, 128)concat = torch.cat((img, label_output), dim=1)output = self.model(concat)return outputa = torch.ones(2,3,128,128)
b = torch.ones(2,1)
b = b.long()
a = a.to(device)
b = b.to(device)c = discriminator((a,b))

 三、训练模型及可视化

 这一部分主要定义初始化权重,构建鉴别器和生成器。

# 定义损失函数
adversarial_loss = nn.BCELoss() def generator_loss(fake_output, label):gen_loss = adversarial_loss(fake_output, label)return gen_lossdef discriminator_loss(output, label):disc_loss = adversarial_loss(output, label)return disc_loss
learning_rate = 0.0002G_optimizer = optim.Adam(generator.parameters(),     lr = learning_rate, betas=(0.5, 0.999))
D_optimizer = optim.Adam(discriminator.parameters(), lr = learning_rate, betas=(0.5, 0.999))# 设置训练的总轮数
num_epochs = 100
# 初始化用于存储每轮训练中判别器和生成器损失的列表
D_loss_plot, G_loss_plot = [], []# 循环进行训练
for epoch in range(1, num_epochs + 1):# 初始化每轮训练中判别器和生成器损失的临时列表D_loss_list, G_loss_list = [], []# 遍历训练数据加载器中的数据for index, (real_images, labels) in enumerate(train_loader):# 清空判别器的梯度缓存D_optimizer.zero_grad()# 将真实图像数据和标签转移到GPU(如果可用)real_images = real_images.to(device)labels      = labels.to(device)# 将标签的形状从一维向量转换为二维张量(用于后续计算)labels = labels.unsqueeze(1).long()# 创建真实目标和虚假目标的张量(用于判别器损失函数)real_target = Variable(torch.ones(real_images.size(0), 1).to(device))fake_target = Variable(torch.zeros(real_images.size(0), 1).to(device))# 计算判别器对真实图像的损失D_real_loss = discriminator_loss(discriminator((real_images, labels)), real_target)# 从噪声向量中生成假图像(生成器的输入)noise_vector = torch.randn(real_images.size(0), latent_dim, device=device)noise_vector = noise_vector.to(device)generated_image = generator((noise_vector, labels))# 计算判别器对假图像的损失(注意detach()函数用于分离生成器梯度计算图)output = discriminator((generated_image.detach(), labels))D_fake_loss = discriminator_loss(output, fake_target)# 计算判别器总体损失(真实图像损失和假图像损失的平均值)D_total_loss = (D_real_loss + D_fake_loss) / 2D_loss_list.append(D_total_loss)# 反向传播更新判别器的参数D_total_loss.backward()D_optimizer.step()# 清空生成器的梯度缓存G_optimizer.zero_grad()# 计算生成器的损失G_loss = generator_loss(discriminator((generated_image, labels)), real_target)G_loss_list.append(G_loss)# 反向传播更新生成器的参数G_loss.backward()G_optimizer.step()# 打印当前轮次的判别器和生成器的平均损失print('Epoch: [%d/%d]: D_loss: %.3f, G_loss: %.3f' % ((epoch), num_epochs, torch.mean(torch.FloatTensor(D_loss_list)), torch.mean(torch.FloatTensor(G_loss_list))))# 将当前轮次的判别器和生成器的平均损失保存到列表中D_loss_plot.append(torch.mean(torch.FloatTensor(D_loss_list)))G_loss_plot.append(torch.mean(torch.FloatTensor(G_loss_list)))if epoch%10 == 0:# 将生成的假图像保存为图片文件save_image(generated_image.data[:50], './sample_%d' % epoch + '.png', nrow=5, normalize=True)# 将当前轮次的生成器和判别器的权重保存到文件torch.save(generator.state_dict(), './generator_epoch_%d.pth' % (epoch))torch.save(discriminator.state_dict(), './discriminator_epoch_%d.pth' % (epoch))

 from numpy.random import randint, randn
from numpy import linspace
from matplotlib import pyplot as plt, gridspec
import numpy as np# Assuming 'generator' and 'device' are defined earlier in your codegenerator.load_state_dict(torch.load('./generator_epoch_100.pth'), strict=False)
generator.eval()interpolated = randn(100)
interpolated = torch.tensor(interpolated).to(device).type(torch.float32)label = 0
labels = torch.ones(1) * label
labels = labels.to(device).unsqueeze(1).long()predictions = generator((interpolated, labels))
predictions = predictions.permute(0, 2, 3, 1).detach().cpu()import warnings
warnings.filterwarnings("ignore")
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
plt.rcParams['figure.dpi'] = 100plt.figure(figsize=(8, 3))pred = (predictions[0, :, :, :] + 1) * 127.5
pred = np.array(pred)
plt.imshow(pred.astype(np.uint8))
plt.show()

 代码中的操作将预测结果的值加1(这样所有的值都变为非负数),然后乘以127.5,最终得到的值就在0到255之间。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/330116.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Crypto】Rabbit

文章目录 一、Rabbit解题感悟 一、Rabbit 题目提示很明显是Rabbit加密,直接解 小小flag,拿下! 解题感悟 提示的太明显了

二分查找

题目链接 题目: 分析: 如果按照从头到尾的顺序一次比较, 每次只能舍弃一个元素, 效率是非常低的, 而且没有用到题目的要求, 数组是有序的因为数组是有序的, 所以如果我们随便找到一个位置, 和目标元素进行比较, 如果大于目标元素, 说明该位置的右侧元素都比目标元素大, 都可…

一键恢复安卓手机数据:3个快速简便的解决方案!

安卓手机作为我们不可或缺的数字伙伴,承载着大量珍贵的个人和工作数据。然而,随着我们在手机上进行各种操作,不可避免地会遇到一些令人头痛的问题,比如意外删除文件、系统故障或其他不可预见的情况,导致重要数据的丢失…

汽车生产线中的工业机器人应用HT3S-PNS-ECS(EtherCAT/Profinet)协议转换通讯方案案例分析

汽车生产线中的工业机器人应用HT3S-PNS-ECS(EtherCAT/Profinet)协议转换通讯方案案例分析 ——北京中科易联科技有限公司供稿—— 一、摘要 随着工业自动化的快速发展,汽车生产线对工业机器人的依赖日益增加。HT3S-PNS-ECS作为工业机器人中的关键组件,其…

OpenFeign快速入门 替代RestTemplate

1.引入依赖 <!--openFeign--><dependency><groupId>org.springframework.cloud</groupId><artifactId>spring-cloud-starter-openfeign</artifactId></dependency><!--负载均衡器--><dependency><groupId>org.spr…

重学java 38.创建线程的方式⭐

It is during our darkest moments that we must focus to see the light —— 24.5.24 一、第一种方式_继承extends Thread方法 1.定义一个类,继承Thread 2.重写run方法,在run方法中设置线程任务(所谓的线程任务指的是此线程要干的具体的事儿,具体执行的代码) 3.创建自定义线程…

响应式处理-一篇打尽

纯pc端响应式 pc端平常用到的响应式布局 大致就如下三种&#xff0c;当然也会有其他方法&#xff0c;欢迎评论区补充 将div height、width设置成100% flex布局 flex布局主要是将flex-wrap: wrap&#xff0c; 最后&#xff0c;你可以通过给子元素设置 flex 属性来控制它们的…

c4d云渲染是工程文件会暴露吗?

在数字创意产业飞速发展的今天&#xff0c;C4D云渲染因其高效便捷而备受欢迎。然而&#xff0c;随着技术应用的深入&#xff0c;人们开始关注一个核心问题&#xff1a;在享受云渲染带来的便利的同时&#xff0c;C4D工程文件安全吗&#xff1f;是否会有暴露的风险&#xff1f;下…

【30天精通Prometheus:一站式监控实战指南】第4天:node_exporter从入门到实战:安装、配置详解与生产环境搭建指南,超详细

亲爱的读者们&#x1f44b;   欢迎加入【30天精通Prometheus】专栏&#xff01;&#x1f4da; 在这里&#xff0c;我们将探索Prometheus的强大功能&#xff0c;并将其应用于实际监控中。这个专栏都将为你提供宝贵的实战经验。&#x1f680;   Prometheus是云原生和DevOps的…

绿联硬盘数据恢复方法:安全、高效找回珍贵数据

在数字化时代&#xff0c;硬盘承载着大量的个人和企业数据&#xff0c;一旦数据丢失或损坏&#xff0c;后果往往不堪设想。绿联硬盘以其稳定的性能和良好的口碑赢得了众多用户的信赖&#xff0c;但即便如此&#xff0c;数据恢复问题仍然是用户可能面临的一大挑战。本文将为您详…

炫酷网页设计:HTML5 + CSS3打造8种心形特效

你以为520过去了&#xff0c;你就逃过一劫了&#xff1f;那不是还有分手呢&#xff0c;那不是还得再找对象呢&#xff0c;那不是还有七夕节呢&#xff0c;那不是还有纪念日呢&#xff0c;那不是还有各种各样的节日呢&#xff0c;所以呀&#xff0c;这8种HTML5 CSS3打造8种心形…

Java 程序的基本结构,编写和运行第一个Java程序(Hello World)!

Java程序的基本结构 Java是一种面向对象的编程语言&#xff0c;其程序结构较为规范。Java程序由一个或多个类组成&#xff0c;每个类包含数据成员和方法。 1. 包声明&#xff08;Package Declaration&#xff09; 包是Java中组织类的一种机制&#xff0c;使用包可以避免类名…

华为编程题目(实时更新)

1.大小端整数 计算机中对整型数据的表示有两种方式&#xff1a;大端序和小端序&#xff0c;大端序的高位字节在低地址&#xff0c;小端序的高位字节在高地址。例如&#xff1a;对数字 65538&#xff0c;其4字节表示的大端序内容为00 01 00 02&#xff0c;小端序内容为02 00 01…

【Django】从零开始学Django(持续更新中)

pip install Djangopython manage.py startapp index运行&#xff1a; 成功&#xff01;&#xff01;&#xff01; 在templates中新建index.html文件&#xff1a;

SpringBoot搭建Eureka注册中心

系列文章目录 文章目录 系列文章目录前言前言 前些天发现了一个巨牛的人工智能学习网站,通俗易懂,风趣幽默,忍不住分享一下给大家。点击跳转到网站,这篇文章男女通用,看懂了就去分享给你的码吧。 1、Spring-Cloud Euraka介绍 Spring-Cloud Euraka是Spring Cloud集合中一…

linux系统CPU持续飙高的排查方法

目录 前言&#xff1a; 1、查看系统cpu使用情况 2、找出占用cpu高的进程 3、进一步分析进程占用的原因&#xff01;&#xff01;&#xff01; 4、解决办法 前言&#xff1a; 如果一台服务器&#xff0c;它的cpu使用率一直处于一个高峰值&#xff0c;此时服务器可能导致无…

【数据结构与算法】之堆及其实现!

目录 1、堆的概念及结构 2、堆的实现 2.1 堆向下和向上调整算法 2.2 堆的创建 2.3 建堆时间复杂度 2.4 堆的插入 2.5 堆的删除 2.6 完整代码 3、完结散花 个人主页&#xff1a;秋风起&#xff0c;再归来~ 数据结构与算法 个人格言&#…

Hadoop3:HDFS的Fsimage和Edits文件介绍

一、概念 Fsimage文件&#xff1a;HDFS文件系统元数据的一个永久性的检查点&#xff0c;其中包含HDFS文件系统的所有目 录和文件inode的序列化信息。 Edits文件&#xff1a;存放HDFS文件系统的所有更新操作的路径&#xff0c;文件系统客户端执行的所有写操作首先 会被记录到Ed…

【状态压缩dp】最短Hamilton路径

题意&#xff1a; 从0开始&#xff0c;必须走完全部节点&#xff0c;且不重复走&#xff0c;不漏走的最短距离 关键思路&#xff1a; 从0开始 走到j 节点所走情况是 state【state表示经过的点&#xff0c;不代表顺序&#xff0c;就表示经过的点】 f[i][j]表示 从0开始 走到j…

经纬恒润第三代重载自动驾驶平板车

随着无人驾驶在封闭场地和干线道路场景的加速落地&#xff0c;港口作为无人化运营的先行者&#xff0c;其场景的复杂度、特殊性对无人化运营的技术提出了各种挑战。经纬恒润作为无人驾驶解决方案提供商&#xff0c;见证了港口在无人化运营方面的尝试及发展&#xff0c;并深度参…