专栏名称: GiantPandaLLM
专注于机器学习、深度学习、计算机视觉、图像处理等多个方向技术分享。团队由一群热爱技术且热衷于分享的小伙伴组成。我们坚持原创,每天一到两篇原创技术分享。希望在传播知识、分享知识的同时能够启发你,大家一起共同进步(・ω<)☆
目录
相关文章推荐
GiantPandaLLM  ·  【博客转载】Row-Major VS ... ·  16 小时前  
GiantPandaLLM  ·  【博客转载】CUDA Coalesced ... ·  2 天前  
GiantPandaLLM  ·  【博客转载】C++/CUDA Data ... ·  4 天前  
GiantPandaLLM  ·  【博客转载】CUDA Kernel ... ·  4 天前  
51好读  ›  专栏  ›  GiantPandaLLM

SGLang的Expert Parallel特性解读

GiantPandaLLM  · 公众号  · 3D  · 2025-01-11 22:14

主要观点总结

SGLang实现了Expert Parallel(EPMoE),这是其率先在开源推理框架中实现的。SGLang通过修改上层接口和底层实现,特别是利用GroupedGemmRunner类进行矩阵乘法,并实现了EPMoE类和其Forward方法,实现了类似EP MoE训练时的步骤。SGLang EP MoE Kernel通过预重排序和两次Group GEMM,以及两次重排序,最终得到最终输出。EPMoE和MoE EP训练流程的区别在于,EPMoE在推理时通过优化All2All流程来降低通信成本。SGLang EPMoE计算流程中最耗时的Group GEMM尚未使用FalshInfer的优化版本,因此可能效率不高。

关键观点总结

关键观点1: SGLang实现Expert Parallel(EPMoE)

SGLang是开源推理框架中率先实现EPMoE的。

关键观点2: 上层接口和底层实现修改

通过修改上层接口和底层实现,特别是利用GroupedGemmRunner类进行矩阵乘法,并实现了EPMoE类和其Forward方法。

关键观点3: EPMoE计算流程

通过预重排序和两次Group GEMM,以及两次重排序,得到最终输出。

关键观点4: EPMoE和MoE EP训练流程的区别

EPMoE在推理时通过优化All2All流程来降低通信成本。

关键观点5: EPMoE效率问题

SGLang EPMoE计算流程中最耗时的Group GEMM尚未使用FalshInfer的优化版本,可能效率不高。


正文

请到「今天看啥」查看全文


EPMoE类的定义
class EPMoE(torch.nn.Module):
    """
    MoE专家并行实现
    
    Args:
        num_experts: 专家总数
        top_k: 每个token选择的专家数量
        hidden_size: 隐藏层大小
        intermediate_size: 中间层大小
        params_dtype: 参数数据类型,默认为None使用系统默认类型
        renormalize: 是否重新归一化,默认True
        use_grouped_topk: 是否使用分组topk,默认False
        num_expert_group: 专家组数量,仅在use_grouped_topk=True时使用
        topk_group: 每组选择的专家数量,仅在use_grouped_topk=True时使用
        quant_config: 量化配置,默认None
        tp_size: 张量并行大小,默认None
        prefix: 前缀,默认空字符串
        correction_bias: 修正偏置,默认None
    """


    def __init__(
        self,
        num_experts: int,
        top_k: int,
        hidden_size: int,
        intermediate_size: int,
        params_dtype: Optional[torch.dtype] = None,
        renormalize: bool = True,
        use_grouped_topk: bool = False,
        num_expert_group: Optional[int] = None,
        topk_group: Optional[int] = None,
        quant_config: Optional[QuantizationConfig] = None,
        tp_size: Optional[int] = None,
        prefix: str = "",
        correction_bias: Optional[torch.Tensor] = None,
    )
