cuda編程筆記(18)-- 使用im2col + GEMM 實現卷積

我們之前介紹了cudnn調用api直接實現卷積,本文我們探究手動實現。

對于直接使用for循環在cpu上的實現方法,就不過多介紹,只要了解卷積的原理,就很容易實現。

im2col 的核心思想

im2col = image to column

  • 把輸入 feature map 的每個卷積感受野(sliding window)展開成一列向量

  • 卷積核也展開成一個行向量

  • 然后把卷積轉化成 矩陣乘法(GEMM, General Matrix Multiply)

舉例

假設:

  • 輸入:1 channel, 4×4

  • 卷積核:1×2×2

  • 步長 stride=1

輸入:

1  2  3  4
5  6  7  8
9  10 11 12
13 14 15 16

卷積核 2×2 的每個 sliding window:

[[1,2],[5,6]] → 展開為 [1,2,5,6]
[[2,3],[6,7]] → 展開為 [2,3,6,7]
...

im2col 后:

X_col = [[1, 2, 5, 6],   # 第一個位置[2, 3, 6, 7],   # 第二個位置[3, 4, 7, 8],...]

卷積核也展開成:

W_col = [w1, w2, w3, w4]

然后 卷積計算就變成矩陣乘法

Y=W_{col}\cdot X_{col}

  • 輸出每個位置就是矩陣乘法的一個元素

  • 對多通道、多卷積核也可以批量做

說人話,就是把卷積核每次對準的這塊矩陣區域展平成向量;比如X_col的第一行,與W_col作向量乘法,結果就是第一次卷積得到的結果。

但是一般我們會將W_col的參數作為一行,所以X_col實際存儲需要轉置一下,這樣才符合矩陣乘法的要求

為什么效率高?

  1. GEMM 有高度優化

    • BLAS/cuBLAS/cuDNN 都對矩陣乘法做了很多優化

    • 可以充分利用 SIMD / GPU 并行

  2. 循環嵌套少

    • 原始卷積是 6 層循環(batch, output channel, height, width, input channel, kernel height/width)

    • im2col + GEMM → 只要一次矩陣乘法

  3. 容易擴展到多通道、多 batch

代碼實現

代碼中對應的X_col就是按列存儲展開的元素

#ifndef __CUDACC__
#define __CUDACC__
#endif
#include <cuda_runtime.h>
#include <device_launch_parameters.h>
#include <cudnn.h>
#include <cublas_v2.h>
#include <iostream>
#include<cstdio>
#include <cmath>
#include <cstdlib>
#include <vector>__global__ void im2col_kernel(const float *data_im,//輸入圖片數據 [C,H,W]int channels,int height,int width,// 輸入通道數 C,輸入高度 H, 寬度 Wint ksize,int pad,int stride,// 卷積核大小 (假設方形: kH=kW=ksize),padding 大小,卷積步長int height_col,int width_col,// 輸出特征圖的高寬float *data_col){ // im2col 展開的矩陣輸出 [C*ksize*ksize, H_out*W_out]int index=blockIdx.x*blockDim.x+threadIdx.x;//index: 每個線程負責展開一個卷積窗口中的一個元素。int n=channels*ksize*ksize*height_col*width_col;//n: 總的元素數if(index<n){//對比一下w_out*h_out*k_idx*k_w*k_h與n就知道為什么要這么計算這五個變量了//w_out, h_out: 代表當前處理的輸出特征圖位置 (卷積結果的坐標)。int w_out=index%width_col;int h_out=(index/width_col)%height_col;//k_idx → (c_in, k_h, k_w): 把卷積核的 index 分解成通道號、核內行列。int k_idx=(index/width_col/height_col);int k_w=k_idx%ksize;int k_h=(k_idx/ksize)%ksize;//c_in輸入通道索引int c_in=k_idx/(ksize*ksize);//im_row, im_col: 對應輸入圖片上的真實位置(考慮 stride、padding 后)//這是經典的把輸出坐標映射回輸入坐標的公式int im_row=h_out*stride-pad+k_h;int im_col=w_out*stride-pad+k_w;//col_index: data_col 的存儲索引int col_index=(c_in*ksize*ksize+k_h*ksize+k_w)* (height_col * width_col) + h_out * width_col + w_out;float val=0;//如果 im_row, im_col 在輸入圖像范圍內 → 取值;否則屬于padding → 0if(im_row>=0&&im_row<height&&im_col>=0&&im_col<width){val=data_im[(c_in*height+im_row)*width+im_col];}data_col[col_index]=val;}
}
void im2col_gpu(const float *d_im,int channels,int height,int width,int ksize,int pad,int stride,float *d_col){//計算輸出大小:就是卷積輸出 H_out, W_out 的公式。int height_col=(height +2*pad-ksize)/stride+1;int width_col=(width+2*pad-ksize)/stride+1;int n=channels*ksize*ksize*height_col*width_col;int threads=256;int blocks=(n+threads-1)/threads;im2col_kernel<<<blocks,threads>>>(d_im,channels,height,width,ksize,pad,stride,height_col,width_col,d_col);cudaDeviceSynchronize();
}
void conv_forward_im2col(const float *d_input,// 輸入圖像 [C,H,W]const float *d_weight,// 卷積核 [K,C,ksize,ksize]float *d_output,//輸出 [K,H_out,W_out]int C,int H,int W,// 輸入通道數,高,寬int K,// 卷積核數量 (輸出通道數)int ksize,int stride,int pad,// 核大小,步長,填充cublasHandle_t &handle){int H_out=(H+2*pad-ksize)/stride+1;int W_out=(W+2*pad-ksize)/stride+1;float *d_col;// im2col buffer: [C*ksize*ksize, H_out*W_out]cudaMalloc(&d_col,sizeof(float)*C*ksize*ksize*H_out*W_out);im2col_gpu(d_input,C,H,W,ksize,pad,stride,d_col);//原本是W*x的順序,但是由于cublasSgemm函數的特性,寫的時候參照實際傳遞的方式// GEMM: d_weight:[K, C*ksize*ksize] * d_col:[C*ksize*ksize, H_out*W_out] = [K, H_out*W_out]const float alpha=1.0f,beta=0.0f;cublasSgemm(handle,CUBLAS_OP_N,CUBLAS_OP_N,H_out*W_out,K,C*ksize*ksize,&alpha,d_col,H_out*W_out,d_weight,C*ksize*ksize,&beta,d_output,H_out*W_out);cudaFree(d_col);
}
int main() {int C=1, H=5, W=5;int K=1, ksize=3, stride=1, pad=1;std::vector<float> h_input(C*H*W, 1.0f);   // 全1輸入std::vector<float> h_weight(K*C*ksize*ksize, 1.0f); // 全1卷積核std::vector<float> h_output(K*H*W);float *d_input, *d_weight, *d_output;cudaMalloc(&d_input, h_input.size()*sizeof(float));cudaMalloc(&d_weight, h_weight.size()*sizeof(float));cudaMalloc(&d_output, h_output.size()*sizeof(float));cudaMemcpy(d_input, h_input.data(), h_input.size()*sizeof(float), cudaMemcpyHostToDevice);cudaMemcpy(d_weight, h_weight.data(), h_weight.size()*sizeof(float), cudaMemcpyHostToDevice);cublasHandle_t handle;cublasCreate(&handle);conv_forward_im2col(d_input, d_weight, d_output,C,H,W,K,ksize,stride,pad, handle);cudaMemcpy(h_output.data(), d_output, h_output.size()*sizeof(float), cudaMemcpyDeviceToHost);cublasDestroy(handle);// 打印結果int H_out=(H+2*pad-ksize)/stride+1;int W_out=(W+2*pad-ksize)/stride+1;std::cout << "Output (" << K << "," << H_out << "," << W_out << "):\n";for(int i=0;i<K;i++){for(int h=0;h<H_out;h++){for(int w=0;w<W_out;w++){std::cout << h_output[i*H_out*W_out+h*W_out+w] << " ";}std::cout << "\n";}}cudaFree(d_input);cudaFree(d_weight);cudaFree(d_output);return 0;
}

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/98013.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/98013.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/98013.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Loopback for Mac:一鍵打造虛擬音頻矩陣,實現跨應用音頻自由流轉

虛擬音頻設備創建 模擬物理設備&#xff1a;Loopback允許用戶在Mac上創建虛擬音頻設備&#xff0c;這些設備可被系統及其他應用程序識別為真實硬件&#xff0c;實現音頻的虛擬化傳輸。多源聚合&#xff1a;支持將麥克風、應用程序&#xff08;如Skype、Zoom、GarageBand、Logic…

深入解析Django重定向機制

概述 核心是一個基類 HttpResponseRedirectBase&#xff0c;以及兩個具體的子類 HttpResponseRedirect&#xff08;302 臨時重定向&#xff09;和 HttpResponsePermanentRedirect&#xff08;301 永久重定向&#xff09;。它們都是 HttpResponse 的子類&#xff0c;專門用于告訴…

【Java實戰?】從IO到NIO:Java高并發編程的飛躍

目錄一、NIO 與 IO 的深度剖析1.1 IO 的局限性1.2 NIO 核心特性1.3 NIO 核心組件1.4 NIO 適用場景二、NIO 核心組件實戰2.1 Buffer 緩沖區2.2 Channel 通道2.3 Selector 選擇器2.4 NIO 文件操作案例三、NIO2.0 實戰3.1 Path 類3.2 Files 類3.3 Files 類高級操作3.4 NIO2.0 實戰…

