解析生成对抗网络(GAN):原理与应用

目录

一、引言

二、生成对抗网络原理

(一)基本架构

(二)训练过程

三、生成对抗网络的应用

(一)图像生成

无条件图像生成:

(二)数据增强

(三)风格迁移

四、生成对抗网络训练中的挑战与解决策略

(一)模式崩溃

(二)梯度消失


一、引言

生成对抗网络(GAN)自 2014 年被 Goodfellow 等人提出以来,在深度学习领域引起了广泛的关注和研究热潮。它创新性地引入了一种对抗训练的思想,通过生成器和判别器的相互博弈,使得生成器能够学习到数据的潜在分布,从而生成逼真的样本数据。这种独特的机制使得 GAN 在图像生成、文本生成、音频生成等多个领域展现出了巨大的潜力,为人工智能技术的发展带来了新的突破和方向。

二、生成对抗网络原理

(一)基本架构

GAN 主要由两个核心组件构成:生成器(Generator)和判别器(Discriminator)。

  1. 生成器
    • 生成器的任务是接收一个随机噪声向量 (通常从一个简单的分布,如标准正态分布 N(0,1)采样得到),并通过一系列的神经网络层将其映射为与真实数据相似的生成数据G(z)
    • 例如,在图像生成任务中,生成器的输出将是一张与训练数据集中图像具有相似特征的合成图像。
    • 生成器通常采用多层的反卷积神经网络(Deconvolutional Neural Network)或转置卷积神经网络(Transposed Convolutional Neural Network)结构。以生成64*64其网络结构如下:
      import torch
      import torch.nn as nnclass Generator(nn.Module):def __init__(self):super(Generator, self).__init__()# 输入为 100 维的噪声向量self.fc = nn.Linear(100, 4 * 4 * 1024)self.deconv1 = nn.ConvTranspose2d(1024, 512, kernel_size=4, stride=2, padding=1)self.bn1 = nn.BatchNorm2d(512)self.deconv2 = nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1)self.bn2 = nn.BatchNorm2d(256)self.deconv3 = nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1)self.bn3 = nn.BatchNorm2d(128)self.deconv4 = nn.ConvTranspose2d(128, 3, kernel_size=4, stride=2, padding=1)def forward(self, x):x = self.fc(x)x = x.view(-1, 1024, 4, 4)x = torch.relu(self.bn1(self.deconv1(x)))x = torch.relu(self.bn2(self.deconv2(x)))x = torch.relu(self.bn3(self.deconv3(x)))x = torch.tanh(self.deconv4(x))return x

  2. 判别器
  • 判别器的作用是区分输入的数据是来自真实数据分布还是由生成器生成的数据。它接收真实数据 x 或生成数据 G(z),并输出一个表示数据真实性的概率值  D(x)或D(G(z)) ,取值范围在 0 到  1之间,接近  表示数据更可能是真实的,接近  表示数据更可能是生成的。

判别器通常采用卷积神经网络(Convolutional Neural Network)结构。例如,对于判断  彩色图像的判别器网络结构如下:

class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.conv1 = nn.Conv2d(3, 128, kernel_size=4, stride=2, padding=1)self.bn1 = nn.BatchNorm2d(128)self.conv2 = nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1)self.bn2 = nn.BatchNorm2d(256)self.conv3 = nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1)self.bn3 = nn.BatchNorm2d(512)self.conv4 = nn.Conv2d(512, 1, kernel_size=4, stride=2, padding=0)def forward(self, x):x = torch.relu(self.bn1(self.conv1(x)))x = torch.relu(self.bn2(self.conv2(x)))x = torch.relu(self.bn3(self.conv3(x)))x = torch.sigmoid(self.conv4(x))return x.view(-1)

(二)训练过程

GAN 的训练过程是一个对抗性的迭代过程。

三、生成对抗网络的应用

(一)图像生成

1.无条件图像生成

GAN 可以用于生成各种类型的图像,如人脸图像、风景图像等。例如,在人脸图像生成任务中,通过在大规模人脸数据集上训练 GAN,生成器能够学习到人脸的各种特征,如五官的形状、肤色、表情等,从而生成全新的、逼真的人脸图像。

代码示例:

# 假设已经定义好生成器 G 和判别器 D,以及相关的优化器和损失函数
# 训练循环
num_epochs = 100
for epoch in range(num_epochs):for i, (real_images, _) in enumerate(dataloader):# 训练判别器# 采样噪声z = torch.randn(real_images.shape[0], 100).to(device)# 生成假图像fake_images = G(z)# 计算判别器损失real_loss = criterion(D(real_images), torch.ones(real_images.shape[0]).to(device))fake_loss = criterion(D(fake_images.detach()), torch.zeros(fake_images.shape[0]).to(device))d_loss = (real_loss + fake_loss) / 2# 更新判别器参数d_optimizer.zero_grad()d_loss.backward()d_optimizer.step()# 训练生成器# 再次采样噪声z = torch.randn(real_images.shape[0], 100).to(device)# 生成假图像fake_images = G(z)# 计算生成器损失g_loss = criterion(D(fake_images), torch.ones(fake_images.shape[0]).to(device))# 更新生成器参数g_optimizer.zero_grad()g_loss.backward()g_optimizer.step()

2.条件图像生成

可以通过在生成器和判别器的输入中加入条件信息,实现条件图像生成。例如,根据给定的文本描述生成相应的图像,或者根据特定的类别标签生成属于该类别的图像。

以根据类别标签生成图像为例,在生成器的输入中除了噪声向量 ,还加入类别标签的编码向量 ,生成器的网络结构需要进行相应修改,如:

class ConditionalGenerator(nn.Module):def __init__(self, num_classes):super(ConditionalGenerator, self).__init__()# 输入为 100 维噪声向量和类别编码向量self.fc = nn.Linear(100 + num_classes, 4 * 4 * 1024)# 后续的反卷积层与无条件生成器类似self.deconv1 = nn.ConvTranspose2d(1024, 512, kernel_size=4, stride=2, padding=1)self.bn1 = nn.BatchNorm2d(512)#...def forward(self, x, y):# 拼接噪声向量和类别编码向量x = torch.cat([x, y], dim=1)x = self.fc(x)x = x.view(-1, 1024, 4, 4)x = torch.relu(self.bn1(self.deconv1(x)))#...return x

(二)数据增强

  • 图像数据增强
    • 在图像分类、目标检测等任务中,数据量不足可能导致模型过拟合。GAN 可以用于生成额外的图像数据来扩充数据集。通过在原始图像数据集上训练 GAN,生成与原始图像相似但又有一定变化的图像,如不同角度、光照条件下的图像,从而增加数据的多样性,提高模型的泛化能力。
  • 其他数据类型的数据增强
    • 除了图像数据,GAN 也可以应用于其他数据类型的数据增强,如文本数据。例如,通过生成与原始文本相似的新文本,扩充文本数据集,有助于训练更强大的文本处理模型,如文本分类、机器翻译等模型。

(三)风格迁移

  • 图像风格迁移原理
    • GAN 可以实现图像风格迁移,即将一幅图像的内容与另一幅图像的风格进行融合。其原理是通过定义内容损失和风格损失,利用生成器生成具有目标风格的图像,同时判别器用于判断生成图像的质量和风格一致性。
    • 例如,使用预训练的 VGG 网络来计算内容损失和风格损失。内容损失衡量生成图像与原始内容图像在特征表示上的差异,风格损失衡量生成图像与目标风格图像在风格特征(如纹理、颜色分布等)上的差异。

代码示例实现风格迁移

import torchvision.models as models
import torch.nn.functional as F# 加载预训练的 VGG 网络
vgg = models.vgg19(pretrained=True).features.eval().to(device)# 定义内容损失函数
def content_loss(content_features, generated_features):return F.mse_loss(content_features, generated_features)# 定义风格损失函数
def style_loss(style_features, generated_features):style_loss = 0for s_feat, g_feat in zip(style_features, generated_features):# 计算 Gram 矩阵s_gram = gram_matrix(s_feat)g_gram = gram_matrix(g_feat)style_loss += F.mse_loss(s_gram, g_gram)return style_loss# Gram 矩阵计算函数
def gram_matrix(x):b, c, h, w = x.size()features = x.view(b * c, h * w)gram = torch.mm(features, features.t())return gram.div(b * c * h * w)

四、生成对抗网络训练中的挑战与解决策略

(一)模式崩溃

问题描述

模式崩溃是 GAN 训练中常见的问题之一,表现为生成器生成的样本多样性不足,往往集中在少数几种模式上。例如,在生成人脸图像时,可能生成的人脸都具有相似的特征,而不能涵盖人脸的多种可能形态。

解决策略

Wasserstein GAN(WGAN):WGAN 对 GAN 的损失函数进行了改进,采用 Wasserstein 距离来衡量真实数据分布和生成数据分布之间的差异,而不是传统的 JS 散度。这使得训练过程更加稳定,减少了模式崩溃的发生。其关键代码修改如下:

# 判别器的最后一层不再使用 Sigmoid 激活函数
self.conv4 = nn.Conv2d(512, 1, kernel_size=4, stride=2, padding=0)
# 定义 WGAN 的损失函数
def wgan_loss(real_pred, fake_pred):return -torch.mean(real_pred) + torch.mean(fake_pred)

