timm 視覺庫中的 create_model 函數詳解

timm 視覺庫中的 create_model 函數詳解

最近一年 Vision Transformer 及其相關改進的工作層出不窮,在他們開源的代碼中,大部分都用到了這樣一個庫:timm。各位煉丹師應該已經想必已經對其無比熟悉了,本文將介紹其中最關鍵的函數之一:create_model 函數。

timm簡介

PyTorchImageModels,簡稱timm,是一個巨大的PyTorch代碼集合,包括了一系列:

  • image models
  • layers
  • utilities
  • optimizers
  • schedulers
  • data-loaders / augmentations
  • training / validation scripts

旨在將各種 SOTA 模型、圖像實用工具、常用的優化器、訓練策略等視覺相關常用函數的整合在一起,并具有復現ImageNet訓練結果的能力。

源碼:https://github.com/rwightman/pytorch-image-models

文檔:https://fastai.github.io/timmdocs/

create_model 函數的使用及常用參數

本小節先介紹 create_model 函數,及常用的參數 **kwargs

顧名思義,create_model 函數是用來創建一個網絡模型(如 ResNet、ViT 等),timm 庫本身可供直接調用的模型已有接近400個,用戶也可以自己實現一些模型并注冊進 timm (這一部分內容將在下一小節著重介紹),供自己調用。

model_name

我們首先來看最簡單地用法:直接傳入模型名稱 model_name

import timm 
# 創建 resnet-34 
model = timm.create_model('resnet34')
# 創建 efficientnet-b0
model = timm.create_model('efficientnet_b0')

我們可以通過 list_models 函數來查看已經可以直接創建、有預訓練參數的模型列表:

all_pretrained_models_available = timm.list_models(pretrained=True)
print(all_pretrained_models_available)
print(len(all_pretrained_models_available))

輸出:

[..., 'vit_large_patch16_384', 'vit_large_patch32_224_in21k', 'vit_large_patch32_384', 'vit_small_patch16_224', 'wide_resnet50_2', 'wide_resnet101_2', 'xception', 'xception41', 'xception65', 'xception71']
452

如果沒有設置 pretrained=True 的話有將會輸出612,即有預訓練權重參數的模型有452個,沒有預訓練參數,只有模型結構的共有612個。

pretrained

如果我們傳入 pretrained=True,那么 timm 會從對應的 URL 下載模型權重參數并載入模型,只有當第一次(即本地還沒有對應模型參數時)會去下載,之后會直接從本地加載模型權重參數。

model = timm.create_model('resnet34', pretrained=True)

輸出:

Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet34-43635321.pth" to /home/song/.cache/torch/hub/checkpoints/resnet34-43635321.pth

features_only、out_indices

create_mode 函數還支持 features_only=True 參數,此時函數將返回部分網絡,該網絡提取每一步最深一層的特征圖。還可以使用 out_indices=[…] 參數指定層的索引,以提取中間層特征。

# 創建一個 (1, 3, 224, 224) 形狀的張量
x = torch.randn(1, 3, 224, 224)
model = timm.create_model('resnet34')
preds = model(x)
print('preds shape: {}'.format(preds.shape))all_feature_extractor = timm.create_model('resnet34', features_only=True)
all_features = all_feature_extractor(x)
print('All {} Features: '.format(len(all_features)))
for i in range(len(all_features)):print('feature {} shape: {}'.format(i, all_features[i].shape))out_indices = [2, 3, 4]
selected_feature_extractor = timm.create_model('resnet34', features_only=True, out_indices=out_indices)
selected_features = selected_feature_extractor(x)
print('Selected Features: ')
for i in range(len(out_indices)):print('feature {} shape: {}'.format(out_indices[i], selected_features[i].shape))

我們以一個 (1, 3, 224, 224) 形狀的張量為輸入,在視覺任務中,圖像輸入張量總是類似的形狀。上面例程展示了,創建完整模型 model,創建完整特征提取器 all_feature_extractor,和創建某幾層特征提取器 selected_feature_extractor 的具體輸出。