:

        super().__init__()

        # 如果未指定参数类型,使用系统默认类型
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()

        # 设置张量并行相关参数
        self.tp_size = (
            tp_size if tp_size is not None else get_tensor_model_parallel_world_size()
        )
        self.tp_rank = get_tensor_model_parallel_rank()

        # 设置专家相关参数
        self.num_experts = num_experts
        assert self.num_experts % self.tp_size == 0  # 确保专家数可以被tp_size整除
        self.num_experts_per_partition = self.num_experts // self.tp_size  # 每个分区的专家数
        self.start_expert_id = self.tp_rank * self.num_experts_per_partition  # 当前分区起始专家ID
        self.end_expert_id = self.start_expert_id + self.num_experts_per_partition - 1  # 当前分区结束专家ID

        # 设置其他参数
        self.top_k = top_k
        self.intermediate_size = intermediate_size
        self.renormalize = renormalize
        self.use_grouped_topk = use_grouped_topk
        if self.use_grouped_topk:
            assert num_expert_group is not None and topk_group is not None
        self.num_expert_group = num_expert_group
        self.topk_group = topk_group
        self.correction_bias = correction_bias

        # 设置量化方法
        if quant_config is None:
            self.quant_method: Optional[QuantizeMethodBase] = UnquantizedEPMoEMethod()
            self.use_fp8_w8a8 = False
            self.activation_scheme = None
        else:
            self.quant_method: Optional[QuantizeMethodBase] = Fp8EPMoEMethod(
                quant_config
            )
            self.use_fp8_w8a8 = True
            self.fp8_dtype = torch.float8_e4m3fn
            self.activation_scheme = quant_config.activation_scheme

        # 创建权重
        self.quant_method.create_weights(
            layer=self,
            num_experts_per_partition=self.num_experts_per_partition,
            hidden_size=hidden_size,
            intermediate_size=self.intermediate_size,
            params_dtype=params_dtype,
            weight_loader=self.weight_loader,
        )

        # 初始化分组矩阵乘法运行器
        self.grouped_gemm_runner = None

这个类定义中我们可以看到它主要是做一些准备工作,同时EPMoE复用了Tensor Parallel的进程组,所以也是直接在Tensor Parallel进程组上获取当前Rank需要处理的是哪些Expert ID。

EPMoE 类的 Forward

简单添加几行注释:

