paddle2.3-基于联邦学习实现FedAVg算法-CNN

目录

1. 联邦学习介绍

2. 实验流程

3. 数据加载

4. 模型构建

5. 数据采样函数

6. 模型训练


1. 联邦学习介绍

联邦学习是一种分布式机器学习方法,中心节点为server(服务器),各分支节点为本地的client(设备)。联邦学习的模式是在各分支节点分别利用本地数据训练模型,再将训练好的模型汇合到中心节点,获得一个更好的全局模型。

联邦学习的提出是为了充分利用用户的数据特征训练效果更佳的模型,同时,为了保证隐私,联邦学习在训练过程中,server和clients之间通信的是模型的参数(或梯度、参数更新量),本地的数据不会上传到服务器。

本项目主要是升级1.8版本的联邦学习fedavg算法至2.3版本,内容取材于基于PaddlePaddle实现联邦学习算法FedAvg - 飞桨AI Studio星河社区

2. 实验流程

联邦学习的基本流程是:

1. server初始化模型参数,所有的clients将这个初始模型下载到本地;

2. clients利用本地产生的数据进行SGD训练;

3. 选取K个clients将训练得到的模型参数上传到server;

4. server对得到的模型参数整合,所有的clients下载新的模型。

5. 重复执行2-5,直至收敛或达到预期要求

import os
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import random
import time
import paddle
import paddle.nn as nn
import numpy as np
from paddle.io import Dataset,DataLoader
import paddle.nn.functional as F

3. 数据加载

mnist_data_train=np.load('data/data2489/train_mnist.npy')
mnist_data_test=np.load('data/data2489/test_mnist.npy')
print('There are {} images for training'.format(len(mnist_data_train)))
print('There are {} images for testing'.format(len(mnist_data_test)))
# 数据和标签分离(便于后续处理)
Label=[int(i[0]) for i in mnist_data_train]
Data=[i[1:] for i in mnist_data_train]
There are 60000 images for training
There are 10000 images for testing

4. 模型构建

class CNN(nn.Layer):def __init__(self):super(CNN,self).__init__()self.conv1=nn.Conv2D(1,32,5)self.relu = nn.ReLU()self.pool1=nn.MaxPool2D(kernel_size=2,stride=2)self.conv2=nn.Conv2D(32,64,5)self.pool2=nn.MaxPool2D(kernel_size=2,stride=2)self.fc1=nn.Linear(1024,512)self.fc2=nn.Linear(512,10)# self.softmax = nn.Softmax()def forward(self,inputs):x = self.conv1(inputs)x = self.relu(x)x = self.pool1(x)x = self.conv2(x)x = self.relu(x)x = self.pool2(x)x=paddle.reshape(x,[-1,1024])x = self.relu(self.fc1(x))y = self.fc2(x)return y

5. 数据采样函数

# 均匀采样,分配到各个client的数据集都是IID且数量相等的
def IID(dataset, clients):num_items_per_client = int(len(dataset)/clients)client_dict = {}image_idxs = [i for i in range(len(dataset))]for i in range(clients):client_dict[i] = set(np.random.choice(image_idxs, num_items_per_client, replace=False)) # 为每个client随机选取数据image_idxs = list(set(image_idxs) - client_dict[i]) # 将已经选取过的数据去除client_dict[i] = list(client_dict[i])return client_dict
# 非均匀采样,同时各个client上的数据分布和数量都不同
def NonIID(dataset, clients, total_shards, shards_size, num_shards_per_client):shard_idxs = [i for i in range(total_shards)]client_dict = {i: np.array([], dtype='int64') for i in range(clients)}idxs = np.arange(len(dataset))data_labels = Labellabel_idxs = np.vstack((idxs, data_labels)) # 将标签和数据ID堆叠label_idxs = label_idxs[:, label_idxs[1,:].argsort()]idxs = label_idxs[0,:]for i in range(clients):rand_set = set(np.random.choice(shard_idxs, num_shards_per_client, replace=False)) shard_idxs = list(set(shard_idxs) - rand_set)for rand in rand_set:client_dict[i] = np.concatenate((client_dict[i], idxs[rand*shards_size:(rand+1)*shards_size]), axis=0) # 拼接return client_dict

class MNISTDataset(Dataset):def __init__(self, data,label):self.data = dataself.label = labeldef __getitem__(self, idx):image=np.array(self.data[idx]).astype('float32')image=np.reshape(image,[1,28,28])label=np.array(self.label[idx]).astype('int64')return image, labeldef __len__(self):return len(self.label)

6. 模型训练

class ClientUpdate(object):def __init__(self, data, label, batch_size, learning_rate, epochs):dataset = MNISTDataset(data,label)self.train_loader = DataLoader(dataset,batch_size=batch_size,shuffle=True,drop_last=True)self.learning_rate = learning_rateself.epochs = epochsdef train(self, model):optimizer=paddle.optimizer.SGD(learning_rate=self.learning_rate,parameters=model.parameters())criterion = nn.CrossEntropyLoss(reduction='mean')model.train()e_loss = []for epoch in range(1,self.epochs+1):train_loss = []for image,label in self.train_loader:# image=paddle.to_tensor(image)# label=paddle.to_tensor(label.reshape([label.shape[0],1]))output=model(image)loss= criterion(output,label)# print(loss)loss.backward()optimizer.step()optimizer.clear_grad()train_loss.append(loss.numpy()[0])t_loss=sum(train_loss)/len(train_loss)e_loss.append(t_loss)total_loss=sum(e_loss)/len(e_loss)return model.state_dict(), total_loss

train_x = np.array(Data)
train_y = np.array(Label)
BATCH_SIZE = 32
# 通信轮数
rounds = 100
# client比例
C = 0.1
# clients数量
K = 100
# 每次通信在本地训练的epoch
E = 5
# batch size
batch_size = 10
# 学习率
lr=0.001
# 数据切分
iid_dict = IID(mnist_data_train, 100)
def training(model, rounds, batch_size, lr, ds,L, data_dict, C, K, E, plt_title, plt_color):global_weights = model.state_dict()train_loss = []start = time.time()# clients与server之间通信for curr_round in range(1, rounds+1):w, local_loss = [], []m = max(int(C*K), 1) # 随机选取参与更新的clientsS_t = np.random.choice(range(K), m, replace=False)for k in S_t:# print(data_dict[k])sub_data = ds[data_dict[k]]sub_y = L[data_dict[k]]local_update = ClientUpdate(sub_data,sub_y, batch_size=batch_size, learning_rate=lr, epochs=E)weights, loss = local_update.train(model)w.append(weights)local_loss.append(loss)# 更新global weightsweights_avg = w[0]for k in weights_avg.keys():for i in range(1, len(w)):# weights_avg[k] += (num[i]/sum(num))*w[i][k]weights_avg[k]=weights_avg[k]+w[i][k]   weights_avg[k]=weights_avg[k]/len(w)global_weights[k].set_value(weights_avg[k])# global_weights = weights_avg# print(global_weights)#模型加载最新的参数model.load_dict(global_weights)loss_avg = sum(local_loss) / len(local_loss)if curr_round % 10 == 0:print('Round: {}... \tAverage Loss: {}'.format(curr_round, np.round(loss_avg, 5)))train_loss.append(loss_avg)end = time.time()fig, ax = plt.subplots()x_axis = np.arange(1, rounds+1)y_axis = np.array(train_loss)ax.plot(x_axis, y_axis, 'tab:'+plt_color)ax.set(xlabel='Number of Rounds', ylabel='Train Loss',title=plt_title)ax.grid()fig.savefig(plt_title+'.jpg', format='jpg')print("Training Done!")print("Total time taken to Train: {}".format(end-start))return model.state_dict()#导入模型
mnist_cnn = CNN()
mnist_cnn_iid_trained = training(mnist_cnn, rounds, batch_size, lr, train_x,train_y, iid_dict, C, K, E, "MNIST CNN on IID Dataset", "orange")

