花卉分類CNN

tensorflow升級到1.0之后,增加了一些高級模塊: 如tf.layers, tf.metrics, 和tf.losses,使得代碼稍微有些簡化。

任務:花卉分類

版本:tensorflow 1.3

數據:http://download.tensorflow.org/example_images/flower_photos.tgz

花總共有五類,分別放在5個文件夾下。

閑話不多說,直接上代碼,希望大家能看懂:)

?

# -*- coding: utf-8 -*-
from skimage import io,transform
import glob
import os
import tensorflow as tf
import numpy as np
import timepath='e:/flower/'#將所有的圖片resize成100*100
w=100
h=100
c=3#讀取圖片
def read_img(path):cate=[path+x for x in os.listdir(path) if os.path.isdir(path+x)]imgs=[]labels=[]for idx,folder in enumerate(cate):for im in glob.glob(folder+'/*.jpg'):print('reading the images:%s'%(im))img=io.imread(im)img=transform.resize(img,(w,h))imgs.append(img)labels.append(idx)return np.asarray(imgs,np.float32),np.asarray(labels,np.int32)
data,label=read_img(path)#打亂順序
num_example=data.shape[0]
arr=np.arange(num_example)
np.random.shuffle(arr)
data=data[arr]
label=label[arr]#將所有數據分為訓練集和驗證集
ratio=0.8
s=np.int(num_example*ratio)
x_train=data[:s]
y_train=label[:s]
x_val=data[s:]
y_val=label[s:]#-----------------構建網絡----------------------
#占位符
x=tf.placeholder(tf.float32,shape=[None,w,h,c],name='x')
y_=tf.placeholder(tf.int32,shape=[None,],name='y_')#第一個卷積層(100——>50)
conv1=tf.layers.conv2d(inputs=x,  filters=32,  kernel_size=[5, 5], padding="same", activation=tf.nn.relu,kernel_initializer=tf.truncated_normal_initializer(stddev=0.01))
pool1=tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2)#第二個卷積層(50->25)
conv2=tf.layers.conv2d( inputs=pool1, filters=64,  kernel_size=[5, 5], padding="same", activation=tf.nn.relu,kernel_initializer=tf.truncated_normal_initializer(stddev=0.01))
pool2=tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=2)#第三個卷積層(25->12)
conv3=tf.layers.conv2d(inputs=pool2, filters=128, kernel_size=[3, 3], padding="same", activation=tf.nn.relu,kernel_initializer=tf.truncated_normal_initializer(stddev=0.01))
pool3=tf.layers.max_pooling2d(inputs=conv3, pool_size=[2, 2], strides=2)#第四個卷積層(12->6)
conv4=tf.layers.conv2d(  inputs=pool3, filters=128, kernel_size=[3, 3], padding="same",activation=tf.nn.relu,kernel_initializer=tf.truncated_normal_initializer(stddev=0.01))
pool4=tf.layers.max_pooling2d(inputs=conv4, pool_size=[2, 2], strides=2)re1 = tf.reshape(pool4, [-1, 6 * 6 * 128])#全連接層
dense1 = tf.layers.dense(inputs=re1,      units=1024,   activation=tf.nn.relu,kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),kernel_regularizer=tf.contrib.layers.l2_regularizer(0.003))
dense2= tf.layers.dense(inputs=dense1,  units=512,  activation=tf.nn.relu,kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),kernel_regularizer=tf.contrib.layers.l2_regularizer(0.003))
logits= tf.layers.dense(inputs=dense2,  units=5,   activation=None,kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),kernel_regularizer=tf.contrib.layers.l2_regularizer(0.003))
#---------------------------網絡結束---------------------------

loss=tf.losses.sparse_softmax_cross_entropy(labels=y_,logits=logits)
train_op=tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)
correct_prediction = tf.equal(tf.cast(tf.argmax(logits,1),tf.int32), y_)    
acc= tf.reduce_mean(tf.cast(correct_prediction, tf.float32))#定義一個函數,按批次取數據
def minibatches(inputs=None, targets=None, batch_size=None, shuffle=False):assert len(inputs) == len(targets)if shuffle:indices = np.arange(len(inputs))np.random.shuffle(indices)for start_idx in range(0, len(inputs) - batch_size + 1, batch_size):if shuffle:excerpt = indices[start_idx:start_idx + batch_size]else:excerpt = slice(start_idx, start_idx + batch_size)yield inputs[excerpt], targets[excerpt]#訓練和測試數據,可將n_epoch設置更大一些

