正文
-
CUDA运行2个块,每个块有4个线程。8个线程中的每一个计算一个单独的位置,例如
z[0] = x[0] + y[0]
-
Triton也运行2个块,每个块执行向量化加法。向量的大小是块的大小,即4。例如
z[0:3] = x[0:3] + y[0:3]
所有
Triton kernel中的操作都是向量化的:加载数据、操作数据、存储数据和创建掩码。
让我们考虑另一个简单的例子:
同样,我们要将大小为
6
的向量
x
和
y
相加,并将输出保存到大小也为6的向量
z
中。我们使用大小为4的块,因此我们有
cdiv(6, 4) = 2
个块。
x = torch.tensor([1,2,3,4,5,6])
y = torch.tensor([0,1,0,1,0,1])
x, y, x+y
CUDA kernel将类似于以下C代码:
# x,y = 输入张量, z = 输出张量, n = x的大小, bs = 块大小
def add_cuda_k(x, y, z, n, bs):
# 定位此特定kernel正在执行的整体计算的哪一部分
block_id = ... # 在我们的例子中: 是[0,1]中的一个
thread_id = ... # 在我们的例子中: 是[0,1,2,3]中的一个
# 识别此特定kernel需要的数据位置
offs = block_id * bs + thread_id
# 保护子句, 确保我们不会越界
if offs
# 读取数据
x_value = x[offs]
y_value = y[offs]
# 执行操作
z_value = x_value + y_value
# 写入数据
z[offs] = z_value
# 重要: offs, x_value, y_value, x_value 都是标量!
# 保护条件也是一种标量, 因为它检查一个值上的一个条件。
为了说明,这里是每个kernel的变量:
现在让我们看一下相应的Triton kernel,大致如下所示:
# 注意:这是为了说明,语法不完全正确。请参见下文以获取正确的Triton语法
def add_triton_k(x, y, z, n, bs):
# 定位此特定kernel正在执行的整体计算的哪一部分
block_id = tl.program_id(0) # 在我们的例子中: 是[0,1]中的一个
# 识别此特定kernel需要的数据位置
offs = block_id * bs + tl.arange(0, bs) #
# 保护子句变成一个掩码,这是一个布尔向量
mask = offs #
# 读取数据
x_values = x[offs] #
y_values = y[offs] #
# 执行操作
z_value = x_value + y_value #
# 写入数据
z[offs] = z_value #
再次说明,这里是每个kernel的变量:
术语说明:在Triton术语中,每个处理块的kernel被称为“program”。也就是说,我们上面的例子运行了2个program。因此,“block_id”通常被称为“pid”(“program id”的缩写),但它们是相同的。
示例1: 复制张量
让我们看一些例子。为了保持简单,我们将使用非常小的块大小。
目标: 给定一个形状为 (n) 的张量
x
,将其复制到另一个张量
z
中。
# # 这是一个普通的Python函数,用于启动Triton kernel
def copy(x, bs, kernel_fn):
z = torch.zeros_like(x)
check_tensors_gpu_ready(x, z)
n = x.numel()
n_blocks = cdiv(n, bs)
grid = (n_blocks,) # 我们有多少个块?可以是1d/2d/3d元组或返回1d/2d/3d元组的函数
# 启动网格!
# - kernel_fn是我们下面编写的Triton kernel
# - grid是我们上面构建的网格
# - x,z,n,bs是传递给每个kernel函数的参数
kernel_fn[grid](x,z,n,bs)
return z
注意:
出于教育目的,下面的kernel有一个逻辑错误(但语法是正确的)。你能发现它吗?
# # 这是Triton kernel:
# triton.jit装饰器将一个Python函数转换为Triton kernel,该kernel在GPU上运行。
# 在这个函数内部,只允许使用部分Python操作。
# 例如,当不进行模拟时,我们不能打印或使用断点,因为这些在GPU上不存在。
@triton.jit
# 当我们传递torch张量时,它们会自动转换为指向其第一个值的指针
# 例如,上面我们传递了x,但在这里我们接收到x_ptr
def copy_k(x_ptr, z_ptr, n, bs: tl.constexpr):
pid = tl.program_id(0)
offs = tl.arange(0, bs) # 从pid计算偏移量
mask = offs x = tl.load(x_ptr + offs, mask) # 加载一个值向量,将`x_ptr + offs`视为`x_ptr[offs]`
tl.store(z_ptr + offs, x, mask) # 存储一个值向量
print_if(f'pid = {pid} | offs = {offs}, mask = {mask}, x = {x}', '')
# 问题: 这个kernel有什么问题?
z = copy(x, bs=2, kernel_fn=copy_k)
pid = [0] | offs = [0 1], mask = [ True True], x = [1 2]
pid = [1] | offs = [0 1], mask = [ True True], x = [1 2]
pid = [2] | offs = [0 1], mask = [ True True], x = [1 2]
z
tensor([1, 2, 0, 0, 0, 0])
我们没有正确地移动偏移量。我们总是使用 offsets = [0,1],但它们应该随着 pid 变化。
@triton.jit
def copy_k(x_ptr, z_ptr, n, bs: tl.constexpr):
pid = tl.program_id(0)
offs = pid * n + tl.arange(0, bs)
mask = offs x = tl.load(x_ptr + offs, mask)
tl.store(z_ptr + offs, x, mask)
print_if(f'pid = {pid} | offs = {offs}, mask = {mask}, x = {x}', '')
z = copy(x, bs=2, kernel_fn=copy_k)
pid = [0] | offs = [0 1], mask = [ True True], x = [1 2]
pid = [1] | offs = [6 7], mask = [False False], x = [1 1]
pid = [2] | offs = [12 13], mask = [False False], x = [1 1]
不完全正确。我们添加了
pid * n
,但想要添加
pid * bs
@triton.jit
def copy_k(x_ptr, z_ptr, n, bs: tl.constexpr):
pid = tl.program_id(0)
offs = pid * bs + tl.arange(0, bs)
mask = offs x = tl.load(x_ptr + offs, mask)
tl.store(z_ptr + offs, x, mask)
print_if(f'pid = {pid} | offs = {offs}, mask = {mask}, x = {x}', '')
z = copy(x, bs=2, kernel_fn=copy_k)
pid = [0] | offs = [0 1], mask = [ True True], x = [1 2]
pid = [1] | offs = [2 3], mask = [ True True], x = [3 4]
pid = [2] | offs = [4 5], mask = [ True True], x = [5 6]
Yes!
x, z
(tensor([1, 2, 3, 4, 5, 6]), tensor([1, 2, 3, 4, 5, 6]))
正如我们所见,编写GPU程序涉及许多索引,我们很容易搞混。因此,我强烈建议先在模拟模式下编写和调试kernel,并首先使用小示例进行测试!
示例2:灰度化图像
在这个示例中,我们将灰度化一张小狗的图像。我们将看到如何处理二维数据。
这同样适用于三维数据。
我们改编了Jeremy Howard的示例,来自这个colab / youtube。因此,感谢他的示例和选择的小狗图像。
注:在这个示例中,如果不重启jupyter内核,会发生两件奇怪的事情:
-
无法导入torchvision,可能是由于循环依赖。-> 目前不知道为什么,需要深入挖掘。
-
下面的模拟triton kernel失败,因为浮点数不能乘以uint向量 -> 在GPU上不进行模拟时可以工作,所以似乎是
TRITON_INTERPRET
的bug。