CuTe C++ 簡介02,gemm_device cuda kernel 的實現

??

? ? ? ? 《CuTe C++ 簡介01,從示例開始 》?中,最后看到了 計算 gemm 的cuda kernel,使用 NVIDIA CUTLASS 的 CUTe (CUDA Tile) 庫實現的高性能 GEMM (通用矩陣乘法) CUDA kernel。接下來解釋一下這個內核的各個部分。文末再貼一遍代碼,方便查看。

1. 模板參數和函數簽名

template <class ProblemShape, class CtaTiler,class TA, class AStride, class ASmemLayout, class AThreadLayout,class TB, class BStride, class BSmemLayout, class BThreadLayout,class TC, class CStride, class CSmemLayout, class CThreadLayout,class Alpha, class Beta>

這個內核高度模板化,支持配置,

? ? ? ? 任意數據類型 (TA, TB, TC);

? ? ? ? 任意矩陣形狀和步長;

? ? ? ? 任意內存布局和線程映射;

? ? ? ? 不同的標量類型 (Alpha, Beta);

2. 靜態斷言和預條件檢查

? 大量的?CUTE_STATIC_ASSERT_V?和?static_assert?確保在編譯時驗證如下內容,

? ? ? ? 正確的張量維度;

? ? ? ? 線程布局和數據分塊的兼容性;

? ? ? ? 內存布局的一致性;

3. 張量創建和分塊

Tensor mA = make_tensor(make_gmem_ptr(A), select<0,2>(shape_MNK), dA);
Tensor gA = local_tile(mA, cta_tiler, cta_coord, Step<_1, X,_1>{});

? ? ? ? 創建全局內存中的矩陣張量;

? ? ? ? 使用?local_tile?將大矩陣分塊為線程 block 處理的子塊;

4. 共享內存分配

__shared__ TA smemA[cosize_v<ASmemLayout>];
Tensor sA = make_tensor(make_smem_ptr(smemA), sA_layout);

? ? ? ??為矩陣 A 和 B 分配共享內存;

? ? ? ? 使用模板化的布局確保內存訪問的高效性;

5. 數據分區

Tensor tAgA = local_partition(gA, tA, threadIdx.x);
Tensor tAsA = local_partition(sA, tA, threadIdx.x);

? ? ? ??將全局內存和共享內存的數據分區給各個線程;

? ? ? ? 每個線程負責加載特定的數據塊;

6. 累加器分配和初始化

Tensor tCrC = make_tensor_like(tCgC);
clear(tCrC);

? ? ? ??為每個線程創建寄存器中的累加器;

? ? ? ? 初始化為零;

7. 主循環 (核心計算部分)

for (int k_tile = 0; k_tile < K_TILE_MAX; ++k_tile)
{copy(tAgA(_,_,k_tile), tAsA);copy(tBgB(_,_,k_tile), tBsB);cp_async_fence();cp_async_wait<0>();__syncthreads();gemm(tCsA, tCsB, tCrC);__syncthreads();
}

這是內核的核心部分:

  1. 數據加載: 將全局內存中的數據拷貝到共享內存;

  2. 異步等待: 使用?cp_async?指令實現異步數據加載;

  3. 同步: 確保所有線程完成數據加載;

  4. 計算: 從共享內存加載數據并進行矩陣乘法計算;

  5. 再次同步: 確保所有線程完成計算;

8. 收尾處理

axpby(alpha, tCrC, beta, tCgC);

? ? ? ??應用縮放因子 alpha 和 beta;

? ? ? ? 將計算結果寫回全局內存;

關鍵點復盤

  1. 雙緩沖機制: 通過循環處理 K 維度,實現計算和數據加載的重疊;

  2. 高效內存訪問: 使用共享內存減少全局內存訪問;

  3. 線程級并行: 精細的線程調度和數據分區;

  4. 模板元編程: 編譯時優化,生成高度特化的代碼;

  5. 異步拷貝: 使用?cp_async?指令隱藏內存延遲;

