【2.10-2.16学习周报】

文章目录

  • 摘要
  • Abstract
  • 一、理论方法介绍
    • 1.模糊类增量学习
    • 2.Rainbow Memory(RM)
      • 2.1多样性感知内存更新
      • 2.2通过数据增强增强样本多样性(DA)
  • 二、实验
    • 1.实验概况
    • 2.RM核心代码
    • 3.实验结果
  • 总结


摘要

本博客概述了文章《Rainbow Memory: Continual Learning with a Memory of Diverse Samples》聚焦于任务边界模糊的持续学习场景,提出基于样本分类不确定性和数据增强的Rainbow Memory (RM)记忆管理策略。多数研究在任务不共享类别的较人为的设置下评估相关方法,但在现实世界应用场景中,任务之间的类分布是不断变化的,更现实和实用的是任务共享类别的模糊CIL设置。在这种设置下,之前存储少量旧数据的方法虽在缓解灾难性遗忘方面有成果,但也引出了如何管理记忆(memory)的最优策略问题。基于该问题,研究者在新定义的模糊CIL设置下更好地持续学习的两个因素:记忆的采样和记忆中的数据增强,进而提出Rainbow Memory(RM)方法。通过在MNIST、CIFAR10、CIFAR100和ImageNet数据集上的实证验证,RM在模糊持续学习设置中显著提高了准确性,大幅超越现有技术。

文章链接
实验代码链接

Abstract

This blog summarizes the article “Rainbow Memory: Continual Learning with a Memory of Diverse Samples”, which focuses on the continuous learning scenario with fuzzy task boundaries, and proposes a Rainbow Memory (RM) memory management strategy based on sample classification uncertainty and data augmentation. Most studies evaluate the relevant methods in a more artificial setting where tasks do not share categories, but in real-world application scenarios, the class distribution between tasks is constantly changing, and it is more realistic and practical to see the fuzzy CIL settings of task sharing categories. In this setting, the previous method of storing a small amount of old data has been successful in mitigating catastrophic forgetting, but it also raises the question of the optimal strategy for managing memory. Based on this problem, the researchers proposed a rainbow memory (RM) method for better continuous learning under the newly defined fuzzy CIL setting: memory sampling and data enhancement in memory. Through empirical verification on MNIST, CIFAR10, CIFAR100, and ImageNet datasets, RM significantly improves accuracy in fuzzy continuous learning settings, significantly outperforming existing technologies.

一、理论方法介绍

1.模糊类增量学习

模糊类增量学习的设置要求如下:1)每个任务作为流顺序地给出,(2)大多数(分配的)任务类别彼此不同,以及(3)模型只能利用先前任务的非常小的一部分数据。 如下图所示,在模糊CIL中,任务共享类,与传统的不相交CIL相反。建议的记忆管理策略更新的情景记忆与当前任务的样本,以保持不同的样本在内存中。数据扩充(DA)进一步增强了内存中样本的多样性。
在这里插入图片描述

2.Rainbow Memory(RM)

在模糊类增量学习的场景中,现有方法因样本多样性不足导致模型易过拟合或遗忘严重。为了解决该问题,研究者提出了Rainbow Memory(RM),RM提出通过多样性记忆管理和数据增强解决 Blurry-CIL 问题。

2.1多样性感知内存更新

研究者认为,被选择存储在内存中的样本应该不仅是代表其相应的类,还要识别其他类。为了选择这样的样本,研究者认为,在分类边界附近的样本是最具鉴别力的,靠近分布中心的样本是最具代表性的。为了满足这两个特点,研究者建议抽样的样本是不同的特征空间。

