正文
def
naive_softmax
(
x
):
"""Compute row-wise softmax of X using native pytorch
We subtract the maximum element in order to avoid overflows. Softmax is invariant to
this shift.
"""
# read MN elements ; write M elements; 读取MN元素;写M个元素
x_max
=
x
.
max
(
dim
=
1
)[
0
]
# read MN + M elements ; write MN elements; 读取MN+M元素;写入MN元素
z
=
x
-
x_max
[:,
None
]
# read MN elements ; write MN elements; 读取MN元素;写入MN元素
numerator
=
torch
.
exp
(
z
)
# read MN elements ; write M elements; 读取MN元素;写M个元素
denominator
=
numerator
.
sum
(
dim
=
1
)
# read MN + M elements ; write MN elements; 读取MN M元素;写入MN元素
ret
=
numerator
/
denominator
[:,
None
]
# in total: read 5MN + 2M elements ; wrote 3MN + 2M elements;
return
ret
# 共:读取5MN+2M元素;写了3MN+2M个元素
从代码中的注释可知,naive softmax的访存量为:读取5MN+2M元素;写了3MN+2M个元素;即
8MN+4M
;
0x02 Triton Fused Softmax实现
softmax_kernel的主要思路为:给kernel分配num_programs个programs(也就是thread blocks,后边都把program等同于thread block),每个thread block处理互不重合的一部分rows;对每个row,按行求safe softmax,先求max,再求exp,最后求:softmax_output = numerator / denominator。这个softmax_kernel只需要对x进行读操作一次,以及对y进行写操作一次,对比naive softmax的
8MN+4M
访存量,Triton softmax_kernel只需要
2MN
的访存量,
约为原来的1/4
;
@triton.jit
defsoftmax_kernel(output_ptr,input_ptr,input_row_stride,output_row_stride,n_rows,n_cols,BLOCK_SIZE:tl.constexpr,
num_stages:tl.constexpr):
# starting row of the program
row_start=tl.program_id(0)
row_step=tl.num_programs(0)
forrow_idxintl.range(row_start,n_rows,row_step,num_stages=num_stages):
# The stride represents how much we need to increase the pointer to advance 1 row
row_start_ptr=input_ptr+row_idx*input_row_stride
# The block size is the next power of two greater than n_cols, so we can fit each
# row in a single block
col_offsets=tl.arange(0,BLOCK_SIZE)
input_ptrs=row_start_ptr+col_offsets
# Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols
mask=col_offsets<n_cols
row=tl.load(input_ptrs,mask=mask,other=-float('inf'))
# Subtract maximum for numerical stability
row_minus_max=row-tl.max(row,axis=0)
# Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA)
numerator=tl.exp(row_minus_max)
denominator=tl.sum(numerator,axis=0)
softmax_output=numerator/denominator
# Write back output to DRAM
output_row_start_ptr=output_ptr+row_idx*output_row_stride
output_ptrs=output_row_start_ptr+col_offsets
tl.store(output_ptrs,softmax_output,mask=mask)
0x03 row索引的计算方式