? ? ? ? 評價:這個內核展示了現代 GPU 編程的最佳實踐,通過精細的內存層次管理和線程調度來實現高性能的矩陣乘法運算。

template <class ProblemShape, class CtaTiler,class TA, class AStride, class ASmemLayout, class AThreadLayout,class TB, class BStride, class BSmemLayout, class BThreadLayout,class TC, class CStride, class CSmemLayout, class CThreadLayout,class Alpha, class Beta>
__global__ static
__launch_bounds__(decltype(size(CThreadLayout{}))::value)
void
gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler,TA const* A, AStride dA, ASmemLayout sA_layout, AThreadLayout tA,TB const* B, BStride dB, BSmemLayout sB_layout, BThreadLayout tB,TC      * C, CStride dC, CSmemLayout          , CThreadLayout tC,Alpha alpha, Beta beta)
{using namespace cute;// PreconditionsCUTE_STATIC_ASSERT_V(rank(shape_MNK) == Int<3>{});                   // (M, N, K)CUTE_STATIC_ASSERT_V(rank(cta_tiler) == Int<3>{});                   // (BLK_M, BLK_N, BLK_K)static_assert(is_static<AThreadLayout>::value);static_assert(is_static<BThreadLayout>::value);static_assert(is_static<CThreadLayout>::value);CUTE_STATIC_ASSERT_V(size(tA) == size(tB));                          // NumThreadsCUTE_STATIC_ASSERT_V(size(tC) == size(tA));                          // NumThreadsCUTE_STATIC_ASSERT_V(size<0>(cta_tiler) % size<0>(tA) == Int<0>{});  // BLK_M / THR_MCUTE_STATIC_ASSERT_V(size<2>(cta_tiler) % size<1>(tA) == Int<0>{});  // BLK_K / THR_KCUTE_STATIC_ASSERT_V(size<1>(cta_tiler) % size<0>(tB) == Int<0>{});  // BLK_N / THR_NCUTE_STATIC_ASSERT_V(size<2>(cta_tiler) % size<1>(tB) == Int<0>{});  // BLK_K / THR_KCUTE_STATIC_ASSERT_V(size<0>(cta_tiler) % size<0>(tC) == Int<0>{});  // BLK_M / THR_MCUTE_STATIC_ASSERT_V(size<1>(cta_tiler) % size<1>(tC) == Int<0>{});  // BLK_N / THR_Nstatic_assert(is_static<ASmemLayout>::value);static_assert(is_static<BSmemLayout>::value);static_assert(is_static<CSmemLayout>::value);CUTE_STATIC_ASSERT_V(size<0>(ASmemLayout{}) == size<0>(cta_tiler));  // BLK_MCUTE_STATIC_ASSERT_V(size<0>(CSmemLayout{}) == size<0>(cta_tiler));  // BLK_MCUTE_STATIC_ASSERT_V(size<0>(BSmemLayout{}) == size<1>(cta_tiler));  // BLK_NCUTE_STATIC_ASSERT_V(size<1>(CSmemLayout{}) == size<1>(cta_tiler));  // BLK_NCUTE_STATIC_ASSERT_V(size<1>(ASmemLayout{}) == size<2>(cta_tiler));  // BLK_KCUTE_STATIC_ASSERT_V(size<1>(BSmemLayout{}) == size<2>(cta_tiler));  // BLK_KCUTE_STATIC_ASSERT_V(congruent(select<0,2>(shape_MNK), dA));         // dA strides for shape MKCUTE_STATIC_ASSERT_V(congruent(select<1,2>(shape_MNK), dB));         // dB strides for shape NKCUTE_STATIC_ASSERT_V(congruent(select<0,1>(shape_MNK), dC));         // dC strides for shape MN//// Full and Tiled Tensors//// Represent the full tensorsTensor mA = make_tensor(make_gmem_ptr(A), select<0,2>(shape_MNK), dA); // (M,K)Tensor mB = make_tensor(make_gmem_ptr(B), select<1,2>(shape_MNK), dB); // (N,K)Tensor mC = make_tensor(make_gmem_ptr(C), select<0,1>(shape_MNK), dC); // (M,N)// Get the appropriate blocks for this thread blockauto cta_coord = make_coord(blockIdx.x, blockIdx.y, _);              // (m,n,k)Tensor gA = local_tile(mA, cta_tiler, cta_coord, Step<_1, X,_1>{});  // (BLK_M,BLK_K,k)Tensor gB = local_tile(mB, cta_tiler, cta_coord, Step< X,_1,_1>{});  // (BLK_N,BLK_K,k)Tensor gC = local_tile(mC, cta_tiler, cta_coord, Step<_1,_1, X>{});  // (BLK_M,BLK_N)// Shared memory buffers__shared__ TA smemA[cosize_v<ASmemLayout>];__shared__ TB smemB[cosize_v<BSmemLayout>];Tensor sA = make_tensor(make_smem_ptr(smemA), sA_layout);            // (BLK_M,BLK_K)Tensor sB = make_tensor(make_smem_ptr(smemB), sB_layout);            // (BLK_N,BLK_K)//// Partition the copying of A and B tiles across the threads//// TUTORIAL: Example of simple raked partitioning of ThreadLayouts tA|tB over data A|B tilesTensor tAgA = local_partition(gA, tA, threadIdx.x);                  // (THR_M,THR_K,k)Tensor tAsA = local_partition(sA, tA, threadIdx.x);                  // (THR_M,THR_K)Tensor tBgB = local_partition(gB, tB, threadIdx.x);                  // (THR_N,THR_K,k)Tensor tBsB = local_partition(sB, tB, threadIdx.x);                  // (THR_N,THR_K)CUTE_STATIC_ASSERT_V(size<0>(tAgA) == size<0>(tAsA));                // THR_MCUTE_STATIC_ASSERT_V(size<1>(tAgA) == size<1>(tAsA));                // THR_KCUTE_STATIC_ASSERT_V(size<0>(tBgB) == size<0>(tBsB));                // THR_NCUTE_STATIC_ASSERT_V(size<1>(tBgB) == size<1>(tBsB));                // THR_K//// Define A/B partitioning and C accumulators//// TUTORIAL: Example of partitioning via projections of a ThreadLayout tC// Partition sA (BLK_M, BLK_K) by the rows of tCTensor tCsA = local_partition(sA, tC, threadIdx.x, Step<_1, X>{});   // (THR_M,BLK_K)// Partition sB (BLK_N, BLK_K) by the cols of tCTensor tCsB = local_partition(sB, tC, threadIdx.x, Step< X,_1>{});   // (THR_N,BLK_K)// Partition gC (M,N) by the tile of tCTensor tCgC = local_partition(gC, tC, threadIdx.x, Step<_1,_1>{});   // (THR_M,THR_N)// Allocate the accumulators -- same shape/layout as the partitioned dataTensor tCrC = make_tensor_like(tCgC);                                // (THR_M,THR_N)CUTE_STATIC_ASSERT_V(size<0>(tCrC) == size<0>(tCgC));                // THR_MCUTE_STATIC_ASSERT_V(size<0>(tCrC) == size<0>(tCsA));                // THR_MCUTE_STATIC_ASSERT_V(size<1>(tCrC) == size<1>(tCgC));                // THR_NCUTE_STATIC_ASSERT_V(size<1>(tCrC) == size<0>(tCsB));                // THR_NCUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCsB));                // BLK_K// Clear the accumulatorsclear(tCrC);#if 0if(thread0()) {print("  mA : "); print(  mA); print("\n");print("  gA : "); print(  gA); print("\n");print("  sA : "); print(  sA); print("\n");print("tAgA : "); print(tAgA); print("\n");print("tAsA : "); print(tAsA); print("\n");}
#endif#if 0if(thread0()) {print("  mB : "); print(  mB); print("\n");print("  gB : "); print(  gB); print("\n");print("  sB : "); print(  sB); print("\n");print("tBgB : "); print(tBgB); print("\n");print("tBsB : "); print(tBsB); print("\n");}
#endif#if 0if(thread0()) {print("  mC : "); print(  mC); print("\n");print("  gC : "); print(  gC); print("\n");print("tCsA : "); print(tCsA); print("\n");print("tCsB : "); print(tCsB); print("\n");print("tCgC : "); print(tCgC); print("\n");print("tCrC : "); print(tCrC); print("\n");}
#endif#if 1// TUTORIAL: Example of a simple mainloop that read tiles of data into shared memory,//           and then computes on those tiles.//   copy(.) operates on the global and shared memory via the tA|tB partitioning//   gemm(.) operates on the shared and register memory via the tC partitioningauto K_TILE_MAX = size<2>(tAgA);for (int k_tile = 0; k_tile < K_TILE_MAX; ++k_tile){// Copy gmem to smem with tA|tB thread-partitioned tensorscopy(tAgA(_,_,k_tile), tAsA);      // A   (THR_M,THR_K) -> (THR_M,THR_K)copy(tBgB(_,_,k_tile), tBsB);      // B   (THR_N,THR_K) -> (THR_N,THR_K)// TUTORIAL: The above call to copy(tAgA(_,_,k_tile), tAsA) is equivalent to//   Tensor tAgAk = tAgA(_,_,k_tile);//   CUTE_UNROLL//   for (int i = 0; i < size(tAsA); ++i) {//     tAsA(i) = tAgAk(i);//   }cp_async_fence();        // Label the end of (potential) cp.async instructionscp_async_wait<0>();      // Sync on all (potential) cp.async instructions__syncthreads();         // Wait for all threads to write to smem// Compute gemm on tC thread-partitioned smemgemm(tCsA, tCsB, tCrC);            // (THR_M,THR_N) += (THR_M,BLK_K) * (THR_N,BLK_K)// TUTORIAL: The above call to gemm(tCsA, tCsB, tCrC) is equivalent to//   CUTE_UNROLL//   for (int k = 0; k < size<1>(tCsA); ++k) {//     CUTE_UNROLL//     for (int m = 0; m < size<0>(tCrC); ++m) {//       CUTE_UNROLL//       for (int n = 0; n < size<1>(tCrC); ++n) {//         tCrC(m,n) += tCsA(m,k) * tCsB(n,k);//       }//     }//   }__syncthreads();         // Wait for all threads to read from smem}#endif//// Epilogue//axpby(alpha, tCrC, beta, tCgC);// TUTORIAL: The above call to axpby(alpha, tCrC, beta, tCgC) is equivalent to//   CUTE_UNROLL//   for (int i = 0; i < size(tCrC); ++i) {//     tCgC(i) = alpha * tCrC(i) + beta * tCgC(i);//   }
}

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/95922.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/95922.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/95922.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

