联邦学习FedAvg-基于去中心化数据的深度网络高效通信学习

        随着计算机算力的提升,机器学习作为海量数据的分析处理技术,已经广泛服务于人类社会。 然而,机器学习技术的发展过程中面临两大挑战:一是数据安全难以得到保障,隐私泄露问题亟待解决;二是网络安全隔离和行业隐私,不同行业部门之间存在数据壁垒,导致数据形成“孤岛”无法安全共享,而仅凭各部门独立数据训练的机器学习模型性能无法达到全局最优化。为解决上述问题,谷歌提出了联邦学习(FL,federated learning)技术。

        本文主要对联邦学习的开山之作《Communication-Efficient Learning of Deep Networks from Decentralized Data》 进行重点内容的解读与整理总结。

论文链接:Communication-Efficient Learning of Deep Networks from Decentralized Data

源码实现:https://gitcode.net/mirrors/WHDY/fedavg?utm_source=csdn_github_accelerator 

目录

摘要

1. 介绍

1.1 问题来源

1.2 本文贡献

1.3 联邦学习特性

1.4 联邦优化

1.5 相关工作

1.6 联邦学习框架图

2. 算法介绍

2.1 联邦随机梯度下降(FedSGD)

2.2 联邦平均算法(FedAvg)

3. 实验设计与实现

3.1 模型初始化

3.2 数据集的设置

3.2.1 MNIST数据集

3.2.2 莎士比亚作品集

3.3 实验优化

3.3.1 增加并行性

3.3.2 增加客户端计算量

 3.4 探究客户端数据集的过度优化

3.5 CIFAR实验

3.6 大规模LSTM实验

4. 总结展望

 摘要

现代移动设备拥有大量的适合模型学习的数据,基于这些数据训练得到的模型可以极大地提升用户体验。例如,语言模型能提升语音设别的准确率和文本输入的效率,图像模型能自动筛选好的照片。然而,移动设备拥有的丰富的数据经常具有关于用户的敏感的隐私信息且多个移动设备所存储的数据总量很大,这样一来,不适合将各个移动设备的数据上传到数据中心,然后使用传统的方法进行模型训练。作者提出了一个替代方法,这种方法可以基于分布在各个设备上的数据(无需上传到数据中心),然后通过局部计算的更新值进行聚合来学习到一个共享模型。作者定义这种非中心化方法为“联邦学习”。作者针对深度网络的联邦学习任务提出了一种实用方法,这种方法在学习过程中多次对模型进行平均。同时,作者使用了五种不同的模型和四个数据集对这种方法进行了实验验证。实验结果表明,这种方法面对不平衡以及非独立同分布的数据,具有较好的鲁棒性。在这种方法中,通信所产生的资源开销是主要的瓶颈,实验结果表明,与同步随机梯度下降相比,该方法的通信轮次减少了10-100倍。

1 介绍

1.1 问题来源

        移动设备中有大量数据适合机器学习任务,利用这些数据反过来可以改善用户体验。例如图像识别模型可以帮助用户挑选好的照片。但是这些数据具有高度私密性,并且数据量大,所以我们不可能把这些数据拿到云端服务器进行集中训练。论文提出了一种分布式机器学习方法称为联邦学习(Federal Learning),在该框架中,服务器将全局模型下发给客户,客户端利用本地数据集进行训练,并将训练后的权重上传到服务器,从而实现全局模型的更新。

1.2 本文贡献

  • 提出了从分散的存储于各个移动设备的数据中训练模型是一个重要的研究方向
  • 提出了一个简单实用的算法来解决这种在非中心化设置下的学习问题
  • 做了大量实验来评估所提算法

        具体来说,本文介绍了“联邦平均”算法,这种算法融合了客户端上的局部随机梯度下降计算与服务器上的模型平均。作者使用该算法进行了大量实验,结果表明了这种算法对于不平衡且非独立同分布的数据具有很好的鲁棒性,并且使得在非中心存储的数据上进行深度网络训练所需的通信轮次减少了几个数量级。

