《迁移学习》—— 将 ResNet18 模型迁移到食物分类项目中

文章目录

  • 一、迁移学习的简单介绍
    • 1.迁移学习是什么?
    • 2.迁移学习的步骤
  • 二、数据集介绍
  • 三、代码实现
    • 1. 步骤
    • 2.所用到方法介绍的文章链接
    • 3. 完整代码

一、迁移学习的简单介绍

1.迁移学习是什么?

  • 迁移学习是指利用已经训练好的模型,在新的任务上进行微调。
  • 迁移学习可以加快模型训练速度,提高模型性能,并且在数据稀缺的情况下也能很好地工作。

2.迁移学习的步骤

  • (1) 选择预训练的模型和适当的层:通常,我们会选择在大规模图像数据集(如ImageNet)上预训练的模型,如VGG、ResNet等。然后,根据新数据集的特点,选择需要微调的模型层。对于低级特征的任务(如边缘检测),最好使用浅层模型的层,而对于高级特征的任务(如分类),则应选择更深层次的模型。
  • (2) 冻结预训练模型的参数:保持预训练模型的权重不变,只训练新增加的层或者微调一些层,避免因为在数据集中过拟合导致预训练模型过度拟合。
  • (3) 在新数据集上训练新增加的层:在冻结预训练模型的参数情况下,训练新增加的层。这样,可以使新模型适应新的任务,从而获得更高的性能。
  • (4) 微调预训练模型的层:在新层上进行训练后,可以解冻一些已经训练过的层,并且将它们作为微调的目标。这样做可以提高模型在新数据集上的性能。
  • (5) 评估和测试:在训练完成之后,使用测试集对模型进行评估。如果模型的性能仍然不够好,可以尝试调整超参数或者更改微调层。

二、数据集介绍

  • 下图是数据集的结构
    • 在 food_dataset2 文件夹下含有训练数据和测试数据
    • 训练集和测试集数据中都含有 20 种食物图片,数量在200~400不等
    • trainda.txt 和 testda.txt 文本中存放了每张图片的路径及标签,用 0~19 这20个数字分别对20种食物进行标签
    • 在代码中通过trainda.txt 和 testda.txt 文本中的内容来获取每张图片及对应的标签
      在这里插入图片描述
    • 下面是trainda.txt文本中的部分内容(testda.txt 中的内容格式相同)
      在这里插入图片描述
  • 送福利!!! 私信送此数据集 !!!

三、代码实现

1. 步骤

  • 1.调用resnet18模型,并保存需要训练的模型参数
  • 2.定义一个图像预处理和数据增强字典
  • 3.定义获取每张食物图片和标签的类方法
  • 4.获取训练集和测试集数据
  • 5.对数据集进行打包
  • 6.调用交叉熵损失函数并创建优化器
  • 7.定义训练模型的函数
  • 8.定义测试模型的函数
  • 9.训练模型,并每训练一轮测试一次

2.所用到方法介绍的文章链接

  • ResNet 残差网络神经网络
    • https://blog.csdn.net/weixin_73504499/article/details/142575775?spm=1001.2014.3001.5501
  • 数据增强
    • https://blog.csdn.net/weixin_73504499/article/details/142499263?spm=1001.2014.3001.5501
  • 调整学习率
    • https://blog.csdn.net/weixin_73504499/article/details/142526863?spm=1001.2014.3001.5501

3. 完整代码

