機器學習實踐三---神經網絡學習

Neural Networks

在這個練習中,將實現神經網絡BP算法,練習的內容是手寫數字識別。

Visualizing the data

這次數據還是5000個樣本,每個樣本是一張20*20的灰度圖片
fig, ax_array = plt.subplots(nrows=10, ncols=10, figsize=(6, 4))for row in range(10):for column in range(10):ax_array[row, column].matshow(sample_images[10 * row + column].reshape((20, 20)).T, cmap='gray')ax_array[row, column].axis('off')plt.show()returndata = loadmat("ex4data1.mat")
X = data['X']
y = data['y']m = X.shape[0]
rand_sample_num = np.random.permutation(m)
sample_images = X[rand_sample_num[0:100], :]
display_data(sample_images)

Model representation

這是一個簡單的神經網絡,輸入層、隱藏層、輸出,樣本圖片是20*20,所以輸入層是400個單元,(再加上一個額外偏差單元),第二層隱藏層是25個單元, 輸出層是10個單元。從上面的數據顯示中有兩個變量X 和y。
ex4weights.mat 中提供了訓練好的網絡參數theta1, theta2,
theta1 has size 25 x 401
theta2 has size 10 x 26
在這里插入圖片描述

Feedforward and cost function

為了最后的輸出,我們將標簽值也就是數字從0到9, 轉化為one-hot 碼

from sklearn.preprocessing import OneHotEncoder
def to_one_hot(y):encoder =  OneHotEncoder(sparse=False)  # return a array instead of matrixy_onehot = encoder.fit_transform(y.reshape(-1,1))return y_onehot

加載數據

X, label_y = load_mat('ex4data1.mat')
X = np.insert(X, 0, 1, axis=1)
y = to_one_hot(label_y)

load weight

def load_weight(path):data = loadmat(path)return data['Theta1'], data['Theta2']t1, t2 = load_weight('ex4weights.mat')

theta 轉化
因為opt.minimize傳參問題,我們這里對theta進行平坦化

# 展開
def unrool(var1, var2):return np.r_[var1.flatten(), var2.flatten()]
# 分開矩陣化
def rool(array):return array[:25*401].reshape(25, 401), array[25*401:].reshape(10, 26)

Feedforward Regularized cost function

這里主要是前饋傳播 和 代價函數的一些邏輯,正則化為了預防高方差問題。

def sigmoid(z):return 1 / (1 + np.exp(-z))#前饋傳播
def feed_forward(theta, X):theta1, theta2 = rool(theta)a1 = Xz2 = a1.dot(theta1.T)a2 = np.insert(sigmoid(z2), 0, 1, axis=1)z3 = a2.dot(theta2.T)a3 = sigmoid(z3)return a1, z2, a2, z3, a3# a1, z2, a2, z3, h = feed_forward(t1, t2, X)def cost(theta, X, y):a1, z2, a2, z3, h = feed_forward(theta, X)J = -y * np.log(h) - (1-y) * np.log(1 - h)return J# Implement Regularization
def regularized_cost(theta, X, y, l=1):theta1, theta2 = rool(theta)temp_theta1 = theta1[:, 1:]temp_theta2 = theta2[:, 1:]reg = temp_theta1.flatten().T.dot(temp_theta1.flatten()) + temp_theta2.flatten().T.dot(temp_theta2.flatten())regularized_theta = l / (2 * len(X)) * reg return regularized_theta + cost(theta, X, y)

Backprogation

反向傳播算法,是機器學習比較難推理的算法了, 也是最重要的算法,為了得到最優的theta值, 通過進行反向傳播,來不斷跟新theta值, 當然還有一些超參數,如lambda、a 、訓練迭代次數,如果進行adam、Rmsprop等優化學習效率算法,還有有一些其他的超參數。


# random initalization# 梯度
def gradient(theta, X, y):theta1, theta2 = rool(theta)a1, z2, a2, z3, h = feed_forward(theta, X)d3 = h - yd2 = d3.dot(theta2[:, 1:]) * sigmoid_gradient(z2)D2 = d3.T.dot(a2)D1 = d2.T.dot(a1)D = (1 / len(X)) * unrool(D1, D2)return D

Sigmoid gradient

也就是對sigmoid 函數求導

def sigmoid_gradient(z):return sigmoid(z) * (1 - sigmoid(z))

Random initialization

初始化參數,我們一般使用隨機初始化np.random.randn(-2,2),生成高斯分布,再乘以一個小的數,這樣把它初始化為很小的隨機數,
這樣直觀地看就相當于把訓練放在了邏輯回歸的直線部分進行開始,初始化參數還可以盡量避免梯度消失和梯度爆炸的問題。

def random_init(size):return np.random.randn(-2, 2, size) * 0.01

Backporpagation

Regularized Neural Networks

