跟代码执行流程,读Megatron源码(四)megatron初始化脚本initialize.py之initialize_megatron()分布式环境初始化

  在前文中,我们讲述了pretrain函数的执行流程,其首要步骤是megatron分组的初始化与环境的配置。本文将深入initialize_megatron函数源码,剖析其初始化分布式训练环境的内部机制。

  注:在此假设读者具备3D并行相关知识

一. initialize_megatron函数的上下文调用关系(initialize.py)

  在Megatron-LM中,initialize.py文件中的initialize_megatron函数是分布式训练环境的初始化核心。该函数由trainning.py的pretrain函数调用,是整个pretrain流程中的第一个核心步骤,负责配置3D并行的环境和分组信息。

1. initialize_megatron函数源码

  initialize_megatron函数的核心代码段如下:

  需要注意的是,尽管initialize_megatron函数还涵盖了设置全局参数、分词器构建、自动恢复配置、TensorBoard日志记录、计时器设置以及依赖编译等辅助功能,但这些功能在初始化流程中虽具重要性,却非其核心职责所在。其核心功能聚焦于上述代码段所描述的分布式与模型并行初始化流程,而该流程主要是通过finish_mpu_init中调用的_initialize_distributed函数实现的,如下图。

2. _initialize_distributed

  _initialize_distributed函数主要有两个作用:

  a. 通过调用torch.distributed.init_process_group()初始化分布式环境,该函数设置了分布式训练所需的基本通信环境,包括进程间的通信后端、worldsize(参与训练的进程总数)、每个进程的rank号等。默认情况下,它会创建一个全局的进程组,这个进程组定义了哪些进程可以相互通信,也可以根据需要创建更多的进程组以支持更复杂的通信模式。

  在使用 init_process_group() 初始化分布式环境之后,我们可以使用PyTorch提供的分布式通信和同步 API 来实现跨进程的通信和数据同步。这包括使用dist.all_reduce() 来聚合梯度、dist.barrier() 同步所有进程的执行点等。

  请注意,在调用init_process_group()之前,需要确保已经正确设置了所有相关的环境变量(如MASTER_ADDR、MASTER_PORT),并且这些环境变量对于每张卡都是唯一的。

  b. 通过调用mpu.initialize_model_parallel()来初始化分布式训练环境中的数据并行(DP)、张量并行(TP)、和流水线并行(PP)的分组,如下图。

  mpu.initialize_model_parallel()的入参解释如下:

  • tensor_model_parallel_size:张量并行的大小。

  • pipeline_model_parallel_size:流水线并行的大小。

  • virtual_pipeline_model_parallel_size:虚拟流水线并行的大小,这是一个更高级的特性,允许在流水线阶段内部进一步分割模型。

  • pipeline_model_parallel_split_rank:该参数指定了流水线分割的起始rank,即决定了哪个rank的设备将开始处理流水线的第一个阶段,然后接下来的阶段按顺序分配给rank号递增的设备。

  • context_parallel_size、expert_model_parallel_size等参数用于特定的模型架构,如带有上下文并行或专家并行的Transformer模型。

  • distributed_timeout_minutes:分布式操作的超时时间(以分钟为单位)。

  • nccl_communicator_config_path:NCCL通信器的配置文件路径,NCCL是用于NVIDIA GPU的高效通信库。

  • order:指定并行策略的顺序,例如'tp-cp-ep-dp-pp'表示张量并行(Tensor Parallelism)、上下文并行(Context Parallelism)、专家并行(Expert Parallelism)、数据并行(Data Parallelism)和流水线并行(Pipeline Parallelism)的顺序。

  • encoder_pipeline_model_parallel_size:专门用于编码器的流水线并行大小。

  • get_embedding_ranks和get_position_embedding_ranks:用于获取特定用于embedding或position embedding的GPU rank号,以便为这些组件配置特定的并行策略。

