5 手寫卷積函數

5 手寫卷積函數

  • 背景
  • 介紹
  • 滑動窗口的方式
    • 代碼
    • 問題
  • 矩陣乘法的方式
    • 原理
    • 代碼
    • 結果
  • 效果對比
    • 對比代碼
    • 日志
    • 結果
  • 一些思考

背景

從現在開始各種手寫篇章,先從最經典的卷積開始

介紹

對于卷積層的具體操作,我這里就不在具體說卷積具體是什么東西了。
對于手寫卷積操作而言,有兩種方式,一種就是最樸素的通過滑動窗口來實現的方式,另一種方式就是使用矩陣乘法來簡化操作過程的方式。

滑動窗口的方式

在這里插入圖片描述

卷積操作的動圖https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md

通過上面的圖片和連接就可以很直觀地感受到卷積操作的方式,也能很直接想到使用簡單的滑動窗口來實現,如果還不能理解,建議去B站搜下視頻學習下

代碼

"""
-*- coding: utf-8 -*-
使用滑動窗口方式的手動卷積
@Author : Leezed
@Time : 2025/6/27 15:33
"""import numpy as npclass ManualSlideWindowConv():"""手動實現卷積操作,使用滑動窗口方式沒有實現反向傳播功能"""def __init__(self, kernel_size, in_channel, out_channel, stride=1, padding=0, bias=True):self.kernel_size = kernel_sizeself.in_channel = in_channelself.out_channel = out_channelself.stride = strideself.padding = paddingself.bias = biasself.weight = np.random.randn(out_channel, in_channel, kernel_size, kernel_size)if bias:self.bias = np.random.randn(out_channel)else:self.bias = Nonedef print_weight(self):print("Weight shape:", self.weight.shape)print("Weight values:\n", self.weight)def get_weight(self):return self.weightdef set_weight(self, weight):if weight.shape != self.weight.shape:raise ValueError(f"Weight shape mismatch: expected {self.weight.shape}, got {weight.shape}")self.weight = weightdef __call__(self, x, *args, **kwargs):if self.padding > 0:x = np.pad(x, ((0, 0), (0, 0), (self.padding, self.padding), (self.padding, self.padding)), mode='constant')  # 在四周填充0batch_size, in_channel, height, width = x.shapekernel_size = self.kernel_size# 計算輸出的高度和寬度out_height = (height - kernel_size) // self.stride + 1out_width = (width - kernel_size) // self.stride + 1output = np.zeros((batch_size, self.out_channel, out_height, out_width))for channel in range(self.out_channel):# 取出當前輸出通道的權重kernel = self.weight[channel, :, :, :]# 添加biasif self.bias is not None:output[:, channel, :, :] += self.bias[channel]else:output[:, channel, :, :] = 0for i, end_height in enumerate(range(kernel_size - 1, height, self.stride)):for j, end_width in enumerate(range(kernel_size - 1, width, self.stride)):# 取出圖像的滑動窗口start_height = end_height - kernel_size + 1start_width = end_width - kernel_size + 1window = x[:, :, start_height:end_height + 1, start_width:end_width + 1]# 計算卷積result = np.sum(kernel * window, axis=(1, 2, 3))output[:, channel, i, j] += resultreturn outputif __name__ == '__main__':# 測試代碼x = np.random.randn(2, 3, 5, 5)  # batch_size=2, in_channel=3, height=5, width=5conv_layer = ManualSlideWindowConv(kernel_size=3, in_channel=3, out_channel=2, stride=1, padding=1)output = conv_layer(x)print("Output shape:", output.shape)conv_layer.print_weight()

問題

但是活動卷積的方式有一個問題就是這個方式太費時了,因為有三層循環,而對于python而言是用循環去計算是一件費力不討好的事情,具體的時間花費我會在后面畫出圖來直觀地展現

矩陣乘法的方式

原理

https://zhuanlan.zhihu.com/p/360859627
https://gist.github.com/hsm207/7bfbe524bfd9b60d1a9e209759064180
https://blog.csdn.net/caip12999203000/article/details/126494740

