正文
policy = self._determine_active_policy(waiting_queue)
prefix_computed = False
if isinstance(policy, CacheAwarePolicy):
prefix_computed = True
temporary_deprioritized = self._compute_prefix_matches(
waiting_queue, policy
)
if policy == CacheAwarePolicy.LPM:
SchedulePolicy._sort_by_longest_prefix(
waiting_queue, temporary_deprioritized
)
elif policy == CacheAwarePolicy.DFS_WEIGHT:
SchedulePolicy._sort_by_dfs_weight(waiting_queue, self.tree_cache)
else:
raise ValueError(f"Unknown CacheAware Policy: {policy=}")
else:
if policy == CacheAgnosticPolicy.FCFS:
pass
elif policy == CacheAgnosticPolicy.LOF:
SchedulePolicy._sort_by_longest_output(waiting_queue)
elif policy == CacheAgnosticPolicy.RANDOM:
SchedulePolicy._sort_randomly(waiting_queue)
else:
raise ValueError(f"Unknown CacheAgnostic Policy: {policy=}")
return prefix_computed
当然了这里有一些细节的优化,有兴趣可以仔细阅读这部分代码,我这里提两个:
_determine_active_policy 中如果发现等待队列太长且默认采用的是LPM(最长前缀匹配),则换成FCFS。但如果是dfs-weight则不影响,本质还是计算成本的权衡。
def _determine_active_policy(self, waiting_queue: List[Req]) -> Policy: if len(waiting_queue) > 128 and self.policy == CacheAwarePolicy.LPM: # Turn off the expensive prefix matching and sorting when the #queue is large. return CacheAgnosticPolicy.FCFS return self.policy
_compute_prefix_matches 有一种提高缓存命中率的策略in-batch prefix caching。如果当前batch(waiting queue)中,有不少请求有同一个前缀,而且前缀在已有cache中仅匹配了一小部分(
# NOTE(sang): This logic is for in-batch prefix caching; # If there are more than 1 request that have small matching prefix from # existing cache, but all those requests share the same prefix, we prefer # to schedule only one of them so that we can increase the cache hit rate. # We prefer to set IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD > 0 because too small # threshold means we cannot use in-batch prefix caching for short prefixes. # It is kind of common when the engine is long running (e.g., imagine the prefix "the").
PrefillAdder
第二个和调度相关的类是PrefillAdder,它决定了还能不能插入新请求,其返回有三种,语义很字面直白。
class AddReqResult(Enum): CONTINUE = auto() # Continue to add requests NO_TOKEN = auto() # No token left OTHER = auto() # Other reasons to stop adding requests
PrefillAdder核心数据结构如下,关键是rem_total_tokens,rem_input_tokens,rem_chunk_tokens。他们的区别是:
• rem_total_tokens 包括prefill和decoding 一共的上下文长度
• rem_input_tokens 则只包括prefill 的输入
• rem_chunk_tokens 则是一个chunk可以包含的token数
## in python/sglang/srt/managers/schedule_policy.py class PrefillAdder: def __init__( self, tree_cache: BasePrefixCache, running_batch: ScheduleBatch, new_token_ratio: float, rem_total_tokens: int, rem_input_tokens: int, rem_chunk_tokens: Optional[int], mixed_with_decode_tokens: int = 0, ): self.tree_cache = tree_cache self.running_batch = running_batch self.new_token_ratio = new_token_ratio self.rem_total_tokens = rem_total_tokens - mixed_with_decode_tokens self.rem_input_tokens = rem_input_tokens - mixed_with_decode_tokens self.rem_chunk_tokens = rem_chunk_tokens if self.rem_chunk_tokens is not None: self.rem_chunk_tokens -= mixed_with_decode_tokens self.cur_rem_tokens = rem_total_tokens - mixed_with_decode_tokens self.req_states = None self.can_run_list = [] self.new_being_chunked_req = None self.log_hit_tokens = 0 self.log_input_tokens = 0
我们可以用一个简单的接口函数,来体会返回状态和这几个关键变量的关系,如下函数是add_one_req 请求的最后一个环节,用来最终判断是否可以插入请求。
def budget_state(self): if self.rem_total_tokens <= 0 or self.cur_rem_tokens <= 0: return AddReqResult.NO_TOKEN if self.rem_input_tokens <= 0 or ( self.rem_chunk_tokens is not None and self.rem_chunk_tokens <= 0 ): return AddReqResult.OTHER return AddReqResult.CONTINUE
当然,判断返回的地方不只这个函数,具体可以阅读PrefillAdder add_one_req源码理解。对于理解kvcache 管理,目前这点可能就够了。最后,可以被插入的请求都会放在can_run_list这个列表中。
Req
Req 是核心请求类,包括判断请求是否可以结束,以及核心的数据结构。在介绍req 核心结构前,先简单看看几种finish reason。
FINISH_MATCHED_TOKEN # 匹配了终止的token,比如tokenizer,sampler,scheduler 等设置的eos token FINISH_MATCHED_STR # 匹配了终止的字符串,一般是sampler设置的 FINISH_LENGTH # 匹配了最大输出长度 FINISH_ABORT # 由于其他原因终止,比如请求不合法等等
req 核心成员变量较多,但为了理解调度,我们有必要过一些。为了方便分析,分为几段介绍。首先是输入输出信息,也是最重要的。
# Input and output info self.rid = rid #请求id, chunkedCache entry的key self.origin_input_text = origin_input_text #原始请求输入文本字符串 self.origin_input_ids_unpadded = ( #原始请求输入token list origin_input_ids_unpadded if origin_input_ids_unpadded else origin_input_ids # Before image padding ) self.origin_input_ids = origin_input_ids #也是原始请求输入,但可能是padding过后的。 #通常和origin_input_ids_unpadded一样 #在image input下,sglang对输入做额外的padding,则有区别 self.output_ids = [] # Each decode stage's output ids #输出token list self.fill_ids = None # fill_ids = origin_input_ids + output_ids # 完整的上下文token list self.session_id = session_id # 会话id,一轮用户会话可能有多个请求 self.input_embeds = input_embeds # embedding 化后的输入 # Memory pool info self.req_pool_idx = None #对于req_token_pool的索引
其次是用于判断结束的成员变量
# Check finish self.tokenizer = None # tokenizer,可以用于eos等stop token判断 self.finished_reason = None # 结束理由 self.to_abort = False # 是否是finished_abort self.stream = stream # 是否是流式的请求 self.eos_token_ids = eos_token_ids # eos token list,用于结束判断
然后是用于推理的成员变量
# For incremental decoding # ----- | --------- read_ids -------| # ----- | surr_ids | # xxxxx | xxxxxxxxxxx | xxxxxxxxxxx | # ----- ^ ----------- ^ ----------- ^ # ----- 1 ----------- 2 ----------- 3 # 1: surr_offset # 2: read_offset # 3: last token self.surr_offset = None # Surrounding offset to defeat the cleanup algorithm self.read_offset = None # 上图已经说明了surr_offset和read_offset的区别 # surr_offset通常记录上一次处理到的位置,read_offset 说明正在处理的位置 self.decoded_text = "" # 解码的输出 # Prefix info,与共享prefix 的kvcache 有关 self.prefix_indices = [] # Tokens to run prefill. input_tokens - shared_prefix_tokens. self.extend_input_len = 0 self.last_node = None # Chunked prefill self.is_being_chunked = 0 # The number of cached tokens, that were already cached in the KV cache # cached的tokens self.cached_tokens = 0 self.vid = 0 # version id to sync decode status with in detokenizer_manager # 只有jumpforward 会对其进行修改,同步detokenizer的状态 # For retraction # 用于撤回类似的功能,即需要回退decode 的输出 self.is_retracted = False # Constrained decoding, 一般用于类似json的结构化输出 self.grammar: Optional[BaseGrammarObject] = None # Sampling info self.sampling_params = sampling_params self.lora_path = lora_path #剩下一众logits 相关的,为了方便大家理解源码,我这里也介绍 self.return_logprob # 是否有必要返回logits self.logprob_start_len # 从哪个位置开始算logits self.extend_logprob_start_len #, extend_即extend的部分开始算,简单理解extend_logprob_start_len = extend_logprob_start_lens - prefix_len self.normalized_prompt_logprob #归一化后prompt的logits # _idx 的list,即token 本身(idx 指词表里的index) # _val 的list,即log值,即分布概率 # _output 和 _input 即输入输出,top即按照val 的top分布 self.top_logprobs_num self.output_token_logprobs_idx self.output_token_logprobs_val self.output_top_logprobs_idx self.output_top_logprobs_val self.input_token_logprobs_idx self.input_token_logprobs_val
上面提到jump forward decodig,有些同学可能不熟悉,这里简单介绍一下jump forward,其实这很容易理解,prompt 有时候会是一种类似“完形填空”的方式,而我们只需要生成其中”空白“的部分,不需要生成prompt 已经有的部分。图例如下。
最后我们再介绍一下req 几个比较重要的成员函数。
第一组: finished() && check_finished() 用于判断是否可以结束,以及finished_reason 是哪种情况 第二组: init_next_round_input #初始化本请求下一轮inference 需要的参数,比如计算需要用多长的kvcache (主要是计算fill_ids和extend_input_len) 第三组: init_incremental_detokenize 与 get_next_inc_detokenization 这两个函数通常是用于获取下一轮detokenizer 相关的参数并进行相关配置 逻辑上,detokenizer 自己会管理相关配置,req的这两个接口主要是for jump forward decoding, 由于jump forward的解码过程存在一些跳跃,所以需要请求级别自己去配置 同理上面也只有jump forward 需要单独提供detokinizer的vid,其他detokenize manager 自己就可以管理 第四组: jump_forward_and_retokenize 也是jump forward 相关 所以我们看到jump forward和结构化输出相关,在sglang 这边也是一个相当重要的角色。 第五组: reset_for_retract 为了撤回decode,重置decode 参数,比如
ModelConfig & ForwardMode
为了更方便理解SchdeuleBatch,我们还需要了解两个类,一个是modelConfig,另一个是forwardmode。
forward mode 主要是说明了sglang 支持的各种inference 模式,包括如下8种。
class ForwardMode(IntEnum): # Prefill a new sequence. This is deprecated now. "EXTEND" covers this case. PREFILL = auto() # Extend a sequence. The KV cache of the beginning part of the sequence is already computed (e.g., system prompt). # 即带cache的prefill,场景上覆盖了PREFILL EXTEND = auto() # Decode one token. DECODE = auto() # Contains both EXTEND and DECODE when doing chunked prefill. # 即一个batch 里既有prefill,又有decode MIXED = auto() # No sequence to forward. For data parallel attention, some workers wil be IDLE if no sequence are allocated. IDLE = auto() # 空闲 # Used in speculative decoding: verify a batch in the target model. TARGET_VERIFY = auto() # Used in speculative decoding: extend a batch in the draft model. DRAFT_EXTEND = auto() # 上面两个是投机推理的模式,自回归不会用到。了解投机推理的应该很好理解这两个阶段。 # A dummy first batch to start the pipeline for overlap scheduler. # It is now used for triggering the sampling_info_done event for the first prefill batch. # 这是一个特殊的模式,用于初始化scheduler的各种配置和相关预热 ,是第一个batch的forward 模式 DUMMY_FIRST = auto()
ModelConfig则是有关inference的配置。其中重要的参数如下:
self.model_path = model_path #模型路径 self.revision = revision # 版本,主要是拿开源配置用 self.quantization = quantization #量化 # Parse args, huggingface 开源配置,还允许override 配置 self.model_override_args = json.loads(model_override_args) self.hf_config = get_config( model_path, trust_remote_code=trust_remote_code, revision=revision, model_override_args=self.model_override_args, ) self.hf_text_config = get_hf_text_config(self.hf_config) self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) # Check model type self.is_generation = is_generation_model( # 是不是生成模型 self.hf_config.architectures, is_embedding ) self.is_multimodal = is_multimodal_model(self.hf_config.architectures) # 是不是多模态 self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures) #是不是encoder-decoder模式 # 模型的配置信息,比如是MHA还是MLA,支持最大上下文长度,各自dim,lora/rope # 逻辑上有了下面这堆参数,我们可以计算出kvcache需要多少 self.context_len = context_length self.head_dim = 256 self.attention_arch = AttentionArch.MLA self.kv_lora_rank = self.hf_config.kv_lora_rank self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim self.num_attention_heads = self.hf_text_config.num_attention_heads self.num_key_value_heads = self.num_attention_heads self.hidden_size = self.hf_text_config.hidden_size self.num_hidden_layers = self.hf_text_config.num_hidden_layers self.vocab_size = self.hf_text_config.vocab_size #词表大小,和tokenizer/sampling有关 self.hf_eos_token_id = self.get_hf_eos_token_id() # 终止tokenlist
ScheduleBatch->ModelWorkerBatch->ForwardBatch
接下来,我们隆重介绍batch三兄弟里的第一位,schedule batch,他是最上层的batch 结构,和scheduler 直接交互。有了以上的铺垫,理解scheduleBatch就相对简单了。