由于计算样本与样本之间的距离O(N2)较为复杂和昂贵,研究者通过分类模型估计的样本的不确定性来估计相对位置,即假设模型的更确定的样本将位于更靠近类分布的中心,通过测量扰动样本的模型输出方差来计算样本的不确定性,扰动样本通过各种数据增强转换方法进行:包括颜色抖动、剪切和剪切,如下图所示:在这里插入图片描述
通过蒙特-卡罗(MC)法近似计算分布p(y = c)的不确定度|x),当给定扰动样本x的先验时,即p(x| x)的情况下,推导过程可以写成:
在这里插入图片描述
其中,x、x^~、y和A分别表示样本、扰动样本、样本的标签和扰动方法的数量。分布D * 表示由扰动样本λ x定义的数据分布。特别地,扰动样本λ x由随机函数fr(·)绘制,如下:
在这里插入图片描述
其中θr是表示第r次扰动的随机因子的超参数。
测量样品相对于扰动的不确定性为:
在这里插入图片描述
其中u(x)表示样本x的不确定性,Sc是类别c是预测的前1类别的次数。1c表示二进制类索引向量。较低的u(x)值对应于扰动上更一致的top-1类,表明x位于模型强置信的区域.

2.2通过数据增强增强样本多样性(DA)

为了进一步增强记忆中的示例的多样性,研究者采用了数据增强(DA)。 DA的通过图像级或特征扰动使给定的样本多样化,这对应于通过确保多样性来更新内存的理念
随着任务迭代的进行,新任务中的样本可能会遵循与情节内存中的样本(即,从以前的任务中)遵循不同的分布。 研究者在新任务的类别和内存中旧类的示例中采用混合标记的DA来“混合”图像。 这种混合标签DA减轻了由类分布在任务上的变化引起的副作用,并改善了表现。
混合标记的DA方法之一,CutMix 生成了混合样品和平滑标签,鉴于一组监督样品(X1,Y1)和(X2,Y2),其公式如下:
在这里插入图片描述
该集合根据从β-分布中得出的超参数β表示图像x1的随机选择像素区域.

二、实验

1.实验概况

研究者通过将RM与各种实验设置中的艺术状态进行比较,从经验上验证了RM的功效。 基准测试的CIL任务设置,情节内存的内存大小和性能指标。在MNIST、CIFAR10、CIFAR100和ImageNet数据集上进行实验。采用多种CIL任务设置、不同的记忆大小和性能指标评估RM方法。将RM与EWC、Rwalk、iCaRL等标准CIL方法对比 ,比较不同方法在各种设置下的Last Accuracy(A5)、Last Forgetting(F5)和Intransigence(I5)等指标。分析RM在不同模糊水平(如Blurry0、Blurry10、Blurry30)下的性能,还探究了不确定性测量方法、记忆更新算法、数据增强方法等对性能的影响。

2.RM核心代码

RM部分的完整核心代码如下:

