政安晨:【Keras机器学习示例演绎】(四十二)—— 使用 KerasNLP 和 tf.distribute 进行数据并行训练

目录

简介

导入

基本批量大小和学习率

计算按比例分配的批量大小和学习率


政安晨的个人主页政安晨

欢迎 👍点赞✍评论⭐收藏

收录专栏: TensorFlow与Keras机器学习实战

希望政安晨的博客能够对您有所裨益,如有不足之处,欢迎在评论区提出指正!

本文目标:使用 KerasNLP 和 tf.distribute 进行数据并行训练。

简介


分布式训练是一种在多台设备或机器上同时训练深度学习模型的技术。它有助于缩短训练时间,并允许使用更多数据训练更大的模型。KerasNLP 是一个为自然语言处理任务(包括分布式训练)提供工具和实用程序的库。

在本文中,我们将使用 KerasNLP 在 wikitext-2 数据集(维基百科文章的 200 万字数据集)上训练基于 BERT 的屏蔽语言模型 (MLM)。MLM 任务包括预测句子中的屏蔽词,这有助于模型学习单词的上下文表征。

本指南侧重于数据并行性,尤其是同步数据并行性,即每个加速器(GPU 或 TPU)都拥有一个完整的模型副本,并查看不同批次的部分输入数据。部分梯度在每个设备上计算、汇总,并用于计算全局梯度更新。

具体来说,本文将教您如何在以下两种设置中使用 tf.distribute API 在多个 GPU 上训练 Keras 模型,只需对代码做最小的改动:

—— 在一台机器上安装多个 GPU(通常为 2 至 8 个)(单主机、多设备训练)。这是研究人员和小规模行业工作流程最常见的设置。
—— 在由多台机器组成的集群上,每台机器安装一个或多个 GPU(多设备分布式训练)。这是大规模行业工作流程的良好设置,例如在 20-100 个 GPU 上对十亿字数据集进行高分辨率文本摘要模型训练。

!pip install -q --upgrade keras-nlp
!pip install -q --upgrade keras  # Upgrade to Keras 3.

导入

import osos.environ["KERAS_BACKEND"] = "tensorflow"import tensorflow as tf
import keras
import keras_nlp

在开始任何训练之前,让我们配置一下我们的单 GPU,使其显示为两个逻辑设备。

在使用两个或更多物理 GPU 进行训练时,这完全没有必要。这只是在默认 colab GPU 运行时(只有一个 GPU 可用)上显示真实分布式训练的一个技巧。

