真实面经题目 · 原创解析

单机多卡 LLM 推理中的分布式 GEMM 如何切分矩阵,并完成跨 GPU 通信?

这题考 tensor parallel 下 GEMM 切分和 collective communication 的基本工程理解。回答要能把矩阵维度切分、局部计算、AllReduce/AllGather/ReduceScatter、通信计算重叠和推理场景约束讲清楚。

出现于:阿里巴巴 · C/C++

60 秒回答模板

我会从一个 GEMM 写起:C = A × B。单机多卡分布式 GEMM 的核心是把 A、B 或输出 C 按某个维度切到多张 GPU 上,每张卡做局部矩阵乘,再通过 collective communication 合成下游需要的结果。常见有列切 B,也就是每张卡持有一部分输出通道,局部算出 C 的一部分,后续若下一层也按兼容方式切分,可以暂时不通信;也有行切 B 或切输入通道,每张卡算一部分 partial sum,最后需要 AllReduce 或 ReduceScatter 把部分和合并。LLM 推理里的 tensor parallel 常把线性层成对设计,例如一层列并行减少输出聚合,后一层行并行再做规约。跨 GPU 信息共享一般通过 NCCL 这类库做 AllReduce、AllGather、ReduceScatter 或 Broadcast,单机内走 NVLink、PCIe、NVSwitch 等链路。工程优化要关注切分维度是否减少通信量、通信是否落在关键路径、能否和计算重叠、batch/sequence 太小时通信延迟是否主导,以及显存、KV cache、负载均衡和数值一致性。

考点 维度切分
难度 真实面经题
回答目标 讲清分布式 GEMM 切分和通信

深入解析

01

先把 GEMM 写成矩阵维度

GEMM 可以写成 C[M,N] = A[M,K] × B[K,N]。分布式 GEMM 本质是在 M、N 或 K 维度上切分数据和权重,让每张 GPU 只持有一部分矩阵,做局部乘法,再根据切分方式决定是否需要把结果拼接或规约。

02

按 N 维切是输出分片

如果把 B 按 N 维切成多份,每张 GPU 计算一部分输出通道 C_i = A × B_i。这个模式常被称为列并行。它的好处是局部结果就是输出的一段,可以用 AllGather 拼完整输出,也可以让下一层继续消费分片结果,从而推迟或减少通信。

03

按 K 维切需要规约部分和

如果把 K 维切开,每张 GPU 只计算 A_i × B_i 得到同形状的 partial C,最终 C 是所有 partial C 的求和。这类切分需要 AllReduce 或 ReduceScatter 完成跨卡求和。它适合和前后层的切分配合,但通信规约通常会成为关键成本。

04

通信原语取决于下游需要

如果下游需要每张卡都有完整结果,通常用 AllGather 或 AllReduce;如果下游继续按分片处理,可以用 ReduceScatter 保留分片结果;如果只有一张卡需要结果,则可能用 Gather 或 Reduce。LLM 推理系统会尽量让相邻层的并行方式匹配,避免每层都全量同步。

05

单机多卡还要看互联拓扑

单机内通信可能走 NVLink、NVSwitch 或 PCIe,不同拓扑的带宽和延迟差异很大。分布式 GEMM 的切分不只看计算量平均,还要看通信量、通信次数、collective 算法、GPU 放置和链路瓶颈。小 batch decode 阶段尤其容易被通信延迟放大。

06

优化目标是减少关键路径通信

工程上会用更合适的并行策略、算子融合、通信计算重叠、分组 GEMM、异步 collective、减少不必要的 AllGather,以及保持分片张量继续流动来优化。最终要用端到端延迟、token 吞吐、通信占比和显存占用验证,而不是只看单个 GEMM 的 TFLOPS。

易错点

  • 只说把模型参数平均分到多张卡,没有落到 GEMM 的 M/N/K 维度切分。
  • 把所有跨卡通信都叫同步,没有区分 AllReduce、AllGather、ReduceScatter 的语义。
  • 忽略相邻层并行方式配合,导致每层都拼完整张量,通信成本过高。
  • 只看单卡 GEMM 性能,不看单机互联拓扑和 collective 延迟。
  • 认为多卡一定更快,忽略小 batch decode 下通信延迟可能主导。
  • 把训练里的数据并行梯度同步直接套到推理分布式 GEMM,没有说明 tensor parallel 的局部计算和结果合成。

面试官追问

列并行和行并行为什么经常配对出现?

列并行产生输出通道分片,行并行可以消费分片输入并在输出处规约。两者配合可以减少中间层全量 AllGather,让通信集中在必要位置。

AllReduce 和 ReduceScatter 的区别是什么?

AllReduce 后每张卡都有完整规约结果;ReduceScatter 会先规约再把结果切分给各卡。若下游能继续使用分片结果,ReduceScatter 可以减少每张卡保留的数据量和后续通信。

为什么推理 decode 阶段多卡效率可能不高?

Decode 通常逐 token、小 batch,单步计算量较小,collective 的固定延迟和 KV cache 访问更容易占主导。多卡增加算力的同时也增加同步开销。

跨 GPU 信息共享一定要自己写通信吗?

通常不会手写底层通信,而是通过 NCCL 或框架封装调用 collective。工程重点是选择合适原语、安排张量切分和减少关键路径上的同步。