具體的原理我就不贅述了,上面的三個鏈接認真看也能看明白了,他的本質思想就是講滑動窗口中的多次乘法,直接改成矩陣乘法,通過這種方式來進行加速,而且加速的幅度不小,但是會生成一個較大的矩陣,不可避免的帶來內存的開銷,也就是說本質是拿空間換時間

具體的例子就在圖中
在這里插入圖片描述
mmConv = ManualMatMulConv(kernel_size=3, in_channel=3, out_channel=64,padding=1)的卷積層面對
x = np.random.randn(64, 3, 224, 224).astype(np.float32)的特征就要吃1.5G+的內存了。

代碼

class ManualMatMulConv():"""手動實現卷積操作,使用卷積乘法方式沒有實現反向傳播功能"""def __init__(self, kernel_size, in_channel, out_channel, stride=1, padding=0, bias=True):self.kernel_size = kernel_sizeself.in_channel = in_channelself.out_channel = out_channelself.stride = strideself.padding = paddingself.bias = biasself.weight = np.random.randn(out_channel, in_channel, kernel_size, kernel_size)if bias:self.bias = np.random.randn(out_channel)else:self.bias = Nonedef print_weight(self):print("Weight shape:", self.weight.shape)print("Weight values:\n", self.weight)def get_weight(self):return self.weightdef set_weight(self, weight):if weight.shape != self.weight.shape:raise ValueError(f"Weight shape mismatch: expected {self.weight.shape}, got {weight.shape}")self.weight = weightdef __call__(self, x, *args, **kwargs):if self.padding > 0:x = np.pad(x, ((0, 0), (0, 0), (self.padding, self.padding), (self.padding, self.padding)), mode='constant')  # 在四周填充0batch_size, in_channel, height, width = x.shapekernel_size = self.kernel_size# 計算輸出的高度和寬度out_height = (height - kernel_size) // self.stride + 1out_width = (width - kernel_size) // self.stride + 1# 將權重轉換為矩陣形式weight_matrix = self.weight.reshape(self.out_channel, -1)  # shape (out_channel, in_channel * kernel_size * kernel_size)# 將輸入轉為矩陣形式 手寫unfold方式unfolded_x = []for i in range(0, height - kernel_size + 1, self.stride):for j in range(0, width - kernel_size + 1, self.stride):# 取出圖像的滑動窗口 轉成矩陣形式window = x[:, :, i:i + kernel_size, j:j + kernel_size].reshape(batch_size, -1)unfolded_x.append(window)unfolded_x = np.array(unfolded_x)  # shape: (num_windows, batch_size, in_channel * kernel_size * kernel_size)unfolded_x = np.transpose(unfolded_x, (1, 0, 2))  # shape: (batch_size, num_windows, in_channel * kernel_size * kernel_size)# 使用矩陣乘法計算卷積output = np.matmul(unfolded_x, weight_matrix.T)  # shape (batch_size, num_windows, out_channel)output = np.transpose(output, (0, 2, 1))  # shape (batch_size, out_channel, num_windows)output = output.reshape(batch_size, self.out_channel, out_height, out_width)# 添加biasif self.bias is not None:output += self.bias.reshape(1, -1, 1, 1)# 輸出結果return output

結果

檢測代碼

if __name__ == '__main__':# 測試代碼conv = ManualMatMulConv(kernel_size=3, in_channel=3, out_channel=2, stride=1, padding=0, bias=False)slide_window_conv = ManualSlideWindowConv(kernel_size=3, in_channel=3, out_channel=2, stride=1, padding=0, bias=False)conv.set_weight(slide_window_conv.get_weight())x = np.random.randn(1, 3, 5, 5)  # 輸入形狀 (batch_size, in_channel, height, width)output = conv(x)slide_window_output = slide_window_conv(x)print("Output shape:", output.shape)print("slide_window_output shape:", slide_window_output.shape)assert np.allclose(conv.get_weight(), slide_window_conv.get_weight()), "Weights do not match!"print("output:")print(output)print("slide_window_output:")print(slide_window_output)# 校驗是否相同assert np.allclose(output, slide_window_output), "Outputs do not match!"print("Outputs match!")

