【詳細圖解】再次理解im2col
轉自:https://mp.weixin.qq.com/s/GPDYKQlIOq6Su0Ta9ipzig
一句話:im2col是將一個[C,H,W]矩陣變成一個[H,W]矩陣的一個方法,其原理是利用了行列式進行等價轉換。
為什么要做im2col? 減少調用gemm的次數。
重要:本次的代碼只是為了方便理解im2col,不是用來做加速,所以代碼寫的很簡單且沒有做任何優化。
一、卷積的可視化
例子是一個[1, 6, 6]的輸入,卷積核是[1, 3, 3],stride等于1,padding等于0。那么卷積的過程可視化如下圖,一共需要做16次卷積計算,每次卷積計算有9次乘法和8次加法。
輸出的公式如下,即Output_height = (6 - 3 + 2*0)/1 + 1 = 4 = Output_width
二、行列式
乘號左邊的橫條,跟乘號右邊的豎條進行點乘(即每個元素對應相乘后再全部加起來)。
關于行列式,大家都清楚的一點,一根橫條的元素個數要等于一根豎條的元素個數(這樣才可以讓做點乘的時候能一一對應起來,不會讓小方塊落單)。豎條有多少條,出來的結果就有多少個小方塊(在橫條的個數為1的情況下)。
出來的結果(等號的右邊)的行數等于乘號左邊的橫條的行數,出來的結果(等號的右邊)的列數等于乘號右邊的橫條的列數,公式表示就是[row, x] * [x, col] = [row, col]。舉個例子[3, 8] * [8, 4] = [3, 4]
三、[1, H, W]的im2col
展開后,就可以直接做兩個數組的矩陣乘積了
import numpy as npscr = np.array(np.arange(0,7**2).reshape(7, 7))
intH, intW= scr.shapekernel = np.array([-0.2589, 0.2106, -0.1583, -0.0107, 0.1177, 0.1693, -0.1582, -0.3048, -0.1946]).reshape(3,3)
KHeight, KWeight = kernel.shaperow_num = intH - KHeight + 1
col_num = intW - KWeight + 1
OutScrIm2Col = np.zeros([row_num*col_num,KHeight*KWeight]) ii, jj = 0, 0
col_cnt, row_cnt = 0, 0
for i in range(0, row_num):for j in range(0, col_num): # 這倆個for是為了遍歷列,即乘了多少次,這里完全可以merge成一個for循環,只需要提前計算好就行ii = ijj = jfor iii in range(0, KHeight): # 這倆個for是為了取出一次 一橫 * 一豎 的 行列式,這里完全可以mege成一個for循環,只需要提前計算好就行for jjj in range(0, KHeight):OutScrIm2Col[row_cnt][col_cnt] = scr[ii][jj]jj +=1col_cnt += 1ii += 1jj = jcol_cnt = 0row_cnt += 1im2col_kernel = im2col_kernel.reshape(-1,9)
OutScrIm2Col = OutScrIm2Col.T
out = np.matmul(im2col_kernel,OutScrIm2Col) # 這步就是做兩個數組的矩陣乘積
中間倆個for循環是來填滿展開的數組/矩陣的每一列,即卷積核對應的元素,其個數等于卷積核的元素個數,舉個例子,[1, 3, 3]的卷積核,那么該卷積核的元素個數等于9;最外層的兩個for循環是用來填滿展開的數組/矩陣的每一行,即列數,也就是卷積核在輸入滑動了多少次
pytorch來做驗證
import torch
from torch import nn
import numpy as np
torch.manual_seed(100)net = nn.Conv2d(1, 1, 3, padding=0, bias=False)scr = np.array(np.arange(0, 7**2).reshape(1, 1, 7, 7)).astype(np.float32)
scr = torch.from_numpy(scr)print(net.weight.data) # 把這里的weight的值復制到上面numpy的代碼來做驗證
print(net(scr))# print的信息
tensor([[[[-0.2589, 0.2106, -0.1583],[-0.0107, 0.1177, 0.1693],[-0.1582, -0.3048, -0.1946]]]])
tensor([[[[ -7.6173, -8.2053, -8.7934, -9.3815, -9.9695],[-11.7337, -12.3217, -12.9098, -13.4978, -14.0859],[-15.8500, -16.4381, -17.0261, -17.6142, -18.2022],[-19.9664, -20.5545, -21.1425, -21.7306, -22.3186],[-24.0828, -24.6708, -25.2589, -25.8469, -26.4350]]]],grad_fn=<ThnnConv2DBackward>)
四、[C, H, W]的im2col
前面一堆圖,是我故意不寫文字,希望大家能夠通過圖能夠看明白。前面卷積核只有一行的情況,跟[1, H, W]的情況基本一摸一樣,只是這一行的元素個數等于卷積核的元素個數即可5x3x3=45,展開的特征圖的每一個豎條也是45。
當卷積核函數等于3的時候,就是對應的只要增加卷積核的橫條數即可,展開的特征圖沒有改變。這里希望大家用行列式的計算和普通卷積的過程聯想起來,你會發現是一摸一樣的計算過程。
代碼其實跟[1,H, W]只有一初不同,就是從特征圖里面取數據的時候多了個維度,需要取對應的通道。這里為什么要取對應的通道數呢?原因是行列式的計算中,橫條和豎條是元素一一對應做乘法。
import numpy as np
np.set_printoptions(threshold=np.inf)src = np.array(np.arange(0, 9**3))[0:5*9*9]
src = np.tile(src, 5)
src = src.reshape(-1, 5, 9, 9)
kernel = np.array([[[[-0.1158, 0.0942, -0.0708],[-0.0048, 0.0526, 0.0757],[-0.0708, -0.1363, -0.0870]],[[-0.1139, -0.1128, 0.0702],[ 0.0631, 0.0857, -0.0244],[ 0.1197, 0.1481, 0.0765]],[[-0.0823, -0.0589, -0.0959],[ 0.0966, 0.0166, 0.1422],[-0.0167, 0.1335, 0.0729]],[[-0.0032, -0.0768, 0.0597],[ 0.0083, -0.0754, 0.0867],[-0.0228, -0.1440, -0.0832]],[[ 0.1352, 0.0615, -0.1005],[ 0.1163, 0.0049, -0.1384],[ 0.0440, -0.0468, -0.0542]]]])scrN, srcChannel, intH, intW= src.shape
KoutChannel, KinChannel, kernel_H, kernel_W = kernel.shape
im2col_kernel = kernel.reshape(KoutChannel, -1)outChannel, outH, outW = KoutChannel, (intH - kernel_H + 1) , (intW - kernel_W + 1)
OutScrIm2Col = np.zeros( [ kernel_H*kernel_W*KinChannel, outH*outW ] )
row_num, col_num = OutScrIm2Col.shapeii, jj, cnt_row, cnt_col = 0, 0, 0, 0# 卷積核的reshape準備 :outchannel, k*k*inchannel
im2col_kernel = kernel.reshape(KoutChannel, -1)
# 輸入的reshape準備 :outH = (intH - k + 2*pading)/stride + 1
outChannel, outH, outW = KoutChannel, (intH - kernel_H + 1) , (intW - kernel_W + 1)i_id = -1
cnt_col = -1
cnr = 0
for Outim2colCol_H in range(0, outH):i_id += 1j_id = -1cnt_row = -1for Outim2colCol_W in range(0, outW):j_id, cnt_col += 1, += 1cnt_row = 0for c in range(0, srcChannel): # 取一次卷積的數據,放到一列for iii in range(0, kernel_H):i_number = iii + i_idfor jjj in range(0, kernel_W):j_number = jjj + j_idOutScrIm2Col[cnt_row][cnt_col] = src[bs][c][i_number][j_number]cnr +=1cnt_row += 1Out = np.matmul(im2col_kernel, OutScrIm2Col)
Out.reshape(outChannel, outH, outW)
print(Out.shape)
print(outChannel, outH, outW)
pytorch代碼的驗證
import torch
from torch import nn
import numpy as np
torch.manual_seed(100)net = nn.Conv2d(in_channels=5, out_channels=1, kernel_size=3, padding=0, bias=False)
print(net.weight.data.shape)
print(net.weight.data)scr = np.array(np.arange(0, 9**3))[:9*9*5].reshape(1, -1, 9, 9).astype(np.float32)scr = torch.from_numpy(src)
print("data:", scr.shape)
scr = torch.from_numpy(scr)
print("data:", scr.shape)Out = net(scr)
print("Our:", Out.shape)
print(Out)
五、[B, C, H, W]的im2col
問題:如何bs=9的情況呢,要怎么做im2col+gemm呢?方法 1:把filter攤平的shape變成[3,5339],把input攤平的shape變成[5339,16]
– output的shape就為[3,16]了 - ?
方法 2:把filter攤平的shape變成[39,533],把input攤平的shape變成[533,16],output的shape就為[39,16]了
– 隱患:如何filter數量是51233這種數量,那么非常占用顯存/內存
方法 3:im2col+gemm外面加一層關于bs的for循環
– 隱患:加一層for循環嵌套非常耗時
經過簡單分析,發現采取for循環的方式來進行im2col是相對合適的情況。我向msnh2012的作者穆士凝魂請教,得到的答案是,是用加一層for循環的方式居多,而且由于可以并發,多一層循環的開銷比想象中小一些。如果是推理框架的話,有部分情況bs是等于1的,所以可以規避這個問題。