import torchimport torchvision.models as models  # 导入存有各种深度学习模型的模块from torch import nn  # 导入神经网络模块from torch.utils.data import Dataset, DataLoader  # Dataset: 抽象类,一种用于获取数据的方法  DataLoader:数据包管理工具,打包数据from torchvision import transforms  # transforms模块提供了一系列用于图像预处理和数据增强的函数和类from PIL import Image  # 用于处理图片import numpy as np""" 调用resnet18模型 """resnet_model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)for param in resnet_model.parameters():param.requires_grad = False# 模型所有参数(即权重和偏差)的 requires_grad 属性设置成 False,从而冻结所有模型参数# 使得在反向传播过程中不会计算他们的梯度,从此减少模型的计算量,提高推理速度in_features = resnet_model.fc.in_features  # 获取resnet18模型全连接层原输入的特征个数resnet_model.fc = nn.Linear(in_features, 20)  # 创建一个全连接层输入特征个数为: in_features  输出特征个数为:数据集中事务的种类数量params_to_update = []  # 保存需要训练的参数,仅仅包含全连接层的参数for param in resnet_model.parameters():if param.requires_grad == True:params_to_update.append(param)""" 图像预处理和数据增强 """data_transforms = {'train':transforms.Compose([transforms.Resize([300, 300]),transforms.RandomRotation(45),  # 随机旋转,-45到45度之间随机transforms.CenterCrop(224),  # 中心裁剪transforms.RandomHorizontalFlip(p=0.5),  # 随机水平反转 选择一个概率transforms.RandomVerticalFlip(p=0.5),  # 随机垂直翻转# transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1),  # 亮度、对比度transforms.RandomGrayscale(p=0.1),  # 概率转换成灰度率,3通道就是R G Btransforms.ToTensor(),  # 转化为神经网络可以识别的 Tensor 类型transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # 对图片数据进行归一化,[均值],[标准差]]),'valid':transforms.Compose([transforms.Resize([224, 224]),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}""" 定义获取每张食物图片和标签的类方法 """class food_dataset(Dataset):def __init__(self, file_path, transform=None):self.file_path = file_pathself.imgs = []self.labels = []self.transform = transformwith open(self.file_path) as f:samples = [x.strip().split(' ') for x in f.readlines()]for img_path, label in samples:self.imgs.append(img_path)self.labels.append(label)def __len__(self):return len(self.imgs)def __getitem__(self, idx):image = Image.open(self.imgs[idx])if self.transform:image = self.transform(image)label = self.labels[idx]label = torch.from_numpy(np.array(label, dtype=np.int64))return image, label""" 获取训练集和测试集数据 """training_data = food_dataset(file_path='trainda.txt', transform=data_transforms['train'])test_data = food_dataset(file_path='testda.txt', transform=data_transforms['valid'])""" 对数据集进行打包 """train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)  # 64张图片为一个包,shuffle --> 打乱顺序test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)""" 判断当前设备是否支持GPU,其中mps是苹果m系列芯片的GPU """device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"print(f"Using {device} device")# 把模型传入到 gpu 或 cpumodel = resnet_model.to(device)""" 调用交叉熵损失函数 """loss_fn = nn.CrossEntropyLoss()"""" 创建优化器并调整优化器中的学习率--> lr """optimizer = torch.optim.Adam(params_to_update, lr=0.001)scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)""" 定义训练模型的函数 """def train(dataloader, model, loss_fn, optimizer):model.train()  # 告诉模型,开始训练# pytorch提供2种方式来切换训练和测试的模式,分别是:model.train()和 model.eval()。# 一般用法是:在训练开始之前写上model.trian(),在测试时写上model.for X, y in dataloader:X, y = X.to(device), y.to(device)  # 把训练数据集和标签传入cpu或gpupred = model.forward(X)  # .forward可以被省略,父类中已经对次功能进行了设置。自动初始化w权值loss = loss_fn(pred, y)  # 通过交叉熵损失函数计算损失值 lossoptimizer.zero_grad()  # 梯度值清零loss.backward()  # 反向传播计算得到每个参数的梯度值woptimizer.step()  # 根据梯度更新网络w参数""" 定义测试模型的函数 """best_acc = 0  # 用于更新准确率def test(dataloader, model, loss_fn):global best_accsize = len(dataloader.dataset)num_batches = len(dataloader)model.eval()  # 测试,w就不能再更新test_loss, correct = 0, 0with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model.forward(X)test_loss += loss_fn(pred, y).item()  # test_loss是会自动累加每一个批次的损失值correct += (pred.argmax(1) == y).type(torch.float).sum().item()  # correct是会自动累加每一个批次的正确率test_loss /= num_batches  # 平均的损失值correct /= size  # 平均的正确率print(f"Test result: \n Accuracy: {(100 * correct)}%, Avg loss: {test_loss}")# 找到最好的准确率if correct > best_acc:best_acc = correct""" 定义模型训练的轮数,并每训练一轮测试一次 """epochs = 30for e in range(epochs):print(f"Epoch {e + 1}\n---------------------------")train(train_dataloader, model, loss_fn, optimizer)scheduler.step()  # 在每个epoch的训练中,使用scheduler.step()语句进行学习率更新test(test_dataloader, model, loss_fn)print('最优的训练结果为:', best_acc)
  • 结果如下
    • 此结果只是训练了30轮后的结果,可以训练更多轮,最后的准确率还会有所提高
      在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/439571.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Windows应急响应-Auto病毒

