[Machine Learning] 多任务学习

文章目录

  • 基于参数的MTL模型 (Parameter-based MTL Models)
  • 基于特征的MTL模型 (Feature-based MTL Models)
    • 基于特征的MTL模型 I:
    • 基于特征的MTL模型 II:
  • 基于特征和参数的MTL模型 (Feature- and Parameter-based MTL Models)


多任务学习 (Multi-task Learning, MTL) 是一种同时学习多个相关问题的方法,它通过利用这些问题之间的相关性来进行学习。


在单任务学习 (Single-Task Learning, STL) 中,每个任务有一个独立的模型,这些模型分别学习不同的任务。这里,每个任务(Task 1, Task 2, Task 3, Task 4)都有它自己的输入和独立的神经网络模型。这些模型不会共享学习到的特征或表示,它们是完全独立的。

在多任务学习中,一个单一的模型共同学习多个任务。模型共享输入层和可能还有一些隐藏层,但在最后,可以有特定于任务的输出层。通过这种方式,模型可以学习到在多个任务间共通的、有用的表示,这可以提升模型在各个任务上的性能,特别是当这些任务相关时。多任务学习还有助于提高数据利用率和学习效率,因为相同的数据和模型参数被用来解决多个问题。

这幅图用来说明的关键点是,在多任务学习中,我们期望通过任务之间的相关性来提升性能,而在单任务学习中,每个任务都是孤立地学习,无法从其他任务中学习到的信息中受益。

当任务彼此独立时,多任务学习与单任务学习相比并无优势。

对于数据不足的问题,当有多个相关任务且每个任务的训练样本有限时,多任务学习是一个很好的解决方案。

设定有 m m m个学习任务 { T i } i = 1 m \{T_i\}_{i=1}^m {Ti}i=1m,其中所有任务或其子集彼此相关,多任务学习旨在通过使用 m m m个任务中包含的知识来帮助提高模型对 T i \mathcal{T}_i Ti的学习。任务 T i \mathcal{T}_i Ti伴随着一个训练集 D i = { x j i , y j i } j = 1 n i D_i = \{ x_j^i, y_j^i \}_{j=1}^{n_i} Di={xji,yji}j=1ni

我们的任务是为 { T i } i = 1 m \{T_i\}_{i=1}^m {Ti}i=1m学习假设。

在MTL中,我们考虑线性假设函数,表示为 h ( x ) = w T x h(x) = w^T x h(x)=wTx。对于 m m m 个不同但相关的任务,即 { T i } i = 1 m \{T_i\}^m_{i=1} {Ti}i=1m,我们定义 w i w^i wi 为第 i i i 个任务的假设,其中 i = 1 , … , m i = 1, \ldots, m i=1,,m

MTL的经验风险最小化算法表示为:

min ⁡ W = [ w 1 , … , w m ] 1 m ∑ i = 1 m 1 n i ∑ j = 1 n i ℓ ( x j i , y j i , w i ) \min\limits_{W=[w^1,\ldots,w^m]} \frac{1}{m} \sum_{i=1}^{m} \frac{1}{n_i} \sum_{j=1}^{n_i} \ell (x^i_j, y^i_j, w^i) W=[w1,,wm]minm1i=1mni1j=1ni(xji,yji,wi)

MTL模型通常由两个主要组件组成:参数共享和特征变换。参数共享是指在多个任务间共享模型参数,这样可以使不同任务互相借鉴彼此的信息,从而提高学习效率。特征变换则是指对输入数据进行变换,以找到一个更适合所有任务的表示方式。

基于参数的MTL模型 (Parameter-based MTL Models)

在这种方法中,我们考虑多个相关的任务,并且假设每个任务的假设 w i w^i wi可以表示为一个共同的基础参数 w 0 w_0 w0加上一个特定任务的偏差 Δ w i \Delta w^i Δwi。这个模型的形式化为:

min ⁡ w 0 , Δ W = [ Δ w 1 , … , Δ w m ] 1 m ∑ i = 1 m 1 n i ∑ j = 1 n i ℓ ( x j i , y j i , w 0 + Δ w i ) \min_{w_0,\Delta W = [\Delta w^1, \ldots, \Delta w^m]} \frac{1}{m} \sum_{i=1}^{m} \frac{1}{n_i} \sum_{j=1}^{n_i} \ell(x^i_j, y^i_j, w_0 + \Delta w^i) w0,ΔW=[Δw1,,Δwm]minm1i=1mni1j=1ni(xji,yji,w0+Δwi)

这里的 ℓ \ell 是损失函数, x j i x^i_j xji y j i y^i_j yji是第 i i i个任务的第 j j j个训练样本及其标签。

