初学迁移学习的理解

1.迁移学习(Transfer Learning)是什么?

简而言之,迁移学习(Transfer Learning)是一种机器学习方法,就是把为任务 A 开发的模型作为初始点,重新使用在为任务 B 开发模型的过程中

迁移学习是通过从已学习的相关任务中转移里面的知识来改进学习的新任务,例如,我们可能会发现学习识别苹果可能有助于识别梨,或者学习弹奏电子琴可能有助于学习钢琴。

找到目标问题的相似性,迁移学习任务就是从相似性出发,将旧领域(domain)学习过的模型应用在新领域上。我们通常希望利用源领域的知识和特征来辅助目标任务的学习。

2.为什么要进行迁移学习?

1、大数据与少标注的矛盾: 虽然有大量的数据,但往往都是没有标注的,无法训练机器学习模型。人工进行数据标定太耗时**(非监督)**。

2、大数据与弱计算的矛盾(本身很穷): 普通人无法拥有庞大的数据量与计算资源。因此需要借助于模型的迁移。

3.如何进行迁移学习(How to transfer?)

1.选择源领域和目标领域:

确定你要解决的目标任务(目标领域)。另外,选择一个与目标任务相关但不完全相同的任务作为源领域。源领域通常具有足够的数据和标签,可以帮助我们通过这个模型提取出所需的知识去学习通用的特征

2.选择模型架构:

模型架构的设计旨在从源领域中提取知识特征,然后将这些特征应用到目标任务中。我们通常希望利用源领域的知识和特征来辅助目标任务的学习,特别是当目标任务的数据较少或者目标领域与源领域有一定的相似性时。

根据任务的复杂度和数据情况选择合适的模型架构,可以是经典的卷积神经网络(如VGG、ResNet等)、循环神经网络(如LSTM、GRU等)、生成对抗网络(如GANs)等。

3.冻结部分模型参数:

如果源领域和目标领域之间的特征差异较大,可以选择冻结源领域模型的部分参数,只训练部分参数以适应目标任务。这有助于防止源领域特定的特征影响目标任务的学习。

4.选择合适的损失函数:

根据目标任务的性质选择合适的损失函数,如分类任务可以选择交叉熵损失函数,回归任务可以选择**均方误差损失函数(MSE)**等。

5.进行数据预处理:

比如进行数据归一化**(比如:对源领域和目标领域的图片数据进行归一化处理,确保它们的像素值处于相同的量级),对图像数据进行降噪处理,去除可能存在的图像噪声(比如:可以采用图像平滑算法(如高斯模糊)对图像进行平滑处理,降低噪声的影响),另外还有数据清洗(检测并处理源领域和目标领域数据中的缺失值和异常值,确保数据质量良好)** 等等。

import pandas as pd
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.model_selection import train_test_split# 假设我们有一个CSV文件包含汽车数据,其中有数值型和分类型特征# 读取数据
data = pd.read_csv('car_data.csv')# 数据清洗(处理缺失值)
data = data.dropna()# 数据预处理(数值型特征归一化,分类型特征编码)
# 数值型特征归一化
scaler = StandardScaler()
numerical_features = ['horsepower', 'weight', 'acceleration']
data[numerical_features] = scaler.fit_transform(data[numerical_features])# 分类型特征编码
label_encoder = LabelEncoder()
data['origin'] = label_encoder.fit_transform(data['origin'])# 划分训练集和测试集
X = data.drop('mpg', axis=1)  # 特征矩阵
y = data['mpg']  # 目标变量X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 查看预处理后的数据
print(X_train.head())
print(y_train.head())

在这里插入图片描述

4.When to transfer: 什么时候可以进行迁移

在这里插入图片描述

5.迁移学习的一些概念

1.基本定义:
域(Domain):数据特征和特征分布组成,是学习的主体
源域 (Source domain):已有知识的域
目标域 (Target domain):要进行学习的域
任务 (Task):由目标函数和学习结果组成,是学习的结果

2.按特征空间分类:
同构迁移学习(Homogeneous TL): 源域和目标域的特征空间相同
异构迁移学习(Heterogeneous TL):源域和目标域的特征空间不同

