正文
else:
assert self.attn_backend.num_wrappers == 1
self.update = self.update_single_wrapper
另外需要说明的是forward_metadata,forward metadata 和 updater 对应,包括两种,由于sglang scheduler 本身同时只会有一个forward batch,所以只需要一份forward metadata 即可。
# Other metadata
self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None
self.decode_cuda_graph_metadata = {}
self.prefill_cuda_graph_metadata = {}
这里的forward_metadata 其实主要是Wrapper 和周围参数,每次forward 都会在metadata里指定wrapper。
@dataclassclassDecodeMetadata:
decode_wrappers:List[BatchDecodeWithPagedKVCacheWrapper]@dataclassclassPrefillMetadata:
prefill_wrappers:List[BatchPrefillWithPagedKVCacheWrapper]
use_ragged:bool
extend_no_prefix:bool
wrapper 的数据结构和初始化过程
首先是检查入参,kv layout 只支持NHD和HND。
_check_kv_layout(kv_layout)
有关jit 的功能, 我们先不管。所以直接看内部新引入的资源,以下以BatchPrefillWithPagedKVCacheWrapper 为例。
## 前三行就是入参赋值,没啥好说的
self._kv_layout = kv_layout
self._float_workspace_buffer = float_workspace_buffer
self.device = float_workspace_buffer.device
## 这里的backend 更多程度上其实是底层实现用flashattention2还是flashattention3
## auto 即自动识别,如果当前硬件支持fa3则用fa3,否则用fa2,而对于fa3
## vector_sparse 可以理解为一种中间形式的稀疏表达,之所以fa3 需要存储这个中间数组
## 是由于考虑到fa3下将vector_sparse 存放于GPU寄存器的话,寄存器不够
if backend in ["fa3", "auto"]:
# NOTE(Zihao): assume maximum accumulate kv length is 16M
self._vector_sparse_indices_buffer = torch.empty(
(16 * 1024 * 1024,), dtype=torch.int32, device=self.device
)
# NOTE(Zihao): assume maximum batch size is 32768
self._vector_sparse_indptr_buffer = torch.empty(
(32768,), dtype=torch.int32, device=self.device
)
## kv_lens_buffer 实际是请求对应的kv cache len的长度,单位也是tokens num
self._kv_lens_buffer = torch.empty(
(32768,), dtype=torch.int32, device=self.device
)
## 如下两个都是存储控制信息的buffer,区别在于_int_workspace_buffer是device 侧的buffer
## _pin_memory_int_workspace_buffer是host 侧的buffer,二者通过cudaMemcpyAsync 互相交互
## 后面我们会看到这个结构就是init_forward_metadata的核心。
self._int_workspace_buffer = torch.empty(
(8 * 1024 * 1024,), dtype=torch.uint8, device=self.device
)
self._pin_memory_int_workspace_buffer = torch.empty(
self._int_workspace_buffer.shape,
dtype=self._int_workspace_buffer.dtype,
device="cpu",
pin_memory=True,
)
以上的资源属于wrapper 内部的核心资源,最后就是将attentionbackend 引用进来方便访问。
self._qo_indptr_buf = qo_indptr_buf
self._paged_kv_indptr_buf = paged_kv_indptr_buf
self._paged_kv_indices_buf = paged_kv_indices_buf
self._paged_kv_last_page_len_buf = paged_kv_last_page_len_buf
self._backend = backend
## 以下主要是cuda graph使用参数,我们放到cuda graph 模式讲。
self._custom_mask_buf = custom_mask_buf
self._mask_indptr_buf = mask_indptr_buf
self._max_total_num_rows = None
而 decode wrapper 也比较类似,但是没有了其中一部分,注意我们这里也暂时屏蔽了cuda graph相关的实现。
可以看到,decode 下,query 相关的结构不见了(_qo_indptr_buf与_kv_lens_buffer),对于decode,都是one-by-one 的输出,query相关的内容本身也已经在gpu cache上,不需要额外传入(但是cuda graph 模式下也有query 相关结构,我们再解析)。另外,由于fa3 主要是在fa2 基础上加了seq parallel,只影响prefill,所以decode 这边不需要vector_sparse 这个中间层的buffer。
self._kv_layout = kv_layout
self._float_workspace_buffer = float_workspace_buffer
self.device = float_workspace_buffer.device
self._int_workspace_buffer = torch.empty(
(8 * 1024 * 1024,), dtype=torch.uint8, device=self.device
)
self._pin_memory_int_workspace_buffer = torch.empty(
(8 * 1024 * 1024,),
dtype=torch.uint8,
pin_memory=True,
device="cpu",
)
self._fixed_batch_size = 0
self._paged_kv_indptr_buf = paged_kv_indptr_buffer
self._paged_kv_indices_buf = paged_kv_indices_buffer
self._paged_kv_last_page_len_buf = paged_kv_last_page_len_buffer
self._use_tensor_cores = use_tensor_cores
self._use_cuda_graph = use_cuda_graph
本节最后,简单说明一下wrapper 的两个主要接口的功能。
1.
plan
forward 过程的第一步,将控制信息写入gpu
2.
run
forward过程第二步,从model 里forward会获得更新后的kvcache(有些forward 会在model 层进行kvcache的更新,比如deepseek),此时调用wrapper的run进行low-level 的run。
下图就是forward 过程中各结构接口的调用关系。
Plan Info 是什么
另外,我们还需要介绍一下plan info。这是wrapper的核心数据结构之一,属于运行的配置信息,plan info即flashinfer 计划为本次forward 提供的配置信息。数据结构如下:
struct PrefillPlanInfo {
int64_t padded_batch_size; # batch size, 和forward batchsize有一丝区别
# forwardbatch的batchsize 是当前batch里请求的个数
# padded_batch_size 可能比forward size大,它面向GPU CTA
# 每个CTA 需要计算的tile根据请求情况获得,padding_batchsize 根据tile 计算
int64_t total_num_rows; # 当前batch 处理总输入token长度, 对应sum(qo_indptr)
int64_t total_num_rows_offset; # 对应qo_indptr的数组指针
int64_t cta_tile_q; # 一个CTA 负责处理的query 长度,即tile 后的query 长度, 下面假设query 0 input 被tile成三个tile
int64_t request_indices_offset; # tile 后的request index 数组指针,like [0, 0, 0]
int64_t qo_tile_indices_offset; # tile 后的query index 数组指针,like 请求0 被tile 成三份,like[0, 1, 2]
int64_t kv_tile_indices_offset; # tile 后的kv index 数组指针,如果kv chunksize > need_kv_len, 则为[0, 0, 0]
int64_t merge_indptr_offset; # merge indptr 与tile 无关,是与请求和gqa有关的
# 如果模型的group size为4,则一个请求对应四个merge_indptr项, 比如[100,200,300,400]
int64_t o_indptr_offset; # 一个请求一个,值为对齐到tile_kv_len * group_size(mha下为1)
int64_t kv_chunk_size_ptr_offset; # 这里也是个数组指针,但是数组size 为1,内容就是kv_chunk_size
int64_t v_offset; # attention中间态计算结果,s_ = q*k, v_ = softmax(s_)*v
int64_t s_offset;
int64_t block_valid_mask_offset; # 数组指针,数组内容是根据tile 分片后的block是不是有效的
# 一般都是有效,但在cudagraph的使用下,会有不对齐的情况,以后再说
bool enable_cuda_graph; # 使用cuda graph
bool split_kv; # 是否进行了分片
}
plan info 实在是面向gpu的核心数据结构,这里才有了我们以往耳熟能详的tiling 过程。现在我们可以继续补充forward batch的流程图,forward batch 走进wrapper 里就是plan info了。
init_forward_meta 与 plan/prefill 举例
理解了整个初始化和主要的数据结构,接下来我们可以看看init_forward_meta的过程了。这里的核心就是wrapper的plan 接口。如下我列出了其中prefill 的branch的case(不考虑encoder-decoder 和 sliding window的实现)。
def init_forward_metadata(self, forward_batch: ForwardBatch):
prefix_lens = forward_batch.extend_prefix_lens
# Some heuristics to check whether to use ragged forward
# 如果有prefill token 太长的情况,使用ragged tensor
if forward_batch.extend_num_tokens >= 4096 and self.num_wrappers == 1:
use_ragged = True
extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
else:
use_ragged = False
extend_no_prefix = False
# 通过updater 更新 prefill wrapper
self.indices_updater_prefill.update(
forward_batch.req_pool_indices,
forward_batch.seq_lens,
forward_batch.seq_lens_sum,
prefix_lens,
prefill_wrappers=self.prefill_wrappers_paged,
use_ragged=use_ragged,
encoder_lens=forward_batch.encoder_lens,
spec_info=None,
)
## 更新forward_metadata
self.forward_metadata = PrefillMetadata(
self.prefill_wrappers_paged, use_ragged, extend_no_prefix
)
这里ForwardBatch的字段内容可以参考Bruce 仗剑走天涯:sglang 源码学习笔记(一)- Cache、Req与Scheduler(
https://zhuanlan.zhihu.com/p/17186885141
)里的说明。其中最重要的是update接口, 这里最终会调用到wrapper的plan,栈如下。实际上begin_forward 就是 plan,指针是同一个。
FlashInferAttnBackend.init_forward_metadata->
FlashInferIndicesUpdaterPrefill.update_single_wrapper->
FlashInferIndicesUpdaterPrefill.call_begin_forward->
BatchPrefillWithPagedKVCacheWrapper.begin_forward->
BatchPrefillWithPagedKVCacheWrapper.plan
我们主要讲两个函数的实现,call_begin_forward与plan。
call_begin_forward
def call_begin_forward(
self,
wrapper_ragged: BatchPrefillWithRaggedKVCacheWrapper, # ragged wrapper, 用于输入较长的情况
wrapper_paged: BatchPrefillWithPagedKVCacheWrapper, # 主要使用的wrapper
req_pool_indices: torch.Tensor, # batch中包含的请求index
paged_kernel_lens: torch.Tensor, # 请求的长度(对ragged 情况下,对应extend_prefix_lens)
paged_kernel_lens_sum: int, # paged_kernel_lens 之和
seq_lens: torch.Tensor, # 请求的完整长度
prefix_lens: torch.Tensor, # extend_prefix_lens
kv_start_idx: torch.Tensor, # 传参为一般None,实际上指各请求kv cache的起始index
kv_indptr: torch.Tensor, # attention backend 的kv 数组
qo_indptr: torch.Tensor, # attention backend 的qo 数组
use_ragged: bool, # 是否使用ragged tensor用于query
spec_info: Optional[SpecInfo], # 是否是投机推理
):
如上是call_begin_forward的传参,方便大家理解一些上下文。接下来是具体的实现。
## 获得当前batch的batchsize
bs = len(req_pool_indices)
if spec_info is None:
# Normal extend
# indptr 意为矩阵中每行非零值的起始位置,以下说明了每个请求的输出token存kv cache的位置
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
kv_indptr = kv_indptr[: bs + 1]
# 这里将分配一个数组,具体赋值在create_flashinfer_kv_indices_triton中
kv_indices = torch.empty(
paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
)
## 注意,这是一个并行函数,并行度是bs,如下调用说明同时起了bs个trtion 内核执行,入参都一样
create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices,
paged_kernel_lens,
kv_indptr,
kv_start_idx,
kv_indices,
self.req_to_token.shape[1],
)
## 以下说明了每个请求的输出token的起始位置
qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
qo_indptr = qo_indptr[: bs + 1]
custom_mask = None
## end_forward 接口已被废弃,可以忽略
wrapper_paged.end_forward()
## begin_forward == plan接口
wrapper_paged.begin_forward(
qo_indptr,
kv_indptr,
kv_indices,
self.kv_last_page_len[:bs],
self.num_qo_heads,
self.num_kv_heads,
self.head_dim,
1,
q_data_type=self.q_data_type,
custom_mask=custom_mask,
)
begin forward 稍后分析,我们先看看create_flashinfer_kv_indices_triton,首先这是一个并行执行的函数,依赖triton的jit 。上述我们看到了调用方的调用代码,注意[bs,] 这个部分,这表明在0轴上起了bs 个triton 内核执行该函数,bs的并行度 最终体现在函数内,就是tl.program_id(axis=0)的返回值,该返回值为[0, bs-1],所以以下其实是对入参数组的并行访问。
@triton.jitdefcreate_flashinfer_kv_indices_triton(
req_to_token_ptr,# [max_batch, max_context_len]
req_pool_indices_ptr,
page_kernel_lens_ptr,
kv_indptr,
kv_start_idx,
kv_indices_ptr,
req_to_token_ptr_stride:tl.constexpr,):
BLOCK_SIZE:tl.constexpr=512
pid=tl.program_id(axis=0)
# batch 中第N个请求的req_pool_index和kv_indices_offset
req_pool_index=tl.load(req_pool_indices_ptr+pid)
kv_indices_offset=tl.load(kv_indptr+pid)
kv_start=0
kv_end=0
ifkv_start_idx:
kv_start=tl.load(kv_start_idx+pid).to(tl.int32)
kv_end=kv_start
## 获得请求的kvcache 长度
kv_end+=tl.load(page_kernel_lens_ptr+pid).to(tl.int32)
num_loop=tl.cdiv(kv_end-kv_start,BLOCK_SIZE)
foriinrange(num_loop):
# block_size 又是一个并行度,意在加速load,store的并行效率
# offset 返回的是一个BLOCK_SIZE 维度的array
offset=tl.arange(0,BLOCK_SIZE)+i*BLOCK_SIZE
mask=offset<kv_end-kv_start
# 并行读取req_to_token pool 中req 对应的token 索引
data=tl.load(
req_to_token_ptr
+req_pool_index*req_to_token_ptr_stride
+kv_start
+offset,
mask=mask,
)
## 并行写入kv indices数组
## 注意,这里kv_indices是临时结构,和token_to_kv_pool 没有关系,但最终会作为wrapper的入参
tl.store(kv_indices_ptr+kv_indices_offset+offset,data,mask=mask)
backend CacheModule 是什么
在理解wrapper->plan 的调用链之前,我们先看看wrapper 里的核心结构————cache module。
具体上说,cache module 就是wrapper 真正的核心,是cpp 入口结构。它是被延后初始化的,因为主要是接口抽象类,不是实际资源,所以延后初始化也可以接受。cache module 被构建的时机是plan 接口调用时,根据backend的值和硬件当前情况,再次进行一次backend的判定,并根据backend的判定情况获取相应的cache module和挂载相应的具体接口。
首先是判断要不要使用flashattention3.
if self._backend == "auto":
self._backend = determine_attention_backend(
self.device,