专栏名称: GiantPandaLLM
专注于机器学习、深度学习、计算机视觉、图像处理等多个方向技术分享。团队由一群热爱技术且热衷于分享的小伙伴组成。我们坚持原创,每天一到两篇原创技术分享。希望在传播知识、分享知识的同时能够启发你,大家一起共同进步(・ω<)☆
目录
相关文章推荐
GiantPandaLLM  ·  【博客转载】Row-Major VS ... ·  9 小时前  
GiantPandaLLM  ·  【博客转载】CUDA Coalesced ... ·  2 天前  
GiantPandaLLM  ·  【博客转载】C++/CUDA Data ... ·  3 天前  
GiantPandaLLM  ·  【博客转载】CUDA Kernel ... ·  4 天前  
51好读  ›  专栏  ›  GiantPandaLLM

Triton Kernel 编译阶段

GiantPandaLLM  · 公众号  · 3D  · 2024-12-31 23:17

正文

请到「今天看啥」查看全文




def add (x: torch.Tensor, y: torch.Tensor) :
output = torch.empty_like(x)
assert x.is_cuda and y.is_cuda and output.is_cuda
n_elements = output.numel()

grid = lambda meta: (triton.cdiv(n_elements, meta[ 'BLOCK_SIZE' ]), )
triton_kernel=add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE= 1024 )
torch.cuda.synchronize()

# Save compilation stages - some of the stages identified here are specific to NVIDIA devices:
with open( 'triton_IR.txt' , 'w' ) as f:
print(triton_kernel.asm[ 'ttir' ], file=f)
with open( 'triton_TTGIR.txt' , 'w' ) as f:
print(triton_kernel.asm[ 'ttgir' ], file=f)
with open( 'triton_LLVMIR.txt' , 'w' ) as f:
print(triton_kernel.asm[ 'llir' ], file=f)
with open( 'triton_PTX.ptx' , 'w' ) as f:
print(triton_kernel.asm[ 'ptx' ], file=f)
with open( 'triton_cubin.txt' , 'w' ) as f:
print(triton_kernel.asm[ 'cubin' ], file=f)

return output

torch.manual_seed( 0 )
size = 98432
x = torch.rand(size, device= 'cuda' )
y = torch.rand(size, device= 'cuda' )
output_torch = x + y
output_triton = add(x, y)
print(output_torch)
print(output_triton)
print( f'The maximum difference between torch and triton is '
f' {torch.max(torch.abs(output_torch - output_triton))} ' )

Triton 向量加法kernel包含 @triton.jit 装饰器。Triton 编译器会编译被 @triton.jit 标记的函数,通过多个编译阶段将函数降级。辅助函数 add 分配输出张量,计算适当的 GPU 网格大小,并额外保存中间编译阶段。

聚焦于编译过程,Triton 内核通过以下图中所示的一系列阶段被降级为设备特定的汇编代码。

内核编译首先通过遍历被装饰的Python函数的抽象语法树(AST)来创建Triton中间表示(Triton-IR)。Triton-IR是一个未优化的、与机器无关的中间表示。它引入了块级编程要求,并基于开源LLVM编译器项目。接下来,Triton编译器优化并将Triton-IR转换为Triton-GPU IR(Triton-TTGIR)阶段,然后转换为LLVM-IR。Triton-IR和Triton-GPUIR表示都是以MLIR Dialect的形式编写的,其中MLIR是LLVM的一个子项目,旨在改进异构硬件的编译。







请到「今天看啥」查看全文