萬代《寶可夢》主題新品扭蛋公開!史上最大尺寸

使用jQuery的常用方法與返回值分析 jQuery是一個輕量級的JavaScript庫&#xff0c;旨在簡化HTML文檔遍歷和操作、事件處理以及動畫效果的創建。本文將介紹一些常用的jQuery方法及其返回值&#xff0c;幫助開發者更好地理解和運用這一強大的庫。 1. 選擇器方法 jQuery提供了多種…

【FastDDS】Layer Transport ( 05-Shared Memory Transport)

6.4 共享內存傳輸 共享內存&#xff08;SHM&#xff09;傳輸依靠主機操作系統提供的共享內存機制&#xff0c;實現了在同一處理單元/機器上運行的實體之間的快速通信。注意 Fast DDS 利用域參與者&#xff08;DomainParticipant&#xff09;的 GuidPrefix_t 來識別在同一主機上…

記 2025/9/6

人工智能常見的模型按照處理問題分為6大類&#xff1a;處理權重問題的權重模型、處理狀態問題的狀態模型、處理序列問題的問題模型、處理表示問題的表示模型、處理相似度的相似模型、處理分類問題的分類模型。權重是計算特定狀態下事物的重要性。狀態問題是刻畫權重動態變化的過…

開啟Python之路,第一節學習大綱-從入門到進階

