TensorFlow入门(十二、分布式训练)

1、按照并行方式来分

        ①模型并行

                假设我们有n张GPU,不同的GPU被输入相同的数据,运行同一个模型的不同部分。

                在实际训练过程中,如果遇到模型非常庞大,一张GPU不够存储的情况,可以使用模型并行的分布式训练,把模型的不同部分交给不同的GPU负责。这种方式存在一定的弊端:①这种方式需要不同的GPU之间通信,从而产生较大的通信成本。②由于每个GPU上运行的模型部分之间存在一定的依赖,导致规模伸缩性差。

        ②数据并行

                假设我们有n张GPU,不同的GPU被输入不同的数据,运行相同的完整的模型。

                如果遇到一张GPU就能够存下一个模型的情况,可以采用数据并行的方式,这种方式的各部分独立,伸缩性好。

2、按照更新方式来分

        采用数据并行方式时,由于每个GPU负责一部分数据,涉及到如何更新参数的问题,因此分为同步更新和异步更新两种方式。

        ①同步更新

                所有GPU计算完每一个batch(也就是每批次数据)后,再统一计算新权值,等所有GPU同步新值后,再开始进行下一轮计算。

                同步更新的好处是loss的下降比较稳定,但是这个的坏处也很明显,这种方式有等待,处理的速度取决于最慢的那个GPU计算的时间。

        ②异步更新

                每个GPU计算完梯度后,无需等待其他GPU更新,立即更新整体权值并同步。

                异步更新的好处是计算速度快,计算资源能得到充分利用,但是缺点是loss的下降不稳定,抖动大。

3、按照算法来分

        ①Parameter Sever算法

                原理:假设我们有n张GPU,GPU0将数据分成n份分到各张GPU上,每张GPU负责自己那一批次数据的训练,得到梯度后,返回给GPU0上做累计,得到更新的权重参数后,再分发给各张GPU。

        ②Ring AllReduce算法

                原理:假设我们有n张GPU,它们以环形相连,每张GPU都有一个左邻和一个右邻,每张GPU向各自的右邻发送数据,并从它的左邻接近数据。循环n-1次完成梯度积累,再循环n-1次做参数同步。整个算法过程分两个步骤进行:首先是scatter_reduce,然后是allgather。在scatter-reduce,然后是allgather。在scatter-reduce步骤中,GPU将交换数据,使每个GPU可得到最终结果的一个块。在allgather步骤中,gpu将交换这些块,以便所有gpu得到完整的最终结果。

tf.distribute API:

        它是TensorFlow在多GPU、多机器上进行分布式训练用的API。使用这个API,可以在尽可能少改动代码的同时,分布式训练模型。

        它的核心API是tf.distribute.Strategy,只需简单几行代码就可以实现单机多GPU,多机多GPU等情况的分布式训练。

        它的主要优点:

                ①简单易用,开箱即用,高性能

                ②便于各种分布式Strategy切换

                ③支持Custom Training Loop、Estimator、Keras

                ④支持eager excution

tf.distribute.Strategy目前主要有四个Strategy:

        ①MirroredStrategy,即镜像策略

                MirroredStrategy用于单机多GPU、数据并行、同步更新的情况,它会在每个GPU上保存一份模型副本,模型中的每个变量都镜像在所有副本中。这些变量一起形成一个名为MirroredVariable的概念变量。通过apply相同的更新,这些变量保持彼此同步。

                创建一个镜像策略的方法如下:

                        mirrored_strategy = tf.distribute.MirroredStrategy()

                也可以自定义用哪些devices,如:

                        mirrored_strategy = tf.distribute.MirroredStrategy(devices=["/gpu:0","/gpu:1"])

                训练过程中,镜像策略用了高效的All-reduce算法来实现设备之间变量的传递更新。默认情况下它使用NVIDA NCCL (tf.distribute.NcclAllReduce)作为all-reduce算法的实现。通过apply相同的更新,这些变量保持彼此同步。

                官方也提供了其他的一些all-reduce实现方法,可供选择,如:

                        tf.distribute.CrossDeviceOps

                        tf.distribute.HierarchicalCopyAllReduce

                        tf.distribute.ReductionToOneDevice

        ②CentralStorageStrategy,即中心存储策略

                使用该策略时,参数被统一存在CPU里,然后复制到所有GPU上,它的优点是通过这种方式,GPU是负载均衡的,但一般情况下CPU和GPU通信代价比较大。

                创建一个中心存储策略的方法如下:

                             central_storage_strategy = tf.distribute.experimental.CentralStorageStratygy()

        ③MultiWorkerMirroredStrategy,即多端镜像策略

                该API和MirroredStrategy类似,它是其多机多GPU分布式训练的版本。

                创建一个多端镜像策略的方法如下:

                             multiworker_strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()

        ④ParameterServerStrategy,即参数服务策略

                简称PS策略,由于计算速度慢和负载不均衡,很少使用这种策略。

                创建一个参数服务策略的方法如下:

                              ps_strategy = tf.distribute.experimental.ParameterServerStrategy()

示例代码如下:

import tensorflow as tf#设置总训练轮数
num_epochs = 5
#设置每轮训练的批大小
batch_size_per_replica = 64
#设置学习率,指定了梯度下降算法中用于更新权重的步长大小
learning_rate = 0.001#创建镜像策略
strategy = tf.distribute.MirroredStrategy()
#通过同步更新时副本的数量计算出本机的GPU设备数量
print("Number of devices: %d"% strategy.num_replicas_in_sync)
#通过副本数量乘以每轮训练的批大小,得出训练总数据量的大小
batch_size = batch_size_per_replica * strategy.num_replicas_in_sync#函数将输入的图片调整为224x224大小,再将像素值除以255进行归一化,同时返回标签信息
def resize(image,label):image = tf.image.resize(image,[224,224])/255.0return image,label#载入数据集并预处理
dataset,_ = tf.keras.datasets.cifar10.load_data()
images,labels = dataset
dataset = tf.data.Dataset.from_tensor_slices((images, labels))
dataset = dataset.map(resize).shuffle(1024).batch(batch_size)#在strategy.scope下创建模型和优化器
with strategy.scope():#载入了MobileNetV2模型,该模型在ImageNet上预先训练好了,并可以在分类问题上进行微调model = tf.keras.applications.MobileNetV2()#设置训练时用的优化器、损失函数和准确率评测标准model.compile(optimizer = tf.keras.optimizers.Adam(learning_rate = learning_rate),loss = tf.keras.losses.sparse_categorical_crossentropy,metrics = [tf.keras.metrics.sparse_categorical_accuracy])#执行训练过程
model.fit(dataset,epochs = num_epochs)

对于CIFAR-10数据集下载过慢的问题,可以手动去官网下载

https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gzicon-default.png?t=N7T8https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz下载完成后将其放在如下图的路径下,并将数据集文件改名为cifar-10-batches-py.tar.gz并解压

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/155190.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

C进阶-自定义类型:结构体、枚举、联合

本章重点: 结构体: 结构体类型的声明 结构的自引用 结构体变量的定义和初始化 结构体内存对齐 结构体传参 结构体实现位段(位段的填充&可移植性) 1 结构体的声明 1.1 结构的基础知识 结构是一些值的集合,这些值称…

【WSN】模拟无线传感器网络研究(Matlab代码实现)

💥💥💞💞欢迎来到本博客❤️❤️💥💥 🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜密,逻辑清晰,为了方便读者。 ⛳️座右铭&a…

配置mysql+Navicat+XShell宝塔:Mark

Centos7开放3306端口(iptables 防火墙 未设置) Centos7开放3306端口_centos开启3306端口-CSDN博客 firewall-cmd --zonepublic --add-port3306/tcp --permanent Navicat连接1130错误的解决方法 Navicat连接1130错误的解决方法 - 风纳云 ERROR 1062 …

JVM 虚拟机面试知识脑图 初高级

导图下载地址 https://mm.edrawsoft.cn/mobile-share/index.html?uuid3f88d904374599-src&share_type1 类加载器 双亲委派模型 当一个类收到类加载请求,它首先把类加载请求交给父类(如果还有父类,继续往上递交请求).如果父类无法加载该类,再交给子类加载 防止内存中出现…

vivado FFT IP仿真(3)FFT IP选项说明

xilinx FFT IP手册PG109 1 Configuration 2 Implementation 3 Detailed Implementation IP Symbol

Malformed \uxxxx encoding.问题解决方案

问题背景 Maven项目构建时报错如下, [ERROR] Malformed \uxxxx encoding. [ERROR] java.lang.IllegalArgumentException: Malformed \uxxxx encoding. [ERROR] [ERROR] To see the full stack trace of the errors, re-run Maven with the -e switch. [ERROR] Re…

IDEA使用模板创建webapp时,web.xml文件版本过低的一种解决方法

创建完成后的web.xml 文件&#xff0c;版本太低 <!DOCTYPE web-app PUBLIC"-//Sun Microsystems, Inc.//DTD Web Application 2.3//EN""http://java.sun.com/dtd/web-app_2_3.dtd" ><web-app><display-name>Archetype Created Web Appl…

使用IntelliJ Idea必备的插件!

趁手的工具让开发事半功倍&#xff0c;好用的IDEA插件让效率加倍。 今天给大家分享几个优秀的IDEA插件。 插件安装 首先得知道在IDEA哪安装插件&#xff1f; 点击File---->Settings---->找到Plugins标签&#xff0c;即可搜索想要的插件进行安装了。 现在来看下有哪些值…

Ubuntu16.04apt更新失败