import logging
import randomimport numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriterfrom methods.finetune import Finetune
from utils.data_loader import cutmix_data, ImageDatasetlogger = logging.getLogger()
writer = SummaryWriter("tensorboard")def cycle(iterable):# iterate with shufflingwhile True:for i in iterable:yield iclass RM(Finetune):def __init__(self, criterion, device, train_transform, test_transform, n_classes, **kwargs):super().__init__(criterion, device, train_transform, test_transform, n_classes, **kwargs)self.batch_size = kwargs["batchsize"]self.n_worker = kwargs["n_worker"]self.exp_env = kwargs["stream_env"]if kwargs["mem_manage"] == "default":self.mem_manage = "uncertainty"def train(self, cur_iter, n_epoch, batch_size, n_worker, n_passes=0):if len(self.memory_list) > 0:mem_dataset = ImageDataset(pd.DataFrame(self.memory_list),dataset=self.dataset,transform=self.train_transform,)memory_loader = DataLoader(mem_dataset,shuffle=True,batch_size=(batch_size // 2),num_workers=n_worker,)stream_batch_size = batch_size - batch_size // 2else:memory_loader = Nonestream_batch_size = batch_size# train_list == streamed_list in RMtrain_list = self.streamed_listtest_list = self.test_listrandom.shuffle(train_list)# Configuring a batch with streamed and memory data equally.train_loader, test_loader = self.get_dataloader(stream_batch_size, n_worker, train_list, test_list)logger.info(f"Streamed samples: {len(self.streamed_list)}")logger.info(f"In-memory samples: {len(self.memory_list)}")logger.info(f"Train samples: {len(train_list)+len(self.memory_list)}")logger.info(f"Test samples: {len(test_list)}")# TRAINbest_acc = 0.0eval_dict = dict()self.model = self.model.to(self.device)for epoch in range(n_epoch):# initialize for each taskif epoch <= 0:  # Warm start of 1 epochfor param_group in self.optimizer.param_groups:param_group["lr"] = self.lr * 0.1elif epoch == 1:  # Then set to maxlrfor param_group in self.optimizer.param_groups:param_group["lr"] = self.lrelse:  # Aand go!self.scheduler.step()train_loss, train_acc = self._train(train_loader=train_loader, memory_loader=memory_loader,optimizer=self.optimizer, criterion=self.criterion)eval_dict = self.evaluation(test_loader=test_loader, criterion=self.criterion)writer.add_scalar(f"task{cur_iter}/train/loss", train_loss, epoch)writer.add_scalar(f"task{cur_iter}/train/acc", train_acc, epoch)writer.add_scalar(f"task{cur_iter}/test/loss", eval_dict["avg_loss"], epoch)writer.add_scalar(f"task{cur_iter}/test/acc", eval_dict["avg_acc"], epoch)writer.add_scalar(f"task{cur_iter}/train/lr", self.optimizer.param_groups[0]["lr"], epoch)logger.info(f"Task {cur_iter} | Epoch {epoch+1}/{n_epoch} | train_loss {train_loss:.4f} | train_acc {train_acc:.4f} | "f"test_loss {eval_dict['avg_loss']:.4f} | test_acc {eval_dict['avg_acc']:.4f} | "f"lr {self.optimizer.param_groups[0]['lr']:.4f}")best_acc = max(best_acc, eval_dict["avg_acc"])return best_acc, eval_dictdef update_model(self, x, y, criterion, optimizer):optimizer.zero_grad()do_cutmix = self.cutmix and np.random.rand(1) < 0.5if do_cutmix:x, labels_a, labels_b, lam = cutmix_data(x=x, y=y, alpha=1.0)logit = self.model(x)loss = lam * criterion(logit, labels_a) + (1 - lam) * criterion(logit, labels_b)else:logit = self.model(x)loss = criterion(logit, y)_, preds = logit.topk(self.topk, 1, True, True)loss.backward()optimizer.step()return loss.item(), torch.sum(preds == y.unsqueeze(1)).item(), y.size(0)def _train(self, train_loader, memory_loader, optimizer, criterion):total_loss, correct, num_data = 0.0, 0.0, 0.0self.model.train()if memory_loader is not None and train_loader is not None:data_iterator = zip(train_loader, cycle(memory_loader))elif memory_loader is not None:data_iterator = memory_loaderelif train_loader is not None:data_iterator = train_loaderelse:raise NotImplementedError("None of dataloder is valid")for data in data_iterator:if len(data) == 2:stream_data, mem_data = datax = torch.cat([stream_data["image"], mem_data["image"]])y = torch.cat([stream_data["label"], mem_data["label"]])else:x = data["image"]y = data["label"]x = x.to(self.device)y = y.to(self.device)l, c, d = self.update_model(x, y, criterion, optimizer)total_loss += lcorrect += cnum_data += dif train_loader is not None:n_batches = len(train_loader)else:n_batches = len(memory_loader)return total_loss / n_batches, correct / num_datadef allocate_batch_size(self, n_old_class, n_new_class):new_batch_size = int(self.batch_size * n_new_class / (n_old_class + n_new_class))old_batch_size = self.batch_size - new_batch_sizereturn new_batch_size, old_batch_size

1.内存管理与数据混合(对应论文 Section 4.1)
将内存中的旧任务样本(memory_loader)与当前任务的流数据(train_loader)按比例混合(默认各占50%)。

使用cycle(memory_loader)循环读取内存数据,避免内存样本因容量限制被忽略。
实现多样性记忆回放,通过混合新旧任务样本缓解灾难性遗忘,确保模型同时学习新任务和巩固旧任务知识。