OpenCV 實戰:圖像模板匹配與旋轉處理實現教程

目錄 一、功能概述&#xff1a;代碼能做什么&#xff1f; 二、環境準備&#xff1a;先搭好運行基礎 1. 安裝 Python 2. 安裝 OpenCV 庫 3. 準備圖像文件 三、代碼逐段解析&#xff1a;從基礎到核心 1. 導入 OpenCV 庫 2. 讀取圖像文件 3. 模板圖像旋轉&#xff1a;處理…

一、cadence的安裝及入門教學(反相器的設計與仿真)

一、Cadence的安裝 1、安裝VMware虛擬機 2、安裝帶有cadence軟件的Linux系統 注&#xff1a;網盤鏈接 分享鏈接&#xff1a;https://disk.ningsuan.com.cn/#s/8XaVdtRQ 訪問密碼&#xff1a;11111 所有文件壓縮包及文檔密碼&#xff1a; Cadence_ic 3、安裝tsmc18工藝庫…

用ai寫了個UE5插件

文章目錄實際需求1.頭文件2.源文件3.用法小結實際需求 這個需求來源于之前的一個項目&#xff0c;當時用了一個第三方插件&#xff0c;里邊有一些繪制線段的代碼&#xff0c;c層用的是drawdebugline&#xff0c;當時看底層&#xff0c;覺得應該沒問題&#xff0c;不應該在rele…

機器學習從入門到精通 - 強化學習初探:Q-Learning到Deep Q-Network實戰

機器學習從入門到精通 - 強化學習初探&#xff1a;從 Q-Learning 到 Deep Q-Network 實戰 一、開場白&#xff1a;推開強化學習這扇門 不知道你有沒有過這種感覺 —— 盯著一個復雜的系統&#xff0c;既想讓它達到某個目標&#xff0c;又苦于無法用傳統規則去精確描述每一步該怎…

【OpenHarmony文件管理子系統】文件訪問接口解析

OpenHarmony文件訪問接口&#xff08;filemanagement_file_api&#xff09; 概述 OpenHarmony文件訪問接口&#xff08;filemanagement_file_api&#xff09;是開源鴻蒙操作系統中的核心文件系統接口&#xff0c;為應用程序提供了完整的文件IO操作能力。該項目基于Node-API&…

云手機運行是否消耗自身流量?

云手機運行是否消耗自身流量&#xff0c;取決于具體的使用場景和設置&#xff1a;若用戶在連接云手機時&#xff0c;使用的是家中Wi-Fi、辦公室局域網等非移動數據網絡&#xff0c;那么在云手機運行過程中&#xff0c;基本不會消耗用戶自身的移動數據流量&#xff0c;在家中連接…

JavaSe之多線程

一、多線程基本了解 1、多線程基本知識 1.進程:進入到內存中執行的應用程序 2.線程:內存和CPU之間開通的通道->進程中的一個執行單元 3.線程作用:負責當前進程中程序的運行.一個進程中至少有一個線程,一個進程還可以有多個線程,這樣的應用程序就稱之為多線程程序 4.簡單理解…

產品月報|睿本云8月產品功能迭代

