【论文解读】元学习:MAML

一、简介

元学习的目标是在各种学习任务上训练模型,这样它就可以只使用少量的训练样本来解决新任务。
在这里插入图片描述

论文所提出的算法训练获取较优模型的参数,使其易于微调,从而实现快速自适应。该算法与任何用梯度下降训练的模型兼容,适用于各种学习问题,包括分类、回归和强化学习。
论文中表明,该算法在few-shot image classification基准上达到了SOTA的性能,在few-shot regression上也产出了良好的结果,并加速了策略梯度强化学习的微调

1.1 元学习与一般ML的区别

  • ML: 根据给定数据找到一个函数f,后续在相同的任务上运用该函数
  • Meta Learning: 根据大量任务(数据)找一个 F可以输出f 的能力,后续运用的时候在F上进行较少数据量的update 后就可以得到对应运用任务的函数f
    在这里插入图片描述

二、算法思路与伪代码(监督学习)

2.1 主要思路

核心思路就是找到一个较好的初始参数值,可以在任何同一类型的任务上进行少量数据较少次数update 后就可以得到较好的模型,下图展示了meta Learning 最终学习的参数 ϕ \phi ϕ
在这里插入图片描述

2.2 伪代码

Algorithm2 MAML for Few-Shot Supervised Learning Require: p ( T ) : distribution over tasks Require: α : 一系列task训练-supportSet,梯度更新学习率-在循环内更新 β : 一系列task评估-querySet,梯度更新学习率-在循环外更新 1: 初始化参数  θ 2:  while not done  do 3:  从任务集合中抽取任务  T i ∼ p ( T ) 4:  for all T i do 5:  从任务中抽取k shot个样本 D = { X j , Y j } ∈ T i 6:  基于任务的损失函数计算损失 L T i = l ( Y j , f θ i ( X j ) ) 7:  基于损失函数计算梯度, 并更新参数 ∂ L T i ∂ θ i = ∇ θ L T i ( f θ ) θ i ′ = θ − α ∇ θ L T i ( f θ ) 8:  从任务中抽取 q query 个样本 D ′ = { X j , Y j } ∈ T i 基于更新后的 θ ′ 进行预测并计算损失,用于循环后更新 L T i ′ = l ( Y j , f θ i ′ ( X j ) ) 计算梯度 ∂ L T i ′ ∂ θ i ′ = ∇ θ L T i ′ ( f θ ′ ) 计算最终梯度 ∇ θ L T i ( f θ ′ ) = ∂ L T i ′ ∂ θ i = ∂ L T i ′ ∂ θ i ′ ∂ θ i ′ ∂ θ i 9:  end for 10:  Update  θ ← θ − β ∑ T i ∼ p ( T ) ∇ θ L T i ( f θ ′ ) 11:  end while r e t u r n θ \begin{aligned} &\rule{110mm}{0.4pt} \\ &\text{Algorithm2 MAML for Few-Shot Supervised Learning}\\ &\rule{110mm}{0.4pt} \\ &\textbf{Require: } p(\mathcal{T}): \text{distribution over tasks}\\ &\textbf{Require: } \alpha \text{: 一系列task训练-supportSet,梯度更新学习率-在循环内更新} \\ &\hspace{17mm} \beta \text{: 一系列task评估-querySet,梯度更新学习率-在循环外更新}\\ &\rule{110mm}{0.4pt} \\ &\text{ 1: 初始化参数 } \theta \\ &\text{ 2: }\textbf{while }\text{not done }\textbf{do }\\ &\text{ 3: }\hspace{5mm}\text{从任务集合中抽取任务 }\mathcal{T}_i \sim p(\mathcal{T}) \\ &\text{ 4: }\hspace{5mm}\textbf{for all }\mathcal{T}_i\textbf{ do }\\ &\text{ 5: }\hspace{10mm}\text{从任务中抽取k shot个样本} \mathcal{D}=\{X^j, Y^j\} \in \mathcal{T}_i\\ &\text{ 6: }\hspace{10mm}\text{基于任务的损失函数计算损失} \mathcal{L}_{\mathcal{T}_i}=l(Y^j, f_{\theta_{i}}(X^j))\\ &\text{ 7: }\hspace{10mm}\text{基于损失函数计算梯度, 并更新参数} \frac{\partial{\mathcal{L}_{\mathcal{T}_i}}}{\partial \theta_i} = \nabla_\theta \mathcal{L}_{\mathcal{T}_i}(f_\theta) \\ &\hspace{17mm} \theta_i^{\prime} = \theta - \alpha \nabla_\theta \mathcal{L}_{\mathcal{T}_i}(f_\theta) \\ &\text{ 8: }\hspace{10mm}\text{从任务中抽取 q query 个样本} \mathcal{D}^{\prime}=\{X^j, Y^j\} \in \mathcal{T}_i\\ &\hspace{15mm} \text{基于更新后的}\theta^{\prime}\text{进行预测并计算损失,用于循环后更新} \mathcal{L}^{\prime}_{\mathcal{T}_i}=l(Y^j, f_{\theta^{\prime}_{i}}(X^j))\\ &\hspace{15mm} \text{计算梯度}\frac{\partial{\mathcal{L}^{\prime}_{\mathcal{T}_i}}}{\partial \theta^{\prime}_i} = \nabla_\theta \mathcal{L}^{\prime}_{\mathcal{T}_i}(f_{\theta^{\prime}}) \\ &\hspace{15mm} \text{计算最终梯度} \nabla_\theta \mathcal{L}_{\mathcal{T}_i}(f_{\theta^{\prime}}) = \frac{\partial{\mathcal{L}^{\prime}_{\mathcal{T}_i}}}{\partial \theta_i}=\frac{\partial{\mathcal{L}^{\prime}_{\mathcal{T}_i}}}{\partial \theta^{\prime}_i}\frac{\partial \theta^{\prime}_i}{\partial \theta_i} \\ &\text{ 9: }\hspace{5mm}\textbf{end for} \\ &\text{10: }\hspace{5mm}\text{Update } \theta \leftarrow \theta - \beta \sum_{\mathcal{T}_i \sim p(\mathcal{T})} \nabla_\theta \mathcal{L}_{\mathcal{T}_i}(f_{\theta^{\prime}}) \\ &\text{11: }\textbf{end while } \\ &\bf{return} \: \theta \\[-1.ex] &\rule{110mm}{0.4pt} \\[-1.ex] \end{aligned} Algorithm2 MAML for Few-Shot Supervised LearningRequire: p(T):distribution over tasksRequire: α一系列task训练-supportSet,梯度更新学习率-在循环内更新β一系列task评估-querySet,梯度更新学习率-在循环外更新 1: 初始化参数 θ 2: while not done do  3: 从任务集合中抽取任务 Tip(T) 4: for all Ti do  5: 从任务中抽取k shot个样本D={Xj,Yj}Ti 6: 基于任务的损失函数计算损失LTi=l(Yj,fθi(Xj)) 7: 基于损失函数计算梯度并更新参数θiLTi=θLTi(fθ)θi=θαθLTi(fθ) 8: 从任务中抽取 q query 个样本D={Xj,Yj}Ti基于更新后的θ进行预测并计算损失,用于循环后更新LTi=l(Yj,fθi(Xj))计算梯度θiLTi=θLTi(fθ)计算最终梯度θLTi(fθ)=θiLTi=θiLTiθiθi 9: end for10: Update θθβTip(T)θLTi(fθ)11: end while returnθ