模式正则化:通过在生成器的损失函数中加入正则化项,鼓励生成器生成更多样化的样本。例如,在生成器的损失函数中加入对生成样本的熵约束,使得生成样本的分布更加均匀。

(二)梯度消失

  • 问题描述
    • 在 GAN 训练初期,当判别器的性能非常好时,生成器的梯度可能会变得非常小,导致生成器难以更新参数,无法有效地学习到数据的分布。这是因为判别器能够很容易地区分真实数据和生成数据,使得生成器的损失函数接近饱和,梯度趋近于 。
  • 解决策略
    • 梯度惩罚(Gradient Penalty):在判别器的损失函数中加入梯度惩罚项,限制判别器的梯度大小,使得判别器不会过于强大,从而缓解生成器的梯度消失问题。例如,在 WGAN-GP(Wasserstein GAN with Gradient Penalty)中,梯度惩罚项的计算如下:
      def gradient_penalty(critic, real, fake, device):BATCH_SIZE, C, H, W = real.shape# 随机采样插值系数alpha = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)# 计算插值数据interpolated_images = real * alpha + fake * (1 - alpha)# 计算判别器对插值数据的输出mixed_scores = critic(interpolated_images)# 计算梯度gradient = torch.autograd.grad(inputs=interpolated_images,outputs=mixed_scores,grad_outputs=torch.ones_like(mixed_scores),create_graph=True,retain_graph=True,)[0]# 计算梯度惩罚项gradient = gradient.view(gradient.shape[0], -1)gradient_norm = gradient.norm(2, dim=1)gradient_penalty = torch.mean((gradient_norm - 1) ** 2)return gradient_penalty

    • 使用 Leaky ReLU 激活函数:在判别器和生成器的网络中使用 Leaky ReLU 激活函数替代传统的 ReLU 激活函数。Leaky ReLU 允许在负半轴有一个较小的斜率,从而避免了在某些情况下神经元完全不激活导致的梯度消失问题。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/481647.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

学习视频超分辨率扩散模型中的空间适应和时间相干性(原文翻译)

文章目录 摘要1. Introduction2. Related Work3. Our Approach3.1. Video Upscaler3.2. Spatial Feature Adaptation Module3.3. Temporal Feature Alignment Module3.4. Video Refiner3.5. Training Strategy 4. Experiments4.1. Experimental Settings4.2. Comparisons with …

全面解读权限控制与RBAC模型在若依中的实现

目录 前言1 权限控制基础概念1.1 权限控制的核心要素1.2 常见权限控制模型 2 RBAC模型详解2.1 RBAC的基本原理2.2 RBAC的优点2.3 RBAC的扩展模型 3 若依框架中的权限管理3.1 菜单管理3.2 角色管理3.3 用户管理 4 若依权限管理的实现流程4.1 创建菜单4.2 创建角色并分配权限4.3 …

Mybatis:CRUD数据操作之单个条件(动态SQL)

Mybatis基础环境准备请看:Mybatis基础环境准备 本篇讲解Mybati数据CRUD数据操作之单个条件(动态SQL) 如上图所示,用户在查询时只能选择 品牌名称、当前状态、企业名称 这三个条件中的一个,但是用户到底选择哪儿一个&am…

2023信息安全管理与评估-linux应急响应-1

靶机的环境:Linux webserver 5.4.0-109-generic 1.提交攻击者的 ip地址 linux应急响应需要知道黑客想要进入服务器或者内网,一定从web入手 所以应急响应的i第一步应该就是看一下web日志,看进行了神码操作 apache的网站日志是/var/log/apac…

【Springboot】@Autowired和@Resource的区别