前端開啟Python之路&#xff0c;前端有沒有必要卷后端技術&#xff0c;歡迎各位大神批評指正 第一階段&#xff1a;基礎入門 (打好根基) 目標&#xff1a; 理解編程基本概念&#xff0c;掌握 Python 核心語法&#xff0c;能編寫簡單的腳本程序。 1、環境搭建與開發工具 安裝 Py…

webshell及冰蝎雙擊無法打開?

什么是webshell&#xff1f; web:萬維網 shell&#xff1a;是指一種應用程序&#xff0c;為用戶和系統之間建立連接&#xff0c;通過這個界面訪問操作系統內核的服務 webshell:是以asp、aspx、php、jsp或者cgi等網頁文件形式存在的一種命令執行環境&#xff0c;也可以將其稱做…

【星閃】Hi2821 | PWM脈寬調制模塊 + 呼吸燈例程

1. 簡介PWM&#xff08;Pulse Width Modulation&#xff09;&#xff0c;全稱脈寬調制&#xff0c;通過對一系列脈沖的寬度進行調制&#xff0c;等效出所需波形。即對模擬信號電平進行數字編碼&#xff0c;通過調節頻率、占空比的變化來調節信號的變化。一個 PWM 周期內由一段高…

51單片機---硬件學習(電子琴、主從應答模式、modbus模型、DS18B20傳感器顯示溫度)

一、串行通信與并行通信1、串行通信定義&#xff1a;數據一位一位地按順序通過單條傳輸線進行傳輸的通信方式。優點&#xff1a;傳輸線少&#xff0c;成本低&#xff0c;適合長距離傳輸缺點&#xff1a;傳輸速度相對較慢2、并行通信定義&#xff1a;數據的各位同時通過多條并行…

SpringBoot后端開發常用工具詳細介紹——SpringSecurity認證用戶保證安全

簡單的開始 創建SpringBoot項目 首先創建一個簡單的springboot項目&#xff0c;假設端口為8888&#xff0c;添加controller控制層&#xff0c;并在其中添加TestController控制類&#xff0c;那么啟動springboot項目之后&#xff0c;訪localhost:8888/api/message頁面會顯示my…

別再手工縫合API了!開源LLMOps神器LMForge,讓你像搭積木一樣玩轉AI智能體!

你是否受夠了這些&#xff1f; 剛調通OpenAI的API&#xff0c;老板說“咱們試試國產模型降本增效”&#xff0c;你看著滿屏的if-else只想說“我暈”。想給AI加上“查天氣”、“執行代碼”的能力&#xff0c;卻發現Function Calling的代碼復雜得讓人頭皮發麻。本地的Agentdemo驚…

window使用ffmep工具,加自定義腳本執行視頻轉碼成h264(運營人員使用)

技術文章大綱&#xff1a;ffmep配合腳本使用1. 需要提供腳本給視頻轉碼的給運營,給運營上傳視頻使用安裝ffmep windows版本(目前我使用的就是windows)將腳本里面的執行路徑修改成自己的電腦安裝ffmep/bin/ffmep.exe路徑處理好之后就點擊執行2.環境準備ffmep windows版解壓到一個…

Leetcode 240. 搜索二維矩陣 II 矩陣 / 二分

原題鏈接&#xff1a; Leetcode 240. 搜索二維矩陣 II 解法一&#xff1a;排除法 參考 【圖解】排除法&#xff0c;一圖秒懂&#xff01;&#xff08;Python/Java/C/C/Go/JS/Rust&#xff09; 從右上角&#xff1a; class Solution { public:bool searchMatrix(vector<vec…