可以結合下面 ResNet34 的結構圖來理解(圖中不同的顏色表示不同的 layer),根據下圖分析各層的卷積操作,計算各層最后一個卷積的輸入,并與上面例程的輸出(附在圖后)驗證是否一致。

在這里插入圖片描述

輸出:

preds shape: torch.Size([1, 1000])
All 5 Features:
feature 0 shape: torch.Size([1, 64, 112, 112])
feature 1 shape: torch.Size([1, 64, 56, 56])
feature 2 shape: torch.Size([1, 128, 28, 28])
feature 3 shape: torch.Size([1, 256, 14, 14])
feature 4 shape: torch.Size([1, 512, 7, 7])
Selected Features:
feature 2 shape: torch.Size([1, 128, 28, 28])
feature 3 shape: torch.Size([1, 256, 14, 14])
feature 4 shape: torch.Size([1, 512, 7, 7])

這樣,我們就可以通過 timm_model 函數及其 features_onlyout_indices 參數將預訓練模型方便地轉換為自己想要的特征提取器。

接下來我們來看一下這些特征提取器究竟是什么類型:

import timm
feature_extractor = timm.create_model('resnet34', features_only=True, out_indices=[3])print('type:', type(feature_extractor))
print('len: ', len(feature_extractor))
for item in feature_extractor:print(item)

輸出:

type: <class 'timm.models.features.FeatureListNet'>
len:  7
conv1
bn1
act1
maxpool
layer1
layer2
layer3

可以看到,feature_extractor 其實也是一個神經網絡,在 timm 中稱為 FeatureListNet,而我們通過 out_indices 參數來指定截取到哪一層特征。

需要注意的是,ViT 模型并不支持 features_only 選項(0.4.12版本)。

extractor = timm.create_model('vit_base_patch16_224', features_only=True)

輸出:

RuntimeError: features_only not implemented for Vision Transformer models.

create_model 函數究竟做了什么

registry

在了解了 create_model 函數的基本使用之后,我們來深入探索一下 create_model 函數的源碼,看一下究竟是怎樣實現從模型到特征提取器的轉換的。

create_model 主體只有 50 行左右的代碼,因此所有這些神奇的事情是在其他地方完成的。我們知道 timm.list_models() 函數中的每一個模型名字(str)實際上都是一個函數。以下代碼可以測試這一點:

import timm
import random 
from timm.models import registrym = timm.list_models()[-1]
print(m)
registry.is_model(m)

輸出:

xception71
True

實際上,在 timm 內部,有一個字典稱為 _model_entrypoints 包含了所有的模型名稱和他們各自的函數。比如說,我們可以通過 model_entrypoint 函數從 _model_entrypoints 內部得到 xception71 模型的構造函數。

constuctor_fn = registry.model_entrypoint(m)
print(constuctor_fn)

輸出:

<function timm.models.xception_aligned.xception71(pretrained=False, **kwargs)>

也有可能輸出:

<function xception71 at 0x7fc0cba0eca0>

一樣的。

如我們所見,在 timm.models.xception_aligned 模塊中有一個函數稱為 xception71 。類似的,timm 中的每一個模型都有著一個這樣的構造函數。事實上,內部的 _model_entrypoints 字典大概長這個樣子:

