【博士每天一篇文献-算法】Overcoming catastrophic forgetting in neural networks

阅读时间:2023-10-24

1 介绍

年份:2016
作者:James Kirkpatrick, Razvan Pascanu, Neil Rabinowitz, Joel Veness, Guillaume Desjardins, Andrei A. Rusu, Kieran Milan, John Quan, Tiago Ramalho, Agnieszka Grabska-Barwinska, Demis Hassabis, Claudia Clopath, Dharshan Kumaran, Raia Hadsell,加州史丹佛大學史丹佛大學
期刊:Proceedings of the national academy of sciences
引用量:5449
这篇论文的主题是关于神经网络如何克服灾难性遗忘的问题,灾难性遗忘是神经网络在顺序学习任务时的一个限制。论文提出了一种称为弹性权重合并(EWC)的方法,可以使神经网络在学习新任务的同时记住旧任务。EWC会有选择地降低对先前学习任务重要的权重的学习速度,从而防止灾难性遗忘。作者通过在MNIST数据集上解决分类任务和顺序学习Atari 2600游戏的实验来证明EWC的有效性。论文将EWC与其他方法如L2正则化和dropout正则化进行了比较,结果表明EWC在保持旧任务高性能的同时学习新任务方面优于这些方法。论文解释了EWC的实现和合理性,包括如何约束重要参数和确定哪些权重对于每个任务是重要的。论文还讨论了哺乳动物大脑可能支持无灾难遗忘连续学习的神经机制。总的来说,这篇论文通过使用EWC提出了解决神经网络灾难性遗忘问题的方法。

2 创新点

  1. EWC方法:论文提出了一种名为弹性权重整合的算法,用于实现神经网络的连续学习。该算法根据先前学习任务中权重的重要性,减缓学习过程,从而保留旧任务的知识。
  2. 在MNIST数据集和Atari 2600游戏中的应用:论文通过在MNIST数据集上进行分类任务和在Atari 2600游戏中进行学习来展示EWC的有效性。结果表明,相比于L2正则化和dropout正则化等其他方法,EWC在学习新任务的同时能够维持旧任务的高性能。
  3. EWC的实施和正当性:论文解释了EWC的具体实施和合理性,包括对重要参数的约束和确定每个任务中哪些权重是重要的。论文还提到了哺乳动物大脑中支持连续学习而不发生灾难性遗忘的神经机制。

3 算法

(1)计算步骤

  • 计算每个权重在先前任务中的重要性:
    • 先前任务的损失函数:L_prev(θ),其中θ表示网络的权重。
    • Fisher信息矩阵:F_prev(θ) = E[∇²L_prev(θ)],其中∇²表示梯度的二阶导数。
    • 权重重要性:I_prev(θ) = F_prev(θ) * (θ - θ_prev)²,其中θ_prev表示在先前任务上训练后的权重。
  • 计算当前任务的损失函数:当前任务的损失函数:L_curr(θ)。
  • 计算正则化项并更新网络权重:
    • 正则化项:EWC_loss(θ) = L_curr(θ) + λ * Σ[I_prev(θ)], 其中λ是正则化项的权重,Σ表示对所有权重求和。
    • 更新网络权重:θ_new = argmin(θ)[EWC_loss(θ)]

(2)推理过程
训练了一个模型,其参数为 θ \theta θ,定义最小化以下损失函数来完成此操作:
L ( θ ) = L n e w ( θ ) + ∑ i = 1 n λ 2 F i ( θ i − θ i ∗ ) 2 \mathcal{L}(\theta) = \mathcal{L}_{new}(\theta) + \sum_{i=1}^{n} \frac{\lambda}{2} F_i (\theta_i - \theta_i^*)^2 L(θ)=Lnew(θ)+i=1n2λFi(θiθi)2