睿本云8月更新已陸續上線&#xff01; 睿本云8月產品月報&#xff0c;點擊查收&#x1f447;小程序支付成功彈窗廣告、企業會員增加卡券銷售和卡券退貨模塊、工廠端可批量新增多門店訂貨單、門店端和工廠端新增“極速訂貨”、商品調撥業務支持自定義多種流程配置等功能迭代更新…

融云:當我們談論 AI 重構業務時,我們到底在談論什么

所有業務都值得用 AI 重新做一次。 這句話正在從一句鼓舞人心的口號&#xff0c;演變為一場無人可避的商業現實。AI 帶來的結構性機會&#xff0c;意味著企業有機會從根本上重構成本、效率與體驗的曲線。但這一切最終都要回到一個無比務實的問題上&#xff1a; AI 究竟如何在我…

org.yaml.snakeyaml.error.YAMLException: java.nio.charset.MalformedInputException: Input length = 1異常

org.yaml.snakeyaml.error.YAMLException: java.nio.charset.MalformedInputException: Input length 1異常問題解決一、問題背景二、錯誤現象三、原因分析核心問題&#xff1a;字符集不匹配四、解決過程試錯路徑記錄五、最終方案1.創建launch.json文件&#xff0c;修改VSCode…

【C語言】深入理解指針(5)

目錄 sizeof和strlen 1.sizeof 2.strlen 3. sizeof 和 strlen 的對比 sizeof和strlen 1.sizeof sizeo正名&#xff1a;sizeof是操作符&#xff0c;不是函數&#xff0c;sizeof是操作符&#xff0c;括號內如果有計算不會進行計算sizeof 是操作符&#xff0c;用于計算變量所…

動態代理設計模式

JDK動態代理實現 動態代理利用了JDK API,動態地在內存中構建代理對象,從而實現對目標對象的代理功能.動態代理又被稱為JDK代理或接口代理. 靜態代理與動態代理的區別: 靜態代理在編譯時就已經實現了,編譯完成后代理類是一個實際的class文 動態代理是在運行時動態生成的,即編譯…

《Html泛型魔法學院:用霍格沃茨風格網頁教授集合框架》

一、項目概述 這個創意教學網頁&#xff0c;將Java泛型與集合框架知識融入霍格沃茨魔法世界主題。通過沉浸式UI設計和交互式代碼練習&#xff0c;讓抽象的技術概念變得生動有趣。主要技術棧包括&#xff1a; HTML5語義化結構Tailwind CSS框架Font Awesome圖標庫純JavaScript交…

學習PaddlePaddle--環境配置-PyCharm + Conda?

第一階段&#xff1a;安裝與配置 Python 和 Conda?? 雖然 PyCharm 可以管理環境&#xff0c;但我們先獨立準備好 Conda 環境&#xff0c;這樣更清晰可靠。 ??1. 安裝 Miniconda (Python 環境管理)?? 1. ??下載??&#xff1a; ? 訪問 Miniconda 官網。 ? 選擇 ??M…

【數據庫】Sql Server數據庫中isnull、iif、case when三種方式的使用和空值判斷

大家好&#xff0c;我是全棧小5&#xff0c;歡迎來到《小5講堂》。 這是《Sql Server》系列文章&#xff0c;每篇文章將以博主理解的角度展開講解。 溫馨提示&#xff1a;博主能力有限&#xff0c;理解水平有限&#xff0c;若有不對之處望指正&#xff01; 目錄前言ISNULL用法c…

【藍橋杯選拔賽真題64】C++最大空白區 第十四屆藍橋杯青少年創意編程大賽 算法思維 C++編程選拔賽真題解

C++最大空白區 第十四屆藍橋杯青少年創意編程大賽C++選拔賽真題 博主推薦 所有考級比賽學習相關資料合集【推薦收藏】 1、C++專欄 電子學會C++一級歷年真題解析 電子學會C++二級歷年真題解析

試用Augment編寫python腳本實現智能家居3D環境交互響應

環境配置 VS Code中直接安裝Augment擴展&#xff0c;然后郵箱登錄就能獲得7天的試用。 從如下位置安裝3D建模軟件Blender&#xff1a; https://www.blendercn.org/downloadme#xiazai Blender 是一款免費開源的 3D 創作套件。它支持整個三維流程&#xff1a;建模、綁定、動畫…