二. mpu.initialize_model_parallel函数的调用关系(_init_.py)

  mpu.initialize_model_parallel()函数的调用关系在import中表明,如下图:

  其中mpu定义在megatron/core/_init_.py中:

  如上图,mpu指向parallel_state,因此,对于mpu.initialize_model_parallel()的调用既是对parallel_state.initialize_model_parallel()的调用,该函数实现在parallel_state.py中。

三. 分组逻辑的具体实现(parallel_state.py)

1. parallel_state.initialize_model_parallel的调用关系

  parallel_state.initialize_model_parallel函数在分布式训练架构中扮演着关键但非终结性的角色,其核心功能是启动模型并行所需的基础设置。该函数专注于预配置模型并行相关的进程群组(process groups)与全局状态变量,确保并行执行环境的基础通信框架得以确立。

  下面以dp为例,展示该函数的代码执行逻辑。

  首先,该函数创建RankGenerator对象实体,该对象的作用是根据用户输入的tp/dp/pp大小,以及总卡数(world_size),确定最终每张卡的分组,如下图。

  其次,它首先调用RankGenerator组件动态生成高效的rank分配方案,这一步骤是优化资源利用与通信效率的关键。随后,通过调用后端(backend)接口,根据RankGenerator产出的分组策略,构建起实际用于数据交换的通信群组(communication groups)。这些通信群组覆盖了数据并行(dp)、张量并行(tp)、以及流水线并行(pp)等维度(这里只以dp为例),确保模型训练过程中的数据流通与参数同步能够高效且有序地进行,如下图。

  随着所有必需通信群组的成功建立,以及全局分组变量的初始化,模型并行的核心通信网络得以全面搭建完成。这一网络的构建不仅标志着模型并行化训练环境的初步就绪,更为后续的高性能计算任务奠定了坚实的基础,确保了分布式训练过程中数据一致性与效率的最优化实现。

2. RankGenerator.get_ranks

  RankGenerator,顾名思义,其主要职责为生成rank分组,它管理着用户的tp/pp/dp/ep/cp数值,以及全局卡数(world_size)等分组相关的配置项,并在get_ranks函数中调用generate_masked_orthogonal_rank_groups函数获取用户需要的最终分组信息,如下图:

3. generate_masked_orthogonal_rank_groups

  generate_masked_orthogonal_rank_groups函数是rank分组的最终实现,其代码逻辑如下:

  a. 筛选并行性尺寸:

  masked_shape:从parallel_size和mask中筛选出被掩码(即True)的并行尺寸,这些尺寸将用于生成组内的rank。

  unmasked_shape:同样从parallel_size和mask中筛选出未被掩码的并行尺寸,这些尺寸将用于在更广泛的并行环境中(跨组)定位每个组。

  b. 计算步长:

  global_stride:通过prefix_product(parallel_size)计算得到全局步长。

  masked_stride和unmasked_stride:分别根据mask从global_stride中筛选出被掩码和未被掩码的步长。这些步长用于计算全局rank。

  c. 确定组大小和组数:

  group_size:通过prefix_product(masked_shape)[-1]计算得到每个组的大小。

  num_of_group:通过world_size // group_size计算得到组的数量,即全局大小除以每个组的大小。

  d. 生成rank:

  遍历每个组(group_index从0到num_of_group-1)。

  使用decompose(group_index, unmasked_shape)根据未被掩码的并行尺寸分解组索引,得到该组在全局并行环境中的位置(decomposed_group_idx)。

  在每个组内,遍历每个rank(rank_in_group从0到group_size-1)。

  使用decompose(rank_in_group, masked_shape)根据被掩码的并行尺寸分解组内rank,得到该rank在组内的位置(decomposed_rank_idx)。

  计算每个rank的全局索引,通过将被掩码和未被掩码的索引分别与其对应的步长进行内积(inner_product),然后将两个内积相加得到。

  最后,将计算得到的每个组内的rank添加到ranks列表中。

  e. 返回rank列表:

  函数最终返回ranks列表,其中包含了每个组内的所有rank,这些rank在全局并行环境中是唯一的。

  至此分布式训练分组的全部逻辑均介绍完毕,后续文章将继续解析分组完成后的训练逻辑。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/383588.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Linux - 进程的概念、状态、僵尸进程、孤儿进程及进程优先级

