深度学习-迁移学习

深度学习中的迁移学习是通过在大规模数据上训练的模型,将其知识迁移到数据相对较少的相关任务中,能显著提升目标任务的模型性能。


一、迁移学习的核心概念

  1. 源任务(Source Task)与目标任务(Target Task)

    (1)源任务:通常拥有大量标注数据以及预训练好的模型,模型可以从中提取到通用特征。(2)目标任务:数据量相对有限,与源任务有相似性,但需要迁移模型知识适应特定的需求。
  2. 特征迁移

    (1)深度学习模型的层级结构有“自下而上”的特征表示,底层(如边缘、形状特征)更通用高层特征(如复杂纹理、特定形状)更具体。(2)迁移学习通过保留底层特征,并微调高层特征以适应新任务。
  3. 微调与冻结

    (1)冻结:冻结模型底层权重,保留已学到的底层特征,适合用于不同数据但相似的任务。(2)微调:对高层权重进行少量训练,使其适应目标任务,适用于源、目标任务有一定关联的情况。
  4. 模型剪枝与特征选择

    (1)剪枝可以减少模型复杂度,提升推理速度,适合在特定硬件上优化迁移模型的性能。

二、迁移学习的策略及示意图

迁移学习主要有以下策略,每个策略适用于不同场景。

1. 特征提取策略(Feature Extraction)
  • 使用预训练模型的卷积层作为固定的特征提取器,只在输出部分添加新的全连接层或分类层。
  • 应用于源任务和目标任务相似度较高的情况(如图像分类任务)。

代码示例

from tensorflow.keras.applications import VGG16
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Flatten# 加载预训练的 VGG16 模型,不包含顶层
base_model = VGG16(weights='imagenet', include_top=False, input_shape=(224, 224, 3))# 将卷积层的权重冻结
for layer in base_model.layers:layer.trainable = False# 添加新的全连接层
x = Flatten()(base_model.output)
output = Dense(10, activation='softmax')(x)
model = Model(inputs=base_model.input, outputs=output)
2. 微调策略(Fine-tuning)
  • 在预训练模型的基础上保留底层特征,微调高层特征,适应新的目标任务。适合在源任务和目标任务高度相似时使用。

代码示例

# 微调部分卷积层
for layer in base_model.layers[:15]:layer.trainable = False
for layer in base_model.layers[15:]:layer.trainable = True

3. 跨领域迁移(Cross-domain Transfer)
  • 针对不同领域任务的特征迁移策略,如图像到文本、语音到文本的跨领域迁移。需要添加或替换特定的适应层以完成不同领域的转换。

三、迁移学习的代码实现示例

以下代码展示了在 ImageNet 预训练的 VGG16 模型上,通过冻结部分卷积层并添加自定义全连接层,用于一个新的分类任务(如猫狗分类)。

import torch
import torch.nn as nn
import torchvision.models as models
from torchvision import datasets, transforms
from torch.utils.data import DataLoader# 1. 加载预训练的 VGG16 模型
vgg16 = models.vgg16(pretrained=True)# 2. 冻结前面的卷积层
for param in vgg16.features.parameters():param.requires_grad = False# 3. 修改分类器部分,适应猫狗二分类任务
# 获取 VGG16 的输入特征数,并替换最后一层为适合二分类的线性层
num_features = vgg16.classifier[6].in_features
vgg16.classifier[6] = nn.Linear(num_features, 2)  # 2 classes for binary classification# 4. 定义训练参数
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vgg16 = vgg16.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(vgg16.classifier[6].parameters(), lr=0.001)  # 只更新最后一层参数# 5. 定义数据预处理和加载
data_transforms = {'train': transforms.Compose([transforms.Resize((224, 224)),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),'val': transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
}train_dataset = datasets.ImageFolder(root='data/train', transform=data_transforms['train'])
val_dataset = datasets.ImageFolder(root='data/val', transform=data_transforms['val'])
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)# 6. 训练模型
def train_model(model, criterion, optimizer, num_epochs=10):for epoch in range(num_epochs):model.train()running_loss = 0.0correct = 0for inputs, labels in train_loader:inputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()# 前向传播outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()# 统计损失和准确率running_loss += loss.item() * inputs.size(0)_, preds = torch.max(outputs, 1)correct += torch.sum(preds == labels)epoch_loss = running_loss / len(train_loader.dataset)epoch_acc = correct.double() / len(train_loader.dataset)print(f'Epoch {epoch}/{num_epochs - 1} - Loss: {epoch_loss:.4f}, Acc: {epoch_acc:.4f}')# 7. 调用训练函数
train_model(vgg16, criterion, optimizer, num_epochs=10)

  • 冻结卷积层:使用 for param in vgg16.features.parameters(): param.requires_grad = False 冻结了 vgg16.features 中的参数,使其在训练中不更新。

  • 修改分类层:更改 vgg16.classifier[6] 中的最后一个线性层,使其适应二分类任务(猫狗分类)。

  • 数据预处理与加载:利用 transforms 进行图像的标准化和尺寸调整,确保模型输入一致,加载后的数据放入 DataLoader 中便于批量处理。

  • 训练循环:在 train_model 函数中进行批次训练,计算损失并更新模型参数。

