主要观点总结
文章主要介绍了作者如何使用CUDA对MiniMax的Lightning Attention模块进行优化,并展示了使用Cursor+NCU进行CUDA优化的过程。文章详细说明了在GPU上加速RWKV6模型的Linear Attention计算,硬件高效的线性注意力机制Gated Linear Attention论文阅读,以及如何实现MiniMaxText01LightningAttention模块。通过对比不同的实现方式,如Naive版本、Triton优化版本和CUDA版本,并利用NCU工具进行性能分析,找出了性能瓶颈并进行了优化。文章最后总结了使用Cursor Claude-sonnet-3.5-2024102进行CUDA优化的限制,并强调了人工反馈的重要性,不推荐直接使用AI生成的优化代码。
关键观点总结
关键观点1: 使用CUDA对MiniMax的Lightning Attention模块进行优化
文章通过对比不同的实现方式,如Naive版本、Triton优化版本和CUDA版本,展示了如何使用CUDA进行优化。
关键观点2: 使用Cursor+NCU进行CUDA优化
文章详细介绍了如何使用Cursor和NCU工具进行性能分析,找出了性能瓶颈并进行了优化。
关键观点3: 关于Linear Attention架构的算法原理和做推理的优势
文章提到了硬件高效的线性注意力机制Gated Linear Attention论文阅读,并参考了之前的blog。
关键观点4: 实现MiniMaxText01LightningAttention模块
文章说明了如何在SGLang推理框架中支持MiniMax Text01模型,并建立了Prefill和Decode过程的优化算子和Benchmark。
关键观点5: 关于使用Cursor Claude-sonnet-3.5-2024102进行CUDA优化的限制
文章最后总结了使用Cursor Claude-sonnet-3.5-2024102这种最先进的大模型进行CUDA优化的限制,并强调了人工反馈的重要性。
正文
详细数据可以参考 https://github.com/sgl-project/sglang/pull/3030
最后,这个kernel还有非常大的可提升空间,不过这不是本文重点,本文的重点是我将在下一节演示一下我是如何使用Cursor+NCU来联合优化CUDA Kernel的,如果你想在Cursor中使用最先进的Claude-3.5-sonnet-20241022来直接给你写出性能不错的CUDA kernel,根据我的使用记录来看是非常困难的。大模型既不会给你避免Bank Conflict,也不会给你合并内存访问,并且大多数时候还会给你写出效率非常低的Python直译cuda代码。然而Cursor下的Claude-3.5-sonnet-2024102有多模态功能是可以看懂图片的,所以我们可以把NCU的一些关键Profile信息给他,手工强化学习,我稍后会演示如何利用NCU的结果让Cursor更聪明,从而写出我们想要的优化代码。
0x1. 实操版
0x1.1 Triton naive版本
kernel代码:https://github.com/sgl-project/sglang/pull/2920/files#diff-16ed66afc4b7f52545a3fffd55c9fd6daaf87189d9a0d252fccba42951c1cc40R14-R105
首先是一个最Naive的版本,对于q,k,v的每个头使用一个Block来计算,也就是一共有
个Block,然后每个头的维度都从92 padding到128来满足Triton kernel的计算需求。
从上面的性能结果来看,和原始的PyTorch实现几乎没有区别。
0x1.2 Triton 优化版本
https://github.com/sgl-project/sglang/pull/2966
把上面的naive版本的Triton kernel之前的手动Padding到128移除了,然后在kernel中使用Mask的方式来解决dim维度没有对齐到2的幂次的问题。从上面的结果可以看到,Lightning Attention模块的端到端时间确实是下降了一些。
0x1.3 CUDA 版本
把上面那几行 Lighting Attention Decode Python代码直接扔给Cursor Sonnet 3.5 20241022模型,然后它很快就产生了一份cuda kernel。
#define THREADS_PER_BLOCK 128
template<typename T>
__global__ void lightning_attention_decode_kernel(
const T* __restrict__ q, // [b, h, 1, d]
const T* __restrict__ k, // [b, h, 1, d]
const T* __restrict__ v, // [b, h, 1, e]
const float* __restrict__ past_kv, // [b, h, d, e]
const float* __restrict__ slope, // [h, 1, 1]
T* __restrict__ output, // [b, h, 1, e]