三、简单实践

用Meta Learning 学习 y = a × s i n ( x + b ) y = a\times sin(x + b) y=a×sin(x+b), 不同的a, b代表不同的任务

3.1 任务数据准备

class SineWaveTask:def __init__(self):self.a = np.random.uniform(0.1, 5.0)self.b = np.random.uniform(1, 2 * np.pi)self.train_x = Nonedef f(self, x):return self.a * np.sin(x + self.b)def train_set(self, size=10, force_new=False):if self.train_x is None and not force_new:self.train_x = np.random.uniform(-5, 5, size)x = self.train_xelif not force_new:x = self.train_xelse:x = np.random.uniform(-5, 5, size)y = self.f(x)return torch.Tensor(x).float(), torch.Tensor(y).float()def test_set(self, size=50):x = np.linspace(-5, 5, size)y = self.f(x)return torch.Tensor(x).float(), torch.Tensor(y).float()def plot(self, *args, **kwargs):x, y = self.test_set()return plt.plot(x.cpu().detach().numpy(), y.cpu().detach().numpy(), *args, **kwargs)SineWaveTask().plot()
SineWaveTask().plot()
SineWaveTask().plot()
plt.show()

在这里插入图片描述

3.2 模型

因为query task中需要用support task后的参数进行推理,后进行二阶导来update 参数,所以多了一个query_forward 方法

class sineModel(nn.Module):def __init__(self):super(sineModel, self).__init__()self.l1 = nn.Linear(1, 40)self.l2 = nn.Linear(40, 40)self.head = nn.Linear(40, 1)def forward(self, x):x = torch.relu(self.l1(x))x = torch.relu(self.l2(x))return self.head(x)def query_forward(self, x, support_param_dict):x = torch.relu(F.linear(x, support_param_dict['l1.weight'], support_param_dict['l1.bias']))x = torch.relu(F.linear(x, support_param_dict['l2.weight'], support_param_dict['l2.bias']))return F.linear(x, support_param_dict['head.weight'], support_param_dict['head.bias'])SUPPORT_QUERY_TASKS = [SineWaveTask() for _ in range(1000)]
TEST_TASKS = [SineWaveTask() for _ in range(1000)]

3.3 MAML


def maml_sine(model, epochs, lr=1e-3, inner_lr=0.1, batch_size=1, first_order=False):opt = torch.optim.Adam(model.parameters(), lr=lr)loss_fn = nn.MSELoss()ep_loss = []for ep_i in range(epochs):tqd_bar = tqdm(enumerate(random.sample(SUPPORT_QUERY_TASKS, len(SUPPORT_QUERY_TASKS))),total=len(SUPPORT_QUERY_TASKS))tqd_bar.set_description(f'[ {ep_i+1:02d} / {epochs:02d} ]')task_loss = []for idx, suport_t in tqd_bar:fast_weights = OrderedDict(model.named_parameters())s_x, s_y = suport_t.train_set(force_new=False)q_x, q_y = suport_t.train_set(force_new=True)# supportfor _ in range(1): s_y_hat = model(torch.Tensor(s_x[:, None]))loss = loss_fn(s_y_hat, torch.Tensor(s_y.reshape(-1, 1)))grads = torch.autograd.grad(loss, fast_weights.values(), create_graph=not first_order) # 便于进行二阶导fast_weights = OrderedDict((name, param - inner_lr * (grad.detach().data if first_order else grad) )for ((name, param), grad) in zip(fast_weights.items(), grads))# querylogits = model.query_forward(torch.Tensor(q_x[:, None]), fast_weights)loss = loss_fn(logits, torch.Tensor(q_y.reshape(-1, 1)))task_loss.append(loss)if (idx + 1) % batch_size == 0:# updatemodel.train()opt.zero_grad()meta_batch_loss = torch.stack(task_loss).mean()meta_batch_loss.backward()opt.step()loss_item = meta_batch_loss.cpu().detach().numpy()tqd_bar.set_postfix({'loss': "{:.3f}".format(loss_item)})task_loss = []ep_loss.append(loss_item)return ep_losssine_model = sineModel()
ep_losses = maml_sine(sine_model, epochs=5, lr=1e-3, inner_lr=0.02, batch_size=2, first_order=False)

结果查看

全部代码见笔者github:maml.ipynb

maml训练结果显然要好于随机模型
在这里插入图片描述

参考

  • Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks
  • 李宏毅老师的课程PPT(国立台湾大学)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/133039.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

前端JavaScript深拷贝与浅拷贝