这样,第 i i i个任务的模型参数可以表示为 w i = w 0 + Δ w i w^i = w_0 + \Delta w^i wi=w0+Δwi。全局参数 w 0 w_0 w0捕获了所有任务之间的共性,而 Δ w i \Delta w^i Δwi则捕获了任务特有的特性。我们的优化目标是最小化所有任务的总损失,同时尽可能地使得各任务参数相互接近,这通常通过添加一个正则化项 ∥ Δ W ∥ F 2 \|\Delta W\|_F^2 ∥ΔWF2来实现:

min ⁡ w 0 , Δ W = [ Δ w 1 , … , Δ w m ] 1 m ∑ i = 1 m 1 n i ∑ j = 1 n i ℓ ( x j i , y j i , w 0 + Δ w i ) + λ ∥ Δ W ∥ F 2 \min_{w_0,\Delta W = [\Delta w^1, \ldots, \Delta w^m]} \frac{1}{m} \sum_{i=1}^{m} \frac{1}{n_i} \sum_{j=1}^{n_i} \ell(x^i_j, y^i_j, w_0 + \Delta w^i) + \lambda \|\Delta W\|_F^2 w0,ΔW=[Δw1,,Δwm]minm1i=1mni1j=1ni(xji,yji,w0+Δwi)+λ∥ΔWF2

这个模型更好,因为它鼓励多任务学习算法具有更强的相关性。

另一个模型使用秩约束:

min ⁡ W = [ w 1 , … , w m ] 1 m ∑ i = 1 m 1 n i ∑ j = 1 n i ℓ ( x j i , y j i , w i ) + λ rank ( W ) \min\limits_{W=[w^1,\ldots,w^m]} \frac{1}{m} \sum\limits_{i=1}^{m} \frac{1}{n_i} \sum\limits_{j=1}^{n_i} \ell(x^i_j, y^i_j, w^i) + \lambda \text{ rank}(W) W=[w1,,wm]minm1i=1mni1j=1ni(xji,yji,wi)+λ rank(W)

基于特征的MTL模型 (Feature-based MTL Models)

在基于特征的MTL模型中,假设是从训练样例中学到的:

给定一组数据 D i = { x j i , y j i } j = 1 n i \mathcal{D}_i = \{ x_j^{i}, y_j^{i} \}_{j=1}^{n_i} Di={xji,yji}j=1ni

我们希望通过特征映射使得任务之间更加相关。即,我们希望找到一个投影矩阵 P P P,使得 D i \mathcal{D}_i Di变换为 D i = { P T x j i , y j i } j = 1 n i \mathcal{D}_i = \{ P^T x_j^{i}, y_j^{i} \}_{j=1}^{n_i} Di={PTxji,yji}j=1ni

基于特征的MTL模型 I:

min ⁡ W , P 1 m ∑ i = 1 m 1 n i ∑ j = 1 n i ℓ ( P T x j i , y j i , w i ) + λ rank ( W ) s.t.  P P T = I \min_{W,P} \frac{1}{m} \sum_{i=1}^{m} \frac{1}{n_i} \sum_{j=1}^{n_i} \ell(P^T x_j^{i}, y_j^{i},w^i) + \lambda \text{rank}(W) \text{ s.t. } PP^T = I W,Pminm1i=1mni1j=1ni(PTxji,yji,wi)+λrank(W) s.t. PPT=I

这个损失函数计算的是映射后的特征与目标值之间的误差,并加入了正则化项以控制权重矩阵W的复杂度。损失函数以 ℓ ( P T x i j , y i j , w i ) \ell(P^T x_i^j, y_i^j, w^i) (PTxij,yij,wi) 表示, x i j x_i^j xij 是第i个任务的第j个样本的特征, y i j y_i^j yij 是对应的目标值, w i w^i wi 是第i个任务的权重向量, P P P 是一个投影矩阵,使得通过 P T x j i P^T x_j^{i} PTxji变换后的特征可以更好地为多个任务服务。

λ \lambda λ 是正则化项的权重, rank ( W ) \text{rank}(W) rank(W) 是权重矩阵的秩,用于控制模型的复杂度。

基于特征的MTL模型 II:

这是一个共享隐藏层的神经网络架构,其中隐藏层的节点可以被看作是特征提取器。

对应的优化问题考虑了一个共享参数 w 0 w_0 w0 和针对每个任务的调整参数 Δ w i \Delta w_i Δwi

这个模型的目标是最小化包含共享参数和任务特定调整的损失函数,并通过 λ ∣ ∣ Δ W ∣ ∣ F 2 \lambda ||\Delta W||_F^2 λ∣∣ΔWF2 正则化每个任务的参数调整量。

隐藏层对于所有任务来说是共享的,这意味着模型可以学习通用的特征表示,而输出层则是特定于任务的。

基于特征和参数的MTL模型 (Feature- and Parameter-based MTL Models)

