专栏名称: GiantPandaLLM
专注于机器学习、深度学习、计算机视觉、图像处理等多个方向技术分享。团队由一群热爱技术且热衷于分享的小伙伴组成。我们坚持原创,每天一到两篇原创技术分享。希望在传播知识、分享知识的同时能够启发你,大家一起共同进步(・ω<)☆
目录
相关文章推荐
GiantPandaLLM  ·  [Triton编程][基础] Triton ... ·  4 天前  
51好读  ›  专栏  ›  GiantPandaLLM

SGLang 源码学习笔记:Cache、Req与Scheduler

GiantPandaLLM  · 公众号  · 3D  · 2025-05-14 18:19

正文

请到「今天看啥」查看全文



policy = self._determine_active_policy(waiting_queue)

prefix_computed = False
if isinstance(policy, CacheAwarePolicy):
prefix_computed = True
temporary_deprioritized = self._compute_prefix_matches(
waiting_queue, policy
)
if policy == CacheAwarePolicy.LPM:
SchedulePolicy._sort_by_longest_prefix(
waiting_queue, temporary_deprioritized
)
elif policy == CacheAwarePolicy.DFS_WEIGHT:
SchedulePolicy._sort_by_dfs_weight(waiting_queue, self.tree_cache)
else:
raise ValueError(f"Unknown CacheAware Policy: {policy=}")
else:
if policy == CacheAgnosticPolicy.FCFS:
pass
elif policy == CacheAgnosticPolicy.LOF:
SchedulePolicy._sort_by_longest_output(waiting_queue)
elif policy == CacheAgnosticPolicy.RANDOM:
SchedulePolicy._sort_randomly(waiting_queue)
else:
raise ValueError(f"Unknown CacheAgnostic Policy: {policy=}")

return prefix_computed

当然了这里有一些细节的优化,有兴趣可以仔细阅读这部分代码,我这里提两个:

_determine_active_policy 中如果发现等待队列太长且默认采用的是LPM(最长前缀匹配),则换成FCFS。但如果是dfs-weight则不影响,本质还是计算成本的权衡。

def _determine_active_policy(self, waiting_queue: List[Req]) -> Policy:
        if len(waiting_queue) > 128 and self.policy == CacheAwarePolicy.LPM:
            # Turn off the expensive prefix matching and sorting when the #queue is large.
            return CacheAgnosticPolicy.FCFS
        return self.policy

_compute_prefix_matches 有一种提高缓存命中率的策略in-batch prefix caching。如果当前batch(waiting queue)中,有不少请求有同一个前缀,而且前缀在已有cache中仅匹配了一小部分(

# NOTE(sang): This logic is for in-batch prefix caching;
# If there are more than 1 request that have small matching prefix from
# existing cache, but all those requests share the same prefix, we prefer
# to schedule only one of them so that we can increase the cache hit rate.
# We prefer to set IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD > 0 because too small
# threshold means we cannot use in-batch prefix caching for short prefixes.
# It is kind of common when the engine is long running (e.g., imagine the prefix "the").

PrefillAdder

第二个和调度相关的类是PrefillAdder,它决定了还能不能插入新请求,其返回有三种,语义很字面直白。

class AddReqResult(Enum):
    CONTINUE = auto()  # Continue to add requests
    NO_TOKEN = auto()  # No token left
    OTHER = auto()  # Other reasons to stop adding requests

PrefillAdder核心数据结构如下,关键是rem_total_tokens,rem_input_tokens,rem_chunk_tokens。他们的区别是:

  • • rem_total_tokens 包括prefill和decoding 一共的上下文长度
  • • rem_input_tokens 则只包括prefill 的输入
  • • rem_chunk_tokens 则是一个chunk可以包含的token数
## in python/sglang/srt/managers/schedule_policy.py
class PrefillAdder:
    def __init__(
        self,
        tree_cache: BasePrefixCache,
        running_batch: ScheduleBatch,
        new_token_ratio: float,
        rem_total_tokens: int,
        rem_input_tokens: int,
        rem_chunk_tokens: Optional[int],
        mixed_with_decode_tokens: int = 0,
    ):
        self.tree_cache = tree_cache
        self.running_batch = running_batch
        self.new_token_ratio = new_token_ratio
        self.rem_total_tokens = rem_total_tokens - mixed_with_decode_tokens
        self.rem_input_tokens = rem_input_tokens - mixed_with_decode_tokens
        self.rem_chunk_tokens = rem_chunk_tokens
        if self.rem_chunk_tokens is not None:
            self.rem_chunk_tokens -= mixed_with_decode_tokens

        self.cur_rem_tokens = rem_total_tokens - mixed_with_decode_tokens

        self.req_states = None
        self.can_run_list = []
        self.new_being_chunked_req = None
        self.log_hit_tokens = 0
        self.log_input_tokens = 0

我们可以用一个简单的接口函数,来体会返回状态和这几个关键变量的关系,如下函数是add_one_req 请求的最后一个环节,用来最终判断是否可以插入请求。

def budget_state(self):
    if self.rem_total_tokens <= 0 or self.cur_rem_tokens <= 0:
        return AddReqResult.NO_TOKEN

    if self.rem_input_tokens <= 0 or (
        self.rem_chunk_tokens is not None and self.rem_chunk_tokens <= 0
    ):
        return AddReqResult.OTHER

    return AddReqResult.CONTINUE

当然,判断返回的地方不只这个函数,具体可以阅读PrefillAdder add_one_req源码理解。对于理解kvcache 管理,目前这点可能就够了。最后,可以被插入的请求都会放在can_run_list这个列表中。

Req

Req 是核心请求类,包括判断请求是否可以结束,以及核心的数据结构。在介绍req 核心结构前,先简单看看几种finish reason。

FINISH_MATCHED_TOKEN # 匹配了终止的token,比如tokenizer,sampler,scheduler 等设置的eos token
FINISH_MATCHED_STR   # 匹配了终止的字符串,一般是sampler设置的
FINISH_LENGTH        # 匹配了最大输出长度
FINISH_ABORT         # 由于其他原因终止,比如请求不合法等等

req 核心成员变量较多,但为了理解调度,我们有必要过一些。为了方便分析,分为几段介绍。首先是输入输出信息,也是最重要的。

# Input and output info
self.rid = rid                                               #请求id, chunkedCache entry的key
self.origin_input_text = origin_input_text                   #原始请求输入文本字符串
self.origin_input_ids_unpadded = (                           #原始请求输入token list
            origin_input_ids_unpadded
            if origin_input_ids_unpadded
            else origin_input_ids  # Before image padding
        )
self.origin_input_ids = origin_input_ids                     #也是原始请求输入,但可能是padding过后的。
                                                             #通常和origin_input_ids_unpadded一样
                                                             #在image input下,sglang对输入做额外的padding,则有区别
self.output_ids = []  # Each decode stage's output ids       #输出token list
self.fill_ids = None  # fill_ids = origin_input_ids + output_ids # 完整的上下文token list
self.session_id = session_id                                 # 会话id,一轮用户会话可能有多个请求
self.input_embeds = input_embeds                             # embedding 化后的输入

# Memory pool info
self.req_pool_idx = None                                     #对于req_token_pool的索引

其次是用于判断结束的成员变量

# Check finish
self.tokenizer = None                           # tokenizer,可以用于eos等stop token判断
self.finished_reason = None                     # 结束理由
self.to_abort = False                           # 是否是finished_abort
self.stream = stream                            # 是否是流式的请求
self.eos_token_ids = eos_token_ids              # eos token list,用于结束判断

然后是用于推理的成员变量

# For incremental decoding
# ----- | --------- read_ids -------|
# ----- |   surr_ids  |
# xxxxx | xxxxxxxxxxx | xxxxxxxxxxx |
# ----- ^ ----------- ^ ----------- ^
# ----- 1 ----------- 2 ----------- 3
# 1: surr_offset
# 2: read_offset
# 3: last token
self.surr_offset = None  # Surrounding offset to defeat the cleanup algorithm
self.read_offset = None  
# 上图已经说明了surr_offset和read_offset的区别
# surr_offset通常记录上一次处理到的位置,read_offset 说明正在处理的位置
self.decoded_text = ""   # 解码的输出

# Prefix info,与共享prefix 的kvcache 有关
self.prefix_indices = []
# Tokens to run prefill. input_tokens - shared_prefix_tokens.
self.extend_input_len = 0
self.last_node = None

# Chunked prefill
self.is_being_chunked = 0

# The number of cached tokens, that were already cached in the KV cache
# cached的tokens
self.cached_tokens = 0

self.vid = 0  # version id to sync decode status with in detokenizer_manager
              # 只有jumpforward 会对其进行修改,同步detokenizer的状态

# For retraction  # 用于撤回类似的功能,即需要回退decode 的输出
self.is_retracted = False

# Constrained decoding, 一般用于类似json的结构化输出
self.grammar: Optional[BaseGrammarObject] = None

# Sampling info
self.sampling_params = sampling_params
self.lora_path = lora_path

#剩下一众logits 相关的,为了方便大家理解源码,我这里也介绍
self.return_logprob # 是否有必要返回logits
self.logprob_start_len # 从哪个位置开始算logits
self.extend_logprob_start_len  #, extend_即extend的部分开始算,简单理解extend_logprob_start_len = extend_logprob_start_lens - prefix_len
self.normalized_prompt_logprob #归一化后prompt的logits
# _idx 的list,即token 本身(idx 指词表里的index)
# _val 的list,即log值,即分布概率
# _output 和 _input 即输入输出,top即按照val 的top分布
self.top_logprobs_num
self.output_token_logprobs_idx
self.output_token_logprobs_val
self.output_top_logprobs_idx
self.output_top_logprobs_val
self.input_token_logprobs_idx
self.input_token_logprobs_val

上面提到jump forward decodig,有些同学可能不熟悉,这里简单介绍一下jump forward,其实这很容易理解,prompt 有时候会是一种类似“完形填空”的方式,而我们只需要生成其中”空白“的部分,不需要生成prompt 已经有的部分。图例如下。

最后我们再介绍一下req 几个比较重要的成员函数。

第一组:
finished() && check_finished() 用于判断是否可以结束,以及finished_reason 是哪种情况

第二组:
init_next_round_input #初始化本请求下一轮inference 需要的参数,比如计算需要用多长的kvcache
                     (主要是计算fill_ids和extend_input_len)

第三组:
init_incremental_detokenize 与 get_next_inc_detokenization
这两个函数通常是用于获取下一轮detokenizer 相关的参数并进行相关配置
逻辑上,detokenizer 自己会管理相关配置,req的这两个接口主要是for jump forward decoding,
由于jump forward的解码过程存在一些跳跃,所以需要请求级别自己去配置
同理上面也只有jump forward 需要单独提供detokinizer的vid,其他detokenize manager 自己就可以管理

第四组:
jump_forward_and_retokenize
也是jump forward 相关
所以我们看到jump forward和结构化输出相关,在sglang 这边也是一个相当重要的角色。

第五组:
reset_for_retract
为了撤回decode,重置decode 参数,比如

ModelConfig & ForwardMode

为了更方便理解SchdeuleBatch,我们还需要了解两个类,一个是modelConfig,另一个是forwardmode。

forward mode 主要是说明了sglang 支持的各种inference 模式,包括如下8种。

class ForwardMode(IntEnum):
    # Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
    PREFILL = auto()
    # Extend a sequence. The KV cache of the beginning part of the sequence is already computed (e.g., system prompt).
    # 即带cache的prefill,场景上覆盖了PREFILL
    EXTEND = auto()
    # Decode one token.
    DECODE = auto()
    # Contains both EXTEND and DECODE when doing chunked prefill.
    # 即一个batch 里既有prefill,又有decode
    MIXED = auto()
    # No sequence to forward. For data parallel attention, some workers wil be IDLE if no sequence are allocated.
    IDLE = auto() # 空闲

    # Used in speculative decoding: verify a batch in the target model.
    TARGET_VERIFY = auto()
    # Used in speculative decoding: extend a batch in the draft model.
    DRAFT_EXTEND = auto()
    # 上面两个是投机推理的模式,自回归不会用到。了解投机推理的应该很好理解这两个阶段。

    # A dummy first batch to start the pipeline for overlap scheduler.
    # It is now used for triggering the sampling_info_done event for the first prefill batch.
    # 这是一个特殊的模式,用于初始化scheduler的各种配置和相关预热 ,是第一个batch的forward 模式
    DUMMY_FIRST = auto()

ModelConfig则是有关inference的配置。其中重要的参数如下:

self.model_path = model_path #模型路径
self.revision = revision     # 版本,主要是拿开源配置用
self.quantization = quantization #量化

# Parse args, huggingface 开源配置,还允许override 配置
self.model_override_args = json.loads(model_override_args)
self.hf_config = get_config(
            model_path,
            trust_remote_code=trust_remote_code,
            revision=revision,
            model_override_args=self.model_override_args,
        )
self.hf_text_config = get_hf_text_config(self.hf_config)
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)

# Check model type
self.is_generation = is_generation_model(                               # 是不是生成模型
            self.hf_config.architectures, is_embedding
        )
self.is_multimodal = is_multimodal_model(self.hf_config.architectures)  # 是不是多模态
self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures) #是不是encoder-decoder模式

# 模型的配置信息,比如是MHA还是MLA,支持最大上下文长度,各自dim,lora/rope
# 逻辑上有了下面这堆参数,我们可以计算出kvcache需要多少
self.context_len = context_length
self.head_dim = 256
self.attention_arch = AttentionArch.MLA
self.kv_lora_rank = self.hf_config.kv_lora_rank
self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim
self.num_attention_heads = self.hf_text_config.num_attention_heads
self.num_key_value_heads = self.num_attention_heads
self.hidden_size = self.hf_text_config.hidden_size
self.num_hidden_layers = self.hf_text_config.num_hidden_layers

self.vocab_size = self.hf_text_config.vocab_size #词表大小,和tokenizer/sampling有关
self.hf_eos_token_id = self.get_hf_eos_token_id() # 终止tokenlist

ScheduleBatch->ModelWorkerBatch->ForwardBatch

接下来,我们隆重介绍batch三兄弟里的第一位,schedule batch,他是最上层的batch 结构,和scheduler 直接交互。有了以上的铺垫,理解scheduleBatch就相对简单了。







请到「今天看啥」查看全文