Transformer 架构精读:从 Self-Attention 的物理意义到 Multi-Head 的并行之美
「Attention is All You Need」(Vaswani et al., 2017) 不仅是 NLP 的转折点,更是现代大模型(LLM)的物理基石。本文旨在从数学本质、代码实现及系统效率三个维度,深度拆解 Transformer 的核心机制。
1. Self-Attention:非局部的全局关联
传统的 CNN 受限于卷积核大小,RNN 受限于时间步的串行依赖。而 Self-Attention 实现了 $O(1)$ 的路径长度,让序列中任意两个 Token 都能实现“瞬间通信”。
1.1 物理意义:Q、K、V 到底在做什么?
我们可以将 Self-Attention 理解为一个寻址过程:
- Query ($Q$): “我要找什么?”(当前 Token 的需求特征)
- Key ($K$): “我有什么?”(其他 Token 的属性标签)
- Value ($V$): “我能提供什么信息?”(实际承载的内容)
通过 $QK^T$ 计算相似度,模型实际上在学习:在当前语境下,哪些 Token 的信息对我是最重要的。
1.2 缩放因子 $\sqrt{d_k}$ 的数学必要性
为什么公式中要除以 $\sqrt{d_k}$?
- 防止梯度消失:当 $d_k$ 很大时,$QK^T$ 的点积结果方差会很大。
- Softmax 饱和区:如果点积值过大,经过 Softmax 后会进入极度平缓的分散区,导致梯度接近于 0。通过缩放,我们将分布拉回梯度敏感区。
$$
\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V
$$
2. 工程实现:PyTorch 视角下的注意力掩码
在处理 Batch 数据或 Decoder 自回归生成时,Mask(掩码) 是保证模型“不作弊”的关键。
import torch
import torch.nn.functional as F
def scaled_dot_product_attention(Q, K, V, mask=None):
"""
精简版 Scaled Dot-Product Attention 实现
"""
d_k = Q.size(-1)
# 1. 计算注意力分数: (Batch, Head, Seq, Seq)
scores = torch.matmul(Q, K.transpose(-2, -1)) / (d_k ** 0.5)
# 2. 应用掩码:
# 在 Decoder 中,需屏蔽未来信息;在 Padding 处,需屏蔽无效位置
if mask is not None:
# 将 mask 为 0 的位置设为极小值,Softmax 后权重接近 0
scores = scores.masked_fill(mask == 0, -1e9)
# 3. 归一化权重
attn_weights = F.softmax(scores, dim=-1)
# 4. 加权求和得到 context vector
return torch.matmul(attn_weights, V), attn_weights
3. Multi-Head Attention:特征空间的“分而治之”
单一的 Attention 容易让模型陷入局部关注。Multi-Head 的本质是特征空间的并行采样。
- 子空间投影:将 $d_{model}$ 维度的特征切分为 $h$ 个低维空间。
- 语义分工:在实际观测中,不同的 Head 会自发演化出不同的职能——有的 Head 关注句法结构(如动宾关系),有的关注实体指代,有的关注标点符号。
计算复杂度分析:
虽然看似增加了计算量,但在并行计算框架下,Multi-Head 实际上是通过矩阵分块实现的,总参数量与 Single-Head 保持一致(通过 $W^Q, W^K, W^V$ 的降维投影)。
4. 架构对比与系统瓶颈
作为计算机专业学生,我们需要关注算法背后的系统代价:
| 维度 | RNN / LSTM | Transformer |
|---|---|---|
| 并行度 | 差 (依赖前一时刻状态) | 极佳 (矩阵运算天然适配 GPU) |
| 长距离依赖 | 易丢失 (梯度消失/爆炸) | 无损 (任意距离 $O(1)$ 通信) |
| 计算复杂度 | $O(n \cdot d^2)$ | $O(n^2 \cdot d)$ (注意力矩阵平方向增长) |
| 显存瓶颈 | 较低 | 高 (Self-Attention 矩阵随序列长度平方级爆炸) |
5. 关于“Attention 局限性”的思考
尽管 Transformer 极度强大,但其 $O(n^2)$ 的复杂度限制了它处理超长文本(如整本书)的能力。目前学术界的研究热点,如 FlashAttention(通过算子融合减少内存 I/O)和 Linear Attention,正是在试图解决这个系统级的瓶颈。
延伸阅读: