今天在看網上的視頻學習深度學習的時候,用到了CIFAR-10數據集。當我興高采烈的運行代碼時,卻發現了一些錯誤:
# -*- coding: utf-8 -*-
import pickle as p
import numpy as np import os def load_CIFAR_batch(filename): """ 載入cifar數據集的一個batch """ with open(filename, 'r') as f: datadict = p.load(f) X = datadict['data'] Y = datadict['labels'] X = X.reshape(10000, 3, 32, 32).transpose(0, 2, 3, 1).astype("float") Y = np.array(Y) return X, Y def load_CIFAR10(ROOT): """ 載入cifar全部數據 """ xs = [] ys = [] for b in range(1, 6): f = os.path.join(ROOT, 'data_batch_%d' % (b,)) X, Y = load_CIFAR_batch(f) xs.append(X) ys.append(Y) Xtr = np.concatenate(xs) Ytr = np.concatenate(ys) del X, Y Xte, Yte = load_CIFAR_batch(os.path.join(ROOT, 'test_batch')) return Xtr, Ytr, Xte, Yte
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
錯誤代碼如下:
'gbk' codec can't decode byte 0x80 in position 0: illegal multibyte sequence
1
于是乎開始各種搜索問題,問大佬,網上的答案都是類似:
然而并沒有解決問題!還是錯誤的!(我大概搜索了一下午吧,都是上面的答案)
哇,就當我很絕望的時候,我終于發現了一個新奇的答案,抱著試一試的態度,嘗試了一下:
def load_CIFAR_batch(filename):
""" 載入cifar數據集的一個batch """ with open(filename, 'rb') as f: datadict = p.load(f, encoding='latin1') X = datadict['data'] Y = datadict['labels'] X = X.reshape(10000, 3, 32, 32).transpose(0, 2, 3, 1).astype("float") Y = np.array(Y) return X, Y
1
2
3
4
5
6
7
8
9
10
竟然成功了,這里沒有報錯了!欣喜之余,我就很好奇,encoding=’latin1’到底是啥玩意呢,以前沒有見過啊?于是,我搜索了一下,了解到:
Latin1是ISO-8859-1的別名,有些環境下寫作Latin-1。ISO-8859-1編碼是單字節編碼,向下兼容ASCII,其編碼范圍是0x00-0xFF,0x00-0x7F之間完全和ASCII一致,0x80-0x9F之間是控制字符,0xA0-0xFF之間是文字符號。
因為ISO-8859-1編碼范圍使用了單字節內的所有空間,在支持ISO-8859-1的系統中傳輸和存儲其他任何編碼的字節流都不會被拋棄。換言之,把其他任何編碼的字節流當作ISO-8859-1編碼看待都沒有問題。這是個很重要的特性,MySQL數據庫默認編碼是Latin1就是利用了這個特性。ASCII編碼是一個7位的容器,ISO-8859-1編碼是一個8位的容器。
還沒等我高興起來,運行后,又發現了一個問題:
memory error
1
什么鬼?內存錯誤!哇,原來是數據大小的問題。
X = X.reshape(10000, 3, 32, 32).transpose(0,2,3,1).astype("float")
1
這告訴我們每批數據都是10000 * 3 * 32 * 32,相當于超過3000萬個浮點數。 float數據類型實際上與float64相同,意味著每個數字大小占8個字節。這意味著每個批次占用至少240 MB。你加載6這些(5訓練+ 1測試)在總產量接近1.4 GB的數據。
for b in range(1,2):
f = os.path.join(ROOT, 'data_batch_%d' % (b,)) X, Y = load_CIFAR_batch(f) xs.append(X) ys.append(Y)
1
2
3
4
5
所以如有可能,如上代碼所示只能一次運行一批。
到此為止,錯誤基本搞定,下面貼出正確代碼:
# -*- coding: utf-8 -*-
import pickle as p
import numpy as np import os def load_CIFAR_batch(filename): """ 載入cifar數據集的一個batch """ with open(filename, 'rb') as f: datadict = p.load(f, encoding='latin1') X = datadict['data'] Y = datadict['labels'] X = X.reshape(10000, 3, 32, 32).transpose(0, 2, 3, 1).astype("float") Y = np.array(Y) return X, Y def load_CIFAR10(ROOT): """ 載入cifar全部數據 """ xs = [] ys = [] for b in range(1, 2): f = os.path.join(ROOT, 'data_batch_%d' % (b,)) X, Y = load_CIFAR_batch(f) xs.append(X) #將所有batch整合起來 ys.append(Y) Xtr = np.concatenate(xs) #使變成行向量,最終Xtr的尺寸為(50000,32,32,3) Ytr = np.concatenate(ys) del X, Y Xte, Yte = load_CIFAR_batch(os.path.join(ROOT, 'test_batch')) return Xtr, Ytr, Xte, Yte
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import numpy as np
from julyedu.data_utils import load_CIFAR10
import matplotlib.pyplot as plt plt.rcParams['figure.figsize'] = (10.0, 8.0) plt.rcParams['image.interpolation'] = 'nearest' plt.rcParams['image.cmap'] = 'gray' # 載入CIFAR-10數據集 cifar10_dir = 'julyedu/datasets/cifar-10-batches-py' X_train, y_train, X_test, y_test = load_CIFAR10(cifar10_dir) # 看看數據集中的一些樣本:每個類別展示一些 print('Training data shape: ', X_train.shape) print('Training labels shape: ', y_train.shape) print('Test data shape: ', X_test.shape) print('Test labels shape: ', y_test.shape)
順便看一下CIFAR-10數據組成: