我们可以把 DeepSeek 稀疏注意力(DeepSeek Sparse Attention, DSA)中的闪电索引器(Lightning Indexer) 想象成一位专门负责阅读和检索《红楼梦》全书信息的“记忆筛选专家”。
例如,《红楼梦》全书篇幅巨大,如果我们想让一个语言模型(比如 DeepSeek-V3.2-Exp)记住书里的所有细节,并在读到某个句子时能立刻回想起所有相关信息,效率是个大问题。
1. 核心挑战:全书阅读的 $O(L^2)$ 困境
想象《红楼梦》全书有 $L$ 个 token(可以把一个字或词语看作一个 token)。当模型读到第 $L$ 个字(比如“散”)时,如果它需要同时回顾并计算之前 $L-1$ 个字中每一个字对“散”这个字的影响,那么总体的计算量就是 $L \times L$,即 $O(L^2)$ 复杂度。对于 128K 这样长的上下文,$L^2$ 的计算量是难以承受的。
闪电索引器就是用来解决这个效率问题的关键工具。
2. 闪电索引器的作用与机制
DSA 的原型(prototype)主要由两部分组成:lightning indexer 和 fine-grained token selection mechanism(细粒度 token 选择机制)。
A. 计算关联性:索引分数(Index Score)
闪电索引器扮演了“筛选专家”的角色。当模型读到当前的查询 token ($h_t$),比如“宝玉哭了”,它需要迅速判断出书本前面所有的先前 token ($h_s$) 中,哪些是高度相关的(比如“黛玉回天乏术”)。
闪电索引器就是通过计算索引分数 ($I_{t,s}$) 来判断这种关联性的。
计算公式(公式 1) 闪电索引器使用一个高效的公式来计算索引分数: $$I_{t,s} = \sum_{j=1}^{H^I} w_{t,j}^I \cdot \text{ReLU}(q_{t,j}^I \cdot k_s^I) \quad \text{(1)} \text{}$$
- $H^I$ (Indexer Heads):索引器拥有少量的头部。
- 定义:$H^I$ 表示索引器头部的数量(the number of indexer heads)。
- 作用:在多头注意力(Multi-Head Attention)的结构中,头部允许模型从不同的表示子空间中捕获信息。这里的 $H^I$ 表示索引器利用了多个独立的“视角”或“计算通道”来计算索引分数。公式中的求和符号 $\sum_{j=1}^{H^I}$ 表示将所有这些头部计算出的分数进行累加。
- 效率考量:闪电索引器设计得非常高效,其中一个原因是它具有少量的头部(a small number of heads)。
- $h_t$ (Query Token):当前的 token(“宝玉哭了”)会生成 $q_{t,j}^I$ 和 $w_{t,j}^I$。
- $h_s$ (Preceding Token):前面的 token(“黛玉回天乏术”)会生成 $k_s^I$。
- ReLU:选择 ReLU 作为激活函数是出于对吞吐量(throughput consideration)的考量,意味着它计算速度快。
- 作用:ReLU 是一种激活函数(activation function)。在公式(1)中,它应用于 $q_{t,j}^I$ 和 $k_s^I$ 的点积(dot product)之后。
- 选择原因:选择 ReLU 作为激活函数是出于吞吐量(throughput consideration)的考量。这意味着 ReLU 相比其他复杂的激活函数,计算速度更快,有助于提高整体的计算效率。
- 效率:虽然理论上它也要遍历所有先前 token,但由于它设计轻量(头部少,可用 FP8 实现),其计算效率是显著的(computational efficiency is remarkable),计算量远小于之前 DeepSeek-V3.1-Terminus 中使用的 MLA。
B. 筛选记忆:Top-k 选择
当闪电索引器为“宝玉哭了”计算完所有先前句子的索引分数 ${I_{t,s}}$ 之后,细粒度 token 选择机制就会启动:
- 它只检索对应于 Top-k 索引分数的 key-value entries ${c_s}$。
- 假设《红楼梦》全书有 10 万个 token,但 Top-k 只选择 2048 个 key-value tokens。
- 最终模型的注意力输出 ($u_t$),只在这 2048 个被稀疏选择的 key-value entries ${c_s}$ 上进行计算。
这个机制将主模型的核心注意力复杂度(core attention complexity)从 $O(L^2)$ 降低到了 $O(L k)$,其中 $k$ (2048) 远小于 $L$ (128K)。
上面的公式还是难以理解,不太能够理解其中的细节,我们通过deepseek-v3.2-exp的代码库中的 fp8_index_kernel 这个闪电索引器的实现来尝试解释下。
公式与代码的对应关系
符号映射
公式中的各个符号在代码中的对应关系如下:
- $I_{t,s}$: 对应输出张量
o[i_b, i_m, i1_n * blk_n1 + i2_n * blk_n2],表示位置 $t$ 的查询对位置 $s$ 的键的索引分数
|
|
- $H^I$: 对应参数
h,即注意力头的数量
|
|
- $j$: 对应循环变量
i_h,遍历所有注意力头
|
|
- $w_{t,j}^I$: 对应查询缩放因子
q_s_frag[i_h],即q_s[i_b, i_m, i_h]
|
|
- $q_{t,j}^I$: 对应查询张量
q_smem,即q[i_b, i_m, :, :]的第i_h个头
|
|
- $k_s^I$: 对应键张量
k_smem,即k[i_b, s, :]
|
|
计算步骤对应
公式的计算在代码中按以下步骤实现:
1. 点积 $q_{t,j}^I \cdot k_s^I$: 通过矩阵乘法 T.gemm(k_smem, q_smem, logits, ...) 计算,结果存储在 logits[i3_n, i_h] 中,对应查询头 $j$ 与键位置 $s$ 的点积
2. ReLU 激活: 使用 T.max(logits[i3_n, i_h], 0) 实现 $\text{ReLU}(q_{t,j}^I \cdot k_s^I)$
3. 乘以权重 $w_{t,j}^I$: 在同一行代码中,ReLU 结果乘以查询缩放因子 q_s_frag[i_h],即 logits[i3_n, i_h] = T.max(logits[i3_n, i_h], 0) * q_s_frag[i_h]
4. 跨头求和 $\sum_{j=1}^{H^I}$: 通过 T.reduce_sum(logits, logits_sum, dim=1) 在头维度(dim=1)上求和,实现对所有头的累加
5. 键缩放: 最后乘以键缩放因子 k_s_frag[i3_n],即 logits_sum[i3_n] *= k_s_frag[i3_n]
完整计算流程
代码实际计算的完整公式为:
$$I_{t,s} = k_s[s] \cdot \sum_{j=1}^{H^I} q_s[t,j] \cdot \text{ReLU}(q[t,j] \cdot k[s])$$
这与您给出的公式一致,其中 $w_{t,j}^I = q_s[t,j]$ 是查询的缩放因子,而键的缩放因子 $k_s[s]$ 在求和后统一应用。
ReLU 是目前最流行的激活函数之一,在深度学习中广泛应用。
数学定义:f(x) = max(0, x)
简单来说:
- 当 x > 0 时,输出 x
- 当 x ≤ 0 时,输出 0