Round: 10... 	Average Loss: [0.024]
Round: 20... 	Average Loss: [0.015]
Round: 30... 	Average Loss: [0.008]
Round: 40... 	Average Loss: [0.003]
Round: 50... 	Average Loss: [0.004]
Round: 60... 	Average Loss: [0.002]
Round: 70... 	Average Loss: [0.002]
Round: 80... 	Average Loss: [0.002]
Round: 90... 	Average Loss: [0.001]
Round: 100... 	Average Loss: [0.]
Training Done!
Total time taken to Train: 759.6239657402039

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/146904.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

cad图纸如何防止盗图(一个的制造设计型企业如何保护设计图纸文件)

在现代企业中,设计图纸是公司的重要知识产权,关系到公司的核心竞争力。然而,随着技术的发展,员工获取和传播设计图纸的途径越来越多样化,如何有效地防止员工复制设计图纸成为了企业管理的一大挑战。本文将从技术、管理…

QT常用控件介绍

QT信号与槽机制 connect (A,SIGNLA(aaa()),B, SLOT(bbb())); GUI继承简介 布局管理器 垂直布局水平布局网格布局表单布局 输出控件 Label: 标签Text Browser: 文本浏览器Graphics View : 图形视图框架Calendar Widget: 日历控件LCD Number: 液晶字体数…

VSCode 在部分 Linux 设备上终端和文本编辑器显示文本不正常的解决方法

部分Linux设备上运行VSCode时,发现文本编辑器的缩放不明显,终端字体间距过大等。 这里以Kali Linux为例,其他Linux发行版请选择对应的系统内置的等宽字体 我们依次打开 设置 -> 外观 -> 字体 这里我们可以发现,Kali Linux默…

数据结构与算法----递归

