上节介绍了normal kernel执行过程中会分成两部分,第一步通过notify_dispatch计算meta信息,然后本节介绍数据dispatch的过程。
notify_dispatch过程中会计算其他所有rank发送给当前rank多少token,写入到host的moe_recv_counter_mapped,还会计算其他所有rdma_rank发送给当前rank多少token,写入到host的moe_recv_rdma_counter_mapped,这里通过cpu端轮询这个值,轮询到之后就可以在notify_dispatch不结束的情况下开始分配显存,做到overlap。
Buffer::internode_dispatch(...) {internode::notify_dispatch(...);while (true) {// Read total countnum_recv_tokens = static_cast<int>(*moe_recv_counter);num_rdma_recv_tokens = static_cast<int>(*moe_recv_rdma_counter);// Read per-expert countbool ready = (num_recv_tokens >= 0) and (num_rdma_recv_tokens >= 0);for (int i = 0; i < num_local_experts and ready; ++ i)ready &= moe_recv_expert_counter[i] >= 0;if (ready)break;}auto recv_x = torch::empty({num_recv_tokens, hidden}, x.options());...internode::dispatch(...);...
}
角色分配
constexpr int kNumDispatchRDMASenderWarps = 7;
SETUP_LAUNCH_CONFIG(num_channels * 2, (kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NVL_PEERS) * 32, stream);__global__ void dispatch(...) {const auto role_meta = [=]() -> std::pair<WarpRole, int> {if (is_forwarder) {if (warp_id < NUM_MAX_NVL_PEERS) {return {WarpRole::kRDMAAndNVLForwarder, (warp_id + channel_id) % NUM_MAX_NVL_PEERS};} else {return {WarpRole::kForwarderCoordinator, warp_id - NUM_MAX_NVL_PEERS};} } else if (warp_id < kNumDispatchRDMASenderWarps) {return {WarpRole::kRDMASender, -1};} else if (warp_id == kNumDispatchRDMASenderWarps) {return {WarpRole::kRDMASenderCoordinator, -1};} else {return {WarpRole::kNVLReceivers, (warp_id + channel_id - kNumDispatchRDMASenderWarps) % NUM_MAX_NVL_PEERS};}}();auto warp_role = role_meta.first;auto target_rank = role_meta.second; // Not applicable for RDMA senders
dispatch的时候会对每个warp分配角色以执行不同的逻辑,可以看到一共有num_channels * 2个sm,每个sm有7 + 1 + 8个warp,角色分配如下图所示
kRDMASender warp数量为kNumDispatchRDMASenderWarps,即7个,这个角色主要计算token会被发送到哪些rdma_rank,然后将数据拷贝到对应rdma_rank的的buffer。
kRDMASenderCoordinator warp数量为1
kNVLReceivers warp数量为NUM_MAX_NVL_PEERS
kRDMAAndNVLForwarder warp数量为NUM_MAX_NVL_PEERS
kForwarderCoordinator warp数量为1
kRDMASender
首先是将本端的meta信息发送给其他rank,rdma_channel_meta是SymBuffer,对于当前sm的send_buffer保存了向每一个rdma_rank发送的meta信息。假设当前channel为x。
对于[0, 7]之间的lane_id,send_buffer(dst_rdma_rank)[lane_id]表示的是channel[0, x - 1]向dst_rdma_rank的gpu[lane_id]发送的token数
对于[8, 15]之间的lane_id,send_buffer(dst_rdma_rank)[lane_id]表示的是channel[0, x]向dst_rdma_rank的gpu[lane_id]发送的token数
send_buffer(dst_rdma_rank)[16]表示的是channel[0, x - 1]向dst_rdma_rank这个节点发送的token数
send_buffer(dst_rdma_rank)[16]表示的是channel[0, x]向dst_rdma_rank这个节点发送的token数
auto rdma_channel_meta = SymBuffer<int>(rdma_buffer_ptr, NUM_MAX_NVL_PEERS * 2 + 2, kNumRDMARanks, channel_id, num_channels);if (warp_role == WarpRole::kRDMASender) {for (int dst_rdma_rank = warp_id; dst_rdma_rank < kNumRDMARanks; dst_rdma_rank += kNumDispatchRDMASenderWarps) {if (lane_id < NUM_MAX_NVL_PEERS) {rdma_channel_meta.send_buffer(dst_rdma_rank)[lane_id] = -(channel_id == 0 ? 0 : gbl_channel_prefix_matrix[(dst_rdma_rank * NUM_MAX_NVL_PEERS + lane_id) * num_channels + channel_id - 1]) - 1;} else if (lane_id < NUM_MAX_NVL_PEERS * 2) {rdma_channel_meta.send_buffer(dst_rdma_rank)[lane_id] = -gbl_channel_prefix_matrix[(dst_rdma_rank * NUM_MAX_NVL_PEERS + lane_id - NUM_MAX_NVL_PEERS) * num_channels + channel_id] - 1;} else if (lane_id == NUM_MAX_NVL_PEERS * 2) {rdma_channel_meta.send_buffer(dst_rdma_rank)[lane_id] = -(channel_id == 0 ? 0 : rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id - 1]) - 1;} else if (lane_id == NUM_MAX_NVL_PEERS * 2 + 1) {rdma_channel_meta.send_buffer(dst_rdma_rank)[lane_id] = -rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id] - 1;}nvshmemx_int_put_nbi_warp(rdma_channel_meta.recv_buffer(rdma_rank), rdma_channel_meta.send_buffer(dst_rdma_rank), NUM_MAX_NVL_PEERS * 2 + 2,translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));}}
}
每个channel处理一段token_start_idx到token_end_idx的连续token,每个warp每次处理一个token,每个warp内部的一个lane对应一个同号卡节点
__global__ void dispatch(...) {auto hidden_bytes = hidden_int4 * sizeof(int4);auto num_bytes_per_rdma_token = get_num_bytes_per_rdma_token(hidden_int4, num_scales, num_topk, num_topk);auto rdma_channel_data = SymBuffer<int8_t>(rdma_buffer_ptr, num_max_rdma_chunked_recv_tokens * num_bytes_per_rdma_token, kNumRDMARanks, channel_id, num_channels);auto rdma_channel_meta = SymBuffer<int>(rdma_buffer_ptr, NUM_MAX_NVL_PEERS * 2 + 2, kNumRDMARanks, channel_id, num_channels);if (warp_role == WarpRole::kRDMASender) {int token_start_idx, token_end_idx;get_channel_task_range(num_tokens, num_channels, channel_id, token_start_idx, token_end_idx);for (int dst_rdma_rank = warp_id; dst_rdma_rank < kNumRDMARanks; dst_rdma_rank += kNumDispatchRDMASenderWarps) {}...}}
rdma_channel_data为SymBuffer,假设一共两个channel,3台机器,当前为channel1,则rdma_channel_data格式如下,send_ptr和recv_ptr指向箭头位置,每一个矩形都是一个类似nccl的fifo,一共有num_max_rdma_chunked_recv_tokens个slot,每个slot存一个token,即num_bytes_per_rdma_token,每个矩形都对应了一个dst_rdma_rank,比如send_ptr开始的第一个矩形表示channel_id为1的warp如果要和rdma_rank 0的机器进行通信,则使用这一个矩形对应的buffer。
__global__ void dispatch(...) {__shared__ volatile int rdma_send_next_token_idx;if (warp_role == WarpRole::kRDMASender) {for (token_idx = token_start_idx + warp_id; token_idx < token_end_idx; token_idx += kNumDispatchRDMASenderWarps) {// Acquire sequential lockwhile (lane_id == 0 and rdma_send_next_token_idx != token_idx);__syncwarp();...// Release sequential locklane_id == 0 ? (rdma_send_next_token_idx += 1) : 0;}
}
rdma_send_next_token_idx初始化为0,每次等到rdma_send_next_token_idx等于自己处理的token_idx的时候,说明轮到自己了,这时候才可以访问临界区,释放锁就将rdma_send_next_token_idx + 1。
rdma发送过程中的同步
前边有说到send_ptr(0)是一个fifo,对应发送到rdma_rank0,由head,tail进行同步,生产者就是kRDMASender,假设当前是warp0,正在处理token[x],然后发现这个token需要被发送到所有rdma_rank,那么首先获取每个rdma_rank对应的tail指针,然后判断每一个rdma_rank对应的fifo是否有空间,判断方法就是对比tail是否大于head,如果大于说明有空间,如果每个fifo都有空间,那么将数据拷贝到每个fifo的tail处。这里的消费者其实是对端rdma_rank,当这个fifo中head位置的token被对端处理完成后,对端会通过rdma_write更新这个head,完整的图如下所示
__global__ void dispatch(...) {if (warp_role == WarpRole::kRDMASender) {for (token_idx = token_start_idx + warp_id; token_idx < token_end_idx; token_idx += kNumDispatchRDMASenderWarps) {// Read RDMA rank existenceuint64_t is_token_in_rank_uint64 = 0; if (lane_id < kNumRDMARanks)is_token_in_rank_uint64 = *reinterpret_cast<const uint64_t*>(is_token_in_rank + token_idx * num_ranks + lane_id * NUM_MAX_NVL_PEERS);// Acquire sequential lockwhile (lane_id == 0 and rdma_send_next_token_idx != token_idx);__syncwarp();// Acquire next tailint rdma_tail_idx = -1;if (is_token_in_rank_uint64 != 0) { rdma_tail_idx = rdma_send_channel_next_tail[lane_id] ++;while (rdma_tail_idx - cached_rdma_channel_head >= num_max_rdma_chunked_recv_tokens)cached_rdma_channel_head = static_cast<int>(ld_volatile_global(rdma_channel_head.buffer(lane_id)));} __syncwarp();// Update last token tailif (last_rdma_tail_idx >= 0)st_release_cta(const_cast<const int *>(rdma_send_channel_tail + lane_id), last_rdma_tail_idx + 1);last_rdma_tail_idx = rdma_tail_idx;// Release sequential locklane_id == 0 ? (rdma_send_next_token_idx += 1) : 0; }}
}
拷贝数据
现在各个lane_id都拿到了自己的rdma_tail_idx,然后开始拷贝数据,首先需要整个warp知道需要往哪几个lane_id对应的rdma_rank发送数据,然后开始遍历kNumRDMARanks,通过shfl_sync获取第i个lane的rdma_tail_idx,如果rdma_tail_idx大于0,说明需要向rdma_rank[i]发数据,则将i记录到topk_ranks中。
broadcast函数就是通过shfl_sync将lane[i]输入数据广播到所有lane,将第i个rdma_rank对应的fifo地址广播到dst_send_buffers。
__global__ void dispatch(...) { if (warp_role == WarpRole::kRDMASender) { for (...) {SourceMeta src_meta;int num_topk_ranks = 0, topk_ranks[kNumTopkRDMARanks];void* dst_send_buffers[kNumTopkRDMARanks];#pragma unrollfor (int i = 0, slot_idx; i < kNumRDMARanks; ++ i) if ((slot_idx = __shfl_sync(0xffffffff, rdma_tail_idx, i)) >= 0) {slot_idx = slot_idx % num_max_rdma_chunked_recv_tokens;topk_ranks[num_topk_ranks] = i;auto recv_is_token_in_rank_uint64 = broadcast(is_token_in_rank_uint64, i);auto recv_is_token_in_rank_values = reinterpret_cast<const bool*>(&recv_is_token_in_rank_uint64);if (lane_id == num_topk_ranks)src_meta = SourceMeta(rdma_rank, recv_is_token_in_rank_values);dst_send_buffers[num_topk_ranks ++] = reinterpret_cast<uint8_t*>(broadcast(send_buffer, i)) + slot_idx * num_bytes_per_rdma_token;}}}
}
然后开始拷贝embedding,UNROLLED_WARP_COPY就是一个warp将数据通过LD_FUNC从SRC读取到unrolled_values,然后通过ST_FUNC将unrolled_values写入DST。
SRC就是用户的输入数据,LD_FUNC就是下边两个PTX中的一个,而ST_FUNC就是st_broadcast,就是将数据写入到dst_send_buffers的每个地址中。
#ifndef DISABLE_AGGRESSIVE_PTX_INSTRS
#define LD_NC_FUNC "ld.global.nc.L1::no_allocate.L2::256B"
#else
#define LD_NC_FUNC "ld.volatile.global"
#endif
__global__ void dispatch(...) { if (warp_role == WarpRole::kRDMASender) { for (...) { // Copy `x` into symmetric send bufferauto st_broadcast = [=](const int key, const int4& value) {#pragma unrollfor (int j = 0; j < num_topk_ranks; ++ j)st_na_global(reinterpret_cast<int4*>(dst_send_buffers[j]) + key, value);};UNROLLED_WARP_COPY(5, lane_id, hidden_int4, 0, x + token_idx * hidden_int4, ld_nc_global, st_broadcast);}}
}#define UNROLLED_WARP_COPY(UNROLL_FACTOR, LANE_ID, N, DST, SRC, LD_FUNC, ST_FUNC) \
{ \ constexpr int kLoopStride = 32 * (UNROLL_FACTOR); \typename std::remove_reference<decltype(LD_FUNC((SRC) + 0))>::type unrolled_values[(UNROLL_FACTOR)]; \auto __src = (SRC); \auto __dst = (DST); \for (int __i = (LANE_ID); __i < ((N) / kLoopStride) * kLoopStride; __i += kLoopStride) { \_Pragma("unroll") \for (int __j = 0; __j < (UNROLL_FACTOR); ++ __j) \unrolled_values[__j] = LD_FUNC(__src + __i + __j * 32); \_Pragma("unroll") \for (int __j = 0; __j < (UNROLL_FACTOR); ++ __j) \ST_FUNC(__dst + __i + __j * 32, unrolled_values[__j]); \} \ for (int __i = ((N) / kLoopStride) * kLoopStride + (LANE_ID); __i < (N); __i += 32) \ST_FUNC(__dst + __i, LD_FUNC(__src + __i)); \
}
拷贝完数据之后,然后开始拷贝meta data等。
kRDMASenderCoordinator
kRDMASenderCoordinator只有一个warp,管理当前sm到其他所有节点的rdma发送,这个warp也参与到上图中rdma发送的同步过程,完整图如下。
这个warp中的每一个lane对应一个rdma节点,持有这个节点对应fifo的last_issue_tail,表示上一次执行的rdma_write到了哪里。
处理流程就是轮询每一个节点的last_issue_tail,如果发现有数据可以发送,那么发送数据然后更新last_issue_tail,并将last_issue_tail发送到对端节点。
如果对NCCL比较熟悉的话会发现下图的同步流程和NCCL非常像,kRDMASender就是NCCL里的kernel,会拷贝数据到fifo,
kRDMASenderCoordinator相当于proxy线程,轮询到fifo中有新数据之后就执行数据的发送。
__global__ void dispatch(...) { else if (warp_role == WarpRole::kRDMASenderCoordinator) {int num_tokens_to_send = 0; if (lane_id < kNumRDMARanks) {num_tokens_to_send = rdma_channel_prefix_matrix[lane_id * num_channels + channel_id];if (channel_id > 0) num_tokens_to_send -= rdma_channel_prefix_matrix[lane_id * num_channels + channel_id - 1];} }
}
然后初始化last_issued_tail为0,while循环直到所有lane的num_tokens_to_send为0,即全部发送完成。
然后开始轮询每一个rank,这里通过channel_id + i的方式避免尝试打散对其他机器的访问,dst_rdma_rank就是本次尝试发送的对端节点,然后dst_rdma_rank这个lane通过shfl_sync将num_tokens_to_send广播到所有线程,如果synced_num_tokens_to_send为0,则说明不需要向他发送数据。
__global__ void dispatch(...) { else if (warp_role == WarpRole::kRDMASenderCoordinator) {int last_issued_tail = 0;while (__any_sync(0xffffffff, num_tokens_to_send > 0)) {for (int i = 0, synced_num_tokens_to_send; i < kNumRDMARanks; ++ i) {int dst_rdma_rank = (i + channel_id) % kNumRDMARanks;synced_num_tokens_to_send = __shfl_sync(0xffffffff, num_tokens_to_send, dst_rdma_rank);if (synced_num_tokens_to_send == 0)continue;...}
}
然后去查看dst_rdma_rank对应的fifo队列,首先将lane[dst_rdma_rank]的last_issued_tail广播给所有lane,表示上次网络发送的流程执行到了哪里,然后读取这个fifo的rdma_send_channel_tail,表示数据生产到了哪里,如果这俩值不想等,说明产生了新数据,num_tokens_processed表示这次需要发送多少个token,这里首先会尝试合并到一个比较大的数据量,如果这次发送的token小于一个阈值,并且不是最后一次发送,那么就不执行发送,等合并到较大规模数据再发送
__global__ void dispatch(...) { else if (warp_role == WarpRole::kRDMASenderCoordinator) {...auto synced_last_issued_tail = __shfl_sync(0xffffffff, last_issued_tail, dst_rdma_rank);auto processed_tail = ld_acquire_cta(const_cast<const int*>(rdma_send_channel_tail + dst_rdma_rank));auto num_tokens_processed = processed_tail - synced_last_issued_tail;if (num_tokens_processed != synced_num_tokens_to_send and num_tokens_processed < num_max_rdma_chunked_send_tokens)continue;}
}
然后通过synced_last_issued_tail找到fifo中需要发送数据的位置,然后通过nvshmemx_int8_put_nbi_warp执行rdma_write发送数据到对端机器的recv buffer。然后执行fence操作,nvshmem的fence只能保证下发的顺序,不保证完成顺序,因此对于rdma来讲是一个空操作。下发write之后,更新last_issued_tail,通过nvshmemx_signal_op执行一个rdma atomic操作将这次发送的数据量atomic add到对端的rdma_channel_tail,通知对端机器,这里atomic是为了保序,前序的write落盘之后才会执行atomic操作。
__global__ void dispatch(...) { else if (warp_role == WarpRole::kRDMASenderCoordinator) {...auto num_tokens_to_issue = min(num_tokens_processed, num_max_rdma_chunked_send_tokens);if (dst_rdma_rank != rdma_rank) {auto dst_slot_idx = synced_last_issued_tail % num_max_rdma_chunked_recv_tokens;EP_DEVICE_ASSERT(dst_slot_idx + num_tokens_to_issue <= num_max_rdma_chunked_recv_tokens);nvshmemx_int8_put_nbi_warp(rdma_channel_data.recv_buffer(rdma_rank) + dst_slot_idx * num_bytes_per_rdma_token,rdma_channel_data.send_buffer(dst_rdma_rank) + dst_slot_idx * num_bytes_per_rdma_token,num_bytes_per_rdma_token * num_tokens_to_issue,translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));nvshmem_fence();} else {// Lighter fence for local RDMA rankmemory_fence();} // Update tails__syncwarp();if (lane_id == dst_rdma_rank) {last_issued_tail += num_tokens_to_issue;num_tokens_to_send -= num_tokens_to_issue;nvshmemx_signal_op(rdma_channel_tail.buffer(rdma_rank), num_tokens_to_issue, NVSHMEM_SIGNAL_ADD,translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));} }
}
kRDMAAndNVLForwarder
这个角色的作用就是每个节点上负责接收其他节点发送过来的token,并将这些token转发到当前机器上的其他gpu。
到这里完整的rdma发送过程中的同步就可以看到了,如下图所示,左边为发送端,右边为接收端,假设是rank0,假设都是sm0,发送端的Coordinator在执行了write之后会通过rdma atomic写rank0的rdma_channel_tail,rank0发现fifo中有数据之后,开始尝试将数据转发给当前机器上的其他gpu。
const auto role_meta = [=]() -> std::pair<WarpRole, int> {if (is_forwarder) {if (warp_id < NUM_MAX_NVL_PEERS) {return {WarpRole::kRDMAAndNVLForwarder, (warp_id + channel_id) % NUM_MAX_NVL_PEERS};} else {return {WarpRole::kForwarderCoordinator, warp_id - NUM_MAX_NVL_PEERS};} } else ...}(); auto warp_role = role_meta.first;auto target_rank = role_meta.second
之前有说过,DeepEP初始化的时候,会在每个gpu上分配buffer_ptr,并映射到其他gpu,kRDMAAndNVLForwarder负责数据的发送,假设target_rank为2,当前nvl_rank为1,那么他将拿到gpu2的buffer_ptr,即ws_rr_buffer_ptr,rs_wr_rank为nvl_rank,假设为sm1,那么nvl_channel_x指向ws_rr_buffer_ptr中的位置如下图所示,表示他将写入到gpu2的第一个channel的buffer中src_rank为1的位置。
__global__ void dispatch(...) { if (warp_role == WarpRole::kRDMAAndNVLForwarder) {// NOTES: `rs_wr_buffer_ptr` means "Read for Senders, Write for Receivers", `ws_rr_buffer_ptr` means "Write for Senders, Read for Receivers"rs_wr_buffer_ptr = buffer_ptrs[nvl_rank], ws_rr_buffer_ptr = buffer_ptrs[target_rank], rs_wr_rank = nvl_rank, ws_rr_rank = target_rank;auto nvl_channel_x = AsymBuffer<int4>(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * hidden_int4, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);auto nvl_channel_prefix_start = AsymBuffer<int>(ws_rr_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);auto nvl_channel_prefix_end = AsymBuffer<int>(ws_rr_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);}
}
然后看下解析send端发送过来的meta data的逻辑,每个lane_id对应一个源节点,轮询rdma_channel_meta直到收到数据,还是假设target_rank为2,当前nvl_rank为1,channel_id为x,那么对于lane[lane_id],会从meta中读取到第lane_id个机器的channel[0, x - 1]发到当前机器上gpu[dst_nvl_rank]的token总数,然后将这个作为写入gpu[dst_nvl_rank]的nvl_channel_prefix_start的lane_id位置,同理对于nvl_channel_prefix_end和rdma数据,num_tokens_to_recv_from_rdma表示当前sm会收到lane_id这个机器发送过来多少token。
__global__ void dispatch(...) { if (warp_role == WarpRole::kRDMAAndNVLForwarder) {if (lane_id < kNumRDMARanks) {while (true) {auto meta_0 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + dst_nvl_rank);auto meta_1 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + NUM_MAX_NVL_PEERS + dst_nvl_rank);auto meta_2 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + NUM_MAX_NVL_PEERS * 2);auto meta_3 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + NUM_MAX_NVL_PEERS * 2 + 1);if (meta_0 < 0 and meta_1 < 0 and meta_2 < 0 and meta_3 < 0) {// Notify NVL ranksint start_sum = -meta_0 - 1, end_sum = -meta_1 - 1;EP_DEVICE_ASSERT(start_sum >= 0 and end_sum >= 0 and end_sum >= start_sum);st_relaxed_sys_global(nvl_channel_prefix_start.buffer() + lane_id, -start_sum - 1);st_relaxed_sys_global(nvl_channel_prefix_end.buffer() + lane_id, -end_sum - 1);// Save RDMA channel received token countsrc_rdma_channel_prefix = -meta_2 - 1;auto src_rdma_channel_prefix_1 = -meta_3 - 1;num_tokens_to_recv_from_rdma = src_rdma_channel_prefix_1 - src_rdma_channel_prefix;if (not kCachedMode)recv_rdma_channel_prefix_matrix[lane_id * num_channels + channel_id] = src_rdma_channel_prefix_1;src_rdma_channel_prefix += lane_id == 0 ? 0 : recv_rdma_rank_prefix_sum[lane_id - 1];EP_DEVICE_ASSERT(num_tokens_to_recv_from_rdma >= 0);break;}}}}
}
然后就是开始转发数据,机内数据的转发同步方式也是通过fifo进行同步,判断head/tail可以知道有没有新数据产生,以及buffer是否满了,这个过程和NCCL的机内kernel同步比较相似,如下图所示。
首先判断是否可以转发数据到其他gpu,第一个while循环表示所有节点发送过来的数据有没有都处理完,如果没处理结束,然后每个warp开始判断对应的gpu的fifo中是否有空闲,fifo总的容量为num_max_nvl_chunked_recv_tokens,已经使用的空间为tail - head,如果剩余的空间大于等于num_max_nvl_chunked_send_tokens,那么可以发送数据。
__global__ void dispatch(...) { if (warp_role == WarpRole::kRDMAAndNVLForwarder) {while (__any_sync(0xffffffff, num_tokens_to_recv_from_rdma > 0)) {// Check destination queue emptiness, or wait a buffer to be releasedstart_time = clock64();while (lane_id == 0) { int num_used_slots = cached_nvl_channel_tail - cached_nvl_channel_head;if (num_max_nvl_chunked_recv_tokens - num_used_slots >= num_max_nvl_chunked_send_tokens)break;cached_nvl_channel_head = ld_volatile_global(nvl_channel_head.buffer());} }}
}
还是假设target_rank为2,当前nvl_rank为1,那么当前warp发现gpu[2]的buffer有空闲之后,开始轮询其他所有的机器,假设这次处理的机器是src_rdma_rank,因为每个lane_id对应了一个节点,因此将lane[src_rdma_rank]的剩余token数广播到warp,如果剩余token大于0,那么开始判断src_rdma_rank有没有新发送数据过来,判断方法就是看rdma_channel_tail和rdma_channel_head是否相等,不相等的话说明有新数据产生。
__global__ void dispatch(...) { if (warp_role == WarpRole::kRDMAAndNVLForwarder) {while (true) {src_rdma_rank = (src_rdma_rank + 1) % kNumRDMARanks;if (__shfl_sync(0xffffffff, num_tokens_to_recv_from_rdma, src_rdma_rank) > 0) { if (lane_id == src_rdma_rank and cached_rdma_channel_head == cached_rdma_channel_tail)cached_rdma_channel_tail = static_cast<int>(ld_acquire_sys_global(rdma_channel_tail.buffer(src_rdma_rank)));if (__shfl_sync(0xffffffff, cached_rdma_channel_tail > cached_rdma_channel_head, src_rdma_rank))break;} } auto src_rdma_head = __shfl_sync(0xffffffff, cached_rdma_channel_head, src_rdma_rank);auto src_rdma_tail = __shfl_sync(0xffffffff, cached_rdma_channel_tail, src_rdma_rank);}
}
然后开始将数据从当前gpu的rdma buffer转发到对端gpu的nvlbuffer。
遍历rdma buffer中的所有新token,即从src_rdma_head到src_rdma_tail,源数据为shifted,从shifted读取meta信息,判断这个token是否需要被转发到dst_nvl_rank,如果不需要就判断下一个token。通过cached_nvl_channel_tail可以算出这次填充到fifo的那个slot。然后开始实际的拷贝数据到nvl_channel_x。
__global__ void dispatch(...) { if (warp_role == WarpRole::kRDMAAndNVLForwarder) {for (int i = src_rdma_head, num_tokens_sent = 0; i < src_rdma_tail; ++ i) {auto rdma_slot_idx = i % num_max_rdma_chunked_recv_tokens;void* shifted = rdma_channel_data.recv_buffer(src_rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token;auto src_meta = ld_nc_global(reinterpret_cast<SourceMeta*>(reinterpret_cast<int8_t*>(shifted) + hidden_bytes));lane_id == src_rdma_rank ? (num_tokens_to_recv_from_rdma -= 1) : 0;bool is_in_dst_nvl_rank = src_meta.is_token_in_nvl_rank(dst_nvl_rank);if (lane_id == src_rdma_rank) {auto cached_head = is_in_dst_nvl_rank ? rdma_nvl_token_idx : -1;rdma_nvl_token_idx += is_in_dst_nvl_rank;if (not kCachedMode)send_nvl_head[i * NUM_MAX_NVL_PEERS] = cached_head;}if (not is_in_dst_nvl_rank)continue;int dst_slot_idx = (cached_nvl_channel_tail ++) % num_max_nvl_chunked_recv_tokens;// Copy dataUNROLLED_WARP_COPY(5, lane_id, hidden_int4,nvl_channel_x.buffer() + dst_slot_idx * hidden_int4,reinterpret_cast<int4*>(shifted),ld_nc_global, st_na_global);}// __shared__ volatile int forward_channel_head[NUM_MAX_NVL_PEERS][kNumRDMARanks];if (lane_id == src_rdma_rank)forward_channel_head[dst_nvl_rank][src_rdma_rank] = (cached_rdma_channel_head = src_rdma_tail);if (lane_id == 0)st_release_sys_global(nvl_channel_tail.buffer(), cached_nvl_channel_tail);}
}
可以看到除了更新了tail指针通知对端gpu有新数据之外,还更新了forward_channel_head,forward_channel_head[dst_nvl_rank][src_rdma_rank]表示从src_rdma_rank过来的数据,被转发到dst_nvl_rank的进度到了哪里,这个下边看kForwarderCoordinator的时候会介绍。
kForwarderCoordinator
kForwarderCoordinator只有一个warp,每个lane对应一个源节点,他做的事情就是将新的head通过rdma write写回到源节点,通知源节点自己已经处理过这个数据了。因为比如源节点写过来一个token之后,这个token可能会被转发给多个gpu,kRDMAAndNVLForwarder中每个warp负责转发这个token到一个gpu,那么只有当所有warp都处理完这个token之后,才能通知源节点。
具体看下代码,每个lane遍历forward_channel_head[0 - NUM_MAX_NVL_PEERS][lane]找到min_head,通过rdma_write将这个min_head发送给节点lane的rdma_channel_head,即图3中的绿线。
__global__ void dispatch(...) { else if (warp_role == WarpRole::kForwarderCoordinator) {int last_head = 0, target_rdma = lane_id < kNumRDMARanks ? lane_id : 0; while (true) {// Find minimum headint min_head = std::numeric_limits<int>::max();#pragma unrollfor (int i = 0; i < NUM_MAX_NVL_PEERS; ++ i) if (not forward_channel_retired[i])min_head = min(min_head, forward_channel_head[i][target_rdma]);if (__all_sync(0xffffffff, min_head == std::numeric_limits<int>::max()))break;// Update remote headif (min_head != std::numeric_limits<int>::max() and min_head > last_head and lane_id < kNumRDMARanks)nvshmem_uint64_p(rdma_channel_head.buffer(rdma_rank), last_head = min_head,translate_dst_rdma_rank<kLowLatencyMode>(lane_id, nvl_rank));// Nanosleep and let other warps work__nanosleep(NUM_WAIT_NANOSECONDS);} }
}
kNVLReceivers
kNVLReceivers用于数据的接收,一共有NUM_MAX_NVL_PEERS个warp,每个warp对应一个gpu,即src_nvl_rank,会从这个gpu收数据,每个warp的lane对应一个节点。
然后从recv_gbl_rank_prefix_sum中获取total_offset,recv_gbl_rank_prefix_sum[i]表示前i个rank发送到当前gpu的总token数,因此这里就是获取第lane_id个机器的gpu[src_nvl_rank]对应的rank发送过来的数据的起始位置。
假设当前是第x个channel,nvl_channel_prefix_start表示的是第lane_id个机器的channel[0, x - 1]发送到当前rank的token数,因此当前rank处理lane_id机器发送过来的token数需要被存到[total_offset + prefix_sum_start, total_offset + prefix_sum_end)的区间。
最后通过warp_reduce可以知道其他所有机器的gpu[src_nvl_rank]的同sm发送过来的token总数。
__global__ void dispatch(...) { else {// NVL consumers// Retrieve rank offset from barrier results (each lane's register stores an RDMA rank)int src_nvl_rank = target_rank, total_offset = 0; if (lane_id < kNumRDMARanks and lane_id * NUM_MAX_NVL_PEERS + src_nvl_rank > 0) total_offset = recv_gbl_rank_prefix_sum[lane_id * NUM_MAX_NVL_PEERS + src_nvl_rank - 1];// Receive channel offsetsint start_offset = 0, end_offset = 0, num_tokens_to_recv;auto start_time = clock64();while (lane_id < kNumRDMARanks) {start_offset = ld_volatile_global(nvl_channel_prefix_start.buffer() + lane_id);end_offset = ld_volatile_global(nvl_channel_prefix_end.buffer() + lane_id);if (start_offset < 0 and end_offset < 0) { start_offset = -start_offset - 1, end_offset = -end_offset - 1; total_offset += start_offset;break;} } num_tokens_to_recv = warp_reduce_sum(end_offset - start_offset);}
}
然后通过head和tail判断是否有新数据产生,如果head和tail不相等,说明产生了新的数据。
__global__ void dispatch(...) { else {int cached_channel_head_idx = 0, cached_channel_tail_idx = 0; while (num_tokens_to_recv > 0) { while (lane_id == 0) { // Ready to copyif (cached_channel_head_idx != cached_channel_tail_idx)break;cached_channel_tail_idx = ld_acquire_sys_global(nvl_channel_tail.buffer());} } }
}
然后开始处理从head开始到tail的token,通过head可以知道元数据在buffer中的哪个slot,读取meta信息可以知道数据是哪个src_rdma_rank发送过来的,然后将src_rdma_rank的total_offset广播到整个warp,即recv_token_idx,数据将被放到输出的这个位置,并将total_offset + 1。
然后开始拷贝数据,即将数据从nvl_channel_x中的slot拷贝到recv_x + recv_token_idx * hidden_int4。
同理拷贝scale,topk_idx,topk_weights,最后更新tail。
__global__ void dispatch(...) { else {int num_recv_tokens = cached_channel_tail_idx - cached_channel_head_idx;for (int chunk_idx = 0; chunk_idx < num_recv_tokens; ++ chunk_idx, -- num_tokens_to_recv) {int token_idx_in_buffer = (cached_channel_head_idx ++) % num_max_nvl_chunked_recv_tokens;auto meta = ld_nc_global(nvl_channel_src_meta.buffer() + token_idx_in_buffer);int64_t recv_token_idx = __shfl_sync(0xffffffff, total_offset, meta.src_rdma_rank);(lane_id == meta.src_rdma_rank) ? (total_offset += 1) : 0; // Copy dataUNROLLED_WARP_COPY(5, lane_id, hidden_int4,recv_x + recv_token_idx * hidden_int4,nvl_channel_x.buffer() + token_idx_in_buffer * hidden_int4,ld_nc_global, st_na_global);// Copy source metaif (lane_id == 0 and not kCachedMode)st_na_global(recv_src_meta + recv_token_idx, meta);// Copy scalesUNROLLED_WARP_COPY(1, lane_id, num_scales,recv_x_scales + recv_token_idx * num_scales,nvl_channel_x_scales.buffer() + token_idx_in_buffer * num_scales,ld_nc_global, st_na_global);// Copy `topk_idx` and `topk_weights`if (lane_id < num_topk) {auto recv_idx = recv_token_idx * num_topk + lane_id;auto buffer_idx = token_idx_in_buffer * num_topk + lane_id;st_na_global(recv_topk_idx + recv_idx, static_cast<int64_t>(ld_nc_global(nvl_channel_topk_idx.buffer() + buffer_idx)));st_na_global(recv_topk_weights + recv_idx, ld_nc_global(nvl_channel_topk_weights.buffer() + buffer_idx));}}// Move queue__syncwarp();if (lane_id == 0)st_relaxed_sys_global(nvl_channel_head.buffer(), cached_channel_head_idx);}
}