deepseek-v3.2-exp的闪电索引器

我们可以把 DeepSeek 稀疏注意力(DeepSeek Sparse Attention, DSA)中的闪电索引器(Lightning Indexer) 想象成一位专门负责阅读和检索《红楼梦》全书信息的“记忆筛选专家”。

例如,《红楼梦》全书篇幅巨大,如果我们想让一个语言模型(比如 DeepSeek-V3.2-Exp)记住书里的所有细节,并在读到某个句子时能立刻回想起所有相关信息,效率是个大问题。

1. 核心挑战:全书阅读的 $O(L^2)$ 困境

想象《红楼梦》全书有 $L$ 个 token(可以把一个字或词语看作一个 token)。当模型读到第 $L$ 个字(比如“散”)时,如果它需要同时回顾并计算之前 $L-1$ 个字中每一个字对“散”这个字的影响,那么总体的计算量就是 $L \times L$,即 $O(L^2)$ 复杂度。对于 128K 这样长的上下文,$L^2$ 的计算量是难以承受的。

闪电索引器就是用来解决这个效率问题的关键工具。

2. 闪电索引器的作用与机制

DSA 的原型(prototype)主要由两部分组成:lightning indexerfine-grained token selection mechanism(细粒度 token 选择机制)。

A. 计算关联性:索引分数(Index Score)

闪电索引器扮演了“筛选专家”的角色。当模型读到当前的查询 token ($h_t$),比如“宝玉哭了”,它需要迅速判断出书本前面所有的先前 token ($h_s$) 中,哪些是高度相关的(比如“黛玉回天乏术”)。

闪电索引器就是通过计算索引分数 ($I_{t,s}$) 来判断这种关联性的。

计算公式(公式 1) 闪电索引器使用一个高效的公式来计算索引分数: $$I_{t,s} = \sum_{j=1}^{H^I} w_{t,j}^I \cdot \text{ReLU}(q_{t,j}^I \cdot k_s^I) \quad \text{(1)} \text{}$$

  • $H^I$ (Indexer Heads):索引器拥有少量的头部。
    • 定义:$H^I$ 表示索引器头部的数量(the number of indexer heads)。
    • 作用:在多头注意力(Multi-Head Attention)的结构中,头部允许模型从不同的表示子空间中捕获信息。这里的 $H^I$ 表示索引器利用了多个独立的“视角”或“计算通道”来计算索引分数。公式中的求和符号 $\sum_{j=1}^{H^I}$ 表示将所有这些头部计算出的分数进行累加。
    • 效率考量:闪电索引器设计得非常高效,其中一个原因是它具有少量的头部(a small number of heads)。
  • $h_t$ (Query Token):当前的 token(“宝玉哭了”)会生成 $q_{t,j}^I$ 和 $w_{t,j}^I$。
  • $h_s$ (Preceding Token):前面的 token(“黛玉回天乏术”)会生成 $k_s^I$。
  • ReLU:选择 ReLU 作为激活函数是出于对吞吐量(throughput consideration)的考量,意味着它计算速度快。
    • 作用:ReLU 是一种激活函数(activation function)。在公式(1)中,它应用于 $q_{t,j}^I$ 和 $k_s^I$ 的点积(dot product)之后。
    • 选择原因:选择 ReLU 作为激活函数是出于吞吐量(throughput consideration)的考量。这意味着 ReLU 相比其他复杂的激活函数,计算速度更快,有助于提高整体的计算效率。
  • 效率:虽然理论上它也要遍历所有先前 token,但由于它设计轻量(头部少,可用 FP8 实现),其计算效率是显著的(computational efficiency is remarkable),计算量远小于之前 DeepSeek-V3.1-Terminus 中使用的 MLA。

B. 筛选记忆:Top-k 选择

当闪电索引器为“宝玉哭了”计算完所有先前句子的索引分数 ${I_{t,s}}$ 之后,细粒度 token 选择机制就会启动:

  1. 它只检索对应于 Top-k 索引分数的 key-value entries ${c_s}$。
  2. 假设《红楼梦》全书有 10 万个 token,但 Top-k 只选择 2048 个 key-value tokens
  3. 最终模型的注意力输出 ($u_t$),只在这 2048 个被稀疏选择的 key-value entries ${c_s}$ 上进行计算。