n_epoch=1000
batch_size=64
sess=tf.InteractiveSession()  
sess.run(tf.global_variables_initializer())
for epoch in range(n_epoch):start_time = time.time()#trainingtrain_loss, train_acc, n_batch = 0, 0, 0for x_train_a, y_train_a in minibatches(x_train, y_train, batch_size, shuffle=True):_,err,ac=sess.run([train_op,loss,acc], feed_dict={x: x_train_a, y_: y_train_a})train_loss += err; train_acc += ac; n_batch += 1print("   train loss: %f" % (train_loss/ n_batch))print("   train acc: %f" % (train_acc/ n_batch))#validationval_loss, val_acc, n_batch = 0, 0, 0for x_val_a, y_val_a in minibatches(x_val, y_val, batch_size, shuffle=False):err, ac = sess.run([loss,acc], feed_dict={x: x_val_a, y_: y_val_a})val_loss += err; val_acc += ac; n_batch += 1print("   validation loss: %f" % (val_loss/ n_batch))print("   validation acc: %f" % (val_acc/ n_batch))sess.close()

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/387228.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/387228.shtml
英文地址,請注明出處:http://en.pswp.cn/news/387228.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

【模板】可持久化線段樹

可持久化線段樹/主席樹: 顧名思義,該數據結構是可以訪問歷史版本的線段樹。用于解決需要查詢歷史信息的區間問題。 在功能與時間復雜度上與開n棵線段樹無異,然而空間復雜度從$O(n\times nlogn)$降到了$O(nlogn)$。 實現方法: 每次…

skimage庫需要依賴 numpy+mkl 和scipy

skimage庫需要依賴 numpymkl 和scipy1、打開運行,輸入cmd回車,輸入python回車,查看python版本2、在https://www.lfd.uci.edu/~gohlke/pythonlibs/#numpy 中,根據自己python版本下載需要的包 (因為我的是python 2.7.13 …

操作系統04進程同步與通信

4.1 進程間的相互作用 4.1.1 進程間的聯系資源共享關系相互合作關系臨界資源應互斥訪問。臨界區:不論是硬件臨界資源,還是軟件臨界資源,多個進程必須互斥地對它們進行訪問。把在每個進程中訪問臨界資源的那段代碼稱為臨界資源區。顯然&#x…

oracle遷移到greenplum的方案

oracle數據庫是一種關系型數據庫管理系統,在數據庫領域一直處于領先的地位,適合于大型項目的開發;銀行、電信、電商、金融等各領域都大量使用Oracle數據庫。 greenplum是一款開源的分布式數據庫存儲解決方案,主要關注數據倉庫和BI…

CNN框架的搭建及各個參數的調節

本文代碼下載地址:我的github本文主要講解將CNN應用于人臉識別的流程,程序基于PythonnumpytheanoPIL開發,采用類似LeNet5的CNN模型,應用于olivettifaces人臉數據庫,實現人臉識別的功能,模型的誤差降到了5%以…

操作系統05死鎖

進程管理4--Deadlock and Starvation Concurrency: Deadlock and Starvation 內容提要 >產生死鎖與饑餓的原因 >解決死鎖的方法 >死鎖/同步的經典問題:哲學家進餐問題 Deadlock 系統的一種隨機性錯誤 Permanent blocking of a set of processes that eith…

CNN tensorflow 人臉識別

數據材料這是一個小型的人臉數據庫,一共有40個人,每個人有10張照片作為樣本數據。這些圖片都是黑白照片,意味著這些圖片都只有灰度0-255,沒有rgb三通道。于是我們需要對這張大圖片切分成一個個的小臉。整張圖片大小是1190 942&am…

數據結構01緒論

第一章緒論 1.1 什么是數據結構 數據結構是一門研究非數值計算的程序設計問題中,計算機的操作對象以及他們之間的關系和操作的學科。 面向過程程序數據結構算法 數據結構是介于數學、計算機硬件、計算機軟件三者之間的一門核心課程。 數據結構是程序設計、編譯…

css3動畫、2D與3D效果

1.兼容性 css3針對同一樣式在不同瀏覽器的兼容 需要在樣式屬性前加上內核前綴; 谷歌(chrome) -webkit-transition: Opera(歐鵬) -o-transition: Firefox(火狐) -moz-transition Ie -ms-tr…

ES6學習筆記(六)數組的擴展

1.擴展運算符 1.1含義 擴展運算符(spread)是三個點(...)。它好比 rest 參數的逆運算,將一個數組轉為用逗號分隔的參數序列。 console.log(...[1, 2, 3]) // 1 2 3console.log(1, ...[2, 3, 4], 5) // 1 2 3 4 5[...doc…

數據結構02線性表

第二章 線性表 C中STL順序表:vector http://blog.csdn.net/weixin_37289816/article/details/54710677鏈表:list http://blog.csdn.net/weixin_37289816/article/details/54773406在數據元素的非空有限集中: (1)存在唯一一個被稱作“第…

訓練一個神經網絡 能讓她認得我

寫個神經網絡,讓她認得我(?????)(Tensorflow,opencv,dlib,cnn,人臉識別) 這段時間正在學習tensorflow的卷積神經網絡部分,為了對卷積神經網絡能夠有一個更深的了解,自己動手實現一個例程是比較好的方式,所以就選了一個這樣比…

數據結構03棧和隊列

第三章棧和隊列 STL棧:stack http://blog.csdn.net/weixin_37289816/article/details/54773495隊列:queue http://blog.csdn.net/weixin_37289816/article/details/54773581priority_queue http://blog.csdn.net/weixin_37289816/article/details/5477…

Java動態編譯執行

在某些情況下,我們需要動態生成java代碼,通過動態編譯,然后執行代碼。JAVA API提供了相應的工具(JavaCompiler)來實現動態編譯。下面我們通過一個簡單的例子介紹,如何通過JavaCompiler實現java代碼動態編譯…

樹莓派pwm驅動好盈電調及伺服電機

本文講述如何通過樹莓派的硬件PWM控制好盈電調來驅動RC車子的前進后退,以及如何驅動伺服電機來控制車子轉向。 1. 好盈電調簡介 車子上的電調型號為:WP-10BLS-A-RTR,在好盈官網并沒有搜到對應手冊,但找到一份通用RC競速車的電調使…

數據結構04串

第四章 串 STL:string http://blog.csdn.net/weixin_37289816/article/details/54716009計算機上非數值處理的對象基本上是字符串數據。 在不同類型的應用中,字符串具有不同的特點,要有效的實現字符串的處理,必須選用合適的存儲…

CAS單點登錄原理解析

CAS單點登錄原理解析 SSO英文全稱Single Sign On,單點登錄。SSO是在多個應用系統中,用戶只需要登錄一次就可以訪問所有相互信任的應用系統。CAS是一種基于http協議的B/S應用系統單點登錄實現方案,認識CAS之前首先要熟悉http協議、Session與Co…

JDK1.6版添加了新的ScriptEngine類,允許用戶直接執行js代碼。

JDK1.6版添加了新的ScriptEngine類,允許用戶直接執行js代碼。在Java中直接調用js代碼 不能調用瀏覽器中定義的js函數,會拋出異常提示ReferenceError: “alert” is not defined。[java] view plaincopypackage com.sinaapp.manjushri; import javax.sc…

數據結構05數組和廣義表

第五章 數組 和 廣義表 數組和廣義表可以看成是線性表在下述含義上的擴展:表中的數據元素本身也是一個數據結構。 5.1 數組的定義 n維數組中每個元素都受著n個關系的約束,每個元素都有一個直接后繼元素。 可以把二維數組看成是這樣一個定長線性表&…

k8s的ingress使用

ingress 可以配置一個入口來提供k8s上service從外部來訪問的url、負載平衡流量、終止SSL和提供基于名稱的虛擬主機。 配置ingress的yaml: 要求域名解析無誤 要求service對應的pod正常 一、test1.domain.com --> service1:8080 apiVersion: extensions/v1beta1…