timm 視覺庫中的 create_model 函數詳解
最近一年 Vision Transformer 及其相關改進的工作層出不窮,在他們開源的代碼中,大部分都用到了這樣一個庫:timm。各位煉丹師應該已經想必已經對其無比熟悉了,本文將介紹其中最關鍵的函數之一:create_model
函數。
timm簡介
PyTorchImageModels,簡稱timm,是一個巨大的PyTorch代碼集合,包括了一系列:
- image models
- layers
- utilities
- optimizers
- schedulers
- data-loaders / augmentations
- training / validation scripts
旨在將各種 SOTA 模型、圖像實用工具、常用的優化器、訓練策略等視覺相關常用函數的整合在一起,并具有復現ImageNet訓練結果的能力。
源碼:https://github.com/rwightman/pytorch-image-models
文檔:https://fastai.github.io/timmdocs/
create_model 函數的使用及常用參數
本小節先介紹 create_model
函數,及常用的參數 **kwargs
。
顧名思義,create_model
函數是用來創建一個網絡模型(如 ResNet、ViT 等),timm 庫本身可供直接調用的模型已有接近400個,用戶也可以自己實現一些模型并注冊進 timm (這一部分內容將在下一小節著重介紹),供自己調用。
model_name
我們首先來看最簡單地用法:直接傳入模型名稱 model_name
import timm
# 創建 resnet-34
model = timm.create_model('resnet34')
# 創建 efficientnet-b0
model = timm.create_model('efficientnet_b0')
我們可以通過 list_models
函數來查看已經可以直接創建、有預訓練參數的模型列表:
all_pretrained_models_available = timm.list_models(pretrained=True)
print(all_pretrained_models_available)
print(len(all_pretrained_models_available))
輸出:
[..., 'vit_large_patch16_384', 'vit_large_patch32_224_in21k', 'vit_large_patch32_384', 'vit_small_patch16_224', 'wide_resnet50_2', 'wide_resnet101_2', 'xception', 'xception41', 'xception65', 'xception71']
452
如果沒有設置 pretrained=True
的話有將會輸出612,即有預訓練權重參數的模型有452個,沒有預訓練參數,只有模型結構的共有612個。
pretrained
如果我們傳入 pretrained=True
,那么 timm 會從對應的 URL 下載模型權重參數并載入模型,只有當第一次(即本地還沒有對應模型參數時)會去下載,之后會直接從本地加載模型權重參數。
model = timm.create_model('resnet34', pretrained=True)
輸出:
Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet34-43635321.pth" to /home/song/.cache/torch/hub/checkpoints/resnet34-43635321.pth
features_only、out_indices
create_mode
函數還支持 features_only=True
參數,此時函數將返回部分網絡,該網絡提取每一步最深一層的特征圖。還可以使用 out_indices=[…]
參數指定層的索引,以提取中間層特征。
# 創建一個 (1, 3, 224, 224) 形狀的張量
x = torch.randn(1, 3, 224, 224)
model = timm.create_model('resnet34')
preds = model(x)
print('preds shape: {}'.format(preds.shape))all_feature_extractor = timm.create_model('resnet34', features_only=True)
all_features = all_feature_extractor(x)
print('All {} Features: '.format(len(all_features)))
for i in range(len(all_features)):print('feature {} shape: {}'.format(i, all_features[i].shape))out_indices = [2, 3, 4]
selected_feature_extractor = timm.create_model('resnet34', features_only=True, out_indices=out_indices)
selected_features = selected_feature_extractor(x)
print('Selected Features: ')
for i in range(len(out_indices)):print('feature {} shape: {}'.format(out_indices[i], selected_features[i].shape))
我們以一個 (1, 3, 224, 224)
形狀的張量為輸入,在視覺任務中,圖像輸入張量總是類似的形狀。上面例程展示了,創建完整模型 model
,創建完整特征提取器 all_feature_extractor
,和創建某幾層特征提取器 selected_feature_extractor
的具體輸出。
可以結合下面 ResNet34 的結構圖來理解(圖中不同的顏色表示不同的 layer),根據下圖分析各層的卷積操作,計算各層最后一個卷積的輸入,并與上面例程的輸出(附在圖后)驗證是否一致。
輸出:
preds shape: torch.Size([1, 1000])
All 5 Features:
feature 0 shape: torch.Size([1, 64, 112, 112])
feature 1 shape: torch.Size([1, 64, 56, 56])
feature 2 shape: torch.Size([1, 128, 28, 28])
feature 3 shape: torch.Size([1, 256, 14, 14])
feature 4 shape: torch.Size([1, 512, 7, 7])
Selected Features:
feature 2 shape: torch.Size([1, 128, 28, 28])
feature 3 shape: torch.Size([1, 256, 14, 14])
feature 4 shape: torch.Size([1, 512, 7, 7])
這樣,我們就可以通過 timm_model
函數及其 features_only
、out_indices
參數將預訓練模型方便地轉換為自己想要的特征提取器。
接下來我們來看一下這些特征提取器究竟是什么類型:
import timm
feature_extractor = timm.create_model('resnet34', features_only=True, out_indices=[3])print('type:', type(feature_extractor))
print('len: ', len(feature_extractor))
for item in feature_extractor:print(item)
輸出:
type: <class 'timm.models.features.FeatureListNet'>
len: 7
conv1
bn1
act1
maxpool
layer1
layer2
layer3
可以看到,feature_extractor
其實也是一個神經網絡,在 timm 中稱為 FeatureListNet
,而我們通過 out_indices
參數來指定截取到哪一層特征。
需要注意的是,ViT 模型并不支持 features_only
選項(0.4.12版本)。
extractor = timm.create_model('vit_base_patch16_224', features_only=True)
輸出:
RuntimeError: features_only not implemented for Vision Transformer models.
create_model 函數究竟做了什么
registry
在了解了 create_model
函數的基本使用之后,我們來深入探索一下 create_model
函數的源碼,看一下究竟是怎樣實現從模型到特征提取器的轉換的。
create_model
主體只有 50 行左右的代碼,因此所有這些神奇的事情是在其他地方完成的。我們知道 timm.list_models()
函數中的每一個模型名字(str)實際上都是一個函數。以下代碼可以測試這一點:
import timm
import random
from timm.models import registrym = timm.list_models()[-1]
print(m)
registry.is_model(m)
輸出:
xception71
True
實際上,在 timm 內部,有一個字典稱為 _model_entrypoints
包含了所有的模型名稱和他們各自的函數。比如說,我們可以通過 model_entrypoint
函數從 _model_entrypoints
內部得到 xception71
模型的構造函數。
constuctor_fn = registry.model_entrypoint(m)
print(constuctor_fn)
輸出:
<function timm.models.xception_aligned.xception71(pretrained=False, **kwargs)>
也有可能輸出:
<function xception71 at 0x7fc0cba0eca0>
一樣的。
如我們所見,在 timm.models.xception_aligned
模塊中有一個函數稱為 xception71
。類似的,timm 中的每一個模型都有著一個這樣的構造函數。事實上,內部的 _model_entrypoints
字典大概長這個樣子:
_model_entrypoints
> >
{
'cspresnet50':<function timm.models.cspnet.cspresnet50(pretrained=False, **kwargs)>,'cspresnet50d': <function timm.models.cspnet.cspresnet50d(pretrained=False, **kwargs)>,
'cspresnet50w': <function timm.models.cspnet.cspresnet50w(pretrained=False, **kwargs)>,
'cspresnext50': <function timm.models.cspnet.cspresnext50(pretrained=False, **kwargs)>,
'cspresnext50_iabn': <function timm.models.cspnet.cspresnext50_iabn(pretrained=False, **kwargs)>,
'cspdarknet53': <function timm.models.cspnet.cspdarknet53(pretrained=False, **kwargs)>,
'cspdarknet53_iabn': <function timm.models.cspnet.cspdarknet53_iabn(pretrained=False, **kwargs)>,
'darknet53': <function timm.models.cspnet.darknet53(pretrained=False, **kwargs)>,
'densenet121': <function timm.models.densenet.densenet121(pretrained=False, **kwargs)>,
'densenetblur121d': <function timm.models.densenet.densenetblur121d(pretrained=False, **kwargs)>,
'densenet121d': <function timm.models.densenet.densenet121d(pretrained=False, **kwargs)>,
'densenet169': <function timm.models.densenet.densenet169(pretrained=False, **kwargs)>,
'densenet201': <function timm.models.densenet.densenet201(pretrained=False, **kwargs)>,
'densenet161': <function timm.models.densenet.densenet161(pretrained=False, **kwargs)>,
'densenet264': <function timm.models.densenet.densenet264(pretrained=False, **kwargs)>,}
所以說,在 timm 對應的模塊中,每個模型都有一個構造器。比如說 ResNets 系列模型被定義在 timm.models.resnet
模塊中。因此,實際上我們有兩種方式來創建一個 resnet34
模型:
import timm
from timm.models.resnet import resnet34# 使用 create_model
m = timm.create_model('resnet34')# 直接調用構造函數
m = resnet34()
但使用上,我們無須調用構造函數。所用模型都可以通過 create_model
函數來將創建。
Register model
resnet34
構造函數的源碼如下:
@register_model
def resnet34(pretrained=False, **kwargs):"""Constructs a ResNet-34 model."""model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3], **kwargs)return _create_resnet('resnet34', pretrained, **model_args)
我們會發現 timm 中的每個模型都有一個 register_model
裝飾器。最開始, _model_entrypoints
是一個空字典。我們是通過 register_model
裝飾器來不斷地像其中添加模型名稱和它對應的構造函數。該裝飾器的定義如下:
def register_model(fn):# lookup containing modulemod = sys.modules[fn.__module__]module_name_split = fn.__module__.split('.')module_name = module_name_split[-1] if len(module_name_split) else ''# add model to __all__ in modulemodel_name = fn.__name__if hasattr(mod, '__all__'):mod.__all__.append(model_name)else:mod.__all__ = [model_name]# add entries to registry dict/sets_model_entrypoints[model_name] = fn_model_to_module[model_name] = module_name_module_to_models[module_name].add(model_name)has_pretrained = False # check if model has a pretrained url to allow filtering on thisif hasattr(mod, 'default_cfgs') and model_name in mod.default_cfgs:# this will catch all models that have entrypoint matching cfg key, but miss any aliasing# entrypoints or non-matching comboshas_pretrained = 'url' in mod.default_cfgs[model_name] and 'http' in mod.default_cfgs[model_name]['url']if has_pretrained:_model_has_pretrained.add(model_name)return fn
我們可以看到, register_model
函數完成了一些比較基礎的步驟,但這里需要指出的是這一句:
_model_entrypoints[model_name] = fn
它將給定的 fn
添加到 _model_entrypoints
其鍵名為 fn.__name__
。所以說 resnet34
函數上的裝飾器 @register_model
在 _model_entrypoints
中創建一個新的條目,像這樣:
{’resnet34’: <function timm.models.resnet.resnet34(pretrained=False, **kwargs)>}
我們同樣可以看到在 resnet34
構造函數的源碼中,在設置完一些 model_args
之后,它會隨后調用 _create_resnet
函數。讓我們再來看一下該函數的源碼:
def _create_resnet(variant, pretrained=False, **kwargs):return build_model_with_cfg(ResNet, variant, default_cfg=default_cfgs[variant], pretrained=pretrained, **kwargs)
所以在 _create_resnet
函數之中,會再調用 build_model_with_cfg
函數并將一個構造器類 ResNet
、變量名 resnet34
、一個 default_cfg
和一些 **kwargs
傳入其中。
default config
timm 中所有的模型都有一個默認的配置,包括指向它的預訓練權重參數的URL、類別數、輸入圖像尺寸、池化尺寸等。
resnet34
的默認配置如下:
{'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet34-43635321.pth',
'num_classes': 1000,
'input_size': (3, 224, 224),
'pool_size': (7, 7),
'crop_pct': 0.875,
'interpolation': 'bilinear',
'mean': (0.485, 0.456, 0.406),
'std': (0.229, 0.224, 0.225),
'first_conv': 'conv1',
'classifier': 'fc'}
此默認配置與其他參數(如構造函數類和一些模型參數)一起傳遞給 build_model_with_cfg
函數。
build model with config
這個 build_model_with_cfg
函數負責:
- 真正地實例化一個模型類來創建一個模型
- 若
pruned=True
,對模型進行剪枝 - 若
pretrained=True
,加載預訓練模型參數 - 若
features_only=True
,將模型轉換為特征提取器
看一下該函數的源碼:
def build_model_with_cfg(model_cls: Callable,variant: str,pretrained: bool,default_cfg: dict,model_cfg: dict = None,feature_cfg: dict = None,pretrained_strict: bool = True,pretrained_filter_fn: Callable = None,pretrained_custom_load: bool = False,**kwargs):pruned = kwargs.pop('pruned', False)features = Falsefeature_cfg = feature_cfg or {}if kwargs.pop('features_only', False):features = Truefeature_cfg.setdefault('out_indices', (0, 1, 2, 3, 4))if 'out_indices' in kwargs:feature_cfg['out_indices'] = kwargs.pop('out_indices')model = model_cls(**kwargs) if model_cfg is None else model_cls(cfg=model_cfg, **kwargs)model.default_cfg = deepcopy(default_cfg)if pruned:model = adapt_model_from_file(model, variant)# for classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for featsnum_classes_pretrained = 0 if features else getattr(model, 'num_classes', kwargs.get('num_classes', 1000))if pretrained:if pretrained_custom_load:load_custom_pretrained(model)else:load_pretrained(model,num_classes=num_classes_pretrained, in_chans=kwargs.get('in_chans', 3),filter_fn=pretrained_filter_fn, strict=pretrained_strict)if features:feature_cls = FeatureListNetif 'feature_cls' in feature_cfg:feature_cls = feature_cfg.pop('feature_cls')if isinstance(feature_cls, str):feature_cls = feature_cls.lower()if 'hook' in feature_cls:feature_cls = FeatureHookNetelse:assert False, f'Unknown feature class {feature_cls}'model = feature_cls(model, **feature_cfg)model.default_cfg = default_cfg_for_features(default_cfg) # add back default_cfgreturn model
我們可以看到,模型在這一步被創建出來:model = model_cls(**kwargs)
。本文將不再深入到 pruned
和 adapt_model_from_file
內部查看。
總結
通過本文,我們已經完全了解了 create_model
函數,我們了解到:
- 每個模型有不同的構造函數,可以傳入不同的參數,
_model_entrypoints
字典包括了所有的模型名稱及其對應的構造函數 build_with_model_cfg
函數接收模型構造器類和其中的一些具體參數,真正地實例化一個模型load_pretrained
會加載預訓練參數FeatureListNet
類可以將模型轉換為特征提取器
Ref:
https://github.com/rwightman/pytorch-image-models
https://fastai.github.io/timmdocs/
https://fastai.github.io/timmdocs/create_model#Turn-any-model-into-a-feature-extractor
https://fastai.github.io/timmdocs/tutorial_feature_extractor
https://zhuanlan.zhihu.com/p/404107277