def train(self, cur_iter, n_epoch, batch_size, n_worker, n_passes=0):# 加载内存数据(旧任务样本)和流数据(新任务样本)if len(self.memory_list) > 0:mem_dataset = ImageDataset(self.memory_list, transform=self.train_transform)memory_loader = DataLoader(mem_dataset, batch_size=(batch_size // 2), ...)stream_batch_size = batch_size - batch_size // 2else:memory_loader = Nonestream_batch_size = batch_size# 混合流数据和内存数据data_iterator = zip(train_loader, cycle(memory_loader))  # 循环迭代内存数据x = torch.cat([stream_data["image"], mem_data["image"]])y = torch.cat([stream_data["label"], mem_data["label"]])
  1. 数据增强:CutMix
    以50%概率应用CutMix,将两张图像局部区域混合,并生成对应的混合标签(labels_a和labels_b)。

计算混合损失(lam * loss_a + (1-lam) * loss_b),鼓励模型学习更鲁棒的特征,实现标签混合增强(Section 4.2),通过生成边界复杂的样本提升记忆库多样性,增强模型泛化能力。

def update_model(self, x, y, criterion, optimizer):# CutMix增强:混合图像和标签do_cutmix = self.cutmix and np.random.rand(1) < 0.5if do_cutmix:x, labels_a, labels_b, lam = cutmix_data(x=x, y=y, alpha=1.0)logit = self.model(x)loss = lam * criterion(logit, labels_a) + (1 - lam) * criterion(logit, labels_b)else:logit = self.model(x)loss = criterion(logit, y)

3.动态学习率与批量调整

# Warm start学习率调整
if epoch <= 0:for param_group in self.optimizer.param_groups:param_group["lr"] = self.lr * 0.1  # 初始低学习率
elif epoch == 1:param_group["lr"] = self.lr  # 恢复基准学习率
else:self.scheduler.step()         # 后续按计划调整# 动态调整新旧任务批量大小
def allocate_batch_size(self, n_old_class, n_new_class):new_batch_size = int(self.batch_size * n_new_class / (n_old_class + n_new_class))old_batch_size = self.batch_size - new_batch_sizereturn new_batch_size, old_batch_size

初始阶段使用低学习率(10%基准值)进行预热(Warm-up),避免训练初期不稳定。

根据新旧类别比例动态分配批量大小,平衡新旧任务的学习强度,防止新任务数据主导学习过程。
4. 训练流程与评估

# 训练与评估循环
for epoch in range(n_epoch):train_loss, train_acc = self._train(...)  # 训练eval_dict = self.evaluation(...)          # 评估logger.info(f"Task {cur_iter} | Epoch {epoch+1} | train_acc {train_acc:.4f} | test_acc {eval_dict['avg_acc']:.4f}")

3.实验结果

研究者将提出的RM与各种数据集的“ Blurry10-Online”设置中的其他方法进行了比较,并总结了如下表的结果,如表所示,RM始终优于所有其他方法,并且当类(| C |)增加时,增益会更大。但是,在MNIST上,没有DA的RM表现最好。 研究者认为,DA会干扰模型培训,因为示例足以避免忘记。
在这里插入图片描述
下表列出了三个情节记忆大小(K)的CIFAR10-Blurry10Online的比较; 200、500和1,000。结果表明,这些方法在最终任务中保留了有效的示例,足以恢复以前任务中发生的遗忘。 ICARL,GDUMB和BIC对于不固定(i5)的有效性较小,并且与EWC和RWALK相比,它们在忘记方面的表现较大,作为权衡。
在这里插入图片描述
研究者进一步比较了任务流的准确性轨迹; 由随机分配的函数ψ(c)生成的三个流,具有不同的随机种子,用于Imagenet和单个流,用Imagenet,并总结了下图中的结果:
在这里插入图片描述
RM在整个任务流中始终优于其他基线。

总结

研究结论:研究者提出一种名为彩虹记忆(RM)的方法,用于处理任务共享类别(模糊 - CIL)的现实持续学习场景。通过基于样本分类不确定性的新的多样性增强采样方法和多种数据增强技术,在CIFAR10、CIFAR100和ImageNet的模糊 - CIL场景中,RM大幅优于现有方法,在不连续和离线CIL设置中也有可比性能。
研究的创新性:一是提出基于样本扰动不确定性的多样性增强采样方法管理有限容量记忆;二是采用多种数据增强技术提高样本多样性,增强记忆中样本的代表性和判别性。
研究展望:可研究基于不确定性的记忆更新和数据增强在训练时的关系,及其对不同CIL任务的影响。还可探索RM在更多类型数据集或其他领域持续学习场景中的应用效果。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/18463.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