其中, L n e w ( θ ) \mathcal{L}_{new}(\theta) Lnew(θ) 是新任务的损失函数,n 是先前任务的数量, F i F_i Fi 是 Fisher 信息矩阵的对角线元素, θ i ∗ \theta_i^* θi 是在先前任务i中找到的最优参数。 λ \lambda λ 是一个超参数,控制先前任务对新任务的影响。
Fisher 信息矩阵是 Hessian 矩阵的期望值,它衡量了损失函数对参数的二阶导数。 在 EWC 中,只计算对角线元素,因为它们提供了最大的信息,同时也更容易计算。Fisher 信息矩阵的对角线元素可以通过以下公式计算:
F i , j = E x ∼ D i [ ∂ log ⁡ p ( y ∣ x , θ ) ∂ θ i ∂ log ⁡ p ( y ∣ x , θ ) ∂ θ j ] F_{i,j} = \mathbb{E}_{x\sim D_i}[\frac{\partial \log p(y|x,\theta)}{\partial \theta_i} \frac{\partial \log p(y|x,\theta)}{\partial \theta_j}] Fi,j=ExDi[θilogp(yx,θ)θjlogp(yx,θ)]

其中, D i D_i Di是先前任务i的数据分布, p ( y ∣ x , θ ) p(y|x,\theta) p(yx,θ)是模型在给定输入x 和参数 θ \theta θ的情况下预测输出y的概率分布。
在每次学习新任务之前,需要计算 Fisher 信息矩阵和最优参数 θ i ∗ \theta_i^* θi。这可以通过在先前任务上运行梯度下降来实现,直到收敛为止。一旦计算出 Fisher 信息矩阵和最优参数,就可以使用 EWC 来学习新任务,同时保留先前任务的知识。
最后,可以使用以下公式计算 EWC 梯度:
g i = ∇ θ i L n e w ( θ ) + λ ∑ j = 1 n F i , j ( θ i − θ i ∗ ) g_i = \nabla_{\theta_i} \mathcal{L}_{new}(\theta) + \lambda \sum_{j=1}^{n} F_{i,j} (\theta_i - \theta_i^*) gi=θiLnew(θ)+λj=1nFi,j(θiθi)
其中, g i g_i gi是 EWC 梯度, ∇ θ i L n e w ( θ ) \nabla_{\theta_i} \mathcal{L}_{new}(\theta) θiLnew(θ)是新任务的梯度。通过添加正则化项,EWC 可以确保新任务不会完全覆盖先前任务的知识,从而在连续学习中实现知识共享。

5 实验结果分析

(1)总结一
image.png

  • 使用纯随机梯度下降(SGD)训练这个任务序列会引发灾难性遗忘。
  • 图2A展示了两个不同任务的测试集性能。在训练从第一个任务切换到第二个任务时,任务B的性能迅速下降,而任务A的性能迅速上升。
  • 任务A的遗忘问题会随着更长的训练时间而进一步恶化。
  • 使用L2正则化不能解决这个问题,因为它对所有权重施加了相同的保护限制,导致在任务B上学习的能力受到限制。
  • 然而,使用EWC可以根据任务A中每个权重的重要性,使网络能够在不遗忘任务A的情况下很好地学习任务B。
  • 图2B展示了使用EWC和使用SGD与dropout正则化的所有任务的平均性能。可以看到EWC在旧任务上保持了高性能,并且仍然能够学习新任务。
  • 图2C展示了两个不同置换程度下网络深度的Fisher信息矩阵的相似性。任务越不相似,早期层的Fisher信息矩阵重叠越小。

(2)总结二

  • 当网络在两个非常相似的任务上训练(两个MNIST版本,只有少数像素被重排),这两个任务在整个网络中依赖于相似的权重集
  • 当两个任务之间更不相似时,网络开始为两个任务分配单独的能力(即权重)。

在进行大量重排时,网络靠近输出的层确实被两个任务重复使用。这反映了重排使得输入对内容是非常不同的,但输出的内容(即类别标签)是共享的。
(3)总结三
EWC可以在要求更高的强化学习(RL)领域中支持连续学习。作者测试了在经典的Atari 2600游戏集上,将Deep Q Networks与EWC相结合的方法。实验中,通过使用EWC,能够学习多个游戏,而不会忘记以前学习的游戏 。与以前的RL方法相比,EWC利用了固定资源(即网络容量)的单个网络,并且计算开销较小。

