真实面经题目 · 原创解析
单机多卡 LLM 推理中的分布式 GEMM 如何切分矩阵,并完成跨 GPU 通信?
这题考 tensor parallel 下 GEMM 切分和 collective communication 的基本工程理解。回答要能把矩阵维度切分、局部计算、AllReduce/AllGather/ReduceScatter、通信计算重叠和推理场景约束讲清楚。
真实面经题目 · 原创解析
这题考 tensor parallel 下 GEMM 切分和 collective communication 的基本工程理解。回答要能把矩阵维度切分、局部计算、AllReduce/AllGather/ReduceScatter、通信计算重叠和推理场景约束讲清楚。
我会从一个 GEMM 写起:C = A × B。单机多卡分布式 GEMM 的核心是把 A、B 或输出 C 按某个维度切到多张 GPU 上,每张卡做局部矩阵乘,再通过 collective communication 合成下游需要的结果。常见有列切 B,也就是每张卡持有一部分输出通道,局部算出 C 的一部分,后续若下一层也按兼容方式切分,可以暂时不通信;也有行切 B 或切输入通道,每张卡算一部分 partial sum,最后需要 AllReduce 或 ReduceScatter 把部分和合并。LLM 推理里的 tensor parallel 常把线性层成对设计,例如一层列并行减少输出聚合,后一层行并行再做规约。跨 GPU 信息共享一般通过 NCCL 这类库做 AllReduce、AllGather、ReduceScatter 或 Broadcast,单机内走 NVLink、PCIe、NVSwitch 等链路。工程优化要关注切分维度是否减少通信量、通信是否落在关键路径、能否和计算重叠、batch/sequence 太小时通信延迟是否主导,以及显存、KV cache、负载均衡和数值一致性。
GEMM 可以写成 C[M,N] = A[M,K] × B[K,N]。分布式 GEMM 本质是在 M、N 或 K 维度上切分数据和权重,让每张 GPU 只持有一部分矩阵,做局部乘法,再根据切分方式决定是否需要把结果拼接或规约。
如果把 B 按 N 维切成多份,每张 GPU 计算一部分输出通道 C_i = A × B_i。这个模式常被称为列并行。它的好处是局部结果就是输出的一段,可以用 AllGather 拼完整输出,也可以让下一层继续消费分片结果,从而推迟或减少通信。
如果把 K 维切开,每张 GPU 只计算 A_i × B_i 得到同形状的 partial C,最终 C 是所有 partial C 的求和。这类切分需要 AllReduce 或 ReduceScatter 完成跨卡求和。它适合和前后层的切分配合,但通信规约通常会成为关键成本。
如果下游需要每张卡都有完整结果,通常用 AllGather 或 AllReduce;如果下游继续按分片处理,可以用 ReduceScatter 保留分片结果;如果只有一张卡需要结果,则可能用 Gather 或 Reduce。LLM 推理系统会尽量让相邻层的并行方式匹配,避免每层都全量同步。
单机内通信可能走 NVLink、NVSwitch 或 PCIe,不同拓扑的带宽和延迟差异很大。分布式 GEMM 的切分不只看计算量平均,还要看通信量、通信次数、collective 算法、GPU 放置和链路瓶颈。小 batch decode 阶段尤其容易被通信延迟放大。
工程上会用更合适的并行策略、算子融合、通信计算重叠、分组 GEMM、异步 collective、减少不必要的 AllGather,以及保持分片张量继续流动来优化。最终要用端到端延迟、token 吞吐、通信占比和显存占用验证,而不是只看单个 GEMM 的 TFLOPS。
列并行产生输出通道分片,行并行可以消费分片输入并在输出处规约。两者配合可以减少中间层全量 AllGather,让通信集中在必要位置。
AllReduce 后每张卡都有完整规约结果;ReduceScatter 会先规约再把结果切分给各卡。若下游能继续使用分片结果,ReduceScatter 可以减少每张卡保留的数据量和后续通信。
Decode 通常逐 token、小 batch,单步计算量较小,collective 的固定延迟和 KV cache 访问更容易占主导。多卡增加算力的同时也增加同步开销。
通常不会手写底层通信,而是通过 NCCL 或框架封装调用 collective。工程重点是选择合适原语、安排张量切分和减少关键路径上的同步。