ABP - 事件总线之分布式事件总线

ABP - 事件总线之分布式事件总线 1. 分布式事件总线的集成1.2 基于 RabbitMQ 的分布式事件总线 2. 分布式事件总线的使用2.1 发布2.2 订阅2.3 事务和异常处理 3. 自己扩展的分布式事件总线实现 事件总线可以实现代码逻辑的解耦&#xff0c;使代码模块之间功能职责更清晰。而分布…

Zotero7 从下载到安装

Zotero7 从下载到安装 目录 Zotero7 从下载到安装下载UPDATE2025.2.16 解决翻译api异常的问题 下载 首先贴一下可用的链接 github官方仓库&#xff1a;https://github.com/zotero/zotero中文社区&#xff1a;https://zotero-chinese.com/官网下载页&#xff1a;https://www.z…

typecho快速发布文章

typecho_Pytools typecho_Pytools工具由python编写&#xff0c;可以快速批量的在本地发布文章&#xff0c;不需要登陆后台粘贴md文件内容&#xff0c;同时此工具还能查看最新的评论消息。… 开源地址: GitHub Gitee 使用教学&#xff1a;B站 一、主要功能 所有操作不用登陆博…

Redis7——基础篇(一)

前言&#xff1a;此篇文章系本人学习过程中记录下来的笔记&#xff0c;里面难免会有不少欠缺的地方&#xff0c;诚心期待大家多多给予指教。 基础篇&#xff1a; Redis&#xff08;一&#xff09; 一、Redis定义 官网地址&#xff1a;Redis - The Real-time Data Platform R…

K8s组件

一、Kubernetes 集群架构组件 K8S 是属于主从设备模型&#xff08;Master-Slave 架构&#xff09;&#xff0c;即有 Master 节点负责集群的调度、管理和运维&#xff0c;Slave 节点是集群中的运算工作负载节点。 主节点一般被称为 Master 节点&#xff0c;master节点上有 apis…

草图绘制技巧

1、点击菜单栏文件–》新建–》左下角高级新手切换–》零件&#xff1b; 2、槽口&#xff1a;直槽口&#xff0c;中心点槽口&#xff0c;三点源槽口&#xff0c;中心点圆弧槽口&#xff1b; 3、草图的约束&#xff1a;需要按住ctrl键&#xff0c;选中两个草图&#xff0c;然后…

一款基于若依的wms系统

Wms-Ruoyi-仓库库存管理 若依wms是一套基于若依的wms仓库管理系统&#xff0c;支持lodop和网页打印入库单、出库单。毫无保留给个人及企业免费使用。 前端采用Vue、Element UI。后端采用Spring Boot、Spring Security、Redis & Jwt。权限认证使用Jwt&#xff0c;支持多终…

AWS transit gateway 的作用

说白了,就是根据需要,来起到桥梁的作用,内部沟通,或者面向internet. 先看一下diagram 图: 最中间的就是transit gateway, 要达到不同vpc 直接通讯的目的: The following is an example of a default transit gateway route table for the attachments shown in the previ…

把 CSV 文件摄入到 Elasticsearch 中 - CSVES

在我们之前的很多文章里&#xff0c;我有讲到这个话题。在今天的文章中&#xff0c;我们就提重谈。我们使用一种新的方法来实现。这是一个基于 golang 的开源项目。项目的源码在 https://github.com/githubesson/csves/。由于这个原始的代码并不支持 basic security 及带有安全…

[操作系统] 基础 IO:理解“文件”与 C 接口

在 Linux 操作系统中&#xff0c;“一切皆文件”这一哲学思想贯穿始终。从基础 IO 学习角度来看&#xff0c;理解“文件”不仅仅意味着了解磁盘上存储的数据&#xff0c;还包括对内核如何管理各种资源的认识。本文将从狭义与广义两个层面对“文件”进行解读&#xff0c;归纳文件…

