??
? ? ? ? 《CuTe C++ 簡介01,從示例開始 》?中,最后看到了 計算 gemm 的cuda kernel,使用 NVIDIA CUTLASS 的 CUTe (CUDA Tile) 庫實現的高性能 GEMM (通用矩陣乘法) CUDA kernel。接下來解釋一下這個內核的各個部分。文末再貼一遍代碼,方便查看。
1. 模板參數和函數簽名
template <class ProblemShape, class CtaTiler,class TA, class AStride, class ASmemLayout, class AThreadLayout,class TB, class BStride, class BSmemLayout, class BThreadLayout,class TC, class CStride, class CSmemLayout, class CThreadLayout,class Alpha, class Beta>
這個內核高度模板化,支持配置,
? ? ? ? 任意數據類型 (TA, TB, TC);
? ? ? ? 任意矩陣形狀和步長;
? ? ? ? 任意內存布局和線程映射;
? ? ? ? 不同的標量類型 (Alpha, Beta);
2. 靜態斷言和預條件檢查
? 大量的?CUTE_STATIC_ASSERT_V
?和?static_assert
?確保在編譯時驗證如下內容,
? ? ? ? 正確的張量維度;
? ? ? ? 線程布局和數據分塊的兼容性;
? ? ? ? 內存布局的一致性;
3. 張量創建和分塊
Tensor mA = make_tensor(make_gmem_ptr(A), select<0,2>(shape_MNK), dA);
Tensor gA = local_tile(mA, cta_tiler, cta_coord, Step<_1, X,_1>{});
? ? ? ? 創建全局內存中的矩陣張量;
? ? ? ? 使用?local_tile
?將大矩陣分塊為線程 block 處理的子塊;
4. 共享內存分配
__shared__ TA smemA[cosize_v<ASmemLayout>];
Tensor sA = make_tensor(make_smem_ptr(smemA), sA_layout);
? ? ? ??為矩陣 A 和 B 分配共享內存;
? ? ? ? 使用模板化的布局確保內存訪問的高效性;
5. 數據分區
Tensor tAgA = local_partition(gA, tA, threadIdx.x);
Tensor tAsA = local_partition(sA, tA, threadIdx.x);
? ? ? ??將全局內存和共享內存的數據分區給各個線程;
? ? ? ? 每個線程負責加載特定的數據塊;
6. 累加器分配和初始化
Tensor tCrC = make_tensor_like(tCgC);
clear(tCrC);
? ? ? ??為每個線程創建寄存器中的累加器;
? ? ? ? 初始化為零;
7. 主循環 (核心計算部分)
for (int k_tile = 0; k_tile < K_TILE_MAX; ++k_tile)
{copy(tAgA(_,_,k_tile), tAsA);copy(tBgB(_,_,k_tile), tBsB);cp_async_fence();cp_async_wait<0>();__syncthreads();gemm(tCsA, tCsB, tCrC);__syncthreads();
}
這是內核的核心部分:
-
數據加載: 將全局內存中的數據拷貝到共享內存;
-
異步等待: 使用?
cp_async
?指令實現異步數據加載; -
同步: 確保所有線程完成數據加載;
-
計算: 從共享內存加載數據并進行矩陣乘法計算;
-
再次同步: 確保所有線程完成計算;
8. 收尾處理
axpby(alpha, tCrC, beta, tCgC);
? ? ? ??應用縮放因子 alpha 和 beta;
? ? ? ? 將計算結果寫回全局內存;
關鍵點復盤
-
雙緩沖機制: 通過循環處理 K 維度,實現計算和數據加載的重疊;
-
高效內存訪問: 使用共享內存減少全局內存訪問;
-
線程級并行: 精細的線程調度和數據分區;
-
模板元編程: 編譯時優化,生成高度特化的代碼;
-
異步拷貝: 使用?
cp_async
?指令隱藏內存延遲;
? ? ? ? 評價:這個內核展示了現代 GPU 編程的最佳實踐,通過精細的內存層次管理和線程調度來實現高性能的矩陣乘法運算。
template <class ProblemShape, class CtaTiler,class TA, class AStride, class ASmemLayout, class AThreadLayout,class TB, class BStride, class BSmemLayout, class BThreadLayout,class TC, class CStride, class CSmemLayout, class CThreadLayout,class Alpha, class Beta>
__global__ static
__launch_bounds__(decltype(size(CThreadLayout{}))::value)
void
gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler,TA const* A, AStride dA, ASmemLayout sA_layout, AThreadLayout tA,TB const* B, BStride dB, BSmemLayout sB_layout, BThreadLayout tB,TC * C, CStride dC, CSmemLayout , CThreadLayout tC,Alpha alpha, Beta beta)
{using namespace cute;// PreconditionsCUTE_STATIC_ASSERT_V(rank(shape_MNK) == Int<3>{}); // (M, N, K)CUTE_STATIC_ASSERT_V(rank(cta_tiler) == Int<3>{}); // (BLK_M, BLK_N, BLK_K)static_assert(is_static<AThreadLayout>::value);static_assert(is_static<BThreadLayout>::value);static_assert(is_static<CThreadLayout>::value);CUTE_STATIC_ASSERT_V(size(tA) == size(tB)); // NumThreadsCUTE_STATIC_ASSERT_V(size(tC) == size(tA)); // NumThreadsCUTE_STATIC_ASSERT_V(size<0>(cta_tiler) % size<0>(tA) == Int<0>{}); // BLK_M / THR_MCUTE_STATIC_ASSERT_V(size<2>(cta_tiler) % size<1>(tA) == Int<0>{}); // BLK_K / THR_KCUTE_STATIC_ASSERT_V(size<1>(cta_tiler) % size<0>(tB) == Int<0>{}); // BLK_N / THR_NCUTE_STATIC_ASSERT_V(size<2>(cta_tiler) % size<1>(tB) == Int<0>{}); // BLK_K / THR_KCUTE_STATIC_ASSERT_V(size<0>(cta_tiler) % size<0>(tC) == Int<0>{}); // BLK_M / THR_MCUTE_STATIC_ASSERT_V(size<1>(cta_tiler) % size<1>(tC) == Int<0>{}); // BLK_N / THR_Nstatic_assert(is_static<ASmemLayout>::value);static_assert(is_static<BSmemLayout>::value);static_assert(is_static<CSmemLayout>::value);CUTE_STATIC_ASSERT_V(size<0>(ASmemLayout{}) == size<0>(cta_tiler)); // BLK_MCUTE_STATIC_ASSERT_V(size<0>(CSmemLayout{}) == size<0>(cta_tiler)); // BLK_MCUTE_STATIC_ASSERT_V(size<0>(BSmemLayout{}) == size<1>(cta_tiler)); // BLK_NCUTE_STATIC_ASSERT_V(size<1>(CSmemLayout{}) == size<1>(cta_tiler)); // BLK_NCUTE_STATIC_ASSERT_V(size<1>(ASmemLayout{}) == size<2>(cta_tiler)); // BLK_KCUTE_STATIC_ASSERT_V(size<1>(BSmemLayout{}) == size<2>(cta_tiler)); // BLK_KCUTE_STATIC_ASSERT_V(congruent(select<0,2>(shape_MNK), dA)); // dA strides for shape MKCUTE_STATIC_ASSERT_V(congruent(select<1,2>(shape_MNK), dB)); // dB strides for shape NKCUTE_STATIC_ASSERT_V(congruent(select<0,1>(shape_MNK), dC)); // dC strides for shape MN//// Full and Tiled Tensors//// Represent the full tensorsTensor mA = make_tensor(make_gmem_ptr(A), select<0,2>(shape_MNK), dA); // (M,K)Tensor mB = make_tensor(make_gmem_ptr(B), select<1,2>(shape_MNK), dB); // (N,K)Tensor mC = make_tensor(make_gmem_ptr(C), select<0,1>(shape_MNK), dC); // (M,N)// Get the appropriate blocks for this thread blockauto cta_coord = make_coord(blockIdx.x, blockIdx.y, _); // (m,n,k)Tensor gA = local_tile(mA, cta_tiler, cta_coord, Step<_1, X,_1>{}); // (BLK_M,BLK_K,k)Tensor gB = local_tile(mB, cta_tiler, cta_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k)Tensor gC = local_tile(mC, cta_tiler, cta_coord, Step<_1,_1, X>{}); // (BLK_M,BLK_N)// Shared memory buffers__shared__ TA smemA[cosize_v<ASmemLayout>];__shared__ TB smemB[cosize_v<BSmemLayout>];Tensor sA = make_tensor(make_smem_ptr(smemA), sA_layout); // (BLK_M,BLK_K)Tensor sB = make_tensor(make_smem_ptr(smemB), sB_layout); // (BLK_N,BLK_K)//// Partition the copying of A and B tiles across the threads//// TUTORIAL: Example of simple raked partitioning of ThreadLayouts tA|tB over data A|B tilesTensor tAgA = local_partition(gA, tA, threadIdx.x); // (THR_M,THR_K,k)Tensor tAsA = local_partition(sA, tA, threadIdx.x); // (THR_M,THR_K)Tensor tBgB = local_partition(gB, tB, threadIdx.x); // (THR_N,THR_K,k)Tensor tBsB = local_partition(sB, tB, threadIdx.x); // (THR_N,THR_K)CUTE_STATIC_ASSERT_V(size<0>(tAgA) == size<0>(tAsA)); // THR_MCUTE_STATIC_ASSERT_V(size<1>(tAgA) == size<1>(tAsA)); // THR_KCUTE_STATIC_ASSERT_V(size<0>(tBgB) == size<0>(tBsB)); // THR_NCUTE_STATIC_ASSERT_V(size<1>(tBgB) == size<1>(tBsB)); // THR_K//// Define A/B partitioning and C accumulators//// TUTORIAL: Example of partitioning via projections of a ThreadLayout tC// Partition sA (BLK_M, BLK_K) by the rows of tCTensor tCsA = local_partition(sA, tC, threadIdx.x, Step<_1, X>{}); // (THR_M,BLK_K)// Partition sB (BLK_N, BLK_K) by the cols of tCTensor tCsB = local_partition(sB, tC, threadIdx.x, Step< X,_1>{}); // (THR_N,BLK_K)// Partition gC (M,N) by the tile of tCTensor tCgC = local_partition(gC, tC, threadIdx.x, Step<_1,_1>{}); // (THR_M,THR_N)// Allocate the accumulators -- same shape/layout as the partitioned dataTensor tCrC = make_tensor_like(tCgC); // (THR_M,THR_N)CUTE_STATIC_ASSERT_V(size<0>(tCrC) == size<0>(tCgC)); // THR_MCUTE_STATIC_ASSERT_V(size<0>(tCrC) == size<0>(tCsA)); // THR_MCUTE_STATIC_ASSERT_V(size<1>(tCrC) == size<1>(tCgC)); // THR_NCUTE_STATIC_ASSERT_V(size<1>(tCrC) == size<0>(tCsB)); // THR_NCUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCsB)); // BLK_K// Clear the accumulatorsclear(tCrC);#if 0if(thread0()) {print(" mA : "); print( mA); print("\n");print(" gA : "); print( gA); print("\n");print(" sA : "); print( sA); print("\n");print("tAgA : "); print(tAgA); print("\n");print("tAsA : "); print(tAsA); print("\n");}
#endif#if 0if(thread0()) {print(" mB : "); print( mB); print("\n");print(" gB : "); print( gB); print("\n");print(" sB : "); print( sB); print("\n");print("tBgB : "); print(tBgB); print("\n");print("tBsB : "); print(tBsB); print("\n");}
#endif#if 0if(thread0()) {print(" mC : "); print( mC); print("\n");print(" gC : "); print( gC); print("\n");print("tCsA : "); print(tCsA); print("\n");print("tCsB : "); print(tCsB); print("\n");print("tCgC : "); print(tCgC); print("\n");print("tCrC : "); print(tCrC); print("\n");}
#endif#if 1// TUTORIAL: Example of a simple mainloop that read tiles of data into shared memory,// and then computes on those tiles.// copy(.) operates on the global and shared memory via the tA|tB partitioning// gemm(.) operates on the shared and register memory via the tC partitioningauto K_TILE_MAX = size<2>(tAgA);for (int k_tile = 0; k_tile < K_TILE_MAX; ++k_tile){// Copy gmem to smem with tA|tB thread-partitioned tensorscopy(tAgA(_,_,k_tile), tAsA); // A (THR_M,THR_K) -> (THR_M,THR_K)copy(tBgB(_,_,k_tile), tBsB); // B (THR_N,THR_K) -> (THR_N,THR_K)// TUTORIAL: The above call to copy(tAgA(_,_,k_tile), tAsA) is equivalent to// Tensor tAgAk = tAgA(_,_,k_tile);// CUTE_UNROLL// for (int i = 0; i < size(tAsA); ++i) {// tAsA(i) = tAgAk(i);// }cp_async_fence(); // Label the end of (potential) cp.async instructionscp_async_wait<0>(); // Sync on all (potential) cp.async instructions__syncthreads(); // Wait for all threads to write to smem// Compute gemm on tC thread-partitioned smemgemm(tCsA, tCsB, tCrC); // (THR_M,THR_N) += (THR_M,BLK_K) * (THR_N,BLK_K)// TUTORIAL: The above call to gemm(tCsA, tCsB, tCrC) is equivalent to// CUTE_UNROLL// for (int k = 0; k < size<1>(tCsA); ++k) {// CUTE_UNROLL// for (int m = 0; m < size<0>(tCrC); ++m) {// CUTE_UNROLL// for (int n = 0; n < size<1>(tCrC); ++n) {// tCrC(m,n) += tCsA(m,k) * tCsB(n,k);// }// }// }__syncthreads(); // Wait for all threads to read from smem}#endif//// Epilogue//axpby(alpha, tCrC, beta, tCgC);// TUTORIAL: The above call to axpby(alpha, tCrC, beta, tCgC) is equivalent to// CUTE_UNROLL// for (int i = 0; i < size(tCrC); ++i) {// tCgC(i) = alpha * tCrC(i) + beta * tCgC(i);// }
}