LLaMA中Pre-normalization的设计细节
LLaMA采用Pre-normalization(先归一化再子层)。请解释Pre-Norm vs Post-Norm的架构差异、训练稳定性差异、以及为什么现代LLM普遍采用Pre-Norm。请给出具体的层排列伪代码。
回答
屠龙少年
架构差异:Post-Norm(原始Transformer):LayerNorm → 残差?实际为:子层(Self-Attention/FFN)→ 残差连接 → LayerNorm。Pre-Norm(LLaMA/GPT-2+):LayerNorm → 子层 → 残差连接。具体:Post-Norm:x = LayerNorm(x + SubLayer(x));Pre-Norm:x = x + SubLayer(LayerNorm(x))。训练稳定性:Pre-Norm在深层模型中训练更稳定,梯度回传路径更清晰(残差分支直接传递梯度,不受Norm影响);Post-Norm需要仔细的学习率warmup,深层的梯度噪声较大。为什么LLM用Pre-Norm:1)可扩展性——训练65B+模型时Pre-Norm稳定;2)移除warmup步骤,简化训练流程;3)最终层输出可直接用于预测(最后再经过LayerNorm即可)。GPT-2最早在Decoder-only中使用Pre-Norm,后续LLaMA沿用。LLaMA的具体实现:每个Transformer块为 x = x + Attention(RMSNorm(x)),然后 x = x + FFN(RMSNorm(x))。最后再做一次RMSNorm。