这个机制将主模型的核心注意力复杂度(core attention complexity)从 $O(L^2)$ 降低到了 $O(L k)$,其中 $k$ (2048) 远小于 $L$ (128K)。

上面的公式还是难以理解,不太能够理解其中的细节,我们通过deepseek-v3.2-exp的代码库中的 fp8_index_kernel 这个闪电索引器的实现来尝试解释下。

公式与代码的对应关系

符号映射

公式中的各个符号在代码中的对应关系如下:

  • $I_{t,s}$: 对应输出张量 o[i_b, i_m, i1_n * blk_n1 + i2_n * blk_n2],表示位置 $t$ 的查询对位置 $s$ 的键的索引分数
1
T.copy(logits_sum, o[i_b, i_m, i1_n * blk_n1 + i2_n * blk_n2])
  • $H^I$: 对应参数 h,即注意力头的数量
1
def fp8_index_kernel(h: int, d: int):
  • $j$: 对应循环变量 i_h,遍历所有注意力头
1
2
3
4
5
6
7
8
9
10
11
T.gemm(
k_smem,
q_smem,
logits,
transpose_A=False,
transpose_B=True,
clear_accum=True,
)
for i_h, i3_n in T.Parallel(h, blk_n2):
logits[i3_n, i_h] = T.max(logits[i3_n, i_h], 0) * q_s_frag[i_h]
  • $w_{t,j}^I$: 对应查询缩放因子 q_s_frag[i_h],即 q_s[i_b, i_m, i_h]
1
2
q_s_frag = T.alloc_fragment(h, FP32)
T.copy(q_s[i_b, i_m, 0], q_s_frag)
  • $q_{t,j}^I$: 对应查询张量 q_smem,即 q[i_b, i_m, :, :] 的第 i_h 个头
1
2
q_smem = T.alloc_shared((h, d), FP8)
T.copy(q[i_b, i_m, 0, 0], q_smem)
  • $k_s^I$: 对应键张量 k_smem,即 k[i_b, s, :]
1
2
k_smem = T.alloc_shared((blk_n2, d), FP8)
T.copy(k[i_b, i1_n * blk_n1 + i2_n * blk_n2, 0], k_smem)

计算步骤对应

公式的计算在代码中按以下步骤实现:

1. 点积 $q_{t,j}^I \cdot k_s^I$: 通过矩阵乘法 T.gemm(k_smem, q_smem, logits, ...) 计算,结果存储在 logits[i3_n, i_h] 中,对应查询头 $j$ 与键位置 $s$ 的点积

2. ReLU 激活: 使用 T.max(logits[i3_n, i_h], 0) 实现 $\text{ReLU}(q_{t,j}^I \cdot k_s^I)$

3. 乘以权重 $w_{t,j}^I$: 在同一行代码中,ReLU 结果乘以查询缩放因子 q_s_frag[i_h],即 logits[i3_n, i_h] = T.max(logits[i3_n, i_h], 0) * q_s_frag[i_h]

4. 跨头求和 $\sum_{j=1}^{H^I}$: 通过 T.reduce_sum(logits, logits_sum, dim=1) 在头维度(dim=1)上求和,实现对所有头的累加

5. 键缩放: 最后乘以键缩放因子 k_s_frag[i3_n],即 logits_sum[i3_n] *= k_s_frag[i3_n]

完整计算流程

代码实际计算的完整公式为:

$$I_{t,s} = k_s[s] \cdot \sum_{j=1}^{H^I} q_s[t,j] \cdot \text{ReLU}(q[t,j] \cdot k[s])$$

这与您给出的公式一致,其中 $w_{t,j}^I = q_s[t,j]$ 是查询的缩放因子,而键的缩放因子 $k_s[s]$ 在求和后统一应用。

ReLU 是目前最流行的激活函数之一,在深度学习中广泛应用。
数学定义: f(x) = max(0, x)
简单来说:

  • 当 x > 0 时,输出 x
  • 当 x ≤ 0 时,输出 0