专栏名称: GiantPandaLLM
专注于机器学习、深度学习、计算机视觉、图像处理等多个方向技术分享。团队由一群热爱技术且热衷于分享的小伙伴组成。我们坚持原创,每天一到两篇原创技术分享。希望在传播知识、分享知识的同时能够启发你,大家一起共同进步(・ω<)☆
目录
相关文章推荐
GiantPandaLLM  ·  图解Vllm ... ·  5 天前  
51好读  ›  专栏  ›  GiantPandaLLM

sglang 源码学习笔记(二)- backend & forward 过程

GiantPandaLLM  · 公众号  · 3D  · 2025-05-18 22:14

正文

请到「今天看啥」查看全文



else:
assert self.attn_backend.num_wrappers == 1
self.update = self.update_single_wrapper

另外需要说明的是forward_metadata,forward metadata 和 updater 对应,包括两种,由于sglang scheduler 本身同时只会有一个forward batch,所以只需要一份forward metadata 即可。

        # Other metadata
        self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None
        self.decode_cuda_graph_metadata = {}
        self.prefill_cuda_graph_metadata = {}

这里的forward_metadata 其实主要是Wrapper 和周围参数,每次forward 都会在metadata里指定wrapper。

@dataclassclassDecodeMetadata:
    decode_wrappers:List[BatchDecodeWithPagedKVCacheWrapper]@dataclassclassPrefillMetadata:
    prefill_wrappers:List[BatchPrefillWithPagedKVCacheWrapper]
    use_ragged:bool
    extend_no_prefix:bool

wrapper 的数据结构和初始化过程

首先是检查入参,kv layout 只支持NHD和HND。

        _check_kv_layout(kv_layout)

有关jit 的功能, 我们先不管。所以直接看内部新引入的资源,以下以BatchPrefillWithPagedKVCacheWrapper 为例。

        ## 前三行就是入参赋值,没啥好说的
        self._kv_layout = kv_layout
        self._float_workspace_buffer = float_workspace_buffer
        self.device = float_workspace_buffer.device
        ## 这里的backend 更多程度上其实是底层实现用flashattention2还是flashattention3
        ## auto 即自动识别,如果当前硬件支持fa3则用fa3,否则用fa2,而对于fa3
        ## vector_sparse 可以理解为一种中间形式的稀疏表达,之所以fa3 需要存储这个中间数组
        ## 是由于考虑到fa3下将vector_sparse 存放于GPU寄存器的话,寄存器不够
        if backend in ["fa3", "auto"]:
            # NOTE(Zihao): assume maximum accumulate kv length is 16M
            self._vector_sparse_indices_buffer = torch.empty(
                (16 * 1024 * 1024,), dtype=torch.int32, device=self.device
            )
            # NOTE(Zihao): assume maximum batch size is 32768
            self._vector_sparse_indptr_buffer = torch.empty(
                (32768,), dtype=torch.int32, device=self.device
            )

        ## kv_lens_buffer 实际是请求对应的kv cache len的长度,单位也是tokens num
        self._kv_lens_buffer = torch.empty(
            (32768,), dtype=torch.int32, device=self.device
        )
        
        ## 如下两个都是存储控制信息的buffer,区别在于_int_workspace_buffer是device 侧的buffer
        ## _pin_memory_int_workspace_buffer是host 侧的buffer,二者通过cudaMemcpyAsync 互相交互
        ## 后面我们会看到这个结构就是init_forward_metadata的核心。
        self._int_workspace_buffer = torch.empty(
            (8 * 1024 * 1024,), dtype=torch.uint8, device=self.device
        )
        self._pin_memory_int_workspace_buffer = torch.empty(
            self._int_workspace_buffer.shape,
            dtype=self._int_workspace_buffer.dtype,
            device="cpu",
            pin_memory=True,
        )

以上的资源属于wrapper 内部的核心资源,最后就是将attentionbackend 引用进来方便访问。

        self._qo_indptr_buf = qo_indptr_buf
        self._paged_kv_indptr_buf = paged_kv_indptr_buf
        self._paged_kv_indices_buf = paged_kv_indices_buf
        self._paged_kv_last_page_len_buf = paged_kv_last_page_len_buf
        self._backend = backend

        ## 以下主要是cuda graph使用参数,我们放到cuda graph 模式讲。
        self._custom_mask_buf = custom_mask_buf
        self._mask_indptr_buf = mask_indptr_buf
        self._max_total_num_rows = None

