专栏名称: GiantPandaLLM
专注于机器学习、深度学习、计算机视觉、图像处理等多个方向技术分享。团队由一群热爱技术且热衷于分享的小伙伴组成。我们坚持原创,每天一到两篇原创技术分享。希望在传播知识、分享知识的同时能够启发你,大家一起共同进步(・ω<)☆
目录
相关文章推荐
GiantPandaLLM  ·  图解Vllm ... ·  3 天前  
51好读  ›  专栏  ›  GiantPandaLLM

【CUDA 博客】使用PTX指令更高效地加载和存储矩阵

GiantPandaLLM  · 公众号  · 3D  · 2025-05-26 12:00

正文

请到「今天看啥」查看全文


.x4
addr0–addr7
addr8–addr15
addr16–addr23
addr24–addr31

下图展示了使用 ldmatrix 加载的 8x8 矩阵的fragment布局:

// 使用64位地址加载一个8x8矩阵
.reg .b64 addr;
.reg .b32 d;
ldmatrix.sync.aligned.m8n8.x1.shared::cta.b16 {d}, [addr];

// 加载两个8x8矩阵,以列主格式
.reg .b64 addr;
.reg .b32 d<2>;
ldmatrix.sync.aligned.m8n8.x2.trans.shared.b16 {d0, d1}, [addr];

// 加载四个8x8矩阵
.reg .b64 addr;
.reg .b32 d<4>;
ldmatrix.sync.aligned.m8n8.x4.b16 {d0, d1, d2, d3}, [addr];

实现

如上所述,指针应该位于 .shared 空间中。有多种方法将通用指针转换为 .shared 空间。最简单的方法如下(https://forums.developer.nvidia.com/t/problem-about-ptx-instruction-cp-async-ca-shared-global/224219/2):

size_t asl = __cvta_generic_to_shared(smem+threadIdx.x);

我们也可以使用内联汇编:

asm volatile(".reg .u64 smem_ptr64; cvta.to.shared.u64 smem_ptr64, %0;\n" :: "l"(smem+threadIdx.x));

或者像这样:

asm volatile(".reg .u64 smem_ptr64; cvta.to.shared.u64 smem_ptr64, %0;\n" :: "l"(smem+threadIdx.x))
asm volatile(".reg .u32 smem_ptr32; cvt.u32.u64 smem_ptr32, smem_ptr64;\n" ::);

我们也可以参考CUTLASS库(https://github.com/NVIDIA/cutlass/blob/ad7b2f5e84fcfa124cb02b91d5bd26d238c0459e/include/cute/arch/copy_sm75.hpp#L39)来获取实现思路。

从这里开始,实现就比较直接了:

#include 
#include 

// 定义一个设备端内联函数,用于从共享内存加载8x8矩阵
// d0: 输出参数,用于存储加载的数据
// address: 输入参数,共享内存中的地址
__device__ __forceinline__ void ldmatrix_sync_aligned_m8n8_x1_b16(
    uint32_t &d0, const uint32_t &address)
 
{
// 使用内联PTX汇编指令加载矩阵
// ldmatrix.sync.aligned.m8n8.x1.shared.b16: 同步加载8x8矩阵,每个元素16位
// {%0}: 输出寄存器,存储加载的数据
// [%1]: 输入寄存器,指定共享内存地址
asm volatile("ldmatrix.sync.aligned.m8n8.x1.shared.b16 {%0}, [%1];"
               : "=r" (d0)    // 输出约束,表示d0是一个输出寄存器
               : "r"(address))
// 输入约束,表示address是一个输入寄存器
}

// CUDA核函数,用于演示矩阵加载
__global__ void ldmatrix(uint16_t *value) {
// 定义共享内存大小
constexprint N = 64;
// 声明共享内存数组
  __shared__ uint16_t smem[N];
// 获取当前线程ID
auto tid = threadIdx.x;

// 计算行偏移量:每个线程负责8个元素,所以乘以8
constuint32_t offset_rows = sizeof(uint16_t) * (tid % 8) * 8;
// 计算最终地址:共享内存基址 + 行偏移
constuint32_t address = __cvta_generic_to_shared(smem) + offset_rows;

// 初始化共享内存
for (uint32_t i = tid; i < N; i += blockDim.x) {
    smem[i] = i;
  }
// 同步所有线程,确保共享内存初始化完成
  __syncthreads();

// 声明用于存储加载数据的变量
uint32_t frag;
// 调用矩阵加载函数
  ldmatrix_sync_aligned_m8n8_x1_b16(frag, address);

// 再次同步,确保所有线程都完成加载
  __syncthreads();

// 从32位数据中提取两个16位值
// 提取低16位
uint16_t number1 = static_cast<uint16_t>(frag & 0xFFFF);
// 提取高16位
uint16_t number2 = static_cast<uint16_t>((frag >> 16) & 0xFFFF);
// 打印结果
printf("%d -> %d  %d   %d   \n", tid, (int)(smem[2 * tid]), (int)number1,
         (int)number2);
}

// 主函数
int main() {
// 声明设备端指针
uint16_t *d_value;
// 分配设备内存
  cudaMalloc(&d_value, sizeof(uint16_t));
// 启动核函数,使用1个块,32个线程
  ldmatrix<<<132>>>(d_value);
// 等待设备完成
  cudaDeviceSynchronize();
// 释放设备内存
  cudaFree(d_value);
return0;
}

注意,根据上表,线程0-7需要对应于前8行的地址:

const uint32_t offset_rows = sizeof(uint16_t) * (tid % 8) * 8;
const uint32_t address = __cvta_generic_to_shared(smem) + offset_rows;

我们将在加载时传递地址和fragment。注意,每个fragment有 32bit ,我们可以通过先使用全16位掩码来提取最后16位,然后右移并再次执行相同的操作来提取前16位来输出加载的fragment。

0 -> 0  0   1   
1 -> 2  2   3   
2 -> 4  4   5   
3 -> 6  6   7   
4 -> 8  8   9   
5 -> 10  10   11   
6 -> 12  12   13   
7 -> 14  14   15   
8 -> 16  16   17   
9 -> 18  18   19   
10 -> 20  20   21   
11 -> 22  22   23   
12 -> 24  24   25   
13 -> 26  26   27   
14 -> 28  28   29   
15 -> 30  30   31   
16 -> 32  32   33   
17 -> 34  34   35   
18 -> 36  36   37   
19 -> 38  38   39   
20 -> 40  40   41   
21 -> 42  42   43   
22 -> 44  44   45   
23 -> 46  46   47   
24 -> 48  48   49   
25 -> 50  50   51   
26 -> 52  52   53   
27 -> 54  54   55   
28 -> 56  56   57   
29 -> 58  58   59   
30 -> 60  60   61   
31 -> 62  62   63

我们可以看到每个寄存器包含两个值。

我们可以在一个warp中同时写入两个矩阵。我们需要考虑到地址是按线程组提供的:

.num 线程 0-7 线程 8-15 线程 16-23 线程 24-31






请到「今天看啥」查看全文