min ⁡ w 0 , Δ W , P 1 m ∑ i = 1 m 1 n i ∑ j = 1 n i ℓ ( P T x j i , y j i , w 0 + Δ w i ) + λ ∥ Δ W ∥ F 2 s.t.  P P T = I \min_{w_0, \Delta W,P} \frac{1}{m} \sum_{i=1}^{m} \frac{1}{n_i} \sum_{j=1}^{n_i} \ell(P^T x_j^{i}, y_j^{i},w_0 + \Delta w^i) + \lambda \|\Delta W\|_F^2 \text{ s.t. } PP^T = I w0,ΔW,Pminm1i=1mni1j=1ni(PTxji,yji,w0+Δwi)+λ∥ΔWF2 s.t. PPT=I

该模型旨在找到一个跨任务共享的特征投影( P P P)和一组针对所有任务优化的参数( w 0 w_0 w0 Δ W ΔW ΔW)。

  • 目标函数: min ⁡ w 0 , Δ W , P \min_{w_0, \Delta W,P} minw0,ΔW,P 表示我们的目标是最小化关于 w 0 w_0 w0(共享参数)、 Δ W \Delta W ΔW(任务特定参数变化)和 P P P(特征投影矩阵)的某个函数。

  • 任务平均: 1 m ∑ i = 1 m \frac{1}{m} \sum_{i=1}^{m} m1i=1m 表示我们考虑 m m m 个不同的任务,并对这些任务的结果取平均。

  • 任务内平均**: 对于每个任务 i i i 1 n i ∑ j = 1 n i \frac{1}{n_i} \sum_{j=1}^{n_i} ni1j=1ni 用于对该任务中的 n i n_i ni 个样本进行平均。

  • 损失函数: ℓ ( P T x j i , y j i , w 0 + Δ w i ) \ell(P^T x_j^{i}, y_j^{i},w_0 + \Delta w^i) (PTxji,yji,w0+Δwi) 是损失函数,用于量化模型预测 P T x j i P^T x_j^{i} PTxji(经过特征转换的输入)和真实标签 y j i y_j^{i} yji 之间的差异,同时考虑共享参数 w 0 w_0 w0 和任务特定参数的调整 Δ w i \Delta w^i Δwi

  • 正则化项: λ ∥ Δ W ∥ F 2 \lambda \|\Delta W\|_F^2 λ∥ΔWF2 是正则化项,用于防止过拟合。它通过控制任务特定参数变化的大小(使用Frobenius范数)来实现。

  • 约束条件: P P T = I PP^T = I PPT=I 是一个约束条件,确保投影矩阵 P P P 是正交的。这有助于保持映射后的特征间的独立性。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/188143.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于ubuntu22.04手动安装openstack——2023.2版本(最新版)的问题汇总

前言:基本上按照openstack官方网站动手可以搭建成功(如有需要私信发部署文档)。 但是任然有些小问题,所以汇总如下。 第一个问题 问题: ubuntu搭建2023.2版本neutorn报错,ERROR neutron.plugins.ml2.driv…

MATLAB仿真通信系统的眼图

eyediagram eyediagram(complex(used_i,used_q),1100)

与创新者同行,Apache Doris in 2023

在刚刚过去的 Doris Summit Asia 2023 峰会上,Apache Doris PMC 成员、飞轮科技技术副总裁衣国垒带来了“与创新者同行”的主题演讲,回顾了 Apache Doris 在过去一年所取得的技术突破与社区发展,重新思考了在面对海量数据实时分析上的挑战与机…

数据挖掘:分类,聚类,关联关系,回归

数据挖掘: 2022找工作是学历、能力和运气的超强结合体,遇到寒冬,大厂不招人,可能很多算法学生都得去找开发,测开 测开的话,你就得学数据库,sql,oracle,尤其sql要学&…

用于图像处理的高斯滤波器 (LoG) 拉普拉斯

一、说明 欢迎来到拉普拉斯和高斯滤波器的拉普拉斯的故事。LoG是先进行高斯处理,继而进行拉普拉斯算子的图像处理算法。用拉普拉斯具有过零功能,实现边缘岭脊提取。 二、LoG算法简述 在这篇博客中,让我们看看拉普拉斯滤波器和高斯滤波器的拉普…

某XX自考小程序的AES加密分析

前言 主要是报了自考在这个小程序上面做题,就研究了一下这个接口本文仅供学习交流使用,请勿随意传播。如有侵犯你的权益及时联系我删除。 一、抓包分析打开小程序,打开devtools 工具,这里就不啰嗦,直接上过程。 点击…

深入浅出理解ResNet网络模型+PyTorch实现

温故而知新,可以为师矣! 一、参考资料 论文:Identity Mappings in Deep Residual Networks 论文:Deep Residual Learning for Image Recognition ResNet详解PyTorch实现 PyTorch官方实现ResNet 【pytorch】ResNet18、ResNet20、…