1.3 联邦学习特性

  • 从多个移动设备中存储的真实数据中进行模型训练比从存储在数据中心的数据中进行模型训练更具优势
  • 由于数据具有隐私,且多个移动设备所存储的数据总量很大,因此不适合将其上传至数据中心再进行模型训练
  • 对于监督学习任务,数据中的标签信息可以从用户与应用程序的交互中推断出来

1.4 联邦优化

        传统分布式学习关注点在于如何将一个大型神经网络训练分布式进行,数据仍然可能是在几个大的训练中心存储。而联邦学习更关注数据本身,利用联邦学习保证了数据不出本地,并根据数据的特点,对学习模型进行改进。相比于典型的分布式优化问题,联邦优化具有几个关键特性:

  • Non-IID:数据的特征和分布在不同参与方间存在差异
  • Unbalanced:一些用户会更多地使用服务或应用程序,导致本地训练数据量存在差
  • Massively distributed:参与优化的用户数>>平均每个用户的数据量
  • Limited communication:无法保证客户端和服务器端的高效通信

 本文重点关注优化任务中非独立同分布和不平衡问题,以及通信受限的临界属性。

注:独立同分布假设(IID)

        非凸神经网络的目标函数:

对于一个机器学习的问题来说,有,即用模型参数w预测实例的损失。

        设有K个client,第k个client的数据点为P_{k},对应的数据集数量为n_{k}=\left | P_{k} \right |上式可写为:

P_{k}上的数据集是随机均匀采样的,称IID设置,此时有:

不成立则称Non-IID。 

1.5 相关工作

        相关工作中,2010年通过迭代平均本地训练的模型来对感知机进行分布式训练,2015年研究了语音识别深度神经网络的分布式训练,在2015论文里研究了使用“软”平均的异步训练方法。这些工作都考虑的是数据中心化背景下的分布式训练,没有考虑具有数据不平衡且非独立同分布特点的联邦学习任务。但是它们提供了一种思路,即通过迭代平均本地训练模型的算法来解决联邦学习的问题。与本文的研究动机相似在这篇论文中讨论了保护设备中的用户数据的隐私的优点。而在这篇论文中,作者关注于训练深度网络,强调隐私的重要性以及通过在每一轮通信中仅共享一部分参数,进而降低通信开销;但是,他们也没有考虑数据的不平衡以及非独立同分布性,并且他们的研究工作缺乏实验评估。

1.6 联邦学习框架图

2 算法介绍

2.1 联邦随机梯度下降(FedSGD)

设置固定的学习率η,对K个客户端的数据计算其损失梯度:

中心服务器聚合每个客户端计算的梯度,以此来更新模型参数:

其中,

2.2 联邦平均算法(FedAvg)

在客户端进行局部模型的更新:

中心服务器对每个客户端更新后的参数进行加权平均:

每个客户端可以独立地更新模型参数多次,然后再将更新好的参数发送给中心服务器进行加权平均:

FedAvg的计算量与三个参数有关:

  • C:每轮训练选择客户端的比例
  • E:每个客户端更新参数的循环次数所设计的一个因子
  • B:客户端更新参数时,每次梯度下降所使用的数据量

对于一个拥有n_{k}个数据样本的客户端,每轮本地参数更新的次数为:

注:FedSGD只是FedAvg的一个特例,即当参数E=1,B=∞时,FedAvg等价于FedSGD。
 
FedSGD和FedAvg的关系示意图:
地址:https://blog.csdn.net/biongbiongdou/article/details/104358321

3 实验设计与实现

3.1 模型初始化

实验设置
  • 数据集:MNIST中600个无重复的独立同分布样本
  • E=20; C=1; B=50; 中心服务器聚合一次
  • 不同模型使用不同/相同的初始化模型,并通过θ对两模型参数进行加权求和
       

研究模型平均对模型效果的影响:

        这里有两种情况,一种是不同模型使用不同的初始化模型;一种是不同模型使用相同的初始化模型。并且可以通过参数控制权重比进行模型的加权求和。

        可看到,采用不同的初始化参数进行模型平均后,平均模型的效果变差,模型性能比两个父模型都差;采用相同的初始化参数进行模型平均后,对模型的平均可以显著的减少整个训练集的损失,模型性能优于两个父模型。

        该结论是用于实现联邦学习的重要支撑,在每一轮训练时,server发布全局模型,使各个client采用相同的参数模型进行训练,可以有效的减少训练集的损失。