先设置网络设置 换成nat、桥接&#xff0c;如果发现都不行&#xff0c;那么就继续下面操作 1.如果出现一开始就e&#xff0c;检查源&#xff0c;先换源 2.换完源成功之后&#xff0c;ping网络&#xff0c;如果ping不通就是网络问题 如果ping baidu.com ping不通但是ping 112…

Go流程控制与快乐路径原则

Go流程控制与快乐路径原则 文章目录 Go流程控制与快乐路径原则一、流程控制基本介绍二、if 语句2.1 if 语句介绍2.2 单分支结构的 if 语句形式2.3 Go 的 if 语句的特点2.3.1 分支代码块左大括号与if同行2.3.2 条件表达式不需要括号 三、操作符3.1 逻辑操作符3.2 操作符的优先级…

面试经典 150 题 20 —(数组 / 字符串)— 151. 反转字符串中的单词

151. 反转字符串中的单词 方法一 class Solution { public:string reverseWords(string s) {istringstream instr(s);vector<string> words{};string word;while(instr>>word){words.push_back(word);}int length words.size();string result words[length-1];f…

振弦传感器和振弦采集仪应用隧道安全监测的解决方案

振弦传感器和振弦采集仪应用隧道安全监测的解决方案 现代隧道越来越复杂&#xff0c;对于隧道安全的监测也变得越来越重要。振弦传感器和振弦采集仪已经成为了一种广泛应用的技术&#xff0c;用于隧道结构的监测和评估。它们可以提供更精确的测量结果&#xff0c;并且可以在实…

Android Termux安装MySQL,并使用cpolar实现公网安全远程连接[内网穿透]

文章目录 前言1.安装MariaDB2.安装cpolar内网穿透工具3. 创建安全隧道映射mysql4. 公网远程连接5. 固定远程连接地址 前言 Android作为移动设备&#xff0c;尽管最初并非设计为服务器&#xff0c;但是随着技术的进步我们可以将Android配置为生产力工具&#xff0c;变成一个随身…

微服务10-Sentinel中的隔离和降级

文章目录 降级和隔离1.Feign整合Sentinel来完成降级1.2总结 2.线程隔离两种实现方式的区别3.线程隔离中的舱壁模式3.2总结 4.熔断降级5.熔断策略&#xff08;根据异常比例或者异常数&#xff09; 回顾 我们的限流——>目的&#xff1a;在并发请求的情况下服务出现故障&…

吸烟检测Y8N,支持C++,PYTHON,ANDROID

吸烟检测Y8N&#xff0c;支持C,PYTHON,ANDROID 现在&#xff0c;深度学习已经非常流行&#xff0c;最新出来的YOLOV8&#xff0c;更是将精度和速度达到极限。训练一个目标很简单&#xff0c;首先&#xff0c;标记图片&#xff1b;然后训练得到PT模型&#xff1b;最后转换成ONNX…

抖音小店创业攻略,快速了解这些适合新手经营的类目

抖音小店是抖音平台上的一种新型电商形态&#xff0c;它允许用户在抖音上开设自己的小店&#xff0c;销售自己的商品。抖音小店的开设门槛低&#xff0c;成本也不高&#xff0c;因此很受新手创业者的青睐。那么&#xff0c;下面不若与众将介绍抖音小店中有哪些适合新手创业者经…

hive 知识总结

​编辑 社区公告教程下载分享问答JD 登 录 注册 01 hive 介绍与安装 1 hive介绍与原理分析 Hive是一个基于Hadoop的开源数据仓库工具&#xff0c;用于存储和处理海量结构化数据。它是Facebook 2008年8月开源的一个数据仓库框架&#xff0c;提供了类似于SQL语法的HQL&#xf…

jmeter 请求发送加密参数

最近在做http加密接口&#xff0c;请求头的uid参数及body的请求json参数都经过加密再发送请求&#xff0c;加密方式为&#xff1a;ase256。所以&#xff0c;jmeter发送请求前也需要对uid及json参数进行加密。我这里是让开发写了个加密、解密的jar&#xff0c;jmeter直接调用这个…

掌握 Web3 的关键工具:9大宝藏APP助你玩转区块链

Web3世界充满了无限机遇&#xff0c;但要掌握它&#xff0c;您需要合适的工具&#xfffd;&#xfffd;&#xfffd;。今天&#xff0c;我将为您介绍9款Web3必备APP&#xff0c;涵盖钱包、DEX、和工具三大类别。而且&#xff0c;我要特别强烈推荐一个强大的钱包——Bitget Wall…

初识Java 13-2 异常

目录 标准Java异常 新特性&#xff1a;更好的NullPointerException报告机制 使用finally执行清理 finally有什么用 在return时使用finally 缺陷&#xff1a;异常丢失 异常的约束 构造器 本笔记参考自&#xff1a; 《On Java 中文版》 标准Java异常 Throwable类描述了任…