OCR 證件識別:驅動澳門酒店自助入住智能化

澳門酒店作為國際旅游窗口&#xff0c;每日接待持多元證件的旅客&#xff0c;OCR 證件識別技術的應用&#xff0c;讓自助入住終端實現 “一證通辦”&#xff0c;大幅提升服務效率。?旅客在自助終端辦理入住時&#xff0c;只需將護照、港澳通行證、回鄉證、電子身份證等證件貼近…

深入解析匯編語言的奧秘

匯編語言簡介匯編語言&#xff08;Assembly Language&#xff09;是一種低級編程語言&#xff0c;直接對應計算機的機器指令集。它通過助記符&#xff08;如 MOV、ADD&#xff09;代替二進制操作碼&#xff0c;更接近硬件架構&#xff0c;常用于性能優化、嵌入式開發或逆向工程…

Nextcloud 實戰:打造屬于你的私有云與在線協作平臺

隨著數據安全與隱私保護意識的提升&#xff0c;越來越多的個人和組織選擇自建云平臺來替代公有云。Nextcloud 作為一款開源的文件同步與協作套件&#xff0c;不僅能實現類似網盤的文件存儲與分享&#xff0c;還提供日歷、聯系人、即時通訊、在線文檔編輯等協作功能&#xff0c;…

實踐指南:利用衡石AI Data Agent實現自然語言驅動的指標開發與歸因

在數字化轉型的深水區&#xff0c;企業數據團隊常面臨兩難困境&#xff1a;業務部門需要敏捷響應的指標分析&#xff0c;但傳統BI工具依賴技術團隊編寫SQL&#xff0c;導致需求交付周期長達數周&#xff1b;而直接暴露底層數據又存在安全與合規風險。衡石科技推出的AI Data Age…

知微集:Python中的線程(三)

歡迎來到"一起學點什么吧"的合集「NLP知微集」。在這里&#xff0c;我們不愿宏大敘事&#xff0c;只聚焦于自然語言處理領域中那些細微卻關鍵的“齒輪”與“螺絲釘”。我相信&#xff0c;真正深刻的理解&#xff0c;源于對細節的洞察。本期&#xff0c;我將為您拆解的…

動態規劃入門:從記憶化搜索到動態規劃

在開始對動態規劃的講解之前&#xff0c;我們需要先對記憶化搜索進行回顧&#xff1a; 什么是記憶化搜索&#xff1f; 在搜索過程中&#xff0c;當搜索樹中存在大量重復的節點時&#xff0c;我們可以通過引入一個"備忘錄"&#xff08;通常是一個數組或哈希表&#…

Boost搜索引擎 網絡庫與前端(4)

文章目錄前言一、引入網絡庫模塊引入cpp-httplibcpp-httplib測試正式編寫http_server二、前端模塊三、項目的可能拓展總結前言 終于到了最后一篇嘍&#xff0c;嘻嘻&#xff01; 一、引入網絡庫模塊 引入cpp-httplib 下載地址如下&#xff0c;我個人不喜歡新版本 ??cpp-http…

Flink反壓問題

背景在使用flink的過程中&#xff0c;多次遇到過反壓&#xff08;backpressure&#xff09;的問題&#xff0c;這通常是因為數據處理的速率超過了數據源或下游系統的處理能力導致。反壓的底層剖析網絡流控一個重要的概念是網絡流控&#xff0c;如上圖&#xff0c;不同的Consume…

Day5-中間件與請求處理

昨天搞定了異步優化&#xff0c;今天來解決一些實際問題。Day4的API雖然性能不錯&#xff0c;但還缺少一些企業級應用必備的功能。 現在的問題 前端無法訪問API&#xff08;跨域問題&#xff09;沒有請求日志&#xff0c;出問題難以排查錯誤信息格式不統一缺少統一的請求處理機…