🎬 岸边的风:个人主页 🔥 个人专栏 :《 VUE 》 《 javaScript 》 ⛺️ 生活的理想,就是为了理想的生活 ! 目录 引言 1. 深拷贝的实现 1.1 基本类型和特殊类型的处理 1.2 处理循环引用 1.3 性能优化 1.4 完整的深拷贝实现示…

[qt]vs2022+qt5.13.2代码报错QChartView不明确

报错类似下面&#xff1a; 鼠标指上去错误代码显示QChartView不明确,解决方法 在xxx.ui对应的头文件包含”ui_xxx.h“的前方添加如下代码&#xff1a; #include <qchart.h> QT_CHARTS_USE_NAMESPACE

赋能3D智慧校园!老子云数字孪生可视化,学校运维高效之选!

老子云专注于3D领域&#xff0c;自主研发3D可视化底层&#xff0c;已打造了行业智慧园区、智慧交通、智慧机房、智慧水利等标杆案例&#xff0c;构建了可视化数字孪生智慧体系&#xff0c;其中智慧校园不仅实现了技术上的三维落地&#xff0c;更是成为了管控超百万师生校园安全…

SwiftUI 中的几种毛玻璃效果

Preview Code // // testtt.swift // bill2 // // Created by 朱洪苇 on 2023/8/9. //import SwiftUIstruct testtt: View {var body: some View {ZStack {Image("bg1").blur(radius: 5) // 给背景图加模糊VStack {Text("ultraThinMaterial").padding()…

【Linux从入门到精通】线程 | 线程介绍线程控制

本篇文章主要对线程的概念和线程的控制进行了讲解。其中我们再次对进程概念理解。同时对比了进程和线程的区别。希望本篇文章会对你有所帮助。 文章目录 一、线程概念 1、1 什么是线程 1、2 再次理解进程概念 1、3 轻量级进程 二、进程控制 2、1 创建线程 pthread_create 2、2…

AI绘画Stable Diffusion原理之扩散模型DDPM

前言 传送门&#xff1a; stable diffusion&#xff1a;Git&#xff5c;论文 stable-diffusion-webui&#xff1a;Git Google Colab Notebook部署stable-diffusion-webui&#xff1a;Git kaggle Notebook部署stable-diffusion-webui&#xff1a;Git AI绘画&#xff0c;输入一段…

SpringMVC文件的上传下载JRebel的使用

目录 前言 一、JRebel的使用 1.IDea内安装插件 2.激活 3.离线使用 使用JRebel的优势 二、文件上传与下载 1 .导入pom依赖 2.配置文件上传解析器 3.数据表 4.配置文件 5.前端jsp页面 6.controller层 7.测试结果 前言 当涉及到Web应用程序的开发时&…

Android窗口层级(Window Type)分析

前言 Android的窗口Window分为三种类型&#xff1a; 应用Window&#xff0c;比如Activity、Dialog&#xff1b;子Window&#xff0c;比如PopupWindow&#xff1b;系统Window&#xff0c;比如Toast、系统状态栏、导航栏等等。 应用Window的Z-Ordered最低&#xff0c;就是在系…

uni-app 使用uCharts-进行图表展示(折线图带单位)

前言 在uni-app经常是需要进行数据展示&#xff0c;针对这个情况也是有人开发好了第三方包&#xff0c;来兼容不同平台展示 uCharts和pc端的Echarts使用差不多&#xff0c;甚至会感觉在uni-app使用uCharts更轻便&#xff0c;更舒服 但是这个第三方包有优点就会有缺点&#xf…

医院安全不良事件报告系统源码 PHP+ vue2+element+ laravel8+ mysql5.7+ vscode开发

不良事件上报系统通过 “事前的人员知识培训管理和制度落地促进”、“事中的事件上报和跟进处理”、 以及 “事后的原因分析和工作持续优化”&#xff0c;结合预存上百套已正在使用的模板&#xff0c;帮助医院从对护理事件、药品事件、医疗器械事件、医院感染事件、输血事件、意…

数字人员工成企业得力助手,虚拟数字人为企业注入高科技基因

