真实面经题目 · 原创解析
手写 CUDA Softmax2D 时,如何在 K 维做数值稳定的 softmax,避免指数溢出和精度爆炸?
这题考 row-wise softmax kernel 的正确性和性能:按 K 维减最大值、FP32 累加、block/warp reduction、不同 K 大小的线程映射,以及极端输入验证。
真实面经题目 · 原创解析
这题考 row-wise softmax kernel 的正确性和性能:按 K 维减最大值、FP32 累加、block/warp reduction、不同 K 大小的线程映射,以及极端输入验证。
我会先把输入看成 X[M, K],每一行 m 在 K 维做 softmax。直接算 exp(x_j) / sum(exp(x_t)) 会在 x 很大时指数溢出,也会在差距很大时造成精度问题。稳定做法是每行先求 row_max = max_j X[m, j],再计算 exp(X[m, j] - row_max),因为 softmax 对同一行减常数不变;此时最大指数是 exp(0)=1,可以避免正向溢出。然后对这些 exp 值求 row_sum,最后输出 Y[m, j] = exp(X[m, j] - row_max) / row_sum。CUDA 实现上,常见是一个 block 或一个 warp 负责一行,线程沿 K 维 stride 读取,先做 block/warp reduce max,再做 reduce sum,最后归一化写回。FP16/BF16 输入时,max、sum 和除法通常用 FP32 计算,再按输出类型写回。K 较小时可用 warp-level reduction,K 中等时一个 block 一行,K 很大时可能需要多 block 分段归约或两阶段 kernel。性能上要保证行内 K 维连续访问、减少重复 global memory 读写、必要时缓存中间 exp 或重算 exp,并用极端大值、全相等、带 mask、全 -inf 等用例验证正确性。
输入是 M*K,可以理解为 M 行、每行 K 个元素。题目说在 K 方向做 softmax,意味着每一行独立归一化:对固定行 m,输出 y[m,j] = exp(x[m,j]) / sum_t exp(x[m,t])。不同 M 行之间没有数学依赖,适合并行映射到不同 block 或 warp。
直接 exp(x) 的风险是 x 大时溢出到 inf。因为 softmax(x) = softmax(x - c),可以令 c 为该行最大值。这样每个 x[m,j] - row_max <= 0,最大的指数为 1,其他值在 0 到 1 之间,从根上降低指数溢出和分母爆炸风险。
稳定 softmax 通常分三步:第一遍沿 K 维 reduce 得到 row_max;第二遍计算 exp(x - row_max) 并 reduce 得到 row_sum;第三遍计算 exp(x - row_max) / row_sum 写出。为了省全局内存,有的实现会在第二遍把 exp 缓存到 shared memory 或寄存器,有的会第三遍重算 exp,取舍取决于 K 大小和资源占用。
K 很小时,一个 warp 处理一行可以减少同步开销;K 中等时,一个 block 处理一行,用 shared memory 或 warp shuffle 做 block reduction;K 很大时,单个 block 可能覆盖不完或效率不足,需要多 block 分段求 max/sum,再做二阶段合并。不能只写一个固定 block 配置就声称适配所有 K。
如果输入是 FP16 或 BF16,指数、sum 和除法最好在 FP32 中完成,最后再转换成目标输出类型。row_sum 是 K 个正数的累加,K 大时低精度累加误差会放大;FP32 accumulator 能显著提升稳定性。还要处理 row_sum 为 0、NaN、inf 或 mask 后全无效的边界策略。
Softmax2D 的访问通常按行连续,如果内存 layout 是 row-major,相邻线程读取连续 K 位置能获得较好 coalescing。优化可考虑向量化 load、减少 global memory pass、融合 mask/scale/dropout 或后续算子。但每次优化都要用 CPU/PyTorch reference、极端大正数、大负数、全相等值、随机长 K 和非整齐 K 做误差验证。
因为分子和分母都会乘上同一个 exp(-c),比例不变。选择 c=row_max 后,所有指数输入都不大于 0,最大指数为 1,数值更稳定。
每个位置的 exp(x - max) 都是 1,sum 是 K,所以输出是 1/K。这是验证实现是否正确的基础用例。
K 较大时,很多指数值相加会放大舍入误差,FP16 动态范围和有效位都有限。用 FP32 累加能降低误差,最后再转换输出类型。
单 block 线程数有限,K 很大时每个线程要处理很多元素,归约和多次读全局内存成本上升。可能需要多 block 分段归约 row_max 和 row_sum,再做归一化。
mask 位置通常设为极小值或在计算中排除。要特别处理一行全被 mask 的情况,否则 row_max、row_sum 和除法可能产生 NaN,需要定义清楚输出策略。