1. 引言
在深度學習計算機視覺任務中,數據預處理和數據增強是模型訓練的關鍵步驟,直接影響模型的泛化能力和最終性能表現。PyTorch 提供的 torchvision.transforms 模塊,封裝了豐富的圖像變換方法,能夠高效地完成圖像標準化、裁剪、翻轉等操作。該模塊支持兩種主要的使用方式:單步變換(Single Transform)和組合變換(Compose),可以靈活應對不同場景下的圖像處理需求。
本文將詳細解析 transforms 的核心 API、參數含義,并通過完整代碼示例演示其使用方法。主要內容包括:
基礎變換操作
- 尺寸調整:Resize(target_size)
- 隨機裁剪:RandomCrop(size, padding=None, pad_if_needed=False)
- 中心裁剪:CenterCrop(size)
- 隨機水平/垂直翻轉:RandomHorizontalFlip(p=0.5), RandomVerticalFlip(p=0.5)
顏色空間變換
- 顏色抖動:ColorJitter(brightness=0, contrast=0, saturation=0, hue=0)
- 隨機灰度化:RandomGrayscale(p=0.1)
- 高斯模糊:GaussianBlur(kernel_size, sigma=(0.1, 2.0))
數據標準化
- 歸一化:Normalize(mean, std)
- 張量轉換:ToTensor()
實用組合方法
- 變換鏈:Compose([transforms1, transforms2,...])
- 隨機選擇:RandomApply(transforms, p=0.5)
- 隨機排序:RandomOrder(transforms)
以圖像分類任務為例,一個典型的數據增強流程可能如下:
from torchvision import transformstrain_transform = transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])val_transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
])
其中,訓練集使用更豐富的增強策略以提高模型魯棒性,而驗證集則采用較簡單的預處理保持數據原始分布。通過合理配置這些變換參數,可以顯著提升模型在各種視覺任務(如圖像分類、目標檢測、語義分割等)中的表現。
2. transforms 概述
transforms
是 PyTorch 生態系統中 torchvision
庫的核心模塊之一,專門用于計算機視覺任務中的圖像數據處理。它提供了豐富的圖像變換方法,主要分為三大類功能:
圖像預處理:
- 尺寸調整:
transforms.Resize()
可將圖像統一縮放到指定尺寸(如 256x256) - 歸一化:
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
使用 ImageNet 的均值和標準差進行標準化 - 中心裁剪:
transforms.CenterCrop(224)
從圖像中心裁剪出指定大小的區域
- 尺寸調整:
數據增強(常用于訓練階段防止過擬合):
- 隨機裁剪:
transforms.RandomCrop(224)
在隨機位置裁剪 - 顏色變換:
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1)
- 隨機水平翻轉:
transforms.RandomHorizontalFlip(p=0.5)
- 隨機旋轉:
transforms.RandomRotation(degrees=15)
- 隨機裁剪:
格式轉換:
- PIL圖像轉張量:
transforms.ToTensor()
將圖像轉換為 PyTorch 張量(并自動將像素值歸一化到 [0,1]) - 張量轉PIL圖像:
transforms.ToPILImage()
- PIL圖像轉張量:
組合使用示例:
from torchvision import transforms# 訓練階段的變換流水線
train_transform = transforms.Compose([transforms.Resize(256), # 縮放至256x256transforms.RandomCrop(224), # 隨機裁剪224x224transforms.RandomHorizontalFlip(), # 隨機水平翻轉transforms.ToTensor(), # 轉為張量transforms.Normalize(mean=[0.485, 0.456, 0.406], # 標準化std=[0.229, 0.224, 0.225])
])# 驗證階段的變換流水線(通常不包含隨機增強)
val_transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
在實際應用中,這些變換可以顯著提升模型的泛化能力,特別是在數據量有限的情況下。對于不同的計算機視覺任務(如圖像分類、目標檢測等),可以根據具體需求組合不同的變換操作。
3. 核心 API 詳解
3.1 基礎變換
(1)?Resize(size)
功能:調整圖像尺寸。
參數:
size
?(int or tuple):目標尺寸。如果是?int
,短邊縮放至該值,長邊按比例調整;如果是?(h, w)
,則強制縮放到指定大小。
示例:
transform = transforms.Resize(256) # 短邊縮放到256,長邊按比例調整
transform = transforms.Resize((224, 224)) # 強制縮放到224x224
(2)?CenterCrop(size)
功能:從圖像中心裁剪指定大小的區域。
參數:
size
?(int or tuple):裁剪尺寸(int
?表示正方形,(h, w)
?表示矩形)。
示例:
transform = transforms.CenterCrop(224) # 裁剪224x224的正方形
(3)?RandomCrop(size)
功能:隨機位置裁剪圖像。
參數:
size
?(int or tuple):裁剪尺寸。padding
?(int or tuple, optional):填充邊緣(防止裁剪過小)。
示例:
transform = transforms.RandomCrop(224, padding=10) # 隨機裁剪224x224,邊緣填充10像素
(4)?RandomHorizontalFlip(p=0.5)
功能:以概率?
p
?水平翻轉圖像(默認?p=0.5
)。示例:
transform = transforms.RandomHorizontalFlip(p=0.7) # 70%概率水平翻轉
(5)?RandomRotation(degrees)
功能:隨機旋轉圖像。
參數:
degrees
?(float or tuple):旋轉角度范圍(如?30
?表示?[-30°, 30°]
,(10, 30)
?表示?[10°, 30°]
)。
示例:
transform = transforms.RandomRotation(30) # 隨機旋轉 ±30°
3.2 張量轉換 & 標準化
(1)?ToTensor()
功能:
將?
PIL.Image
?或?numpy.ndarray
?轉換為?torch.Tensor
([C, H, W]
?格式)。像素值從?
[0, 255]
?縮放到?[0.0, 1.0]
。
示例:
transform = transforms.ToTensor() # 轉換為張量
(2)?Normalize(mean, std)
功能:對張量進行標準化(逐通道計算:
(x - mean) / std
)。參數:
mean
?(list):各通道均值(如 ImageNet 的?[0.485, 0.456, 0.406]
)。std
?(list):各通道標準差(如 ImageNet 的?[0.229, 0.224, 0.225]
)。
示例:
transform = transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]
)
3.3 顏色變換
(1)?ColorJitter
功能:隨機調整亮度、對比度、飽和度和色相。
參數說明:
brightness (float 或 tuple):亮度調整范圍
- 當輸入為單個浮點數時(如 0.2),表示亮度調整范圍為 [1-0.2, 1+0.2] = [0.8, 1.2]
- 當輸入為元組時(如 (0.7, 1.3)),表示直接指定亮度調整范圍
- 示例:brightness=0.5 表示圖片亮度將在原始值的50%-150%之間隨機調整
contrast (float 或 tuple):對比度調整范圍
- 調節方式與brightness相同
- 示例:contrast=(0.8, 1.5) 表示對比度將在原始值的80%-150%之間隨機調整
應用場景:
- 這些參數常用于圖像增強和數據增強任務
- 在訓練深度學習模型時,隨機調整這些參數可以增加訓練數據的多樣性
- 每個參數的調整都是在指定范圍內隨機取值,而不是固定值
saturation (float 或 tuple):飽和度調整范圍
- 調節方式與brightness相同
- 示例:saturation=0.3 表示飽和度將在原始值的70%-130%之間隨機調整
hue (float 或 tuple):色相調整范圍
- 當輸入為單個浮點數時(如 0.1),表示色相調整范圍為 [-0.1, 0.1]
- 當輸入為元組時(如 (-0.2, 0.3)),表示直接指定色相調整范圍
- 注意:色相值通常以弧度表示,范圍一般為[-0.5, 0.5]
- 示例:hue=0.05 表示色相將在[-0.05, 0.05]范圍內隨機調整
示例:
transform = transforms.ColorJitter(brightness=0.2,contrast=0.2,saturation=0.2,hue=0.1
)
(2)?Grayscale(num_output_channels=1)
功能:將圖像轉為灰度圖。
參數:
num_output_channels
:輸出通道數(1 或 3)。
示例:
transform = transforms.Grayscale(num_output_channels=3) # 轉為3通道灰度圖
4. 完整代碼示例
4.1 定義訓練和測試的變換
from torchvision import transforms# 訓練集變換(含數據增強)
train_transform = transforms.Compose([transforms.RandomResizedCrop(224), # 隨機縮放裁剪至224x224transforms.RandomHorizontalFlip(), # 50%概率水平翻轉transforms.ColorJitter( # 隨機顏色調整brightness=0.2, contrast=0.2, saturation=0.2),transforms.ToTensor(), # 轉為張量 [C, H, W], 值范圍[0, 1]transforms.Normalize( # 標準化(ImageNet參數)mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])# 測試集變換(僅預處理)
test_transform = transforms.Compose([transforms.Resize(256), # 短邊縮放到256transforms.CenterCrop(224), # 中心裁剪224x224transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
4.2 應用到數據集?
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader# 加載CIFAR10數據集(應用變換)
train_dataset = CIFAR10(root='./data', train=True, transform=train_transform, # 應用訓練變換download=True
)test_dataset = CIFAR10(root='./data', train=False, transform=test_transform, # 應用測試變換download=True
)# 創建DataLoader
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
5. 總結
使用 Compose
可以方便地組合多個變換操作,這些變換會按照添加順序依次執行。例如:
transforms.Compose([transforms.Resize(256), # 調整圖像大小transforms.RandomCrop(224), # 隨機裁剪transforms.ToTensor(), # 轉換為張量transforms.Normalize( # 標準化mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
])
在實際應用中,訓練和測試階段通常采用不同的轉換策略:
標準化(Normalize)是一個關鍵步驟,它能:
當使用預訓練模型時,應該采用該模型訓練時使用的均值和標準差(常見的是 ImageNet 的統計值:mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])。如果不使用預訓練模型,可以計算自己數據集的統計值進行標準化。
PyTorch 中的
transforms
模塊是計算機視覺任務中圖像處理的核心工具,它提供了一系列用于圖像預處理、數據增強和數據類型轉換的功能。這些轉換操作可以高效地將原始圖像數據轉換為適合深度學習模型訓練的格式。主要功能包括:
- 預處理:如圖像大小調整(Resize)、中心裁剪(CenterCrop)、轉換為張量(ToTensor)等基礎操作
- 數據增強:訓練時增加數據多樣性的隨機變換,如隨機水平翻轉(RandomHorizontalFlip)、隨機旋轉(RandomRotation)
- 張量轉換:將 PIL 圖像或 numpy 數組轉換為 PyTorch 張量,并進行數值歸一化等操作
- 訓練階段:建議使用數據增強來提升模型泛化能力,常用增強方法包括:
RandomCrop
:隨機裁剪圖像ColorJitter
:隨機調整亮度、對比度、飽和度RandomHorizontalFlip
:隨機水平翻轉RandomRotation
:隨機旋轉
- 測試階段:通常只需基礎預處理,如固定大小的裁剪和標準化
- 將輸入數據縮放到相近的數值范圍
- 加速模型收斂過程
- 提高訓練穩定性
掌握?transforms
?的使用,可以顯著提升計算機視覺任務的效率和模型性能!?