【PyTorch】圖像多分類項目
目錄
StratifiedShuffleSplit
transforms.ToTensor
Counter
StratifiedShuffleSplit
sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=0)
創建StratifiedShuffleSplit對象,用于將數據集劃分為訓練集和測試集。
- n_splits=1:劃分次數為1,大于1則多次劃分,每次劃分生成一組新訓練集和新測試集。
- test_size=0.2:測試集比例為0.2,即測試集的大小占總樣本的20%
- random_state=0:隨機種子為0,類似random的種子,保證每次抽樣到的數據一樣?
StratifiedShuffleSplit是scikit-learn庫中的一個類,用于創建訓練集和測試集的劃分,同時保持每個類別中的樣本比例一致。核心思想:分層抽樣。
StratifiedShuffleSplit?類的工作原理:
先根據每個類別的樣本數量將數據集劃分為盡可能相等的子集(分層)
然后在這些子集中隨機選擇樣本拆分創建訓練集和測試集(隨機拆分)
插入空格更好理解:Stratified Shuffle Split分層隨機拆分類!
transforms.ToTensor
data_transformer = transforms.Compose([transforms.ToTensor()])
?transforms.ToTensor()的作用是將PIL圖像或NumPy數組轉換為PyTorch張量,并且將圖像的像素值從[0, 255]范圍縮放到[0.0, 1.0]范圍,即在[0.0, 1.0]范圍內對像素值進行歸一化。轉換后的張量形狀為(C, H, W)
Compose是 torchvision.transforms 模塊的一個類,創建一個Compose對象時,需要傳入一個包含一個或多個變換操作的列表。Compose對象一般包含四個變換操作:調整圖像大小、從中心裁剪圖像、將圖像轉換為張量以及歸一化。
Counter
counter_train=collections.Counter(y_train)
用于統計圖像標簽,即每類標簽圖像數量,Counter是用于計數的子類字典。例如PyTorch torchvision包中STL-10數據集的訓練數據集: