主要观点总结
SGLang实现了Expert Parallel(EPMoE),这是其率先在开源推理框架中实现的。SGLang通过修改上层接口和底层实现,特别是利用GroupedGemmRunner类进行矩阵乘法,并实现了EPMoE类和其Forward方法,实现了类似EP MoE训练时的步骤。SGLang EP MoE Kernel通过预重排序和两次Group GEMM,以及两次重排序,最终得到最终输出。EPMoE和MoE EP训练流程的区别在于,EPMoE在推理时通过优化All2All流程来降低通信成本。SGLang EPMoE计算流程中最耗时的Group GEMM尚未使用FalshInfer的优化版本,因此可能效率不高。
关键观点总结
关键观点1: SGLang实现Expert Parallel(EPMoE)
SGLang是开源推理框架中率先实现EPMoE的。
关键观点2: 上层接口和底层实现修改
通过修改上层接口和底层实现,特别是利用GroupedGemmRunner类进行矩阵乘法,并实现了EPMoE类和其Forward方法。
关键观点3: EPMoE计算流程
通过预重排序和两次Group GEMM,以及两次重排序,得到最终输出。
关键观点4: EPMoE和MoE EP训练流程的区别
EPMoE在推理时通过优化All2All流程来降低通信成本。
关键观点5: EPMoE效率问题
SGLang EPMoE计算流程中最耗时的Group GEMM尚未使用FalshInfer的优化版本,可能效率不高。
正文
EPMoE类的定义
class EPMoE(torch.nn.Module):
"""
MoE专家并行实现
Args:
num_experts: 专家总数
top_k: 每个token选择的专家数量
hidden_size: 隐藏层大小
intermediate_size: 中间层大小
params_dtype: 参数数据类型,默认为None使用系统默认类型
renormalize: 是否重新归一化,默认True
use_grouped_topk: 是否使用分组topk,默认False
num_expert_group: 专家组数量,仅在use_grouped_topk=True时使用
topk_group: 每组选择的专家数量,仅在use_grouped_topk=True时使用
quant_config: 量化配置,默认None
tp_size: 张量并行大小,默认None
prefix: 前缀,默认空字符串
correction_bias: 修正偏置,默认None
"""
def __init__(
self,
num_experts: int,
top_k: int,
hidden_size: int,
intermediate_size: int,
params_dtype: Optional[torch.dtype] = None,
renormalize: bool = True,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None,
prefix: str = "",
correction_bias: Optional[torch.Tensor] = None,
):
super().__init__()
# 如果未指定参数类型,使用系统默认类型
if params_dtype is None:
params_dtype = torch.get_default_dtype()
# 设置张量并行相关参数
self.tp_size = (
tp_size if tp_size is not None else get_tensor_model_parallel_world_size()
)
self.tp_rank = get_tensor_model_parallel_rank()
# 设置专家相关参数
self.num_experts = num_experts
assert self.num_experts % self.tp_size == 0 # 确保专家数可以被tp_size整除
self.num_experts_per_partition = self.num_experts // self.tp_size # 每个分区的专家数
self.start_expert_id = self.tp_rank * self.num_experts_per_partition # 当前分区起始专家ID
self.end_expert_id = self.start_expert_id + self.num_experts_per_partition - 1 # 当前分区结束专家ID
# 设置其他参数
self.top_k = top_k
self.intermediate_size = intermediate_size
self.renormalize = renormalize
self.use_grouped_topk = use_grouped_topk
if self.use_grouped_topk:
assert num_expert_group is not None and topk_group is not None
self.num_expert_group = num_expert_group
self.topk_group = topk_group
self.correction_bias = correction_bias
# 设置量化方法
if quant_config is None:
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedEPMoEMethod()
self.use_fp8_w8a8 = False
self.activation_scheme = None
else:
self.quant_method: Optional[QuantizeMethodBase] = Fp8EPMoEMethod(
quant_config
)
self.use_fp8_w8a8 = True
self.fp8_dtype = torch.float8_e4m3fn
self.activation_scheme = quant_config.activation_scheme
# 创建权重
self.quant_method.create_weights(
layer=self,
num_experts_per_partition=self.num_experts_per_partition,
hidden_size=hidden_size,
intermediate_size=self.intermediate_size,
params_dtype=params_dtype,
weight_loader=self.weight_loader,
)
# 初始化分组矩阵乘法运行器
self.grouped_gemm_runner = None
这个类定义中我们可以看到它主要是做一些准备工作,同时EPMoE复用了Tensor Parallel的进程组,所以也是直接在Tensor Parallel进程组上获取当前Rank需要处理的是哪些Expert ID。
EPMoE 类的 Forward
简单添加几行注释:
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
"""前向传播函数
Args:
hidden_states: 输入的隐藏状态张量
router_logits: 路由器输出的logits张量
Returns:
output: 经过MoE层处理后的输出张量
"""
assert self.quant_method is not None
# 初始化分组矩阵乘法运行器
if self.grouped_gemm_runner is None:
self.grouped_gemm_runner = GroupedGemmRunner(
hidden_states.device, use_flashinfer=False # TODO: use flashinfer
)
# 选择专家,获取topk权重和ID
topk_weights, topk_ids = select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=self.top_k,
use_grouped_topk=self.use_grouped_topk,
renormalize=self.renormalize,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
correction_bias=self.correction_bias,
)
# 预处理topk ID,获取重排序信息
reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(
topk_ids, self.num_experts
)
# 初始化门控输入张量
gateup_input = torch.empty(
(int(hidden_states.shape[0] * self.top_k), hidden_states.shape[1]),
device=hidden_states.device,
dtype=self.fp8_dtype if self.use_fp8_w8a8 else hidden_states.dtype,
)
# 动态量化时计算输入缩放因子
if self.activation_scheme == "dynamic":
max_value = (
torch.max(hidden_states)
.repeat(self.num_experts_per_partition)
.to(torch.float32)
)
self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max
# 预重排序,重新排列输入数据
pre_reorder_triton_kernel[(hidden_states.shape[0],)](
hidden_states,
gateup_input,
src2dst,
topk_ids,
self.w13_input_scale,
self.start_expert_id,
self.end_expert_id,
self.top_k,
hidden_states.shape[1],
BLOCK_SIZE=512,
)
# 获取当前rank的分段指针和权重索引
seg_indptr_cur_rank = seg_indptr[self.start_expert_id : self.end_expert_id + 2]
weight_indices_cur_rank = torch.arange(
0,
self.num_experts_per_partition,
device=hidden_states.device,
dtype=torch.int64,
)
# 第一次分组矩阵乘法
gateup_output = torch.empty(
gateup_input.shape[0],
self.w13_weight.shape[1],
device=hidden_states.device,
dtype=hidden_states.dtype,
)
gateup_output = self.grouped_gemm_runner(
a=gateup_input,
b=self.w13_weight,
c=gateup_output,
batch_size=self.num_experts_per_partition,
weight_column_major=True,
seg_indptr=seg_indptr_cur_rank,
weight_indices=weight_indices_cur_rank,
use_fp8_w8a8=self.use_fp8_w8a8,
scale_a=self.w13_input_scale,
scale_b=self.w13_weight_scale,
)
# 激活函数处理
down_input = torch.empty(
gateup_output.shape[0],
gateup_output.shape[1] // 2,
device=gateup_output.device,
dtype=self.fp8_dtype if self.use_fp8_w8a8 else hidden_states.dtype,
)
if self.w2_input_scale is None:
self.w2_input_scale = torch.ones(
self.num_experts_per_partition,
dtype=torch.float32,
device=hidden_states.device,
)
silu_and_mul_triton_kernel[(gateup_output.shape[0],)](
gateup_output,
down_input,
gateup_output.shape[1],
reorder_topk_ids,
self.w2_input_scale,
self.start_expert_id,
self.end_expert_id,
BLOCK_SIZE=512,
)
# 第二次分组矩阵乘法
down_output = torch.empty(
down_input.shape[0],
self.w2_weight.shape[1],
device=hidden_states.device,
dtype=hidden_states.dtype,
)
down_output = self.grouped_gemm_runner(
a=down_input,
b=self.w2_weight,
c=down_output,
batch_size=self.num_experts_per_partition,
weight_column_major=True,
seg_indptr=seg_indptr_cur_rank,
weight_indices=weight_indices_cur_rank,
use_fp8_w8a8=self.use_fp8_w8a8,
scale_a=self.w2_input_scale,
scale_b=self.w2_weight_scale,
)
# 后重排序,生成最终输出
output = torch.empty_like(hidden_states)
post_reorder_triton_kernel[(hidden_states.size(0),)](
down_output,
output,
src2dst,
topk_ids,
topk_weights,
self.start_expert_id,
self.end_expert_id,
self.top_k,
hidden_states.size(1),
BLOCK_SIZE=512,
)
return output
这个forward函数的流程还是比较清晰的:
-
首先根据router_logits选择每个token要使用的top-k个专家及其权重
-
对输入数据进行预处理和重排序,将相同专家的数据分组在一起以便后续批量计算
-
执行第一次分组矩阵乘法(grouped gemm),将输入与gate和up投影权重(w13_weight)相乘
-
对第一次矩阵乘法的结果应用SiLU激活函数并进行处理
-
执行第二次分组矩阵乘法,将激活后的结果与down投影权重(w2_weight)相乘
-
最后进行后重排序,将各个专家的输出按原始token顺序重组,并根据专家权重进行加权组合得到最终输出
这个过程基本上和EP MoE训练时的步骤一致,其中第二步和最后一步就对应了EP中的两次All2All。
权重加载逻辑
笔者注
:对于本篇文章的主题来说,可以不用在意这几个工具函数。