【Springboot】Autowired和Resource的区别 【一】定义【1】Autowired【2】Resource 【二】区别【1】包含的属性不同【2】Autowired默认按byType自动装配,而Resource默认byName自动装配【3】注解应用的地方不同【4】出处不同【5】装配顺序不用(1&#xff…

服务器遭受DDoS攻击后如何恢复运行?

当服务器遭受 DDoS(分布式拒绝服务)攻击 后,恢复运行需要快速采取应急措施来缓解攻击影响,并在恢复后加强防护以减少未来攻击的风险。以下是详细的分步指南: 一、应急处理步骤 1. 确认服务器是否正在遭受 DDoS 攻击 …

Linux命令系列-常见查看系统资源命令

Linux命令系列-常见查看命令 进程管理内存管理磁盘空间管理网络管理主机系统 摘要:本文将对linux系统上常见的查看系统各种资源的命令进行介绍,包括du,df,netstat等命令。所有这些命令都有相关实验截图,实验平台为ubun…

1-1 Gerrit实用指南

注:学习gerrit需要拥有git相关知识,如果没有学习过git请先回顾git相关知识点 黑马程序员git教程 一小时学会git git参考博客 git 实操博客 1.0 定义 Gerrit 是一个基于 Web 的代码审查系统,它使用 Git 作为底层版本控制系统。Gerrit 的主要功…

Node.js:开发和生产之间的区别

Node.js 中的开发和生产没有区别,即,你无需应用任何特定设置即可使 Node.js 在生产配置中工作。但是,npm 注册表中的一些库会识别使用 NODE_ENV 变量并将其默认为 development 设置。始终在设置了 NODE_ENVproduction 的情况下运行 Node.js。…

【Linux】【字符设备驱动】深入解析

Linux字符设备驱动程序用于控制不支持随机访问的硬件设备,如串行端口、打印机、调制解调器等。这类设备通常以字符流的形式与用户空间程序进行交互。本节将深入探讨字符设备驱动的设计原理、实现细节及其与内核其他组件的交互。 1. 引言 字符设备驱动程序是Linux内…

计算机毕业设计Python异常流量检测 流量分类 流量分析 网络流量分析与可视化系统 网络安全 信息安全 机器学习 深度学习

温馨提示:文末有 CSDN 平台官方提供的学长联系方式的名片! 温馨提示:文末有 CSDN 平台官方提供的学长联系方式的名片! 温馨提示:文末有 CSDN 平台官方提供的学长联系方式的名片! 作者简介:Java领…

排序算法之选择排序堆排序

算法时间复杂度辅助空间复杂度稳定性选择排序O(N^2)O(1)不稳定堆排序O(NlogN)O(1)不稳定 1.选择排序 这应该算是最简单的排序算法了,每次在右边无序区里选最小值,没有无序区时,就宣告排序完毕 比如有一个数组:[2,3,2,6,5,1,4]排…

搜索二维矩阵 II(java)

题目描述 编写一个高效的算法来搜索 m x n 矩阵 matrix 中的一个目标值 target 。该矩阵具有以下特性: 每行的元素从左到右升序排列。每列的元素从上到下升序排列。 代码思路: 用暴力算法: class Solution {public boolean searchMatrix(…

week 9 - Entity-Relationship Modelling

一、数据库设计的重要性 • 设计数据库可使查询更高效、简洁。 • 减少数据冗余(data redundancy),提升表的整洁性。 二、Key Components of ER Modelling 实体-关系建模的基本构成 1. 实体(Entity):表…

玻璃效果和窗户室内效果模拟

一、玻璃效果 首先来讲如何模拟玻璃效果。玻璃的渲染包括三部分,普通场景物体的渲染、反射和折射模拟、毛玻璃模拟。作为场景物体,那么类似其它场景物体Shader一样,可以使用PBR、BlingPhong或者Matcap,甚至三阶色卡通渲染都可以。…

STL算法之set相关算法

STL一共提供了四种与set(集合)相关的算法,分别是并集(union)、交集(intersection)、差集(difference)、对称差集(symmetric difference)。 目录 set_union set_itersection set_difference set_symmetric_difference 所谓set,可细分为数学上定义的和…

房屋结构安全监测系统守护房屋安全卫士

一、系统背景 随着时间的流逝,建筑物的主体结构、设备设施等会因为自然老化、材料疲劳、使用环境的变化以及维护不当等各种因素的影响,逐渐出现性能下降甚至安全隐患。因此,进行房屋安全监测显得尤为重要。房屋结构安全是指建筑物的结构体系在…

uniapp实现组件竖版菜单

社区图片页面 scroll-view scroll-view | uni-app官网 (dcloud.net.cn) 可滚动视图区域。用于区域滚动。 需注意在webview渲染的页面中&#xff0c;区域滚动的性能不及页面滚动。 <template><view class"pics"><scroll-view class"left"…

241127学习日志——[CSDIY] [InternStudio] 大模型训练营 [20]

CSDIY&#xff1a;这是一个非科班学生的努力之路&#xff0c;从今天开始这个系列会长期更新&#xff0c;&#xff08;最好做到日更&#xff09;&#xff0c;我会慢慢把自己目前对CS的努力逐一上传&#xff0c;帮助那些和我一样有着梦想的玩家取得胜利&#xff01;&#xff01;&…

多模态图像生成模型Qwen2vl-Flux,利用Qwen2VL的视觉语言理解能力增强FLUX,可集成ControlNet

Qwen2vl-Flux 是一种先进的多模态图像生成模型&#xff0c;它利用 Qwen2VL 的视觉语言理解能力增强了 FLUX。该模型擅长根据文本提示和视觉参考生成高质量图像&#xff0c;提供卓越的多模态理解和控制。让 FLUX 的多模态图像理解和提示词理解变得很强。 Qwen2vl-Flux有以下特点…