6 代码

https://github.com/yashkant/Elastic-Weight-Consolidation

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/174907.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

ChatGPT从入门到精通

目录 什么是ChatGPT?ChatGPT能帮我干什么?标题在哪里可以使用ChatGPT?什么是ILoveChatGPT(IMYAI)?标题如何拥有头像?如何获取更多对话次数?!标题如何提问GPT?如何正确地利…

0X01

打开题目 点了几下跳出一个新的页面 点击secret 在上一个页面查看源代码,出现action.php然后点击之后就会在地址栏里面出现end.php 抓包看看,出现secr3t.php huidao开始的页面,访问看看 这是一个PHP脚本,以HTML标签开头。该脚本包…

1300*C. Social Distance(贪心构造)

Problem - 1367C - Codeforces 解析&#xff1a; 统计出所有连续0序列&#xff0c;并且记录其左右两侧有没有1&#xff0c;然后对于四种情况分别判断即可。 #include<bits/stdc.h> using namespace std; int t,n,k; signed main(){scanf("%d",&t);while(…

python使用ffmpeg来制作音频格式转换工具(优化版)

简介:一个使用python加上ffmpeg模块来进行音频格式转换的工具。 日志: 20231030:第一版,设置了简单的UI布局和配色,实现音频转为Mp3、AAC、wav、flac四种格式。可解析音频并显示信息,可设置转换后的保存路径 UI界面: 编程平台:visual studio code 编程语言:python 3…

YugaByteDB -- 全新的 “PostgreSQL“ 存储层

文章目录 0 背景1 架构1.1 Master1.2 TServer1.3 Tablet 2 读写链路2.1 DDL2.2 DML2.3 事务 3 KEY 的设计4 Rocksdb 在 YB 中的一些实践总结 0 背景 YugaByteDB 的诞生也是抓住了 spanner 推行的NewSQL 浪潮的尾巴&#xff0c;以 PG 生态为基础 用C实现的 支持 SQL 以及 CQL 语…

Android---如何同view进行渲染

ViewRootImpl 在 Activity、window 和 View 三者关系之间起着承上启下的作用。一方面&#xff0c;ViewRootImpl 中通过 Binder 通信机制&#xff0c;远程调用 WindowSession 将 View 添加到 Window 中&#xff1b;另一方面&#xff0c;ViewRootImpl 在添加 View 之前&#xff0…

vscode打开settings.json方法

cmd shift p&#xff0c;输入setting Open Workspace Settings 也会打开UI设置界面&#xff1b; Open User Settings (JSON) 会打开用户设置 settings.json 文件&#xff1b; Open Workspace Settings (JSON) 会打开工作区设置 settings.json 文件 vscode存在两种设置 sett…

Rust编程基础之变量与可变性

