专栏名称: GiantPandaLLM
专注于机器学习、深度学习、计算机视觉、图像处理等多个方向技术分享。团队由一群热爱技术且热衷于分享的小伙伴组成。我们坚持原创,每天一到两篇原创技术分享。希望在传播知识、分享知识的同时能够启发你,大家一起共同进步(・ω<)☆
目录
相关文章推荐
GiantPandaLLM  ·  【博客转载】Row-Major VS ... ·  昨天  
51好读  ›  专栏  ›  GiantPandaLLM

基于o1-preview解读 Optimized GPTQ INT4 Dequantization ...

GiantPandaLLM  · 公众号  · 3D  · 2024-09-16 22:00

正文

请到「今天看啥」查看全文



粘贴一下o1-preview-128k给出的回复:

# 导入 PyTorch 库
import torch
# 导入 Triton 库
import triton
# 从 Triton 中导入语言模块,简写为 tl
from triton import language as tl

# 使用 @triton.jit 装饰器,将函数编译为 Triton 内核
@triton.jit()
def swizzle_tile(pid,
                 m, n,
                 block_m: tl.constexpr, block_n: tl.constexpr, group_m: tl.constexpr)
:

    # 计算 M 方向上的网格大小,向上取整
    grid_m = tl.cdiv(m, block_m)
    # 计算 N 方向上的网格大小,向上取整
    grid_n = tl.cdiv(n, block_n)

    # 计算每个组的宽度,即组内程序数
    width = group_m * grid_n
    # 计算当前组的 ID,pid 为程序的全局 ID
    group_id = pid // width
    # 计算当前组的大小,防止越界
    group_size = tl.minimum(grid_m - group_id * group_m, group_m)

    # 计算当前程序在 M 维度上的块索引 pid_m
    pid_m = group_id * group_m + (pid % group_size)
    # 计算当前程序在 N 维度上的块索引 pid_n
    pid_n = (pid % width) // group_size

    # 返回计算的 pid_m 和 pid_n,用于后续计算
    return pid_m, pid_n

# 定义矩阵乘法的 Triton 内核,支持 K 维度的分块计算(Split-K)
@triton.jit()
def matmul_split_k_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr,
                          stride_am, stride_ak,
                          stride_bk, stride_bn,
                          stride_cm, stride_cn,
                          stride_scales_g, stride_scales_n,
                          stride_zeros_g, stride_zeros_n,
                          groupsize,
                          m, n, k,
                          block_m: tl.constexpr, block_n: tl.constexpr, block_k: tl.constexpr,
                          group_m: tl.constexpr, split_k: tl.constexpr)
:

    # 获取当前程序的 ID,在第一个维度(M*N 维度)
    pid = tl.program_id(0)
    # 获取在 K 维度上的程序 ID
    pid_k = tl.program_id(1)
    # 计算 K 维度上总的块数,向上取整
    total_blocks_k = tl.cdiv(k, block_k * split_k)

    # 使用自定义的 swizzle_tile 函数计算当前程序对应的块索引
    pid_m, pid_n = swizzle_tile(pid,
                                m, n,
                                block_m, block_n, group_m)

    # 计算当前程序在 M、N、K 维度上的元素偏移
    offs_m = pid_m * block_m + tl.arange(0, block_m)
    offs_n = pid_n * block_n + tl.arange(0, block_n)
    offs_k = pid_k * block_k + tl.arange(0, block_k)

    # 确保 offs_m 和 offs_n 的连续性和对齐
    offs_am = tl.max_contiguous(tl.multiple_of(offs_m, block_m), block_m)
    offs_bn = tl.max_contiguous(tl.multiple_of(offs_n, block_n), block_n)

    # 计算矩阵 A 和矩阵 B 中当前块的指针
    a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
    b_ptrs = b_ptr + ((offs_k[:, None] // 8) * stride_bk + offs_bn[None, :] * stride_bn)

    # 计算 scales 和 zeros 的指针
    scales_ptrs = scales_ptr + offs_bn * stride_scales_n
    zeros_ptrs = zeros_ptr + ((offs_bn // 8) * stride_zeros_n)

    # 计算移位量,用于从压缩的表示中提取实际的值
    shifter = (offs_k % 8) * 4
    zeros_shifter = (offs_bn % 8) * 4

    # 初始化累加器为 0,形状为 (block_m, block_n),数据类型为 float32
    acc = tl.zeros((block_m, block_n), dtype=tl.float32)
    # 遍历 K 维度上的所有块
    for k in range(0, total_blocks_k):
        # 从全局内存中加载矩阵 A 和矩阵 B 的当前块
        a = tl.load(a_ptrs)
        b = tl.load(b_ptrs)

        # 计算当前分组的 ID,用于获取对应的 scales 和 zeros
        g_id = (k * split_k + pid_k) // (groupsize // block_k)

        # 加载对应的 scales
        ptr = scales_ptrs + g_id * stride_scales_g
        scales = tl.load(ptr)

        # 加载对应的 zeros
        ptr = zeros_ptrs + g_id * stride_zeros_g
        zeros = tl.load(ptr)






请到「今天看啥」查看全文