MATLAB环境下生成对抗网络系列(11种)

为了构建有效的图像深度学习模型,数据增强是一个非常行之有效的方法。图像的数据增强是一套使用有限数据来提高训练数据集质量和规模的数据空间解决方案。广义的图像数据增强算法包括:几何变换、颜色空间增强、核滤波器、混合图像、随机擦除、特征空间增强、对抗训练、生成对抗网络和风格迁移等内容。增强的数据代表一个分布覆盖性更广、可靠性更高的数据点集,使用增强数据能够有效增加训练样本的多样性,最小化训练集和验证集以及测试集之间的距离。使用数据增强后的数据集训练模型,可以达到提升模型稳定性、泛化能力的效果。

使用生成对抗网络GAN提取原数据集特征,对抗生成新的目标域图像,已成为众多学者在数据增强技术研究中的优选方法。相比于传统的图像数据增强方法,通过基于GAN的生成式建模技术进行数据增强的思路来源于博弈论中的二人零和博弈,由网络中包含的生成器和判别器利用对抗学习的方法来指导网络训练,在两个网络对抗过程中估计原始数据样本的分布并生成与之相似的新数据。

近期的研究在原始生成对抗网络框架的基础上又提出了多种不同的改进方案,通过设计不同的神经网络架构和损失函数等手段不断提升生成对抗网络的性能。生成对抗网络应用在图像数据增强任务上的思想主要是其通过生成新的训练数据来扩充模型的训练数据,通过样本空间的扩充实现图像分类等任务效果的提升。但目前基于GAN的图像数据增强技术普遍存在模型收敛不稳定、生成图像质量低等问题,如何正确引入高频信息,提升图像数据质量是破解这一系列问题的关键。

MATLAB环境配置如下:

  • MATLAB 2021b
  • Deep Learning Toolbox
  • Parallel Computing Toolbox (optional for GPU usage)

目录如下

  • Generative Adversarial Network (GAN) [paper]
  • Least Squares Generative Adversarial Network (LSGAN) [paper]
  • Deep Convolutional Generative Adversarial Network (DCGAN) [paper]
  • Conditional Generative Adversarial Network (CGAN)[paper]
  • Auxiliary Classifier Generative Adversarial Network (ACGAN) [paper]
  • InfoGAN [paper]
  • Adversarial AutoEncoder (AAE)[paper]
  • Pix2Pix[paper]
  • Wasserstein Generative Adversarial Network (WGAN) [paper]
  • Semi-Supervised Generative Adversarial Network (SGAN) [paper]
  • CycleGAN [paper]
  • DiscoGAN [paper]

部分代码如下:

首先,导入相关的mnist手写数字图

load('mnistAll.mat')

然后对训练、测试图像进行预处理

trainX = preprocess(mnist.train_images); 
trainY = mnist.train_labels;%训练标签
testX = preprocess(mnist.test_images); 
testY = mnist.test_labels;%测试标签

preprocess为归一化函数,如下

function x = preprocess(x)
x = double(x)/255;
x = (x-.5)/.5;
x = reshape(x,28*28,[]);
end

然后进行参数设置,包括潜变量空间维度,batch_size大小,学习率,最大迭代次数等等

settings.latent_dim = 10;
settings.batch_size = 32; settings.image_size = [28,28,1]; 
settings.lrD = 0.0002; settings.lrG = 0.0002; settings.beta1 = 0.5;
settings.beta2 = 0.999; settings.maxepochs = 50;

下面进行编码器初始化,代码还是很容易看懂的

paramsEn.FCW1 = dlarray(initializeGaussian([512,...prod(settings.image_size)],.02));
paramsEn.FCb1 = dlarray(zeros(512,1,'single'));
paramsEn.FCW2 = dlarray(initializeGaussian([512,512]));
paramsEn.FCb2 = dlarray(zeros(512,1,'single'));
paramsEn.FCW3 = dlarray(initializeGaussian([2*settings.latent_dim,512]));
paramsEn.FCb3 = dlarray(zeros(2*settings.latent_dim,1,'single'));

解码器初始化

paramsDe.FCW1 = dlarray(initializeGaussian([512,settings.latent_dim],.02));
paramsDe.FCb1 = dlarray(zeros(512,1,'single'));
paramsDe.FCW2 = dlarray(initializeGaussian([512,512]));
paramsDe.FCb2 = dlarray(zeros(512,1,'single'));
paramsDe.FCW3 = dlarray(initializeGaussian([prod(settings.image_size),512]));
paramsDe.FCb3 = dlarray(zeros(prod(settings.image_size),1,'single'));

判别器初始化