1.Rust变量 在Rust语言中, 变量默认是不可改变的(immutable), 这是Rust提供给我们的众多优势之一, 让我们可以充分利用Rust提供的安全性和简单并发性来编写代码。 当变量不可变时, 一旦值被绑定在一个名称上, 就不能改变这个值。下面是一段代码的例子: fn main() {let x 1;…

Panda3d 介绍

Panda3d 介绍 文章目录 Panda3d 介绍Panda3d 的安装Panda3d 的坐标系统介绍Panda3d 的运行Panda3d 加载一个熊猫父节点和子节点之间的关系 验证Panda3d 的坐标系统X 轴的平移Y 轴的平移Z 轴的平移X 轴的旋转Y 轴的旋转Z 轴的旋转 Panda3D是一个3D引擎:一个用于3D渲染和游戏开发…

[Linux]线程池

[Linux]线程池 文章目录 [Linux]线程池线程池的概念线程池的优点线程池的应用场景线程池的实现 线程池的概念 线程池是一种线程使用模式。线程池是一种特殊的生产消费模型&#xff0c;用户作为生产者&#xff0c;线程池作为消费者和缓冲区。 线程过多会带来调度开销&#xff0c…

Generalized Zero-Shot Learning With Multi-Channel Gaussian Mixture VAE

L D A _{DA} DA​最大化编码后两种特征分布之间的相似性 辅助信息 作者未提供代码

1400*C. Element Extermination(贪心规律)

Problem - 1375C - Codeforces 解析&#xff1a; 可以发现&#xff0c;最左端的数字&#xff0c;无论删除自己还是下一个&#xff0c;这个位置的值都不会变小。 同理&#xff0c;最右端位置的值都不会变大。 所以当最后剩余两个数字的时候&#xff0c;只有左端小于右端数字&…

【经典面试】87 字符串解码

字符串解码 题解1 递归(程序栈)——形式语言自动机(LL(1)) : O(S)另一种递归(直观) 题解2 2个栈(逆波兰式)1个栈(参考官方&#xff0c;但是不喜欢) 给定一个经过编码的字符串&#xff0c;返回它解码后的字符串。 编码规则为: k[encoded_string]&#xff0c;表示其中方括号内部的…

深入探究深度学习、神经网络与卷积神经网络以及它们在多个领域中的应用

目录 1、什么是深度学习&#xff1f; 2、深度学习的思想 3、深度学习与神经网络 4、深度学习训练过程 4.1、先使用自下上升非监督学习&#xff08;就是从底层开始&#xff0c;一层一层的往顶层训练&#xff09; 4.2、后自顶向下的监督学习&#xff08;就是通过带标签的数…

文件管理怎么清内存?效率提升一倍

定期清理文件管理可以释放存储空间和提高系统性能。随着时间的推移&#xff0c;手机中可能会存储大量无用的数据&#xff0c;例如缓存、垃圾文件等&#xff0c;导致系统运行缓慢。那么如何清理文件管理的内存呢&#xff1f;下面介绍三种方法。 一、搜索无用的文件夹进行清理 1…

【Bug—eNSP】华为eNsp路由器设备启动一直是0解决方案!

目录 一、项目场景 二、问题描述 三、原因分析 四、解决方案 注意&#

(二开)Flink 修改源码拓展 SQL 语法

1、Flink 扩展 calcite 中的语法解析 1&#xff09;定义需要的 SqlNode 节点类-以 SqlShowCatalogs 为例 a&#xff09;类位置 flink/flink-table/flink-sql-parser/src/main/java/org/apache/flink/sql/parser/dql/SqlShowCatalogs.java 核心方法&#xff1a; Override pu…

高阶数据结构学习 —— 图(1)

文章目录 1、并查集2、了解图3、邻接矩阵4、压缩路径5、基本概念6、邻接表 1、并查集 并查集是一个森林&#xff0c;是由多棵树组成的。这相当于整套数据&#xff0c;分成多个集合。查找有交集的集合们&#xff0c;会把它们合并起来&#xff0c;所以叫并查集。 一开始拿到的是…

电脑突然提示找不到msvcp140.dll怎么办,解决msvcp140.dll丢失的办法

当我们在电脑上运行某些软件或游戏时&#xff0c;可能会遇到一个常见的错误消息&#xff1a;“找不到msvcp140.dll”。出现这样的情况通常意味着系统缺少一个重要的动态链接库文件&#xff0c;而这可能会导致程序无法正常启动。如果你现在遇到了这个问题&#xff0c;哪有可以用…

3D LUT 滤镜 shader 源码分析

最近在做滤镜相关的渲染学习&#xff0c;目前大部分 LUT 滤镜代码实现都是参考由 GPUImage 提供的 LookupFilter 的逻辑&#xff0c;整个代码实现不多。参考网上的博文也有各种解释&#xff0c;参考了大量博文之后终于理解了&#xff0c;所以自己重新整理了一份&#xff0c;方便…