题目:《Towards Unbiased Training in Federated Open-world Semi-supervised》
来源:ICML2023
注意比较与 ORCA 的区别
Abstract
联邦半监督学习(FedSSL)已经成为一种新范式,允许分布式客户端在稀缺的标记数据和丰富的未标记数据上协作训练机器学习模型。然而,现有的 FedSSL 工作依赖于一个 封闭世界 的假设,即所有的本地训练数据和全局测试数据都来自标记数据集中观察到的类别。向更进一步发展至关重要:使 FL 模型适应开放世界设置,即未标记数据中存在未见过的类别(unseen classes)
在本文中,我们提出了一个新的联邦开放世界半监督学习(FedoSSL)框架,它可以解决分布式和开放世界设置中的关键挑战,即异构分布的未见过类别的有偏训练过程。具体来说,由于某种未见过的类别的出现取决于客户端,因此在多个客户端中存在的本地未见过的类别(Locally Unseen Classes)可能比只在一个客户端中存在的全局未见过的类别(Globally Unseen Classes)获得不同的、更优越的聚合效果。我们采用 不确定性感知的抑制损失 来减轻本地未见过和全局未见过类别之间的有偏训练
类似于 ORCA 的 不确定性自适应边
不确定性自适应边界用来 减少已知和新类别的类内方差之间的差距
此外,我们通过补充一个 校准模块 来启用全球聚合,以避免由于不同客户端间不一致的数据分布而在知识转移过程中可能产生的潜在冲突。提出的 FedoSSL 可以轻松适应最新的 FL 方法,这也通过在基准测试和真实世界数据集(CIFAR-10、CIFAR-100 和 CINIC-10)上的广泛实验得到了验证
Introduction
标记数据的稀缺性和非标记数据的丰富性促成了联邦半监督学习(FedSSL)的出现,它可以同时利用标记和非标记数据来优化分布式环境中的全局模型。现有的 FedSSL 方案已经证明了基于客户端或服务器端少量标记数据训练模型的能力。然而,这些工作依赖于封闭世界假设,即 所有本地训练数据和全局测试数据都来自标记数据集中包含的相同类别集,这在实际场景中往往是无效的。相比之下,开放世界设置允许新类别的发现,因此在野外数据中很常见,例如在医学图像分类任务中,一些疾病自然稀缺,以前从未标记过,而模型可能需要既能将图像分类为预定义类型(已知类别),又能发现新的未知疾病(未知类别)。因此,出现了一个新的基本问题:如何在开放世界设置下,通过分布式数据协作训练模型,以实现对已知和未知类别的分类?
为此,我们 首先构建了一个新的 FedSSL 基准,将传统的封闭世界训练框架扩展到开放世界设置。令人惊讶的是,研究表明,由于训练过程中未标记数据中存在未知类别,导致性能显著下降。尽管一些工作探讨了未知类别问题,以避免将来自未知类别的未标记样本错误分类为已知类别,同时减轻已知和未知类别之间的不良性能差距,但尚无文献考虑分布式环境中的这一问题。随着多个参与者的加入,问题定义不同,即一个客户端中的一些未知类别可能从全局视角来看存在于其他客户端,因此需要对不同类型的样本/类别进行新的细粒度定义以及训练机制。值得一提的是,由于不同客户端间类别的异构分布,按照传统的 FL 机制简单地聚合参数可能会导致拥有不同未知类别的客户端之间的有偏训练过程
为应对上述挑战,本文提出了一个全新的联邦开放世界半监督学习框架(FedoSSL),它能够实现不同类型样本之间的无偏训练过程。具体来说,本地客户端中的未知类别首先被重新定义为局部未知类别 和 全局未知类别。然后,我们设计了一个不确定性感知的抑制损失,以适应性地控制局部和全局未知类别之间的不良训练发散,即局部未知类别通常由于跨客户端协作而比全局未知类别具有更高的训练效率。此外,考虑到不同客户端间未知类别的异构分布导致聚合阶段潜在的知识转移冲突,我们引入了一个校准模块,以产生相应的全局中心,以便在每个客户端上进行专门的本地调整。我们展示了我们的方法在开放世界设置下与广泛使用的模型(即 ResNet-18)和下游任务(即 CIFAR-10、CIFAR-100、CINIC-10)相比,与最先进的基线相比,显著提高了模型的准确性
本文的贡献总结如下:
- 据我们所知,我们是第一个考虑 FedSSL 中开放世界设置的人,其中未标记数据中存在未知类别,由于未知类别的异构分布而具有挑战性
- 我们设计了一个全新的 FedoSSL 框架,它可以在不同类型的类别(即局部未知和全局未知类别)之间实现无偏学习,并在异构数据分布下进行校准的知识聚合
- 我们在三个典型的图像分类任务上进行了广泛的实验。实证评估显示,FedoSSL 的性能优于最先进的方法
图1: 提出的FedoSSL算法框架。流程:每个客户端首先在私有数据集上执行本地训练,经过几个周期(例如,通过优化公式(1)中的损失函数),然后通过基于Sinkhorn-Knopp的聚类算法(Genevay等人,2019)计算本地中心。将模型参数和本地中心上传到服务器。服务器执行标准的模型聚合。服务器再次使用Sinkhorn-Knopp聚类来获得全局中心。将全局模型和全局中心返回给客户端,客户端使用它们进行本地训练。
Related Work
Methodology
Preliminary and Problem Definition
我们专注于开放世界设置中的联邦半监督学习,其中每个客户端的数据都是 部分标记 的,并且标记数据和未标记数据的类别不同,即未标记数据中存在未知类别(unseen classes)
假设有个客户端每个客户端持有一个包含标记部分和未标记部分的私有分类数据集,其中。整个标记和未标记数据集可以分别表示为和
我们记在全部标记数据中已知类别集为,未标记测试数据中的类别集为,与传统的(封闭世界)FedSSL 不同,后者
本文我们考虑,并记,
分别为已知类和未知类的集合
FedSSL 的目标是在多个分散的客户端中训练一个泛化的全局模型,参数为,即
其中
是总数据量
是客户端的损失函数
具体来说,模型可以分解为:
一个特征提取器,参数为,用于学习低维特征
一个分类器,参数为
每个客户端上的半监督学习算法的训练损失通常包括:
半监督损失
无监督损失,权重参数,
通常,对标记实例应用 标准交叉熵损失
其中,表示模型对输入产生的预测概率
是交叉熵函数
对于无监督损失,有两种典型形式:
基于标记数据的伪标签、基于数据增强的一致性正则化
然而在开放世界设置中,由于存在未知类别,上述方法无法对已知和未知类别进行分类。因此类似于 ORCA(Cao 等人,2022)和 NACH(Guo 等人,2022),我们使用 成对目标 作为未标记数据的无监督损失对未知类别进行分类
其中,和分别是标记和未标记数据的全部特征集,是在小批量中的 最近集,通过计算所有特征表示对的余弦距离来计算
Cao, K., Brbic, M., and Leskovec, J. Open-world semi-supervised learning. In International Conference on Learning Representations, 2022. URL https://openreview.net/forum?id=O-r8LOR-CCA.
Guo, L.-Z., Zhang, Y.-G., Wu, Z.-F., Shao, J.-J., and Li,Y.-F. Robust semi-supervised learning when not all classes have labels. In Oh, A. H., Agarwal, A., Bel-grave, D., and Cho, K. (eds.), Advances in Neural In-formation Processing Systems, 2022. URL https://openreview.net/forum?id=lDohSFOHr0.
Overview of FedoSSL
先前的 FedSSL 方法没有考虑到每个客户端上存在未知类别,这导致许多来自位置类别的数据被错误地分类为已知类别。此外不同客户端上数据分布的不一致性带来一个新问题:一些未知类别可能存在于多个客户端中,导致不同未知类别之间的训练偏差
例如在图 1 中,客户端 1、2 和 K 都有类别 4,而类别 5、6、7、8 和 9 只存在于其中一个客户端中。在这种情况下,我们需要在全局视角中进行更细粒度(fine-grained)的定义
定义1(局部未知类和全局未知类)。在FedoSSL中,客户端上的未知类别可以分为两种类型:
局部未知类,其中
全局未知类,其中
当 FedoSSL 中存在多个客户端的未知类别时,需要考虑两个挑战:
首先,由于客户端之间的协作,局部未知类别可能比全局未知类别学习得更快。现有的无监督成对损失对每个类别都同等对待,而未知类别之间的不平衡训练进度将导致伪标签生成上的大偏差,甚至在已知类别上的性能下降。因此我们提出了一个 不确定性感知的正则化损失 来减轻不同类别之间的训练偏差
此外由于在同一模型分类器上需要输入标记数据和未标记数据,因此在不同客户端上生成的聚类/类别 ID 是异构的。设计一个 校准策略,在模型聚合阶段对同一未知类别的输出进行对齐至关重要
OBJECTIVE
为了在FedSSL中实现不同未知类别之间的无偏训练,我们提出了FedoSSL方法,整体目标包括三部分:
1)所有数据的基本半监督损失
2)一个不确定性感知的正则化损失
3)一个校准损失,以实现有效的模型聚合
其中,和是权衡超参数(trade-off hyper-parameters)
UNCERTAINTY-AWARE LOSS
考虑到不同类型的未知类别具有不同的训练进度,即局部未知类别可以从模型协作中受益,我们寻求添加一个正则项来减轻 局部未知类别 和 全局未知类别 之间的训练偏差
具体来说,我们使用生成的簇/类 id 的不确定性来反映训练进度,并对具有高不确定性的样本施加更大的惩罚。然后不确定性感知损失可以定义为:
其中是数据不确定性函数。为了准确地探索不确定性,我们依赖于由 softmax 函数输出计算的伪标签置信度 和 属于相关伪标签的样本比例,即:
其中是类别的权重,可以从标记数据中估计,
是模型预测类别的训练样本数量
是具有最大样本量的类别的样本数量
注意可以是任何与成反比的函数
在本文中,我们更关注 局部未知类别 和 全局未知类别之间的无偏训练,而忽略了已知类别和未知类别之间的训练不一致性,因为 ORCA 和 NACH 已经解决了这个问题
CALIBRATION MODULE
由于未标记数据上的成对目标函数,即,仅确保相似样本被分类到同一个组/簇,因此同一个未知类别在不同客户端上可能被分到不同的簇 id(不同的类标签)
例如,在图 2 中,由于缺乏标记数据的监督,未知类别 4、5 和 6 可能被分类为 4 到 9 之间的任何标签。这种标签的异质性会显著降低聚合性能。因此需要设计一个校准模块,在模型聚合阶段之前,对不同客户端上局部分类器的异构输出进行对齐
受到基于聚类的 FL 技术的启发(Lubana等人,2022),该技术旨在通过添加全局质心聚合机制来对齐不同客户端的局部聚类性能,我们将这种技术扩展到我们的 FedoSSL 场景中,并使用全局质心作为自监督信号来指导未知类别的分类过程。校准模块的损失可以表示为:
这里的是与全局质心相关的交叉熵损失(Cross entropy)
具体来说,我们首先在服务器上从所有客户端聚合局部中心,以获得全局聚类,即通过使用Sinkhorn-Knopp(Genevay等人,2019)聚类算法。然后,将全局中心返回给客户端进行进一步的校准。以全局中心为指导,分类器在未知类别上的输出可以通过接近聚类分配的交叉熵损失进行更新:
其中表示全局质心。此外,为了防止一个类别的聚类分配在训练过程中发生剧烈变化,我们还设计了以下损失函数,以促进特征表示的聚类能力:
这里和是相应的聚类分配,通过将表示与全局质心匹配来计算
Algorithm Workflow
客户端更新
在每一轮通信中,客户端从服务器下载全局模型和全局质心。在客户端更新阶段,每个客户端进行几次基于局部梯度的更新(例如,E 个周期),以优化公式 4 中的局部目标。然后,通过基于 Sinkhorn-Knopp 的聚类算法计算局部质心
服务器聚合
在客户端完成局部模型的训练后,更新后的模型和局部质心将被发送到服务器进行进一步的聚合。具体来说,服务器首先通过对它们进行加权平均来聚合局部模型。然后通过聚合局部质心(即再次使用 Sinkhorn-Knopp 聚类算法)来计算全局质心。上述步骤将重复执行,直到模型收敛