而 decode wrapper 也比较类似,但是没有了其中一部分,注意我们这里也暂时屏蔽了cuda graph相关的实现。

可以看到,decode 下,query 相关的结构不见了(_qo_indptr_buf与_kv_lens_buffer),对于decode,都是one-by-one 的输出,query相关的内容本身也已经在gpu cache上,不需要额外传入(但是cuda graph 模式下也有query 相关结构,我们再解析)。另外,由于fa3 主要是在fa2 基础上加了seq parallel,只影响prefill,所以decode 这边不需要vector_sparse 这个中间层的buffer。

        self._kv_layout = kv_layout
        self._float_workspace_buffer = float_workspace_buffer
        self.device = float_workspace_buffer.device
        self._int_workspace_buffer = torch.empty(
            (8 * 1024 * 1024,), dtype=torch.uint8, device=self.device
        )
        self._pin_memory_int_workspace_buffer = torch.empty(
            (8 * 1024 * 1024,),
            dtype=torch.uint8,
            pin_memory=True,
            device="cpu",
        )
        self._fixed_batch_size = 0
        self._paged_kv_indptr_buf = paged_kv_indptr_buffer
        self._paged_kv_indices_buf = paged_kv_indices_buffer
        self._paged_kv_last_page_len_buf = paged_kv_last_page_len_buffer
        self._use_tensor_cores = use_tensor_cores
        self._use_cuda_graph = use_cuda_graph

本节最后,简单说明一下wrapper 的两个主要接口的功能。

1. plan forward 过程的第一步,将控制信息写入gpu

2. run forward过程第二步,从model 里forward会获得更新后的kvcache(有些forward 会在model 层进行kvcache的更新,比如deepseek),此时调用wrapper的run进行low-level 的run。

下图就是forward 过程中各结构接口的调用关系。

Plan Info 是什么

另外,我们还需要介绍一下plan info。这是wrapper的核心数据结构之一,属于运行的配置信息,plan info即flashinfer 计划为本次forward 提供的配置信息。数据结构如下:

struct PrefillPlanInfo {
  int64_t padded_batch_size;                    # batch size, 和forward batchsize有一丝区别
                                                # forwardbatch的batchsize 是当前batch里请求的个数
                                                # padded_batch_size 可能比forward size大,它面向GPU CTA
                                                # 每个CTA 需要计算的tile根据请求情况获得,padding_batchsize 根据tile 计算
  int64_t total_num_rows;                       # 当前batch 处理总输入token长度, 对应sum(qo_indptr)
  int64_t total_num_rows_offset;                # 对应qo_indptr的数组指针
  int64_t cta_tile_q;                           # 一个CTA 负责处理的query 长度,即tile 后的query 长度, 下面假设query 0 input 被tile成三个tile
  int64_t request_indices_offset;               # tile 后的request index 数组指针,like [0, 0, 0]
  int64_t qo_tile_indices_offset;               # tile 后的query index 数组指针,like 请求0 被tile 成三份,like[0, 1, 2]
  int64_t kv_tile_indices_offset;               # tile 后的kv index 数组指针,如果kv chunksize > need_kv_len, 则为[0, 0, 0]
  int64_t merge_indptr_offset;                  # merge indptr 与tile 无关,是与请求和gqa有关的
                                                # 如果模型的group size为4,则一个请求对应四个merge_indptr项, 比如[100,200,300,400]
  int64_t o_indptr_offset;                      # 一个请求一个,值为对齐到tile_kv_len * group_size(mha下为1)
  int64_t kv_chunk_size_ptr_offset;             # 这里也是个数组指针,但是数组size 为1,内容就是kv_chunk_size
  int64_t v_offset;                             # attention中间态计算结果,s_ = q*k, v_ = softmax(s_)*v
  int64_t s_offset;                             
  int64_t block_valid_mask_offset;              # 数组指针,数组内容是根据tile 分片后的block是不是有效的
                                                # 一般都是有效,但在cudagraph的使用下,会有不对齐的情况,以后再说
  bool enable_cuda_graph;                       # 使用cuda graph
  bool split_kv;                                # 是否进行了分片
}