1、迷宫回溯问题 package com.yhb.code.datastructer.recursion¥5;public class MiGong {public static void main(String[] args) {// 先创建一个二维数组,模拟迷宫// 地图int[][] map new int[8][7];// 使用1 表示墙// 上下全部置为1for (int i 0; i…

SpringBoot详解

文章目录 1. Xml 和 JavaConfig1.1 什么是 JavaConfig1.2 JavaConfig 配置容器1.3 ImportResource1.4 PropertyResource 2. 注解SpringBootApplication3.Spring Boot 核心配置文件3.1 yaml格式3.2 .yml 文件3.3 多环境配置3.4 Spring Boot 自定义配置3.4.1 Value 注解 3.5 注解…

浏览器指定DNS

edge--设置 https://dns.alidns.com/dns-query

计算机毕设 大数据B站数据分析与可视化 - python 数据分析 大数据

文章目录 0 前言1 课题背景2 实现效果3 数据获取4 数据可视化5 最后 0 前言 🔥 这两年开始毕业设计和毕业答辩的要求和难度不断提升,传统的毕设题目缺少创新和亮点,往往达不到毕业答辩的要求,这两年不断有学弟学妹告诉学长自己做…

有时候,使用 clang -g test.c 编译出可执行文件后,发现 gdb a.out 进行调试无法读取符号信息,为什么?

经过测试,gdb 并不是和所有版本的 llvm/clang 都兼容的 当 gdb 版本为 9.2 时,能支持 9.0.1-12 版本的 clang,但无法支持 16.0.6 版本的 clang 可以尝试使用 LLVM 专用的调试器 lldb 我尝试使用了 16.0.6 版本的 lldb 调试 16.0.6 的 clan…

kafka客户端应用参数详解

一、基本客户端收发消息 Kafka提供了非常简单的客户端API。只需要引入一个Maven依赖即可&#xff1a; <dependency><groupId>org.apache.kafka</groupId><artifactId>kafka_2.13</artifactId><version>3.4.0</version></depend…

Java后端模拟面试,题集①

1.Spring bean的生命周期 实例化 Instantiation属性赋值 Populate初始化 Initialization销毁 Destruction 2.Spring AOP的创建在bean的哪个时期进行的 &#xff08;图片转载自Spring Bean的完整生命周期&#xff08;带流程图&#xff0c;好记&#xff09;&#xff09; 3.MQ如…

2023年哪款PDF虚拟打印机好用?

PDF文档想必大家都不陌生&#xff0c;在工作中经常会用到该格式的文档&#xff0c;那么有哪些方法能制作PDF文档呢&#xff1f;一般都是借助PDF虚拟打印机的&#xff0c;那么有哪些好用的软件呢&#xff1f; pdfFactory不仅为用户提供了丰富的PDF文档生成、打印功能&#xff0…

【JavaEE】JavaScript

JavaScript 文章目录 JavaScript组成书写方式行内式内嵌式外部式&#xff08;推荐写法&#xff09; 输入输出变量创建动态类型基本数据类型数字类型特殊数字值 String转义字符求长度字符串拼接布尔类型undefined未定义数据类型null 运算符条件语句if语句三元表达式switch 循环语…

嵌入式Linux应用开发-基础知识-第十九章驱动程序基石③

嵌入式Linux应用开发-基础知识-第十九章驱动程序基石③ 第十九章 驱动程序基石③19.5 定时器19.5.1 内核函数19.5.2 定时器时间单位19.5.3 使用定时器处理按键抖动19.5.4 现场编程、上机19.5.5 深入研究&#xff1a;定时器的内部机制19.5.6 深入研究&#xff1a;找到系统滴答 1…

数学建模Matlab之检验与相关性分析

只要做C题基本上都会用到相关性分析、一般性检验等&#xff01; 回归模型性能检验 下面讲一下回归模型的性能评估指标&#xff0c;用来衡量模型预测的准确性。下面是每个指标的简单解释以及它们的应用情境&#xff1a; 1. MAPE (平均绝对百分比误差) 描述: 衡量模型预测的相对…

VRRP配置案例(路由走向分析,端口切换)

以下配置图为例 PC1的配置 acsw下行为access口&#xff0c;上行为trunk口&#xff0c; 将g0/0/3划分到vlan100中 <Huawei>sys Enter system view, return user view with CtrlZ. [Huawei]sysname acsw [acsw] Sep 11 2023 18:15:48-08:00 acsw DS/4/DATASYNC_CFGCHANGE:O…

【Spring Cloud】Ribbon 实现负载均衡的原理,策略以及饥饿加载

文章目录 前言一、什么是 Ribbon二、Ribbon 实现负载均衡的原理2.1 负载均衡的流程2.2 Ribbon 实现负载均衡的源码剖析 三、Ribbon 负载均衡策略3.1 负载均衡策略3.2 演示 Ribbon 负载均衡策略的更改 四、Ribbon 的饥饿加载4.1查看 Ribbon 的懒加载4.2 Ribbon 的饥饿加载模式 前…

Python无废话-办公自动化Excel修改数据

如何修改Excel 符合条件的数据&#xff1f;用Python 几行代码搞定。 需求&#xff1a;将销售明细表的产品名称为PG手机、HW手机、HW电脑的零售价格分别修改为4500、5500、7500&#xff0c;并保存Excel文件。如下图 Python 修改Excel 数据&#xff0c;常见步骤&#xff1a; 1&…

Ubuntu20 QT6.0 编译 ODBC 驱动

一、新建测试项目 新建一个控制台项目&#xff0c; // main.cpp #include <QCoreApplication> #include <QSqlDatabase> #include <QDebug>int main(int argc, char *argv[]) {QCoreApplication a(argc, argv);// 获取当前Qt支持的驱动列表QStringList driv…

1300*C. Coin Rows(枚举模拟)

解析&#xff1a; 两人都绝对聪明&#xff0c;Alice先走&#xff0c;尽量让Bob所能拿的分数最少&#xff0c;Alice有一次往下走的机会&#xff0c;剩余没走过的点正好分为两断断开的区域&#xff0c;所以Bob的最大分数要么在第一格向下或者在最后一列向下。 遍历区间&#xff0…

stm32之1602+DHT11+继电器

描述&#xff1a; 1、DHT11监测温室度&#xff0c;并显示到1602液晶上 2、通过串口打印&#xff08;或通过蓝牙模块在手机上查看&#xff09; 3、当温度大于24度时&#xff0c;开启继电器。小于时关闭继电器&#xff08;继电器可连接风扇---假想O(∩_∩)O哈哈~&#xff09; 一、…