从官方放出的这个表格中,我们可以看到,如果使用Triton,内存事务合并、SRAM管理以及SM内的线程调度都是自动进行的,我们只需要把精力花在SM之间管理即可,这也就是说,
Triton的编程粒度是Block
(每个Block只会被调度到一个SM上),而不是Thread。我们只需要考虑每个Block需要做什么,至于Thread/Warp的分布和调度,Triton自动给我们处理了。那么,Block这个概念,在Triton中通过什么进行表达呢?答案是:
program
。
importtriton importtriton.languageastl @triton.jit defadd_kernel(x_ptr,# *Pointer* to first input vector. y_ptr,# *Pointer* to second input vector. output_ptr,# *Pointer* to output vector. n_elements,# Size of the vector. BLOCK_SIZE:tl.constexpr,# Number of elements each program should process. # NOTE: `constexpr` so it can be used as a shape value. ): # There are multiple 'programs' processing different data. We identify which program # we are here: # 有多个'程序'(也就是block)处理不同的数据。我们在这里标识我们是哪个程序: pid=tl.program_id(axis=0)# We use a 1D launch grid so axis is 0. # This program will process inputs that are offset from the initial data. # For instance, if you had a vector of length 256 and block_size of 64, the programs # would each access the elements [0:64, 64:128, 128:192, 192:256]. # Note that offsets is a list of pointers: # 该程序将处理与初始数据偏移的输入。 # 例如,如果您有长度为 256 的向量和块大小为 64,程序 # 将分别访问元素[0:64, 64:128, 128:192, 192:256]。 # 请注意,偏移量是指针的列表: block_start=pid*BLOCK_SIZE offsets=block_start+tl.arange(0,BLOCK_SIZE) # Create a mask to guard memory operations against out-of-bounds accesses. # 创建一个mask以防止内存操作超出范围。 mask=offsets<n_elements # Load x and y from DRAM, masking out any extra elements in case the input is not a # multiple of the block size. x=tl.load(x_ptr+offsets,mask=mask) y=tl.load