kaggle計算機視覺比賽技巧,9. 計算機視覺 - 9.12. 實戰Kaggle比賽:圖像分類(CIFAR-10) - 《動手學深度學習》 - 書棧網 · BookStack...

9.12. 實戰Kaggle比賽:圖像分類(CIFAR-10)

到目前為止,我們一直在用Gluon的data包直接獲取NDArray格式的圖像數據集。然而,實際中的圖像數據集往往是以圖像文件的形式存在的。在本節中,我們將從原始的圖像文件開始,一步步整理、讀取并將其變換為NDArray格式。

我們曾在“圖像增廣”一節中實驗過CIFAR-10數據集。它是計算機視覺領域的一個重要數據集。現在我們將應用前面所學的知識,動手實戰CIFAR-10圖像分類問題的Kaggle比賽。該比賽的網頁地址是https://www.kaggle.com/c/cifar-10 。

圖9.16展示了該比賽的網頁信息。為了便于提交結果,請先在Kaggle網站上注冊賬號。

e21389f3c099b6fe538226d20f97726a.gif

圖 9.16 CIFAR-10圖像分類比賽的網頁信息。比賽數據集可通過點擊“Data”標簽獲取

首先,導入比賽所需的包或模塊。In[1]:importd2lzhasd2l

frommxnetimportautograd,gluon,init

frommxnet.gluonimportdataasgdata,lossasgloss,nn

importos

importpandasaspd

importshutil

importtime

9.12.1. 獲取和整理數據集

比賽數據分為訓練集和測試集。訓練集包含5萬張圖像。測試集包含30萬張圖像,其中有1萬張圖像用來計分,其他29萬張不計分的圖像是為了防止人工標注測試集并提交標注結果。兩個數據集中的圖像格式都是png,高和寬均為32像素,并含有RGB三個通道(彩色)。圖像一共涵蓋10個類別,分別為飛機、汽車、鳥、貓、鹿、狗、青蛙、馬、船和卡車。圖9.16的左上角展示了數據集中部分飛機、汽車和鳥的圖像。

9.12.1.1. 下載數據集

登錄Kaggle后,可以點擊圖9.16所示的CIFAR-10圖像分類比賽網頁上的“Data”標簽,并分別下載訓練數據集train.7z、測試數據集test.7z和訓練數據集標簽trainLabels.csv。

9.12.1.2. 解壓數據集

下載完訓練數據集train.7z和測試數據集test.7z后需要解壓縮。解壓縮后,將訓練數據集、測試數據集以及訓練數據集標簽分別存放在以下3個路徑:../data/kaggle_cifar10/train/[1-50000].png;

../data/kaggle_cifar10/test/[1-300000].png;

../data/kaggle_cifar10/trainLabels.csv。

為方便快速上手,我們提供了上述數據集的小規模采樣,其中train_tiny.zip包含100個訓練樣本,而test_tiny.zip僅包含1個測試樣本。它們解壓后的文件夾名稱分別為train_tiny和test_tiny。此外,將訓練數據集標簽的壓縮文件解壓,并得到trainLabels.csv。如果使用上述Kaggle比賽的完整數據集,還需要把下面demo變量改為False。In[2]:# 如果使用下載的Kaggle比賽的完整數據集,把demo變量改為False

demo=True

ifdemo:

importzipfile

forfin['train_tiny.zip','test_tiny.zip','trainLabels.csv.zip']:

withzipfile.ZipFile('../data/kaggle_cifar10/'+f,'r')asz:

z.extractall('../data/kaggle_cifar10/')

9.12.1.3. 整理數據集

我們需要整理數據集,以方便訓練和測試模型。以下的read_label_file函數將用來讀取訓練數據集的標簽文件。該函數中的參數valid_ratio是驗證集樣本數與原始訓練集樣本數之比。In[3]:defread_label_file(data_dir,label_file,train_dir,valid_ratio):

withopen(os.path.join(data_dir,label_file),'r')asf:

# 跳過文件頭行(欄名稱)

lines=f.readlines()[1:]

tokens=[l.rstrip().split(',')forlinlines]

idx_label=dict(((int(idx),label)foridx,labelintokens))

labels=set(idx_label.values())

n_train_valid=len(os.listdir(os.path.join(data_dir,train_dir)))

n_train=int(n_train_valid*(1-valid_ratio))

assert0

returnn_train// len(labels), idx_label

下面定義一個輔助函數,從而僅在路徑不存在的情況下創建路徑。In[4]:defmkdir_if_not_exist(path):# 本函數已保存在d2lzh包中方便以后使用

ifnotos.path.exists(os.path.join(*path)):

os.makedirs(os.path.join(*path))

我們接下來定義reorg_train_valid函數來從原始訓練集中切分出驗證集。以valid_ratio=0.1為例,由于原始訓練集有50,000張圖像,調參時將有45,000張圖像用于訓練并存放在路徑input_dir/train下,而另外5,000張圖像將作為驗證集并存放在路徑input_dir/valid下。經過整理后,同一類圖像將被放在同一個文件夾下,便于稍后讀取。In[5]:defreorg_train_valid(data_dir,train_dir,input_dir,n_train_per_label,

idx_label):

label_count={}

fortrain_fileinos.listdir(os.path.join(data_dir,train_dir)):

idx=int(train_file.split('.')[0])

label=idx_label[idx]

mkdir_if_not_exist([data_dir,input_dir,'train_valid',label])

shutil.copy(os.path.join(data_dir,train_dir,train_file),

os.path.join(data_dir,input_dir,'train_valid',label))

iflabelnotinlabel_countorlabel_count[label]

mkdir_if_not_exist([data_dir,input_dir,'train',label])

shutil.copy(os.path.join(data_dir,train_dir,train_file),

os.path.join(data_dir,input_dir,'train',label))

label_count[label]=label_count.get(label,0)+1

else:

mkdir_if_not_exist([data_dir,input_dir,'valid',label])

shutil.copy(os.path.join(data_dir,train_dir,train_file),

os.path.join(data_dir,input_dir,'valid',label))

下面的reorg_test函數用來整理測試集,從而方便預測時的讀取。In[6]:defreorg_test(data_dir,test_dir,input_dir):

mkdir_if_not_exist([data_dir,input_dir,'test','unknown'])

fortest_fileinos.listdir(os.path.join(data_dir,test_dir)):

shutil.copy(os.path.join(data_dir,test_dir,test_file),

os.path.join(data_dir,input_dir,'test','unknown'))

最后,我們用一個函數分別調用前面定義的read_label_file函數、reorg_train_valid函數以及reorg_test函數。In[7]:defreorg_cifar10_data(data_dir,label_file,train_dir,test_dir,input_dir,

valid_ratio):

n_train_per_label,idx_label=read_label_file(data_dir,label_file,

train_dir,valid_ratio)

reorg_train_valid(data_dir,train_dir,input_dir,n_train_per_label,

idx_label)

reorg_test(data_dir,test_dir,input_dir)

我們在這里只使用100個訓練樣本和1個測試樣本。訓練數據集和測試數據集的文件夾名稱分別為train_tiny和test_tiny。相應地,我們僅將批量大小設為1。實際訓練和測試時應使用Kaggle比賽的完整數據集,并將批量大小batch_size設為一個較大的整數,如128。我們將10%的訓練樣本作為調參使用的驗證集。In[8]:ifdemo:

# 注意,此處使用小訓練集和小測試集并將批量大小相應設小。使用Kaggle比賽的完整數據集時可

# 設批量大小為較大整數

train_dir,test_dir,batch_size='train_tiny','test_tiny',1

else:

train_dir,test_dir,batch_size='train','test',128

data_dir,label_file='../data/kaggle_cifar10','trainLabels.csv'

input_dir,valid_ratio='train_valid_test',0.1

reorg_cifar10_data(data_dir,label_file,train_dir,test_dir,input_dir,

valid_ratio)

9.12.2. 圖像增廣

為應對過擬合,我們使用圖像增廣。例如,加入transforms.RandomFlipLeftRight()即可隨機對圖像做鏡面翻轉,也可以通過transforms.Normalize()對彩色圖像RGB三個通道分別做標準化。下面列舉了其中的部分操作,你可以根據需求來決定是否使用或修改這些操作。In[9]:transform_train=gdata.vision.transforms.Compose([

# 將圖像放大成高和寬各為40像素的正方形

gdata.vision.transforms.Resize(40),

# 隨機對高和寬各為40像素的正方形圖像裁剪出面積為原圖像面積0.64~1倍的小正方形,再放縮為

# 高和寬各為32像素的正方形

gdata.vision.transforms.RandomResizedCrop(32,scale=(0.64,1.0),

ratio=(1.0,1.0)),

gdata.vision.transforms.RandomFlipLeftRight(),

gdata.vision.transforms.ToTensor(),

# 對圖像的每個通道做標準化

gdata.vision.transforms.Normalize([0.4914,0.4822,0.4465],

[0.2023,0.1994,0.2010])])

測試時,為保證輸出的確定性,我們僅對圖像做標準化。In[10]:transform_test=gdata.vision.transforms.Compose([

gdata.vision.transforms.ToTensor(),

gdata.vision.transforms.Normalize([0.4914,0.4822,0.4465],

[0.2023,0.1994,0.2010])])

9.12.3. 讀取數據集

接下來,可以通過創建ImageFolderDataset實例來讀取整理后的含原始圖像文件的數據集,其中每個數據樣本包括圖像和標簽。In[11]:# 讀取原始圖像文件。flag=1說明輸入圖像有3個通道(彩色)

train_ds=gdata.vision.ImageFolderDataset(

os.path.join(data_dir,input_dir,'train'),flag=1)

valid_ds=gdata.vision.ImageFolderDataset(

os.path.join(data_dir,input_dir,'valid'),flag=1)

train_valid_ds=gdata.vision.ImageFolderDataset(

os.path.join(data_dir,input_dir,'train_valid'),flag=1)

test_ds=gdata.vision.ImageFolderDataset(

os.path.join(data_dir,input_dir,'test'),flag=1)

我們在DataLoader中指明定義好的圖像增廣操作。在訓練時,我們僅用驗證集評價模型,因此需要保證輸出的確定性。在預測時,我們將在訓練集和驗證集的并集上訓練模型,以充分利用所有標注的數據。In[12]:train_iter=gdata.DataLoader(train_ds.transform_first(transform_train),

batch_size,shuffle=True,last_batch='keep')

valid_iter=gdata.DataLoader(valid_ds.transform_first(transform_test),

batch_size,shuffle=True,last_batch='keep')

train_valid_iter=gdata.DataLoader(train_valid_ds.transform_first(

transform_train),batch_size,shuffle=True,last_batch='keep')

test_iter=gdata.DataLoader(test_ds.transform_first(transform_test),

batch_size,shuffle=False,last_batch='keep')

9.12.4. 定義模型

與“殘差網絡(ResNet)”一節中的實現稍有不同,這里基于HybridBlock類構建殘差塊。這是為了提升執行效率。In[13]:classResidual(nn.HybridBlock):

def__init__(self,num_channels,use_1x1conv=False,strides=1,**kwargs):

super(Residual,self).__init__(**kwargs)

self.conv1=nn.Conv2D(num_channels,kernel_size=3,padding=1,

strides=strides)

self.conv2=nn.Conv2D(num_channels,kernel_size=3,padding=1)

ifuse_1x1conv:

self.conv3=nn.Conv2D(num_channels,kernel_size=1,

strides=strides)

else:

self.conv3=None

self.bn1=nn.BatchNorm()

self.bn2=nn.BatchNorm()

defhybrid_forward(self,F,X):

Y=F.relu(self.bn1(self.conv1(X)))

Y=self.bn2(self.conv2(Y))

ifself.conv3:

X=self.conv3(X)

returnF.relu(Y+X)

下面定義ResNet-18模型。In[14]:defresnet18(num_classes):

net=nn.HybridSequential()

net.add(nn.Conv2D(64,kernel_size=3,strides=1,padding=1),

nn.BatchNorm(),nn.Activation('relu'))

defresnet_block(num_channels,num_residuals,first_block=False):

blk=nn.HybridSequential()

foriinrange(num_residuals):

ifi==0andnotfirst_block:

blk.add(Residual(num_channels,use_1x1conv=True,strides=2))

else:

blk.add(Residual(num_channels))

returnblk

net.add(resnet_block(64,2,first_block=True),

resnet_block(128,2),

resnet_block(256,2),

resnet_block(512,2))

net.add(nn.GlobalAvgPool2D(),nn.Dense(num_classes))

returnnet

CIFAR-10圖像分類問題的類別個數為10。我們將在訓練開始前對模型進行Xavier隨機初始化。In[15]:defget_net(ctx):

num_classes=10

net=resnet18(num_classes)

net.initialize(ctx=ctx,init=init.Xavier())

returnnet

loss=gloss.SoftmaxCrossEntropyLoss()

9.12.5. 定義訓練函數

我們將根據模型在驗證集上的表現來選擇模型并調節超參數。下面定義了模型的訓練函數train。我們記錄了每個迭代周期的訓練時間,這有助于比較不同模型的時間開銷。In[16]:deftrain(net,train_iter,valid_iter,num_epochs,lr,wd,ctx,lr_period,

lr_decay):

trainer=gluon.Trainer(net.collect_params(),'sgd',

{'learning_rate':lr,'momentum':0.9,'wd':wd})

forepochinrange(num_epochs):

train_l_sum,train_acc_sum,n,start=0.0,0.0,0,time.time()

ifepoch>0andepoch%lr_period==0:

trainer.set_learning_rate(trainer.learning_rate*lr_decay)

forX,yintrain_iter:

y=y.astype('float32').as_in_context(ctx)

withautograd.record():

y_hat=net(X.as_in_context(ctx))

l=loss(y_hat,y).sum()

l.backward()

trainer.step(batch_size)

train_l_sum+=l.asscalar()

train_acc_sum+=(y_hat.argmax(axis=1)==y).sum().asscalar()

n+=y.size

time_s="time %.2f sec"%(time.time()-start)

ifvalid_iterisnotNone:

valid_acc=d2l.evaluate_accuracy(valid_iter,net,ctx)

epoch_s=("epoch %d, loss %f, train acc %f, valid acc %f, "

%(epoch+1,train_l_sum/n,train_acc_sum/n,

valid_acc))

else:

epoch_s=("epoch %d, loss %f, train acc %f, "%

(epoch+1,train_l_sum/n,train_acc_sum/n))

print(epoch_s+time_s+', lr '+str(trainer.learning_rate))

9.12.6. 訓練并驗證模型

現在,我們可以訓練并驗證模型了。下面的超參數都是可以調節的,如增加迭代周期等。由于lr_period和lr_decay分別設為80和0.1,優化算法的學習率將在每80個迭代周期后自乘0.1。簡單起見,這里僅訓練1個迭代周期。In[17]:ctx,num_epochs,lr,wd=d2l.try_gpu(),1,0.1,5e-4

lr_period,lr_decay,net=80,0.1,get_net(ctx)

net.hybridize()

train(net,train_iter,valid_iter,num_epochs,lr,wd,ctx,lr_period,

lr_decay)epoch1,loss5.998157,train acc0.055556,valid acc0.100000,time1.34sec,lr0.1

9.12.7. 對測試集分類并在Kaggle提交結果

得到一組滿意的模型設計和超參數后,我們使用所有訓練數據集(含驗證集)重新訓練模型,并對測試集進行分類。In[18]:net,preds=get_net(ctx),[]

net.hybridize()

train(net,train_valid_iter,None,num_epochs,lr,wd,ctx,lr_period,

lr_decay)

forX,_intest_iter:

y_hat=net(X.as_in_context(ctx))

preds.extend(y_hat.argmax(axis=1).astype(int).asnumpy())

sorted_ids=list(range(1,len(test_ds)+1))

sorted_ids.sort(key=lambdax:str(x))

df=pd.DataFrame({'id':sorted_ids,'label':preds})

df['label']=df['label'].apply(lambdax:train_valid_ds.synsets[x])

df.to_csv('submission.csv',index=False)epoch1,loss6.620115,train acc0.090000,time1.24sec,lr0.1

執行完上述代碼后,我們會得到一個submission.csv文件。這個文件符合Kaggle比賽要求的提交格式。提交結果的方法與“實戰Kaggle比賽:房價預測”一節中的類似。

9.12.8. 小結可以通過創建ImageFolderDataset實例來讀取含原始圖像文件的數據集。

可以應用卷積神經網絡、圖像增廣和混合式編程來實戰圖像分類比賽。

9.12.9. 練習使用Kaggle比賽的完整CIFAR-10數據集。把批量大小batch_size和迭代周期數num_epochs分別改為128和300。可以在這個比賽中得到什么樣的準確率和名次?

如果不使用圖像增廣的方法能得到什么樣的準確率?

參與討論,在社區交流方法和結果。你能發掘出其他更好的技巧嗎?

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/542271.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/542271.shtml
英文地址,請注明出處:http://en.pswp.cn/news/542271.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

qthread中獲取當前優先級_Linux中強大的top命令

top命令算是最直觀、好用的查看服務器負載的命令了。它實時動態刷新顯示服務器狀態信息,且可以通過交互式命令自定義顯示內容,非常強大。在終端中輸入top,回車后會顯示如下內容:top - 21:48:39 up 8:57, 2 users, load average: 0…

JavaScript中帶示例的String repeat()方法

JavaScript | 字符串repeat()方法 (JavaScript | String repeat() Method) The String.repeat() method in JavaScript is used to generate a string by repeating the calling string n number of times. n can be any integer from o to any possible number in JavaScript.…

Python生成驗證碼

#!/usr/bin/env python #coding:utf8 import random #方法1: str_codezxcvbnmasdfghjklqwertyuiopZXCVBNMASDFGHJKLQWERTYUIOP0123456789new_codefor i in range(4):   new_coderandom.choice(str_code)print new_code #方法2: new_code[]def str_code…

snmp 獲得硬件信息_計算機網絡基礎課程—簡單網絡管理協議(SNMP)

簡單網絡管理協議(Simple Network Management Protocol)?除了提供網絡層服務的協議和使用那些服務的應用程序,因特網還需要運行一些讓管理員進行設備管理、調試問題、控制路由、監測機器狀態的軟件。這種行為稱為網絡管理。??隨著網絡技術的飛速發展,…

僵尸毀滅工程 服務器已停止運行,《僵尸毀滅工程》steam is not enabled錯誤解決方法...

Steam 上面的 Project Zomboid 因為帶有 VAC 所以建服開服需要 Steam服務器認證,這也是出現 steam is not enabled 錯誤主要原因,也是無法和普通零售正版所建的服務器聯機的罪魁禍首。分兩種情況(下面 Project Zomboid 均簡稱PZ):1、steam版P…

spring boot 1.4默認使用 hibernate validator

spring boot 1.4默認使用 hibernate validator 5.2.4 Final實現校驗功能。hibernate validator 5.2.4 Final是JSR 349 Bean Validation 1.1的具體實現。 How to disable Hibernate validation in a Spring Boot project As [M. Deinum] mentioned in a comment on my original …

python mpi開銷_GitHub - hustpython/MPIK-Means

并行計算的K-Means聚類算法實現一,實驗介紹聚類是擁有相同屬性的對象或記錄的集合,屬于無監督學習,K-Means聚類算法是其中較為簡單的聚類算法之一,具有易理解,運算深度塊的特點.1.1 實驗內容通過本次課程我們將使用C語…

服務器修改開機啟動項,啟動項設置_服務器開機啟動項

最近很多觀眾老爺在苦覓關于啟動項設置的解答,今天欽編為大家綜合5條解答來給大家解開疑惑! 有98%玩家認為啟動項設置_服務器開機啟動項值得一讀!啟動項設置1.如何在bios設置硬盤為第一啟動項詳細步驟根據BIOS分類的不同操作不同:…

字符串查找字符出現次數_查找字符串作為子序列出現的次數

字符串查找字符出現次數Description: 描述: Its a popular interview question based of dynamic programming which has been already featured in Accolite, Amazon. 這是一個流行的基于動態編程的面試問題,已經在亞馬遜的Accolite中得到了體現。 Pr…

Ubuntu 忘記密碼的處理方法

Ubuntu系統啟動時選擇recovery mode,也就是恢復模式。接著選擇Drop to root shell prompt ,也就是獲取root權限。輸入命令查看用戶名 cat /etc/shadow ,$號前面的是用戶名輸入命令:passwd "用戶名" 回車就可以輸入新密碼了轉載于:…

服務器mdl文件轉換,Simulink Project 中 MDL 到 SLX 模型文件格式的轉換

打開彈體示例項目并將 MDL 文件另存為 SLX運行以下命令以創建并打開“sldemo_slproject_airframe”示例的工作副本。Simulink.ModelManagement.Project.projectDemo(airframe, svn);rebuild_s_functions(no_progress_dialog);Creating sandbox for project.Created example fil…

vue 修改div寬度_Vue 組件通信方式及其應用場景總結(1.5W字)

前言相信實際項目中用過vue的同學,一定對vue中父子組件之間的通信并不陌生,vue中采用良好的數據通訊方式,避免組件通信帶來的困擾。今天筆者和大家一起分享vue父子組件之間的通信方式,優缺點,及其實際工作中的應用場景…

Java System類identityHashCode()方法及示例

系統類identityHashCode()方法 (System class identityHashCode() method) identityHashCode() method is available in java.lang package. identityHashCode()方法在java.lang包中可用。 identityHashCode() method is used to return the hashcode of the given object – B…

Linux中SysRq的使用(魔術鍵)

轉:http://www.chinaunix.net/old_jh/4/902287.html 魔術鍵:Linux Magic System Request Key Hacks 當Linux 系統不能正常響應用戶請求時, 可以使用SysRq小工具控制Linux. 一 SysRq的啟用與關閉 要想啟用SysRq, 需要在配置內核時設置Magic SysRq key (CO…

鏈接服務器訪問接口返回了消息沒有活動事務,因為鏈接服務器 SQLEHR 的 OLE DB 訪問接口 SQLNCLI10 無法啟動分布式事務。...

查看一下MSDTC啟動是否正確1、運行 regedt32,瀏覽至 HKEY_LOCAL_MACHINE\Software\Microsoft\MSDTC。添加一個 DWORD 值 TurnOffRpcSecurity,值數據為 1。2、重啟MS DTC服務。3、打開“管理工具”的“組件服務”。a. 瀏覽至"啟動管理工具"。b.…

micropython 蜂鳴器_基于MicroPython的TPYBoard微信遠程可燃氣體報警器的設計與實現...

前言在我們平時的生活中,經常看到因氣體泄漏發生爆炸事故的新聞。房屋起火、人體中毒等此類的新聞報道層出不窮。這種情況下,人民就發明了可燃氣體報警器。當工業環境、日常生活環境(如使用天然氣的廚房)中可燃性氣體發生泄露,可燃氣體報警器…

Java PropertyPermission getActions()方法與示例

PropertyPermission類的getActions()方法 (PropertyPermission Class getActions() method) getActions() method is available in java.util package. getActions()方法在java.util包中可用。 getActions() method is used to get the list of current actions in the form of…

源碼安裝nginx以及平滑升級

源碼安裝nginx以及平滑升級作者:尹正杰版權聲明:原創作品,謝絕轉載!否則將追究法律責任。歡迎加入:高級運維工程師之路 598432640這個博客不方便上傳軟件包,我給大家把軟件包放到百度云鏈接:htt…

ajax 跨站返回值,jquery ajax 跨域問題

補充回答:你的動態頁只是一個請求頁。例如你新建一個 get.asp 頁面,用以下代碼,在服務端實現像URL異步(ajax)請求,將請求結果輸出。客戶端頁面再次用ajax(JS或者jquery的)向get.asp請求數據。兩次ajax完成異域數據請求。get.asp代…

Bootstrap學習筆記系列1-------Bootstrap網格系統

目錄 Bootstrap網格系統 學習筆記簡單網格偏移列嵌套列列排序Bootstrap網格系統 學習筆記 簡單網格 先上代碼再解釋 <!DOCTYPE html> <html><head><title>Bootstrap 模板</title><meta charset"utf-8"><!-- 引入 Bootstrap -…