3.按迁移情景分类:
归纳式迁移学习(Inductive TL):源域和目标域的学习任务不同
直推式迁移学习(Transductive TL):源域和目标域不同,学习任务相同
无监督迁移学习(Unsupervised TL):源域和目标域均没有标签

4.按迁移方法分类:

基于样本的迁移 (Instance based TL): 通过权重重用源域和目标域的样例进行迁移

基于样本的迁移学习方法 (Instance based Transfer Learning) 根据一定的权重生成规则,对数据样本进行重用,来进行迁移学习。下图形象地表示了基于样本迁移方法的思想源域中存在不同种类的动物,如狗、鸟、猫等,目标域只有狗这一种类别。在迁移时,为了最大限度地和目标域相似,我们可以人为地提高源域中属于狗这个类别的样本权重。

**基于特征的迁移 (Feature based TL):**将源域和目标域的特征变换到相同空间

基于特征的迁移方法 (Feature based Transfer Learning) 是指将通过特征变换的方式互相迁移,来减少源域和目标域之间的差距;或者将源域和目标域的数据特征变换到统一特征空间中,然后利用传统的机器学习方法进行分类识别。根据特征的同构和异构性,又可以分为同构和异构迁移学习。下图很形象地表示了两种基于特 征的迁移学习方法。

**基于模型的迁移 (Parameter based TL):**利用源域和目标域的参数共享模型

基于模型的迁移方法 (Parameter/Model based Transfer Learning) 是指从源域和目标域中找到他们之间共享的参数信息,以实现迁移的方法。这种迁移方式要求的假设条件是: 源域中的数据与目标域中的数据可以共享一些模型的参数。下图形象地表示了基于模型的迁移学习方法的基本思想。

**基于关系的迁移 (Relation based TL):**利用源域中的逻辑网络关系进行迁移

基于关系的迁移学习方法 (Relation Based Transfer Learning) 与上述三种方法具有截然不同的思路。这种方法比较关注源域和目标域的样本之间的关系。下图形象地表示了不 同领域之间相似的关系。

6.什么是FineTune微调

微调是迁移学习的一种技术,它通常指的是在已经预训练好的模型基础上,对模型的部分或全部参数进行调整,以适应新任务的需求。微调可以在源域数据上进行,也可以在目标域数据上进行。

举个例子: 假设你有一个在大规模图像数据集上预训练好的卷积神经网络(CNN),用于识别不同物体的图片。现在你有一个小型的数据集,包含了特定类型的物体图片,比如狗和猫。你可以使用迁移学习,将预训练的CNN模型作为基础模型,在你的小型数据集上进行微调,以便让模型学习到狗和猫的识别任务。

from tensorflow.keras.applications import VGG16  # 导入VGG16模型
from tensorflow.keras.models import Sequential  # 导入Sequential模型
from tensorflow.keras.layers import Dense, Flatten, Dropout  # 导入Dense、Flatten和Dropout层
from tensorflow.keras.optimizers import Adam  # 导入Adam优化器
from tensorflow.keras.preprocessing.image import ImageDataGenerator  # 导入ImageDataGenerator用于数据增强# 加载预训练的VGG16模型,不包含顶层分类器
base_model = VGG16(weights='imagenet', include_top=False, input_shape=(224, 224, 3))# 冻结VGG16的卷积层,只微调顶层分类器
for layer in base_model.layers:layer.trainable = False# 添加自定义的顶层分类器
model = Sequential([base_model,  # 添加预训练的VGG16模型作为基础模型Flatten(),  # 展平层Dense(256, activation='relu'),  # 全连接层1,256个神经元,激活函数为ReLUDropout(0.5),  # Dropout层,防止过拟合Dense(2, activation='softmax')  # 全连接层2,2个输出类别:狗和猫,激活函数为softmax
])# 编译模型
model.compile(optimizer=Adam(learning_rate=0.0001),  # 使用Adam优化器,学习率设为0.0001loss='categorical_crossentropy',  # 使用交叉熵损失函数metrics=['accuracy'])  # 评估指标为准确率# 数据增强
train_datagen = ImageDataGenerator(rescale=1./255,  # 像素值缩放到0~1之间rotation_range=20,  # 随机旋转角度范围为20度width_shift_range=0.2,  # 水平随机偏移范围为20%height_shift_range=0.2,  # 垂直随机偏移范围为20%shear_range=0.2,  # 剪切强度范围为20%zoom_range=0.2,  # 缩放范围为20%horizontal_flip=True,  # 水平翻转fill_mode='nearest'  # 填充像素的方式为最近像素
)# 加载训练数据和验证数据
train_generator = train_datagen.flow_from_directory('train_data_dir',  # 训练数据目录target_size=(224, 224),  # 图像尺寸设为224x224batch_size=32,  # 批量大小为32class_mode='categorical'  # 多分类任务
)validation_generator = train_datagen.flow_from_directory('validation_data_dir',  # 验证数据目录target_size=(224, 224),  # 图像尺寸设为224x224batch_size=32,  # 批量大小为32class_mode='categorical'  # 多分类任务
)# 进行模型训练
model.fit(train_generator,  # 训练数据生成器steps_per_epoch=train_generator.samples // 32,  # 每个epoch的步数epochs=10,  # 迭代次数为10validation_data=validation_generator,  # 验证数据生成器validation_steps=validation_generator.samples // 32  # 每个epoch的验证步数
)# 保存微调后的模型
model.save('fine_tuned_model.h5')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/332975.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

01JAVA基础

目录 1.基础语法 1.1 注释 1.2 关键字 1.3 常量 1.4 数据类型 1.5 变量 1.6 标识符 1.7 类型转换 2.算数运算符和分支语句 2.1 算数运算符 1.常规运算符 2.赋值运算符 3.自增自减 4.关系运算符 5.逻辑运算符 6.三元运算符 2.2 数据输入(Scanner) 2.3 分支判断…

mac 安装java jdk8 jdk11 jdk17 等

oracle官网 https://www.oracle.com/java/technologies/downloads/ 查看当前电脑是英特尔的x86 还是arm uname -m 选择指定版本,指定平台的安装包: JDK8 JDK11的,需要当前页面往下拉: 下载到的安装包,双击安装&#x…

基于微信小程序+ JAVA后端实现的【医院挂号预约系统】 设计与实现 (内附设计LW + PPT+ 源码+ 演示视频 下载)

项目名称 项目名称: 《基于微信小程序的医院挂号预约系统设计与实现》 项目技术栈 该项目采用了以下核心技术栈: 后端框架/库: Java, SSM框架数据库: MySQL前端技术: 微信小程序, uni-app 项目展示 全文概括 本…

关于亚马逊、速卖通、虾皮、Lazada等平台自养号测评IP的重要性

在自养号测评中,IP的纯净度是一个至关重要的问题,它直接关系到账号的安全性和稳定性如果使用了被平台识别为异常或存在风险的IP地址,那么账号可能会面临被封禁的风险。这将对账号的正常使用和测评过程中造成严重影响。而使用纯净的IP地址&…

前端开发的设计思路【精炼】(含数据结构设计、组件设计)

数据结构设计 用数据描述所有的内容数据要结构化,易于程序操作(遍历、查找),比如数组、对象、对象为元素构成的数组(每个元素记得设置唯一的 id 属性,以便对元素进行删改操作)数据要可扩展,以便增加新的功能…

EtherCAT总线掉线如何自动重启

EtherCAT通信如果是从站掉线我们可以勾选上自动重启功能如下图所示: 1、自动重启从站 待续.....

MacOS使用PhpStorm+Xdebug断点调式

基本环境: MacOS m1 PhpStorm 2024.1 PHP7.4.33 Xdebug v3.1.6 1、php.ini 配置 [xdebug] zend_extension "/opt/homebrew/Cellar/php7.4/7.4.33_6/pecl/20190902/xdebug.so" xdebug.idekey "PHPSTORM" xdebug.c…

【数组】Leetcode 452. 用最少数量的箭引爆气球【中等】

用最少数量的箭引爆气球 有一些球形气球贴在一堵用 XY 平面表示的墙面上。墙面上的气球记录在整数数组 points ,其中points[i] [xstart, xend] 表示水平直径在 xstart 和 xend之间的气球。你不知道气球的确切 y 坐标。 一支弓箭可以沿着 x 轴从不同点 完全垂直 地…

Tensors张量操作

定义Tensor 下面是一个常见的tensor,包含了里面的数值,属性,以及存储位置 tensor([[0.3565,0.1826,0.6719],[0.6695,0.5364,0.7057]],dtypetorch.float32,devicecuda:0)Tensor的属…

【机器学习】Python中的决策树算法探索

🌈个人主页: 鑫宝Code 🔥热门专栏: 闲话杂谈| 炫酷HTML | JavaScript基础 ​💫个人格言: "如无必要,勿增实体" 文章目录 Python中的决策树算法探索引言1. 决策树基础理论1.1 算法概述1.2 构建过程 2. P…

解决:LVGL+GUI Guider 1.7.2运行一段时间就会卡死死机,内存泄露溢出的问题

概括: 我在使用NXP官方GUI Guider生成的代码出现了内存泄漏的问题。但我遇到的并不是像其他人所说的style的问题,如下链接。而是因为在页面渲染之前就使用了该页面内的组件,内存就会不断增加。 LVGL 死机 内存泄漏_lvgl 内存溢出-CSDN博客 运…

一文读懂Linux

前言 为了便于理解,本文从常用操作和概念开始讲起。虽然已经尽量做到简化,但是涉及到的内容还是有点多。在面试中,Linux 知识点相对于网络和操作系统等知识点而言不是那么重要,只需要重点掌握一些原理和命令即可。为了方便大家准…

操作系统总结(2)

目录 2.1 进程的概念、组成、特征 (1)知识总览 (2)进程的概念 (3)进程的组成—PCB (4)进程的组成---程序段和数据段 (5)程序是如何运行的呢&#xff1f…

嵌入式开发中树莓派和单片机关键区别

综合了几篇帖子作以信息收录:树莓派和单片机作为嵌入式系统领域中两种广泛使用的设备,各自有着不同的特性和应用场景,文章从五个方面进行比对展开。 架构与性能: 树莓派:是一款微型计算机,通常配备基于AR…

Django性能优化:提升加载速度

title: Django性能优化:提升加载速度 date: 2024/5/20 20:16:28 updated: 2024/5/20 20:16:28 categories: 后端开发 tags: 缓存策略HTTP请求DNS查询CDN分发前端优化服务器响应浏览器缓存 第一章:Django性能优化概述 1.1 性能优化的意义 性能优化是…

Spring中@Component注解

Component注解 在Spring框架中,Component是一个通用的注解,用于标识一个类作为Spring容器管理的组件。当Spring扫描到被Component注解的类时,会自动创建一个该类的实例并将其纳入Spring容器中管理。 使用方式 1、基本用法: Co…

深入浅出MySQL事务实现底层原理

重要概念 事务的ACID 原子性(Atomicity):即不可分割性,事务中的操作要么全不做,要么全做一致性(Consistency):一个事务在执行前后,数据库都必须处于正确的状态&#xf…

vb.net打开CAD指指定路径文件

首先打开vsto,创建窗体,添加一个按钮,双击按钮录入代码: Public Class Form1Private Sub Button1_Click(sender As Object, e As EventArgs) Handles Button1.ClickDim cad As Objectcad CreateObject("autocad.Application")cad…

火箭升空AR虚拟三维仿真演示满足客户的多样化场景需求

在航空工业的协同研发领域,航空AR工业装配系统公司凭借前沿的AR增强现实技术,正引领一场革新。通过将虚拟信息无缝融入实际环境中,我们为工程师、设计师和技术专家提供了前所未有的共享和审查三维模型的能力,极大地提升了研发效率…

Go 语言简介 -- 高效、简洁与现代化编程的完美结合

在现代软件开发领域,选择合适的编程语言对于项目的成功至关重要。Go 语言(又称 Golang )自 2009 年由Google发布以来,以其简洁的语法、高效的并发模型以及强大的性能,迅速成为开发者们的新宠。Go语言不仅融合了传统编译…