正文
micro_batch, run_reqs, padded_req_num = _padded_prepare_decode_micro_batch(
req_objs[
0
:micro_batch1_req_num], micro_batch_size, is_multimodal=is_multimodal
)
# 创建第二个micro batch
micro_batch1, run_reqs1, padded_req_num1 = _padded_prepare_decode_micro_batch(
req_objs[micro_batch1_req_num:], micro_batch_size, is_multimodal=is_multimodal
)
return
micro_batch, run_reqs, padded_req_num, micro_batch1, run_reqs1, padded_req_num1
关键实现细节
:
-
均匀分配
:使用
triton.cdiv
确保两个micro batch大小尽可能均匀
-
Padding处理
:对于不足的请求,使用fake请求进行padding
-
内存分配
:每个micro batch都有独立的内存索引
0x03.2 Prefill阶段的MicroBatch拆分
def padded_overlap_prepare_prefill_inputs(req_objs: List[InferReq], max_prefill_num: int, is_multimodal=False):
"""
为Prefill阶段准备重叠执行的两个MicroBatch
Args:
req_objs: 待处理的推理请求列表
max_prefill_num: 最大prefill数量限制
is_multimodal: 是否为多模态输入
Returns:
tuple: (micro_batch, run_reqs, padded_req_num, micro_batch1, run_reqs1, padded_req_num1)
"""
assert max_prefill_num != 0
# 将请求列表平均分成两部分,为两个micro batch做准备
# 使用triton.cdiv确保向上取整,避免遗漏请求
micro_batch1_req_num = triton.cdiv(len(req_objs), 2)
# 创建第一个micro batch:处理前半部分请求
micro_batch, run_reqs, padded_req_num = _padded_prepare_prefill_micro_batch(
req_objs[0:micro_batch1_req_num], is_multimodal=is_multimodal
)
# 创建第二个micro batch:处理后半部分请求
micro_batch1, run_reqs1, padded_req_num1 = _padded_prepare_prefill_micro_batch(
req_objs[micro_batch1_req_num:], is_multimodal=is_multimodal
)
return micro_batch, run_reqs, padded_req_num, micro_batch1, run_reqs1, padded_req_num1
def _padded_prepare_prefill_micro_batch(req_objs: List[InferReq], is_multimodal=False):
"""
为单个MicroBatch准备Prefill阶段的数据
Args:
req_objs: 分配给当前micro batch的请求列表
is_multimodal: 是否支持多模态输入
Returns:
tuple: (micro_batch, run_reqs, padded_req_num)
"""
# === 初始化数据收集变量 ===
run_reqs = [] # 实际运行的请求列表
nopad_total_token_num = 0 # 未padding前的总token数量
nopad_max_len_in_batch = 0 # 批次中的最大输入长度
input_ids = [] # 所有请求的input token ids
nopad_b_req_idx = [] # 请求索引列表
nopad_b_seq_len = [] # 每个请求的序列长度
# prefill阶段只需要padding一个请求形成micro_batch
# 并不需要两个micro batch的batch_size相同(与decode阶段不同)
padded_req_num = 1if len(req_objs) == 0else0
b_ready_cache_len = [] # 每个请求已缓存的KV长度
batch_multimodal_params = [] # 多模态参数列表
# === 处理每个真实请求 ===
for req in req_objs:
run_reqs.append(req)
batch_multimodal_params.append(req.multimodal_params)
nopad_b_req_idx.append(req.req_idx)
# 获取请求的完整input token序列
input_token_ids = req.get_chuncked_input_token_ids()
seq_len = len(input_token_ids)
# 计算需要处理的新token长度(总长度 - 已缓存长度)
input_token_len = seq_len - req.cur_kv_len
# 提取需要处理的新token部分
input_id = input_token_ids[req.cur_kv_len :]
# 收集批次统计信息
nopad_b_seq_len.append(seq_len)
input_ids.extend(input_id) # 将新token添加到批次中
nopad_total_token_num += seq_len # 累加总token数
nopad_max_len_in_batch = max(nopad_max_len_in_batch, input_token_len) # 更新最大长度
b_ready_cache_len.append(req.cur_kv_len) # 记录已缓存长度
# === 添加padding请求(如果需要)===
# 当没有真实请求时,添加一个fake请求进行padding
for _ in range(padded_req_num):
input_ids.append(1) # 添加一个dummy token
nopad_b_req_idx.append(g_infer_context.req_manager.HOLD_REQUEST_ID) # 使用占位请求ID
nopad_b_seq_len.append(1) # 序列长度为1
b_ready_cache_len.append(0) # 无已缓存内容
nopad_total_token_num += 1 # 更新总token数
nopad_max_len_in_batch = max(nopad_max_len_in_batch, 1) # 更新最大长度
# === 转换为CUDA张量 ===
# 将Python列表转换为GPU上的张量,提高计算效率
input_ids = torch.tensor(input_ids, dtype=torch.int64, device="cuda")
nopad_b_req_idx = torch.tensor(nopad_b_req_idx, dtype=torch.int32, device="cuda")
nopad_b_seq_len = torch.tensor(nopad_b_seq_len, dtype=torch.int32, device="cuda")
b_ready_cache_len = torch.tensor(b_ready_cache_len, dtype=torch.int32, device="cuda")
# === 动态内存管理 ===
# 获取全局推理状态锁,确保内存分配的线程安全
g_infer_state_lock.acquire()
# 如果启用了radix cache,先释放足够的缓存空间
if g_infer_context.radix_cache isnotNone:
g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(
input_ids.shape[0] - padded_req_num # 只为真实token分配空间
)
# 为真实token分配KV cache内存索引
mem_indexes = g_infer_context.req_manager.mem_manager.alloc(
input_ids.shape[0] - padded_req_num
).cuda()
# 释放锁
g_infer_state_lock.release()
# === 处理padding token的内存索引 ===
if padded_req_num > 0:
# 为padding token创建占位内存索引
padding_indexs = torch.full(
(padded_req_num,),
fill_value=g_infer_context.req_manager.mem_manager.HOLD_TOKEN_MEMINDEX, # 占位索引值
dtype=torch.int32,
device="cuda",
)
# 将真实token和padding token的内存索引合并
mem_indexes = torch.cat((mem_indexes, padding_indexs), dim=0)
# === 创建PrefillMicroBatch对象 ===
micro_batch = PrefillMicroBatch(
batch_size=nopad_b_seq_len.shape[0], # 批次大小(请求数量)
total_token_num=nopad_total_token_num, # 总token数量
max_len_in_batch=nopad_max_len_in_batch, # 批次中最大输入长度
input_ids=input_ids, # 输入token序列
mem_indexes=mem_indexes, # KV cache内存索引
b_req_idx=nopad_b_req_idx, # 请求索引
b_seq_len=nopad_b_seq_len, # 序列长度
b_ready_cache_len=b_ready_cache_len, # 已缓存长度
multimodal_params=batch_multimodal_params, # 多模态参数
)
return micro_batch, run_reqs, padded_req_num
0x04 Two MicroBatch Overlap 核心执行流程
0x04.1 MicroBatch Overlap Decode实现