DeepSeek DeepEP学习(三)normal dispatch

上节介绍了normal kernel执行过程中会分成两部分,第一步通过notify_dispatch计算meta信息,然后本节介绍数据dispatch的过程。

notify_dispatch过程中会计算其他所有rank发送给当前rank多少token,写入到host的moe_recv_counter_mapped,还会计算其他所有rdma_rank发送给当前rank多少token,写入到host的moe_recv_rdma_counter_mapped,这里通过cpu端轮询这个值,轮询到之后就可以在notify_dispatch不结束的情况下开始分配显存,做到overlap。

Buffer::internode_dispatch(...) {internode::notify_dispatch(...);while (true) {// Read total countnum_recv_tokens = static_cast<int>(*moe_recv_counter);num_rdma_recv_tokens = static_cast<int>(*moe_recv_rdma_counter);// Read per-expert countbool ready = (num_recv_tokens >= 0) and (num_rdma_recv_tokens >= 0);for (int i = 0; i < num_local_experts and ready; ++ i)ready &= moe_recv_expert_counter[i] >= 0;if (ready)break;}auto recv_x = torch::empty({num_recv_tokens, hidden}, x.options());...internode::dispatch(...);...
}

角色分配

constexpr int kNumDispatchRDMASenderWarps = 7;
SETUP_LAUNCH_CONFIG(num_channels * 2, (kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NVL_PEERS) * 32, stream);__global__ void dispatch(...) {const auto role_meta = [=]() -> std::pair<WarpRole, int> {if (is_forwarder) {if (warp_id < NUM_MAX_NVL_PEERS) {return {WarpRole::kRDMAAndNVLForwarder, (warp_id + channel_id) % NUM_MAX_NVL_PEERS};} else {return {WarpRole::kForwarderCoordinator, warp_id - NUM_MAX_NVL_PEERS};}    } else if (warp_id < kNumDispatchRDMASenderWarps) {return {WarpRole::kRDMASender, -1};} else if (warp_id == kNumDispatchRDMASenderWarps) {return {WarpRole::kRDMASenderCoordinator, -1};} else {return {WarpRole::kNVLReceivers, (warp_id + channel_id - kNumDispatchRDMASenderWarps) % NUM_MAX_NVL_PEERS};}}();auto warp_role = role_meta.first;auto target_rank = role_meta.second; // Not applicable for RDMA senders

dispatch的时候会对每个warp分配角色以执行不同的逻辑,可以看到一共有num_channels * 2个sm,每个sm有7 + 1 + 8个warp,角色分配如下图所示
kRDMASender warp数量为kNumDispatchRDMASenderWarps,即7个,这个角色主要计算token会被发送到哪些rdma_rank,然后将数据拷贝到对应rdma_rank的的buffer。
kRDMASenderCoordinator warp数量为1
kNVLReceivers warp数量为NUM_MAX_NVL_PEERS

kRDMAAndNVLForwarder warp数量为NUM_MAX_NVL_PEERS
kForwarderCoordinator warp数量为1

kRDMASender

首先是将本端的meta信息发送给其他rank,rdma_channel_meta是SymBuffer,对于当前sm的send_buffer保存了向每一个rdma_rank发送的meta信息。假设当前channel为x。
对于[0, 7]之间的lane_id,send_buffer(dst_rdma_rank)[lane_id]表示的是channel[0, x - 1]向dst_rdma_rank的gpu[lane_id]发送的token数
对于[8, 15]之间的lane_id,send_buffer(dst_rdma_rank)[lane_id]表示的是channel[0, x]向dst_rdma_rank的gpu[lane_id]发送的token数
send_buffer(dst_rdma_rank)[16]表示的是channel[0, x - 1]向dst_rdma_rank这个节点发送的token数
send_buffer(dst_rdma_rank)[16]表示的是channel[0, x]向dst_rdma_rank这个节点发送的token数

 auto rdma_channel_meta = SymBuffer<int>(rdma_buffer_ptr, NUM_MAX_NVL_PEERS * 2 + 2, kNumRDMARanks, channel_id, num_channels);if (warp_role == WarpRole::kRDMASender) {for (int dst_rdma_rank = warp_id; dst_rdma_rank < kNumRDMARanks; dst_rdma_rank += kNumDispatchRDMASenderWarps) {if (lane_id < NUM_MAX_NVL_PEERS) {rdma_channel_meta.send_buffer(dst_rdma_rank)[lane_id] = -(channel_id == 0 ? 0 : gbl_channel_prefix_matrix[(dst_rdma_rank * NUM_MAX_NVL_PEERS + lane_id) * num_channels + channel_id - 1]) - 1;} else if (lane_id < NUM_MAX_NVL_PEERS * 2) {rdma_channel_meta.send_buffer(dst_rdma_rank)[lane_id] = -gbl_channel_prefix_matrix[(dst_rdma_rank * NUM_MAX_NVL_PEERS + lane_id - NUM_MAX_NVL_PEERS) * num_channels + channel_id] - 1;} else if (lane_id == NUM_MAX_NVL_PEERS * 2) {rdma_channel_meta.send_buffer(dst_rdma_rank)[lane_id] = -(channel_id == 0 ? 0 : rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id - 1]) - 1;} else if (lane_id == NUM_MAX_NVL_PEERS * 2 + 1) {rdma_channel_meta.send_buffer(dst_rdma_rank)[lane_id] = -rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id] - 1;}nvshmemx_int_put_nbi_warp(rdma_channel_meta.recv_buffer(rdma_rank), rdma_channel_meta.send_buffer(dst_rdma_rank), NUM_MAX_NVL_PEERS * 2 + 2,translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));}}
}

每个channel处理一段token_start_idx到token_end_idx的连续token,每个warp每次处理一个token,每个warp内部的一个lane对应一个同号卡节点