文章目录 应急背景分析样本开启监控感染病毒查看监控分析病毒行为1.autorun.inf分析2.异常连接3.进程排查4.启动项排查 查杀1.先删掉autorun.inf文件2.使用xuetr杀掉进程3.启动项删除重启排查入侵排查正常流程 应急背景 运维人员准备通过windows共享文档方式为公司员工下发软件…

保险丝基础知识

一、简介 保险丝(fuse)也被称为电流保险丝,它能够在电流异常升高到一定的高度和热度时,自动熔断切断电流,从而保护电路安全运行。 IEC127标准将它定义为“熔断体(fuse-link)”。熔断体是由电阻率比较大而熔…

强化学习笔记之【Q-learning算法和DQN算法】

强化学习笔记(一)——Q-learning和DQN算法核心公式 文章目录 强化学习笔记(一)——Q-learning和DQN算法核心公式前言:Q-learning算法DQN算法 前言: 强化学习领域,繁冗复杂的大段代码里面&#…

《软件工程概论》作业一:新冠疫情下软件产品设计

课程说明:《软件工程概论》为浙江科技学院2018级软件工程专业在大二下学期开设的必修课。课程使用《软件工程导论(第6版)》(张海藩等编著,清华大学出版社)作为教材。以《软件设计文档国家标准GBT8567-2006》…

【小沐学GIS】blender导入OpenTopography地形数据(BlenderGIS、OSM、Python)

文章目录 1、简介1.1 blender1.2 OpenStreetMap地图 2、BlenderGIS2.1 下载BlenderGIS2.2 安装BlenderGIS2.3 申请opentopography的key2.4 抓取卫星地图2.5 生成高度图2.6 获取OSM数据 结语 1、简介 1.1 blender https://www.blender.org/ Blender 是一款免费的开源 3D 创作套…

IMS添加实体按键流程 - Android14

IMS添加实体按键流程 - Android14 1、实体按键信息(Mi 9 左侧实体按键)2、硬件添加2.1 内核添加设备节点2.2 Generic.kl映射文件2.3 映射文件文件加载loadKeyMapLocked2.4 addDeviceLocked 添加设备相关对象 3、keycode对应scankode4、KeyEvent.java 添加…

【星汇极客】手把手教学STM32 HAL库+FreeRTOS之创建工程(0)

前言 本人是一名嵌入式学习者,在大学期间也参加了不少的竞赛并获奖,包括但不限于:江苏省电子设计竞赛省一、睿抗机器人国二、中国高校智能机器人国二、嵌入式设计竞赛国三、光电设计竞赛国三、节能减排竞赛国三。 后面会经常写一下博客&…

服务器conda环境安装rpy2

参考博客 https://stackoverflow.com/questions/68936589/how-to-select-r-installation-when-using-rpy2-on-conda 现在我遇到这样一个问题,服务器系统环境没有R(没有权限安装),我只能在minconda的conda环境中使用R, 使用方法如下 我现在…

(11)MATLAB莱斯(Rician)衰落信道仿真2

文章目录 前言一、莱斯衰落信道仿真模型二、仿真代码与结果1.仿真代码2.仿真结果画图 三、后续:四、参考文献: 前言 首先给出莱斯衰落信道仿真模型,该模型由直射路径分量和反射路径分量组成,其中反射路径分量由瑞利衰落信道模型构…

