G7 - Semi-Supervised GAN 理论与实战

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

目录

  • 理论知识
  • 模型实现
    • 引用、配置参数
    • 初始化权重
    • 定义算法模型
    • 模型配置
    • 模型训练
    • 训练模型
  • 模型效果
  • 总结与心得体会


理论知识

在条件GAN中,判别器只用来判断图像的真和假,到了条件GAN中,图像本身其实是有标签的,这时候我们可能会想,为什么不直接让判别器输出图像的标签呢?本节要探究的SGAN就实现了这样一个GAN网络。

SGAN将GAN拓展到半监督学习,通过强制判别器D来输出类别标签来实现。

SGAN在一个数据集上训练一个生成器G和一个判别器D,输入是N类中的一个,在训练的时候,判别器D也被用于预测输入是属于N+1类中的哪一个,这个N+1是对应了生成器G的输出,这里的判别器D同时也充当起了分类器C的效果。

经过实验发现,这种方法可以用于训练效果更好的判别器D,并且可以比普通的GAN产生更加高质量的样本。

Semi-Supervised GAN有如下成果:

  • 作者对GANs做了一个新的扩展,允许它同时学习一个生成模型和一个分类器,我们把这个拓展称为半监督GAN或者SGAN
  • 实验结果表明,SGAN在有限数据集上比没有生成部分的基准分类器提升了分类性能
  • 实验结果表明,SGAN可以显著地提升生成样本的质量并降低生成器的训练时间
    模型判别器工作示意
    效果对比
    对比生成效果发现,SGAN比普通的DCGAN算法的结果更好。

模型实现

引用、配置参数

import argparse
import os
import numpy as np
import mathfrom torchvision import datasets, transforms
from torchvision.utils import save_imagefrom torch.utils.data import DataLoader
from torch.autograd import Variableimport torch.nn as nn
import torch.nn.functional as F
import torch# 创建结果输出目录,没有就新增,有就跳过
os.makedirs('images', exist_ok=True)# 参数
n_epochs = 50 # 训练轮数
batch_size = 64 # 每个批次的样本数量
lr = 0.0002 # 学习率
b1 = 0.5 # Adam优化器的第一个动量衰减参数
b2 = 0.999 # Adam 优化器的第二个动量衰减参数
n_cpu = 8 # 用于批次生成的CPU线程数
latent_dim = 100 # 潜在空间的维度
num_classes = 10 # 数据集的类别数
img_size = 32 # 每个图像的尺寸(高度和宽度相等)
channels = 1 # 图像的通道数(灰度图像通道数为1)
sample_interval = 400 # 图像采样间隔# 全局设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

初始化权重

def weights_init_normal(m):classname = m.__class__.__name__if classname.find('Conv')!= -1:torch.nn.init.normal_(m.weight.data, 0.0, 0.02)elif classname.find('BatchNorm') != -1:torch.nn.init.normal_(m.weight.data, 1.0, 0.02)torch.nn.init.normal_(m.bias.data, 0.0)

定义算法模型

# 生成器
class Generator(nn.Module):def __init__(self):super().__init__()# 创建一个标签嵌入层,用于将条件标签映射到潜在空间self.label_emb = nn.Embedding(num_classes, latent_dim)# 初始化图像尺寸, 用于上采样之前self.init_size = img_size //4# 第一个全连接层,将随机噪声映射到合适的维度self.l1 = nn.Sequential(nn.Linear(latent_dim, 128*self.init_size**2))# 生成器的卷积块self.conv_blocks = nn.Sequential(nn.BatchNorm2d(128),nn.Upsample(scale_factor=2),nn.Conv2d(128, 128, 3, stride=1, padding=1),nn.BatchNorm2d(128, 0.8),nn.LeakyReLU(0.2, inplace=True),nn.Upsample(scale_factor=2),nn.Conv2d(128, 64, 3, stride=1, padding=1),nn.BatchNorm2d(64, 0.8),nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(64, channels, 3, stride=1, padding=1),nn.Tanh(),)def forward(self, noise):out = self.l1(noise)out = out.view(out.shape[0], 128, self.init_size, self.init_size)img = self.conv_blocks(out)return img
# 判别器,一个分类网络
class Discriminator(nn.Module):def __init__(self):super().__init__()def discriminator_block(in_filters, out_filters, bn=True):"""返回每个鉴别器块的层"""block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)]if bn:block.append(nn.BatchNorm2d(out_filters, 0.8))return block# 鉴别器的卷积块self.conv_blocks = nn.Sequential(*discriminator_block(channels, 16, bn=False),*discriminator_block(16, 32),*discriminator_block(32, 64),*discriminator_block(64, 128),)# 下采样图像的高度和宽度ds_size = img_size // 2 ** 4# 输出层self.adv_layer = nn.Sequential(nn.Linear(128*ds_size**2, 1),nn.Sigmoid())self.aux_layer = nn.Sequential(nn.Linear(128*ds_size**2, num_classes + 1), nn.Softmax())def forward(self, img):out = self.conv_blocks(img)out = out.view(out.shape[0], -1)validity = self.adv_layer(out)label = self.aux_layer(out)return validity, label

模型配置

# 定义损失函数# 二元交叉熵损失,用于对抗训练
adversarial_loss = nn.BCELoss().to(device)
# 交叉熵损失,用于辅助分类
auxiliary_loss = nn.CrossEntropyLoss().to(device)# 初始化生成器和鉴别器
generator = Generator().to(device)
discriminator = Discriminator().to(device)# 初始化模型权重
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)# 配置数据加载器
os.makedirs('data/mnist', exist_ok=True)dataloader = DataLoader(datasets.MNIST('data/mnist',train=True,download=True,transform=transforms.Compose([transforms.Resize(img_size),transforms.ToTensor(),transforms.Normalize([0.5], [0.5])]),),batch_size=batch_size,shuffle=True)# 创建优化器
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))

模型训练

for epoch in range(opt._epochs):for i, (imgs, labels) in enumerate(dataloader):batch_size = imgs.shape[0]# 生成对抗训练的标签valid = Variable(FloatTensor(batch_size, 1).fill_(1.0), requires_grad=Falsefake = Variable(FloatTensor(batch_size, 2).fill_(0,0)fake_aux_gt = Variable(LongTensor(batch_size).fill_(opt.num_classes, requires_grad)

训练模型

for epoch in range(n_epochs):for i, (imgs, labels) in enumerate(dataloader):batch_size = imgs.shape[0]# 定义对抗训练的标签valid = torch.ones((batch_size, 1), requires_grad=False, device=device)fake = torch.zeros((batch_size, 1), requires_grad=False, device=device)fake_aux_gt= torch.ones((batch_size), dtype=torch.int64, device=device, requires_grad=False)*num_classes# 配置输入数据real_imgs = imgs.to(device)labels = labels.to(device)# ** 训练生成器 **optimizer_G.zero_grad()# 采样噪声和类别标签作为生成器的输入z = torch.rand([batch_size, latent_dim], device=device)# 生成一批图像gen_imgs = generator(z)# 计算生成器的损失衡量生成器欺骗鉴别器的能力validity, _ = discriminator(gen_imgs)g_loss = adversarial_loss(validity, valid)g_loss.backward()optimizer_G.step()# ** 训练鉴别器 **optimizer_D.zero_grad()# 真实图像的损失real_pred, real_aux = discriminator(real_imgs)d_real_loss = (adversarial_loss(real_pred, valid) + auxiliary_loss(real_aux, labels)) / 2# 生成图像的损失fake_pred, fake_aux = discriminator(gen_imgs.detach())d_fake_loss = (adversarial_loss(fake_pred, fake) + auxiliary_loss(fake_aux, fake_aux_gt)) /2# 总的鉴别器损失d_loss = (d_real_loss + d_fake_loss) / 2# 计算鉴别器准确率pred = np.concatenate([real_aux.data.cpu().numpy(), fake_aux.data.cpu().numpy()], axis=0)gt = np.concatenate([labels.data.cpu().numpy(), fake_aux_gt.data.cpu().numpy()], axis=0)d_acc = np.mean(np.argmax(pred, axis=1) == gt)d_loss.backward()optimizer_D.step()batches_done = epoch*len(dataloader) + iif batches_done % sample_interval == 0:save_image(gen_imgs.data[:25], 'images/%d.png' % batches_done, nrow=5, normalize=True)print('[Epoch %d/%d] [Batch %d/%d] [D loss: %f, acc: %d%%] [G loss: %f]' % (epoch, n_epochs, i, len(dataloader), d_loss.item(), 100*d_acc, g_loss.item()))

训练过程