__global__ void dispatch(...) {auto hidden_bytes = hidden_int4 * sizeof(int4);auto num_bytes_per_rdma_token = get_num_bytes_per_rdma_token(hidden_int4, num_scales, num_topk, num_topk);auto rdma_channel_data = SymBuffer<int8_t>(rdma_buffer_ptr, num_max_rdma_chunked_recv_tokens * num_bytes_per_rdma_token, kNumRDMARanks, channel_id, num_channels);auto rdma_channel_meta = SymBuffer<int>(rdma_buffer_ptr, NUM_MAX_NVL_PEERS * 2 + 2, kNumRDMARanks, channel_id, num_channels);if (warp_role == WarpRole::kRDMASender) {int token_start_idx, token_end_idx;get_channel_task_range(num_tokens, num_channels, channel_id, token_start_idx, token_end_idx);for (int dst_rdma_rank = warp_id; dst_rdma_rank < kNumRDMARanks; dst_rdma_rank += kNumDispatchRDMASenderWarps) {}...}}

rdma_channel_data为SymBuffer,假设一共两个channel,3台机器,当前为channel1,则rdma_channel_data格式如下,send_ptr和recv_ptr指向箭头位置,每一个矩形都是一个类似nccl的fifo,一共有num_max_rdma_chunked_recv_tokens个slot,每个slot存一个token,即num_bytes_per_rdma_token,每个矩形都对应了一个dst_rdma_rank,比如send_ptr开始的第一个矩形表示channel_id为1的warp如果要和rdma_rank 0的机器进行通信,则使用这一个矩形对应的buffer。
在这里插入图片描述

图 1
### 顺序锁 如下send_ptr(0)表示上述的send_ptr中第一块buffer,对应rdma_rank0,是一个fifo,由head,tail指针进行同步。因为一个warp对应一个token,所以这里warp0和warp1处理的两个token可能都会被发送到rdma_rank0,因此对send_ptr(0)的访问存在竞争,因此这里DeepEP引入了一个顺序锁来保证原子性。

在这里插入图片描述

图 2
__global__ void dispatch(...) {__shared__ volatile int rdma_send_next_token_idx;if (warp_role == WarpRole::kRDMASender) {for (token_idx = token_start_idx + warp_id; token_idx < token_end_idx; token_idx += kNumDispatchRDMASenderWarps) {// Acquire sequential lockwhile (lane_id == 0 and rdma_send_next_token_idx != token_idx);__syncwarp();...// Release sequential locklane_id == 0 ? (rdma_send_next_token_idx += 1) : 0;}
}

rdma_send_next_token_idx初始化为0,每次等到rdma_send_next_token_idx等于自己处理的token_idx的时候,说明轮到自己了,这时候才可以访问临界区,释放锁就将rdma_send_next_token_idx + 1。

rdma发送过程中的同步

前边有说到send_ptr(0)是一个fifo,对应发送到rdma_rank0,由head,tail进行同步,生产者就是kRDMASender,假设当前是warp0,正在处理token[x],然后发现这个token需要被发送到所有rdma_rank,那么首先获取每个rdma_rank对应的tail指针,然后判断每一个rdma_rank对应的fifo是否有空间,判断方法就是对比tail是否大于head,如果大于说明有空间,如果每个fifo都有空间,那么将数据拷贝到每个fifo的tail处。这里的消费者其实是对端rdma_rank,当这个fifo中head位置的token被对端处理完成后,对端会通过rdma_write更新这个head,完整的图如下所示
在这里插入图片描述

图 3
然后再看下这一过程的代码 首先获取一个token_idx,每个lane负责一个对端rdma_rank,然后每个lane去is_token_in_rank中获取当前token的mask,即is_token_in_rank_uint64,thread[lane_id]的is_token_in_rank_uint64表示是否需要发送到rdma_rank[lane_id]。 然后开始获取顺序锁,获取到锁之后判断is_token_in_rank_uint64,如果不为0,那么说明需要发送到lane_id这个rdma_rank,然后开始等待fifo中有空闲位置,rdma_tail_idx是这次要填充的位置,如果rdma_tail_idx - cached_rdma_channel_head大于等于fifo中slot数,那么就等待,否则说明有了位置,然后将上一次的tail + 1写入rdma_send_channel_tail,并更新last_rdma_tail_idx,表示上一个token已经写入到了各个fifo。
__global__ void dispatch(...) {if (warp_role == WarpRole::kRDMASender) {for (token_idx = token_start_idx + warp_id; token_idx < token_end_idx; token_idx += kNumDispatchRDMASenderWarps) {// Read RDMA rank existenceuint64_t is_token_in_rank_uint64 = 0; if (lane_id < kNumRDMARanks)is_token_in_rank_uint64 = *reinterpret_cast<const uint64_t*>(is_token_in_rank + token_idx * num_ranks + lane_id * NUM_MAX_NVL_PEERS);// Acquire sequential lockwhile (lane_id == 0 and rdma_send_next_token_idx != token_idx);__syncwarp();// Acquire next tailint rdma_tail_idx = -1;if (is_token_in_rank_uint64 != 0) { rdma_tail_idx = rdma_send_channel_next_tail[lane_id] ++;while (rdma_tail_idx - cached_rdma_channel_head >= num_max_rdma_chunked_recv_tokens)cached_rdma_channel_head = static_cast<int>(ld_volatile_global(rdma_channel_head.buffer(lane_id)));}    __syncwarp();// Update last token tailif (last_rdma_tail_idx >= 0)st_release_cta(const_cast<const int *>(rdma_send_channel_tail + lane_id), last_rdma_tail_idx + 1);last_rdma_tail_idx = rdma_tail_idx;// Release sequential locklane_id == 0 ? (rdma_send_next_token_idx += 1) : 0; }}
}

拷贝数据

现在各个lane_id都拿到了自己的rdma_tail_idx,然后开始拷贝数据,首先需要整个warp知道需要往哪几个lane_id对应的rdma_rank发送数据,然后开始遍历kNumRDMARanks,通过shfl_sync获取第i个lane的rdma_tail_idx,如果rdma_tail_idx大于0,说明需要向rdma_rank[i]发数据,则将i记录到topk_ranks中。
broadcast函数就是通过shfl_sync将lane[i]输入数据广播到所有lane,将第i个rdma_rank对应的fifo地址广播到dst_send_buffers。

__global__ void dispatch(...) {    if (warp_role == WarpRole::kRDMASender) {    for (...) {SourceMeta src_meta;int num_topk_ranks = 0, topk_ranks[kNumTopkRDMARanks];void* dst_send_buffers[kNumTopkRDMARanks];#pragma unrollfor (int i = 0, slot_idx; i < kNumRDMARanks; ++ i) if ((slot_idx = __shfl_sync(0xffffffff, rdma_tail_idx, i)) >= 0) {slot_idx = slot_idx % num_max_rdma_chunked_recv_tokens;topk_ranks[num_topk_ranks] = i;auto recv_is_token_in_rank_uint64 = broadcast(is_token_in_rank_uint64, i);auto recv_is_token_in_rank_values = reinterpret_cast<const bool*>(&recv_is_token_in_rank_uint64);if (lane_id == num_topk_ranks)src_meta = SourceMeta(rdma_rank, recv_is_token_in_rank_values);dst_send_buffers[num_topk_ranks ++] = reinterpret_cast<uint8_t*>(broadcast(send_buffer, i)) + slot_idx * num_bytes_per_rdma_token;}}}
}

然后开始拷贝embedding,UNROLLED_WARP_COPY就是一个warp将数据通过LD_FUNC从SRC读取到unrolled_values,然后通过ST_FUNC将unrolled_values写入DST。
SRC就是用户的输入数据,LD_FUNC就是下边两个PTX中的一个,而ST_FUNC就是st_broadcast,就是将数据写入到dst_send_buffers的每个地址中。

#ifndef DISABLE_AGGRESSIVE_PTX_INSTRS
#define LD_NC_FUNC "ld.global.nc.L1::no_allocate.L2::256B"
#else
#define LD_NC_FUNC "ld.volatile.global"
#endif
__global__ void dispatch(...) {    if (warp_role == WarpRole::kRDMASender) {    for (...) {           // Copy `x` into symmetric send bufferauto st_broadcast = [=](const int key, const int4& value) {#pragma unrollfor (int j = 0; j < num_topk_ranks; ++ j)st_na_global(reinterpret_cast<int4*>(dst_send_buffers[j]) + key, value);};UNROLLED_WARP_COPY(5, lane_id, hidden_int4, 0, x + token_idx * hidden_int4, ld_nc_global, st_broadcast);}}
}#define UNROLLED_WARP_COPY(UNROLL_FACTOR, LANE_ID, N, DST, SRC, LD_FUNC, ST_FUNC) \
{ \         constexpr int kLoopStride = 32 * (UNROLL_FACTOR); \typename std::remove_reference<decltype(LD_FUNC((SRC) + 0))>::type unrolled_values[(UNROLL_FACTOR)]; \auto __src = (SRC); \auto __dst = (DST); \for (int __i = (LANE_ID); __i < ((N) / kLoopStride) * kLoopStride; __i += kLoopStride) { \_Pragma("unroll") \for (int __j = 0; __j < (UNROLL_FACTOR); ++ __j) \unrolled_values[__j] = LD_FUNC(__src + __i + __j * 32); \_Pragma("unroll") \for (int __j = 0; __j < (UNROLL_FACTOR); ++ __j) \ST_FUNC(__dst + __i + __j * 32, unrolled_values[__j]); \} \     for (int __i = ((N) / kLoopStride) * kLoopStride + (LANE_ID); __i < (N); __i += 32) \ST_FUNC(__dst + __i, LD_FUNC(__src + __i)); \
} 

拷贝完数据之后,然后开始拷贝meta data等。

kRDMASenderCoordinator

kRDMASenderCoordinator只有一个warp,管理当前sm到其他所有节点的rdma发送,这个warp也参与到上图中rdma发送的同步过程,完整图如下。
这个warp中的每一个lane对应一个rdma节点,持有这个节点对应fifo的last_issue_tail,表示上一次执行的rdma_write到了哪里。
处理流程就是轮询每一个节点的last_issue_tail,如果发现有数据可以发送,那么发送数据然后更新last_issue_tail,并将last_issue_tail发送到对端节点。
如果对NCCL比较熟悉的话会发现下图的同步流程和NCCL非常像,kRDMASender就是NCCL里的kernel,会拷贝数据到fifo,
kRDMASenderCoordinator相当于proxy线程,轮询到fifo中有新数据之后就执行数据的发送。
在这里插入图片描述

图 4
然后看下代码,回顾上一节,rdma_channel_prefix_matrix记录了当前的每一个channel到每一个dst_rank的token前缀和,这里通过前缀和可以算出num_tokens_to_send,表示当前rank向rdma_rank[lane_id]发送多少token
__global__ void dispatch(...) {    else if (warp_role == WarpRole::kRDMASenderCoordinator) {int num_tokens_to_send = 0; if (lane_id < kNumRDMARanks) {num_tokens_to_send = rdma_channel_prefix_matrix[lane_id * num_channels + channel_id];if (channel_id > 0) num_tokens_to_send -= rdma_channel_prefix_matrix[lane_id * num_channels + channel_id - 1];}    }
}

然后初始化last_issued_tail为0,while循环直到所有lane的num_tokens_to_send为0,即全部发送完成。
然后开始轮询每一个rank,这里通过channel_id + i的方式避免尝试打散对其他机器的访问,dst_rdma_rank就是本次尝试发送的对端节点,然后dst_rdma_rank这个lane通过shfl_sync将num_tokens_to_send广播到所有线程,如果synced_num_tokens_to_send为0,则说明不需要向他发送数据。

__global__ void dispatch(...) {    else if (warp_role == WarpRole::kRDMASenderCoordinator) {int last_issued_tail = 0;while (__any_sync(0xffffffff, num_tokens_to_send > 0)) {for (int i = 0, synced_num_tokens_to_send; i < kNumRDMARanks; ++ i) {int dst_rdma_rank = (i + channel_id) % kNumRDMARanks;synced_num_tokens_to_send = __shfl_sync(0xffffffff, num_tokens_to_send, dst_rdma_rank);if (synced_num_tokens_to_send == 0)continue;...}
}

然后去查看dst_rdma_rank对应的fifo队列,首先将lane[dst_rdma_rank]的last_issued_tail广播给所有lane,表示上次网络发送的流程执行到了哪里,然后读取这个fifo的rdma_send_channel_tail,表示数据生产到了哪里,如果这俩值不想等,说明产生了新数据,num_tokens_processed表示这次需要发送多少个token,这里首先会尝试合并到一个比较大的数据量,如果这次发送的token小于一个阈值,并且不是最后一次发送,那么就不执行发送,等合并到较大规模数据再发送

__global__ void dispatch(...) {    else if (warp_role == WarpRole::kRDMASenderCoordinator) {...auto synced_last_issued_tail = __shfl_sync(0xffffffff, last_issued_tail, dst_rdma_rank);auto processed_tail = ld_acquire_cta(const_cast<const int*>(rdma_send_channel_tail + dst_rdma_rank));auto num_tokens_processed = processed_tail - synced_last_issued_tail;if (num_tokens_processed != synced_num_tokens_to_send and num_tokens_processed < num_max_rdma_chunked_send_tokens)continue;}
}

然后通过synced_last_issued_tail找到fifo中需要发送数据的位置,然后通过nvshmemx_int8_put_nbi_warp执行rdma_write发送数据到对端机器的recv buffer。然后执行fence操作,nvshmem的fence只能保证下发的顺序,不保证完成顺序,因此对于rdma来讲是一个空操作。下发write之后,更新last_issued_tail,通过nvshmemx_signal_op执行一个rdma atomic操作将这次发送的数据量atomic add到对端的rdma_channel_tail,通知对端机器,这里atomic是为了保序,前序的write落盘之后才会执行atomic操作。

 __global__ void dispatch(...) {    else if (warp_role == WarpRole::kRDMASenderCoordinator) {...auto num_tokens_to_issue = min(num_tokens_processed, num_max_rdma_chunked_send_tokens);if (dst_rdma_rank != rdma_rank) {auto dst_slot_idx = synced_last_issued_tail % num_max_rdma_chunked_recv_tokens;EP_DEVICE_ASSERT(dst_slot_idx + num_tokens_to_issue <= num_max_rdma_chunked_recv_tokens);nvshmemx_int8_put_nbi_warp(rdma_channel_data.recv_buffer(rdma_rank) + dst_slot_idx * num_bytes_per_rdma_token,rdma_channel_data.send_buffer(dst_rdma_rank) + dst_slot_idx * num_bytes_per_rdma_token,num_bytes_per_rdma_token * num_tokens_to_issue,translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));nvshmem_fence();} else {// Lighter fence for local RDMA rankmemory_fence();}    // Update tails__syncwarp();if (lane_id == dst_rdma_rank) {last_issued_tail += num_tokens_to_issue;num_tokens_to_send -= num_tokens_to_issue;nvshmemx_signal_op(rdma_channel_tail.buffer(rdma_rank), num_tokens_to_issue, NVSHMEM_SIGNAL_ADD,translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));}    }    
}

kRDMAAndNVLForwarder

这个角色的作用就是每个节点上负责接收其他节点发送过来的token,并将这些token转发到当前机器上的其他gpu。
到这里完整的rdma发送过程中的同步就可以看到了,如下图所示,左边为发送端,右边为接收端,假设是rank0,假设都是sm0,发送端的Coordinator在执行了write之后会通过rdma atomic写rank0的rdma_channel_tail,rank0发现fifo中有数据之后,开始尝试将数据转发给当前机器上的其他gpu。
在这里插入图片描述

图 5
因此这里又涉及到了rdma_rank0这个机器上机内转发时候的同步,当前sm的一个warp对应一个当前机器的gpu,负责将这个fifo中收到的数据转发到自己的负责的gpu 在角色分配的时候,可以看到kRDMAAndNVLForwarder的target_rank为(warp_id + channel_id) % NUM_MAX_NVL_PEERS,即每个warp对应了一个gpu
const auto role_meta = [=]() -> std::pair<WarpRole, int> {if (is_forwarder) {if (warp_id < NUM_MAX_NVL_PEERS) {return {WarpRole::kRDMAAndNVLForwarder, (warp_id + channel_id) % NUM_MAX_NVL_PEERS};} else {return {WarpRole::kForwarderCoordinator, warp_id - NUM_MAX_NVL_PEERS};}    } else ...}(); auto warp_role = role_meta.first;auto target_rank = role_meta.second

之前有说过,DeepEP初始化的时候,会在每个gpu上分配buffer_ptr,并映射到其他gpu,kRDMAAndNVLForwarder负责数据的发送,假设target_rank为2,当前nvl_rank为1,那么他将拿到gpu2的buffer_ptr,即ws_rr_buffer_ptr,rs_wr_rank为nvl_rank,假设为sm1,那么nvl_channel_x指向ws_rr_buffer_ptr中的位置如下图所示,表示他将写入到gpu2的第一个channel的buffer中src_rank为1的位置。
在这里插入图片描述

图 6
__global__ void dispatch(...) {    if (warp_role == WarpRole::kRDMAAndNVLForwarder) {// NOTES: `rs_wr_buffer_ptr` means "Read for Senders, Write for Receivers", `ws_rr_buffer_ptr` means "Write for Senders, Read for Receivers"rs_wr_buffer_ptr = buffer_ptrs[nvl_rank], ws_rr_buffer_ptr = buffer_ptrs[target_rank], rs_wr_rank = nvl_rank, ws_rr_rank = target_rank;auto nvl_channel_x = AsymBuffer<int4>(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * hidden_int4, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);auto nvl_channel_prefix_start = AsymBuffer<int>(ws_rr_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);auto nvl_channel_prefix_end = AsymBuffer<int>(ws_rr_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);}
}

然后看下解析send端发送过来的meta data的逻辑,每个lane_id对应一个源节点,轮询rdma_channel_meta直到收到数据,还是假设target_rank为2,当前nvl_rank为1,channel_id为x,那么对于lane[lane_id],会从meta中读取到第lane_id个机器的channel[0, x - 1]发到当前机器上gpu[dst_nvl_rank]的token总数,然后将这个作为写入gpu[dst_nvl_rank]的nvl_channel_prefix_start的lane_id位置,同理对于nvl_channel_prefix_end和rdma数据,num_tokens_to_recv_from_rdma表示当前sm会收到lane_id这个机器发送过来多少token。

__global__ void dispatch(...) {    if (warp_role == WarpRole::kRDMAAndNVLForwarder) {if (lane_id < kNumRDMARanks) {while (true) {auto meta_0 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + dst_nvl_rank);auto meta_1 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + NUM_MAX_NVL_PEERS + dst_nvl_rank);auto meta_2 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + NUM_MAX_NVL_PEERS * 2);auto meta_3 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + NUM_MAX_NVL_PEERS * 2 + 1);if (meta_0 < 0 and meta_1 < 0 and meta_2 < 0 and meta_3 < 0) {// Notify NVL ranksint start_sum = -meta_0 - 1, end_sum = -meta_1 - 1;EP_DEVICE_ASSERT(start_sum >= 0 and end_sum >= 0 and end_sum >= start_sum);st_relaxed_sys_global(nvl_channel_prefix_start.buffer() + lane_id, -start_sum - 1);st_relaxed_sys_global(nvl_channel_prefix_end.buffer() + lane_id, -end_sum - 1);// Save RDMA channel received token countsrc_rdma_channel_prefix = -meta_2 - 1;auto src_rdma_channel_prefix_1 = -meta_3 - 1;num_tokens_to_recv_from_rdma = src_rdma_channel_prefix_1 - src_rdma_channel_prefix;if (not kCachedMode)recv_rdma_channel_prefix_matrix[lane_id * num_channels + channel_id] = src_rdma_channel_prefix_1;src_rdma_channel_prefix += lane_id == 0 ? 0 : recv_rdma_rank_prefix_sum[lane_id - 1];EP_DEVICE_ASSERT(num_tokens_to_recv_from_rdma >= 0);break;}}}}
}

然后就是开始转发数据,机内数据的转发同步方式也是通过fifo进行同步,判断head/tail可以知道有没有新数据产生,以及buffer是否满了,这个过程和NCCL的机内kernel同步比较相似,如下图所示。
在这里插入图片描述

图 7

首先判断是否可以转发数据到其他gpu,第一个while循环表示所有节点发送过来的数据有没有都处理完,如果没处理结束,然后每个warp开始判断对应的gpu的fifo中是否有空闲,fifo总的容量为num_max_nvl_chunked_recv_tokens,已经使用的空间为tail - head,如果剩余的空间大于等于num_max_nvl_chunked_send_tokens,那么可以发送数据。

__global__ void dispatch(...) {    if (warp_role == WarpRole::kRDMAAndNVLForwarder) {while (__any_sync(0xffffffff, num_tokens_to_recv_from_rdma > 0)) {// Check destination queue emptiness, or wait a buffer to be releasedstart_time = clock64();while (lane_id == 0) { int num_used_slots = cached_nvl_channel_tail - cached_nvl_channel_head;if (num_max_nvl_chunked_recv_tokens - num_used_slots >= num_max_nvl_chunked_send_tokens)break;cached_nvl_channel_head = ld_volatile_global(nvl_channel_head.buffer());}  }}
}

还是假设target_rank为2,当前nvl_rank为1,那么当前warp发现gpu[2]的buffer有空闲之后,开始轮询其他所有的机器,假设这次处理的机器是src_rdma_rank,因为每个lane_id对应了一个节点,因此将lane[src_rdma_rank]的剩余token数广播到warp,如果剩余token大于0,那么开始判断src_rdma_rank有没有新发送数据过来,判断方法就是看rdma_channel_tail和rdma_channel_head是否相等,不相等的话说明有新数据产生。

__global__ void dispatch(...) {    if (warp_role == WarpRole::kRDMAAndNVLForwarder) {while (true) {src_rdma_rank = (src_rdma_rank + 1) % kNumRDMARanks;if (__shfl_sync(0xffffffff, num_tokens_to_recv_from_rdma, src_rdma_rank) > 0) { if (lane_id == src_rdma_rank and cached_rdma_channel_head == cached_rdma_channel_tail)cached_rdma_channel_tail = static_cast<int>(ld_acquire_sys_global(rdma_channel_tail.buffer(src_rdma_rank)));if (__shfl_sync(0xffffffff, cached_rdma_channel_tail > cached_rdma_channel_head, src_rdma_rank))break;}    }    auto src_rdma_head = __shfl_sync(0xffffffff, cached_rdma_channel_head, src_rdma_rank);auto src_rdma_tail = __shfl_sync(0xffffffff, cached_rdma_channel_tail, src_rdma_rank);}
}

然后开始将数据从当前gpu的rdma buffer转发到对端gpu的nvlbuffer。
遍历rdma buffer中的所有新token,即从src_rdma_head到src_rdma_tail,源数据为shifted,从shifted读取meta信息,判断这个token是否需要被转发到dst_nvl_rank,如果不需要就判断下一个token。通过cached_nvl_channel_tail可以算出这次填充到fifo的那个slot。然后开始实际的拷贝数据到nvl_channel_x。

__global__ void dispatch(...) {    if (warp_role == WarpRole::kRDMAAndNVLForwarder) {for (int i = src_rdma_head, num_tokens_sent = 0; i < src_rdma_tail; ++ i) {auto rdma_slot_idx = i % num_max_rdma_chunked_recv_tokens;void* shifted = rdma_channel_data.recv_buffer(src_rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token;auto src_meta = ld_nc_global(reinterpret_cast<SourceMeta*>(reinterpret_cast<int8_t*>(shifted) + hidden_bytes));lane_id == src_rdma_rank ? (num_tokens_to_recv_from_rdma -= 1) : 0;bool is_in_dst_nvl_rank = src_meta.is_token_in_nvl_rank(dst_nvl_rank);if (lane_id == src_rdma_rank) {auto cached_head = is_in_dst_nvl_rank ? rdma_nvl_token_idx : -1;rdma_nvl_token_idx += is_in_dst_nvl_rank;if (not kCachedMode)send_nvl_head[i * NUM_MAX_NVL_PEERS] = cached_head;}if (not is_in_dst_nvl_rank)continue;int dst_slot_idx = (cached_nvl_channel_tail ++) % num_max_nvl_chunked_recv_tokens;// Copy dataUNROLLED_WARP_COPY(5, lane_id, hidden_int4,nvl_channel_x.buffer() + dst_slot_idx * hidden_int4,reinterpret_cast<int4*>(shifted),ld_nc_global, st_na_global);}//     __shared__ volatile int forward_channel_head[NUM_MAX_NVL_PEERS][kNumRDMARanks];if (lane_id == src_rdma_rank)forward_channel_head[dst_nvl_rank][src_rdma_rank] = (cached_rdma_channel_head = src_rdma_tail);if (lane_id == 0)st_release_sys_global(nvl_channel_tail.buffer(), cached_nvl_channel_tail);}
}

可以看到除了更新了tail指针通知对端gpu有新数据之外,还更新了forward_channel_head,forward_channel_head[dst_nvl_rank][src_rdma_rank]表示从src_rdma_rank过来的数据,被转发到dst_nvl_rank的进度到了哪里,这个下边看kForwarderCoordinator的时候会介绍。

kForwarderCoordinator

kForwarderCoordinator只有一个warp,每个lane对应一个源节点,他做的事情就是将新的head通过rdma write写回到源节点,通知源节点自己已经处理过这个数据了。因为比如源节点写过来一个token之后,这个token可能会被转发给多个gpu,kRDMAAndNVLForwarder中每个warp负责转发这个token到一个gpu,那么只有当所有warp都处理完这个token之后,才能通知源节点。
具体看下代码,每个lane遍历forward_channel_head[0 - NUM_MAX_NVL_PEERS][lane]找到min_head,通过rdma_write将这个min_head发送给节点lane的rdma_channel_head,即图3中的绿线。

__global__ void dispatch(...) {    else if (warp_role == WarpRole::kForwarderCoordinator) {int last_head = 0, target_rdma = lane_id < kNumRDMARanks ? lane_id : 0; while (true) {// Find minimum headint min_head = std::numeric_limits<int>::max();#pragma unrollfor (int i = 0; i < NUM_MAX_NVL_PEERS; ++ i) if (not forward_channel_retired[i])min_head = min(min_head, forward_channel_head[i][target_rdma]);if (__all_sync(0xffffffff, min_head == std::numeric_limits<int>::max()))break;// Update remote headif (min_head != std::numeric_limits<int>::max() and min_head > last_head and lane_id < kNumRDMARanks)nvshmem_uint64_p(rdma_channel_head.buffer(rdma_rank), last_head = min_head,translate_dst_rdma_rank<kLowLatencyMode>(lane_id, nvl_rank));// Nanosleep and let other warps work__nanosleep(NUM_WAIT_NANOSECONDS);}    }
}

kNVLReceivers

kNVLReceivers用于数据的接收,一共有NUM_MAX_NVL_PEERS个warp,每个warp对应一个gpu,即src_nvl_rank,会从这个gpu收数据,每个warp的lane对应一个节点。
然后从recv_gbl_rank_prefix_sum中获取total_offset,recv_gbl_rank_prefix_sum[i]表示前i个rank发送到当前gpu的总token数,因此这里就是获取第lane_id个机器的gpu[src_nvl_rank]对应的rank发送过来的数据的起始位置。
假设当前是第x个channel,nvl_channel_prefix_start表示的是第lane_id个机器的channel[0, x - 1]发送到当前rank的token数,因此当前rank处理lane_id机器发送过来的token数需要被存到[total_offset + prefix_sum_start, total_offset + prefix_sum_end)的区间。
最后通过warp_reduce可以知道其他所有机器的gpu[src_nvl_rank]的同sm发送过来的token总数。

__global__ void dispatch(...) {    else {// NVL consumers// Retrieve rank offset from barrier results (each lane's register stores an RDMA rank)int src_nvl_rank = target_rank, total_offset = 0; if (lane_id < kNumRDMARanks and lane_id * NUM_MAX_NVL_PEERS + src_nvl_rank > 0) total_offset = recv_gbl_rank_prefix_sum[lane_id * NUM_MAX_NVL_PEERS + src_nvl_rank - 1];// Receive channel offsetsint start_offset = 0, end_offset = 0, num_tokens_to_recv;auto start_time = clock64();while (lane_id < kNumRDMARanks) {start_offset = ld_volatile_global(nvl_channel_prefix_start.buffer() + lane_id);end_offset = ld_volatile_global(nvl_channel_prefix_end.buffer() + lane_id);if (start_offset < 0 and end_offset < 0) { start_offset = -start_offset - 1, end_offset = -end_offset - 1; total_offset += start_offset;break;}    }    num_tokens_to_recv = warp_reduce_sum(end_offset - start_offset);}
}

然后通过head和tail判断是否有新数据产生,如果head和tail不相等,说明产生了新的数据。

__global__ void dispatch(...) {    else {int cached_channel_head_idx = 0, cached_channel_tail_idx = 0; while (num_tokens_to_recv > 0) { while (lane_id == 0) { // Ready to copyif (cached_channel_head_idx != cached_channel_tail_idx)break;cached_channel_tail_idx = ld_acquire_sys_global(nvl_channel_tail.buffer());}    }    }
}

然后开始处理从head开始到tail的token,通过head可以知道元数据在buffer中的哪个slot,读取meta信息可以知道数据是哪个src_rdma_rank发送过来的,然后将src_rdma_rank的total_offset广播到整个warp,即recv_token_idx,数据将被放到输出的这个位置,并将total_offset + 1。
然后开始拷贝数据,即将数据从nvl_channel_x中的slot拷贝到recv_x + recv_token_idx * hidden_int4。
同理拷贝scale,topk_idx,topk_weights,最后更新tail。

__global__ void dispatch(...) {    else {int num_recv_tokens = cached_channel_tail_idx - cached_channel_head_idx;for (int chunk_idx = 0; chunk_idx < num_recv_tokens; ++ chunk_idx, -- num_tokens_to_recv) {int token_idx_in_buffer = (cached_channel_head_idx ++) % num_max_nvl_chunked_recv_tokens;auto meta = ld_nc_global(nvl_channel_src_meta.buffer() + token_idx_in_buffer);int64_t recv_token_idx = __shfl_sync(0xffffffff, total_offset, meta.src_rdma_rank);(lane_id == meta.src_rdma_rank) ? (total_offset += 1) : 0; // Copy dataUNROLLED_WARP_COPY(5, lane_id, hidden_int4,recv_x + recv_token_idx * hidden_int4,nvl_channel_x.buffer() + token_idx_in_buffer * hidden_int4,ld_nc_global, st_na_global);// Copy source metaif (lane_id == 0 and not kCachedMode)st_na_global(recv_src_meta + recv_token_idx, meta);// Copy scalesUNROLLED_WARP_COPY(1, lane_id, num_scales,recv_x_scales + recv_token_idx * num_scales,nvl_channel_x_scales.buffer() + token_idx_in_buffer * num_scales,ld_nc_global, st_na_global);// Copy `topk_idx` and `topk_weights`if (lane_id < num_topk) {auto recv_idx = recv_token_idx * num_topk + lane_id;auto buffer_idx = token_idx_in_buffer * num_topk + lane_id;st_na_global(recv_topk_idx + recv_idx, static_cast<int64_t>(ld_nc_global(nvl_channel_topk_idx.buffer() + buffer_idx)));st_na_global(recv_topk_weights + recv_idx, ld_nc_global(nvl_channel_topk_weights.buffer() + buffer_idx));}}// Move queue__syncwarp();if (lane_id == 0)st_relaxed_sys_global(nvl_channel_head.buffer(), cached_channel_head_idx);}
}

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/30676.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

mysql创建数据库和表

登录 MySQL 并选择数据库 登录 MySQL 命令行&#xff1a;mysql -u root -p 查看所有数据库&#xff1a;SHOW DATABASES; 创建数据库&#xff1a;CREATE DATABASE my_database; 查看数据库是否创建成功&#xff1a;SHOW DATABASES; 选择数据库&#xff1a;USE 你的数据库名…

Python 实现图片提取文字

文章目录 一、效果图 二、库安装 三、使用示例 四、完整代码 一、效果图 使用的图片&#xff1a; 返回文字&#xff1a; 二、库安装 pip install easyocr opencv-python numpy 三、使用示例 ocr EasyOCRProcessor() results ocr.extract_text("test.png",&…

根目录下的两个包相互没有import成功

问题1&#xff1a; import models 或者import models.Models不成功 问题2&#xff1a;在项目名称后面出现unnamed的提示 查阅资料&#xff0c;说错误可能是&#xff1a;.idea 文件夹配置缺失或损坏 PyCharm 的项目配置信息&#xff08;包括名称&#xff09;存储在 .idea 目录中…

什么样的物联网框架适合开展共享自助KTV唱歌项目?

现在物联网的广泛应用&#xff0c;也让更多用户们看到了它的实力&#xff0c;也使得共享经济遍地开花。其中共享自助唱歌设备也备受欢迎&#xff0c;那么适合开展共享自助KTV唱歌项目的物联网框架都应具备哪些特点呢&#xff1f; 智能化与自动化管理 物联网技术在共享KTV中的应…

《白帽子讲 Web 安全》之深入同源策略(万字详解)

目录 引言 一、同源策略基础认知 &#xff08;一&#xff09;定义 &#xff08;二&#xff09;作用 &#xff08;三&#xff09;作用机制详解 二、同源策略的分类 &#xff08;一&#xff09;域名同源策略 &#xff08;二&#xff09;协议同源策略 &#xff08;三&…

【Linux篇】调试器-gdb/cgdb使用

&#x1f4cc; 个人主页&#xff1a; 孙同学_ &#x1f527; 文章专栏&#xff1a;Liunx &#x1f4a1; 关注我&#xff0c;分享经验&#xff0c;助你少走弯路&#xff01; 文章目录 1. 前言2.关于gdb2.1 快速认识gdb2.2 安装cgdb2.3 gdb命令2.4 调试 & 断点 3.常见技巧3.…

推荐一些免费开源支持Vue3甘特图组件

文章目录 前言一、dhtmlxGantt二、frappe-gantt三、vue-ganttastic四、gantt-elastic五、v-gantt六、vue-gantt-schedule-timeline-calendar七、vue-gantt八、总结 前言 在现代项目管理和任务调度中&#xff0c;甘特图是一种非常实用的工具。它能够直观地展示任务的时间安排、…

十大数据科学Python库

十大数据科学Python库 1、NumPy&#xff1a;脊髓2、Pandas&#xff1a;数据操纵专家3、Matplotlib&#xff1a;艺术之魂4、Scikit-Learn&#xff1a;瑞士军刀5、TensorFlow&#xff1a;聪明的家伙6、PyTorch&#xff1a;叛逆者7、Selenium&#xff1a;操纵大师8、NLTK&#xff…

【C++初阶】类与对象(下)

目录 再探构造函数&#xff1a;初始化列表 使用方法&#xff1a; 特点&#xff1a; &#xff11;、初始化列表是每个成员变量定义初始化的地方 &#xff12;、每一成员变量在初始化列表只出现一次 3、必须在初始化列表中出初始化的成员变量 4、成员变量给缺省值 5、在构…

Android设备是如何进入休眠的呢?

首先我们手机灭屏后&#xff0c;一般需要等一段时间CPU才真正进入休眠。即Android设备屏幕暗下来的时候&#xff0c;并不是立即就进入了休眠模式&#xff1b;当所有唤醒源都处于de-avtive状态后&#xff0c;系统才会进入休眠。在手机功耗中从灭屏开始到CPU进入休眠时间越短&…

多线程知识概述

目录 1. 基本知识概述 2. 多线程概述 2.1 优点 2.2 使用场景 3. 创建线程 3.1 继承 Thread 类 3.2 实现 Runnable 接口 3.3 比较 3.4 创建 Callable 接口 3.5 使用线程池 4. Thread 类常用方法 5. 线程生命周期 6. 线程安全机制 6.1 同步代码块 6.2 同步方法 6.3 …

elasticsearch是哪家的

Elasticsearch&#xff1a;数据搜索与分析的领航者 在当今这个信息爆炸的时代&#xff0c;快速且准确地处理海量数据成为了众多企业和组织追求的目标。而Elasticsearch正是在这个背景下脱颖而出的一款强大的开源搜索引擎。它是由位于美国加利福尼亚州的Elastic公司所开发和维护…

Spring学习笔记:工厂模式与反射机制实现解耦

1.什么是Spring? spring是一个开源轻量级的java开发应用框架&#xff0c;可以简化企业级应用开发 轻量级 1.轻量级(对于运行环境没有额外要求) 2.代码移植性高(不需要实现额外接口) JavaEE的解决方案 Spring更像是一种解决方案&#xff0c;对于控制层&#xff0c;它有Spring…

【一个月备战蓝桥算法】递归与递推

字典序 在刷题和计算机科学领域&#xff0c;字典序&#xff08;Lexicographical order&#xff09;也称为词典序、字典顺序、字母序&#xff0c;是一种对序列元素进行排序的方式&#xff0c;它模仿了字典中单词的排序规则。下面从不同的数据类型来详细解释字典序&#xff1a; …

前端学习——CSS

CSS CSS&#xff08;Cascading Style Sheets&#xff09;级联样式表语法 选择器全局选择器元素选择器类选择器ID选择器合并选择器选择器的优先级 字体属性字体颜色 背景属性background-color属性background-image属性background-repeat属性background-size属性background-posit…

【Python 2D绘图】Matplotlib绘图(统计图表)

【Python 2D绘图】Matplotlib绘图&#xff08;统计图表&#xff09; 1. 概述1.1 简介1.2 安装1.3 导入1.4 保存1.5 数据来源1.5.1 Numpy ndarray1.5.2 Pandas DataFrame 1.6 中文显示 2. 基础样式2.1 颜色2.1.1 简称2.1.2 全称 2.2 布局2.2.1 Matplotlib 画布划分2.2.2 绘制子图…

学习笔记:Python网络编程初探之基本概念(一)

一、网络目的 让你设备上的数据和其他设备上进行共享&#xff0c;使用网络能够把多方链接在一起&#xff0c;然后可以进行数据传递。 网络编程就是&#xff0c;让在不同的电脑上的软件能够进行数据传递&#xff0c;即进程之间的通信。 二、IP地址的作用 用来标记唯一一台电脑…

Spark-TTS:基于大模型的文本语音合成工具

GitHub&#xff1a;https://github.com/SparkAudio/Spark-TTS Spark-TTS是一个先进的文本到语音系统&#xff0c;它利用大型语言模型&#xff08;LLM&#xff09;的强大功能进行高度准确和自然的语音合成&#xff1b;旨在高效、灵活、强大地用于研究和生产用途。 一、介绍 Sp…

【RAG】检索后排序 提高回答精度

问题: RAG中&#xff0c;有时&#xff0c;最合适的答案不一定排在检索的最前面 user_query "how safe is llama 2" search_results vector_db.search(user_query, 5)for doc in search_results[documents][0]:print(doc"\n")response bot.chat(user_qu…

线程安全问题(面试重难点)

这里只是简单介绍以下线程安全,具体情况要结合代码进行判断 线程 是随机调度,及 抢占式执行 ,具有随机性,就可能会让我们的结果出现不同 当我们得到的结果并不是我们想要的时候(不符合需求),就会被认定为BUG,此时就是出现了线程安全问题 那么存在线程不安全的代码就被认为是…