3.2 数据集的设置

        初步研究包括两个数据集三个模型族,前两个模型用于识别MNIST数据集,后一个用于实现莎士比亚作品集单词预测。

3.2.1 MNIST数据集

2NN:拥有两个隐藏层,每层200个神经元的多层感知机模型,ReLu激活;

CNN:两个卷积核大小为5X5的卷积层(分别是32通道和64通道,每层后都有一个2X2的最大池化层);

IDD:数据随机打乱分给100个客户端,每个客户端600个样例;

Non-IDD:按数字标签将数据集划分为200个大小为300的碎片,每个客户端两个碎片;

  • 3.2.2 莎士比亚作品集

LSTM:将输入字符嵌入到一个已学习的8维空间中,然后通过两个LSTM层处理嵌入的字符,每层256个节点,最后,第二个LSTM层的输出被发送到每一个字符有一个节点的softmax输出层,使用unroll的80个字符长度进行训练;

Unbalanced-Non-IID:每个角色形成一个客户端,共1146个客户端;

Balanced-IID:直接将数据集划分给1146个客户端;

3.3 实验优化

        在数据中心存储的优化中,通信开销相对较小,计算开销占主导地位。而在联邦优化中,任何一个单一设备所具有的数据量较少,且现代移动设备有相对快的处理器所以这里更关注通信开销因此,我们想要使用额外的计算来减少训练模型所需通信的轮次主要有两个方法,分别是提高并行度以及增加每个客户端的计算量。

3.3.1 增加并行性

固定参数E,对C和B进行讨论。

  •  当B=∞时,增加客户端比例,效果提升的优势较小;
  • 当B=10时,有显著改善,特别是在Non-IID情况下;
  • 在B=10,当C≥0.1时,收敛速度有明显改进,当用户达到一定数量时,收敛增加的速度不再明显。

3.3.2 增加客户端计算量

对于增加每个客户端的计算量,可以通过减小B或者增加E来实现。

  • 每轮增加更多的本地SGD更新可以显著降低通信成本;
  • 对于Unbalanced-Non-IDD的莎士比亚数据减少通信轮数倍数更多,推测可能某些客户端有相对较大的本地数据集,使得增加本地训练更有价值;

 将上述实验结果用折线图的形式展示,这里蓝色线表示的是联邦随机梯度下降的结果:

  • FedAvg相比FedSGD不仅降低通信轮数,还具有更高的测试精度。推测是平均模型产生了类似Dropout的正则化效益; 

 3.4 探究客户端数据集的过度优化

        在E=5以及E=25的设置下,对于大的本地更新次数而言,联邦平均的训练损失会停滞或发散;因此在实际应用时,对于一些模型,在训练后期减少本地训练周期将有助于收敛。 

3.5 CIFAR实验

在CTFAR数据集上进行实验,模型是TensorFlow教程中的模型包括两个卷积层,两个全连接层和一个线性传输层,大约10^6个参数。下表给出了baselineSGD、FedSGD和FedAvg达到三种不同精度目标的通信轮数。

不同学习率下FedSGD和FedAvg的曲线:

3.6 大规模LSTM实验

 为了证明我们的方法对于解决实际问题的有效性,我们进行了一项大规模单词预测任务。

训练集包含来自大型社交网络的100万个公共帖子。我们根据作者对帖子进行分组,总共有超过50个客户端。我们将每个客户的数据集限制为最多5000个单词。模型是一个256节点的LSTM,其词汇量为10000个单词。每个单词的输入和输出嵌入为192维,并与模型共同训练;总共有4950544个参数,使用10个字符的unroll。

对于联邦平均和联邦随机梯度下降的最佳学习率曲线:

  • 相同准确率的情况下,FedAvg的通信轮数更少;测试精度方差更小;
  • E=1比E=5的表现效果更好; 

4 总结展望

         我们的实验表明,联邦学习可以在实践中实现,因为它可以使用相对较少的几轮通信来训练高质量的模型,这一点在各种模型体系结构上得到了证明:一个多层感知器、两个不同的卷积NNs、一个两层LSTM和一个大规模LSTM。虽然联邦学习提供了许多实用的隐私保护,但是通过差分隐私、安全多方计算提供了可以提供更有力的保障,或者他们的组合是未来工作的一个有趣方向。请注意,这两类技术最自然地应用于像FedAvg这样的同步算法。

参考文章:

https://blog.csdn.net/qq_41605740/article/details/124584939?spm=1001.2014.3001.5506

https://blog.csdn.net/weixin_45662974/article/details/119464191?spm=1001.2014.3001.5506 

https://zhuanlan.zhihu.com/p/515756280 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/115598.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

使用飞桨实现的第一个AI项目——波士顿的房价预测

part1.首先引入相应的函数库: 值得说明的地方: (1)首先,numpy是一个python库,主要用于提供线性代数中的矩阵或者多维数组的运算函数,利用import numpy as np引入numpy,并将np作为它的别名 part…

linux字符串处理

目录 1. C 截取字符串,截取两个子串中间的字符串linux串口AT指令 2. 获取该字符串后面的字符串用 strstr() 函数查找需要提取的特定字符串,然后通过指针运算获取该字符串后面的字符串用 strtok() 函数分割字符串,找到需要提取的特定字符串后,…

如何在小程序中给会员设置备注

给会员设置备注是一项非常有用的功能,它可以帮助商家更好地管理和了解自己的会员。下面是一个简单的教程,告诉商家如何在小程序中给会员设置备注。 1. 找到指定的会员卡。在管理员后台->会员管理处,找到需要设置备注的会员卡。也支持对会…

宠物赛道,用AI定制宠物头像搞钱项目教程

今天给大家介绍一个非常有趣,而粉丝价值又极高,用AI去定制宠物头像或合照的AI项目。 接触过宠物行业应该知道,获取1位铲屎官到私域,这类用户的价值是极高的,一个宠物粉,是连铲个屎都要花钱的,每…

USRP 简介,对于NI软件无线电你所需要了解的一切

什么是 USRP 通用软件无线电外设( USRP ) 是由 Ettus Research 及其母公司National Instruments设计和销售的一系列软件定义无线电。USRP 产品系列由Matt Ettus领导的团队开发,被研究实验室、大学和业余爱好者广泛使用。 大多数 USRP 通过以太网线连接到主机&…

本地部署 Stable Diffusion(Mac 系统)

在 Mac 系统本地部署 Stable Diffusion 与在 Windows 系统下本地部署的方法本质上是差不多的。 一、安装 Homebrew Homebrew 是一个流行的 macOS (或 Linux)软件包管理器,用于自动下载、编译和安装各种命令行工具和应用程序。有关说明请访问官…

【分享】PDF如何拆分成2个或多个文件呢?

当我们需要把一个多页的PDF文件拆分成2个或多个独立的PDF文件,可以怎么操作呢?这种情况需要使用相关工具,下面小编就来分享两个常用的工具。 1. PDF编辑器 PDF编辑器不仅可以用来编辑PDF文件,还具备多种功能,拆分PDF文…

GPT-4.0技术大比拼:New Bing与ChatGPT,哪个更适合你

随着GPT-4.0技术的普及和发展,越来越多的平台开始将其应用于各种场景。New Bing已经成功接入GPT-4.0,并将其融入搜索和问答等功能。同样,在ChatGPT官网上,用户只需开通Plus账号,即可体验到GPT-4.0带来的智能交流和信息…

使用flink sqlserver cdc 同步数据到StarRocks

前沿: flink cdc功能越发强大,支持的数据源也越多,本篇介绍使用flink cdc实现: sqlserver-》(using flink cdc)-〉flink -》(using flink starrocks connector)-〉starrocks整个流程…

小游戏分发平台如何以技术拓流?

2023年,小游戏的发展将受到多方面的影响,例如新技术的引入、参与小游戏的新玩家以及游戏市场的激烈竞争等。首先,新技术如虚拟现实(VR)、增强现实(AR)和机器人技术都可以带来新颖的游戏体验。其…

