专栏名称: GiantPandaLLM
专注于机器学习、深度学习、计算机视觉、图像处理等多个方向技术分享。团队由一群热爱技术且热衷于分享的小伙伴组成。我们坚持原创,每天一到两篇原创技术分享。希望在传播知识、分享知识的同时能够启发你,大家一起共同进步(・ω<)☆
目录
相关文章推荐
GiantPandaLLM  ·  【博客转载】C++/CUDA Data ... ·  昨天  
GiantPandaLLM  ·  【博客转载】CUDA Kernel ... ·  昨天  
GiantPandaLLM  ·  [Triton编程][基础]vLLM ... ·  2 天前  
51好读  ›  专栏  ›  GiantPandaLLM

[Triton编程][基础] Triton Fused Softmax Kernel详解: 从Pyt...

GiantPandaLLM  · 公众号  · 3D  · 2025-05-29 17:31

正文

请到「今天看啥」查看全文



def naive_softmax ( x ):
"""Compute row-wise softmax of X using native pytorch

We subtract the maximum element in order to avoid overflows. Softmax is invariant to
this shift.
"""
# read  MN elements ; write M  elements; 读取MN元素;写M个元素
x_max = x . max ( dim = 1 )[ 0 ]
# read MN + M elements ; write MN elements; 读取MN+M元素;写入MN元素
z = x - x_max [:, None ]
# read  MN elements ; write MN elements; 读取MN元素;写入MN元素
numerator = torch . exp ( z )
# read  MN elements ; write M  elements; 读取MN元素;写M个元素
denominator = numerator . sum ( dim = 1 )
# read MN + M elements ; write MN elements; 读取MN M元素;写入MN元素
ret = numerator / denominator [:, None ]
# in total: read 5MN + 2M elements ; wrote 3MN + 2M elements;
return ret # 共:读取5MN+2M元素;写了3MN+2M个元素

从代码中的注释可知,naive softmax的访存量为:读取5MN+2M元素;写了3MN+2M个元素;即 8MN+4M ;

0x02 Triton Fused Softmax实现

softmax_kernel的主要思路为:给kernel分配num_programs个programs(也就是thread blocks,后边都把program等同于thread block),每个thread block处理互不重合的一部分rows;对每个row,按行求safe softmax,先求max,再求exp,最后求:softmax_output = numerator / denominator。这个softmax_kernel只需要对x进行读操作一次,以及对y进行写操作一次,对比naive softmax的 8MN+4M 访存量,Triton softmax_kernel只需要 2MN 的访存量, 约为原来的1/4

@triton.jit
defsoftmax_kernel(output_ptr,input_ptr,input_row_stride,output_row_stride,n_rows,n_cols,BLOCK_SIZE:tl.constexpr,
num_stages:tl.constexpr):
# starting row of the program
row_start=tl.program_id(0)
row_step=tl.num_programs(0)
forrow_idxintl.range(row_start,n_rows,row_step,num_stages=num_stages):
# The stride represents how much we need to increase the pointer to advance 1 row
row_start_ptr=input_ptr+row_idx*input_row_stride
# The block size is the next power of two greater than n_cols, so we can fit each
# row in a single block
col_offsets=tl.arange(0,BLOCK_SIZE)
input_ptrs=row_start_ptr+col_offsets
# Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols
mask=col_offsets<n_cols
row=tl.load(input_ptrs,mask=mask,other=-float('inf'))
# Subtract maximum for numerical stability
row_minus_max=row-tl.max(row,axis=0)
# Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA)
numerator=tl.exp(row_minus_max)
denominator=tl.sum(numerator,axis=0)
softmax_output=numerator/denominator
# Write back output to DRAM
output_row_start_ptr=output_ptr+row_idx*output_row_stride
output_ptrs=output_row_start_ptr+col_offsets
tl.store(output_ptrs,softmax_output,mask=mask)

0x03 row索引的计算方式







请到「今天看啥」查看全文