2026 年 4 月,DeepSeek 在 DeepGEMM 仓库的 PR #304 中引入了一个新的 MoE 推理实现。其核心设计是将 dispatch、两层 GEMM、SwiGLU 激活、combine 全部合并到一个 CUDA kernel 中,利用 Warp Specialization 使 NVLink 通信与 Tensor Core 计算在同一 SM 上并发执行。
本文直接从源码出发,逐层拆解实现。
本文引用的所有代码基于 DeepGEMM 仓库 commit 714dd1a。
1. 问题
标准 MoE 推理包含五个阶段:
EP-Dispatch → Linear1 (Gate+Up) → SwiGLU → Linear2 (Down) → EP-Combine传统实现中,每个阶段是独立的 kernel launch。每个 kernel 将中间结果写入 HBM,下一个 kernel 再从 HBM 读出。其中 dispatch 和 combine 阶段还涉及跨 GPU 的 NVLink 通信。
这种分离式设计的问题:
- 五次 kernel launch 意味着五次 global memory 往返
- 通信和计算是串行的——dispatch 通信完成后才能启动 Linear1,Linear2 完成后才能启动 combine 通信
- SM 在通信阶段处于闲置状态(通信由专门的 copy engine 或通信 kernel 处理)
MegaMoE 将这些阶段全部合并到一个 kernel 中,中间数据通过 SMEM 和 TMEM 传递,同时使通信与计算 overlap。
性能方面,官方给出的数据是分离方案相比推理提速 1.50∼1.73 倍,延迟敏感场景最高 1.96 倍。
2. 硬件前提:Blackwell SM100
MegaMoE 目前仅支持 sm_100f(Blackwell)目标架构。以下硬件特性是 MegaMoE 得以实现的前提。
| 特性 | 说明 | 在 MegaMoE 中的用途 |
|---|---|---|
| TMEM (Tensor Memory) | 256KB/SM,与通用寄存器独立的张量专用内存 | L1/L2 GEMM 之间的中间结果存储,避免占用 GPR 或回写 HBM |
| UMMA (Unbound MMA) | 替代 SM90 WGMMA 的张量核心指令集,原生支持 FP4 | 执行 FP8×FP4 混合精度 MMA,双 SM 协同(2-CTA cluster) |
| TMA (Tensor Memory Accelerator) | 硬件异步数据搬运引擎,支持跨 GPU RDMA | 从远端 GPU 直接 load token 数据到本地 SMEM |
| FP4 原生支持 | E2M1 格式 | 权重存储带宽需求为 FP8 的一半 |
| UE8M0 缩放格式 | 8-bit 无符号指数,0-bit 尾数 | Block scaling 的紧凑 scale factor 表示 |
| NVLink 6.0 | 3.6 TB/s 双向 | 提供足够通信带宽使 overlap 有意义 |
| ClusterTransactionBarrier | 集群级硬件同步原语 | 2-CTA 之间同步 |
TMEM 是 SM100 相对于 SM90 最关键的变化。SM90 的 WGMMA 要求累加器 D 矩阵全部驻留在通用寄存器中,这成为限制 occupancy 的瓶颈——寄存器容量决定了每个 SM 能同时执行的 tile 数量。TMEM 将累加器移出 GPR 文件,256KB/SM 的容量允许容纳更多 inflight tile,直接提高了 occupancy。
UMMA 在 SM100 上提供的 SM100_MMA_MXF8F6F4_2x1SM_SS 指令是 MegaMoE MMA 发射的核心。其中 MXF8F6F4 表示 A=FP8 (E4M3)、B=FP4 (E2M1)、D=F32 的混合精度;2x1SM 表示两个 SM 通过 2-CTA cluster 协同计算;SS 表示 block scaled。
3. Symmetric Memory
Symmetric Memory 是 MegaMoE 实现通信计算重叠的基础。在讨论 warp 分工之前,需要先理解这一机制。
3.1 原理
Symmetric Memory 是 PyTorch 2.x 通过 torch.distributed._symmetric_memory 导出的特性,底层基于 CUDA P2P memory access over NVLink。多个 GPU 将各自的物理 HBM 映射到相同的虚拟地址范围。每个 GPU 可以通过普通的 load/store 指令直接读写其他 GPU 的物理 HBM——无需经过 cudaMemcpy、NCCL 集合通信或 kernel launch。
从 MegaMoE 的 SymBuffer 实现(sym_buffer.cuh)可以看到完整的地址映射逻辑:
template <uint32_t kNumRanks>struct SymBuffer { int64_t base; // 本 GPU 在本地虚拟地址空间的基地址 int64_t offsets[kNumMaxRanks]; // offsets[i] = GPU_i.BASE - base uint32_t rank_idx;};
template <typename ptr_t>CUTLASS_DEVICE ptr_t map(const ptr_t& ptr, const uint32_t& dst_rank_idx) const { int64_t mapped_ptr = offsets[dst_rank_idx] + reinterpret_cast<int64_t>(ptr); return *reinterpret_cast<ptr_t*>(&mapped_ptr);}offsets[i] 存储的是 (GPU_i 本地物理 HBM 的虚拟地址) - (本 GPU 本地物理 HBM 的虚拟地址)。因此 map(local_ptr, dst_rank) 的返回值就是一个直接指向远端 GPU HBM 的有效指针,可以直接 dereference。
在 Python 侧,SymBuffer 的创建通过 symm_mem.empty() 和 symm_mem.rendezvous() 完成(mega/__init__.py):
self.buffer = symm_mem.empty(num_bytes, dtype=torch.int8, device='cuda')self.handle = symm_mem.rendezvous(self.buffer, group=group)empty() 在所有 GPU 上分配大小相同的 buffer,rendezvous() 交换各 GPU 的地址信息并填充到 SymBuffer 的 offsets 数组中。
3.2 远端访问延迟
通过 NVLink 进行 P2P load 的单次延迟约 1∼3 μs,远高于本地 HBM 的数百 ns。如果每个 token 的每个元素都通过细粒度 load 从远端获取,延迟将无法接受。
MegaMoE 通过 TMA 批量传输处理这个问题。在 dispatch 阶段拉取 token 数据时,每次 TMA load 传输的是整个 hidden 维度的连续数据(默认 7168 bytes)(sm100_fp8_fp4_mega_moe.cuh#L548):
ptx::tma_load_1d( pull_buffer.get_base_ptr(), sym_buffer.map(input_token_buffer.get_data_buffer(src_token_idx).get_base_ptr(), current_rank_in_expert_idx), pull_mbarrier, kHidden);TMA 是 GPU 的硬件 DMA 引擎,异步执行数据搬运。SM 发射 TMA 指令后可以立即切换到其他 warp 继续计算,不需要自旋等待搬运完成。单次 TMA load 的 payload 达到 KB 级别,摊薄了 NVLink 的固定访问延迟。
3.3 Dispatch Pull vs Combine Push
MegaMoE 在 dispatch 阶段使用 NVLink read(pull),在 combine 阶段使用 NVLink write(push)。
代码注释指出 pull 的优势:read 可以避免 write 的 flag 额外延迟。Pull 模式下,发起方完全控制读取时机——只需要知晓远端 buffer 已写入完成(通过 NVLink barrier 保证),不需要远端显式设置 per-transaction flag。Push 模式下,发送方写入后远端必须通过某种同步机制确认 buffer 可用,增加了一轮协调。
3.4 适用边界
Symmetric Memory 的底层 NVLink P2P 仅在同一 NVSwitch 域内有效。这意味着 MegaMoE 的通信计算融合限于单机/超节点内部。跨节点的 expert dispatch 需要其他机制(如 RDMA 转发),不在 MegaMoE 的设计范围内。
4. Warp Specialization
MegaMoE kernel 占用所有 SM,每个 SM 内部的 warp 被划分为三种功能角色,各自执行不同的代码路径。这是实现通信与计算重叠的核心机制。
4.1 线程布局
Kernel 模板参数定义了三种角色的线程数量(sm100_fp8_fp4_mega_moe.cuh#L21):
template < uint32_t kNumDispatchThreads, // Dispatch: 可配置 uint32_t kNumNonEpilogueThreads, // MMA: 固定 128 (= 4 warps) uint32_t kNumEpilogueThreads, // Epilogue: 可配置 ...>派生常量:
kNumDispatchWarps = kNumDispatchThreads / 32;kNumMMANonEpilogueWarps = 128 / 32; // = 4kNumEpilogueWarps = kNumEpilogueThreads / 32;三种 warp 的寄存器分配(sm100_fp8_fp4_mega_moe.cuh#L346):
constexpr uint32_t kNumDispatchRegisters = 48;constexpr uint32_t kNumNonEpilogueRegisters = 40;constexpr uint32_t kNumEpilogueRegisters = 208;// 约束:总和 ≤ 64512(Blackwell 单 SM 寄存器文件大小)Dispatch warp 需要的寄存器最少(48),因为它的主要操作是索引计算和 TMA 发射。MMA warp 只需要 40 个寄存器——矩阵乘法的主要工作状态在 TMEM 中,GPR 仅存储指令描述符和少量控制变量。Epilogue warp 需要 208 个寄存器,因为 SwiGLU 计算、amax reduction、FP8 cast、NVLink 写回等多步操作需要在寄存器中维护大量中间状态。
每个 SM 内部的 warp 角色分配:
┌──────────────────────┬──────────────────┬──────────────────────┐│ Dispatch Warps │ MMA Warps (N) │ Epilogue Warps ││ 48 regs/warp │ 40 regs/warp │ 208 regs/warp │├──────────────────────┼──────────────────┼──────────────────────┤│ Expert token 计数 │ TMA Load Token │ L1: SwiGLU + FP8 ││ 跨 GPU 写源 token 索引│ TMA Load Weight │ 重量化 + TMA store││ NVLink Pull Token │ UMMA Issue │ L2: BF16 cast + ││ 信号 l1_arrival │ 信号 tmem_full │ NVLink write ││ Workspace 清理 │ │ Combine reduce │└──────────────────────┴──────────────────┴──────────────────────┘4.2 Dispatch Warps
Dispatch warp 的入口在 warp_idx < kNumDispatchWarps 分支(sm100_fp8_fp4_mega_moe.cuh#L359)。
Step 1 — 本地计数
Dispatch warp 遍历本 GPU 上的所有 token,对每个 token 的 top-k expert 进行 atomicAdd_block 累加,得到每个 expert 的本地 token 数(L386)。
Step 2 — 跨 GPU 通告 token 数
每个 expert 的 token 数通过 64-bit atomic_add 写入 Symmetric Memory workspace 的 expert_send_count(L393):
for (uint32_t i = thread_idx; i < kNumExperts; i += kNumDispatchThreads) { const uint64_t send_value = (1ull << 32) | static_cast<uint64_t>(smem_expert_count[i]); smem_expert_count[i] = static_cast<uint32_t>( ptx::atomic_add(workspace.get_expert_send_count_ptr(i), send_value));}64-bit atomic 的高 32 位是「SM 完成计数器」,所有 SM 都写入后其值等于 kNumSMs * kNumRanks,表示所有 GPU 的所有 SM 都已完成该 expert 的 token 数统计。
Step 3 — 写源 token 索引到远端
dispatch warp 再次遍历 token,这次将 (token_idx, topk_idx) 组合写入对应 expert 所在 GPU 的 workspace 区域(L401):
*sym_buffer.map(dst_ptr, dst_rank_idx) = token_topk_idx;dst_ptr 指向 expert 所在 GPU 的 src_token_topk_idx 数组中的位置,sym_buffer.map() 将其转换为远端可写的指针。
Step 4 — Grid Sync + NVLink Barrier
所有 dispatch warp 完成写入后,执行 grid sync(SM 级同步),然后由 SM 0 的 dispatch warp 通过 red_add_rel_sys 将 per-rank recv count 搬运到 expert 所在 GPU。最后是 NVLink barrier,确保所有 rank 之间的计数信息完全同步。
Step 5 — Pull Token 数据
这是 dispatch 与计算重叠的关键阶段。dispatch warp 遍历本 GPU 上的所有 local expert,对每个 expert 通过 Round-Robin Min-Peeling 算法在所有 src rank 之间均衡选取 token。每选取一个 token,发射一次 TMA load 从远端 pull 该 token 的 hidden 向量到本地 SMEM(L548):
if (cute::elect_one_sync()) { ptx::tma_load_1d( pull_buffer.get_base_ptr(), sym_buffer.map(input_token_buffer.get_data_buffer(src_token_idx).get_base_ptr(), current_rank_in_expert_idx), pull_mbarrier, kHidden);}TMA load 完成后,dispatch warp 将 token 数据从 pull buffer(SMEM)搬到 L1 token buffer(也位于 Symmetric Memory workspace),写入 token 权重和源元数据,然后通过 red_add_rel 递增 l1_arrival_count,通知 MMA warp 该 token block 已就绪(L599):
ptx::red_add_rel( workspace.get_l1_arrival_count_ptr( expert_pool_block_offset + token_idx_in_expert / BLOCK_M), 1);4.3 MMA Warps — TMA Load 数据 (warp_idx == kNumDispatchWarps)
该 warp 负责通过 TMA 将 token 激活量(activations)和 scale factor A 从 global memory 搬运到 SMEM,是 MMA 流水线的生产者。通过 scheduler.for_each_block() 遍历所有待计算的 block(L666)。
关键同步逻辑在 block 循环开始处(L682):
if (block_phase == BlockPhase::Linear1) { const auto ptr = workspace.get_l1_arrival_count_ptr(pool_block_idx); const auto expected = scheduler.get_valid_m<false>(); while (ptx::ld_acq(ptr) != expected);}这里 ld_acq 是 acquire 语义的 load,等待 dispatch warp 的 red_add_rel(release 语义)。expected 是当前 block 中有效 token 的数量(BLOCK_M 或最后一个 block 的余数)。当所有 token 数据到达后,MMA warp 启动 TMA load 将 token 数据搬入 SMEM。
对于 L2 phase,同步逻辑不同——它等待的是 L1 epilogue 的结果(L694):
const uint64_t expected = ((1ull << num_k_blocks) << num_k_blocks) - 1;while (ptx::ld_acq_gpu(ptr) != expected);因为 BLOCK_K == BLOCK_N,L1 输出的每个 block 被切分成两半(gate 和 up 分别对应 SwiGLU 的两个输入),所以 L2 需要等待 2 * num_k_blocks 个 bit 的 mask 全为 1。((1ull << n) << n) - 1 等价于 (1ull << 2n) - 1,但避免了 n==32 时的未定义行为。
TMA load 完成后,warp 通过 full_barrier.arrive() 通知 MMA issue warp 数据已就绪。
4.4 MMA Warps — TMA Load 权重 (warp_idx == kNumDispatchWarps + 1)
与上一 warp 对称,负责通过 TMA 将权重(weights)和 scale factor B 从 global memory 搬运到 SMEM(L735)。权重数据不需要等待 dispatch 完成,只需要等待流水线中的 consumer release(empty_barriers[stage_idx]->wait(phase ^ 1))。
4.5 MMA Warps — Issue (warp_idx == kNumDispatchWarps + 2)
这是唯一发射 UMMA 指令的 warp,且只有 leader CTA 执行(L778)。2-CTA cluster 中 leader CTA 负责计算,follower CTA 仅提供 TMA load 带宽。
在每个 K-block 迭代中,warp 先等待 TMA load 完成(full_barriers[stage_idx]->wait(phase)),然后通过 UTCCP 将 scale factor 从 SMEM copy 到 TMEM,最后发射 UMMA 指令(L854):
for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) { const auto runtime_instr_desc = mma::sm100::make_runtime_instr_desc_with_sf_id(instr_desc, k, k); ptx::SM100_MMA_MXF8F6F4_2x1SM_SS::fma( b_desc, a_desc, accum_stage_idx * UMMA_N, k_block_idx > 0 or k > 0, // clear: 第一个 K-block 需要清零累加器 runtime_instr_desc, kTmemStartColOfSFB, kTmemStartColOfSFA);}BLOCK_K 固定为 128,UMMA_K 为 32,所以每个 block 发射 4 次 UMMA 指令。第一条指令的 clear 标志为 1(累加器初始化为 0),后续指令的 clear 标志为 0(累加到已有结果上)。
MMA 发射完成后通过 empty_barrier_arrive() 释放流水线阶段,允许 TMA load warp 继续加载下一批数据。同时,如果当前 block 的所有 K 都已完成,则通过 tmem_full_barriers[accum_stage_idx]->arrive() 和 umma_arrive_multicast_2x1SM 通知 epilogue warp。
4.6 Epilogue Warps
Epilogue warp 的入口在 warp_idx >= kNumDispatchWarps + kNumMMANonEpilogueWarps(L887)。这是寄存器需求最高的角色(208 regs/warp)。
L1 Epilogue(L941):
从 TMEM 加载 GEMM 累加结果,通过 SM100_TMEM_LOAD_16dp256b1x 以 16 个 float 为粒度读出,然后执行 SwiGLU(L997):
// silu(x) = x * sigmoid(x), 其中 sigmoid(x) = 1 / (1 + e^(-x))auto gate = __bfloat1622float2(bf16_gate);auto neg_gate_exp = make_float2( kFastMath ? __expf(-gate.x) : expf(-gate.x), kFastMath ? __expf(-gate.y) : expf(-gate.y));const auto denom = __fadd2_rn({1.0f, 1.0f}, neg_gate_exp);gate = {gate.x / denom.x, gate.y / denom.y};swiglu_values[i * 2 + k] = __fmul2_rn(__fmul2_rn(gate, up), weights);kFastMath 路径使用 __expf()(内联 PTX 的 ex2.approx 指令,吞吐远高于 expf())和 fast_rcp()(用牛顿迭代逼近倒数),牺牲末位精度换取吞吐。
SwiGLU 之后,epilogue warp 执行:
- amax reduction:跨 warp 求 max-abs,用于计算 FP8 的 scale factor(L1024)
- FP8 cast:将 BF16 结果量化为 FP8 E4M3,写入 SMEM 的 CD buffer(L1055)
- TMA store:将 FP8 结果通过 TMA 写回 Symmetric Memory workspace 中的 L2 输入区域(L1099)
- 通知 L2:通过
red_or_rel_gpu设置l2_arrival_mask的对应 bit(L1118)
L2 Epilogue(L1124):
与 L1 epilogue 类似,从 TMEM 读出 GEMM 累加结果,转换为 BF16 后写入 SMEM。然后通过 NVLink write 直接将结果写回远端 GPU 的 combine buffer(L1212):
// 读取 dispatch 阶段存储的源元数据const auto src_metadata = *workspace.get_token_src_metadata_ptr(m_idx + m_idx_in_block);const uint32_t dst_rank_idx = src_metadata.rank_idx;const uint32_t dst_token_idx = src_metadata.token_idx;
// 直接通过 Symmetric Memory 写入远端 GPU*sym_buffer.map(dst_ptr, dst_rank_idx) = packed;Combine reduction(L1242):
所有 L2 epilogue 完成后,epilogue warp 执行最终的 combine——将同一 token 在不同 expert 上的 top-k 结果按路由权重归约。这通过 TMA load 从 combine buffer 将 chunk 数据搬入 SMEM 逐元素累加,再 TMA store 写回最终的 output tensor。
5. 通信计算重叠的同步机制
MegaMoE 的流水线重叠依赖两个核心同步变量,均存储在 Symmetric Memory workspace 中:
| 变量 | 类型 | 生产者 | 消费者 | 语义 |
|---|---|---|---|---|
l1_arrival_count[pool_block_idx] | uint32_t | Dispatch warp (red_add_rel) | MMA TMA warp (ld_acq spin) | 一个 token block 的所有 token 已 pull 到本地 L1 buffer |
l2_arrival_mask[pool_block_idx] | uint64_t | L1 Epilogue warp (red_or_rel_gpu) | MMA TMA warp (ld_acq_gpu spin) | L1 输出中某个 K 方向的结果已写入,以 bitmask 表示 |
这两个变量实现了三级重叠:
-
Dispatch ↔ L1 GEMM:dispatch warp 在 pull token 的同时,MMA warp 在等待
l1_arrival_count。一旦某个 block 就绪,MMA 立即启动。后续 block 的 NVLink pull 与当前 block 的 GEMM 并发。 -
L1 Epilogue ↔ L2 GEMM:L1 epilogue 完成后通过
red_or_rel_gpu设置l2_arrival_mask。L2 MMA warp 等待 mask 全 1 后启动。L1 后处理与 L2 GEMM 的 K-block 迭代卷积。 -
L2 Epilogue ↔ Workspace Clean:L2 epilogue 的 combine 写回与 dispatch warp 的 workspace 清理通过
sync_unaligned(dispatch + epilogue, barrier)交错执行。
注意:调度器(见第 6 节)保证了每个 wave 内所有 expert 的 L1 全部完成后才启动 L2。这简化了 L2 的同步——开始时大部分 L1 block 的 mask 已经就绪,减少了 L2 MMA warp 的自旋等待时间。
6. AB Swap
在 UMMA 配置中,一个非直观的设计选择是交换了 A(activation)和 B(weight)两个操作数。通常 GEMM 中激活作为 A、权重作为 B,但 MegaMoE 相反(sm100_fp8_fp4_mega_moe.cuh#L162):
// NOTES: always swap A/Bconstexpr uint32_t UMMA_M = LAYOUT_AD_M * 2; // = 256 (2-CTA)constexpr uint32_t UMMA_N = BLOCK_M; // Swap AB这一选择的动机来自 SM100 的 UMMA 指令约束。SM100 UMMA 的 M 维度固定为 128(一 CTA)或 256(2-CTA),而 N 维度可以灵活取值。在 MoE 推理中,尤其是 decode 阶段的小 batch 场景,每个 expert 的 token 数很少,BLOCK_M 可能取 16、32 等小值。如果将高维的 intermediate_hidden(通常 2048~7168)放在 M 方向,指令的 M=256 约束无法满足。
通过 AB Swap,权重维度(hidden × intermediate_hidden)对齐到 UMMA 的 M 方向,激活的 token 维度对齐到灵活的 N 方向。这意味着 BLOCK_M 可以按 token 数自由选择小值(如 16),而不受 UMMA 指令的 256 限制。
Swap 之后,MMA 中的 A 矩阵是权重(FP4),B 矩阵是激活(FP8)。由于 2-CTA 的 TMA multicast,两个 CTA 各自只需要加载一半的 A tile(LOAD_BLOCK_M = BLOCK_M / 2)。B tile 由两个 CTA 各自完整加载 BLOCK_N = 128 宽度。TMA Producer B(权重加载 warp)利用 TMA multicast,一次 load 的权重数据被两个 CTA 共享。
这一优化在 decode(小 BLOCK_M)和 prefill(大 BLOCK_M)场景下均有效:BLOCK_M 大小由 heuristics 根据 token-per-expert 动态选择,不受指令硬件约束。
7. MegaMoEScheduler 与启发式配置
MegaMoEScheduler(scheduler/mega_moe.cuh)是一个 per-SM 状态机,决定每个 SM 在每一轮分配到哪个 expert 的哪个 block。与调度器紧密配合的是启发式配置模块(csrc/jit_kernels/heuristics/mega_moe.hpp),它在 JIT 编译期根据运行时参数推导出正确的 kernel 配置。
7.1 启发式配置
启发式配置需要满足三个约束:正确性、共享内存上限、SM 利用率。
BLOCK_M 选择
get_block_config_for_mega_moe 根据每 expert 期望 token 数选择 BLOCK_M,候选值为 {16, 32, 64, 96, 128, 192}。预期 token 数的计算公式:
expected_tokens_per_expert = num_tokens × num_ranks × num_topk / num_experts分子是全局所有 token 乘以 topk 的总路由次数,分母是 expert 总数。实际路由不均匀,因此所有估算中引入了 kImbalanceFactor = 2,把目标工作量放大两倍以吸收热专家的尾延迟。
BLOCK_M 的选择阶梯为:token 很少时用小 BLOCK_M(16 或 32),避免 M 维度大量 padding 浪费计算;token 多时用大 BLOCK_M(128 或 192),提高单 tile 的算术密度。
Pool 容量
pool 容量必须覆盖所有 rank 的所有 token 全部路由到本 rank 的最坏情况(mega_moe.cuh#L17):
num_max_recv_tokens = num_ranks × num_max_tokens_per_ranknum_max_experts_per_token = min(num_topk, num_experts_per_rank)pool_size = align( num_max_recv_tokens × num_max_experts_per_token + num_experts_per_rank × (kMaxCandidateBlockM - 1), kLCMCandidateBlockM)其中 num_max_experts_per_token 是因为一个 token 的 top-k 选择不可能重复落在同一个 expert 上,所以本 rank 最多收到该 token 的 min(num_topk, num_experts_per_rank) 份副本。额外的 num_experts_per_rank × (kMaxCandidateBlockM - 1) 是各 expert 之间 BLOCK_M 对齐填充的上界。最终对齐到 kLCMCandidateBlockM = 384(所有候选 BLOCK_M 的最小公倍数)。
Expert Wave 粒度
get_num_experts_per_wave_for_mega_moe 的核心逻辑是:在保证所有 SM 都有 block 可算的前提下,wave 内 expert 越少越好。原因是 expert 少 → L1 整体完成得快 → L2 启动得早 → pipeline 深度浅。
具体计算:先估计一个 expert 的 L1 block 数(num_m_blocks × num_n_blocks),然后除以 SM 数得到需要的 expert 数下界。用 kImbalanceFactor = 2 放大以吸收路由不均。最后向上取整到 num_experts_per_rank 的因子(scheduler 要求所有 wave 处理相同数量的 expert)。
SMEM 分配与流水线深度
get_pipeline_config_for_mega_moe 估算共享内存的固定开销和单级流水线开销,然后计算 num_stages = (total_smem - fixed) / per_stage。固定开销包括 dispatch 区的 expert 计数器和 send buffer、CD 输出区(L1/L2 epilogue 共享 SMEM 的 CD buffer,取两者 max)、amax reduction 区和所有 mbarrier 对象。单级流水线开销包括 A tile + B tile + SFA + SFB + full_barrier + empty_barrier。
7.2 Wave 划分
MoE 计算被划分为多个 wave,每个 wave 处理 kNumExpertsPerWave 个 expert,包含两个 phase:L1、L2。
get_num_experts_per_wave_for_mega_moe 的选取原则:在保证所有 SM 都有活可干的前提下,expert 越少越好。原因是:expert 少 → L1 完成得快 → L2 启动得早 → pipeline 深度浅、延迟低。但 expert 太少会导致部分 SM 闲置。因此最优取值是刚好能打满所有 SM 的最小 expert 数。
6.2 块分配
同一个 expert 内部按先 N 方向(output channel)、后 M 方向(token)遍历所有 block。分配策略是 round-robin:
// fetch_next_l1_block 中n_block_idx = block_idx - m_block_idx * kNumL1BlockNs;block_idx += kNumSMs; // 跳转到下一轮属于本 SM 的 block所有 SM 从自己的 sm_id 开始,每轮前进 kNumSMs 步。这保证了每个 SM 的工作量均衡。
6.3 状态机
核心状态机位于 for_each_block 调用的 get_next_block()(L148):
L1 phase: 遍历当前 wave 所有 expert 的 L1 block → L1 完成 → 切换到 L2 phaseL2 phase: 遍历当前 wave 所有 expert 的 L2 block → L2 完成 → 切换到下一个 wave 的 L1 phase所有 wave 完成 → BlockPhase::None → kernel 结束状态机确保了一个关键性质:同一 wave 内,所有 SM 先共同完成所有 expert 的 L1,再切换到 L2。这意味着 L2 开始时大部分 L1 结果已就绪,L2 MMA warp 的自旋等待时间被最小化。
8. FP8×FP4 混合精度管线
7.1 精度配置
| 分量 | 格式 | 量化方式 | Block Size | Scale 格式 |
|---|---|---|---|---|
| 输入激活 | FP8 E4M3 | per-token | 32 | UE8M0 |
| L1 权重 | FP4 E2M1 | per-group | 32×32 | UE8M0 |
| L1 输出(L2 输入) | FP8 E4M3 | per-32 | 32 | UE8M0 |
| L2 权重 | FP4 E2M1 | per-group | 32×32 | UE8M0 |
| 最终输出 | BF16 | — | — | — |
FP4 E2M1 格式:1-bit 符号、2-bit 指数、1-bit 尾数,共 16 个可能值(包含 ±0、±NaN 等特殊值)。有效范围约为 [0.5, 7.5](不含非正规数)。
UE8M0 格式:8-bit 无符号整数指数,0-bit 尾数,值为 2^e。4 个 UE8M0 打包为一个 uint32_t。
7.2 Block Scaling
MegaMoE 的 scaling 是 block-wise 而非 per-tensor 的。对于权重,每个 32×32 的 block 有自己的 scale factor。对于激活,每 32 个元素(在 K 维度上)有一个 scale factor。
这种粒度避免了 per-tensor scaling 的精度问题:不同 expert 的权重分布可能差异很大,per-tensor scaling 会导致某些 expert 的权重在量化后丧失精度。Block-wise scaling 以存储少量额外 scale 为代价,保证了各 expert 的精度。
Block scaling 的 UMMA 指令语义为 D = (A * scale_A) * (B * scale_B) + D。硬件在 MMA 过程中自动应用 scale factor,不增加额外的指令开销。
SF 存储在 TMEM 中,与累加器共享同一块 TMEM 区域(L217):
constexpr uint32_t kNumAccumTmemCols = UMMA_N * kNumEpilogueStages;constexpr uint32_t kNumSFATmemCols = SF_BLOCK_M / 32;constexpr uint32_t kNumSFBTmemCols = SF_BLOCK_N / 32;7.3 L1 权重 Interleave
L1 权重(gate + up projection)使用 interleave 布局,而非简单的 [gate | up] 拼接(mega/__init__.py#L75):
# [gate: 0..7, up: 0..7, gate: 8..15, up: 8..15, ...]这是为了配合 SwiGLU epilogue 的 gate/up pair 消费模式。TMEM load 指令 SM100_TMEM_LOAD_16dp256b1x 一次读出 8 个值,排列为 (0,2) (1,3) (4,6) (5,7) 的 gate/up pair。Interleave 布局保证了 gate 和对应位置的 up 在 TMEM 中相邻,简化了 SwiGLU 的索引计算。
9. NVLink Barrier
MegaMoE 定义了三层同步原语(comm/barrier.cuh):
Grid Sync:SM 级全局同步。基于 cooperative_groups::this_grid().sync() 的思想,但使用轻量级的 atomic_add_rel + ld_acq 自旋实现。每个 SM 的一个线程递增计数器后自旋等待标志位翻转。
NVLink Barrier:跨 rank 同步。仅 SM 0 参与。使用 red_add_rel_sys(PTX release-consume atomic add with system scope)向远端 GPU 的 signal buffer 写入信号,接收方自旋等待。超时 30 秒(以 2 GHz 时钟计数),超时打印所有 rank 的同步状态后触发 device assert。
Cluster Sync:2-CTA 内同步。使用 SM100 的 ClusterTransactionBarrier,在 2-CTA cluster 的两个 CTA 之间同步。初始化阶段使用 cluster.arrive.relaxed + cluster.wait 的弱序变体。
三个关键调用点的 tag:
kBeforeDispatchPullBarrier — dispatch 开始 pull token 前,确保所有 rank 的计数/索引已完成kBeforeCombineReduceBarrier — combine reduction 前,确保所有 rank 的 L2 结果已写回kAfterWorkspaceCleanBarrier — workspace 清理完成,确保下次调用时状态正确10. 与 DeepEP 的差异
MegaMoE 和 DeepEP 是两种不同的 MoE 推理实现方向:
| 维度 | DeepEP | MegaMoE |
|---|---|---|
| 设计层级 | 通信库(dispatch / combine 为独立 kernel) | 全融合 kernel(通信嵌入计算) |
| 目标硬件 | H800 (SM90) | Blackwell (SM100) |
| 内核数量 | 3+(dispatch、compute、combine 独立) | 1(mega-kernel) |
| 去重 | src→dst 相同 expert 只传输一次 | 按 expert 粒度传输 |
| 通信域 | NVLink + RDMA(支持跨节点) | 仅 NVLink(单机/超节点) |
| 中间结果 | HBM | SMEM/TMEM |
DeepEP 的设计更适合跨节点场景和灵活的组合——用户可以将 dispatch/combine 与任意的 GEMM 实现配合使用。MegaMoE 消除了 kernel 边界和 intermediate write 的开销,但限于 SM100 和单节点,且需要整套流程使用其内部的 GEMM。
11. 计算-通信比的量化分析
MegaMoE 的设计建立在一个基本不等式之上:通信能否完全隐藏,取决于计算量是否足够覆盖通信时间。DeepSeek-V4 技术报告给出了量化结果。
对于 DeepSeek-V4-Pro 的配置(每 token-专家对的计算量为 SwiGLU 的 gate、up、down 三个投影的总 FLOPs),每 GB/s 的互连带宽足以隐藏约 6.1 TFLOP/s 的计算对应的通信。一旦单 GPU 的 NVLink 带宽满足这个阈值,通信就不再是瓶颈。
这个结果也解释了为什么 MegaMoE 宣称在小 batch(decode)场景的加速比(最高 1.96×)优于大 batch(prefill)场景(1.50×):大 batch 时单 expert 的 token 更多,计算时间更长,通信占比本来就低,overlap 带来的收益相对较小。而在 decode 的”长尾小批量”场景,每次 expert 收到的 token 极少(RL rollout 可能每个 expert 只有几个 token),传统分离式方案中通信延迟占比极高,overlap 的效果最显著。
反过来看,当计算和通信同时被推到极限时,功耗限制(power throttling)、NOC 拥塞、通信引发的 cache miss 等次生问题会浮现。这也是为啥 MegaMoE 目前限制在单节点 NVLink 域内——跨节点的 RDMA 延迟和功耗特征与 NVLink 差别很大,overlap 的设计参数需要重新计算。
12. 总结
MegaMoE 将 MoE 推理的五个阶段合并到单个 kernel 中,利用 Warp Specialization 实现 NVLink 通信与 Tensor Core 计算的流水线重叠。关键设计决策包括:
- Kernel 融合:dispatch + L1 GEMM + SwiGLU + L2 GEMM + combine 五合一,中间数据走 SMEM/TMEM
- Warp Specialization:四种子角色在同一 SM 内并发,通过
l1_arrival_count、l2_arrival_mask、tmem_full/empty_barrier等同步原语协调 - Symmetric Memory:基于 NVLink P2P 的远端内存直接访问,dispatch 用 pull(NVLink read),combine 用 push(NVLink write)
- AB Swap:交换 GEMM 的 A/B 操作数以突破 SM100 UMMA 指令的 M=256 维度约束,使
BLOCK_M可以自由取小值 - Heuristic JIT 配置:根据运行时 token 数和硬件参数动态选择
BLOCK_M、num_experts_per_wave、num_stages - 计算-通信比:每 GB/s NVLink 带宽可隐藏约 6.1 TFLOP/s 计算,构成 overlap 可行的理论基础
核心实现约 1380 行 CUDA C++(sm100_fp8_fp4_mega_moe.cuh),加上调度器 221 行、同步原语 83 行和启发式配置,总计不到 2000 行。