嘉泰实业和您共创未来财富生活

每一次暖心的沟通都是一次公益,真诚不会因为它的渺小而被忽略;每一声问候都是一次公益,善意不会因为它的普通而被埋没。熟悉嘉泰实业的人都知道,这家企业不但擅长在金融理财领域里面呼风唤雨,同时也非常擅长在公益事业当中践行,属于企业的责任心,为更多有困难的群体带来大爱的传…

大数据课程K13——Spark的距离度量相似度度量

文章作者邮箱:yugongshiye@sina.cn 地址:广东惠州 ▲ 本章节目的 ⚪ 掌握Spark的距离度量和相似度度量; ⚪ 掌握Spark的欧氏距离; ⚪ 掌握Spark的曼哈顿距离; ⚪ 掌握Spark的切比雪夫距离; ⚪ 掌握Spark的最小二乘法; 一、距离度量和相似度度量 1. …

打磨 8 个月、功能全面升级,Milvus 2.3.0 文字发布会现在开始!

Milvus 社区的各位伙伴: 大家晚上好!欢迎来到 Milvus 2.3.0 文字发布会! 作为整个团队的匠心之作,Milvus 2.3.0 历经 8 个月的设计与打磨,无论在新功能、应用场景还是可靠度方面都有不小的提升。 具体来看:…

UG\NX CAM二次开发 插入工序 UF_OPER_create

文章作者:代工 来源网站:NX CAM二次开发专栏 简介: UG\NX CAM二次开发 插入工序 UF_OPER_create 效果: 代码: void MyClass::do_it() {tag_t setup_tag=NULL_TAG;UF_SETUP_ask_setup(&setup_tag);if (setup_tag==NULL_TAG){uc1601("请先初始化加工环境…

【Ubuntu】解决ubuntu虚拟机和物理机之间复制粘贴问题(无需桌面工具)

解决Ubuntu虚拟机和物理机之间复制粘贴问题 第一步 先删除原来的vmware tools(如果有的话) sudo apt-get autoremove open-vm-tools第二步 安装软件包,一般都是用的desktop版本(如果是server换一下) sudo apt-get …

开源vue动态表单组件

一、项目简介 vueelement的动态表单组件,拖拽组件到面板即可实现一个表单 二、实现功能 支持拖拽 支持输入框 支持文本框 支持数字输入框 支持下拉选择器 支持多选框 支持日期控件 支持开关 支持动态表格 支持上传图片 支持上传文件 支持标签 支持ht…

简单了解网络传输介质

目录 一、同轴电缆 二、双绞线 三、光纤 四、串口电缆 一、同轴电缆 10BASE前面的数字表示传输带宽为10M,由于带宽较低、现在已不再使用。 50Ω同轴电缆主要用来传送基带数字信号,因此也被称作为基带同轴电缆,在局域网中得到了广泛的应用…

基于OFDM的水下图像传输通信系统matlab仿真

目录 1.算法运行效果图预览 2.算法运行软件版本 3.部分核心程序 4.算法理论概述 5.算法完整程序工程 1.算法运行效果图预览 2.算法运行软件版本 matlab2022a 3.部分核心程序 function [rx_img] func_TR(tx_img, num_path, pathdelays, pathgains, snr) rng(default); …

AI图像行为分析算法 opencv

AI图像行为分析算法通过pythonopencv深度学习框架对现场操作行为进行全程实时分析,AI图像行为分析算法通过人工智能视觉能够准确判断出现场人员的作业行为是否符合SOP流程规定,并对违规操作行为进行自动抓拍告警。OpenCV是一个基于Apache2.0许可&#xf…

Ubuntu 22.04安装 —— Win11 22H2

目录 Ubuntu使用下载UbuntuVmware 安装图示安装步骤图示 Ubuntu使用 系统环境: Windows 11 22H2Vmware 17 ProUbutun 22.04.3 Server Ubuntu Server documentation | Ubuntu 下载 Ubuntu 官网下载 建议安装长期支持版本 ——> 可以选择桌面版或服务器版(仅包…