Art. 1 | 信号、信息与消息的区别及其在通信中的应用

信号、信息与消息的区别及其在通信中的应用 通信技术是现代社会的基石,其广泛应用于日常生活的各个方面。从手机、互联网到企业信息管理,通信系统无处不在。在这一技术领域中,信号、信息和消息是三大基础概念,支撑着整个通信系统…

云计算Openstack Glance

OpenStack Glance(或称为Glance,但通常OpenStack官方文档中使用的是“Glance”作为项目代号)是OpenStack的镜像服务组件,为创建虚拟机提供镜像服务。以下是对OpenStack Glance的详细解析: 一、基本功能 Glance主要提…

【AI人工智能】文心智能体,双人冒险游戏智能体创作分享

背景 最近半年,“AI agent”(智能体)这一词汇变得非常热门。许多人以为创建自己的智能体会很复杂,实际上,现有的平台已经大大降低了操作门槛。只要有创意,几乎每个人都可以轻松创建属于自己的智能体。今天…

Linux下静态库与动态库制作及分文件编程

Linux下静态库与动态库制作及分文件编程 文章目录 Linux下静态库与动态库制作及分文件编程1.分文件编程1.1优点1.2操作逻辑1.3示例 2.Linux库的概念3.静态库的制作与使用3.1优缺点3.2命名规则3.3制作步骤3.4开始享用 4.动态库的制作与使用4.1优缺点4.2动态库命名规则4.3制作步骤…

Uniapp API

1.uni.showToast 显示消息提示框 unishowToast({ obj参数 }) 2.uni.showLoading 显示 loading 提示框, 需主动调用 uni.hideLoading 才能关闭提示框。 3.uni.showModal 显示模态弹窗,可以只有一个确定按钮,也可以同时有确定和取消按钮。类似于一个A…

VLAN:虚拟局域网

VLAN:虚拟局域网 交换机和路由器协同工作后,将原先的一个广播域,逻辑上,切分为多个广播域。 第一步:创建VLAN [SW1]dispaly vlan 查询vlan VID(VLAN ID):用来区分和标定不同的vlan 由12位二进制构成 范围: 0-4…

算法笔记(十一)——优先级队列(堆)

文章目录 最后一块石头的重量数据流中的第 K 大元素前K个高频单词数据流的中位数 优先级队列是一种特殊的队列,元素按照优先级从高到低(或从低到高)排列,高优先级的元素先出队,可以用 堆来实现 堆是一种二叉树的结构&…

HTB:Preignition[WriteUP]

连接至HTB服务器并启动靶机 靶机IP:10.129.157.49 分配IP:10.10.16.12 1.Directory Brute-forcing is a technique used to check a lot of paths on a web server to find hidden pages. Which is another name for this? (i) Local File Inclusion, (…

如何安全地大规模部署 GenAI 应用程序

大型语言模型和其他形式的生成式人工智能(GenAI) 的广泛使用带来了许多组织可能没有意识到的安全风险。幸运的是,网络和安全提供商正在寻找方法来应对这些前所未有的威胁。 随着人工智能越来越深入地融入日常业务流程,它面临着泄露专有信息、提供错误答…

2.创建第一个MySQL存储过程(2/10)

引言 在现代数据库管理中,存储过程扮演着至关重要的角色。它们是一组为了执行特定任务而编写的SQL语句集合,这些语句被保存在数据库中,并且可以被多次调用执行。存储过程不仅可以提高数据库操作的效率,还能增强数据的安全性和一致…

Docker 启动 Neo4j:详细配置指南和浏览器访问

Docker 启动 Neo4j:详细配置指南和浏览器访问 文章目录 Docker 启动 Neo4j:详细配置指南和浏览器访问一 Neo4j compose 得 yml 配置二 配置描述三 浏览器访问 这篇文章详细介绍了如何使用 Docker Compose 启动 Neo4j 数据库,包括 docker-com…