国产编辑器EverEdit - 二进制模式下观察Window/Linux/MacOs换行符差异

1 换行符格式 1.1 应用场景 稍微了解计算机历史的人都知道&#xff0c; 计算机3大操作系统&#xff1a; Windows、Linux/Unix、MacOS&#xff0c;这3大系统对文本换行的定义各不相同&#xff0c;且互不相让&#xff0c;导致在文件的兼容性方面存在一些问题&#xff0c;比如它们…

设计模式Python版 命令模式(下)

文章目录 前言一、命令队列的实现二、撤销操作的实现三、请求日志四、宏命令 前言 GOF设计模式分三大类&#xff1a; 创建型模式&#xff1a;关注对象的创建过程&#xff0c;包括单例模式、简单工厂模式、工厂方法模式、抽象工厂模式、原型模式和建造者模式。结构型模式&…

Linux:进程概念详解

​ 进程概念详解 一、进程的基本概念 进程在书本上的定义是&#xff1a;计算机中正在运行的程序实例。仅此描述可能让很多人感到困惑。 我们磁盘上存储着.exe文件&#xff0c;启动文件时&#xff0c;文件会从磁盘加载到内存&#xff0c;由CPU对文件的数据和代码进行运算。但…

04性能监控与调优篇(D1_学习前言)

目录 一、引言 二、基本介绍 三、JVM基础 1. java堆 2. 垃圾回收 3. STW 四、调优层次 五、调优指标 六、JVM调优原则 1. 优先原则 2. 堆设置 3. 垃圾回收器设置 1> GC 发展阶段 2> G1的适用场景 3> 其他收集器适⽤场景 4. 年轻代设置 5. 年⽼代设置 …

系统思考—慢就是快

“所有成长&#xff0c;都是一个缓慢渗透的过程&#xff0c;回头看&#xff0c;才发现自己已经走了很远。” —— 余秋雨 这让我想起一个最近做的项目。和一家公司合作&#xff0c;他们的管理模式一直陷入困境&#xff0c;员工积极性低&#xff0c;领导层的决策效率也不高。刚…

String常量池(2)

大家好&#xff0c;今天我们继续学习String常量池&#xff0c;昨天我们已经做了一个介绍&#xff0c;相信大家✓String常量池有了一定了解&#xff0c;那么就来看看它的应用。 字符串常量地(String Table). 字常量她在IVM中是StringTable类,实际是一个固定大小的 HashTable(一…

LabVIEW显微镜成像偏差校准

在高精度显微镜成像中&#xff0c;用户常常需要通过点击图像的不同位置&#xff0c;让电机驱动探针移动到指定点进行观察。然而&#xff0c;在实际操作中&#xff0c;经常会遇到一个问题&#xff1a;当点击位于图像中心附近的点时&#xff0c;探针能够相对准确地定位&#xff1…

Typora“使用”教程

文章目录 零、Typora简介一、下载并安装Typora二、修改License文件三、每次启动第一个Typora时&#xff0c;总弹出Activate窗口四、去除软件左下角未Activate提示五、参考文章 零、Typora简介 Typora 是一款由 Abner Lee 开发的轻量级 Markdown 编辑器&#xff0c;与其他 Mark…

【scikit-multiflow】使用 scikit-multiflow 的流数据生成器生成概念漂移数据流

说在前面 scikit-multiflow 是一个专注于多流学习&#xff08;multi-stream learning&#xff09;的Python库&#xff0c;它为数据流挖掘和在线学习提供了丰富的工具集。这个库的设计灵感来源于著名的scikit-learn&#xff0c;旨在为研究人员和从业者提供一个易于使用且功能强…

计算机视觉-局部特征

一、局部特征 1.1全景拼接 先用RANSAC估计出变换&#xff0c;就可以拼接两张图片 ①提取特征 ②匹配特征 ③拼接图像 1.2 点的特征 怎么找到对应点&#xff1f;&#xff08;才能做点对应关系RANSAC&#xff09; &#xff1a;特征检测 我们希望找到的点具有的特征有什么特…