python 邏輯回歸準確率是1,Python利用邏輯回歸模型解決MNIST手寫數字識別問題詳解...

本文實例講述了Python利用邏輯回歸模型解決MNIST手寫數字識別問題。分享給大家供大家參考,具體如下:

1、MNIST手寫識別問題

MNIST手寫數字識別問題:輸入黑白的手寫阿拉伯數字,通過機器學習判斷輸入的是幾。可以通過TensorFLow下載MNIST手寫數據集,通過import引入MNIST數據集并進行讀取,會自動從網上下載所需文件。

%matplotlib inline

import tensorflow as tf

import tensorflow.examples.tutorials.mnist.input_data as input_data

mnist=input_data.read_data_sets('MNIST_data/',one_hot=True)

import matplotlib.pyplot as plt

def plot_image(image): #圖片顯示函數

plt.imshow(image.reshape(28,28),cmap='binary')

plt.show()

print("訓練集數量:",mnist.train.num_examples,

"特征值組成:",mnist.train.images.shape,

"標簽組成:",mnist.train.labels.shape)

batch_images,batch_labels=mnist.train.next_batch(batch_size=10) #批量讀取數據

print(batch_images.shape,batch_labels.shape)

print('標簽值:',np.argmax(mnist.train.labels[1000]),end=' ') #np.argmax()得到實際值

print('獨熱編碼表示:',mnist.train.labels[1000])

plot_image(mnist.train.images[1000]) #顯示數據集中第1000張圖片

dxkgvjnepbg.jpg

5ovkbdwatcd.jpg

輸出訓練集 的數量有55000個,并打印特征值的shape為(55000,784),其中784代表每張圖片由28*28個像素點組成,由于是黑白圖片,每個像素點只有黑白單通道,即通過784個數可以描述一張圖片的特征值。可以將圖片在Jupyter中輸出,將784個特征值reshape為28×28的二維數組,傳給plt.imshow()函數,之后再通過show()輸出。

MNIST提供next_batch()方法用于批量讀取數據集,例如上面批量讀取10個對應的images與labels數據并分別返回。該方法會按順序一直往后讀取,直到結束后會自動打亂數據,重新繼續讀取。

在打開mnist數據集時,第二個參數設置one_hot,表示采用獨熱編碼方式打開。獨熱編碼是一種稀疏向量,其中一個元素為1,其他元素均為0,常用于表示有限個可能的組合情況。例如數字6的獨熱編碼為第7個分量為1,其他為0的數組。可以通過np.argmax()函數返回數組最大值的下標,即獨熱編碼表示的實際數字。通過獨熱編碼可以將離散特征的某個取值對應歐氏空間的某個點,有利于機器學習中特征之間的距離計算

數據集的劃分,一種劃分為訓練集用于模型的訓練,測試集用于結果的測試,要求集合數量足夠大,而且具有代表性。但是在多次執行后,會導致模型向測試集數據進行擬合,從而導致測試集數據失去了測試的效果。因此將數據集進一步劃分為訓練集、驗證集、測試集,將訓練后的模型用驗證集驗證,當多次迭代結束之后再拿測試集去測試。MNIST數據集中的訓練集為mnist.train,驗證集為mnist.validation,測試集為mnist.test

2、邏輯回歸

與線性回歸相對比,房價預測是根據多個輸入參數x與對應權重w相乘再加上b得到線性的輸出房價。而還有許多問題的輸出是非線性的、控制在[0,1]之間的,比如判斷郵件是否為垃圾郵件,手寫數字為0~9等,邏輯回歸就是用于處理此類問題。例如電子郵件分類器輸出0.8,表示該郵件為垃圾郵件的概率是0.8.

邏輯回歸通過Sigmoid函數保證輸出的值在[0,1]之間,該函數可以將全體實數映射到[0,1],從而將線性的輸出轉換為[0,1]的數。其定義與圖像如下:

iwq25lppscc.jpg

c23nl3szzjw.jpg

在邏輯回歸中如果采用均方差的損失函數,帶入sigmoid會得到一個非凸函數,這類函數會有多個極小值,采用梯度下降法便無法求得最優解。因此在邏輯回歸中采用對數損失函數

4kg2xepsv34.jpg,其中y是特征值x的標簽,y'是預測值。

在手寫數字識別中,通過單層神經元產生連續的輸出值y,將y再輸入到softmax層處理,經過函數計算將結果映射為0~9每個數字對應的概率,概率越大表示該圖片越像某個數字,所有數字的概率之和為1