plan info 实在是面向gpu的核心数据结构,这里才有了我们以往耳熟能详的tiling 过程。现在我们可以继续补充forward batch的流程图,forward batch 走进wrapper 里就是plan info了。

init_forward_meta 与 plan/prefill 举例

理解了整个初始化和主要的数据结构,接下来我们可以看看init_forward_meta的过程了。这里的核心就是wrapper的plan 接口。如下我列出了其中prefill 的branch的case(不考虑encoder-decoder 和 sliding window的实现)。

        def init_forward_metadata(self, forward_batch: ForwardBatch):
            prefix_lens = forward_batch.extend_prefix_lens
            # Some heuristics to check whether to use ragged forward
            # 如果有prefill token 太长的情况,使用ragged tensor
            if forward_batch.extend_num_tokens >= 4096 and self.num_wrappers == 1:
                use_ragged = True
                extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
            else:
                use_ragged = False
                extend_no_prefix = False 
            # 通过updater 更新 prefill wrapper
            self.indices_updater_prefill.update(
                forward_batch.req_pool_indices,
                forward_batch.seq_lens,
                forward_batch.seq_lens_sum,
                prefix_lens,
                prefill_wrappers=self.prefill_wrappers_paged,
                use_ragged=use_ragged,
                encoder_lens=forward_batch.encoder_lens,
                spec_info=None,
            )
            ## 更新forward_metadata
            self.forward_metadata = PrefillMetadata(
                self.prefill_wrappers_paged, use_ragged, extend_no_prefix
            )

