Transformer 架构精读:从 Self-Attention 的物理意义到 Multi-Head 的并行之美

「Attention is All You Need」(Vaswani et al., 2017) 不仅是 NLP 的转折点,更是现代大模型(LLM)的物理基石。本文旨在从数学本质、代码实现及系统效率三个维度,深度拆解 Transformer 的核心机制。

1. Self-Attention:非局部的全局关联

传统的 CNN 受限于卷积核大小,RNN 受限于时间步的串行依赖。而 Self-Attention 实现了 $O(1)$ 的路径长度,让序列中任意两个 Token 都能实现“瞬间通信”。

1.1 物理意义:Q、K、V 到底在做什么?

我们可以将 Self-Attention 理解为一个寻址过程

  • Query ($Q$): “我要找什么?”(当前 Token 的需求特征)
  • Key ($K$): “我有什么?”(其他 Token 的属性标签)
  • Value ($V$): “我能提供什么信息?”(实际承载的内容)

通过 $QK^T$ 计算相似度,模型实际上在学习:在当前语境下,哪些 Token 的信息对我是最重要的。

1.2 缩放因子 $\sqrt{d_k}$ 的数学必要性

为什么公式中要除以 $\sqrt{d_k}$?

  • 防止梯度消失:当 $d_k$ 很大时,$QK^T$ 的点积结果方差会很大。
  • Softmax 饱和区:如果点积值过大,经过 Softmax 后会进入极度平缓的分散区,导致梯度接近于 0。通过缩放,我们将分布拉回梯度敏感区。

$$
\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V
$$


2. 工程实现:PyTorch 视角下的注意力掩码

在处理 Batch 数据或 Decoder 自回归生成时,Mask(掩码) 是保证模型“不作弊”的关键。

import torch
import torch.nn.functional as F

def scaled_dot_product_attention(Q, K, V, mask=None):
    """
    精简版 Scaled Dot-Product Attention 实现
    """
    d_k = Q.size(-1)
    
    # 1. 计算注意力分数: (Batch, Head, Seq, Seq)
    scores = torch.matmul(Q, K.transpose(-2, -1)) / (d_k ** 0.5)
    
    # 2. 应用掩码: 
    # 在 Decoder 中,需屏蔽未来信息;在 Padding 处,需屏蔽无效位置
    if mask is not None:
        # 将 mask 为 0 的位置设为极小值,Softmax 后权重接近 0
        scores = scores.masked_fill(mask == 0, -1e9)
    
    # 3. 归一化权重
    attn_weights = F.softmax(scores, dim=-1)
    
    # 4. 加权求和得到 context vector
    return torch.matmul(attn_weights, V), attn_weights

3. Multi-Head Attention:特征空间的“分而治之”

单一的 Attention 容易让模型陷入局部关注。Multi-Head 的本质是特征空间的并行采样

  • 子空间投影:将 $d_{model}$ 维度的特征切分为 $h$ 个低维空间。
  • 语义分工:在实际观测中,不同的 Head 会自发演化出不同的职能——有的 Head 关注句法结构(如动宾关系),有的关注实体指代,有的关注标点符号

计算复杂度分析
虽然看似增加了计算量,但在并行计算框架下,Multi-Head 实际上是通过矩阵分块实现的,总参数量与 Single-Head 保持一致(通过 $W^Q, W^K, W^V$ 的降维投影)。


4. 架构对比与系统瓶颈

作为计算机专业学生,我们需要关注算法背后的系统代价

维度 RNN / LSTM Transformer
并行度 差 (依赖前一时刻状态) 极佳 (矩阵运算天然适配 GPU)
长距离依赖 易丢失 (梯度消失/爆炸) 无损 (任意距离 $O(1)$ 通信)
计算复杂度 $O(n \cdot d^2)$ $O(n^2 \cdot d)$ (注意力矩阵平方向增长)
显存瓶颈 较低 高 (Self-Attention 矩阵随序列长度平方级爆炸)

5. 关于“Attention 局限性”的思考

尽管 Transformer 极度强大,但其 $O(n^2)$ 的复杂度限制了它处理超长文本(如整本书)的能力。目前学术界的研究热点,如 FlashAttention(通过算子融合减少内存 I/O)和 Linear Attention,正是在试图解决这个系统级的瓶颈。


延伸阅读