Pytorch中CNN入門思想及實現

CNN卷積神經網絡

基礎概念:

以卷積操作為基礎的網絡結構,每個卷積核可以看成一個特征提取器。
在這里插入圖片描述

思想:

每次觀察數據的一部分,如圖,在整個矩陣中只觀察黃色部分3×3的矩陣,將這【3×3】矩陣·(點乘)權重得到特征矩陣的第一項,然后進行平移進行第二項的計算。依此類推,得到最后的特征矩陣。
在這里插入圖片描述
在這里插入圖片描述

利用Pytorch框架實現CNN

import torch
import torch.nn as nn
import numpy as np"""
使用pytorch實現CNN
不考慮偏差值
"""class TorchCNN(nn.Module):def __init__(self, in_channel, out_channel, kernel):super(TorchCNN, self).__init__()self.layer = nn.Conv2d(in_channel, out_channel, kernel, bias=False)def forward(self, x):return self.layer(x)x = np.array([[0.1, 0.2, 0.3, 0.4],[-3, -4, -5, -6],[5.1, 6.2, 7.3, 8.4],[-0.7, -0.8, -0.9, -1]])  #網絡輸入#torch實驗
in_channel = 1      #單通道(NLP中一般用單通道)
out_channel = 3     #多少個卷積核(每一個卷積核代表一個獨立的權重)
kernel_size = 2     #2*2的方塊(功能就是圖中黃色[3×3]矩陣)
torch_model = TorchCNN(in_channel, out_channel, kernel_size)
# print(torch_model.state_dict())
torch_w = torch_model.state_dict()["layer.weight"]
# print(torch_w.numpy().shape)
torch_x = torch.FloatTensor([[x]])
#權重是4維,輸入應該也為四維,通過多一個[],將輸入由三維變成四維
output = torch_model.forward(torch_x)
output = output.detach().numpy()
print(output, output.shape, "torch模型預測結果\n")

自定義模型代碼實現CNN:

采用自定義模型實現CNN,不考慮偏差值,因為要與Pytorch框架結果相對比,需要調取在Pytorch模型中的輸入和隨機權重。因此如果要運行,須將此代碼放在Pytorch框架下運行。

"""
手動實現簡單的神經網絡
與Pytorch對比實驗
"""
#自定義CNN模型
class DiyModel:def __init__(self, input_height, input_width, weights, kernel_size):self.height = input_heightself.width = input_widthself.weights = weightsself.kernel_size = kernel_sizedef forward(self, x):output = []for kernel_weight in self.weights:kernel_weight = kernel_weight.squeeze().numpy()#weight取出來時是[1×2×2],通過squeeze變成[2×2],然后變成numpy取出kernel_output = np.zeros((self.height - kernel_size + 1, self.width - kernel_size + 1)) #全0輸出矩陣for i in range(self.height - kernel_size + 1):for j in range(self.width - kernel_size + 1):window = x[i:i+kernel_size, j:j+kernel_size]   #x是原始輸入 剩下的是矩陣索引方法kernel_output[i, j] = np.sum(kernel_weight * window)  #np.dot != x*y   x*y是點乘(對應位置相乘)output.append(kernel_output)return np.array(output)diy_model = DiyModel(x.shape[0], x.shape[1], torch_w, kernel_size)
output = diy_model.forward(x)
print(output, "diy模型預測結果")

最終對比結果:

在這里插入圖片描述
可以清楚看到Pytorch框架下的結果與自定義框架下的結果相同。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/389582.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/389582.shtml
英文地址,請注明出處:http://en.pswp.cn/news/389582.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

java常用設計模式一:單例模式

1、餓漢式 package singleton.demo;/*** author Administrator* date 2019/01/07*/ public class Singleton {//在調用getInstance方法前,實例已經創建好private static Singleton instance new Singleton();//私有構造,防止被實例化private Singleton(…

SDUT-2121_數據結構實驗之鏈表六:有序鏈表的建立

數據結構實驗之鏈表六:有序鏈表的建立 Time Limit: 1000 ms Memory Limit: 65536 KiB Problem Description 輸入N個無序的整數,建立一個有序鏈表,鏈表中的結點按照數值非降序排列,輸出該有序鏈表。 Input 第一行輸入整數個數N&…

事件映射 消息映射_映射幻影收費站

事件映射 消息映射When I was a child, I had a voracious appetite for books. I was constantly visiting the library and picking new volumes to read, but one I always came back to was The Phantom Tollbooth, written by Norton Juster and illustrated by Jules Fei…

前端代碼調試常用

轉載于:https://www.cnblogs.com/tabCtrlShift/p/9076752.html

Pytorch中BN層入門思想及實現

批歸一化層-BN層(Batch Normalization) 作用及影響: 直接作用:對輸入BN層的張量進行數值歸一化,使其成為均值為零,方差為一的張量。 帶來影響: 1.使得網絡更加穩定,結果不容易受到…

JDK源碼學習筆記——TreeMap及紅黑樹

找了幾個分析比較到位的,不再重復寫了…… Java 集合系列12之 TreeMap詳細介紹(源碼解析)和使用示例 【Java集合源碼剖析】TreeMap源碼剖析 java源碼分析之TreeMap基礎篇 關于紅黑樹: Java數據結構和算法(十一)——紅黑樹 【數據結…

匿名內部類和匿名類_匿名schanonymous

匿名內部類和匿名類Everybody loves a fad. You can pinpoint someone’s generation better than carbon dating by asking them what their favorite toys and gadgets were as a kid. Tamagotchi and pogs? You were born around 1988, weren’t you? Coleco Electronic Q…

Pytorch框架中SGD&Adam優化器以及BP反向傳播入門思想及實現

因為這章內容比較多,分開來敘述,前面先講理論后面是講代碼。最重要的是代碼部分,結合代碼去理解思想。 SGD優化器 思想: 根據梯度,控制調整權重的幅度 公式: 權重(新) 權重(舊) - 學習率 梯度 Adam…

朱曄和你聊Spring系列S1E3:Spring咖啡罐里的豆子

標題中的咖啡罐指的是Spring容器,容器里裝的當然就是被稱作Bean的豆子。本文我們會以一個最基本的例子來熟悉Spring的容器管理和擴展點。閱讀PDF版本 為什么要讓容器來管理對象? 首先我們來聊聊這個問題,為什么我們要用Spring來管理對象&…

ab實驗置信度_為什么您的Ab測試需要置信區間

ab實驗置信度by Alos Bissuel, Vincent Grosbois and Benjamin HeymannAlosBissuel,Vincent Grosbois和Benjamin Heymann撰寫 The recent media debate on COVID-19 drugs is a unique occasion to discuss why decision making in an uncertain environment is a …

基于Pytorch的NLP入門任務思想及代碼實現:判斷文本中是否出現指定字

今天學了第一個基于Pytorch框架的NLP任務: 判斷文本中是否出現指定字 思路:(注意:這是基于字的算法) 任務:判斷文本中是否出現“xyz”,出現其中之一即可 訓練部分: 一&#xff…

erlang下lists模塊sort(排序)方法源碼解析(二)

上接erlang下lists模塊sort(排序)方法源碼解析(一),到目前為止,list列表已經被分割成N個列表,而且每個列表的元素是有序的(從大到小) 下面我們重點來看看mergel和rmergel模塊,因為我…

洛谷P4841 城市規劃(多項式求逆)

傳送門 這題太珂怕了……如果是我的話完全想不出來…… 題解 1 //minamoto2 #include<iostream>3 #include<cstdio>4 #include<algorithm>5 #define ll long long6 #define swap(x,y) (x^y,y^x,x^y)7 #define mul(x,y) (1ll*(x)*(y)%P)8 #define add(x,y) (x…

支撐阻力指標_使用k表示聚類以創建支撐和阻力

支撐阻力指標Note from Towards Data Science’s editors: While we allow independent authors to publish articles in accordance with our rules and guidelines, we do not endorse each author’s contribution. You should not rely on an author’s works without seek…

高版本(3.9版本)python在anaconda安裝opencv庫及skimage庫(scikit_image庫)諸多問題解決辦法

今天開始CV方向的學習&#xff0c;然而剛拿到基礎代碼的時候發現 from skimage.color import rgb2gray 和 import cv2標紅&#xff08;這里是因為我已經配置成功了&#xff0c;所以沒有紅標&#xff09;&#xff0c;我以為是單純兩個庫沒有下載&#xff0c;去pycharm中下載ski…

python 實現斐波那契數列

# coding:utf8 __author__ blueslidef fun(arg1,arg2,stop):if arg10:print(arg1,arg2)arg3 arg1arg2print(arg3)if arg3<stop:arg3 fun(arg2,arg3,stop)fun(0,1,100)轉載于:https://www.cnblogs.com/bluesl/p/9079705.html

單機安裝ZooKeeper

2019獨角獸企業重金招聘Python工程師標準>>> zookeeper下載、安裝以及配置環境變量 本節介紹單機的zookeeper安裝&#xff0c;官方下載地址如下&#xff1a; https://archive.apache.org/dist/zookeeper/ 我這里使用的是3.4.11版本&#xff0c;所以找到相應的版本點…

均線交易策略的回測 r_使用r創建交易策略并進行回測

均線交易策略的回測 rR Programming language is an open-source software developed by statisticians and it is widely used among Data Miners for developing Data Analysis. R can be best programmed and developed in RStudio which is an IDE (Integrated Development…

opencv入門課程:彩色圖像灰度化和二值化(采用skimage庫和opencv庫兩種方法)

用最簡單的辦法實現彩色圖像灰度化和二值化&#xff1a; 首先采用skimage庫&#xff08;skimage庫現在在scikit_image庫中&#xff09;實現&#xff1a; from skimage.color import rgb2gray import numpy as np import matplotlib.pyplot as plt""" skimage庫…

SVN中Revert changes from this revision 跟Revert to this revision

譬如有個文件&#xff0c;有十個版本&#xff0c;假定版本號是1&#xff0c;2&#xff0c;3&#xff0c;4&#xff0c;5&#xff0c;6&#xff0c;7&#xff0c;8&#xff0c;9&#xff0c;10。Revert to this revision&#xff1a; 如果是在版本6這里點擊“Revert to this rev…