在這里插入圖片描述

效果對比

這里采取了四種方式的卷積來進行對比

  1. 滑動窗口方式的卷積
  2. 矩陣乘法的卷積
  3. torch.nn.Conv2d
  4. torch.nn.Conv2d 使用cuda

對比代碼

import numpy as np
from matplotlib import pyplot as plt
from manual.conv.slide_window import ManualSlideWindowConv
from manual.conv.matmul import ManualMatMulConv
import torch
import time# 對比不同batchsize的卷積速度speeds = {'manual_matmul': [],'manual_slide_window': [],'torch': [],'torch_cuda': []
}swConv = ManualSlideWindowConv(kernel_size=3, in_channel=3, out_channel=64,padding=1)
mmConv = ManualMatMulConv(kernel_size=3, in_channel=3, out_channel=64,padding=1)
torchConv = torch.nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, padding=1)
torchCudaConv = torch.nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, padding=1).cuda()def timing_conv(conv,x):start = time.time()y = conv(x)end = time.time()return y, end - startfor bs in [1, 2, 4, 8, 16, 32]:x = np.random.randn(bs, 3, 224, 224).astype(np.float32)x_torch = torch.from_numpy(x)x_torch_cuda = x_torch.cuda()y, speed = timing_conv(swConv, x)speeds['manual_slide_window'].append(speed)print(f'slide_window bs={bs}, speed={speed:.4f}s')y, speed = timing_conv(mmConv, x)speeds['manual_matmul'].append(speed)print(f'matmul bs={bs}, speed={speed:.4f}s')y, speed = timing_conv(torchConv, x_torch)speeds['torch'].append(speed)print(f'torch bs={bs}, speed={speed:.4f}s')y, speed = timing_conv(torchCudaConv, x_torch_cuda)speeds['torch_cuda'].append(speed)print(f'torch_cuda bs={bs}, speed={speed:.4f}s')print('-' * 50)

日志

slide_window bs=1, speed=39.8342s
matmul bs=1, speed=0.1436s
torch bs=1, speed=0.0080s
torch_cuda bs=1, speed=0.0000s
--------------------------------------------------
slide_window bs=2, speed=39.8841s
matmul bs=2, speed=0.2185s
torch bs=2, speed=0.0172s
torch_cuda bs=2, speed=0.0010s
--------------------------------------------------
slide_window bs=4, speed=44.0416s
matmul bs=4, speed=0.3975s
torch bs=4, speed=0.0329s
torch_cuda bs=4, speed=0.0000s
--------------------------------------------------
slide_window bs=8, speed=41.7520s
matmul bs=8, speed=0.3222s
torch bs=8, speed=0.0588s
torch_cuda bs=8, speed=0.0000s
--------------------------------------------------
slide_window bs=16, speed=45.5278s
matmul bs=16, speed=0.5858s
torch bs=16, speed=0.1067s
torch_cuda bs=16, speed=0.0010s
--------------------------------------------------
slide_window bs=32, speed=58.1965s
matmul bs=32, speed=1.2161s
torch bs=32, speed=0.2045s
torch_cuda bs=32, speed=0.0010s
--------------------------------------------------

結果

在這里插入圖片描述

去掉最慢的滑動窗口的結果展示
在這里插入圖片描述
可以看到矩陣乘法的方式還是挺快的,至少比滑動窗口快多了

一些思考

但是隨之而來的就是還有一個問題,為什么矩陣乘法的方式的內存開銷這么大,但是torch.nn.Conv2d好像并沒有這個問題

