timm
- PyTorch Image Models (timm) 技術指南
- 功能概述
- 一、引言
- 二、timm 庫概述
- 三、安裝 timm 庫
- 四、模型加載與推理示例
- 4.1 通用推理流程
- 4.2 具體模型示例
- 4.2.1 ResNeXt50-32x4d
- 4.2.2 EfficientNet-V2 Small 模型
- 4.2.3 DeiT-3 large 模型
- 4.2.4 RepViT-M2 模型
- 4.2.5 ResNet-RS-101
- 4.2.6 Vision Transformer (ViT)
- 4.2.7 Swin Transformer
- 4.2.8 Swin Transformer V2
- 4.2.9 Swin Transformer V2 Cr
- 4.2.10 Levit
- 4.3 加載自定義模型
- 4.4 提取模型的中間特征
- 4.5 凍結模型的部分層
- 4.6 創建模型時指定輸入圖像尺寸
- 4.7 數據預處理階段調整圖像尺寸
- 4.8 調整輸出分類個數
- 4.9 綜合示例
- 更多模型說明
- 五、timm 庫近期更新
- 5.1 2025 年 2 月 21 日更新
- 5.2 其他更新
- 六、分布式訓練支持
- 七、學習率調度器
- 7.1 余弦退火調度器(CosineLRScheduler)
- 7.2 多步學習率調度器(MultiStepLRScheduler)
- 八、總結
- 九、參考資料
PyTorch Image Models (timm) 技術指南
timm
(PyTorch Image Models)是一個廣泛使用的 PyTorch 庫,它集合了大量的圖像模型、層、實用工具、優化器、調度器、數據加載器/增強器以及參考訓練/驗證腳本。以下是對 timm
庫的詳細介紹,包括功能、模型案例、加載與使用示例以及相關教程的信息。
功能概述
- 豐富的圖像模型:包含眾多預訓練的圖像分類、目標檢測、語義分割等模型,如 ResNet、EfficientNet、ViT 等。
- 實用工具:提供了一系列用于模型訓練、驗證和推理的實用工具,如優化器、調度器、數據加載器和增強器等。
- 模型構建與管理:支持輕松構建和管理不同類型的模型,包括模型的初始化、權重加載和保存等。
- 分布式訓練:支持分布式訓練,方便在多個 GPU 或節點上進行高效訓練。
一、引言
在深度學習領域,圖像分類、目標檢測等任務常常需要使用預訓練的圖像模型。PyTorch Image Models (timm)
是一個功能強大的庫,它提供了大量預訓練的圖像模型,涵蓋了各種架構,方便開發者快速搭建和訓練自己的模型。本文將詳細介紹 timm
庫的使用,包括模型加載、推理以及近期更新的模型和功能。
二、timm 庫概述
timm
是一個基于 PyTorch 的圖像模型庫,它收集了眾多先進的圖像模型,如 ResNet、ViT、Swin Transformer 等,并提供了預訓練的權重。通過 timm
,開發者可以輕松地加載這些模型,進行圖像分類、特征提取等任務。
三、安裝 timm 庫
在使用 timm
之前,需要先安裝該庫。可以使用以下命令進行安裝:
pip install timm
四、模型加載與推理示例
4.1 通用推理流程
以下是一個通用的使用 timm
加載模型并進行推理的示例代碼:
import torch
import timm
from PIL import Image
from torchvision import transforms# 加載預訓練模型
model = timm.create_model('model_name', pretrained=True)
model.eval()# 定義圖像預處理轉換
transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])# 加載圖像
image = Image.open('path_to_your_image.jpg')
image = transform(image).unsqueeze(0)# 進行推理
with torch.no_grad():output = model(image)# 獲取預測結果
_, predicted_idx = torch.max(output, 1)
print(f"Predicted class index: {predicted_idx.item()}")
在上述代碼中,'model_name'
需要替換為具體的模型名稱,'path_to_your_image.jpg'
需要替換為實際的圖像文件路徑。
4.2 具體模型示例
4.2.1 ResNeXt50-32x4d
import torch
import timm
from PIL import Image
from torchvision import transforms# 加載預訓練的 ResNeXt50-32x4d 模型
model = timm.create_model('resnext50_32x4d', pretrained=True)
model.eval()# 定義圖像預處理轉換
transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])# 加載圖像
image = Image.open('path_to_your_image.jpg')
image = transform(image).unsqueeze(0)# 進行推理
with torch.no_grad():output = model(image)# 獲取預測結果
_, predicted_idx = torch.max(output, 1)
print(f"ResNeXt50-32x4d Predicted class index: {predicted_idx.item()}")
4.2.2 EfficientNet-V2 Small 模型
import torch
import timm
from PIL import Image
from torchvision import transforms# 加載預訓練的 EfficientNet-V2 Small 模型
model = timm.create_model('efficientnetv2_s', pretrained=True)
model.eval()# 定義圖像預處理轉換
transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])# 加載圖像
image = Image.open('path_to_your_image.jpg')
image = transform(image).unsqueeze(0)# 進行推理
with torch.no_grad():output = model(image)# 獲取預測結果
_, predicted_idx = torch.max(output, 1)
print(f"EfficientNet-V2 Small Predicted class index: {predicted_idx.item()}")
4.2.3 DeiT-3 large 模型
import torch
import timm
from PIL import Image
from torchvision import transforms# 加載預訓練的 DeiT-3 large 模型
model = timm.create_model('deit3_large_patch16_384', pretrained=True)
model.eval()# 定義圖像預處理轉換,注意輸入尺寸為 384x384
transform = transforms.Compose([transforms.Resize(384),transforms.CenterCrop(384),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])# 加載圖像
image = Image.open('path_to_your_image.jpg')
image = transform(image).unsqueeze(0)# 進行推理
with torch.no_grad():output = model(image)# 獲取預測結果
_, predicted_idx = torch.max(output, 1)
print(f"DeiT-3 large Predicted class index: {predicted_idx.item()}")
4.2.4 RepViT-M2 模型
import torch
import timm
from PIL import Image
from torchvision import transforms# 加載預訓練的 RepViT-M2 模型
model = timm.create_model('repvit_m2', pretrained=True)
model.eval()# 定義圖像預處理轉換
transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])# 加載圖像
image = Image.open('path_to_your_image.jpg')
image = transform(image).unsqueeze(0)# 進行推理
with torch.no_grad():output = model(image)# 獲取預測結果
_, predicted_idx = torch.max(output, 1)
print(f"RepViT-M2 Predicted class index: {predicted_idx.item()}")
4.2.5 ResNet-RS-101
import torch
import timm
from PIL import Image
from torchvision import transforms# 加載預訓練的 ResNet-RS-101 模型
model = timm.create_model('resnetrs101', pretrained=True)
model.eval()# 定義圖像預處理轉換
transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])# 加載圖像
image = Image.open('path_to_your_image.jpg')
image = transform(image).unsqueeze(0)# 進行推理
with torch.no_grad():output = model(image)# 獲取預測結果
_, predicted_idx = torch.max(output, 1)
print(f"ResNet-RS-101 Predicted class index: {predicted_idx.item()}")
4.2.6 Vision Transformer (ViT)
import torch
import timm
from PIL import Image
from torchvision import transforms# 加載預訓練的 ViT-Base/32 模型
model = timm.create_model('vit_base_patch32_224', pretrained=True)
model.eval()# 定義圖像預處理轉換
transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])# 加載圖像
image = Image.open('path_to_your_image.jpg')
image = transform(image).unsqueeze(0)# 進行推理
with torch.no_grad():output = model(image)# 獲取預測結果
_, predicted_idx = torch.max(output, 1)
print(f"ViT-Base/32 Predicted class index: {predicted_idx.item()}")
4.2.7 Swin Transformer
import torch
import timm
from PIL import Image
from torchvision import transforms# 加載預訓練的 Swin Transformer 模型
model = timm.create_model('swin_base_patch4_window7_224', pretrained=True)
model.eval()# 定義圖像預處理轉換
transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])# 加載圖像
image = Image.open('path_to_your_image.jpg')
image = transform(image).unsqueeze(0)# 進行推理
with torch.no_grad():output = model(image)# 獲取預測結果
_, predicted_idx = torch.max(output, 1)
print(f"Swin Transformer Predicted class index: {predicted_idx.item()}")
4.2.8 Swin Transformer V2
import torch
import timm
from PIL import Image
from torchvision import transforms# 加載預訓練的 Swin Transformer V2 模型
model = timm.create_model('swinv2_base_window12_192_22k', pretrained=True)
model.eval()# 定義圖像預處理轉換
transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])# 加載圖像
image = Image.open('path_to_your_image.jpg')
image = transform(image).unsqueeze(0)# 進行推理
with torch.no_grad():output = model(image)# 獲取預測結果
_, predicted_idx = torch.max(output, 1)
print(f"Swin Transformer V2 Predicted class index: {predicted_idx.item()}")
4.2.9 Swin Transformer V2 Cr
import torch
import timm
from PIL import Image
from torchvision import transforms# 加載預訓練的 Swin Transformer V2 Cr 模型
model = timm.create_model('swinv2_cr_base_224', pretrained=True)
model.eval()# 定義圖像預處理轉換
transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])# 加載圖像
image = Image.open('path_to_your_image.jpg')
image = transform(image).unsqueeze(0)# 進行推理
with torch.no_grad():output = model(image)# 獲取預測結果
_, predicted_idx = torch.max(output, 1)
print(f"Swin Transformer V2 Cr Predicted class index: {predicted_idx.item()}")
4.2.10 Levit
import torch
import timm
from PIL import Image
from torchvision import transforms# 加載預訓練的 Levit 模型
model = timm.create_model('levit_256', pretrained=True)
model.eval()# 定義圖像預處理轉換
transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])# 加載圖像
image = Image.open('path_to_your_image.jpg')
image = transform(image).unsqueeze(0)# 進行推理
with torch.no_grad():output = model(image)# 獲取預測結果
_, predicted_idx = torch.max(output, 1)
print(f"Levit Predicted class index: {predicted_idx.item()}")
4.3 加載自定義模型
如果你需要加載自定義的模型,可以使用 timm.create_model
函數,并指定模型的名稱和相關參數:
import timm# 創建自定義的 EfficientNet 模型
model = timm.create_model('efficientnet_b0', pretrained=False, num_classes=10)
print(model)
4.4 提取模型的中間特征
import torch
import timm# 加載預訓練的模型
model = timm.create_model('resnet18', pretrained=True, features_only=True)# 生成隨機輸入
x = torch.randn(1, 3, 224, 224)# 提取中間特征
features = model(x)
for i, feat in enumerate(features):print(f"Feature {i} shape: {feat.shape}")
4.5 凍結模型的部分層
import torch
import timm
from timm.utils.model import freeze# 加載預訓練的模型
model = timm.create_model('resnet18', pretrained=True)# 凍結模型的前幾層
submodules = [n for n, _ in model.named_children()]
freeze(model, submodules[:submodules.index('layer2') + 1])# 檢查凍結情況
print(model.layer2[0].conv1.weight.requires_grad) # 輸出: False
print(model.layer3[0].conv1.weight.requires_grad) # 輸出: True
在使用 timm
庫加載預訓練模型后,我們經常需要根據具體的任務需求調整模型的參數,例如輸入圖像尺寸、輸出分類個數等。下面將結合提供的代碼片段詳細介紹如何進行這些參數的調整。
4.6 創建模型時指定輸入圖像尺寸
部分模型在創建時可以通過 img_size
參數指定輸入圖像的尺寸。以下是一個使用 SwinTransformer
模型的示例:
import timm
import torch# 加載預訓練的 SwinTransformer 模型,并指定輸入圖像尺寸為 384x384
model = timm.create_model('swin_base_patch4_window7_224', pretrained=True, img_size=(384, 384))
model.eval()# 隨機生成一個符合指定尺寸的輸入張量進行測試
input_tensor = torch.randn(1, 3, 384, 384)
with torch.no_grad():output = model(input_tensor)
print("Output shape:", output.shape)
在上述代碼中,我們通過 img_size=(384, 384)
指定了輸入圖像的尺寸為 384x384。
4.7 數據預處理階段調整圖像尺寸
除了在創建模型時指定輸入圖像尺寸,還需要在數據預處理階段將輸入圖像調整為指定的尺寸。可以使用 torchvision.transforms
來實現這一點,示例如下:
from torchvision import transforms
from PIL import Image# 定義圖像預處理轉換,將圖像調整為 384x384
transform = transforms.Compose([transforms.Resize((384, 384)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])# 加載圖像
image = Image.open('path_to_your_image.jpg')
image = transform(image).unsqueeze(0)# 進行推理
with torch.no_grad():output = model(image)
print("Output shape:", output.shape)
在這個示例中,我們使用 transforms.Resize((384, 384))
將輸入圖像調整為 384x384 的尺寸。
4.8 調整輸出分類個數
輸出分類個數的調整通常在創建模型時通過 num_classes
參數來實現。以下是一個使用 MetaFormer
模型的示例:
import timm# 加載預訓練的 MetaFormer 模型,并指定輸出分類個數為 10
model = timm.create_model('metaformer', pretrained=True, num_classes=10)
model.eval()# 隨機生成一個輸入張量進行測試
input_tensor = torch.randn(1, 3, 224, 224)
with torch.no_grad():output = model(input_tensor)
print("Output shape:", output.shape)
在上述代碼中,我們通過 num_classes=10
指定了模型的輸出分類個數為 10。
4.9 綜合示例
下面是一個綜合示例,展示了如何同時調整輸入圖像尺寸和輸出分類個數:
import timm
import torch
from torchvision import transforms
from PIL import Image# 加載預訓練的 SwinTransformer 模型,調整輸入圖像尺寸為 384x384,輸出分類個數為 10
model = timm.create_model('swin_base_patch4_window7_224', pretrained=True, img_size=(384, 384), num_classes=10)
model.eval()# 定義圖像預處理轉換,將圖像調整為 384x384
transform = transforms.Compose([transforms.Resize((384, 384)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])# 加載圖像
image = Image.open('path_to_your_image.jpg')
image = transform(image).unsqueeze(0)# 進行推理
with torch.no_grad():output = model(image)
print("Output shape:", output.shape)
在這個綜合示例中,我們同時調整了輸入圖像尺寸和輸出分類個數,并進行了圖像預處理和推理操作。
通過以上方法,我們可以根據具體的任務需求靈活調整預訓練模型的輸入圖像尺寸和輸出分類個數。
更多模型說明
除了上述示例中的模型,timm
庫還包含了許多其他的模型,如 Aggregating Nested Transformers、BEiT、Big Transfer ResNetV2 (BiT) 等。你可以在 timm
的官方文檔 https://huggingface.co/docs/timm 中找到完整的模型列表。
要使用其他模型,只需將 timm.create_model
函數中的模型名稱替換為你想要使用的模型名稱即可。例如,要使用 BEiT 模型,可以使用以下代碼:
import torch
import timm
from PIL import Image
from torchvision import transforms# 加載預訓練的 BEiT 模型
model = timm.create_model('beit_base_patch16_224', pretrained=True)
model.eval()# 定義圖像預處理轉換
transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])# 加載圖像
image = Image.open('path_to_your_image.jpg')
image = transform(image).unsqueeze(0)# 進行推理
with torch.no_grad():output = model(image)# 獲取預測結果
_, predicted_idx = torch.max(output, 1)
print(f"BEiT Predicted class index: {predicted_idx.item()}")
五、timm 庫近期更新
5.1 2025 年 2 月 21 日更新
- 新增 SigLIP 2 ViT 圖像編碼器:可從 https://huggingface.co/collections/timm/siglip-2-67b8e72ba08b09dd97aecaf9 獲取。
- 新增 ‘SO150M2’ ViT 權重:使用 SBB 配方訓練,在 ImageNet 上取得了很好的效果。例如,
vit_so150m2_patch16_reg1_gap_448.sbb_e200_in12k_ft_in1k
的 top-1 準確率達到 88.1%。 - 更新 InternViT - 300M ‘2.5’ 權重。
- 發布 1.0.15 版本。
5.2 其他更新
在 2025 年 1 月至 2024 年 10 月期間,timm
庫還進行了許多其他更新,包括添加新的優化器(如 Kron Optimizer、MARS 優化器等)、支持新的模型(如 convnext_nano
、AIM - v2 編碼器等)、修復一些 bug 以及改進代碼結構等。
六、分布式訓練支持
timm
庫還提供了分布式訓練的支持,相關代碼在 timm/utils/distributed.py
中。以下是一些關鍵函數的介紹:
reduce_tensor
:用于在分布式環境中對張量進行規約操作。distribute_bn
:確保每個節點具有相同的運行時 BN 統計信息。init_distributed_device
:初始化分布式訓練設備。
以下是一個簡單的分布式訓練初始化示例:
import torch
from timm.utils.distributed import init_distributed_deviceargs = type('', (), {})() # 創建一個空的參數對象
device = init_distributed_device(args)
print(f"Device: {device}, World size: {args.world_size}, Rank: {args.rank}")
七、學習率調度器
timm
庫提供了多種學習率調度器,可在 timm/scheduler
目錄下找到相關代碼。以下是一些常見的調度器及其使用示例:
7.1 余弦退火調度器(CosineLRScheduler)
import torch
import timm
from timm.scheduler.scheduler_factory import create_scheduler# 定義優化器
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)# 定義調度器參數
scheduler_args = type('', (), {'sched': 'cosine','epochs': 100,'decay_epochs': 30,'warmup_epochs': 5
})()# 創建調度器
scheduler, num_epochs = create_scheduler(scheduler_args, optimizer)# 訓練循環
for epoch in range(num_epochs):# 訓練代碼...scheduler.step(epoch)
7.2 多步學習率調度器(MultiStepLRScheduler)
import torch
import timm
from timm.scheduler.scheduler_factory import create_scheduler# 定義優化器
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)# 定義調度器參數
scheduler_args = type('', (), {'sched': 'multistep','epochs': 100,'decay_milestones': [30, 60],'decay_rate': 0.1,'warmup_epochs': 5
})()# 創建調度器
scheduler, num_epochs = create_scheduler(scheduler_args, optimizer)# 訓練循環
for epoch in range(num_epochs):# 訓練代碼...scheduler.step(epoch)
八、總結
PyTorch Image Models (timm)
是一個非常實用的圖像模型庫,它提供了豐富的預訓練模型和便捷的使用接口,同時支持分布式訓練和多種學習率調度器。通過本文的介紹,你可以快速上手 timm
庫,進行圖像分類等任務的開發。希望本文對你有所幫助,祝你在深度學習領域取得更好的成果!
九、參考資料
timm
官方文檔:https://huggingface.co/docs/timmtimm
代碼庫:https://github.com/rwightman/pytorch-image-models