四、迁移学习的实际应用场景

  1. 图像分类:用于医疗影像分析、卫星图像识别等。例如使用 ImageNet 预训练模型进行皮肤癌检测。
  2. 目标检测与分割:自动驾驶中的行人检测、视频监控中的异常事件检测等。
  3. 自然语言处理:在 BERT、GPT-3 等预训练模型基础上微调,以适应情感分析、文本分类等任务。
  4. 语音识别:预训练语音模型可用于语音情感识别、口音识别等任务。

五、迁移学习的优缺点

优点

  • 数据需求少:不需要大量标注数据,可以显著缩短模型开发时间。
  • 训练高效:利用已有模型权重,减少训练时间。
  • 泛化能力强:预训练模型在大数据上学到的特征更具普适性,提高目标任务的泛化能力。

缺点

  • 源任务与目标任务的相似性要求:源任务和目标任务若差异较大,迁移效果会明显下降。
  • 存在偏差风险:源任务的偏差可能会迁移到目标任务中,对任务结果产生负面影响。
  • 额外存储开销:需要存储源模型的权重,对计算和存储资源有额外要求。

六、迁移学习的注意事项

  1. 选择合适的源任务:尽量选择与目标任务具有相似特征的源任务模型。
  2. 调整学习率:微调时的学习率应小于源任务,避免过度改变预训练模型的特征。
  3. 慎重选择微调层数:微调的层数应考虑目标任务的复杂性,避免过拟合。
  4. 数据预处理保持一致:确保源任务和目标任务的数据预处理方式一致,否则会影响模型性能。

七、总结

迁移学习在深度学习应用中已成为提升模型训练效率和性能的关键技术,尤其在目标任务与源任务具有一定关联性、且标注数据有限的情况下效果尤为显著。迁移学习通过利用在大规模数据集(如 ImageNet)上预训练的模型知识,将其迁移到新任务中,减少了对大规模数据和计算资源的需求。不同的迁移学习策略(如特征提取、微调、参数冻结等)能够针对性地调整模型层级的学习参数,实现高效的模型适应性。深入理解和灵活应用这些策略是深度学习项目开发的重要技能,能够在分类、检测、分割、文本分析等领域中有效缩短训练周期,并在数据有限的情况下显著提升模型的泛化性能和准确性。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/465775.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

解锁炎症和肿瘤免疫治疗新靶点:TREM1&TREM2

前 言 TREM家族属于细胞表面受体,介导调控炎症反应,现已成为癌症、神经退行性疾病以及炎症性疾病等多种疾病最有潜力的药物靶点。截至2023年6月,有5项FDA注册的临床前或临床试验正在进行中,有3项是TREM2在阿尔茨海默症&#xff…

【Unity】Unity拖拽在Android设备有延迟和卡顿问题的解决

一、介绍 在制作Block类游戏时,其核心的逻辑就是拖拽方块放入到地图中,这里最先想到的就是Unity的拖拽接口IDragHandler,然后通过 IPointerDownHandler, IPointerUpHandler 这两个接口判断按下和松手,具体的实现逻辑就是下面 public void On…

Postman断言与依赖接口测试详解!

在接口测试中,断言是不可或缺的一环。它不仅能够自动判断业务逻辑的正确性,还能确保接口的实际功能实现符合预期。Postman作为一款强大的接口测试工具,不仅支持发送HTTP请求和接收响应,还提供了丰富的断言功能,帮助测试…

NewStar CTF 2024 misc WP

decompress 压缩包套娃,一直解到最后一层,将文件提取出来 提示给出了一个正则,按照正则爆破密码,一共五位,第四位是数字 ^([a-z]){3}\d[a-z]$ 一共就五位数,直接ARCHPR爆破,得到密码 xtr4m&…

鸿蒙开发案例:七巧板

【1】引言(完整代码在最后面) 本文介绍的拖动七巧板游戏是一个简单的益智游戏,用户可以通过拖动和旋转不同形状的七巧板块来完成拼图任务。整个游戏使用鸿蒙Next框架开发,利用其强大的UI构建能力和数据响应机制,实现了…

C++_STL_xx_番外01_关于STL的总结(常见容器的总结;关联式容器分类及特点;二叉树、二叉搜索树、AVL树(平衡二叉搜索树)、B树、红黑树)

文章目录 1. 常用容器总结2. 关联式容器分类3. 二叉树、二叉搜索树、AVL树、B树、红黑树 1. 常用容器总结 针对常用容器的一些总结: 2. 关联式容器分类 关联式容器分为两大类: 基于红黑树的set和map;基于hash表的unorder_set和unorder_ma…

Django目录结构最佳实践