[Epoch 0/50] [Batch 937/938] [D loss: 1.361139, acc: 50%] [G loss: 0.692496]
[Epoch 1/50] [Batch 937/938] [D loss: 1.339481, acc: 50%] [G loss: 0.743710]
[Epoch 2/50] [Batch 937/938] [D loss: 1.324114, acc: 50%] [G loss: 0.876165]
[Epoch 3/50] [Batch 937/938] [D loss: 1.163164, acc: 50%] [G loss: 1.803553]
[Epoch 4/50] [Batch 937/938] [D loss: 1.115438, acc: 50%] [G loss: 3.103937]
[Epoch 5/50] [Batch 937/938] [D loss: 1.045782, acc: 50%] [G loss: 4.149418]
[Epoch 6/50] [Batch 937/938] [D loss: 0.996207, acc: 56%] [G loss: 5.366407]
[Epoch 7/50] [Batch 937/938] [D loss: 0.944309, acc: 67%] [G loss: 5.629416]
[Epoch 8/50] [Batch 937/938] [D loss: 0.923516, acc: 71%] [G loss: 5.145832]
[Epoch 9/50] [Batch 937/938] [D loss: 0.941332, acc: 78%] [G loss: 3.469946]
[Epoch 10/50] [Batch 937/938] [D loss: 0.945412, acc: 75%] [G loss: 3.282685]
[Epoch 11/50] [Batch 937/938] [D loss: 0.862699, acc: 84%] [G loss: 3.509322]
[Epoch 12/50] [Batch 937/938] [D loss: 0.880701, acc: 87%] [G loss: 2.907838]
[Epoch 13/50] [Batch 937/938] [D loss: 0.853650, acc: 92%] [G loss: 4.008491]
[Epoch 14/50] [Batch 937/938] [D loss: 0.814380, acc: 93%] [G loss: 4.354833]
[Epoch 15/50] [Batch 937/938] [D loss: 0.907486, acc: 89%] [G loss: 4.128651]
[Epoch 16/50] [Batch 937/938] [D loss: 0.839670, acc: 87%] [G loss: 3.847980]
[Epoch 17/50] [Batch 937/938] [D loss: 1.118082, acc: 75%] [G loss: 3.573672]
[Epoch 18/50] [Batch 937/938] [D loss: 0.877845, acc: 87%] [G loss: 3.234770]
[Epoch 19/50] [Batch 937/938] [D loss: 1.176042, acc: 75%] [G loss: 4.499653]
[Epoch 20/50] [Batch 937/938] [D loss: 0.942495, acc: 84%] [G loss: 4.823555]
[Epoch 21/50] [Batch 937/938] [D loss: 0.874024, acc: 93%] [G loss: 3.880158]
[Epoch 22/50] [Batch 937/938] [D loss: 0.887224, acc: 90%] [G loss: 3.924105]
[Epoch 23/50] [Batch 937/938] [D loss: 0.876955, acc: 89%] [G loss: 4.332885]
[Epoch 24/50] [Batch 937/938] [D loss: 1.164700, acc: 79%] [G loss: 5.855463]
[Epoch 25/50] [Batch 937/938] [D loss: 0.824182, acc: 100%] [G loss: 3.745309]
[Epoch 26/50] [Batch 937/938] [D loss: 0.991236, acc: 87%] [G loss: 4.963309]
[Epoch 27/50] [Batch 937/938] [D loss: 0.906700, acc: 92%] [G loss: 5.675440]
[Epoch 28/50] [Batch 937/938] [D loss: 0.864558, acc: 93%] [G loss: 5.964598]
[Epoch 29/50] [Batch 937/938] [D loss: 0.788707, acc: 96%] [G loss: 7.074716]
[Epoch 30/50] [Batch 937/938] [D loss: 1.044333, acc: 84%] [G loss: 4.304685]
[Epoch 31/50] [Batch 937/938] [D loss: 0.797054, acc: 100%] [G loss: 5.197765]
[Epoch 32/50] [Batch 937/938] [D loss: 0.824380, acc: 100%] [G loss: 5.913801]
[Epoch 33/50] [Batch 937/938] [D loss: 0.978360, acc: 87%] [G loss: 3.314190]
[Epoch 34/50] [Batch 937/938] [D loss: 1.014248, acc: 78%] [G loss: 8.149563]
[Epoch 35/50] [Batch 937/938] [D loss: 1.352330, acc: 68%] [G loss: 8.068608]
[Epoch 36/50] [Batch 937/938] [D loss: 0.906131, acc: 89%] [G loss: 7.385222]
[Epoch 37/50] [Batch 937/938] [D loss: 0.813954, acc: 98%] [G loss: 5.816649]
[Epoch 38/50] [Batch 937/938] [D loss: 0.840815, acc: 98%] [G loss: 6.768866]
[Epoch 39/50] [Batch 937/938] [D loss: 0.864865, acc: 90%] [G loss: 2.277655]
[Epoch 40/50] [Batch 937/938] [D loss: 0.810660, acc: 93%] [G loss: 6.076533]
[Epoch 41/50] [Batch 937/938] [D loss: 1.189352, acc: 78%] [G loss: 4.746247]
[Epoch 42/50] [Batch 937/938] [D loss: 0.823831, acc: 90%] [G loss: 9.117917]
[Epoch 43/50] [Batch 937/938] [D loss: 0.975088, acc: 85%] [G loss: 2.690667]
[Epoch 44/50] [Batch 937/938] [D loss: 0.911645, acc: 89%] [G loss: 6.431296]
[Epoch 45/50] [Batch 937/938] [D loss: 1.214794, acc: 65%] [G loss: 5.860756]
[Epoch 46/50] [Batch 937/938] [D loss: 0.849733, acc: 98%] [G loss: 4.305855]
[Epoch 47/50] [Batch 937/938] [D loss: 0.910819, acc: 90%] [G loss: 6.148373]
[Epoch 48/50] [Batch 937/938] [D loss: 0.828892, acc: 96%] [G loss: 9.507065]
[Epoch 49/50] [Batch 937/938] [D loss: 1.086049, acc: 84%] [G loss: 5.026798]