_model_entrypoints
> > 
{
'cspresnet50':<function timm.models.cspnet.cspresnet50(pretrained=False, **kwargs)>,'cspresnet50d': <function timm.models.cspnet.cspresnet50d(pretrained=False, **kwargs)>,
'cspresnet50w': <function timm.models.cspnet.cspresnet50w(pretrained=False, **kwargs)>,
'cspresnext50': <function timm.models.cspnet.cspresnext50(pretrained=False, **kwargs)>,
'cspresnext50_iabn': <function timm.models.cspnet.cspresnext50_iabn(pretrained=False, **kwargs)>,
'cspdarknet53': <function timm.models.cspnet.cspdarknet53(pretrained=False, **kwargs)>,
'cspdarknet53_iabn': <function timm.models.cspnet.cspdarknet53_iabn(pretrained=False, **kwargs)>,
'darknet53': <function timm.models.cspnet.darknet53(pretrained=False, **kwargs)>,
'densenet121': <function timm.models.densenet.densenet121(pretrained=False, **kwargs)>,
'densenetblur121d': <function timm.models.densenet.densenetblur121d(pretrained=False, **kwargs)>,
'densenet121d': <function timm.models.densenet.densenet121d(pretrained=False, **kwargs)>,
'densenet169': <function timm.models.densenet.densenet169(pretrained=False, **kwargs)>,
'densenet201': <function timm.models.densenet.densenet201(pretrained=False, **kwargs)>,
'densenet161': <function timm.models.densenet.densenet161(pretrained=False, **kwargs)>,
'densenet264': <function timm.models.densenet.densenet264(pretrained=False, **kwargs)>,}

所以說,在 timm 對應的模塊中,每個模型都有一個構造器。比如說 ResNets 系列模型被定義在 timm.models.resnet 模塊中。因此,實際上我們有兩種方式來創建一個 resnet34 模型:

import timm
from timm.models.resnet import resnet34# 使用 create_model
m = timm.create_model('resnet34')# 直接調用構造函數
m = resnet34()

但使用上,我們無須調用構造函數。所用模型都可以通過 create_model 函數來將創建。

Register model

resnet34 構造函數的源碼如下:

@register_model
def resnet34(pretrained=False, **kwargs):"""Constructs a ResNet-34 model."""model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3], **kwargs)return _create_resnet('resnet34', pretrained, **model_args)

我們會發現 timm 中的每個模型都有一個 register_model 裝飾器。最開始, _model_entrypoints 是一個空字典。我們是通過 register_model 裝飾器來不斷地像其中添加模型名稱和它對應的構造函數。該裝飾器的定義如下:

def register_model(fn):# lookup containing modulemod = sys.modules[fn.__module__]module_name_split = fn.__module__.split('.')module_name = module_name_split[-1] if len(module_name_split) else ''# add model to __all__ in modulemodel_name = fn.__name__if hasattr(mod, '__all__'):mod.__all__.append(model_name)else:mod.__all__ = [model_name]# add entries to registry dict/sets_model_entrypoints[model_name] = fn_model_to_module[model_name] = module_name_module_to_models[module_name].add(model_name)has_pretrained = False  # check if model has a pretrained url to allow filtering on thisif hasattr(mod, 'default_cfgs') and model_name in mod.default_cfgs:# this will catch all models that have entrypoint matching cfg key, but miss any aliasing# entrypoints or non-matching comboshas_pretrained = 'url' in mod.default_cfgs[model_name] and 'http' in mod.default_cfgs[model_name]['url']if has_pretrained:_model_has_pretrained.add(model_name)return fn

我們可以看到, register_model 函數完成了一些比較基礎的步驟,但這里需要指出的是這一句:

_model_entrypoints[model_name] = fn

它將給定的 fn 添加到 _model_entrypoints 其鍵名為 fn.__name__。所以說 resnet34 函數上的裝飾器 @register_model_model_entrypoints 中創建一個新的條目,像這樣:

