使用matlab的生成对抗网络(Generative Adversarial Network,GAN)以及条件CGAN时,案例中 的生成器的输入为图像,改为.mat格式输入遇到的问题。解决方法
官方资源
训练条件生成对抗网络 (CGAN)- MATLAB & Simulink- MathWorks 中国此示例说明如何训练条件生成对抗网络来生成图像。https://ww2.mathworks.cn/help/deeplearning/ug/train-conditional-generative-adversarial-network.html
训练生成对抗网络 (GAN)- MATLAB & Simulink- MathWorks 中国此示例说明如何训练生成对抗网络来生成图像。https://ww2.mathworks.cn/help/deeplearning/ug/train-generative-adversarial-network.html
原代码段
imds = imageDatastore(datasetFolder,IncludeSubfolders=true,LabelSource="foldernames");augmenter = imageDataAugmenter(RandXReflection=true);
augimds = augmentedImageDatastore([64 64],imds,DataAugmentation=augmenter);augimds.MiniBatchSize = miniBatchSize;mbq = minibatchqueue(augimds, ...MiniBatchSize=miniBatchSize, ...PartialMiniBatch="discard", ...MiniBatchFcn=@preprocessData, ...MiniBatchFormat=["SSCB" "BC"], ...OutputEnvironment="auto");
为了支持批量.mat文件以Datastore方式输入到神经网络,修改如下
imds = imageDatastore(imageFolder,'FileExtensions', {'.mat'}, 'ReadFcn', @loadMAT,IncludeSubfolders=true,LabelSource="foldernames");......mbq = minibatchqueue(augimds, ...MiniBatchSize=miniBatchSize, ...PartialMiniBatch="discard", ...MiniBatchFcn=@preprocessData, ... MiniBatchFormat=["SSCB" "BC"], ...OutputEnvironment="auto");
其中“ReadFcn”属性值为句柄“@loadMAT”,定义如下
function data = loadMAT(filename)
data = load(filename);
end
运行报错提示:
错误使用 imageDataAugmenter/augment
输入图像必须为数值,并且通道数少于 4。出错 augmentedImageDatastore/augmentData (第 411 行)
miniBatchData = self.ImageAugmenter.augment(miniBatchData);出错 augmentedImageDatastore/applyAugmentationPipeline (第 405 行)
temp = self.augmentData(temp);出错 augmentedImageDatastore>@(c)self.applyAugmentationPipeline(c) (第 382 行)
Xout = cellfun(@(c) self.applyAugmentationPipeline(c),X,'UniformOutput',false);出错 augmentedImageDatastore/applyAugmentationPipelineToBatch (第 382 行)
Xout = cellfun(@(c) self.applyAugmentationPipeline(c),X,'UniformOutput',false);出错 augmentedImageDatastore/read (第 316 行)
input = self.applyAugmentationPipelineToBatch(input);出错 matlab.io.Datastore/preview (第 288 行)
data = read(copyds);出错 getPreviewFromDatastore (第 9 行)
previewData = preview(inputDatastore);出错 minibatchqueue (第 290 行)
numVariables = numel(getPreviewFromDatastore(originalDatastore));出错 ComparisonMethod (第 311 行)
mbq = minibatchqueue(augimds, ...
错误分析
定位到底层错误处 ,发现传入miniBatchData是一个struct,而正确的格式应该是一个数组
错误示例:
正确示例:
修改方式
句柄函数得到的是struct,从struct中提取数组
修改 loadMAT函数如下
function dataUse = loadMAT(filename)
data = load(filename);
fieldNames = fieldnames(data);
dataUse=data.(fieldNames{1});
end
验证
再次运行,问题解决。