真实面经题目 · 原创解析
Qwen 这类大模型训练中,混合精度训练如何实现,为什么能提升吞吐并降低显存?
这题考的是大模型训练数值与系统效率的结合:混合精度不是简单把所有张量改成 FP16,而是在前向、反向、梯度、权重、优化器状态和通信之间选择合适精度来兼顾吞吐、显存和稳定性。
真实面经题目 · 原创解析
这题考的是大模型训练数值与系统效率的结合:混合精度不是简单把所有张量改成 FP16,而是在前向、反向、梯度、权重、优化器状态和通信之间选择合适精度来兼顾吞吐、显存和稳定性。
Qwen 这类大模型训练中的混合精度,可以概括为用低精度完成大部分矩阵计算和激活存储,用高精度保留容易出数值问题的部分。常见做法是前向和反向的 GEMM/attention/FFN 主计算用 FP16 或 BF16,让 GPU Tensor Core 发挥高吞吐;权重可能保留一份 FP32 master copy 或至少让 optimizer state 用 FP32 保存,Adam 的一阶二阶动量通常也用 FP32;LayerNorm/RMSNorm、softmax、loss、梯度累加或某些 reduction 可以保持 FP32 或做 FP32 accumulation。FP16 因为动态范围较小,训练时常配合 loss scaling:先把 loss 放大再反传,避免小梯度 underflow,更新前再 unscale,并在 overflow 时跳过 step 或调低 scale;BF16 动态范围接近 FP32,通常更少依赖 loss scaling,但精度尾数更少,也要监控收敛。性能提升来自三方面:低精度矩阵乘更快,activation/gradient 占用更小,显存带宽和分布式通信压力下降;显存降低后还可以放大 batch、序列长度或减少 checkpoint 压力。需要补充的是,混合精度不会把所有显存直接减半,因为 optimizer state、master weights、KV/activation、碎片和并行策略仍然占显存;稳定训练要持续监控 NaN、overflow、loss scale、梯度范数和评测指标。
训练中不同张量对速度、显存和数值稳定性的敏感度不同。矩阵乘和激活适合低精度,优化器状态和部分归约更适合高精度。好的回答要说明哪些部分低精度、哪些部分保高精度,而不是一句“用 FP16 训练”。
大模型训练的主要算力消耗来自 attention 和 FFN 中的大矩阵乘。FP16/BF16 能利用现代 GPU 的 Tensor Core,吞吐通常明显高于 FP32,同时中间激活、梯度等张量占用更少内存和带宽。
为了避免更新误差累积,训练常保留 FP32 master weights,或至少让 Adam 的一阶、二阶动量等 optimizer state 使用 FP32。归一化、softmax、loss 计算、梯度累加和某些 reduction 也常保持或累加到 FP32,以降低溢出、下溢和舍入误差。
FP16 动态范围较小,很多小梯度可能在反传中 underflow 到 0。loss scaling 会先把 loss 放大,使梯度进入可表示范围;更新前再把梯度 unscale。如果检测到 inf/NaN,动态 loss scaler 会降低 scale 并跳过本次参数更新,防止污染权重。
BF16 的指数位接近 FP32,动态范围更大,训练大模型时通常比 FP16 更不容易 overflow,因此很多场景更偏好 BF16。它的尾数位少于 FP16/FP32,精细数值精度有限,所以仍要关注收敛、归一化和累加精度。
低精度会减少参数副本、梯度和激活的存储压力,尤其 activation 随 batch、序列长度和层数增长很快。显存释放后可以使用更大 micro-batch、更长上下文、更少重计算,或在同样硬件上训练更大的模型。
低精度不仅让 Tensor Core 更快,也减少读写显存和跨卡传输的数据量。分布式训练中,如果梯度通信、参数 all-gather 或 reduce-scatter 使用合适低精度,也能降低通信带宽压力,但需要控制通信误差和累加精度。
混合精度上线前要监控 loss 曲线、梯度范数、NaN/inf、overflow 次数、loss scale 变化、评测指标和与 FP32/BF16 baseline 的差异。对数值敏感算子要进入 FP32 白名单或使用稳定 kernel,避免为了速度牺牲收敛。
因为 FP16 动态范围有限,小梯度容易下溢成 0。loss scaling 先放大 loss 和梯度,使其落在可表示范围内;更新前再把梯度缩回去,并在检测到 overflow 时调整 scale 或跳过 step。
不是。BF16 动态范围大,通常不太需要 loss scaling,但尾数位少,累加、归一化、softmax、优化器状态仍可能需要更高精度或稳定 kernel,最终要用 loss 和评测验证。
参数副本、梯度和激活从 FP32 换成 FP16/BF16 后单元素字节数减少,activation 尤其受益。显存降低后可以增加 micro-batch 或序列长度,也可以减少 activation checkpointing 压力。
容易出数值问题的 softmax、norm、loss、梯度累加、大范围 reduction、某些 log/exp 操作通常需要 FP32 计算或 FP32 accumulation。具体取决于 kernel 实现和模型稳定性。
混合精度降低单个张量的字节数,ZeRO/3D 并行改变张量和计算在多卡上的切分方式。它们可以组合:低精度减少每份数据大小,ZeRO/并行减少每张卡需要持有的数据范围。