ceubn3lnadm.jpg

交叉熵損失函數:交叉熵用于刻畫兩個概率分布之間的距離

xniabmaaf3s.jpg,其中p代表正確答案,q代表預測值,交叉熵越小距離越近,從而模型的預測越準確。例如正確答案為(1,0,0),甲模型預測為(0.5,0.2,0.3),其交叉熵=-1*log0.5≈0.3,乙模型(0.7,0.1,0.2),其交叉熵=-1*log0.7≈0.15,所以乙模型預測更準確

模型的訓練

首先定義二維浮點數占位符x、y,以及二維參數變量W、b并隨機賦初值。之后定義前向計算為向量x與W對應叉乘再加b,并將得到的線性結果經過softmax處理得到獨熱編碼預測值。

之后定義準確率accuracy,其值為預測值pred與真實值y相等個數來衡量

接下來初始化變量、設置超參數,并定義損失函數、優化器,之后開始訓練。每輪訓練中分批次讀取數據進行訓練,每輪訓練結束后輸出損失與準確率。

import numpy as np

import tensorflow as tf

import tensorflow.examples.tutorials.mnist.input_data as input_data

mnist=input_data.read_data_sets('MNIST_data/',one_hot=True)

import matplotlib.pyplot as plt

#定義占位符、變量、前向計算

x=tf.placeholder(tf.float32,[None,784],name='x')

y=tf.placeholder(tf.float32,[None,10],name='y')

W=tf.Variable(tf.random_normal([784,10]),name='W')

b=tf.Variable(tf.zeros([10]),name='b')

forward=tf.matmul(x,W)+b

pred=tf.nn.softmax(forward) #通過softmax將線性結果分類處理

#計算預測值與真實值的匹配個數

correct_prediction=tf.equal(tf.argmax(pred,1),tf.argmax(y,1))

#將上一步得到的布爾值轉換為浮點數,并求平均值,得到準確率

accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))

ss=tf.Session()

init=tf.global_variables_initializer()

ss.run(init)

#超參數設置

train_epochs=50

batch_size=100 #每個批次的樣本數

batch_num=int(mnist.train.num_examples/batch_size) #一輪需要訓練多少批

learning_rate=0.01

#定義交叉熵損失函數、梯度下降優化器

loss_function=tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred),reduction_indices=1))

optimizer=tf.train.GradientDescentOptimizer(learning_rate).minimize(loss_function)

for epoch in range(train_epochs):

for batch in range(batch_num): #分批次讀取數據進行訓練

xs,ys=mnist.train.next_batch(batch_size)

ss.run(optimizer,feed_dict={x:xs,y:ys})

#每輪訓練結束后通過帶入驗證集的數據,檢測模型的損失與準去率

loss,acc=ss.run([loss_function,accuracy],\

feed_dict={x:mnist.validation.images,y:mnist.validation.labels})

print('第%2d輪訓練:損失為:%9f,準確率:%.4f'%(epoch+1,loss,acc))

從每輪訓練結果可以看出損失在逐漸下降,準確率在逐步上升。

y0lyn4vmxcs.jpg

結果預測

使用訓練好的模型對測試集中的數據進行預測,即將mnist.test.images數據帶入去求pred的值。

為了使結果更便于顯示,可以借助plot函數庫將圖片數據顯示出來,并配以文字label與predic的值。首先通過plt.gcf()得到一副圖像資源并設置其大小。再通過plt.subplot(5,5,index+1)函數將其劃分為5×5個子圖,遍歷第index+1個子圖,分別將圖像資源繪制到子圖,通過set_title()設置每個子圖的title顯示內容。子圖繪制結束后顯示整個圖片,并調用函數傳入圖片、標簽、預測值等參數。

prediction=ss.run(tf.argmax(pred,1),feed_dict={x:mnist.test.images})

def show_result(images,labels,prediction,index,num=10): #繪制圖形顯示預測結果

pic=plt.gcf() #獲取當前圖像

pic.set_size_inches(10,12) #設置圖片大小

for i in range(0,num):

sub_pic=plt.subplot(5,5,i+1) #獲取第i個子圖

#將第index個images信息顯示到子圖上

sub_pic.imshow(np.reshape(images[index],(28,28)),cmap='binary')

title="label:"+str(np.argmax(labels[index])) #設置子圖的title內容

if len(prediction)>0:

title+=",predict:"+str(prediction[index])

sub_pic.set_title(title,fontsize=10)

sub_pic.set_xticks([]) #設置x、y坐標軸不顯示

