正文
"=r"
(weight_fp16.x) :
"r"
(weight_fp16.x),
"r"
(zeros.x))
;
asm
volatile
(
"mul.rn.f16x2 %0, %1, %2;\n"
:
"=r"
(weight_fp16.x) :
"r"
(weight_fp16.x),
"r"
(loaded_scale.x))
;
// 处理第二对fp16值
asm
volatile
(
"sub.f16x2 %0, %1, %2;\n"
:
"=r"
(weight_fp16.y) :
"r"
(weight_fp16.y),
"r"
(zeros.y))
;
asm
volatile
(
"mul.rn.f16x2 %0, %1, %2;\n"
:
"=r"
(weight_fp16.y) :
"r"
(weight_fp16.y),
"r"
(loaded_scale.y))
;
// 处理第三对fp16值
asm
volatile
(
"sub.f16x2 %0, %1, %2;\n"
:
"=r"
(weight_fp16.z) :
"r"
(weight_fp16.z),
"r"
(zeros.z))
;
asm
volatile
(
"mul.rn.f16x2 %0, %1, %2;\n"
:
"=r"
(weight_fp16.z) :
"r"
(weight_fp16.z),
"r"
(loaded_scale.z))
;
// 处理第四对fp16值
asm
volatile
(
"sub.f16x2 %0, %1, %2;\n"
:
"=r"
(weight_fp16.w) :
"r"
(weight_fp16.w),
"r"
(zeros.w))
;
asm
volatile
(
"mul.rn.f16x2 %0, %1, %2;\n"
:
"=r"
(weight_fp16.w) :
"r"
(weight_fp16.w),
"r"
(loaded_scale.w))
;
// 计算输出指针位置并存储结果
half* output_ptr = output +
8
* col +
8
* row * qweight_cols;
*(uint4*)output_ptr = weight_fp16;
}
这里整体是非常好理解的,我们根据线程id定位到当前线程处理的列和行索引之后分别加载零点zeros,缩放系数loaded_scale和权重weight_fp16并对zeros/weight_fp16应用
dequantize_s4_to_fp16x2
反量化kernel把当前行列所在的int32类型的值(8个int4)反量化为8个half类型的输出值,注意这里是用4个half2来存储的。然后使用
(weight - zero) * scale
操作来完成反量化的过程。
这里解析一个
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.x) : "r"(weight_fp16.x), "r"(zeros.x));
指令:
这行代码使用了CUDA PTX,用于执行半精度浮点数(fp16)的减法操作。它的基本语法为:
asm [volatile] ("汇编指令" : 输出操作数 : 输入操作数 : 可能被修改的寄存器);
下面是详细解析:
-
-
-
volatile
修饰符告诉编译器不要优化或重排这段汇编代码,确保它按照指定的顺序执行
-
sub.f16x2 %0, %1, %2;\n
:
-
-
sub.f16x2
是CUDA的指令,表示对两个并排的fp16值(packed half2)执行减法操作
-
%0, %1, %2
是占位符,分别对应后面定义的输出和输入操作数
-
-
: "=r"(weight_fp16.x) : "r"(weight_fp16.x), "r"(zeros.x));
-
第一个冒号后的
"=r"(weight_fp16.x)
是输出操作数,=r 表示这是一个输出到通用寄存器的值
-
第二个冒号后的
"r"(weight_fp16.x)
和
"r"(zeros.x))
是两个输入操作数,r 表示它们来自通用寄存器
通过这个指令就实现了反量化中的减零点的功能,kernel中其它的ptx指令类推。
0x3. dequantize_s4_to_fp16x2 kernel(魔法发生的地方)
这段代码对应的原理在nvidia 2023年夏日专场其实简单讲了一下,我这里结合当时的PPT复述一下这里的原理,通过这个复述读者稍后就可以知道代码中的那一堆魔术和用于计算的PTX指令是做了什么了。注意下面引用的图来BiliBili NVIDIA英伟达频道 上传的《TensorRT-LLM中的 Quantization GEMM(Ampere Mixed GEMM)的 CUTLASS 2.x 实现讲解》。
FasterTransformer 高效的Int8/Int4 快速Convert为FP16
这张slides展示了FP16的IEEE 754标准,一个16bit的数里面包含1个符号位,5个基码位,10个尾数。
假设我们有一个uint8的数143,如果我们把它放到实际的FP16的尾数位里面去,那么我们是否有办法通过合理的设置基码位把143表达出来呢?那我们按照已知的FP16的数值计算方法,拿基码位的二进制前面加上一个1.x,然后去乘以2的(基码位的值-15)次方,我们已知143对应的实际上对应的是下面的值。假设我们想用这个FP16的值来表达Int8,我们可以发现如果x=25的话,我们把上面的FP16的值减去1024就是下面的143了。因此,我们只需要把int8的值放到尾数位,然后把它的基码位设置成25,然后再把FP16的数值结果减去1024就可以得到UINT8转换到FP16的值。
总结一下就是直接把UINT8的数值放在FP16的尾数位,