文章目录
- 前言
- 🎓一、数据集准备
- 🎓二、模型训练
- 🍀🍀1.初始化
- 🍀🍀2.加载数据集
- 🍀🍀3.划分数据集,并保存到新的文件夹
- 🍀🍀4.可视化数据集
- 🍀🍀5.模型构建
- 🍀🍀6.数据增强
- 🍀🍀7.设置训练参数
- 🍀🍀8.训练与测试
- 🎓三、模型测试
- 🍀🍀1.初始化
- 🍀🍀2.读取模型
- 🍀🍀3.读取一张图片
- 🍀🍀4.将图像调整为模型所需的输入尺寸
- 🍀🍀5.预测
- 🍀🍀6.条形图展示
- 🎓四、界面设计与实现
- 🎓五、源码下载
前言
饮食在人们的日常生活、营养与医疗建议以及运动员等专业人士的训练上起着越来越重要的作用。随着互联网的发展和医学的进步,人们普遍上传分享和记录的食物图像形成了非常多的数据库,为了改善饮食结构,塑造更健康的生活方式,分析食物的种类、热量和进食时间成为了营养学上非常重要的研究方法,本文的使用深度学习算法可以根据用户上传的食物图像自动分析食物种类。深度学习是机器学习领域的一个研究方向。深度学习通过对数据特征的学习将原始数据转化为计算机可以理解的抽象数据,根据学习到的特征,可以对原始数据进行检测或分类。深度学习有三个组成部分:架构、目标函数和学习规则。架构是模型中数据的连接规则,目标函数用于评估预测结果与真实结果的一致性,是评估深度学习模型输出结果的标准,学习规则是更新模型中参数的方式。制定合理高效的框架,目标函数和学习规则对提高深度学习模型的精度和效率至关重要。本文构建了一个不同类别的菜品或甜点食物数据集,数据集一共 4000 张,采用 ResNet-50 网络模型训练,接下来手把手教你训练、界面设计与实现。
源码不做任何修改,重点重点重点:为了代码运行不报错,安装 matlab 版本为 2021 及以上版本。
matlab安装教程大家可以参考这个教程: 手把手教你安装matlab软件
文章末尾获取源码+数据集+界面。
🎓一、数据集准备
训练前准备好数据集,可以是开源的,自己采集的都行,每一个文件夹包含一个食物,命名规则是英文的,我这里准备 10 个类别,大家根据自己需要准备自己所需的数据集就行
🎓二、模型训练
编写训练脚本,下面一步一步构建
🍀🍀1.初始化
方便运行脚本时无需手动调整路径:
clear
clc
rng default
mpath = matlab.desktop.editor.getActiveFilename;
[pathstr,~]=fileparts(mpath);
cd(pathstr);
🍀🍀2.加载数据集
加载数据集,并统计每个类别的图像数量
imgDir = fullfile("F:/BaiduNetdiskDownload/matlab_data/");
dataSet = imageDatastore(imgDir, 'IncludeSubfolders', true, 'LabelSource', 'foldernames');
trainSetDetail = countEachLabel(dataSet) % 训练数据
🍀🍀3.划分数据集,并保存到新的文件夹
划分比例根据自己需求进行划分就行,我训练集,验证集, 测试集按 6:2:2 比例划分
% 划分数据集
[trainData, valData, testData] = splitEachLabel(dataSet, 0.6, 0.2, 'randomize');% 创建保存图像的目录
trainDir = fullfile(pathstr, 'TrainingSet');
valDir = fullfile(pathstr, 'ValidationSet');
testDir = fullfile(pathstr, 'TestSet');if ~exist(trainDir, 'dir')mkdir(trainDir);
endif ~exist(valDir, 'dir')mkdir(valDir);
endif ~exist(testDir, 'dir')mkdir(testDir);
end% 复制训练数据集图像
for idx = 1:numel(trainData.Files)currentImgPath = trainData.Files{idx};[~, imgBaseName, imgExt] = fileparts(currentImgPath);destinationPath = fullfile(trainDir, [imgBaseName, imgExt]);copyfile(currentImgPath, destinationPath);
end% 复制验证数据集图像
for idx = 1:numel(valData.Files)currentImgPath = valData.Files{idx};[~, imgBaseName, imgExt] = fileparts(currentImgPath);destinationPath = fullfile(valDir, [imgBaseName, imgExt]);copyfile(currentImgPath, destinationPath);
end% 复制测试数据集图像
for idx = 1:numel(testData.Files)currentImgPath = testData.Files{idx};[~, imgBaseName, imgExt] = fileparts(currentImgPath);destinationPath = fullfile(testDir, [imgBaseName, imgExt]);copyfile(currentImgPath, destinationPath);
end
运行代码之后生成训练集,验证集, 测试集文件夹
🍀🍀4.可视化数据集
从训练图像数据集中随机抽取并显示 16 张图像进行查看
numTrainImages = numel(trainData.Labels);
idx = randperm(numTrainImages,16);
figure
for i = 1:16subplot(4,4,i)I = readimage(trainData,idx(i));imshow(I)title(trainData.Labels(idx(i)))
end
运行代码输出如下:
🍀🍀5.模型构建
我使用 resnet50 ,大家可以跟换其他模型
net = resnet50;
layers = net.Layers
inputSize = net.Layers(1).InputSize % 网络输入尺寸
运行代码报错:
错误使用 resnet50resnet50 需要 Deep Learning Toolbox Model for ResNet-50 Network 支持包以使用预训练的权重。要安装此支持包,请使用附加功能资源管理器。要使用未训练的层,请使用resnet50(Weights,"none),这不需要支持包。
解决方法:
双击我给的文件进行安装
没有邮箱自己注册一个,qq邮箱就行,之后登录就行
等待安装
之后重新运行,输出如下
下面代码的第一个 10 是全连接层的输出节点数量,意思就是数据集中的类别数量,因为我有 10 个类别,大概根据自己的数据集进行修改。‘WeightLearnRateFactor’,10 和 ‘BiasLearnRateFactor’,10 意思就是设置权重和偏置的学习率因子为 10。
lgraph = layerGraph(net);
newLearnableLayer = fullyConnectedLayer(10, ...'Name','new_fc', ...'WeightLearnRateFactor',10, ...'BiasLearnRateFactor',10);lgraph = replaceLayer(lgraph,'fc1000',newLearnableLayer);
newClassLayer = classificationLayer('Name','new_classoutput');
lgraph = replaceLayer(lgraph,'ClassificationLayer_fc1000',newClassLayer);
🍀🍀6.数据增强
为了提高模型的泛化能力,我们在训练过程中使用数据增强技术。数据增强可以通过对训练图像进行随机变换(例如翻转、平移等)来生成更多的训练样本。这有助于模型在面对不同视角、尺度和形态的图像时具有更好的泛化能力。在本例中,采用随机水平翻转以及随机平移的方法进行数据增强。此外,还需要处理灰度图像。由于预训练的 ResNet-50 模型期望输入为彩色图像,需要将灰度图像转换为伪彩色图像,使其可以被模型正确处理。
pixelRange = [-30 30];
imageAugmenter = imageDataAugmenter( ...'RandXReflection',true, ...'RandXTranslation',pixelRange, ...'RandYTranslation',pixelRange);
augimdsTrain = augmentedImageDatastore(inputSize(1:2),trainData, ...'DataAugmentation',imageAugmenter,'ColorPreprocessing','gray2rgb');
augimdsValidation = augmentedImageDatastore(inputSize(1:2),valData);
augimdsTest = augmentedImageDatastore(inputSize(1:2),testData);
🍀🍀7.设置训练参数
options = trainingOptions('sgdm', ...'MiniBatchSize',32, ...'MaxEpochs',15, ...'InitialLearnRate',1e-4, ...'Shuffle','every-epoch', ...'ValidationData',augimdsValidation, ...'ValidationFrequency',3, ...'Verbose', true, ...'Plots','training-progress');
参数解释:
- 第一个设置优化器为 sgdm
- MiniBatchSize,32:每次迭代时使用的批次大小为 32
- MaxEpochs,15:最大训练轮数为 15
- InitialLearnRate,1e-4:初始学习率为 1e-4
- Shuffle,every-epoch:在每个 epoch 结束时打乱训练数据的顺序
- ValidationData,augimdsValidation:指定用于验证的增强图像数据存储对象 augimdsValidation。
- ValidationFrequency,3:每训练 3 次迭代就进行一次验证。
- Verbose, true:在命令窗口显示训练过程的详细信息。
- Plots,training-progress:显示一个实时图表,展示训练进度,包括训练和验证的损失、准确率等指标。
大家根据自己调整训练参数就行
🍀🍀8.训练与测试
开始运行和保存训练模型
model = trainNetwork(augimdsTrain,lgraph,options);
save('ResNet_4_classes.mat', 'model');
使用该模型对测试数据进行分类,并计算测试集上的准确率
[YPred,scores] = classify(model, augimdsTest);
YTest = imdsTest.Labels;
accuracy = mean(YPred == YTest)
训练过程,这是一个比较漫长的过程
🎓三、模型测试
编写测试代码
🍀🍀1.初始化
clear
clc
rng default % 保证结果运行一致
🍀🍀2.读取模型
读取训练好的模型和获取标签名字
load('ResNet_3_classes.mat','model')
inputSize= model. Layers(1).InputSize
class_names = model. Layers(end).ClassNames ;
num_classes = numel(class_names) ;
disp(class_names(randperm(num_classes, 10)))
🍀🍀3.读取一张图片
Img = imread('11483.jpg');
figure
imshow(Img)
🍀🍀4.将图像调整为模型所需的输入尺寸
Img = imresize(Img, inputSize(1: 2));
figure
imshow(Img)
🍀🍀5.预测
将输出预测的类别和对应的置信度分数,显示在图像上方的标题中
[labels, scores] = classify(model, Img);
figure
imshow(Img)
title(string(labels)+","+ num2str(100*scores(class_names == labels),3)+"%");
运行结果如下:
🍀🍀6.条形图展示
[~,idx] = sort(scores, 'descend');
idx=idx(3:-1:1);
classNamesTop = model.Layers(end).ClassNames(idx);
scoresTop = scores(idx);
figure
%绘制条形图显示置信度值
mybar=barh(scoresTop);
xlim([0 1.1])
yticklabels(classNamesTop)
xtips = mybar(1) .YEndPoints + 0.02;
ytips =mybar(1) .XEndPoints ;
labels = string(round(mybar(1).YData, 2));
text(xtips ,ytips ,labels, 'VerticalAlignment','middle')
title('指名前3的预测结果')
xlabel('置信度值')
运行如下:
🎓四、界面设计与实现
上传图片即可输出结果,可以更换自己模型,构建成自己的系统
🎓五、源码下载
下载链接: 源码+数据集+界面
关注我,带你不挂科!!!