paramsDis.FCW1 = dlarray(initializeGaussian([512,settings.latent_dim],.02));
paramsDis.FCb1 = dlarray(zeros(512,1,'single'));
paramsDis.FCW2 = dlarray(initializeGaussian([256,512]));
paramsDis.FCb2 = dlarray(zeros(256,1,'single'));
paramsDis.FCW3 = dlarray(initializeGaussian([1,256]));
paramsDis.FCb3 = dlarray(zeros(1,1,'single'));%平均梯度和平均梯度平方数组
avgG.Dis = []; avgGS.Dis = []; avgG.En = []; avgGS.En = [];
avgG.De = []; avgGS.De = [];

开始训练

dlx = gpdl(trainX(:,1),'CB');
dly = Encoder(dlx,paramsEn);
numIterations = floor(size(trainX,2)/settings.batch_size);
out = false; epoch = 0; global_iter = 0;
while ~outtic; shuffleid = randperm(size(trainX,2));trainXshuffle = trainX(:,shuffleid);fprintf('Epoch %d\n',epoch) for i=1:numIterationsglobal_iter = global_iter+1;idx = (i-1)*settings.batch_size+1:i*settings.batch_size;XBatch=gpdl(single(trainXshuffle(:,idx)),'CB');[GradEn,GradDe,GradDis] = ...dlfeval(@modelGradients,XBatch,...paramsEn,paramsDe,paramsDis,settings);% 更新判别器网络参数[paramsDis,avgG.Dis,avgGS.Dis] = ...adamupdate(paramsDis, GradDis, ...avgG.Dis, avgGS.Dis, global_iter, ...settings.lrD, settings.beta1, settings.beta2);% 更新编码器网络参数[paramsEn,avgG.En,avgGS.En] = ...adamupdate(paramsEn, GradEn, ...avgG.En, avgGS.En, global_iter, ...settings.lrG, settings.beta1, settings.beta2);% 更新解码器网络参数[paramsDe,avgG.De,avgGS.De] = ...adamupdate(paramsDe, GradDe, ...avgG.De, avgGS.De, global_iter, ...settings.lrG, settings.beta1, settings.beta2);if i==1 || rem(i,20)==0progressplot(paramsDe,settings);if i==1 h = gcf;% 捕获图像frame = getframe(h); im = frame2im(frame); [imind,cm] = rgb2ind(im,256); % 写入 GIF 文件if epoch == 0imwrite(imind,cm,'AAEmnist.gif','gif', 'Loopcount',inf); else imwrite(imind,cm,'AAEmnist.gif','gif','WriteMode','append'); end endendendelapsedTime = toc;disp("Epoch "+epoch+". Time taken for epoch = "+elapsedTime + "s")epoch = epoch+1;if epoch == settings.maxepochsout = true;end    
end

下面是完整的辅助函数

模型的梯度计算函数

function [GradEn,GradDe,GradDis]=modelGradients(x,paramsEn,paramsDe,paramsDis,settings)
dly = Encoder(x,paramsEn);
latent_fake = dly(1:settings.latent_dim,:)+...dly(settings.latent_dim+1:2*settings.latent_dim)*...randn(settings.latent_dim,settings.batch_size);
latent_real = gpdl(randn(settings.latent_dim,settings.batch_size),'CB');%训练判别器
d_output_fake = Discriminator(latent_fake,paramsDis);
d_output_real = Discriminator(latent_real,paramsDis);
d_loss = -.5*mean(log(d_output_real+eps)+log(1-d_output_fake+eps));%训练编码器和解码器
x_ = Decoder(latent_fake,paramsDe);
g_loss = .999*mean(mean(.5*(x_-x).^2,1))-.001*mean(log(d_output_fake+eps));%对于每个网络,计算关于损失函数的梯度
[GradEn,GradDe] = dlgradient(g_loss,paramsEn,paramsDe,'RetainData',true);
GradDis = dlgradient(d_loss,paramsDis);
end

提取数据函数

function x = gatext(x)
x = gather(extractdata(x));
end

GPU深度学习数组wrapper函数

function dlx = gpdl(x,labels)
dlx = gpuArray(dlarray(x,labels));
end

权重初始化函数

function parameter = initializeGaussian(parameterSize,sigma)
if nargin < 2sigma = 0.05;
end
parameter = randn(parameterSize, 'single') .* sigma;
end

dropout函数

function dly = dropout(dlx,p)
if nargin < 2p = .3;
end
[n,d] = rat(p);
mask = randi([1,d],size(dlx));
mask(mask<=n)=0;
mask(mask>n)=1;
dly = dlx.*mask;
end

编码器函数

function dly = Encoder(dlx,params)
dly = fullyconnect(dlx,params.FCW1,params.FCb1);
dly = leakyrelu(dly,.2);
dly = fullyconnect(dly,params.FCW2,params.FCb2);
dly = leakyrelu(dly,.2);
dly = fullyconnect(dly,params.FCW3,params.FCb3);
dly = leakyrelu(dly,.2);
end

解码器函数