Django项目目录结构 项目目录结构配置文件引用修改创建自定义子应用方法修改自定义注册目录从apps目录开始 项目目录结构 └── backend # 后端项目目录(项目名称)├── __init__.py├── logs # 项目日志目录├── manage.py #…

AnytimeCL:难度加大,支持任意持续学习场景的新方案 | ECCV‘24

来源:晓飞的算法工程笔记 公众号,转载请注明出处 论文: Anytime Continual Learning for Open Vocabulary Classification 论文地址:https://arxiv.org/abs/2409.08518论文代码:https://github.com/jessemelpolio/AnytimeCL 创新…

2020年美国总统大选数据分析与模型预测

数据集取自:2020年🇺🇸🇺🇸美国大选数据集 - Heywhale.com 前言 对2020年美国总统大选数据的深入分析,提供各州和县层面的投票情况及选民行为的可视化展示。数据预处理阶段将涉及对异常值的处理&#xff0…

工业以太网PLC无线网桥,解决用户布线难题!

工业以太网无线网桥 功能概述 本产品是工业以太网(Profinet、EtherNet/IP、ModbusTCP等)转无线设备,成对使用(一对一),出厂前已经配对好,用户不需要再配对,即插即用。适用于用户布线不方便的场景。使用方式简单,只需要把拨码开关设置好并上电即可工作,无需进行其它设置。支持P…

Android13 系统/用户证书安装相关分析总结(三) 增加安装系统证书的接口遇到的问题和坑

一、前言 接上回说到,修改了程序,增加了接口,却不知道有没有什么问题,于是心怀忐忑等了几天。果然过了几天,应用那边的小伙伴报过来了问题。用户证书安装没有问题,系统证书(新增的接口)还是出现了问题。调…

AUTOSAR CP NVRAM Manager规范导读

一、NVRAM Manager功能概述 NVRAM Manager是AUTOSAR(AUTomotive Open System ARchitecture)框架中的一个模块,负责管理非易失性随机访问存储器(NVRAM)。它提供了一组服务和API,用于在汽车环境中存储、维护和恢复NV数据。以下是NVRAM Manager的一些关键功能: 数据存储和…

kelp protocol

道阻且长,行而不辍,未来可期 有很长一段时间我都在互联网到处拾金,but,东拼西凑的,总感觉不踏实,最近在老老实实的看官方文档 & 阅读白皮书 &看合约,挑拣一些重要的部分配上官方的证据,和过路公主or王子分享一下,愿我们早日追赶上公司里那些可望不可及大佬们。…

LeetCode25:K个一组翻转链表

原题地址:. - 力扣(LeetCode) 题目描述 给你链表的头节点 head ,每 k 个节点一组进行翻转,请你返回修改后的链表。 k 是一个正整数,它的值小于或等于链表的长度。如果节点总数不是 k 的整数倍,那…

k8s图形化显示(KRM)

在master节点 kubectl get po -n kube-system 这个命令会列出 kube-system 命名空间中的所有 Pod 的状态和相关信息,比如名称、状态、重启次数等。 systemctl status kubelet #查看kubelet状态 yum install git #下载git命令 git clone https://gitee.com/duk…

Github 2024-11-07 Go开源项目日报 Top10

根据Github Trendings的统计,今日(2024-11-07统计)共有10个项目上榜。根据开发语言中项目的数量,汇总情况如下: 开发语言项目数量Go项目10HTML项目1Kubernetes: 容器化应用程序管理系统 创建周期:3618 天开发语言:Go协议类型:Apache License 2.0Star数量:106913 个Fork数…

HTML 标签属性——<a>、<img>、<form>、<input>、<table> 标签属性详解

文章目录 1. `<a>`元素属性hreftargetname2. `<img>`元素属性srcaltwidth 和 height3. `<form>`元素属性actionmethodenctype4. `<input>`元素属性typevaluenamereadonly5. `<table>`元素属性cellpaddingcellspacing小结HTML元素除了可以使用全局…

制造业仓储信息化总体规划方案

文件是一份关于制造业仓储信息化的总体规划方案&#xff0c;主要内容包括项目背景、现状调研、项目目标、建设思路、业务蓝图设计方案、系统设计方案以及场景展示等。以下是对PPT内容的分析和总结&#xff1a; 1. 项目背景 目标&#xff1a;通过物流执行系统&#xff08;LES&a…

Ubuntu使用Qt虚拟键盘,支持中英文切换

前言 ​最近领导给了个需求&#xff0c;希望将web嵌入到客户端里面&#xff0c;做一个客户端外壳&#xff0c;可以控制程序的启动、停止、重启&#xff0c;并且可以调出键盘在触摸屏上使用(我们的程序虽然是BS架构&#xff0c;但程序还是运行在本地工控机上的)&#xff0c;我研…

python爬取旅游攻略(1)

参考网址&#xff1a; https://blog.csdn.net/m0_61981943/article/details/131262987 导入相关库&#xff0c;用get请求方式请求网页方式&#xff1a; import requests import parsel import csv import time import random url fhttps://travel.qunar.com/travelbook/list.…