sub_pic.set_yticks([])

index+=1

plt.show()

show_result(mnist.test.images,mnist.test.labels,prediction,10)

運行結果如下,可以看到預測的結果大多準確

ekasv0igwm1.jpg

更多關于Python相關內容感興趣的讀者可查看本站專題:《Python數據結構與算法教程》、《Python加密解密算法與技巧總結》、《Python編碼操作技巧總結》、《Python函數使用技巧總結》、《Python字符串操作技巧匯總》及《Python入門與進階經典教程》

希望本文所述對大家Python程序設計有所幫助。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/533195.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/533195.shtml
英文地址,請注明出處:http://en.pswp.cn/news/533195.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

php面試題接口方面,php面試題6 - osc_xb4v1nhl的個人空間 - OSCHINA - 中文開源技術交流社區...

php面試題6一、總結二、php面試題6寫出你認為語言中的高級函數:1)preg_replace()2)preg_match()3) ignore_user_abort()4) debug_backtrace()5) date_default_timezone_set(“PRC”)6) get_class_methods() 得到類的方法名的數組7) preg_split() 字符串分割成數組8)json_encode…

軌道車輛垂向振動Matlab建模與仿真,基于matlab/simulink的車輛建模與故障分析

隨著鐵路行業高速發展,列車運行速度逐漸提高,鐵路安全越來越受到人們的重視,如何保證鐵道車輛運行安全及其故障監測成為一個亟待解決的重大課題。客車車輛在結構上的故障主要有一系彈簧斷裂、減振器失效、空氣彈簧漏氣、高圓彈簧斷裂、車輪踏面擦傷、軸承故障以及蛇形減震器故障…

關于php的問題有哪些,關于PHP的報錯問題?

關于這個報錯的表格我不知到怎么去做,下面的是代碼:header(content-type:text/html;charsetutf-8);session_start();include_once ../include/conf.php;include_once ../include/func.php;include_once ../include/mysql.func.php;check_login();$pageSi…

oracle消耗內存的查詢,在AIX中計算ORACLE消耗的私有內存總數

一早就收到兄弟伙發的QQ信息,關于aix中oracle內存計算的內容The RSS number is equal to the sum of the number of working-segment pages in memory times 4 andthe code-segment pages in memory times 4.The TRS number is equal to just the code-segment page…

php讀取ds18b20,DS18B20_單總線協議

.H文件#ifndef _ONEWIRE_H#define _ONEWIRE_H#include "STC15F2K60S2.H"#include #define OW_SKIP_ROM 0xcc#define DS18B20_CONVERT 0x44#define DS18B20_READ 0xbe//IC引腳定義sbit DQ P1^4;//函數聲明extern void Delay_OneWire(unsigned int t);extern void Wri…

oracle官方文檔查看方法,oracle官方文檔_查看初始化參數(舉例)

深藍的blog:http://blog.csdn.net/huangyanlong/article/details/46864217記錄了一下,使用oracle11g聯機文檔,查看初始化參數的步驟。如果想查看,可以修改的初始化參數的概念信息,可以點擊“ChangingParameter Values …

matlab usewhitebg,Matlab的:geo??show的網格和框架

對於問題1和問題2,原因是軸總是在圖的後面。因此,一種解決方案是在當前的軸上添加新軸並顯示網格,框和自定義刻度。對於問題3,我使用regexprep以取代S後綴負緯度(同上爲經度)。我唯一的問題是經度0將是0E,緯度0,0N。這…

oracle p l,使用P.A.L制作便攜軟件 (一) 基本原理 | 么么噠擁有者

因愛好自學所得,并非專業,此處只是拋磚引玉,歡迎相互交流、學習、提高,辛苦碼字不易,如轉載望保留鏈接出處。簡單介紹:P.A.L是PortableApps.com Launcher的簡稱,它是PortableApps.com開發的便攜…

oracle form執行后左上角沒出現oracle標記,oracle form學習筆記

新增form步驟打開模板TEMPLATE,將其改成自己所要的名稱,刪除Data Blacks中的BLOCKNAME,DETAILBLOCK,刪除Canvases中的BLOCKNAME,刪除Windows中的BLOCKNAME,新增自己的Windows,Canvases,DateBlacks,在form級別的PRE-FOR…

linux 建oracle分區表,Oracle 10g 11g分區表創建舉例

1.3. 創建其他類型分區表1.3.1. 用多列分區鍵創建范圍分區表SQL> create table aning_mutilcol_range2 (aning_id number,3 aning_name varchar2(100),4 aning_year number,5 aning_month number,6 aning_day number,7 aning_amount number8 )9 partition by range (aning_y…

php carbon 連續日期,日期及時間處理包 Carbon 在 Laravel 中的簡單使用

在編寫 PHP 應用時經常需要處理日期和時間,這篇文章帶你了解一下 Carbon – 繼承自 PHP DateTime 類的 API 擴展,它使得處理日期和時間更加簡單。Laravel 中默認使用的時間處理類就是 Carbon。namespace Carbon;class Carbon extends \DateTime{// code …

chmod g s oracle,chmod

chmod(1)名稱chmod - 更改文件的權限模式用法概要chmod [-fR] absolute-mode file...chmod [-fR] symbolic-mode-list file...chmod [-fR] acl_operation file...chmod [-fR] [- named_attribute]...attribute_specification_list file...描述chmod 實用程序可更改或分配文件的…

linux lzo 壓縮文件,Linux常用壓縮和解壓命令

.tar 解包 tar xvf filename.tar.tar 打包 tar cvf filename.tar dirname.gz 解壓1 gunzip filename.gz.gz 解壓2 gzip -d filename.gz.gz 壓縮 gzip filename.tar.gz 和 .tgz 解壓 tar zxvf filename.tar.gz.tar.gz 和 .tgz 壓縮 tar zcvf filename.tar.gz dirname.bz2 解壓1 …

linux進程cpu時間片,能講一下在Linux系統中時間片是怎么分配的還有優先級的具體算法是...

該樓層疑似違規已被系統折疊 隱藏此樓查看此樓圖 1 RT-Linux結構RT -Linux的關鍵技術是通過軟件來模擬硬件的中斷控制器。當Linux系統要封鎖CPU的中斷時時,RT-Linux中的實時子系統會截取到這個請求,把它記錄下來,而實際上并不真正封鎖硬件中斷…

linux中進行遠程服務器連機可以采用telnet,端口號為,使用telnet測試指定端口的連通性...

原標題:使用telnet測試指定端口的連通性telnet 是一個閹割版的 ssh ,它數據不加密,數據容易被盜竊,也容易受中間人攻擊,所以默認情況下 telnet 端口是必須要被關閉的。telnet為用戶提供了在本地計算機上完成遠程主機工…

linux xd命令,看Linux文件的內容:用cat,less,more,head,tail,nl,od,xxd,gv,xdvi命令

使用命令在Linux系統中查看文件的內容是Linux管理員的基本技能之一,在Linux中,有許多應用程序以不同的方式顯示文件內容。您可以使用cat、less、more、head、tail、nl、od、xxd、gv、xdvi命令來查看文本文件或任何其他文件。為了對此進行測試&#xff0c…

linux遠程拷貝免手動輸入密碼,scp遠程拷貝避免輸入密碼

使用scp遠程拷貝文件到指定服務器上,在客戶端生成密鑰放在需要驗證的服務器上,這樣再次連接后直接登陸,避免輸入密碼。設定場景我們需要將tomcat服務器(client1)192.168.30.20 上的catalina.out日志文件,每天使用指定用戶拷貝到日…

玩轉linux文件描述符和重定向,玩轉Linux文件描述符和重定向

本文介紹linux中文件描述符與重定向的相關知識,文件描述符是與文件輸入、輸出相關聯的整數,它們用來跟蹤已打開的文件。有需要的朋友參考下。原文出處:linux下的文件描述符是與文件輸入、輸出相關聯的整數。它們用來跟蹤已打開的文件。最常見…

linux哪個指令可以設定使用者的密碼,linux期末考試練習題 2

一、單項選擇題1、下面不是對Linux操作系統特點描述的是()A、良好的可移植性B、單用戶C、多用戶D、設備獨立性2、查看創建目錄命令mkdir的幫助文檔可以使用()A、mkdir -hB、man mkdirC、help mkdirD、info mkdir3、用標準的輸出重定向(>)像”>file01”能使文件file01的數…

linux腳本格式模板,Linux Shell 常見的命令行格式簡明總結

#在后臺執行 cmd 指令cmd &#命令序列. 在同一行執行多個命令cmd1 ; cmd2#在當前 shell 中以一組的形式執行多個命令{ cmd1 ; cmd2 ; }#在子 shell 中以一組的形式執行多個命令(cmd1 ; cmd2)#管道. 以 cmd1 的執行輸出作為 cmd2 的輸入cmd1 | cmd2#命令替換. 以 cmd2 的執行…