function dly = Decoder(dlx,params)
dly = fullyconnect(dlx,params.FCW1,params.FCb1);
dly = leakyrelu(dly,.2);
dly = fullyconnect(dly,params.FCW2,params.FCb2);
dly = leakyrelu(dly,.2);
dly = fullyconnect(dly,params.FCW3,params.FCb3);
dly = leakyrelu(dly,.2);
dly = tanh(dly);
end

判别器函数

function dly = Discriminator(dlx,params)
dly = fullyconnect(dlx,params.FCW1,params.FCb1);
dly = leakyrelu(dly,.2);
dly = fullyconnect(dly,params.FCW2,params.FCb2);
dly = leakyrelu(dly,.2);
dly = fullyconnect(dly,params.FCW3,params.FCb3);
dly = sigmoid(dly);
end

工学博士,担任《Mechanical System and Signal Processing》审稿专家,担任
《中国电机工程学报》优秀审稿专家,《控制与决策》,《系统工程与电子技术》,《电力系统保护与控制》,《宇航学报》等EI期刊审稿专家。

擅长领域:现代信号处理,机器学习,深度学习,数字孪生,时间序列分析,设备缺陷检测、设备异常检测、设备智能故障诊断与健康管理PHM等。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/253039.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

告别mPDF迎来TCPDF和中文打印遇到的问题

mPDF是一个用PHP编写的开源PDF生成库。它最初由Claus Holler创建&#xff0c;于2004年发布。原来用开源软件打印中文没有问题&#xff0c;最近发现新的软件包中mPDF被TCPDF代替了&#xff0c;当然如果只用西文的PDF是没有发现问题&#xff0c;但要打印中文就有点抓瞎了如图1&am…

选择大语言模型:2024 年开源 LLM 入门指南

作者&#xff1a;来自 Elastic Aditya Tripathi 如果说人工智能在 2023 年起飞&#xff0c;这绝对是轻描淡写的说法。数千种新的人工智能工具被推出&#xff0c;人工智能功能被添加到现有的应用程序中&#xff0c;好莱坞因对这项技术的担忧而戛然而止。 甚至还有一个人工智能工…

【计算机网络】Socket的SO_TIMEOUT与连接超时时间

SO_TIMEOUT选项是Socket的一个选项&#xff0c;用于设置读取数据的超时时间。它指定了在读取数据时等待的最长时间&#xff0c;如果在指定的时间内没有数据可读取&#xff0c;将抛出SocketTimeoutException异常。 SO_TIMEOUT的设置 默认情况下&#xff0c;SO_TIMEOUT选项的值…

vue3项目中使用mapv

vue3项目中使用mapv mapv是百度地图官方提供的地图数据可视化开源项目&#xff0c;提供了很多效果酷炫的绘图api mapv地址在这里&#xff0c;示例图在这里 先解释为什么要用mapv echarts画的地图&#xff0c;都是行政区划&#xff0c;就算是geo地图&#xff0c;也只能在行政…

神经网络 | 常见的激活函数

Hi&#xff0c;大家好&#xff0c;我是半亩花海。本文主要介绍神经网络中必要的激活函数的定义、分类、作用以及常见的激活函数的功能。 目录 一、激活函数定义 二、激活函数分类 三、常见的几种激活函数 1. Sigmoid 函数 &#xff08;1&#xff09;公式 &#xff08;2&a…

【DevOps】产品需求文档(PRD)与常见原型软件

文章目录 1、PRD介绍1.1、概述1.2、前提条件1.3、主要目的1.4、关键内容1.5、表述方式1.6、需求评审人员1.7、一般内容结构 2、需求流程3、常见原型软件3.1、Word3.2、Axure3.2.1、详细介绍3.2.2、应用分类3.2.3、优缺点 3.3、摹客RP3.4、蓝湖3.5、GUI Design Studio 1、PRD介绍…

[VulnHub靶机渗透] dpwwn: 1

&#x1f36c; 博主介绍&#x1f468;‍&#x1f393; 博主介绍&#xff1a;大家好&#xff0c;我是 hacker-routing &#xff0c;很高兴认识大家~ ✨主攻领域&#xff1a;【渗透领域】【应急响应】 【python】 【VulnHub靶场复现】【面试分析】 &#x1f389;点赞➕评论➕收藏…

HTML -- 常用标签

目录 HTML 标签 单标签 双标签 常见标签的使用 标题和段落 换行、分隔、超链接 列表标签 表单标签 属性 属性的使用 HTML HTML&#xff08;Hyper Text Markup Language&#xff09;&#xff0c;超文本标记语言&#xff0c;是一门标记语言&#xff0c;不是编程语言&am…

微信小程序(三十七)选项点击高亮效果

