paddle2.3-基于联邦学习实现FedAVg算法

目录

1. 联邦学习介绍

2. 实验流程

3. 数据加载

4. 模型构建

5. 数据采样函数

6. 模型训练


1. 联邦学习介绍

联邦学习是一种分布式机器学习方法,中心节点为server(服务器),各分支节点为本地的client(设备)。联邦学习的模式是在各分支节点分别利用本地数据训练模型,再将训练好的模型汇合到中心节点,获得一个更好的全局模型。

联邦学习的提出是为了充分利用用户的数据特征训练效果更佳的模型,同时,为了保证隐私,联邦学习在训练过程中,server和clients之间通信的是模型的参数(或梯度、参数更新量),本地的数据不会上传到服务器。

本项目主要是升级1.8版本的联邦学习fedavg算法至2.3版本,内容取材于基于PaddlePaddle实现联邦学习算法FedAvg - 飞桨AI Studio星河社区

2. 实验流程

联邦学习的基本流程是:

1. server初始化模型参数,所有的clients将这个初始模型下载到本地;

2. clients利用本地产生的数据进行SGD训练;

3. 选取K个clients将训练得到的模型参数上传到server;

4. server对得到的模型参数整合,所有的clients下载新的模型。

5. 重复执行2-5,直至收敛或达到预期要求

import os
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import random
import time
import paddle
import paddle.nn as nn
import numpy as np
from paddle.io import Dataset,DataLoader
import paddle.nn.functional as F

3. 数据加载

mnist_data_train=np.load('data/data2489/train_mnist.npy')
mnist_data_test=np.load('data/data2489/test_mnist.npy')
print('There are {} images for training'.format(len(mnist_data_train)))
print('There are {} images for testing'.format(len(mnist_data_test)))
# 数据和标签分离(便于后续处理)
Label=[int(i[0]) for i in mnist_data_train]
Data=[i[1:] for i in mnist_data_train]
There are 60000 images for training
There are 10000 images for testing

4. 模型构建

class CNN(nn.Layer):def __init__(self):super(CNN,self).__init__()self.conv1=nn.Conv2D(1,32,5)self.relu = nn.ReLU()self.pool1=nn.MaxPool2D(kernel_size=2,stride=2)self.conv2=nn.Conv2D(32,64,5)self.pool2=nn.MaxPool2D(kernel_size=2,stride=2)self.fc1=nn.Linear(1024,512)self.fc2=nn.Linear(512,10)# self.softmax = nn.Softmax()def forward(self,inputs):x = self.conv1(inputs)x = self.relu(x)x = self.pool1(x)x = self.conv2(x)x = self.relu(x)x = self.pool2(x)x=paddle.reshape(x,[-1,1024])x = self.relu(self.fc1(x))y = self.fc2(x)return y

5. 数据采样函数

# 均匀采样,分配到各个client的数据集都是IID且数量相等的
def IID(dataset, clients):num_items_per_client = int(len(dataset)/clients)client_dict = {}image_idxs = [i for i in range(len(dataset))]for i in range(clients):client_dict[i] = set(np.random.choice(image_idxs, num_items_per_client, replace=False)) # 为每个client随机选取数据image_idxs = list(set(image_idxs) - client_dict[i]) # 将已经选取过的数据去除client_dict[i] = list(client_dict[i])return client_dict
# 非均匀采样,同时各个client上的数据分布和数量都不同
def NonIID(dataset, clients, total_shards, shards_size, num_shards_per_client):shard_idxs = [i for i in range(total_shards)]client_dict = {i: np.array([], dtype='int64') for i in range(clients)}idxs = np.arange(len(dataset))data_labels = Labellabel_idxs = np.vstack((idxs, data_labels)) # 将标签和数据ID堆叠label_idxs = label_idxs[:, label_idxs[1,:].argsort()]idxs = label_idxs[0,:]for i in range(clients):rand_set = set(np.random.choice(shard_idxs, num_shards_per_client, replace=False)) shard_idxs = list(set(shard_idxs) - rand_set)for rand in rand_set:client_dict[i] = np.concatenate((client_dict[i], idxs[rand*shards_size:(rand+1)*shards_size]), axis=0) # 拼接return client_dict

class MNISTDataset(Dataset):def __init__(self, data,label):self.data = dataself.label = labeldef __getitem__(self, idx):image=np.array(self.data[idx]).astype('float32')image=np.reshape(image,[1,28,28])label=np.array(self.label[idx]).astype('int64')return image, labeldef __len__(self):return len(self.label)

6. 模型训练

class ClientUpdate(object):def __init__(self, data, label, batch_size, learning_rate, epochs):dataset = MNISTDataset(data,label)self.train_loader = DataLoader(dataset,batch_size=batch_size,shuffle=True,drop_last=True)self.learning_rate = learning_rateself.epochs = epochsdef train(self, model):optimizer=paddle.optimizer.SGD(learning_rate=self.learning_rate,parameters=model.parameters())criterion = nn.CrossEntropyLoss(reduction='mean')model.train()e_loss = []for epoch in range(1,self.epochs+1):train_loss = []for image,label in self.train_loader:# image=paddle.to_tensor(image)# label=paddle.to_tensor(label.reshape([label.shape[0],1]))output=model(image)loss= criterion(output,label)# print(loss)loss.backward()optimizer.step()optimizer.clear_grad()train_loss.append(loss.numpy()[0])t_loss=sum(train_loss)/len(train_loss)e_loss.append(t_loss)total_loss=sum(e_loss)/len(e_loss)return model.state_dict(), total_loss

train_x = np.array(Data)
train_y = np.array(Label)
BATCH_SIZE = 32
# 通信轮数
rounds = 100
# client比例
C = 0.1
# clients数量
K = 100
# 每次通信在本地训练的epoch
E = 5
# batch size
batch_size = 10
# 学习率
lr=0.001
# 数据切分
iid_dict = IID(mnist_data_train, 100)
def training(model, rounds, batch_size, lr, ds,L, data_dict, C, K, E, plt_title, plt_color):global_weights = model.state_dict()train_loss = []start = time.time()# clients与server之间通信for curr_round in range(1, rounds+1):w, local_loss = [], []m = max(int(C*K), 1) # 随机选取参与更新的clientsS_t = np.random.choice(range(K), m, replace=False)for k in S_t:# print(data_dict[k])sub_data = ds[data_dict[k]]sub_y = L[data_dict[k]]local_update = ClientUpdate(sub_data,sub_y, batch_size=batch_size, learning_rate=lr, epochs=E)weights, loss = local_update.train(model)w.append(weights)local_loss.append(loss)# 更新global weightsweights_avg = w[0]for k in weights_avg.keys():for i in range(1, len(w)):# weights_avg[k] += (num[i]/sum(num))*w[i][k]weights_avg[k]=weights_avg[k]+w[i][k]   weights_avg[k]=weights_avg[k]/len(w)global_weights[k].set_value(weights_avg[k])# global_weights = weights_avg# print(global_weights)#模型加载最新的参数model.load_dict(global_weights)loss_avg = sum(local_loss) / len(local_loss)if curr_round % 10 == 0:print('Round: {}... \tAverage Loss: {}'.format(curr_round, np.round(loss_avg, 5)))train_loss.append(loss_avg)end = time.time()fig, ax = plt.subplots()x_axis = np.arange(1, rounds+1)y_axis = np.array(train_loss)ax.plot(x_axis, y_axis, 'tab:'+plt_color)ax.set(xlabel='Number of Rounds', ylabel='Train Loss',title=plt_title)ax.grid()fig.savefig(plt_title+'.jpg', format='jpg')print("Training Done!")print("Total time taken to Train: {}".format(end-start))return model.state_dict()#导入模型
mnist_cnn = CNN()
mnist_cnn_iid_trained = training(mnist_cnn, rounds, batch_size, lr, train_x,train_y, iid_dict, C, K, E, "MNIST CNN on IID Dataset", "orange")

W0605 23:22:00.961916 10307 gpu_context.cc:278] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.2, Runtime API Version: 10.1
W0605 23:22:00.966121 10307 gpu_context.cc:306] device: 0, cuDNN Version: 7.6.Round: 10... 	Average Loss: 0.033
Round: 20... 	Average Loss: 0.011
Round: 30... 	Average Loss: 0.012
Round: 40... 	Average Loss: 0.008
Round: 50... 	Average Loss: 0.003
Round: 60... 	Average Loss: 0.002
Round: 70... 	Average Loss: 0.001
Round: 80... 	Average Loss: 0.001
Round: 90... 	Average Loss: 0.001

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/142217.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【乳腺超声、乳腺钼靶、宫颈癌】等项目数据调研,及相关参考内容整理汇总

一、乳腺超声内容整理 1.1、数据集 Breast Ultrasound Images Dataset;下载地址2STU-Hospital处理和训练参考文档:https://blog.csdn.net/weixin_51511389/article/details/127594654 1.2、可以参考的论文 AAU-net: An Adaptive Attention U-net for Breast Lesions Segmen…

京东获得JD商品详情 API 返回值说明

京东商品详情API接口可以获得JD商品详情原数据。 这个API接口有两种参数,公共参数和请求参数。 公共参数有以下几个: apikey:这是您自己的API密钥,可以在京东开发者中心获取。 请求参数有以下几个: num_iid&#…

LeetCode01

LeetCode01 两数之和 给定一个整数数组 nums 和一个整数目标值 target,请你在该数组中找出 和 为目标值 target 的那两个整数,并返回它们的数组下标。 你可以假设每种输入只会对应一个答案。但是,数组中同一个元素在答案里不能重复出现。 你…

【神印王座】悲啸洞穴之物揭晓,圣采儿差点被骗,幸好龙皓晨聪明

Hello,小伙伴们,我是小郑继续为大家深度解析神印王座。 神印王座动漫现阶段已经出到龙皓晨等人接取新任务深入魔族地界的阶段,而龙皓晨等人接取的任务想必现在大家都知道了,那就是探索魔族地界中的悲啸洞穴。但是大家知道悲啸洞穴里面藏着什么…

软件测试:全链路追踪工具 Zipkin导入、安装(Windows版本)

1.0全链路追踪技术出现的原因 公司内部一个功能的实现,底层可能调用多个应用系统 在调用这个功能的同时,可能会出现多种情况,比如访问较慢,出现错误,可能需要进行定位 所以,我们需要快速定位服务错误点 大…

基于SSM的办公用品管理系统设计与实现

末尾获取源码 开发语言:Java Java开发工具:JDK1.8 后端框架:SSM 前端:采用JSP技术开发 数据库:MySQL5.7和Navicat管理工具结合 服务器:Tomcat8.5 开发软件:IDEA / Eclipse 是否Maven项目&#x…

Docker(三)、Dockerfile探究

Dockerfile探究 一、镜像层概念1、通过执行命令显化docker的机制 二、Dockerfile基础命令1、FROM 基于基准镜像【即构建镜像的时候,依托原有镜像做拓展】2、LABEL & MAINTAINER -说明信息3、WORKDIR 设置工作目录4、ADD & COPY 复制文件5、ENV 设置环境常量…

Ceph存储部署

这里写自定义目录标题 一、Ceph概述二、Ceph的组件三、架构四、安装步骤一、环境部署二、修改ssh配置三、hosts文件修改四、ssh免密配置五、时间同步六、格式化磁盘七、后续的操作暂时都在centos1执行 五、成功将ceph配置完成 前言:后续配置的解释可能标题不是很清晰…

【前端】零基础快速搞定JavaScript核心知识点

文章目录 1.初识JavaScript1.1.JavaScript语言简介1.2.JavaScript引入方式和注释1.3.Javascript变量声明详解1.4.JavaScript变量提升详解 2.JavaScript基础数据类型2.1.JavaScript基础数据类型简介2.2.基础类型数据-Number2.3.基础类型数据-String2.4.基础类型数据-Boolean2.5.…

