在 Transformer 中,Multi-head attention 的计算过程是:
MultiHeadAttn
(
z
q
,
x
)
=
∑
m
=
1
M
W
m
[
∑
k
∈
Ω
k
A
m
q
k
⋅
W
m
′
x
k
]
\text{MultiHeadAttn}(z_q, \mathbb{x}) = \sum_{m=1}^M W_m[\sum_{k\in \Omega_k} A_{mqk} \cdot {W'_m} \mathbb{x}_k]
MultiHeadAttn(zq,x)=∑m=1MWm[∑k∈ΩkAmqk⋅Wm′xk].
其中
m
m
m是 attention head 的索引,
W
m
′
∈
R
C
v
×
C
{W_m}'\in \mathbb{R}^{C_v\times C}
Wm′∈RCv×C 是输入的映射矩阵,
W
m
∈
R
C
×
C
v
{W_m}\in \mathbb{R}^{C\times C_v}
Wm∈RC×Cv 是输出的映射矩阵,二者都是可学习的权重(
C
v
=
C
/
M
C_v = C/M
Cv=C/M)。Attention 权重
A
m
q
k
∝
exp
{
z
q
T
U
m
T
V
m
x
k
C
v
}
A_{mqk}\propto \exp\lbrace \frac{z_q^T U_m^T V_m x_k}{\sqrt{C_v}} \rbrace
Amqk∝exp{CvzqTUmTVmxk},并且
∑
k
∈
Ω
k
A
m
q
k
=
1
\sum_{k\in \Omega_{k}} A_{mqk}=1
∑k∈ΩkAmqk=1,其中
U
m
,
V
m
∈
R
C
v
×
C
U_m,V_m \in \mathbb{R}^{C_v\times C}
Um,Vm∈RCv×C分别是 query 的映射矩阵和 key 的映射矩阵,也都是可学习权重。设 query 和 key 元素的个数分别是
N
q
N_q
Nq和
N
k
N_k
Nk.
MultiHeadAttn
(
z
q
,
x
)
\text{MultiHeadAttn}(z_q, \mathbb{x})
MultiHeadAttn(zq,x)的计算复杂度是
O
(
N
q
C
2
+
N
k
C
2
+
N
q
N
k
C
)
O(N_q C^2 + N_k C^2 + N_q N_k C)
O(NqC2+NkC2+NqNkC)。
输入是
X
∈
R
N
×
C
X\in \mathbb{R}^{N\times C}
X∈RN×C,用
U
m
,
V
m
∈
R
C
v
×
C
U_m,V_m \in \mathbb{R}^{C_v\times C}
Um,Vm∈RCv×C分别对 query 和 key 做线性变换,计算得到
Q
,
K
∈
R
N
×
C
Q,K\in \mathbb{R}^{N\times C}
Q,K∈RN×C矩阵。这样,计算
Q
Q
Q和
K
K
K的复杂度就是
O
(
N
q
×
C
2
)
O(N_q\times C^2)
O(Nq×C2)和
O
(
N
k
×
C
2
)
O(N_k\times C^2)
O(Nk×C2).
然后计算
A
m
q
k
∝
exp
{
z
q
T
U
m
T
V
m
x
k
C
v
}
A_{mqk}\propto \exp\lbrace \frac{z_q^T U_m^T V_m x_k}{\sqrt{C_v}} \rbrace
Amqk∝exp{CvzqTUmTVmxk},复杂度是
O
(
N
q
×
N
k
×
C
)
O(N_q \times N_k \times C)
O(Nq×Nk×C).
A
m
q
k
A_{mqk}
Amqk与
x
k
x_k
xk相乘,计算复杂度是
O
(
N
q
×
N
k
×
C
)
O(N_q \times N_k \times C)
O(Nq×Nk×C).
总体的计算复杂度就是
O
(
N
q
×
C
2
+
N
k
×
C
2
+
N
q
N
k
C
)
O(N_q\times C^2 + N_k\times C^2 + N_q N_k C)
O(Nq×C2+Nk×C2+NqNkC).
在 DETR 中,Transformer encoder 的 query 和 key 元素就是特征图上的像素点,假设输入特征图的宽度和高度分别是
W
W
W和
H
H
H。
Encoder 中的 self-attention 的计算复杂度就是
O
(
H
2
W
2
C
)
O(H^2W^2C)
O(H2W2C).
Decoder 包括了 self attention 和 cross attention,输入包括来自于 encoder 的特征图、
N
N
N个 object queries。
在 decoder 的 cross attention 中,query 元素来自于 object queries,key 元素来自于 encoder 特征图,从 encoder 提供的特征图上提取 key 元素,
N
q
=
N
,
N
k
=
H
×
W
N_q=N, N_k=H\times W
Nq=N,Nk=H×W,计算复杂度是
O
(
N
k
C
2
+
N
N
k
C
)
=
O
(
H
W
C
2
+
N
H
W
C
2
)
O(N_kC^2+NN_kC)=O(HWC^2+NHWC^2)
O(NkC2+NNkC)=O(HWC2+NHWC2).
在 decoder 的 self attention 中,object queries 相互作用,query 和 key 元素都来自于 object queries。
N
q
=
N
k
=
N
N_q=N_k=N
Nq=Nk=N,复杂度就是
O
(
2
N
C
2
+
N
2
C
)
O(2NC^2 + N^2C)
O(2NC2+N2C).