@浙大疏錦行
作業:
了解下cifar數據集,嘗試獲取其中一張圖片
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader , Dataset # DataLoader 是 PyTorch 中用于加載數據的工具
from torchvision import datasets, transforms # torchvision 是一個用于計算機視覺的庫,datasets 和 transforms 是其中的模塊
import matplotlib.pyplot as plt# 設置隨機種子,確保結果可復現
torch.manual_seed(42)
# 先歸一化,再標準化
transform = transforms.Compose([transforms.ToTensor(), # 轉換為張量并歸一化到[0,1]transforms.Normalize((0.1307,), (0.3081,)) # MNIST數據集的均值和標準差,這個值很出名,所以直接使用
])# 2. 加載MNIST數據集,如果沒有會自動下載
train_dataset = datasets.CIFAR10(root='./data',train=True,download=True,transform=transform
)test_dataset = datasets.CIFAR10(root='./data',train=False,transform=transform
)import matplotlib.pyplot as plt# 隨機選擇一張圖片,可以重復運行,每次都會隨機選擇
sample_idx = torch.randint(0, len(train_dataset), size=(1,)).item() # 隨機選擇一張圖片的索引
# len(train_dataset) 表示訓練集的圖片數量;size=(1,)表示返回一個索引;torch.randint() 函數用于生成一個指定范圍內的隨機數,item() 方法將張量轉換為 Python 數字
image, label = train_dataset[sample_idx] # 獲取圖片和標簽# minist數據集的簡化版本
class MNIST(Dataset):def __init__(self, root, train=True, transform=None):# 初始化:加載圖片路徑和標簽self.data, self.targets = fetch_mnist_data(root, train) # 這里假設 fetch_mnist_data 是一個函數,用于加載 MNIST 數據集的圖片路徑和標簽self.transform = transform # 預處理操作def __len__(self): return len(self.data) # 返回樣本總數def __getitem__(self, idx): # 獲取指定索引的樣本# 獲取指定索引的圖像和標簽img, target = self.data[idx], self.targets[idx]# 應用圖像預處理(如ToTensor、Normalize)if self.transform is not None: # 如果有預處理操作img = self.transform(img) # 轉換圖像格式# 這里假設 img 是一個 PIL 圖像對象,transform 會將其轉換為張量并進行歸一化return img, target # 返回處理后的圖像和標簽# 可視化原始圖像(需要反歸一化)
def imshow(img):img = img * 0.3081 + 0.1307 # 反標準化npimg = img.numpy()plt.imshow(npimg[0], cmap='gray') # 顯示灰度圖像plt.show()print(f"Label: {label}")
imshow(image)
Files already downloaded and verified
Label: 6