!nvidia-smi --query-gpu=memory.total --format=csv,noheader
physical_devices = tf.config.list_physical_devices("GPU")
tf.config.set_logical_device_configuration(physical_devices[0],[tf.config.LogicalDeviceConfiguration(memory_limit=15360 // 2),tf.config.LogicalDeviceConfiguration(memory_limit=15360 // 2),],
)logical_devices = tf.config.list_logical_devices("GPU")
logical_devicesEPOCHS = 3
24576 MiB

要使用 Keras 模型进行单主机、多设备同步训练,您需要使用 tf.distribute.MirroredStrategy API。下面是其工作原理:

—— 实例化 MirroredStrategy,可选择配置要使用的特定设备(默认情况下,该策略将使用所有可用的 GPU)。
—— 使用该策略对象打开一个作用域,并在该作用域中创建所需的包含变量的所有 Keras 对象。通常情况下,这意味着在分发作用域内创建和编译模型。
—— 像往常一样通过 fit() 训练模型。

strategy = tf.distribute.MirroredStrategy()
print(f"Number of devices: {strategy.num_replicas_in_sync}")
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1')
Number of devices: 2

基本批量大小和学习率

base_batch_size = 32
base_learning_rate = 1e-4

计算按比例分配的批量大小和学习率

scaled_batch_size = base_batch_size * strategy.num_replicas_in_sync
scaled_learning_rate = base_learning_rate * strategy.num_replicas_in_sync

现在,我们需要下载并预处理 wikitext-2 数据集。该数据集将用于预训练 BERT 模型。我们将过滤掉短行,以确保数据有足够的语境用于训练。

keras.utils.get_file(origin="https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip",extract=True,
)
wiki_dir = os.path.expanduser("~/.keras/datasets/wikitext-2/")# Load wikitext-103 and filter out short lines.
wiki_train_ds = (tf.data.TextLineDataset(wiki_dir + "wiki.train.tokens",).filter(lambda x: tf.strings.length(x) > 100).shuffle(buffer_size=500).batch(scaled_batch_size).cache().prefetch(tf.data.AUTOTUNE)
)
wiki_val_ds = (tf.data.TextLineDataset(wiki_dir + "wiki.valid.tokens").filter(lambda x: tf.strings.length(x) > 100).shuffle(buffer_size=500).batch(scaled_batch_size).cache().prefetch(tf.data.AUTOTUNE)
)
wiki_test_ds = (tf.data.TextLineDataset(wiki_dir + "wiki.test.tokens").filter(lambda x: tf.strings.length(x) > 100).shuffle(buffer_size=500).batch(scaled_batch_size).cache().prefetch(tf.data.AUTOTUNE)
)

在上述代码中,我们下载并提取了 wikitext-2 数据集。然后,我们定义了三个数据集:wiki_train_ds、wiki_val_ds 和 wiki_test_ds。我们对这些数据集进行了过滤,以去除短行,并对其进行批处理,以提高训练效率。

在 NLP 训练/调整中,使用衰减学习率是一种常见的做法。在这里,我们将使用多项式衰减时间表(PolynomialDecay schedule)。

total_training_steps = sum(1 for _ in wiki_train_ds.as_numpy_iterator()) * EPOCHS
lr_schedule = tf.keras.optimizers.schedules.PolynomialDecay(initial_learning_rate=scaled_learning_rate,decay_steps=total_training_steps,end_learning_rate=0.0,
)class PrintLR(tf.keras.callbacks.Callback):def on_epoch_end(self, epoch, logs=None):print(f"\nLearning rate for epoch {epoch + 1} is {model_dist.optimizer.learning_rate.numpy()}")

我们还要回调 TensorBoard,这样就能在本教程后半部分训练模型时可视化不同的指标。我们将所有回调放在一起,如下所示:

callbacks = [tf.keras.callbacks.TensorBoard(log_dir="./logs"),PrintLR(),
]print(tf.config.list_physical_devices("GPU"))
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

准备好数据集后,我们现在要在 strategy.scope() 中初始化并编译模型和优化器:

with strategy.scope():# Everything that creates variables should be under the strategy scope.# In general this is only model construction & `compile()`.model_dist = keras_nlp.models.BertMaskedLM.from_preset("bert_tiny_en_uncased")# This line just sets pooled_dense layer as non-trainiable, we do this to avoid# warnings of this layer being unusedmodel_dist.get_layer("bert_backbone").get_layer("pooled_dense").trainable = Falsemodel_dist.compile(loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),optimizer=tf.keras.optimizers.AdamW(learning_rate=scaled_learning_rate),weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],jit_compile=False,)model_dist.fit(wiki_train_ds, validation_data=wiki_val_ds, epochs=EPOCHS, callbacks=callbacks)
Epoch 1/3
Learning rate for epoch 1 is 0.00019999999494757503239/239 ━━━━━━━━━━━━━━━━━━━━ 43s 136ms/step - loss: 3.7009 - sparse_categorical_accuracy: 0.1499 - val_loss: 1.1509 - val_sparse_categorical_accuracy: 0.3485
Epoch 2/3239/239 ━━━━━━━━━━━━━━━━━━━━ 0s 122ms/step - loss: 2.6094 - sparse_categorical_accuracy: 0.5284
Learning rate for epoch 2 is 0.00019999999494757503239/239 ━━━━━━━━━━━━━━━━━━━━ 32s 133ms/step - loss: 2.6038 - sparse_categorical_accuracy: 0.5274 - val_loss: 0.9812 - val_sparse_categorical_accuracy: 0.4006
Epoch 3/3239/239 ━━━━━━━━━━━━━━━━━━━━ 0s 123ms/step - loss: 2.3564 - sparse_categorical_accuracy: 0.6053
Learning rate for epoch 3 is 0.00019999999494757503239/239 ━━━━━━━━━━━━━━━━━━━━ 32s 134ms/step - loss: 2.3514 - sparse_categorical_accuracy: 0.6040 - val_loss: 0.9213 - val_sparse_categorical_accuracy: 0.4230

根据范围拟合模型后,我们对其进行正常评估!

model_dist.evaluate(wiki_test_ds)
 29/29 ━━━━━━━━━━━━━━━━━━━━ 3s 60ms/step - loss: 1.9197 - sparse_categorical_accuracy: 0.8527[0.9470901489257812, 0.4373602867126465]

对于跨多台计算机的分布式训练(而不是只利用单台计算机上的多个设备进行训练),您可以使用两种分布式策略:MultiWorkerMirroredStrategy 和 ParameterServerStrategy:

—— tf.distribution.MultiWorkerMirroredStrategy(多工作站策略)实现了一种 CPU/GPU 多工作站同步解决方案,可与 Keras 风格的模型构建和训练循环配合使用,并使用跨副本的梯度同步还原。
—— tf.distribution.experimental.ParameterServerStrategy(参数服务器策略)实现了一种异步 CPU/GPU 多工作站解决方案,其中参数存储在参数服务器上,工作站异步更新梯度到参数服务器。


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/324988.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Linux提权--内核漏洞--web用户提权(脏牛)本地提权(脏管道)

免责声明:本文仅做技术交流与学习... 目录 Linux-内核漏洞Web用户提权-探针&利用-脏牛dcow nmap扫描目标IP及端口 导入脚本,进行探针 通过MSF发现目标机器存在脏牛漏洞 ---上传信息搜集的文件,查找漏洞,利用漏洞,继续上传EXP. --密码改了,再用xshell连一下就行了. …

对话易参创始人黄怡然:股权能不能赋能企业增长?| 极新企服直播实录

“ 致所有爱画饼的老板 ” 整理 | 云舒 编辑 | 小白 出品|极新 2022年以前,股权激励作为企业实现增长、吸引人才、保留人才并大幅度激发人才价值的重要手段,几乎成为每一个企业的标配。但是,现在这个时代,股权激励几…

Python 将Excel转换为多种图片格式(PNG, JPG, BMP, SVG)

目录 安装Python Excel库 使用Python将Excel工作表转换为PNG,JPG或BMP图片 使用Python将Excel特定单元格区域转换为PNG,JPG或BMP图片 使用Python将Excel工作表转换为SVG图片 有时,你可能希望以图片形式分享Excel数据,以防止他…

基于遗传优化的双BP神经网络金融序列预测算法matlab仿真

目录 1.程序功能描述 2.测试软件版本以及运行结果展示 3.核心程序 4.本算法原理 5.完整程序 1.程序功能描述 基于遗传优化的双BP神经网络金融序列预测算法matlab仿真,采用的双BP神经网络结构如下: 2.测试软件版本以及运行结果展示 MATLAB2022A版本…

【讲解下目标追踪】

🌈个人主页: 程序员不想敲代码啊 🏆CSDN优质创作者,CSDN实力新星,CSDN博客专家 👍点赞⭐评论⭐收藏 🤝希望本文对您有所裨益,如有不足之处,欢迎在评论区提出指正,让我们共…

个人微信api

简要描述: 退出群聊 请求URL: http://域名地址/quitChatRoom 请求方式: POST 请求头Headers: Content-Type:application/json Authorization:login接口返回 参数: 参数名 必选 类型 …

物流EDI:GEFCO EDI 需求分析

GEFCO专注于汽车物流领域近70年,是欧洲整车市场的物流供应商,也是欧洲十大领先的运输和物流集团之一。GEFCO的业务遍及六大洲,业务覆盖150个国家,在全球拥有庞大的员工队伍,在全球汽车行业的挑战中茁壮成长。为汽车制造…

MySQL-索引篇

文章目录 什么是索引?索引的优缺点索引底层数据结构选型Hash表二叉查找树AVL树红黑树B树&B树 索引类型总结主键索引二级索引聚集索引与非聚集索引聚集索引非聚集索引 覆盖索引与关联索引覆盖索引联合查询最左前缀匹配原则 索引下推如何正确使用索引选择合适的字…

iOS Xcode Debug View Hierarchy 查看视图层级结构

前言 我们难免会遇到接手别人项目的情况,让你去改他遗留的问题,想想都头大,😂可是也不得不面对。作为开发者只要让我们找到出问题的代码文件,我们就总有办法去解决它,那么如何快速定位问题对应的代码文件呢…

r语言数据分析案例-北京市气温预测分析与研究

一、选题背景 近年来,人类大量燃烧煤炭、天然气等含碳燃料导致温室气 体过度排放,大量温室气体强烈吸收地面辐射中的红外线,造 成温室效应不断累积,使得地球温度上升,造成全球气候变暖。 气象温度的预测一直以来都是…

python视频转码脚本

今天有一个临时的需求,就是需要将一个wmv的初步转码成mp4的格式。找了一圈,免费的工具少,即使有免费的工具,在功能上也是有所限制,或者会给你塞广告或者附带安装其它流氓小游戏或者杀毒程序。 我并非不支持正版&#…

前端笔记-day05

文章目录 01-结构伪类选择器02-结构伪类选择器-公式用法03-伪元素选择器04-盒子模型-组成05-盒子模型-边框线06-盒子模型-单方向边框线07-盒子模型-内边距08-盒子模型-padding多值写法09-盒子模型-尺寸计算10-盒子模型-版心居中11-清除默认样式12-元素溢出overflow13-外边距合并…

贪心算法-----柠檬水找零

今日题目:leetcode860 题目链接:点击跳转题目 分析: 顾客只会给三种面值:5、10、20,先分类讨论 当收到5美元时:不用找零,面值5张数1当收到10美元时:找零5美元,面值5张数…

3588 pwm android12 的操作

问题: 客户需要在android12 的界面上操作板卡上的 PWM 蜂鸣器设备。 过程: 1 了解一下 3588 android12 源码的 关于PWM 的驱动。 设备树找不到 pwm 但是, 还不知道,android12 最终包含的 设备树是哪个,但是经过我的…

Meilisearch使用过程趟过的坑

Elasticsearch 做为老牌搜索引擎,功能基本满足,但复杂,重量级,适合大数据量。 MeiliSearch 设计目标针对数据在 500GB 左右的搜索需求,极快,单文件,超轻量。 所以,对于中小型项目来说…

使用html和css实现个人简历表单的制作

根据下列要求,做出下图所示的个人简历(表单) 表单要求 Ⅰ、表格整体的边框为1像素,单元格间距为0,表格中前六列列宽均为100像素,第七列 为200像素,表格整体在页面上居中显示; Ⅱ、前…

2024年电工杯数学建模竞赛A题B题思路代码分享

您的点赞收藏是我继续更新的最大动力! 一定要点击如下的卡片链接,那是获取资料的入口! 点击链接加入群聊【2024电工杯】:http://qm.qq.com/cgi-bin/qm/qr?_wv1027&k_PrjarulWZU8JsAOA9gnj_oHKIjFe195&authKeySbv2XM853…

简洁大气APP下载单页源码

源码介绍 简洁大气APP下载单页源码,源码由HTMLCSSJS组成,记事本打开源码文件可以进行内容文字之类的修改,双击html文件可以本地运行效果,也可以上传到服务器里面 效果截图 源码下载 简洁大气APP下载单页源码

3D Web轻量化引擎HOOPS Communicator如何处理DWG文件中的图纸?

在当今工程设计和建筑领域,数字化技术已经成为不可或缺的一部分。HOOPS Communicator作为一种强大的三维数据可视化工具,被广泛应用于处理各种CAD文件,其中包括AutoCAD的DWG格式。在这篇文章中,我们将探讨HOOPS Communicator是如何…

【Win10设备管理器中无端口选项】

计算机疑难杂症分享002 Win10设备管理器中无端口选项1、问题现象2、问题原因3、问题解决3.1、驱动精灵(亲测的此方法)3.2、添加过时硬件3.3、官方的方法 Win10设备管理器中无端口选项 1、问题现象 当我调试串口通信时,发现打开设备管理器没有端口,打开…