目录 进程基本概念 描述进程-PCB task_struct-PCB的一种 task_struct内容分类 查看进程 通过系统目录查看 通过ps命令查看 通过系统调用获取进程的PID和PPID 通过系统调用创建进程- fork初始 fork函数创建子进程 使用if进行分流 Linux进程状态 运行状态-R 浅度睡眠状态-S…

AV1技术学习:Constrained Directional Enhancement Filter

CDEF允许编解码器沿某些(可能是倾斜的)方向应用非线性消阶滤波器。它以88为单位进行。如下图所示,通过旋转和反射所示的三个模板来定义八个预设方向。 Templates of preset directions and their associated directions. The templates correspond to directions of…

Python小工具之httpstat网络分析

一、简介 Python httpstat是一个基于Python的命令行工具,用于测量HTTP请求的性能和状态信息。它能够向目标服务器发送HTTP请求,并显示详细的统计信息,包括DNS解析时间、建立连接时间、TLS/SSL握手时间、首字节时间、总时间等。这些信息对于排…

Mailspring搭建安装教程:打造个性邮件体验

Mailspring搭建安装教程步骤!如何选择电子邮件服务商? Mailspring作为一款功能强大、界面友好的邮件客户端,成为了许多用户的首选。AokSend将为大家提供详细的Mailspring搭建安装教程,帮助您打造个性化的邮件体验。 Mailspring搭…

若依 ruoyi poi Excel合并行的导入

