Transformer归一化策略:Pre-Norm, Post-Norm与高级技巧
Transformer归一化策略:Pre-Norm, Post-Norm与高级技巧
本文档整理了关于Transformer架构中层归一化(Layer Normalization)的各种策略,以及为提升训练稳定性而设计的多种技巧,并包含了相关的深入问答。
1. Pre-Norm vs. Post-Norm:归一化的位置之争
Pre-Norm 和 Post-Norm 指的是在 Transformer 的一个基本构建块中,层归一化(Layer Normalization)层相对于子层(Sub-layer,即多头自注意力或前馈神经网络)和残差连接(Residual Connection)的位置。这个看似微小的结构差异,对模型的训练稳定性和最终性能有着巨大的影响。
| 特性 | Post-Norm (原始结构) | Pre-Norm (现代标准) |
|---|---|---|
| 结构 | output = LayerNorm(x + Sublayer(x)) |
output = x + Sublayer(LayerNorm(x)) |
| 训练稳定性 | 差,尤其在深层网络中 | 非常好 |
| 梯度问题 | 容易出现梯度爆炸或消失 | 梯度被有效约束,流动稳定 |
| 学习率预热 | 通常必需,且需要精心设计 | 通常不需要或需要很短的预热 |
| 应用 | 原始 Transformer 论文 | GPT、BERT、Llama 等几乎所有现代 LLMs |
Brainstorm Q&A: Pre-Norm 与 Post-Norm 的深层差异
问:后归一化(Post-Norm)最终也等效于在每个模块前进行了归一化,因为前一个模块的输出被归一化了,这和预归一化(Pre-Norm)怎么理解区别?
答: 这是一个非常精准且容易混淆的观察。这个推理忽略了最关键的角色:**残差连接 (Residual Connection)**。关键的区别在于梯度流(反向传播):
- Post-Norm 的问题:梯度在反向传播时,有一路会直接沿着残差主干道(
x)向后传递,这条路径上的梯度是未经处理和约束的。在深层网络中,这些来自不同模块的、未经约束的梯度会在主干道上不断累加,最终可能导致梯度爆炸。 - Pre-Norm 的优势:
LayerNorm扮演了每个Sublayer的“门卫”。任何梯度想要进入Sublayer进行计算,都必须先通过这个“门卫”的检查和约束。这使得整个梯度传播路径(包括主干道和旁路)都变得非常稳定。
结论:关键的区别在于残差连接 Add 操作的位置,它决定了 LayerNorm 是在残差主干上起作用之前还是之后。Pre-Norm 将 LayerNorm 放在了残差连接之前,有效地控制了每一部分的计算和梯度,从而带来了卓越的训练稳定性。
2. 前沿思想:“双重归一化” (Double Norm)
这是最新的架构思想,旨在融合Pre-Norm和Post-Norm的优点。
- 核心思想:保持残差流的“纯净”,不在主干道上放置任何LayerNorm。所有归一化操作都放在“旁路”(即注意力或FFN模块所在的支路)上。
- 流程:输入
x从主干道复制出来,先经过一次LN(Pre-Norm思想),送入计算模块(如Attention),其输出再经过一次LN(Post-Norm思想),最后才将这个“双重净化”后的结果加回主干道。 - 优势:既有Pre-Norm的稳定性,又有比它更通畅的梯度流。Grok、Gemma 2等新模型已采纳此思想。
3. 注意力机制内部的稳定性技巧
除了调整模块结构,还有一些针对注意力计算本身的稳定性技巧。
a. QK Norm
- 问题: Q和K的点积结果(注意力分数)可能变得非常大,导致Softmax函数输出变得“尖锐”,并引发梯度消失或数值不稳定。
- 解决方案: 在计算Q和K的点积之前,分别为Q和K增加一个独立的归一化层(通常是RMSNorm)。这能有效控制注意力分数的大小,稳定计算过程。
- 注意:Value (V) 向量不参与归一化,因为它负责传递内容,而不是参与相似度计算。
b. Logit soft-capping
- 问题: 与QK Norm类似,防止logits(包括注意力分数或最终输出分数)变得过大。
- 解决方案: 使用
tanh函数对logits进行“软上限”处理:logits ← cap * tanh(logits / cap)。这个公式能平滑地将logits的值限制在[-cap, +cap]的范围内,且处处可导。 - 权衡: 这是一种在稳定性和模型性能之间的权衡。它通过直接限制模型的“置信度”上限来保证训练不崩溃,但可能会轻微损害模型的学习能力。
4. z-loss
- 问题: 针对模型最终输出层的Softmax。如果logits过大,会导致Softmax的分母(配分函数
Z(x))出现浮点数上溢,使训练崩溃。 - 解决方案: 引入一个辅助损失函数
α * log²(Z)。这个损失项会“惩罚”log(Z)偏离0的行为,从而激励模型在训练时主动将logits的整体大小控制在合理范围内。 - 对比: 与QK Norm的“事前审查”不同,z-loss是一种“事后惩罚”的正则化方法。
整理自: WordNet:历久弥新
Transformer归一化策略:Pre-Norm, Post-Norm与高级技巧
http://zl1bks.github.io/2025/09/20/Transformer归一化策略:Pre-Norm, Post-Norm与高级技巧/