{&#8217;resnet34&#8217;: <function timm.models.resnet.resnet34(pretrained=False, **kwargs)>}

我們同樣可以看到在 resnet34 構造函數的源碼中,在設置完一些 model_args 之后,它會隨后調用 _create_resnet 函數。讓我們再來看一下該函數的源碼:

def _create_resnet(variant, pretrained=False, **kwargs):return build_model_with_cfg(ResNet, variant, default_cfg=default_cfgs[variant], pretrained=pretrained, **kwargs)

所以在 _create_resnet 函數之中,會再調用 build_model_with_cfg 函數并將一個構造器類 ResNet 、變量名 resnet34、一個 default_cfg 和一些 **kwargs 傳入其中。

default config

timm 中所有的模型都有一個默認的配置,包括指向它的預訓練權重參數的URL、類別數、輸入圖像尺寸、池化尺寸等。

resnet34 的默認配置如下:

{'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet34-43635321.pth',
'num_classes': 1000,
'input_size': (3, 224, 224),
'pool_size': (7, 7),
'crop_pct': 0.875,
'interpolation': 'bilinear',
'mean': (0.485, 0.456, 0.406),
'std': (0.229, 0.224, 0.225),
'first_conv': 'conv1',
'classifier': 'fc'}

此默認配置與其他參數(如構造函數類和一些模型參數)一起傳遞給 build_model_with_cfg 函數。

build model with config

這個 build_model_with_cfg 函數負責:

  1. 真正地實例化一個模型類來創建一個模型
  2. pruned=True,對模型進行剪枝
  3. pretrained=True,加載預訓練模型參數
  4. features_only=True,將模型轉換為特征提取器

看一下該函數的源碼:

def build_model_with_cfg(model_cls: Callable,variant: str,pretrained: bool,default_cfg: dict,model_cfg: dict = None,feature_cfg: dict = None,pretrained_strict: bool = True,pretrained_filter_fn: Callable = None,pretrained_custom_load: bool = False,**kwargs):pruned = kwargs.pop('pruned', False)features = Falsefeature_cfg = feature_cfg or {}if kwargs.pop('features_only', False):features = Truefeature_cfg.setdefault('out_indices', (0, 1, 2, 3, 4))if 'out_indices' in kwargs:feature_cfg['out_indices'] = kwargs.pop('out_indices')model = model_cls(**kwargs) if model_cfg is None else model_cls(cfg=model_cfg, **kwargs)model.default_cfg = deepcopy(default_cfg)if pruned:model = adapt_model_from_file(model, variant)# for classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for featsnum_classes_pretrained = 0 if features else getattr(model, 'num_classes', kwargs.get('num_classes', 1000))if pretrained:if pretrained_custom_load:load_custom_pretrained(model)else:load_pretrained(model,num_classes=num_classes_pretrained, in_chans=kwargs.get('in_chans', 3),filter_fn=pretrained_filter_fn, strict=pretrained_strict)if features:feature_cls = FeatureListNetif 'feature_cls' in feature_cfg:feature_cls = feature_cfg.pop('feature_cls')if isinstance(feature_cls, str):feature_cls = feature_cls.lower()if 'hook' in feature_cls:feature_cls = FeatureHookNetelse:assert False, f'Unknown feature class {feature_cls}'model = feature_cls(model, **feature_cfg)model.default_cfg = default_cfg_for_features(default_cfg)  # add back default_cfgreturn model

我們可以看到,模型在這一步被創建出來:model = model_cls(**kwargs)。本文將不再深入到 prunedadapt_model_from_file 內部查看。

總結

通過本文,我們已經完全了解了 create_model 函數,我們了解到:

  • 每個模型有不同的構造函數,可以傳入不同的參數, _model_entrypoints 字典包括了所有的模型名稱及其對應的構造函數
  • build_with_model_cfg 函數接收模型構造器類和其中的一些具體參數,真正地實例化一個模型
  • load_pretrained 會加載預訓練參數
  • FeatureListNet 類可以將模型轉換為特征提取器

Ref:

https://github.com/rwightman/pytorch-image-models

https://fastai.github.io/timmdocs/

https://fastai.github.io/timmdocs/create_model#Turn-any-model-into-a-feature-extractor

https://fastai.github.io/timmdocs/tutorial_feature_extractor

https://zhuanlan.zhihu.com/p/404107277

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/532666.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/532666.shtml
英文地址,請注明出處:http://en.pswp.cn/news/532666.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

C--數據結構--樹的學習

6.2.1二叉樹的性質 1.二叉樹 性質&#xff1a; 1.若二叉樹的層次從1開始&#xff0c;則在二叉樹的第i層最多有2^(i-1)個結點 2.深度為k的二叉樹最多有2^k -1個結點 &#xff08;k>1&#xff09; 3.對任何一顆二叉樹&#xff0c;如果其葉結點個數為n0,度為2的非葉結點個數…

TVM:使用 Schedule 模板和 AutoTVM 來優化算子

TVM&#xff1a;使用 Schedule 模板和 AutoTVM 來優化算子 在本文中&#xff0c;我們將介紹如何使用 TVM 張量表達式&#xff08;Tensor Expression&#xff0c;TE&#xff09;語言編寫 Schedule 模板&#xff0c;AutoTVM 可以搜索通過這些模板找到最佳 Schedule。這個過程稱為…

TVM:使用 Auto-scheduling 來優化算子

TVM&#xff1a;使用 Auto-scheduling 來優化算子 在本教程中&#xff0c;我們將展示 TVM 的 Auto-scheduling 功能如何在無需編寫自定義模板的情況下找到最佳 schedule。 與基于模板的 AutoTVM 依賴手動模板定義搜索空間不同&#xff0c;auto-scheduler 不需要任何模板。 用…

C語言—sort函數比較大小的快捷使用--algorithm頭文件下

sort函數 一般情況下要將一組數從的大到小排序或從小到大排序&#xff0c;要定義一個新的函數排序。 而我們也可以直接使用在函數下的sort函數&#xff0c;只需加上頭文件&#xff1a; #include<algorithm> using namespace std;sort格式&#xff1a;sort(首元素地址&…

散列的使用

散列 散列簡單來說&#xff1a;給N個正整數和M個負整數&#xff0c;問這M個數中的每個數是否在N中出現過。 比如&#xff1a;N&#xff1a;{1,2,3,4}&#xff0c;M{2,5,7}&#xff0c;其中M的2在N中出現過 對這個問題最直觀的思路是&#xff1a;對M中每個欲查的值x&#xff0…

關于C++中的unordered_map和unordered_set不能直接以pair作為鍵名的問題

關于C中的unordered_map和unordered_set不能直接以pair作為鍵名的問題 在 C STL 中&#xff0c;不同于有序的 std::map 和 std::set 是基于紅黑樹實現的&#xff0c;std::unordered_map 和 std::unordered_set 是基于哈希實現的&#xff0c;在不要求容器內的鍵有序&#xff0c…

AI編譯器與傳統編譯器的聯系與區別

AI編譯器與傳統編譯器的區別與聯系 總結整理自知乎問題 針對神經網絡的編譯器和傳統編譯器的區別和聯系是什么&#xff1f;。 文中提到的答主的知乎主頁&#xff1a;金雪鋒、楊軍、藍色、SunnyCase、貝殼與知了、工藤福爾摩 筆者本人理解 為了不用直接手寫機器碼&#xff0…

python學習1:注釋\變量類型\轉換函數\轉義字符\運算符

python基礎學習 與大多數語言不同&#xff0c;python最具特色的就是使用縮進來表示代碼塊&#xff0c;不需要使用大括號 {} 。縮進的空格數是可變的&#xff0c;但是同一個代碼塊的語句必須包含相同的縮進空格數。 &#xff08;一個tab4個空格&#xff09; Python語言中常見的…

Python、C++ lambda 表達式

Python、C lambda 表達式 lambda函數簡介 匿名函數lambda&#xff1a;是指一類無需定義標識符&#xff08;函數名&#xff09;的函數或子程序。所謂匿名函數&#xff0c;通俗地說就是沒有名字的函數&#xff0c;lambda函數沒有名字&#xff0c;是一種簡單的、在同一行中定義函…

python 學習2 /輸入/ 輸出 /列表 /字典

python基礎學習第二天 輸入輸出 xinput("輸入內容") print(x)input輸出&#xff1a; eval :去掉字符串外圍的引號&#xff0c;按照python的語法執行內容 aeval(12) print(a)eval輸出樣式&#xff1a; 列表 建立&#xff0c;添加&#xff0c;插入&#xff0c;刪去…

Linux、Mac 命令行快捷鍵

Linux、Mac 命令行快捷鍵 Linux 命令行編輯快捷鍵&#xff0c;參考了好多個&#xff0c;應該算是比較全的了&#xff0c;Linux 和 Mac 的都有&#xff0c;筆者本人比較常用的也已經紅色標出來了&#xff0c;如有錯誤或遺漏&#xff0c;歡迎留言指出。 光標移動及編輯&#xff…

Python 命令行傳參

Python 命令行傳參 說到 python 命令行傳參&#xff0c;可能大部分人的第一反應就是用 argparse。的確&#xff0c;argparse 在我們需要指定多個預設的參數&#xff08;如深度學習中指定模型的超參數等&#xff09;時&#xff0c;是非常有用的。但是如果有時我們只需要一個參數…

快速排序 C++

快速排序 C 本文圖示借鑒自清華大學鄧俊輝老師數據結構課程。 快速排序的思想 快速排序是分治思想的典型應用。該排序算法可以原地實現&#xff0c;即空間復雜度為 O(1)O(1)O(1)&#xff0c;而時間復雜度為 O(nlogn)O(nlogn)O(nlogn) 。 算法將待排序的序列 SSS 分為兩個子…

Linux命令行下感嘆號的幾個用法

Linux命令行下 " ! " 的幾個用法 ! 在大多數編程語言中表示取反的意思&#xff0c;但是在命令行中&#xff0c;他還有一些其他的神奇用法。熟練掌握這些用法&#xff0c;可以大大提高我們日常命令行操作的效率。 1 執行歷史命令 !! ! 在命令行中可以用來執行歷史…

三地址碼簡介

三地址碼簡介 三地址碼&#xff08;Three Address Code&#xff09;是一種最常用的中間語言&#xff0c;編譯器可以通過它來改進代碼轉換效率。每個三地址碼指令&#xff0c;都可以被分解為一個四元組&#xff08;4-tuple&#xff09;的形式&#xff1a;&#xff08;運算符&am…

llvm與gcc

llvm與gcc llvm 是一個編譯器&#xff0c;也是一個編譯器架構&#xff0c;是一系列編譯工具&#xff0c;也是一個編譯器工具鏈&#xff0c;開源 C11 實現。 gcc 相對于 clang 的優勢&#xff1a; gcc 支持更過語言前端&#xff0c;如 Java, Ada, FORTRAN, Go等gcc 支持更多地 …

攻防世界web新手區解題 view_source / robots / backup

1**. view_source** 題目描述&#xff1a;X老師讓小寧同學查看一個網頁的源代碼&#xff0c;但小寧同學發現鼠標右鍵好像不管用了。 f12查看源碼即可發現flag 2. robots 題目描述&#xff1a;X老師上課講了Robots協議&#xff0c;小寧同學卻上課打了瞌睡&#xff0c;趕緊來教教…

python參數傳遞*args和**kwargs

python參數傳遞*args和**kwargs 和* 實際上真正的Python參數傳遞語法是 * 和 ** 。*args 和 **kwargs 只是一種約定俗成的編程實踐。我們也可以寫成 *vars 和 **kvars 。就如同其他常規變量的命名一樣&#xff0c; args 和 kwargs 只是一種習慣的名稱。 *args 和 **kwargs 一…

聽GPT 講Rust源代碼--src/tools(25)

File: rust/src/tools/clippy/clippy_lints/src/methods/suspicious_command_arg_space.rs 在Rust源代碼中&#xff0c;suspicious_command_arg_space.rs文件位于clippy_lints工具包的methods目錄下&#xff0c;用于實現Clippy lint SUSPICIOUS_COMMAND_ARG_SPACE。 Clippy是Ru…

Java一次編譯,到處運行是如何實現的

Java一次編譯&#xff0c;到處運行是如何實現的 轉自&#xff1a;https://cloud.tencent.com/developer/article/1415194 &#xff08;排版微調&#xff09; JAVA編譯運行總覽 Java是一種高級語言&#xff0c;要讓計算機執行你撰寫的Java程序&#xff0c;也得通過編譯程序的…