def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
        """前向传播函数
        Args:
            hidden_states: 输入的隐藏状态张量
            router_logits: 路由器输出的logits张量
        Returns:
            output: 经过MoE层处理后的输出张量
        """

        assert self.quant_method is not None

        # 初始化分组矩阵乘法运行器
        if self.grouped_gemm_runner is None:
            self.grouped_gemm_runner = GroupedGemmRunner(
                hidden_states.device, use_flashinfer=False  TODO: use flashinfer
            )

        # 选择专家,获取topk权重和ID
        topk_weights, topk_ids = select_experts(
            hidden_states=hidden_states,
            router_logits=router_logits,
            top_k=self.top_k,
            use_grouped_topk=self.use_grouped_topk,
            renormalize=self.renormalize,
            topk_group=self.topk_group,
            num_expert_group=self.num_expert_group,
            correction_bias=self.correction_bias,
        )

        # 预处理topk ID,获取重排序信息
        reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(
            topk_ids, self.num_experts
        )

        # 初始化门控输入张量
        gateup_input = torch.empty(
            (int(hidden_states.shape[0] * self.top_k), hidden_states.shape[1]),
            device=hidden_states.device,
            dtype=self.fp8_dtype if self.use_fp8_w8a8 else hidden_states.dtype,
        )
        
        # 动态量化时计算输入缩放因子
        if self.activation_scheme == "dynamic":
            max_value = (
                torch.max(hidden_states)
                .repeat(self.num_experts_per_partition)
                .to(torch.float32)
            )
            self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max

        # 预重排序,重新排列输入数据
        pre_reorder_triton_kernel[(hidden_states.shape[0],)](
            hidden_states,
            gateup_input,
            src2dst,
            topk_ids,
            self.w13_input_scale,
            self.start_expert_id,
            self.end_expert_id,
            self.top_k,
            hidden_states.shape[1],
            BLOCK_SIZE=512,
        )

        # 获取当前rank的分段指针和权重索引
        seg_indptr_cur_rank = seg_indptr[self.start_expert_id : self.end_expert_id + 2]
        weight_indices_cur_rank = torch.arange(
            0,
            self.num_experts_per_partition,
            device=hidden_states.device,
            dtype=torch.int64,
        )
        
        # 第一次分组矩阵乘法
        gateup_output = torch.empty(
            gateup_input.shape[0],
            self.w13_weight.shape[1],
            device=hidden_states.device,
            dtype=hidden_states.dtype,
        )
        gateup_output = self.grouped_gemm_runner(
            a=gateup_input,
            b=self.w13_weight,
            c=gateup_output,
            batch_size=self.num_experts_per_partition,
            weight_column_major=True,
            seg_indptr=seg_indptr_cur_rank,
            weight_indices=weight_indices_cur_rank,
            use_fp8_w8a8=self.use_fp8_w8a8,
            scale_a=self.w13_input_scale,
            scale_b=self.w13_weight_scale,
        )

        # 激活函数处理
        down_input = torch.empty(
            gateup_output.shape[0],
            gateup_output.shape[1] // 2,
            device=gateup_output.device,
            dtype=self.fp8_dtype if self.use_fp8_w8a8 else hidden_states.dtype,
        )
        if self.w2_input_scale is None:
            self.w2_input_scale = torch.ones(
                self.num_experts_per_partition,
                dtype=torch.float32,
                device=hidden_states.device,
            )
        silu_and_mul_triton_kernel[(gateup_output.shape[0],)](
            gateup_output,
            down_input,
            gateup_output.shape[1],
            reorder_topk_ids,
            self.w2_input_scale,
            self.start_expert_id,
            self.end_expert_id,
            BLOCK_SIZE=512,
        )

        # 第二次分组矩阵乘法
        down_output = torch.empty(
            down_input.shape[0],
            self.w2_weight.shape[1],
            device=hidden_states.device,
            dtype=hidden_states.dtype,
        )
        down_output = self.grouped_gemm_runner(
            a=down_input,
            b=self.w2_weight,
            c=down_output,
            batch_size=self.num_experts_per_partition,
            weight_column_major=True,
            seg_indptr=seg_indptr_cur_rank,
            weight_indices=weight_indices_cur_rank,
            use_fp8_w8a8=self.use_fp8_w8a8,
            scale_a=self.w2_input_scale,
            scale_b=self.w2_weight_scale,
        )

        # 后重排序,生成最终输出
        output = torch.empty_like(hidden_states)
        post_reorder_triton_kernel[(hidden_states.size(0),)](
            down_output,
            output,
            src2dst,
            topk_ids,
            topk_weights,
            self.start_expert_id,
            self.end_expert_id,
            self.top_k,
            hidden_states.size(1),
            BLOCK_SIZE=512,
        )
        return output

这个forward函数的流程还是比较清晰的:

  • 首先根据router_logits选择每个token要使用的top-k个专家及其权重
  • 对输入数据进行预处理和重排序,将相同专家的数据分组在一起以便后续批量计算
  • 执行第一次分组矩阵乘法(grouped gemm),将输入与gate和up投影权重(w13_weight)相乘
  • 对第一次矩阵乘法的结果应用SiLU激活函数并进行处理
  • 执行第二次分组矩阵乘法,将激活后的结果与down投影权重(w2_weight)相乘
  • 最后进行后重排序,将各个专家的输出按原始token顺序重组,并根据专家权重进行加权组合得到最终输出

这个过程基本上和EP MoE训练时的步骤一致,其中第二步和最后一步就对应了EP中的两次All2All。

权重加载逻辑

笔者注 :对于本篇文章的主题来说,可以不用在意这几个工具函数。







请到「今天看啥」查看全文