LabVIEW示波器连续触发编程

LabVIEW示波器连续触发编程 niScopeEX Fetch Forever范例利用了如何设置硬件和驱动的优点来进行连续采集。 当NI-SCOPE设备被设置为采集预触发扫描,设备上的板载内存被用作一个环形缓冲。这样,无论触发何时到来,设备都可以追踪和检索所有要求…

使用 HTTP Client 轻松进行 API 测试

在开发过程中,我们经常需要测试 API 接口以确保其正常工作。JetBrains 的集成开发环境(IDE)如 CLion、IntelliJ IDEA、PyCharm 等,默认内置了 HTTP Client 插件,可以方便地进行API测试。本文将介绍如何使用HTTP Client…

共话医疗数据安全,美创科技@2023南湖HIT论坛,11月11日见

11月11日浙江嘉兴 2023南湖HIT论坛 如约而来 深入数据驱动运营管理、运营数据中心建设、数据治理和数据安全、数据资产“入表”等热点、前沿话题 医疗数据安全、数字化转型深耕者—— 美创科技再次深入参与 全新发布:医疗数据安全白皮书 深度探讨:数字…

提升中小企业效率的不可或缺的企业云盘网盘

相比之大型企业,中小型企业在挑选企业云盘工具更注重灵活性和成本。那么市面上有哪些企业云盘产品更适合中小企业呢? 说起中小企业不能错过的企业云盘网盘,Zoho Workdrive企业云盘绝对榜上有名! Zoho Workdrive企业云盘为用户提…

轻松按需缩放图片像素——批量处理图片,满足不同需求!

在图片使用过程中,我们经常需要按照不同的要求调整图片的像素大小。如果一张一张地手动调整,不仅耗时而且容易出错。这款软件可以帮助您轻松实现批量处理图片,按需缩放图片像素,让您的图片管理更加高效、便捷! 第一步…

反转链表 --- 递归回溯算法练习三

目录 1. 分析题意 2. 分析算法原理 2.1. 递归思路: 1. 挖掘子问题: 3. 编写代码 3.1. step 1: 3.2. step 2: 3.3. step 3: 3.4. 递归代码: 1. 分析题意 力扣原题链接如下: 206. 反转链…

ChatGPT 如何改变科研之路

《Nature》全球博士后调查[1]中约有三分之一的受访者正在使用人工智能聊天机器人来帮助完善文本、生成或编辑代码、整理其领域的文献等等。 来自巴西的 Rafael Bretas 在日本生活了十多年,日语说得很好。书面日语的各个方面,例如严格的礼貌等级制度&…

4面百度软件测试工程师的面试经验总结

没有绝对的天才,只有持续不断的付出。对于我们每一个平凡人来说,改变命运只能依靠努力幸运,但如果你不够幸运,那就只能拉高努力的占比。 2023年7月,我有幸成为了百度的一名测试工程师,从外包辞职了历经100…

多状态Dp问题——买卖股票的最佳时机含冷冻期

目录 一,题目 二,题目接口 三,解题思路及其代码 一,题目 给定一个整数数组prices,其中第 prices[i] 表示第 i 天的股票价格 。​ 设计一个算法计算出最大利润。在满足以下约束条件下,你可以尽可能地完成…

自媒体项目详述

总体框架 本项目主要着手于获取最新最热新闻资讯,以微服务构架为技术基础搭建校内仅供学生教师使用的校园新媒体app。以文章为主线的核心业务主要分为如下子模块。自媒体模块实现用户创建功能、文章发布功能、素材管理功能。app端用户模块实现文章搜索、文章点赞、…

electron 内部api capturePage实现webview截图

官方文档 .capturePage([rect]) rect Rectangle (可选) - 要捕获的页面区域。 返回 Promise - 完成后返回一个NativeImage 在 rect内捕获页面的快照。 省略 rect 将捕获整个可见页面。 async function cap(){ let image await webviewRef.value.capturePage() console.log(im…

Android修行手册 - POI操作Excel常用样式(字体,背景,颜色,Style)

点击跳转>Unity3D特效百例点击跳转>案例项目实战源码点击跳转>游戏脚本-辅助自动化点击跳转>Android控件全解手册点击跳转>Scratch编程案例点击跳转>软考全系列 👉关于作者 专注于Android/Unity和各种游戏开发技巧,以及各种资源分享&…

Java学习 9.Java-数组 讲解及习题

一、数组的定义与使用 1.数组的基本概念 1.1 为什么要使用数组 数组是最简单的一种数据结构 组织一组相同类型数据的集合 数据结构本身是来描述和组织数据的 数据加结构 1.2.1 数组的创建 代码实现 new int 可省略; char[] chars{a,b,c};//定义一个整形类型数…