目录
- 数据准备
- 定义模型并训练
- 用测试集评估性能
- 推理过程
- ⼀⾏代码查看⽹络结构
- ⼀⾏代码转onnx
- 结语
⼈⽣苦短,我⽤MATLAB。
Pytorch在深度学习领域占据了半壁江⼭,最主要的原因是⽣态完善,⽽且api直观易⽤。但谁能想到现在MATLAB⽤起来⽐Pytorch还好⽤。从数据集划分到训练,再到性能验证和画图,仅仅使⽤了⼏⼗⾏代码。炼丹师们终于可以解放编码时间,把⾃⼰的精⼒放在摸⻥(划掉)算法本身上了。
下⾯⽤⾃⼰的数据集训练⼀个YOLOv4,看看MATLAB到底怎么个事。
数据准备
我的数据内容是⾦属锻件磁粉探伤所显示的零件缺陷。如下例所示,轴孔上有⼀处明显缺陷:
数据集格式为Pascal VOC
格式(详⻅http://host.robots.ox.ac.uk/pascal/VOC/)。⾸先要将格式的标签转换为MATLAB能够使⽤的格式。使⽤以下函数:
function convertXMLtoMAT(xmlFolder, outputMATFile)% xmlFolder: 存放所有XML⽂件的⽂件夹% outputMATFile: 输出的MAT⽂件名xmlFiles = dir(fullfile(xmlFolder, '*.xml')); imageFilenames = {};boundingBoxes = {};for i = 1:length(xmlFiles) tryxmlFilePath fullfile(xmlFiles(i).folder, xmlFiles(i).name); xmlDoc = xmlread(xmlFilePath);filenameNode = xmlDoc.getElementsByTagName('filename').item(0);filename = char(filenameNode.getFirstChild.getData); pathNode = xmlDoc.getElementsByTagName('path').item(0); newPath = fullfile(pwd, 'img', filename);xminNode = xmlDoc.getElementsByTagName('xmin').item(0); yminNode = xmlDoc.getElementsByTagName('ymin').item(0); xmaxNode = xmlDoc.getElementsByTagName('xmax').item(0); ymaxNode = xmlDoc.getElementsByTagName('ymax').item(0); xmin = str2double(xminNode.getFirstChild.getData);ymin = str2double(yminNode.getFirstChild.getData); xmax = str2double(xmaxNode.getFirstChild.getData); ymax = str2double(ymaxNode.getFirstChild.getData);width = xmax - xmin; height = ymax - ymin;bbox = [xmin, ymin, width, height];imageFilenames{end+1, 1} = newPath; boundingBoxes{end+1, 1} = bbox;catch% passendend% 创建table并保存为MAT⽂件myDataset = table(imageFilenames, boundingBoxes, ... 'VariableNames', {'imageFilename', 'defect'});save(outputMATFile, 'myDataset');% 输出MAT⽂件的前⼏⾏disp('数据集的前⼏⾏:');disp(myDataset(1:4, :));end
转换后的标签⽂件就变成⼀个mat
⽂件,加载后就是以下的样⼦:
data = load("my_dataset.mat");
defectDataset = data.myDataset;
% 显示数据
disp(defectDataset(1:5, :));
table sample_num * 2 :imageFilename defect--------------------- ---------------------{'img/00001.jpg'} {[2635 435 133 682]}{'img/00001_z0.jpg'} {[2574 1086 125 561]}{'img/00001_z1.jpg'} {[2569 720 386 596]}{'img/00001_z2.jpg'} {[2608 951 303 647]}{'img/00001_z3.jpg'} {[1748 947 303 606]}
第⼆步,划分数据集。⼀般来讲,按照7:2:1的⽐例拆分训练集、测试集和验证集。
trainingRatio = 0.7;
validationRatio = 0.1;rng("default");
shuffledIndices = randperm(height(defectDataset));
idx = floor(trainingRatio * length(shuffledIndices) );trainingIdx = 1:idx;
trainingDataTbl = defectDataset(shuffledIndices(trainingIdx),:);validationIdx = idx+1 : idx + 1 + floor(validationRatio * length(shuffledIndices) );
validationDataTbl = defectDataset(shuffledIndices(validationIdx),:);testIdx = validationIdx(end)+1 : length(shuffledIndices);
testDataTbl = defectDataset(shuffledIndices(testIdx),:);
第三步,创建 Datastore
。Datastore
是MATLAB特有的⼀个概念,作⽤类似于Pytorch中的DataLoader
。它集成了⼀些常⽤的操作,如计数,按序读取,打乱顺序,合并数据集等⽅法。MATLAB的Datastore
⽐torch的DataLoader
更⽅便,⼀般⽆需⾃⼰定义,拿来就⽤。
imdsTrain = imageDatastore(trainingDataTbl{:,"imageFilename"});
bldsTrain = boxLabelDatastore(trainingDataTbl(:,"defect"));imdsValidation = imageDatastore(validationDataTbl{:,"imageFilename"});
bldsValidation = boxLabelDatastore(validationDataTbl(:,"defect"));imdsTest = imageDatastore(testDataTbl{:,"imageFilename"});
bldsTest = boxLabelDatastore(testDataTbl(:,"defect"));
% 合并
trainingData = combine(imdsTrain,bldsTrain);
validationData = combine(imdsValidation,bldsValidation);
testData = combine(imdsTest,bldsTest);
最后,为了验证数据是否正确,可以画⼀个样本检查⼀下:
data = read(trainingData);
I = data{1};
bbox = data{2};
annotatedImage = insertShape(I,"Rectangle",bbox,LineWidth=10);
annotatedImage = imresize(annotatedImage,2);figure
imshow(annotatedImage);
reset(trainingData);
定义模型并训练
这部分是本⽂的重点,但是却很短,因为MATLAB的流程太简洁了,狠狠爱了。
⾸先定义模型的主要参数:
optimizer = "adam"; % 优化器
gradientDecayFactor = 0.9; % 梯度衰减因⼦
squaredGradientDecayFactor = 0.999; % 平⽅梯度衰减因⼦
initialLearnRate = 0.001; % 初始学习率
learnRateSchedule = 'none'; % 学习率衰减策略
miniBatchSize = 4; % 批⼤⼩
L2Regularization = 0.0005; % L2正则化
MaxEpochs = 100; % 迭代轮数
inputSize = [416 416 3];
className = "defect";
第⼆步,设置yolov4的Anchors
:
rng("default")
trainingDataForEstimation =transform(trainingData,@(data)preprocessData(data,inputSize));
numAnchors = 6;
[anchors,meanIoU] = estimateAnchorBoxes(trainingDataForEstimation,numAnchors);area = anchors(:, 1).*anchors(:,2);
[~,idx] = sort(area,"descend");anchors = anchors(idx,:);
anchorBoxes = {anchors(1:3,:)
anchors(4:6,:)};
最后开始训练模型就ok了:
detector = yolov4ObjectDetector("tiny-yolov4-coco", ...
className,anchorBoxes,InputSize=inputSize);
% 如果你选择了从头开始训练,那么就得定义⼀⼤堆参数
options = trainingOptions(optimizer, ...GradientDecayFactor=gradientDecayFactor, ... SquaredGradientDecayFactor=squaredGradientDecayFactor, ... InitialLearnRate=initialLearnRate, ...LearnRateSchedule=learnRateSchedule, ... MiniBatchSize=miniBatchSize, ...L2Regularization=L2Regularization, ... MaxEpochs=MaxEpochs, ...DispatchInBackground=true, ... ResetInputNormalization=true, ... Shuffle="every-epoch", ...VerboseFrequency=20, ... ValidationFrequency=1000, ... CheckpointPath=tempdir, ...ValidationData=validationData, ... OutputNetwork="best-validation-loss");% 开始训练
[detector,info] = trainYOLOv4ObjectDetector(augmentedTrainingData,detector,options);
% 保存
save('yolov4_detector.mat',"detector"); % 模型
save('yolov4_detector_info.mat',"info"); % 训练过程
训练时会有如下输出:
Computing Input Normalization Statistics.
*************************************************************************
Training a YOLO v4 Object Detector for the following object classes:* defect正在使⽤ 'Processes' 配置⽂件启动并⾏池(parpool)...
已连接到具有 8 个⼯作进程的并⾏池。Epoch Iteration TimeElapsed LearnRate TrainingLoss ValidationLoss1 1 00:00:10 0.001 1174.1 1124.1
1 20 00:00:36 0.001 80.843
1 40 00:00:52 0.001 30.598
1 60 00:01:08 0.001 20.106
1 80 00:01:21 0.001 86.183
1 100 00:01:34 0.001 14.755
2 120 00:01:53 0.001 13.844
2 140 00:02:06 0.001 13.443
2 160 00:02:19 0.001 12.067
2 180 00:02:32 0.001 10.023
2 200 00:02:45 0.001 11.157
2 220 00:02:59 0.001 11.613 *************************************************************************
Detector training complete.
*************************************************************************
训练完成后,模型和相关的记录会分别保存在yolov4_detector.mat
和yolov4_detector_info.mat
中。
用测试集评估性能
⽤来评估性能的
detectionResults = detect(detector,testData,Threshold=0.01);
metrics = evaluateObjectDetection(detectionResults,testData);
AP = averagePrecision(metrics);
[precision,recall] = precisionRecall(metrics,ClassName=className);% 画pr图
figure 1
plot(recall{:},precision{:})
xlabel("Recall")
ylabel("Precision")
grid on
title(sprintf("Average Precision = %.2f",AP))
imshow(I);% 再画个loss曲线
figure 2
plot(info.TrainingLoss)
xlabel("迭代次数")
ylabel("loss") title('Loss曲线')
你可以说我练的不怎么样,但不能说MATLAB不⾏,因为在torch上也是⼀样的效果(哭…
推理过程
I = imread("img/00029.jpg");
[bboxes,scores,labels] = detect(detector,I);
I = insertObjectAnnotation(I,"rectangle",bboxes,scores, ...LineWidth=10,FontSize=72); figure
imshow(I)
⼀⾏代码查看⽹络结构
MATLAB⾥⾯⼀⾏代码可以做到图形化的查看⽹络结构,更专注于分析算法。
net = detector.Network
analyzeNetwork(net)
⼀⾏代码转onnx
exportONNXNetwork(net,'yolov4.onnx')
导出⽹络后,可以将该⽹络导⼊到其他深度学习框架中进⾏推理。妈妈再也不⽤担⼼我的部署了。
结语
MATLAB完全解放了炼丹师的编码时间,让炼丹师能够专注于算法本身,我愿称之为最好的深度学习框架(如果你有licence的话)