從基于大型數據集訓練的神經網絡中提取層,并基于新數據集進行微調。本例使用ImageNet中的子集進行微調。
This example retrains a SqueezeNet neural network using transfer learning. This network has been trained on over a million images, and can classify images into 1000 object categories (such as keyboard, coffee mug, pencil, and many animals). The network has learned rich feature representations for a wide range of images. The network takes an image as input and outputs a prediction score for each of these classes.
Performing transfer learning and fine-tuning of a pretrained neural network typically requires less data, is much faster, and is easier than training a neural network from scratch.
To adapt a pretrained neural network for a new task, replace the last few layers (the network head) so that it outputs prediction scores for each of the classes for the new task. This diagram outlines the architecture of a neural network that makes predictions for classes, and illustrates how to edit the network so that it outputs predictions for classes.
ImageNet 使用 WordNet 的層級分類體系,每個類別有唯一的 ID。
- 老虎(tiger)
- WordNet ID:
n02129604
- 子類別: 包括孟加拉虎、西伯利亞虎(Indochinese tiger)等。
- WordNet ID:
- 兔子(rabbit)
- WordNet ID:
n02325366
- 子類別: 如家兔(
European rabbit
)、野兔(hare
)等。
- WordNet ID:
- 雞(chicken)
- WordNet ID:
n01514668
- 子類別: 如母雞(
hen
)、公雞(rooster
)、小雞(chick
)等。
- WordNet ID:
- 老虎:1,300 張圖片(不同虎亞種)。
- 兔子:1,300 張圖片(含家兔、野兔)。
- 雞:1,300 張圖片(含不同品種、年齡)。
Load Training Data
Create an image datastore. An image datastore enables you to store large collections of image data, including data that does not fit in memory, and efficiently read batches of images when training a neural network. Specify the folder with the extracted images, and indicate that the subfolder names correspond to the image labels.
imds = imageDatastore(digitDatasetPath, ...IncludeSubfolders=true,LabelSource="foldernames");imds.Labels = renamecats(imds.Labels, {'n01514668', 'n02129604','n02325366'}, {'chicken', 'tiger','rabbit'});
numObsPerClass = countEachLabel(imds)
numObsPerClass = Label Count_______ _____chicken 1300 tiger 1300 rabbit 1300
Load Pretrained Network
To adapt a pretrained neural network for a new task, replace the last few layers (the network head) so that it outputs prediction scores for each of the classes for the new task. This diagram outlines the architecture of a neural network that makes predictions for classes, and illustrates how to edit the network so that it outputs predictions for classes.
Load a pretrained SqueezeNet neural network into the workspace by using the imagePretrainedNetwork
function. To return a neural network ready for retraining for the new data, specify the number of classes. When you specify the number of classes, the imagePretrainedNetwork
function adapts the neural network so that it outputs prediction scores for each of the specified number of classes.
You can try other pretrained networks. Deep Learning Toolbox? provides various pretrained networks that have different sizes, speeds, and accuracies. These additional networks usually require a support package. If the support package for a selected network is not installed, then the function provides a download link. For more information, see Pretrained Deep Neural Networks.
net = imagePretrainedNetwork("squeezenet",NumClasses=numClasses);
inputSize = networkInputSize(net)
The learnable layer in the network head (the last layer with learnable parameters) requires retraining. The layer is usually a fully connected layer, or a convolutional layer, with an output size that matches the number of classes.
The networkHead function, attached to this example as a supporting file, returns the layer and learnable parameter names of the learnable layer in the network head.
[layerName,learnableNames] = networkHead(net)
For transfer learning, you can freeze the weights of earlier layers in the network by setting the learning rates in those layers to 0. During training, the trainnet function does not update the parameters of these frozen layers. Because the function does not compute the gradients of the frozen layers, freezing the weights can significantly speed up network training. For small datasets, freezing the network layers prevents those layers from overfitting to the new dataset.
Freeze the weights of the network, keeping the last learnable layer unfrozen.
net = freezeNetwork(net,LayerNamesToIgnore=layerName);
Prepare Data for Training
The images in the datastore can have different sizes. To automatically resize the training images, use an augmented image datastore.
augImds = augmentedImageDatastore(inputSize(1:2),imds,ColorPreprocessing='gray2rgb');
Specify Training Options
Specify the training options. Choosing among the options requires empirical analysis. To explore different training option configurations by running experiments, you can use the Experiment Manager app.
For this example, use these options:
Train using the Adam optimizer.
Validate the network using the validation data every five iterations. For larger datasets, to prevent validation from slowing down training, increase this value.
Display the training progress in a plot, and monitor the accuracy metric.
Disable the verbose output.
opts = trainingOptions("adam", ...InitialLearnRate=1e-4, ...MaxEpochs=50, ...ValidationData=augImdsVal, ...Verbose=false,...Plots="training-progress", ...MiniBatchSize=128,...Metrics="accuracy");
Train Neural Network
Train the neural network using the trainnet function. For classification, use cross-entropy loss. By default, the trainnet function uses a GPU if one is available. Using a GPU requires a Parallel Computing Toolbox? license and a supported GPU device. For information on supported devices, see GPU Computing Requirements. Otherwise, the trainnet function uses the CPU. To specify the execution environment, use the ExecutionEnvironment training option.
rng default
net = trainnet(augImds,net,"crossentropy",opts);
沒有劃分數據集,因為這個例子本身的目的是為了觀察CNN的特征變換。
>> summary(net)已初始化: true可學習參數的數量: 724k輸入:1 'data' 227×227×3 圖像
觀察在訓練集上的性能。
將預訓練的神經網絡直接應用于分類問題。要對新圖像進行分類,請使用 minibatchpredict。要將預測分類分數轉換為標簽,請使用scores2label 函數。有關如何使用預訓練神經網絡進行分類的示例,請參閱使用 GoogLeNet 對圖像進行分類。
Ambiguity of Classifications
You can use the softmax activations to calculate the image classifications that are most likely to be incorrect. Define the ambiguity of a classification as the ratio of the second-largest probability to the largest probability. The ambiguity of a classification is between zero (nearly certain classification) and 1 (nearly as likely to be classified to the most likely class as the second class). An ambiguity of near 1 means the network is unsure of the class in which a particular image belongs. This uncertainty might be caused by two classes whose observations appear so similar to the network that it cannot learn the differences between them. Or, a high ambiguity can occur because a particular observation contains elements of more than one class, so the network cannot decide which classification is correct. Note that low ambiguity does not necessarily imply correct classification; even if the network has a high probability for a class, the classification can still be incorrect.
[R,RI] = maxk(softmaxActivations,2,2);
ambiguity = R(:,2)./R(:,1);
Find the most ambiguous images.
[ambiguity,ambiguityIdx] = sort(ambiguity,"descend");
View the most probable classes of the ambiguous images and the true classes.
classList = unique(imds.Labels);
top10Idx = ambiguityIdx(1:10);
top10Ambiguity = ambiguity(1:10);
mostLikely = classList(RI(ambiguityIdx,1));
secondLikely = classList(RI(ambiguityIdx,2));
table(top10Idx,top10Ambiguity,mostLikely(1:10),secondLikely(1:10),imds.Labels(ambiguityIdx(1:10)),...VariableNames=["Image #","Ambiguity","Likeliest","Second","True Class"])
10×5 tableImage # Ambiguity Likeliest Second True Class_______ _________ _________ _______ __________2268 0.99602 chicken tiger tiger 3330 0.99584 tiger rabbit rabbit 104 0.99187 chicken tiger chicken 304 0.98644 rabbit chicken chicken 1163 0.98466 tiger chicken chicken 3071 0.95684 chicken rabbit rabbit 1925 0.95373 rabbit tiger tiger 3006 0.95209 rabbit chicken rabbit 2772 0.93734 chicken rabbit rabbit 3461 0.9258 tiger rabbit rabbit
容易錯分的地方就這三坨。原因是這些樣本都比較復雜,前景不突出,或者背景復雜,造成特征不明確。