HarmonyOS 4.0 实况窗上线!支付宝实现医疗场景智能提醒

本文转载自支付宝体验科技,作者是蚂蚁集团客户端工程师博欢,介绍了支付宝如何基于 HarmonyOS 4.0 实况窗实现医疗场景履约智能提醒。 1.话题背景 8 月 4 日,华为在 HDC(华为 2023 开发者大会)上推出了新版本操作系统…

一文教你学会ArcGIS Pro地图设计与制图系列全流程(2)

ArcGIS Pro做的成果图及系列文章目录: 系列文章全集: 《一文教你学会ArcGIS Pro地图设计与制图系列全流程(1)》《一文教你学会ArcGIS Pro地图设计与制图系列全流程(2)》《一文教你学会ArcGIS Pro地图设计与…

jupyterlab开发环境最佳构建方式

文章目录 背景jupyterlab环境构建运行虚拟环境构建以及kernel映射验证总结 背景 从jupyter notebook切换到了jupyter lab. 这里记录一下本地环境的最佳构建方式. jupyter lab 安装在jupyterlab-local的anaconda 虚拟环境中.建立多个其他虚拟环境安装各种python包实现环境隔离,…

【实战详解】如何快速搭建接口自动化测试框架?Python + Requests

摘要: 本文主要介绍如何使用Python语言和Requests库进行接口自动化测试,并提供详细的代码示例和操作步骤。希望能对读者有所启发和帮助。 前言 随着移动互联网的快速发展,越来越多的应用程序采用Web API(也称为RESTful API&…

python 探索分形世界|曼德布洛特|np.frompyfunc()

文章目录 分形的重要特征曼德布洛特集合曼德布洛特集合有一个以证明的结论:图像展示np.ogrid[]np.frompyfunc()集合转图像 julia集合 无边的奇迹源自简单规则的无限重复 ---- 分形之父Benoit B.Mandelbrot 分形的重要特征 自相似性无标度性非线性 曼德布洛特集合…

8+单基因+细胞凋亡+WGCNA+单细胞+实验验证

今天给同学们分享一篇单基因细胞凋亡WGCNA实验验证的生信文章“RASGRP2 is a potential immune-related biomarker and regulates mitochondrial-dependent apoptosis in lung adenocarcinoma”,这篇文章于2023年2月3日发表在Front Immunol期刊上,影响因…

订阅《复现SCI文章系列教程》

写在前面 《小杜生信笔记》准备开启新的订阅专栏**《复现期刊文章系列教程》,本专栏小杜会寻找一些自己感兴趣的文章进行复现(不说百分之百的复现,但是也会百分之八十进行复现)。本期刊的教程代码会全部进行公开(通过订…

孙哥Spring源码第26集

第26集、AnnotationAwareAspectJAutoProxyCreator源码 【视频来源于:B站up主孙帅suns Spring源码视频】【微信号:suns45】 26.1、postProcessAfterInitialization分析 26.2、wrapIfNecessary分析 26.3、createProxy分析 26.4、getProxy 26.5、BeanPost…

Deep Span Representations for Named Entity Recognition

原文链接: https://aclanthology.org/2023.findings-acl.672.pdf ACL 2023 介绍 问题 作者认为,一个好的span表征对于NER任务是非常重要的,而之前的工作都是将第一个或最后一个的表征简单的进行组合后,没有进行充分的交互就送入到…

linux 查看CPU架构是AMD还是ARM

要查看 Linux 系统的 CPU 架构是 AMD 还是 ARM,可以使用以下命令: 使用 lscpu 命令并查找 Architecture 字段: lscpu | grep Architecture如果输出结果中包含 x86_64 或 i686,则表示系统的 CPU 架构是 AMD(或者是 x86…

android去掉 原生锁屏

1. /frameworks/base/core/java/com/android/internal/widget/LockPatternUtils.java 直接 return true 2./packages/apps/Settings/src/com/android/settings/password/ScreenLockType.java 都改成 none 类型