专栏名称: GiantPandaLLM
专注于机器学习、深度学习、计算机视觉、图像处理等多个方向技术分享。团队由一群热爱技术且热衷于分享的小伙伴组成。我们坚持原创,每天一到两篇原创技术分享。希望在传播知识、分享知识的同时能够启发你,大家一起共同进步(・ω<)☆
目录
相关文章推荐
GiantPandaLLM  ·  [Triton编程][基础] Triton ... ·  4 天前  
51好读  ›  专栏  ›  GiantPandaLLM

LightLLM中DeepSeek V3/R1 Two MicroBatch Overlap 实现解...

GiantPandaLLM  · 公众号  · 3D  · 2025-05-25 19:14

正文

请到「今天看啥」查看全文



micro_batch, run_reqs, padded_req_num = _padded_prepare_decode_micro_batch(
req_objs[ 0 :micro_batch1_req_num], micro_batch_size, is_multimodal=is_multimodal
)

# 创建第二个micro batch
micro_batch1, run_reqs1, padded_req_num1 = _padded_prepare_decode_micro_batch(
req_objs[micro_batch1_req_num:], micro_batch_size, is_multimodal=is_multimodal
)

return micro_batch, run_reqs, padded_req_num, micro_batch1, run_reqs1, padded_req_num1

关键实现细节

  1. 均匀分配 :使用 triton.cdiv 确保两个micro batch大小尽可能均匀
  2. Padding处理 :对于不足的请求,使用fake请求进行padding
  3. 内存分配 :每个micro batch都有独立的内存索引

0x03.2 Prefill阶段的MicroBatch拆分

def padded_overlap_prepare_prefill_inputs(req_objs: List[InferReq], max_prefill_num: int, is_multimodal=False):
    """
    为Prefill阶段准备重叠执行的两个MicroBatch
    
    Args:
        req_objs: 待处理的推理请求列表
        max_prefill_num: 最大prefill数量限制
        is_multimodal: 是否为多模态输入
    
    Returns:
        tuple: (micro_batch, run_reqs, padded_req_num, micro_batch1, run_reqs1, padded_req_num1)
    """

    assert max_prefill_num != 0
    
    # 将请求列表平均分成两部分,为两个micro batch做准备
    # 使用triton.cdiv确保向上取整,避免遗漏请求
    micro_batch1_req_num = triton.cdiv(len(req_objs), 2)
    
    # 创建第一个micro batch:处理前半部分请求
    micro_batch, run_reqs, padded_req_num = _padded_prepare_prefill_micro_batch(
        req_objs[0:micro_batch1_req_num], is_multimodal=is_multimodal
    )
    
    # 创建第二个micro batch:处理后半部分请求
    micro_batch1, run_reqs1, padded_req_num1 = _padded_prepare_prefill_micro_batch(
        req_objs[micro_batch1_req_num:], is_multimodal=is_multimodal
    )

    return micro_batch, run_reqs, padded_req_num, micro_batch1, run_reqs1, padded_req_num1


def _padded_prepare_prefill_micro_batch(req_objs: List[InferReq], is_multimodal=False):
    """
    为单个MicroBatch准备Prefill阶段的数据
    
    Args:
        req_objs: 分配给当前micro batch的请求列表
        is_multimodal: 是否支持多模态输入
    
    Returns:
        tuple: (micro_batch, run_reqs, padded_req_num)
    """

    # === 初始化数据收集变量 ===
    run_reqs = []                    # 实际运行的请求列表
    nopad_total_token_num = 0        # 未padding前的总token数量
    nopad_max_len_in_batch = 0       # 批次中的最大输入长度
    input_ids = []                   # 所有请求的input token ids
    nopad_b_req_idx = []             # 请求索引列表
    nopad_b_seq_len = []             # 每个请求的序列长度
    
    # prefill阶段只需要padding一个请求形成micro_batch
    # 并不需要两个micro batch的batch_size相同(与decode阶段不同)
    padded_req_num = 1if len(req_objs) == 0else0
    
    b_ready_cache_len = []           # 每个请求已缓存的KV长度
    batch_multimodal_params = []     # 多模态参数列表

    # === 处理每个真实请求 ===
    for req in req_objs:
        run_reqs.append(req)
        batch_multimodal_params.append(req.multimodal_params)
        nopad_b_req_idx.append(req.req_idx)

        # 获取请求的完整input token序列
        input_token_ids = req.get_chuncked_input_token_ids()
        seq_len = len(input_token_ids)
        
        # 计算需要处理的新token长度(总长度 - 已缓存长度)
        input_token_len = seq_len - req.cur_kv_len
        
        # 提取需要处理的新token部分
        input_id = input_token_ids[req.cur_kv_len :]

        # 收集批次统计信息
        nopad_b_seq_len.append(seq_len)
        input_ids.extend(input_id)                    # 将新token添加到批次中
        nopad_total_token_num += seq_len              # 累加总token数
        nopad_max_len_in_batch = max(nopad_max_len_in_batch, input_token_len)  # 更新最大长度
        b_ready_cache_len.append(req.cur_kv_len)      # 记录已缓存长度

    # === 添加padding请求(如果需要)===
    # 当没有真实请求时,添加一个fake请求进行padding
    for _ in range(padded_req_num):
        input_ids.append(1)  # 添加一个dummy token
        nopad_b_req_idx.append(g_infer_context.req_manager.HOLD_REQUEST_ID)  # 使用占位请求ID
        nopad_b_seq_len.append(1)        # 序列长度为1
        b_ready_cache_len.append(0)      # 无已缓存内容
        nopad_total_token_num += 1       # 更新总token数
        nopad_max_len_in_batch = max(nopad_max_len_in_batch, 1)  # 更新最大长度

    # === 转换为CUDA张量 ===
    # 将Python列表转换为GPU上的张量,提高计算效率
    input_ids = torch.tensor(input_ids, dtype=torch.int64, device="cuda")
    nopad_b_req_idx = torch.tensor(nopad_b_req_idx, dtype=torch.int32, device="cuda")
    nopad_b_seq_len = torch.tensor(nopad_b_seq_len, dtype=torch.int32, device="cuda")
    b_ready_cache_len = torch.tensor(b_ready_cache_len, dtype=torch.int32, device="cuda")

    # === 动态内存管理 ===
    # 获取全局推理状态锁,确保内存分配的线程安全
    g_infer_state_lock.acquire()
    
    # 如果启用了radix cache,先释放足够的缓存空间
    if g_infer_context.radix_cache isnotNone:
        g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(
            input_ids.shape[0] - padded_req_num  # 只为真实token分配空间
        )
    
    # 为真实token分配KV cache内存索引
    mem_indexes = g_infer_context.req_manager.mem_manager.alloc(
        input_ids.shape[0] - padded_req_num
    ).cuda()
    
    # 释放锁
    g_infer_state_lock.release()
    
    # === 处理padding token的内存索引 ===
    if padded_req_num > 0:
        # 为padding token创建占位内存索引
        padding_indexs = torch.full(
            (padded_req_num,),
            fill_value=g_infer_context.req_manager.mem_manager.HOLD_TOKEN_MEMINDEX,  # 占位索引值
            dtype=torch.int32,
            device="cuda",
        )
        # 将真实token和padding token的内存索引合并
        mem_indexes = torch.cat((mem_indexes, padding_indexs), dim=0)

    # === 创建PrefillMicroBatch对象 ===
    micro_batch = PrefillMicroBatch(
        batch_size=nopad_b_seq_len.shape[0],          # 批次大小(请求数量)
        total_token_num=nopad_total_token_num,        # 总token数量
        max_len_in_batch=nopad_max_len_in_batch,      # 批次中最大输入长度
        input_ids=input_ids,                          # 输入token序列
        mem_indexes=mem_indexes,                      # KV cache内存索引
        b_req_idx=nopad_b_req_idx,                    # 请求索引
        b_seq_len=nopad_b_seq_len,                    # 序列长度
        b_ready_cache_len=b_ready_cache_len,          # 已缓存长度
        multimodal_params=batch_multimodal_params,    # 多模态参数
    )

    return micro_batch, run_reqs, padded_req_num

0x04 Two MicroBatch Overlap 核心执行流程

0x04.1 MicroBatch Overlap Decode实现







请到「今天看啥」查看全文