多头注意力机制(Multi-Head Attention)是Transformer模型的核心组件,由Vaswani等人在2017年提出,广泛应用于自然语言处理(NLP)、计算机视觉(CV)等领域。以下从原理、实现到应用展开说明:
1. 核心思想
多头注意力的设计目的是让模型能够并行关注输入的不同子空间信息,通过多组独立的注意力头(Attention Heads)捕捉序列中多样化的依赖关系(如局部/全局、语法/语义等)。
关键概念
- 单头注意力:计算查询(Q)、键(K)、值(V)的加权和,权重由Q与K的相似度决定。
- 多头扩展:将Q、K、V拆分为多组(例如8头),每组独立计算注意力后拼接结果,最后通过线性层融合。
2. 数学实现步骤
-
线性投影:
对输入的Q、K、V分别进行$h$次不同的线性变换($h$为头数),得到$h$组$Q_i, K_i, V_i \in \mathbb{R}^{n \times d_k}$($d_k = d_{\text{model}}/h$)。 -
缩放点积注意力:
每组头计算注意力分数:
$$
\text{Attention}(Q_i, K_i, V_i) = \text{softmax}\left(\frac{Q_i K_i^T}{\sqrt{d_k}}\right) V_i
$$
($\sqrt{d_k}$用于缓解梯度消失) -
多头拼接与融合:
将$h$个头的输出拼接后通过线性层$W^O$映射:
$$
\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, ..., \text{head}_h) W^O
$$
3. 为什么有效?
- 并行捕捉多样化模式:不同头可能关注不同位置或特征(例如一个头看句法,另一个头看指代关系)。
- 模型容量提升:通过多组参数化投影,增强表达能力。
- 计算效率:拆分后维度降低($d_k = d_{\text{model}}/h$),总计算量与单头相近。
4. 典型应用场景
自然语言处理(NLP)
- 机器翻译(如Transformer):编码器捕捉源语言结构,解码器交叉注意力关联目标语言。
- 文本生成(如GPT):自回归解码时关注前文关键信息。
- BERT:通过自注意力理解上下文依赖(如消歧任务)。
计算机视觉(CV)
- ViT(Vision Transformer):将图像分块为序列,用多头注意力建模全局关系。
- 目标检测(如DETR):注意力机制替代传统卷积,直接预测物体关系。
多模态任务
- CLIP:对齐图像和文本的跨模态注意力。
- 语音识别:融合声学与文本特征的混合注意力。
5. 代码示例(PyTorch)
import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiHeadAttention(nn.Module):
def __init__(self, d_model=512, h=8):
super().__init__()
self.d_k = d_model // h
self.h = h
self.W_q = nn.Linear(d_model, d_model) # Q投影
self.W_k = nn.Linear(d_model, d_model) # K投影
self.W_v = nn.Linear(d_model, d_model) # V投影
self.W_o = nn.Linear(d_model, d_model) # 输出融合
def forward(self, x):
batch_size = x.size(0)
# 投影到Q/K/V并分头 (batch_size, seq_len, h, d_k)
Q = self.W_q(x).view(batch_size, -1, self.h, self.d_k).transpose(1, 2)
K = self.W_k(x).view(batch_size, -1, self.h, self.d_k).transpose(1, 2)
V = self.W_v(x).view(batch_size, -1, self.h, self.d_k).transpose(1, 2)
# 缩放点积注意力
scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.d_k))
attn = F.softmax(scores, dim=-1)
out = torch.matmul(attn, V) # (batch_size, h, seq_len, d_k)
# 拼接多头结果并融合
out = out.transpose(1, 2).contiguous().view(batch_size, -1, self.h * self.d_k)
return self.W_o(out)
6. 改进与变体
- 稀疏注意力(如Longformer):降低长序列的计算复杂度。
- 相对位置编码(如Transformer-XL):解决位置编码的泛化问题。
- 跨模态注意力(如UniT):融合视觉和语言的多头注意力。
多头注意力通过并行化、子空间分解和特征融合,显著提升了模型对复杂关系的建模能力,成为现代深度学习架构的基石。理解其机制有助于灵活应用于不同任务或设计新的注意力变体。