經過查閱了一些資料,我簡單總結一下

  1. 瓦片式(Tiling)或分塊(Blocking)計算

    雖然 矩陣乘法 概念上是將整個輸入和卷積核展開,但實際的硬件實現(如GPU)并不總是一次性處理所有數據。它們可能會將計算任務分解成更小的“瓦片”或“塊”。

    局部 矩陣乘法: 不是一次性將整個圖像展開,而是每次只對輸入的一小部分(例如一個批次、或者一個輸出通道的一個小區域)進行 im2col 變換和矩陣乘法。這樣可以限制中間矩陣的大小,從而減少瞬時內存占用。計算完成后,再將結果拼接到最終的輸出特征圖上。

    重用數據: 這種分塊策略有助于更好地利用 CPU 緩存或 GPU 顯存,因為同一小塊數據可以在其被完全處理完畢前,反復用于計算,減少數據在主存和緩存之間的移動。

  2. 智能算法選擇
    根據卷積參數動態選擇最適合的底層算法(如 Winograd, FFT, 或優化過的直接卷積),而不是單一地依賴 im2col。

這也就是為啥我們寫出來的方法跟官方版本的有差距的原因。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/912194.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/912194.shtml
英文地址,請注明出處:http://en.pswp.cn/news/912194.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

vue3+element-plus,實現兩個表格同步滾動

需求:現在需要兩個表格,為了方便對比左右的數據,需要其中一邊的表格滾動時,另一邊的表格也跟著一起滾動,并且保持滾動位置的一致性。具體如下圖所示。 實現步驟: 確保兩個表格的寬度一致:如果兩…

Mysql架構

思考:Mysql需要重點學習什么: 索引:索引存儲結構、索引優化......事務:鎖機制與隔離級別、日志、集群架構 本文是對Mysql架構進行初步學習 1、Mysql鏈接 Mysql監聽器是長連接 BIO(阻塞同步IO調用), 不是NIO. 為什么…

使用deepseek制作“喝什么奶茶”隨機抽簽小網頁

教程很簡單,如下操作 1. 新建文本文檔,命名為奶茶.txt 2. 打開deepseek,發送下面這段提示詞:用html5幫我生成一個喝什么奶茶的網頁,點擊按鈕隨機生成奶茶品牌等,包括喜茶等眾多常見的奶茶品牌如果不滿意還…

WOE值:風險建模中的“證據權重”量化術——從似然比理論到FICO評分卡實踐

WOE值(Weight of Evidence,證據權重) 是信用評分和風險建模中用于量化特征分箱對目標變量的預測能力的核心指標。 本文由「大千AI助手」原創發布,專注用真話講AI,回歸技術本質。拒絕神話或妖魔化。搜索「大千AI助手」關…

js遞歸性能優化

JavaScript 遞歸性能優化 遞歸是編程中強大的技術,但在 JavaScript 中如果不注意優化可能會導致性能問題甚至棧溢出。以下是幾種優化遞歸性能的方法: 1. 尾調用優化 (Tail Call Optimization, TCO) ES6 引入了尾調用優化,但只在嚴格模式下…

vue界面增加自定義水印 js

