RMSNorm的原理及相比LayerNorm的改进
LLaMA使用RMSNorm替代传统LayerNorm。请解释RMSNorm的数学公式、计算方式,相比LayerNorm去除了哪些组件?为什么去除后性能反而接近或更好?RMSNorm的计算效率优势如何体现?
回答
小字辈
RMSNorm公式:RMSNorm(x) = x / RMS(x) × γ,其中RMS(x) = sqrt(mean(x²)),γ为可学习缩放参数。与LayerNorm对比:LayerNorm计算均值和方差后进行归一化:LN(x) = (x-μ)/σ × γ + β。RMSNorm去除了均值计算(μ=False)、去除了α减法和β偏置项。为什么有效:RMSNorm仅通过均方根对激活进行缩放,保留了方向信息。研究表明LN中均值减法对Transformer性能贡献很小,去除后模型性能几乎不变甚至略好(更简洁的归一化减少了噪声)。效率优势:减少了一次均值计算和减法操作,在反向传播中梯度计算更简单。在大规模模型中,RMSNorm约节省5-10%的归一化计算时间,累积效果显著。LLaMA、Mistral、Falcon等主流开源模型均采用RMSNorm。