正則化神經網絡

def regularized_gradient(theta, X, y, l=1):a1, z2, a2, z3, h = feed_forward(theta, X)D1, D2 = rool(gradient(theta, X, y))t1[:, 0] = 0t2[:, 0] = 0reg_D1 = D1 + (l / len(X)) * t1reg_D2 = D2 + (l / len(X)) * t2return unrool(reg_D1, reg_D2)

Learning parameters using fmincg

調優參數

def nn_training(X, y):init_theta = random_init(10285)  # 25*401 + 10*26res = opt.minimize(fun=regularized_cost,x0=init_theta,args=(X, y, 1),method='TNC',jac=regularized_gradient,options={'maxiter': 400})return resres = nn_training(X, y)

準確率
def accuracy(theta, X, y):
_, _, _, _, h = feed_forward(res.x, X)
y_pred = np.argmax(h, axis=1) + 1
print(classification_report(y, y_pred))

accuracy(res.x, X, label_y)

Visualizing the hidden layer

隱藏層顯示跟輸入層顯示差不多

def plot_hidden(theta):t1, _ = rool(theta)t1 = t1[:, 1:]fig, ax_array = plt.subplots(5, 5, sharex=True, sharey=True, figsize=(6, 6))for r in range(5):for c in range(5):ax_array[r, c].matshow(t1[r * 5 + c].reshape(20, 20), cmap='gray_r')plt.xticks([])plt.yticks([])plt.show()plot_hidden(res.x)

super parameter lambda update

神經網絡是非常強大的模型,可以形成高度復雜的決策邊界。如果沒有正則化,神經網絡就有可能“過度擬合”一個訓練集,從而使它在訓練集上獲得接近100%的準確性,但在以前沒有見過的新例子上則不會。你可以設置較小的正則化λ值和MaxIter參數高的迭代次數為自己看到這個結果。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/388945.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/388945.shtml
英文地址,請注明出處:http://en.pswp.cn/news/388945.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Microsoft Expression Blend 2 密鑰,key

Microsoft Expression Blend 2 密鑰,key,序列TJ2R3-WHW22-B848T-B78YJ-HHJWJ號

ethereumjs/ethereumjs-common-3-test