随着互联网和人工智能技术的快速发展&#xff0c;以“数字员工”为代表的数字生产力&#xff0c;正在出现在各行各业的业务场景中。数字人员工的出现不是替代人类&#xff0c;而是通过技术提高工作效率&#xff0c;实现更加智能化的服务体验&#xff0c;帮助企业实现大规模自动…

网上企业订货系统功能列表介绍|企业APP订单管理软件

网上企业订货系统功能列表介绍|企业APP订单管理软件 后台功能列表 &#xff08;后台支持手机版本 订货APP,管理订单的APP&#xff09; 后台登陆 输入账号密码登录企业订货管理软件系统 后台首页 显示近日,月,年订单统计&#xff0c;和收款欠款等统计。 订单模块 新建订单 …

人脸识别三部曲

人脸识别三部曲 首先看目录结构图像信息采集 采集图片.py模型训练 训练模型.py人脸识别 人脸识别.py效果 首先看目录结构 引用文121本 opencv │ 采集图片.py │ 训练模型.py │ 人脸识别.py │ └───trainer │ │ trainer.yml │ └───data │ └──…

使用 Sealos 一键部署高可用 MinIO,开启对象存储之旅

大家好&#xff01;今天这篇文章主要向大家介绍如何通过 Sealos 一键部署高可用 MinIO 集群。 MinIO 对象存储是什么&#xff1f; 对象是二进制数据&#xff0c;例如图像、音频文件、电子表格甚至二进制可执行代码。对象的大小可以从几 B 到几 TB 不等。像 MinIO 这样的对象存储…

零基础学前端(三)重点讲解 HTML

1. 该篇适用于从零基础学习前端的小白 2. 初学者不懂代码得含义也要坚持模仿逐行敲代码&#xff0c;以身体感悟带动头脑去理解新知识 3. 初学者切忌&#xff0c;不要眼花缭乱&#xff0c;不要四处找其它文档&#xff0c;要坚定一个教授者的方式&#xff0c;将其学通透&#xff…

家政服务接单小程序开发源码 家政保洁上门服务小程序源码 开源完整版

分享一个家政服务接单小程序开发源码&#xff0c;家政保洁上门服务小程序源码&#xff0c;一整套完整源码开源&#xff0c;可二开&#xff0c;含完整的前端后端和详细的安装部署教程&#xff0c;让你轻松搭建家政类的小程序。家政服务接单小程序开发源码为家政服务行业带来了诸…

shell脚本学习笔记02(小滴课堂)

可以在home目录下创建一个shell.sh文件。 按w进入命令行模式。按i进入插入模式。如果想返回命令行模式&#xff0c;按esc即可。然后可以使用x和dd进行删除内容。 在插入模式下我们点击esc键&#xff0c;再去按:键&#xff0c;我们就可以进入到底行模式了&#xff1a; 可以设…

518抽奖软件,可生成几排几列的号码座号

518抽奖软件简介 518抽奖软件&#xff0c;518我要发&#xff0c;超好用的年会抽奖软件&#xff0c;简约设计风格。 包含文字号码抽奖、照片抽奖两种模式&#xff0c;支持姓名抽奖、号码抽奖、数字抽奖、照片抽奖。(www.518cj.net) 生成号码/座号 入口&#xff1a; 主界面上点…

基于深度学习的加密恶意流量检测

加密恶意流量检测 研究目标定位数据收集数据处理基于特征分类算法的数据预处理基于源数据分类算法的数据预处理 特征提取模型选择基于数据特征的深度学习检测算法基于特征自学习的深度学习检测算法 训练和评估精确性指标实时性指标 应用检验改进 摘录自&#xff1a;Mingfang ZH…

ZABBIX 6.4官方安装文档

一、官网地址 Zabbix&#xff1a;企业级开源监控解决方案 二、下载 1.选择您Zabbix服务器的平台 2. Install and configure Zabbix for your platform a. Install Zabbix repository # rpm -Uvh https://repo.zabbix.com/zabbix/6.4/rhel/8/x86_64/zabbix-release-6.4-1.el8…