注释很详细&#xff0c;直接上代码 上一篇 新增内容&#xff1a; 1.选择性渲染类 2.以数字为需渲染内容&#xff08;数量&#xff09; 源码&#xff1a; index.wxml <view class"Area"><!-- {{activeNumindex?Active:}}是选择性添加类名进行渲染 -->&l…

VC++添加菜单学习

新建一个单文档工程&#xff1b; 完成以后看一下有没有出现如下图的 资源视图 的tab&#xff1b;如果没有&#xff0c;在文件列表中找到xxx.rc2文件&#xff1b; 点击 资源视图 的tab&#xff0c;或者双击 .rc2 文件名&#xff0c;就会转到如下图的资源视图&#xff1b;然后展…

Redis(十三)缓存双写一致性策略

文章目录 概述示例 缓存双写一致性缓存按照操作来分&#xff0c;细分2种读写缓存&#xff1a;同步直写策略读写缓存&#xff1a;异步缓写策略双检加锁策略 数据库和缓存一致性更新策略先更新数据库&#xff0c;再更新缓存先更新缓存&#xff0c;再更新数据库先删除缓存&#xf…

大模型工作方法论

这是去年探索大模型留下的一些有效工作方法论&#xff0c;给大家分享出来。看懂着&#xff0c;一点就通&#xff1b;看不懂着&#xff0c;会老追问这到底是什么呀。 &#xff08;1&#xff09; 1、成功&#xff1a;成功才是成功之母&#xff0c;失败不是成功之母。老研究失败没…

网络选择流程分析(首选网络类型切换流程)

首先是界面,我在此平台的界面如下: 对应的入口源码位置在Settings的UniEnabledNetworkModePreferenceController中,当然其他平台可能在PreferredNetworkModePreferenceController中,流程上都是大同小异 然后点击切换按钮会调用到UniEnabledNetworkModePreferenceControlle…

Fink CDC数据同步(三)Flink集成Hive

1 目的 持久化元数据 Flink利用Hive的MetaStore作为持久化的Catalog&#xff0c;我们可通过HiveCatalog将不同会话中的 Flink元数据存储到Hive Metastore 中。 利用 Flink 来读写 Hive 的表 Flink打通了与Hive的集成&#xff0c;如同使用SparkSQL或者Impala操作Hive中的数据…

4、ChatGPT 无法完成的 5 项编码任务

ChatGPT 无法完成的 5 项编码任务 这是 ChatGPT 不能做的事情的一个清单,但这并非详尽无遗。ChatGPT 可以从头开始生成相当不错的代码,但是它不能取代你的工作。 我喜欢将 ChatGPT 视为 StackOverflow 的更智能版本。非常有帮助,但不会很快取代专业人士。当 ChatGPT 问世时…

远程主机可能不符合 glibc 和 libstdc++ Vs Code 服务器的先决条件

vscode连接远程主机报错&#xff0c;原因官方已经公布过了&#xff0c;需要远程主机 glibc>2.28&#xff0c;所以Ubuntu18及以下版本没法再远程连接了&#xff0c;其他Linux系统执行ldd --version查看glibc版本自行判断。 解决方案建议&#xff1a; 不要再想升级glibc了 问题…

少儿编程考级:智慧启迪还是智商税?

在当前科技日新月异的时代背景下&#xff0c;少儿编程教育日益受到家长和社会的广泛关注。与此同时&#xff0c;各类少儿编程考级应运而生&#xff0c;引发了公众对于其价值和意义的深度探讨。一部分人认为这是对孩子逻辑思维与创新能力的有效锻炼&#xff0c;是智慧启迪的重要…

WebGL+Three.js入门与实战——绘制水平移动的点、通过鼠标控制绘制(点击绘制、移动绘制、模拟画笔)

个人简介 &#x1f440;个人主页&#xff1a; 前端杂货铺 &#x1f64b;‍♂️学习方向&#xff1a; 主攻前端方向&#xff0c;正逐渐往全干发展 &#x1f4c3;个人状态&#xff1a; 研发工程师&#xff0c;现效力于中国工业软件事业 &#x1f680;人生格言&#xff1a; 积跬步…

C语言之找单身狗

个人主页&#xff08;找往期文章包括但不限于本期文章中不懂的知识点&#xff09;&#xff1a; 我要学编程(ಥ_ಥ)-CSDN博客 题目&#xff1a; 在一个整型数组中&#xff0c;只有一个数字出现一次&#xff0c;其他数组都是成对出现的&#xff0c;请找出那个只出现一次的数字。…

Python HTTP隧道在远程通信中的应用:穿越网络的“魔法门”

在这个数字化时代&#xff0c;远程通信就像是我们日常生活中的“魔法门”&#xff0c;让我们可以随时随地与远方的朋友、同事或服务器进行交流。而在这扇“魔法门”的背后&#xff0c;Python HTTP隧道技术发挥着举足轻重的作用。 想象一下&#xff0c;你坐在家里的沙发上&…