这里ForwardBatch的字段内容可以参考Bruce 仗剑走天涯:sglang 源码学习笔记(一)- Cache、Req与Scheduler( https://zhuanlan.zhihu.com/p/17186885141 )里的说明。其中最重要的是update接口, 这里最终会调用到wrapper的plan,栈如下。实际上begin_forward 就是 plan,指针是同一个。

FlashInferAttnBackend.init_forward_metadata->
    FlashInferIndicesUpdaterPrefill.update_single_wrapper->
        FlashInferIndicesUpdaterPrefill.call_begin_forward->
            BatchPrefillWithPagedKVCacheWrapper.begin_forward->
                BatchPrefillWithPagedKVCacheWrapper.plan

我们主要讲两个函数的实现,call_begin_forward与plan。

call_begin_forward

def call_begin_forward(
        self,
        wrapper_ragged: BatchPrefillWithRaggedKVCacheWrapper,  # ragged wrapper, 用于输入较长的情况
        wrapper_paged: BatchPrefillWithPagedKVCacheWrapper,    # 主要使用的wrapper
        req_pool_indices: torch.Tensor,                        # batch中包含的请求index
        paged_kernel_lens: torch.Tensor,                       # 请求的长度(对ragged 情况下,对应extend_prefix_lens)
        paged_kernel_lens_sum: int,                            # paged_kernel_lens 之和
        seq_lens: torch.Tensor,                                # 请求的完整长度
        prefix_lens: torch.Tensor,                             # extend_prefix_lens
        kv_start_idx: torch.Tensor,                            # 传参为一般None,实际上指各请求kv cache的起始index 
        kv_indptr: torch.Tensor,                               # attention backend 的kv 数组
        qo_indptr: torch.Tensor,                               # attention backend 的qo 数组
        use_ragged: bool,                                      # 是否使用ragged tensor用于query
        spec_info: Optional[SpecInfo],                         # 是否是投机推理
    ):

如上是call_begin_forward的传参,方便大家理解一些上下文。接下来是具体的实现。

        ## 获得当前batch的batchsize
        bs = len(req_pool_indices)
        if spec_info is None:
            # Normal extend
            # indptr 意为矩阵中每行非零值的起始位置,以下说明了每个请求的输出token存kv cache的位置
            kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
            kv_indptr = kv_indptr[: bs + 1]
            # 这里将分配一个数组,具体赋值在create_flashinfer_kv_indices_triton中
            kv_indices = torch.empty(
                paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
            )
            ## 注意,这是一个并行函数,并行度是bs,如下调用说明同时起了bs个trtion 内核执行,入参都一样
            create_flashinfer_kv_indices_triton[(bs,)](
                self.req_to_token,
                req_pool_indices,
                paged_kernel_lens,
                kv_indptr,
                kv_start_idx,
                kv_indices,
                self.req_to_token.shape[1],
            )
            ## 以下说明了每个请求的输出token的起始位置
            qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
            qo_indptr = qo_indptr[: bs + 1]
            custom_mask = None
        ## end_forward 接口已被废弃,可以忽略
        wrapper_paged.end_forward()
        ## begin_forward == plan接口
        wrapper_paged.begin_forward(
            qo_indptr,
            kv_indptr,
            kv_indices,
            self.kv_last_page_len[:bs],
            self.num_qo_heads,
            self.num_kv_heads,
            self.head_dim,
            1,
            q_data_type=self.q_data_type,
            custom_mask=custom_mask,
        )

begin forward 稍后分析,我们先看看create_flashinfer_kv_indices_triton,首先这是一个并行执行的函数,依赖triton的jit 。上述我们看到了调用方的调用代码,注意[bs,] 这个部分,这表明在0轴上起了bs 个triton 内核执行该函数,bs的并行度 最终体现在函数内,就是tl.program_id(axis=0)的返回值,该返回值为[0, bs-1],所以以下其实是对入参数组的并行访问。

@triton.jitdefcreate_flashinfer_kv_indices_triton(
    req_to_token_ptr,# [max_batch, max_context_len]
    req_pool_indices_ptr,
    page_kernel_lens_ptr,
    kv_indptr,
    kv_start_idx,
    kv_indices_ptr,
    req_to_token_ptr_stride:tl.constexpr,):
    BLOCK_SIZE:tl.constexpr=512
    pid=tl.program_id(axis=0)

    # batch 中第N个请求的req_pool_index和kv_indices_offset
    req_pool_index=tl.load(req_pool_indices_ptr+pid)
    kv_indices_offset=tl.load(kv_indptr+pid)

    kv_start=0
    kv_end=0
    ifkv_start_idx:
        kv_start=tl.load(kv_start_idx+pid).to(tl.int32)
        kv_end=kv_start
    ## 获得请求的kvcache 长度
    kv_end+=tl.load(page_kernel_lens_ptr+pid).to(tl.int32)

    num_loop=tl.cdiv(kv_end-kv_start,BLOCK_SIZE)
    foriinrange(num_loop):
        # block_size 又是一个并行度,意在加速load,store的并行效率
        # offset 返回的是一个BLOCK_SIZE 维度的array
        offset=tl.arange(0,BLOCK_SIZE)+i*BLOCK_SIZE
        mask=offset<kv_end-kv_start
        # 并行读取req_to_token pool 中req 对应的token 索引
        data=tl.load(
            req_to_token_ptr
            +req_pool_index*req_to_token_ptr_stride
            +kv_start
            +offset,
            mask=mask,
        )
        ## 并行写入kv indices数组
        ## 注意,这里kv_indices是临时结构,和token_to_kv_pool 没有关系,但最终会作为wrapper的入参
        tl.store(kv_indices_ptr+kv_indices_offset+offset,data,mask=mask)

backend CacheModule 是什么

在理解wrapper->plan 的调用链之前,我们先看看wrapper 里的核心结构————cache module。

具体上说,cache module 就是wrapper 真正的核心,是cpp 入口结构。它是被延后初始化的,因为主要是接口抽象类,不是实际资源,所以延后初始化也可以接受。cache module 被构建的时机是plan 接口调用时,根据backend的值和硬件当前情况,再次进行一次backend的判定,并根据backend的判定情况获取相应的cache module和挂载相应的具体接口。

首先是判断要不要使用flashattention3.

            if self._backend == "auto":
                self._backend = determine_attention_backend(
                    self.device,






请到「今天看啥」查看全文