查看test能夠讓你更好滴了解其API文檔的使用 ethereumjs-common/tests/chains.js const tape require(tape) const Common require(../index.js)tape([Common]: Initialization / Chain params, function (t) {t.test(Should initialize with chain provided, function (st) …

mysql修改_mysql修改表操作

一: 修改表信息1.修改表名alter table test_a rename to sys_app;2.修改表注釋alter table sys_application comment 系統信息表;二:修改字段信息1.修改字段類型和注釋alter table sys_application modify column app_name varchar(20) COMMENT 應用的名…

機器學習實踐四--正則化線性回歸 和 偏差vs方差

這次實踐的前半部分是,用水庫水位的變化,來預測大壩的出水量。 給數據集擬合一條直線,可能得到一個邏輯回歸擬合,但它并不能很好地擬合數據,這是高偏差(high bias)的情況,也稱為“欠…

深度學習 推理 訓練_使用關系推理的自我監督學習進行訓練而無需標記數據

深度學習 推理 訓練背景與挑戰📋 (Background and challenges 📋) In a modern deep learning algorithm, the dependence on manual annotation of unlabeled data is one of the major limitations. To train a good model, usually, we have to prepa…

Android strings.xml中定義字符串顯示空格

<string name"str">字 符 串</string> 其中 就表示空格。如果直接在里面鍵入空格&#xff0c;無論多少空格都只會顯示一個。 用的XML轉義字符記錄如下&#xff1a; 空格&#xff1a; <string name"out_bound_submit">出 庫</strin…

WCF開發入門的六個步驟

在這里我就用一個據于一個簡單的場景&#xff1a;服務端為客服端提供獲取客戶信息的一個接口讀取客戶信息&#xff0c;來完成WCF開發入門的六個步驟。 1. 定義WCF服務契約 A. 項目引用節點右鍵添加引用。 B. 在代碼文件里&#xff0c;添加以下命名空間的引…

LOJ116 有源匯有上下界最大流(上下界網絡流)

考慮有源匯上下界可行流&#xff1a;由匯向源連inf邊&#xff0c;那么變成無源匯圖&#xff0c;按上題做法跑出可行流。此時該inf邊的流量即為原圖中該可行流的流量。因為可以假裝把加上去的那些邊的流量放回原圖。 此時再從原來的源向原來的匯跑最大流。超源超匯相關的邊已經流…

CentOS 7 使用 ACL 設置文件權限

Linux 系統標準的 ugo/rwx 集合并不允許為不同的用戶配置不同的權限&#xff0c;所以 ACL 便被引入了進來&#xff0c;為的是為文件和目錄定義更加詳細的訪問權限&#xff0c;而不僅僅是這些特別指定的特定權限。 ACL 可以為每個用戶&#xff0c;每個組或不在文件所屬組中的用…

機器學習實踐五---支持向量機(SVM)

之前已經學到了很多監督學習算法&#xff0c; 今天的監督學習算法是支持向量機&#xff0c;與邏輯回歸和神經網絡算法相比&#xff0c;它在學習復雜的非線性方程時提供了一種更為清晰&#xff0c;更強大的方式。 Support Vector Machines SVM hypothesis Example Dataset 1…

作為微軟技術.net 3.5的三大核心技術之一的WCF雖然沒有WPF美麗的外觀

作為微軟技術.net 3.5的三大核心技術之一的WCF雖然沒有WPF美麗的外觀 但是它卻是我們開發分布式程序的利器 但是目前關于WCF方面的資料相當稀少 希望我的這一系列文章可以幫助大家盡快入門 下面先介紹一下我的開發環境吧 操作系統&#xff1a;windows vista business版本 編譯器…

服務器安裝mysql_阿里云服務器上安裝MySQL

關閉防火墻和selinuxCentOS7以下&#xff1a;service iptables stopsetenforce 0CentOS7.xsystemctl stop firewalldsystemctl disable firewalldsystemctl status firewalldvi /etc/selinux/config把SELINUXenforcing 改成 SELINUXdisabled一、安裝依賴庫yum -y install make …

在PyTorch中轉換數據

In continuation of my previous post ,we will keep on deep diving into basic fundamentals of PyTorch. In this post we will discuss about ways to transform data in PyTorch.延續我以前的 發布后 &#xff0c;我們將繼續深入研究PyTorch的基本原理。 在這篇文章中&a…

「網絡流24題」試題庫問題

傳送門&#xff1a;>Here< 題意&#xff1a;有K種類型的共N道試題用來出卷子&#xff0c;要求卷子須有M道試題。已知每道題屬于p種類型&#xff0c;每種類型的試題必須有且僅有k[i]道。現問出這套試卷的一種具體方案 思路分析 昨天打了一天的Dinic&#xff0c;今天又打了…

機器學習實踐六---K-means聚類算法 和 主成分分析(PCA)

在這次練習中將實現K-means 聚類算法并應用它壓縮圖片&#xff0c;第二部分&#xff0c;將使用主成分分析算法去找到一個臉部圖片的低維描述。 K-means Clustering Implementing K-means K-means算法是一種自動將相似的數據樣本聚在一起的方法,K-means背后的直觀是一個迭代過…

航海家軟件公式全破解

水手突破 上趨勢:MA(LOW,20)*1.2,color0080ff,linethick2;次上趨勢:MA(LOW,20)*1.1,COLORYELLOW;次下趨勢:MA(HIGH,20)*0.9,COLORWHITE;下趨勢:MA(HIGH,20)*0.8,COLORGREEN,linethick2;ZD:(C-REF(C,1))/REF(C,1)*100;HDZF:(HHV(H,20)-C)/(HHV(H,20)-LLV(L,20));趨勢強度:IF(C&g…

打包 壓縮 命令tar zip

2019獨角獸企業重金招聘Python工程師標準>>> 打包 壓縮 命令tar zip tar語法 #壓縮 tar -czvf ***.tar.gz tar -cjvf ***.tar.bz2 #解壓縮 tar -xzvf ***.tar.gz tar -xjvf ***.tar.bz2 tar [主選項輔選項] 文件或目錄 主選項是必須要有的&#xff0c;它告訴tar要做…

mysql免安裝5.7.17_mysql免安裝5.7.17數據庫配置

首先要有 mysql-5.7.10-winx64環境: mysql-5.7.10-winx64 win10(64位)配置環境變量&#xff1a;1、把mysql-5.7.10-winx64放到D盤&#xff0c;進入D\mysql-5.7.10-winx64\bin目錄&#xff0c;復制路徑&#xff0c;配置環境變量&#xff0c;在path后面添加D\mysql-5.7.10-winx6…

tidb數據庫_異構數據庫復制到TiDB

tidb數據庫This article is based on a talk given by Tianshuang Qin at TiDB DevCon 2020.本文基于Tianshuang Qin在 TiDB DevCon 2020 上的演講 。 When we convert from a standalone system to a distributed one, one of the challenges is migrating the database. We’…

機器學習實踐七----異常檢測和推薦系統

Anomaly detection 異常檢測是機器學習中比較常見的應用&#xff0c;它主要用于非監督學習問題&#xff0c;從某些角度看&#xff0c; 它又類似于一些監督學習問題。 什么是異常檢測&#xff1f;來看幾個例子&#xff1a; 例1. 假設是飛機引擎制造商&#xff0c; 要對引擎進行…