本文仅针对文字相关的合并做了处理 ,图片合并及保存需要另做处理!! 目标:Excel合并行内容的导入 结果: 1. ExcelUtil.java 类,新增方法:判断是否是合并行 /*** 新增 合并行相关代码:…

Java | Leetcode Java题解之第264题丑数II

题目&#xff1a; 题解&#xff1a; class Solution {public int nthUglyNumber(int n) {int[] dp new int[n 1];dp[1] 1;int p2 1, p3 1, p5 1;for (int i 2; i < n; i) {int num2 dp[p2] * 2, num3 dp[p3] * 3, num5 dp[p5] * 5;dp[i] Math.min(Math.min(num2…

Docker-Compose配置zookeeper+KaFka+CMAK简单集群

1. 本地DNS解析管理 # 编辑hosts文件 sudo nano /etc/hosts # 添加以下三个主机IP 192.168.186.77 zoo1 k1 192.168.186.18 zoo2 k2 192.168.186.216 zoo3 k3注&#xff1a;zoo1是192.168.186.77的别名&#xff0c;zoo2是192.168.186.18的别名&#xff0c;zoo3是192.168.186.1…

react中组件间的通信

一、父传子 1.代码展示 import React, { useState } from react;function SonPage(props){ // 子组件const {msg} propsreturn (<div>我是子组件 {msg}</div>) }function App() { // 父组件const [msgText,setMsgText] useState(父传子)return (<div classN…

全国区块链职业技能大赛第八套区块链产品需求分析与方案设计

任务1-1:区块链产品需求分析与方案设计 医疗健康平台中涉及到医院、医生、患者等参与方,他们需要在区块链医疗健康平台中完成账户注册、身份上链、挂号就诊、查询病例等多种业务活动。通过对业务活动的功能分析,可以更好的服务系统的开发流程。基于医疗健康平台系统架构,以…

Vue3可媲美Element Plus Tree组件开发之append节点

在前面的章节&#xff0c;我们完成了可媲美Element Plus Tree组件的基本开发。通过实现各种计算属性&#xff0c;tree数据状态变化引起的视图更新被计算属性所接管了&#xff0c;无需我们再手动做各种遍历、查找以及手动监听操作&#xff0c;这样后续开发高级功能变得易如反掌啦…

kafka架构+原理+源码

1.安装jdk17 sudo yum -y update sudo wget https://download.oracle.com/java/17/latest/jdk-17_linux-x64_bin.rpm sudo yum -y install ./jdk-17_linux-x64_bin.rpm、 sudo java -version 2.安装kafka How to easily install kafka without zookeeper | Aditya’s Blog …

好用的资产管理系统 国内5款资产管理系统排名

选择合适的固定资产管理系统对于企业的资产跟踪和维护至关重要。市场上有许多优秀的资产管理系统&#xff0c;每款系统都有其独特的功能和优势。本文将盘点5个好用的固定资产管理系统排名不分先后&#xff0c;帮助您了解它们的主要特点和适用场景&#xff0c;从而选择最适合您企…

【Java 数据结构】ArrayList类介绍

ArrayList类介绍 初识List接口ArrayList类ArrayList类是什么顺序表的模拟实现初始化增加元素删除元素查找元素修改元素 ArrayList类使用构造方法ArrayList源码阅读常用方法及其注意事项 初识List接口 List 是集合框架中的一个接口, 它的里面包含了一些方法, 例如add(), remove…

JAVA项目样本

学生管理系统SISM-v2.0 项目构建 ebtity 学生类:属性,setter,getter,toString(),构造器… dao层 数据交互,数组CRUD(增删改查) 接口 实现

《Techporters架构搭建》-Day03 功能权限设计

功能权限设计 引言权限介绍什么是权限权限的作用 RBAC概述RBAC的组成RBAC支持的安全原则RBAC模型 基于RBAC的权限设计用户管理角色管理菜单管理部门管理岗位管理 权限系统设计ER图标准RBAC模型表复杂RBAC模型表 多租户架构什么是多租户&#xff1f;多租户特点多租户模型竖井隔离…

汽车免拆诊断案例 | 2014 款上汽名爵 GT 车发动机无法起动

故障现象 一辆2014款上汽名爵GT车&#xff0c;搭载15S4G发动机&#xff0c;累计行驶里程约为18.4万km。该车因左前部发生碰撞事故进厂维修&#xff0c;更换损坏的部件后起动发动机&#xff0c;起动机运转有力&#xff0c;但无着机迹象。用故障检测仪检测&#xff0c;发现无法与…

(leetcode学习)236. 二叉树的最近公共祖先

给定一个二叉树, 找到该树中两个指定节点的最近公共祖先。 百度百科中最近公共祖先的定义为&#xff1a;“对于有根树 T 的两个节点 p、q&#xff0c;最近公共祖先表示为一个节点 x&#xff0c;满足 x 是 p、q 的祖先且 x 的深度尽可能大&#xff08;一个节点也可以是它自己的祖…

【BUG】已解决:TypeError: the JSON object must be str, bytes or bytearray, not dict

已解决&#xff1a;TypeError: the JSON object must be str, bytes or bytearray, not dict 目录 已解决&#xff1a;TypeError: the JSON object must be str, bytes or bytearray, not dict 【常见模块错误】 错误原因&#xff1a; 解决方案&#xff1a; 欢迎来到英杰社区…

2024最新手机软件APP下载排行网站源码 软件下载站PHP源码

源码介绍 这是一款简洁蓝色的手机软件下载应用排行、平台和最新发布网站源码&#xff0c;主要包括主页、APP列表页、APP详情介绍页、新闻资讯列表、新闻详情页、关于我们等模块页面。 软件下载站PHP网站源码&#xff0c;简单的部署上线&#xff0c;访问首页安装程序&#xff…

探索PyMuPDF:Python中的强大PDF处理库

探索PyMuPDF&#xff1a;Python中的强大PDF处理库 背景&#xff1a;为何选择PyMuPDF 在数字化时代&#xff0c;PDF文件因其跨平台的兼容性和对格式的严格保持而成为文档交换的通用格式。然而&#xff0c;处理PDF文件往往需要专门的工具或库。这就是PyMuPDF库的用武之地。PyMuP…