多头注意力机制形状变化¶
约 700 个字 83 行代码 预计阅读时间 5 分钟
QKV¶
关于 Transformer 的 QKV 同源问题
\(QK^TV = B_QL_QD_Q \cdot B_KD_KL_K \cdot B_VL_VD_V\)
(1)回答同源问题:
- 自注意力机制中(selfA),\(QKV\) 来自同一个序列
- 交叉注意力机制中(crossA),\(Q\) 来自一个序列,\(KV\) 来自另一个序列,捕捉两个不同序列之间的依赖关系
(2)讨论形状问题:
- 必须满足的是: \(D_Q = D_K\) 、 \(L_K = L_V\)
- 简化版本: \(attn = L_Q \times L_K\)
多头注意力机制:
的形状变化:BLD \(\stackrel{D=H\cdot d}{\rightarrow}\) BLHd
形状变化¶
Python
def forward(self, q, k, v, mask=None):
batch_size, sequence_length, model_dim = q.shape
# [B, L, d_model] → 线性投影 → [B, L, H*d_k] → view → [B, L, H,d_k] → transpose → [B, H, L, d_k]
q = self.q_proj(q).view(batch_size, sequence_length, self.num_heads, self.head_dim).transpose(1, 2)
# [B, L, d_model] → 线性投影 → [B, L, H*d_k] → view → [B, L, H,d_k] → transpose → [B, H, L, d_k]
k = self.k_proj(k).view(batch_size, sequence_length, self.num_heads, self.head_dim).transpose(1, 2)
# [B, L, d_model] → 线性投影 → [B, L, H*d_v] → view → [B, L, H,d_v] → transpose → [B, H, L, d_v]
v = self.v_proj(v).view(batch_size, sequence_length, self.num_heads, self.head_dim).transpose(1, 2)
# [B, H, L, d_k] × [B, H, d_k, S] → 矩阵乘法(torch.matmul) → [B, H, L, S]
scores = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(self.head_dim)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e09)
# [B, H, L, S] → softmax → [B, H, L, S]
prob = F.softmax(scores, dim=-1)
prob = self.dropout(prob)
# [B, H, L, S] × [B, H, S, d_v] → 矩阵乘法(torch.matmul) → [B, H, L, d_v] → 转置(.transpose) → [B, L, H, d_v] → 重塑(.view) → [B, L, d_model]
attn_weights = torch.matmul(prob, v).transpose(1, 2).contiguous().view(batch_size, sequence_length, model_dim)
# 恢复原始维度:[B, L, d_model] → 线性投影 → [B, L, d_model]
output = self.o_proj(attn_weights)
return output
(1)查询(queries):
[B, L, d_model] → 线性投影(nn.Linear) → [B, L, H*d_k] → 重塑(.view) → [B, L, H, d_k] → 转置(.transpose) → [B, H, L, d_k]
- 通过线性投影将查询向量投影到多头注意力空间,并重塑为多头的形状。
Python
q = self.q_proj(q).view(batch_size,sequence_length,self.num_heads,self.head_dim).transpose(1,2)
(2)键(keys): \(D_Q = D_K \triangleq d_k\)
[B, S, d_model] → 线性投影(nn.Linear) → [B, L, H*d_k] → 重塑(.view) → [B, L, H, d_k] → 转置(.transpose) → [B, H, L, d_k]
- 通过线性投影将键向量投影到多头注意力空间,并重塑为多头的形状。
- 更准确的记号,表示 QK的维度必须相同 \(\triangleq d_k\) ,KV 的长度必须相同 \(\triangleq S\)
简化版
- \(Q \triangleq BLD_k、 K \triangleq BSD_k 、V \triangleq BSD_v\)
- linear 、view、transpose \(Q:BLD_k -> BLHd_k -> BHLd_k\)
- \(K:BSD_k -> BSHd_k -> BHSd_k\)
- \(V:BSD_v -> BSHd_v -> BHSd_v\)
- torch.matmal \(QK^T: BHLd_k \cdot BHd_kS -> BHLS\)
- softmax 、dropout、torch.matmul \(sonftmax(QK^T)V ->BHLS \cdot BHSd_v -> BHLd_v -> transpose -> view -> BLD\)
(3)值(values):
[B, L, d_model] → 线性投影(nn.Linear) → [B, L, H*d_v] → 重塑(.view) → [B, L, H, d_v] → 转置(.transpose) → [B, H, L, d_v]
- 通过线性投影将值向量投影到多头注意力空间,并重塑为多头的形状。
(4)计算注意力得分:
[B, H, L, d_k] × ([B, H, L, d_k] → transpose → [B, H, d_k, L]) → 矩阵乘法(torch.matmul) → [B, H, L, L]
- 通过矩阵乘法计算查询和键之间的点积注意力得分,并进行缩放。
(5)计算注意力加权值:
[B, H, L, L] → softmax → [B, H, L, L]
- 通过softmax计算注意力权重。
(6)应用注意力权重:
[B, H, L, L] × [B, H, L, d_v] → 矩阵乘法 → [B, H, L, d_v] → 转置 → [B, L, H, d_v] → 重塑 → [B, L, d_model]
- 通过矩阵乘法将注意力权重应用于值向量,并重塑为原始形状。
(7)输出: 还原维度
[B, L, d_model] → 线性投影 → [B, L, d_model]
- 通过线性投影将合并的输出恢复到原始维度。
📢 简化版
- \(Q \triangleq BLD_k、 K \triangleq BSD_k 、V \triangleq BSD_v\)
- linear 、view、transpose \(Q:BLD_k -> BLHd_k -> BHLd_k\)
- \(K:BSD_k -> BSHd_k -> BHSd_k\)
- \(V:BSD_v -> BSHd_v -> BHSd_v\)
- torch.matmal \(QK^T: BHLd_k \cdot BHd_kS -> BHLS\)
- softmax 、dropout、torch.matmul \(sonftmax(QK^T)V ->BHLS \cdot BHSd_v -> BHLd_v -> transpose -> view -> BLD\)
完整的代码,得会呀
Python
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class MultiHeadAttention(nn.Module):
def __init__(self,model_dim,num_heads,dropout=0.1):
super().__init__()
self.model_dim = model_dim
self.num_heads = num_heads
self.head_dim = model_dim // num_heads
self.q_proj = nn.Linear(model_dim,model_dim)
self.k_proj = nn.Linear(model_dim,model_dim)
self.v_proj = nn.Linear(model_dim,model_dim)
self.dropout = nn.Dropout(dropout)
self.o_proj = nn.Linear(model_dim,model_dim)
def forward(self,q,k,v,mask=None):
batch_size,sequence_length,model_dim = q.shape
q = self.q_proj(q).view(batch_size,sequence_length,self.num_heads,self.head_dim).transpose(1,2)
k = self.k_proj(k).view(batch_size,sequence_length,self.num_heads,self.head_dim).transpose(1,2)
v = self.v_proj(v).view(batch_size,sequence_length,self.num_heads,self.head_dim).transpose(1,2)
scores = torch.matmul(q,k.transpose(-1,-2))//math.sqrt(self.head_dim)
if mask is not None:
scores = scores.masked_fill(mask==0,-1e09)
prob = F.softmax(scores,dim=-1)
prob = self.dropout(prob)
attn_weights = torch.matmul(prob,v).transpose(1,2).contiguous().view(batch_size,sequence_length,model_dim)
output = self.o_proj(attn_weights)
return output
model_dim = 512
num_heads = 8
mha = MultiHeadAttention(model_dim=model_dim, num_heads=num_heads, dropout=0.1)
batch_size = 10
sequence_length = 60
q = torch.randn(batch_size, sequence_length, model_dim)
k = torch.randn(batch_size, sequence_length, model_dim)
v = torch.randn(batch_size, sequence_length, model_dim)
mask = None
output = mha(q, k, v, mask)
print(output.shape) # 输出的形状应该是(batch_size, sequence_length, model_dim)
2025-03-19 17:24:12 2025-03-19 21:01:48