模型效果

第一次输出的图像
第一次输出图像
25轮输出的图像
25轮输出的图像
最后输出的图像
最后输出的图像
通过图像可以发现,不同类型的数字间有很大区别,SGAN可以生成的很好

总结与心得体会

SGAN对于GAN的改进,更像是一个拥有着部分共同权重的一组小模型,可以让每个分类的图像生成的更加精确,避免生成的图像同时拥有着几种手势的特点,有点不伦不类。
目前的判别器中,无法对生成的图像打上准确的标签,这样应该会影响生成的精度,如果可以结合CGAN,直接让生成器也学习不同分类的特点,然后让判别器精确的区分,应该可以得到一个更精确的条件生成网络。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/355725.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【机器学习】Lasso回归:稀疏建模与特征选择的艺术

🌈个人主页: 鑫宝Code 🔥热门专栏: 闲话杂谈| 炫酷HTML | JavaScript基础 ​💫个人格言: "如无必要,勿增实体" 文章目录 Lasso回归:稀疏建模与特征选择的艺术引言一、Lasso回归简介1.1 基本…

leetcode:557. 反转字符串中的单词 III(python3解法)

难度:简单 给定一个字符串 s ,你需要反转字符串中每个单词的字符顺序,同时仍保留空格和单词的初始顺序。 示例 1: 输入:s "Lets take LeetCode contest" 输出:"steL ekat edoCteeL tsetnoc…

深度学习推理显卡设置

深度学习推理显卡设置 进入NVIDIA控制面板,选择 “管理3D设置”设置 "低延时模式"为 "“超高”"设置 “电源管理模式” 为 “最高性能优先” 使用锁频来获得稳定的推理 法一:命令行操作 以管理员身份打开CMD查看GPU核心可用频率&…

泵设备的监测控制和智慧运维

泵是一种输送流体或使流体增压的机械。它通过各种工作原理(如离心、柱塞等)将机械能转换为流体的动能或压力能,从而实现液体的输送、提升、循环等操作。 泵的一些具体应用场景: 1.智能水务:在城市供水管网中&#xff…

MATLAB绘图技巧-多边形区域填充图

MATLAB绘图技巧-多边形区域填充图 以下内容来自:科学网—MATLAB绘图技巧-多边形区域填充图 - 彭真明的博文 (sciencenet.cn)START 为了突出某个区域或局部数据的特性,便于数据的可视化和解释,常需要绘制二维区域填充图。MATLAB提供了三种类型…

Linux-DNS分离解析、多域名解析和子域

一、分离解析 在下面的图片中外网用户通过DNS服务器域名解析得到外网的公有IP地址;内网用户通过DNS服务器域名解析得到内网的私有IP 地址(本次实验想要得到的结果);一个IP可以对应多个域名;一个域名可以对应多个IP。 1…

【机器学习300问】119、什么是语言模型?

语言模型(Language Models)是自然语言处理(NLP)的重要组成部分,它的目的是量化一段文本或一个序列的概率。简单讲就是你给语言模型一个句子,它给你计算出特定语言中这个句子出现的概率。这样的概率度量可以…

Mac电脑FTP客户端推荐:Transmit 5 for Mac 中文版

Transmit 5是一款专为macOS平台设计的功能强大的FTP(文件传输协议)客户端软件。Transmit 5凭借其强大的功能、直观易用的界面和高效的性能,成为需要频繁进行文件传输和管理的个人用户和专业用户的理想选择。无论是对于新手还是经验丰富的用户…

报表开发工具DevExpress Reporting v23.2 - 增强PDF导出、多平台打印等

DevExpress Reporting是.NET Framework下功能完善的报表平台,它附带了易于使用的Visual Studio报表设计器和丰富的报表控件集,包括数据透视表、图表,因此您可以构建无与伦比、信息清晰的报表。 DevExpress Reporting控件日前正式发布了v23.2…

MS17-010(Eternal blue永恒之蓝)漏洞利用+修复方法

目录 一、漏洞简介 漏洞原理 影响版本 二、漏洞复现 三、复现过程 1、扫描局域网内的C段主机(主机发现) 扫描结果: 2.使用MSF的永恒之蓝漏洞模块 3.对主机进行扫描,查看其是否有永恒之蓝漏洞 4.准备攻击 四、漏洞利用 …

Marin说PCB之如何在CST仿真软件中添加三端子的电容模型?

上期文章小编我给诸位道友们分享了Murata家的三端子电容的一些特性,这期文章接着上回把三端子电容模型如何在CST软件中搭建给大家分享一下,小编我辛辛苦苦兢兢业业的给各位帖子们免费分享我的一些设计心得,这些按照小编我华山派门派的要求都是…

人脸识别——可解释的人脸识别(XFR)人脸识别模型是根据什么来识别个人的

可解释性人脸识别(XFR)? 人脸识别有一个任务叫1:N(识别)。这个任务将一个人的照片与N张注册照片进行比较,找出相似度最高的人。 这项任务用于刑事调查和出入境点。在犯罪调查中,任务从监控摄像…

大厂晋升学习方法一:海绵学习法

早晨 30 分钟 首先,我们可以把起床的闹钟提前 30 分钟,比如原来 07:30 的闹钟可以改为 07:00。不用担心提前 30 分钟起床会影响休息质量,习惯以后,早起 30 分钟不但不会影响一天的精力,甚至可能反而让人更有精神。早起…

51单片机宏定义的例子

代码 demo.c #include "hardware.h"void delay() {volatile unsigned int n;for(n 0; n < 50000; n); }int main(void) {IO_init();while(1){PINSET(LED);delay();PINCLR(LED);delay();}return 0; }cfg.h #ifndef _CFG_H_ #define _CFG_H_// #define F_CPU …

学校教育为什么要选择SOLIDWORKS教育版?

在数字化和智能化时代&#xff0c;学校教育正面临着挑战与机遇。为了培养具备创新能力和实践技能的新时代人才&#xff0c;学校教育需要引入先进的教学工具和资源。SOLIDWORKS教育版作为一款专为教育和培训目的而设计的软件&#xff0c;以其全方面的功能、友好的用户界面、丰富…

顶顶通呼叫中心中间件-替换授权文件使授权文件生效指南

一、登录my.ddrj.com下载授权文件 登录地址&#xff1a;用户-顶顶通授权管理系统 登录之后正式授权然后点击查看把license.json下载下来&#xff0c;然后替换到fs的授权文件路径&#xff0c;默认路径是&#xff1a;/ddt/fs/conf 如果安装路径不一样就需要自己去看看授权文件存…

【网络安全的神秘世界】文件上传、JBOSS、Struct漏洞复现

&#x1f31d;博客主页&#xff1a;泥菩萨 &#x1f496;专栏&#xff1a;Linux探索之旅 | 网络安全的神秘世界 | 专接本 | 每天学会一个渗透测试工具 攻防环境搭建及漏洞原理学习 Kali安装docker 安装教程 PHP攻防环境搭建 中间件介绍 介于应用系统和系统软件之间的软件。…

Day7—zookeeper基本操作

ZooKeeper介绍 ZooKeeper&#xff08;动物园管理员&#xff09;是一个分布式的、开源的分布式应用程序的协调服务框架&#xff0c;简称zk。ZooKeeper是Apache Hadoop 项目下的一个子项目&#xff0c;是一个树形目录服务。 ZooKeeper的主要功能 配置管理 分布式锁 集群管理…

VBA学习(13):获取多层文件夹内文件名并建立超链接

代码使用了FileSystemObject对象和递归的方法实现文件夹和文件的遍历功能。分别将文件夹名称和文件名提取在表格的A/B列&#xff0c;并对文件名创建了超链接。 示例代码如下&#xff1a; Sub AutoAddLink()Dim strFldPath As StringWith Application.FileDialog(msoFileDialog…

Comparison method violates its general contract! 神奇的报错

发生情况 定位到问题代码如下&#xff08;脱敏处理过后&#xff09;&#xff0c;意思是集合排序&#xff0c;如果第一个元素大于第二个元素&#xff0c;比较结果返回1&#xff0c;否则返回-1&#xff0c;这里粗略的认为小于和等于是一样的结果 List<Integer> list Arr…