vue整個界面增加自定義水印 需求:領導想要增加自定義水印 好不容易調完,還是想記錄一下,在.vue界面編寫 export default {mounted() {this.$nextTick(() > {this.addWatermark()})},methods: {// 關鍵:添加水印// 動態添加水印addWaterm…

Go開發工程師-Golang基礎知識篇

開篇 我們嘗試從2個方面來進行介紹: 1. 社招實際面試問題 2. 問題涉及的基礎點梳理 社招面試題 米哈游 1. Go 里面使用 Map 時應注意問題和數據結構 2. Map 擴容是怎么做的? 3. Map 的 panic 能被 recover 掉嗎?了解 panic 和 recover …

能否僅用兩臺服務器實現集群的高可用性??

我們將問題分為兩部分來回答:一是使用 Redis 或 Hazelcast 確保數據一致性后是否仍需 Oracle 或 MySQL 等數據庫;二是能否僅用兩臺服務器實現集群的高可用性。以下是詳細探討: 1. 使用 Redis 或 Hazelcast 確保數據一致性后,還需要…

spring-ai-alibaba DashScopeCloudStore自動裝配問題

問題 在學習spring-ai-alibaba時,發現1.0.0.2版本在自動裝配DashScopeCloudStore時,會報如下錯誤: Field dashScopeCloudStore in com.example.spring_ai_alibaba_examples.examples.SpringAiAlibabaExample01 required a bean of type com…

docker-compose部署nacos

1、docker-compose內容 高版本的nacos使用docker啟動,需要將所有的端口放開,僅僅開放8848端口,spring-boot客戶端獲取nacos配置的時候,可能取到的內容為空。 version: 3# 定義自定義網絡,確保服務間通信和外部訪問 ne…

CSRF 與 SSRF 的關聯與區別

CSRF 與 SSRF 的關聯與區別 區別 特性CSRF (跨站請求偽造)SSRF (服務器端請求偽造)攻擊方向客戶端 → 目標網站服務器 → 內部/外部資源攻擊目標利用用戶身份執行非預期操作利用服務器訪問內部資源或發起對外請求受害者已認證的用戶存在漏洞的服務器利用條件用戶必須已登錄目…

Payload-SDK自動升級

Payload-SDK自動升級 前言 自動升級旨在通過無人機更新負載上的軟件,包括不限于:Payload-SDK應用、配置文件等。對于文件的傳輸,大疆的Payload-SDK給我們提供了兩種方式:使用FTP協議和使用大疆自研的DCFTP。我們實現的自動升級是…

第五代移動通信新型調制及非正交多址傳輸技術研究與設計

第五代移動通信新型調制及非正交多址傳輸技術研究與設計 一、新型調制技術研究與實現 1. FBMC (濾波器組多載波) 調制實現 import numpy as np import matplotlib.pyplot as plt from scipy.fft import fft, ifft, fftshift from scipy.signal import get_window

AI 智能運維,重塑大型企業軟件運維:從自動化到智能化的進階實踐?

一、引言:企業軟件運維的智能化轉型浪潮? 在數字化轉型加速的背景下,大型企業軟件架構日益復雜,微服務、多云環境、分布式系統的普及導致傳統運維模式面臨效率瓶頸。AI 技術的滲透催生了智能運維(AIOps)的落地&#x…

Apache CXF安裝詳細教程(Windows)

本章教程,主要介紹,如何在Windows上安裝Apache CXF,JDK版本是使用的1.8. 一、下載Apache CXF Apache CXF(Apache Celtix Fireworks)是一個開源的 Web 服務框架,用于 構建和開發服務端與客戶端的 Web 服務應用程序。它支持多種 Web 服務標準,尤其是 SOAP(基于 XML 的協議…

逆向入門(22)程序逆向篇-TraceMe

界面看起來很普通 也沒有殼,直接搜索字符串找到關鍵代碼處 但是發現這些都是賦值,并沒有實現跳轉相關的函數。這里通過給彈窗函數下斷點,追一下返回函數來找觸發點。 再次點擊check,觸發斷點,接著按ctrlF9返回到函數…

中文PDF解析準確率排名

市面上的文檔解析工具種類各異,包括更適用于論文解析的,專精于表格數據提取的,針對手寫體優化的,適用于技術文檔的,擅長處理復雜多語言混排文檔的,專門處理政府招標文檔表格的,以及擅長金融類表…

Conformal LEC:官方學習教程

相關閱讀 Conformal LEChttps://blog.csdn.net/weixin_45791458/category_12993839.html?spm1001.2014.3001.5482 本文是對Conformal Equivalence Checking User Guide中附錄實驗的翻譯(有刪改),實驗文件可見安裝目錄Conformal/share/cfm/l…

【Torch】nn.Embedding算法詳解

1. 定義 nn.Embedding 是 PyTorch 中的 查表式嵌入層(lookup‐table),用于將離散的整數索引(如詞 ID、實體 ID、離散特征類別等)映射到一個連續的、可訓練的低維向量空間。它通過維護一個形狀為 (num_embeddings, emb…

cdq 三維偏序應用 / P4169 [Violet] 天使玩偶/SJY擺棋子

最近學了 cdq 分治想來做做這道題,結果被有些毒瘤的代碼惡心到了。 /ll 題目大意:一開始給定一些平面中的點。然后給定一些修改和詢問: 修改:增加一個點。詢問:給定一個點,求離這個點最近(定義…