一、論文
????????MobileNetV2 論文提出了一種新的移動架構,該架構提高了移動模型在多個任務和基準測試中的性能,以及在各種不同模型大小范圍內的性能. 該架構基于倒殘差結構,其中 shortcut 連接在 thin bottleneck 層之間. 中間的 expansion 層使用輕量級 depthwise 卷積來過濾特征,作為非線性的來源. 此外,作者發現,為了保持表示能力,重要的是要移除 narrow 層中的非線性. 論文展示了這種架構在 ImageNet 分類、COCO 目標檢測和 VOC 圖像分割任務中的有效性. 論文評估了準確性、計算成本(以 multiply-adds 衡量)、延遲和參數數量之間的權衡. 該論文的關鍵貢獻是具有線性 bottleneck 的倒殘差模塊,它可以實現高效的推理并減少內存占用.
1.1、基本信息?
-
標題: MobileNetV2: Inverted Residuals and Linear Bottlenecks ?
-
作者: Mark Sandler, Andrew Howard, Menglong Zhu, Andrey Zhmoginov, Liang-Chieh Chen ?
-
單位: Google Inc. ?
-
主要貢獻: 提出了一種新的移動架構,MobileNetV2,它提高了移動模型在多個任務和基準測試中的性能,以及不同模型大小范圍內的性能。 ?
1.2、主要內容
1.2.1、倒殘差塊和線性瓶頸
-
MobileNetV2 的核心創新在于“倒殘差塊”。傳統的殘差塊(如 ResNet 中的)連接的是較寬的層,而倒殘差塊連接的是較窄的 bottleneck 層。中間的 expansion 層負責擴展通道維度。
-
這些塊使用深度可分離卷積以提高效率。
-
一個關鍵要素是“線性 bottleneck”。論文認為,ReLU 非線性激活在低維空間中可能會破壞信息。因此,每個塊的最后一層是線性的(沒有 ReLU),以保留信息。
1.2.2、架構設計
-
MobileNetV2 在 MobileNetV1 的深度可分離卷積的基礎上構建。
-
該架構由一個初始卷積層、一系列倒殘差塊和一個最終卷積層組成。
-
平均池化層和全連接層用于分類。
-
使用寬度乘數和分辨率乘數來創建更小、計算量更少的模型。
1.2.3、效率
-
倒殘差結構旨在實現高內存效率,這對于移動設備至關重要。它可以減少內存占用和對昂貴的內存訪問的需求。
-
與標準卷積相比,深度可分離卷積顯著減少了參數數量和計算量。
1.2.4、性能
-
論文表明,MobileNetV2 在 ImageNet 分類任務上比 MobileNetV1 實現了更高的準確率和效率。
-
它還在目標檢測(使用 SSDLite)和語義分割(使用 Mobile DeepLabv3)方面取得了出色的效果。
1.3、作用
????????MobileNetV2 旨在為移動和資源受限環境提供高效的卷積神經網絡。 ?它在準確性和計算成本之間實現了有效的權衡。
1.4、影響
????????MobileNetV2 改進了 MobileNetV1 的架構,并在多個任務和基準測試中實現了最先進的性能。它為在資源受限的設備上部署深度學習模型提供了一種有效的方法。
1.5、優點
-
效率: MobileNetV2 通過使用 depthwise separable 卷積和 inverted residuals,顯著減少了計算需求和內存占用。 ?
-
性能: MobileNetV2 在圖像分類、目標檢測和語義分割等任務中表現出色。 ?
-
可調性: 寬度乘數和分辨率乘數允許輕松調整模型大小,以適應不同的資源約束。 ?
-
內存效率: 倒殘差 bottleneck 層允許特別節省內存的實現,這對于移動應用程序非常重要。
1.6、缺點
-
復雜性: 與 MobileNetV1 相比,MobileNetV2 引入了 inverted residual 塊,這可能會使模型架構更復雜。
-
權衡: 雖然寬度乘數和分辨率乘數提供了靈活性,但它們需要在準確性和效率之間進行仔細的權衡。
論文地址:
????????[1801.04381] MobileNetV2: Inverted Residuals and Linear Bottlenecks
?二、MobileNetV2
2.1、網絡的背景
????????MobileNetV1網絡的depthwise部分的卷積核容易參數為0,導致浪費掉。 MobileNetV2網絡是由谷歌團隊在2018年提出的,它對于MobileNetV1而言,有著更高的準確率和更小的網絡模型。
2.2、Inverted Residuals(倒殘差結構)?
????????在MobileNetV2 中,是先升維,再降維的操作,所以該結構叫倒殘差結構,網絡結構表格中的 bottleneck就是倒殘差結構。
殘差結構的過程是:
????????1、1x1卷積降維
????????2、3x3卷積
????????3、1x1卷積升維
????????即對輸入特征矩陣進行利用1x1卷積進行降維,減少輸入特征矩陣的channel,然后 通過3x3的卷積核進行處理提取特征,最后通過1x1的卷積核進行升維,那么它的結 構就是兩邊深,中間淺的結構。
????????在MobileNetV2網絡結構中,采用了倒殘差結構,從名字可以聯想到,它的結構應該 是中間深,兩邊淺,它的結構如下圖所示:
倒殘差結構的過程是:?
????????1、 首先會通過一個1x1卷積層來進行升維處理,在卷積后會跟有BN和Relu6激活函 數
????????2、 緊接著是一個3x3大小DW卷積,卷積后面依舊會跟有BN和Relu6激活函數
????????3、 最后一個卷積層是1x1卷積,起到降維作用,注意卷積后只跟了BN結構,并沒有 使用Relu6激活函數
????????在MobileNetV1中,DW卷積的個數局限于上一層的輸出通道數,無法自由改變,但 是加入PW卷積之后,也就是升維卷積之后,DW卷積的個數取決于PW卷積的輸出通 道數,而這個通道數是可以任意指定的,因此解除了3x3卷積核個數的限制。
2.3、Relu6
????????因為在低維空間使用非線性函數(如Relu)會損失一些信息,高維空間中會損失得相 對少一些,因此引入Inverted Residuals,升維之后再卷積,能夠更好地保留特征。
????????在移動端設備float16的低精度的時 候,也能有很好的數值分辨率,如果對Relu的激活范圍不加限制,輸出范圍為0到正 無窮,如果激活值非常大,分布在一個很大的范圍內,則低精度的float16無法很好 地精確描述如此大范圍的數值,帶來精度損失,所以在量化過程中,Relu6能夠有更 好的量化表現和更小的精度下降。
2.4、Linear Bottlenecks
2.5、Shortcut?
????????注意:Shortcut并不是該網絡提出的,而是殘差結構提出的。
????????這里由于具有倒殘差結構,所以也會有shortcut操作,如下圖所示,左側為有 shortcut連接的倒殘差結構,右側是無shortcut連接的倒殘差結構:
????????shortcut將輸入與輸出直接進行相加,可以使得網絡在較深的時候依舊可以進行訓 練。
注意:這里只有stride=1且輸入特征矩陣與輸出特征矩陣shape相同時才有 shortcut連接。
2.6、拓展因子
舉例:
????????假如輸入特征矩陣是128維的,隱藏層個數是256,條件1是輸出必須有128個 為正,并且這128個輸出為正的神經元的應該是非線性相關的。
~~~~當隱藏層m遠大于輸入n,那么 更容易讓ReLU保持可逆。
在訓練前和訓練后,各個層是否發生了不可逆的情況, 如下圖所示:
????????在訓練前,由于都是初始化的,所以激活為正的特征個數都比較集中,圖中虛線是信 息會丟失的閾值,縱坐標低于該虛線,會導致信息丟失。在訓練后,平均值變化較 小,但是兩級分化較為嚴重,有一些最小值已經低于了閾值,就會造成信息丟失,但 是絕大多數層還是可逆的。????????
2.7、網絡的結構
????????Input是每一層結構的輸入矩陣尺寸和channel;
????????Operator是操作;
????????t是拓展因子;
????????c是輸出特征矩陣channel;
????????n是bottleneck的重復次數;
????????s是步距,如果bottleneck重復,但只針對于第一次bottleneck的DW卷積,其他為 1。
????????由網絡結構的圖可以看到最后一層是卷積層,但是其實就是一個全連接層的作用,k 是輸出的類別,如果是ImageNet數據集,那么k就是1000。
注意:
????????在每個DW卷積之后都有batchNorm操作,這里組件中為了減少學習者工 作量并沒有體現該結構,但是學習者需要知道。
注意:
????????在第一個bottleneck結構中,由于t=1,所以并沒有進行升維操作,即沒有 第一個Conv2D層。
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary# 定義一個輔助函數,用于確保通道數可以被 divisor 整除,這對于硬件優化很重要
def _make_divisible(v, divisor, min_value=None):"""確保通道數 v 可以被 divisor 整除,如果需要可以指定最小值 min_value。這個技巧在移動端神經網絡設計中常用,有助于硬件加速。"""if min_value is None:min_value = divisornew_v = max(min_value, int(v + divisor / 2) // divisor * divisor)# 為了避免向下取整過多導致信息損失,進行一個小的調整if new_v < 0.9 * v:new_v += divisorreturn new_v# 定義一個卷積 + BatchNorm + ReLU6 的基本塊
class ConvBNReLU(nn.Sequential):def __init__(self, in_planes, out_planes, kernel_size, stride, groups=1):"""卷積、批歸一化和 ReLU6 激活函數的組合。groups 參數用于實現分組卷積(當 groups > 1 時)或深度可分離卷積(當 groups 等于輸入通道數時)。"""padding = (kernel_size - 1) // 2super(ConvBNReLU, self).__init__(nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),nn.BatchNorm2d(out_planes),nn.ReLU6(inplace=True) # ReLU6 激活函數,限制輸出范圍在 [0, 6])# 定義 MobileNetV2 的核心模塊:倒置殘差塊 (Inverted Residual Block)
class InvertedResidual(nn.Module):def __init__(self, inp, oup, stride, expand_ratio):"""倒置殘差塊是 MobileNetV2 的基本構建單元。它首先通過一個擴展層增加通道數,然后進行深度可分離卷積,最后通過一個投影層減少通道數。如果輸入和輸出的形狀相同且步長為 1,則使用殘差連接。Args:inp (int): 輸入通道數oup (int): 輸出通道數stride (int): 卷積步長 (1 或 2)expand_ratio (int): 擴展率,中間層的通道數是輸入通道數的 expand_ratio 倍"""super(InvertedResidual, self).__init__()self.stride = strideassert stride in [1, 2]hidden_dim = int(round(inp * expand_ratio)) # 中間擴展層的通道數self.use_res_connect = self.stride == 1 and inp == oup # 判斷是否使用殘差連接layers = []if expand_ratio != 1:# 擴展層:1x1 卷積增加通道數layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1, stride=1))layers.extend([# 深度可分離卷積:一個深度卷積和一個逐點卷積的組合# 深度卷積 (Depthwise Convolution): 每個輸入通道應用一個獨立的卷積核ConvBNReLU(hidden_dim, hidden_dim, kernel_size=3, stride=stride, groups=hidden_dim),# 逐點卷積 (Pointwise Convolution): 1x1 卷積用于線性地組合深度卷積的輸出nn.Conv2d(hidden_dim, oup, kernel_size=1, stride=1, bias=False),nn.BatchNorm2d(oup),])self.conv = nn.Sequential(*layers)def forward(self, x):if self.use_res_connect:return x + self.conv(x) # 如果滿足條件,使用殘差連接else:return self.conv(x)# 定義 MobileNetV2 網絡結構
class MobileNetV2(nn.Module):def __init__(self, num_classes=1000, width_mult=1.0, inverted_residual_setting=None, round_nearest=8):"""MobileNet V2 的主體網絡結構。Args:num_classes (int): 分類的類別數width_mult (float): 寬度乘數,用于調整網絡中所有層的通道數,控制模型大小inverted_residual_setting (list of list): 倒置殘差塊的配置列表,每個子列表包含 [t, c, n, s],分別表示擴展率 (t),輸出通道數 (c),重復次數 (n),步長 (s)。round_nearest (int): 將通道數調整為最接近的 round_nearest 的倍數,有助于硬件優化。"""super(MobileNetV2, self).__init__()if inverted_residual_setting is None:# 默認的 MobileNetV2 結構參數input_channel = 32last_channel = 1280inverted_residual_setting = [# t, c, n, s[1, 16, 1, 1],[6, 24, 2, 2],[6, 32, 3, 2],[6, 64, 4, 2],[6, 96, 3, 1],[6, 160, 3, 2],[6, 320, 1, 1],]else:# 如果提供了自定義的 inverted_residual_setting,則使用它if len(inverted_residual_setting) != 7:raise ValueError("inverted_residual_setting 應該包含 7 個層級的配置")# 構建網絡的第一個卷積層input_channel = _make_divisible(input_channel * width_mult, round_nearest)self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)features = [ConvBNReLU(3, input_channel, kernel_size=3, stride=2)] # 輸入通道數為 3 (RGB),輸出通道數為調整后的 input_channel,步長為 2# 構建倒置殘差塊for t, c, n, s in inverted_residual_setting:output_channel = _make_divisible(c * width_mult, round_nearest)for i in range(n):stride = s if i == 0 else 1 # 只有每個模塊的第一個殘差塊使用指定的步長,其余步長都為 1features.append(InvertedResidual(input_channel, output_channel, stride, expand_ratio=t))input_channel = output_channel # 更新下一個殘差塊的輸入通道數# 構建最后一個卷積層features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1, stride=1))# 將特征提取層封裝成 Sequential 容器self.features = nn.Sequential(*features)# 構建分類器self.classifier = nn.Sequential(nn.Dropout(0.2), # Dropout 正則化,防止過擬合nn.Linear(self.last_channel, num_classes), # 全連接層,將特征映射到類別數)# 初始化權重for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out') # 使用 Kaiming 正態分布初始化卷積層權重if m.bias is not None:nn.init.zeros_(m.bias) # 如果有偏置,初始化為 0elif isinstance(m, nn.BatchNorm2d):nn.init.ones_(m.weight) # BatchNorm 的權重初始化為 1nn.init.zeros_(m.bias) # BatchNorm 的偏置初始化為 0elif isinstance(m, nn.Linear):nn.init.normal_(m.weight, 0, 0.01) # 全連接層的權重使用均值為 0,標準差為 0.01 的正態分布初始化nn.init.zeros_(m.bias) # 全連接層的偏置初始化為 0def _forward_impl(self, x):# 前向傳播的實現x = self.features(x) # 通過特征提取層x = F.adaptive_avg_pool2d(x, 1) # 自適應平均池化,將特征圖尺寸變為 1x1x = torch.flatten(x, 1) # 將特征圖展平成一維向量x = self.classifier(x) # 通過分類器return xdef forward(self, x):# 定義前向傳播函數return self._forward_impl(x)if __name__ == '__main__':model = MobileNetV2()print(summary(model,(3,224,224)))
----------------------------------------------------------------Layer (type) Output Shape Param #
================================================================Conv2d-1 [-1, 32, 112, 112] 864BatchNorm2d-2 [-1, 32, 112, 112] 64ReLU6-3 [-1, 32, 112, 112] 0Conv2d-4 [-1, 32, 112, 112] 288BatchNorm2d-5 [-1, 32, 112, 112] 64ReLU6-6 [-1, 32, 112, 112] 0Conv2d-7 [-1, 16, 112, 112] 512BatchNorm2d-8 [-1, 16, 112, 112] 32InvertedResidual-9 [-1, 16, 112, 112] 0Conv2d-10 [-1, 96, 112, 112] 1,536BatchNorm2d-11 [-1, 96, 112, 112] 192ReLU6-12 [-1, 96, 112, 112] 0Conv2d-13 [-1, 96, 56, 56] 864BatchNorm2d-14 [-1, 96, 56, 56] 192ReLU6-15 [-1, 96, 56, 56] 0Conv2d-16 [-1, 24, 56, 56] 2,304BatchNorm2d-17 [-1, 24, 56, 56] 48InvertedResidual-18 [-1, 24, 56, 56] 0Conv2d-19 [-1, 144, 56, 56] 3,456BatchNorm2d-20 [-1, 144, 56, 56] 288ReLU6-21 [-1, 144, 56, 56] 0Conv2d-22 [-1, 144, 56, 56] 1,296BatchNorm2d-23 [-1, 144, 56, 56] 288ReLU6-24 [-1, 144, 56, 56] 0Conv2d-25 [-1, 24, 56, 56] 3,456BatchNorm2d-26 [-1, 24, 56, 56] 48InvertedResidual-27 [-1, 24, 56, 56] 0Conv2d-28 [-1, 144, 56, 56] 3,456BatchNorm2d-29 [-1, 144, 56, 56] 288ReLU6-30 [-1, 144, 56, 56] 0Conv2d-31 [-1, 144, 28, 28] 1,296BatchNorm2d-32 [-1, 144, 28, 28] 288ReLU6-33 [-1, 144, 28, 28] 0Conv2d-34 [-1, 32, 28, 28] 4,608BatchNorm2d-35 [-1, 32, 28, 28] 64InvertedResidual-36 [-1, 32, 28, 28] 0Conv2d-37 [-1, 192, 28, 28] 6,144BatchNorm2d-38 [-1, 192, 28, 28] 384ReLU6-39 [-1, 192, 28, 28] 0Conv2d-40 [-1, 192, 28, 28] 1,728BatchNorm2d-41 [-1, 192, 28, 28] 384ReLU6-42 [-1, 192, 28, 28] 0Conv2d-43 [-1, 32, 28, 28] 6,144BatchNorm2d-44 [-1, 32, 28, 28] 64InvertedResidual-45 [-1, 32, 28, 28] 0Conv2d-46 [-1, 192, 28, 28] 6,144BatchNorm2d-47 [-1, 192, 28, 28] 384ReLU6-48 [-1, 192, 28, 28] 0Conv2d-49 [-1, 192, 28, 28] 1,728BatchNorm2d-50 [-1, 192, 28, 28] 384ReLU6-51 [-1, 192, 28, 28] 0Conv2d-52 [-1, 32, 28, 28] 6,144BatchNorm2d-53 [-1, 32, 28, 28] 64InvertedResidual-54 [-1, 32, 28, 28] 0Conv2d-55 [-1, 192, 28, 28] 6,144BatchNorm2d-56 [-1, 192, 28, 28] 384ReLU6-57 [-1, 192, 28, 28] 0Conv2d-58 [-1, 192, 14, 14] 1,728BatchNorm2d-59 [-1, 192, 14, 14] 384ReLU6-60 [-1, 192, 14, 14] 0Conv2d-61 [-1, 64, 14, 14] 12,288BatchNorm2d-62 [-1, 64, 14, 14] 128InvertedResidual-63 [-1, 64, 14, 14] 0Conv2d-64 [-1, 384, 14, 14] 24,576BatchNorm2d-65 [-1, 384, 14, 14] 768ReLU6-66 [-1, 384, 14, 14] 0Conv2d-67 [-1, 384, 14, 14] 3,456BatchNorm2d-68 [-1, 384, 14, 14] 768ReLU6-69 [-1, 384, 14, 14] 0Conv2d-70 [-1, 64, 14, 14] 24,576BatchNorm2d-71 [-1, 64, 14, 14] 128InvertedResidual-72 [-1, 64, 14, 14] 0Conv2d-73 [-1, 384, 14, 14] 24,576BatchNorm2d-74 [-1, 384, 14, 14] 768ReLU6-75 [-1, 384, 14, 14] 0Conv2d-76 [-1, 384, 14, 14] 3,456BatchNorm2d-77 [-1, 384, 14, 14] 768ReLU6-78 [-1, 384, 14, 14] 0Conv2d-79 [-1, 64, 14, 14] 24,576BatchNorm2d-80 [-1, 64, 14, 14] 128InvertedResidual-81 [-1, 64, 14, 14] 0Conv2d-82 [-1, 384, 14, 14] 24,576BatchNorm2d-83 [-1, 384, 14, 14] 768ReLU6-84 [-1, 384, 14, 14] 0Conv2d-85 [-1, 384, 14, 14] 3,456BatchNorm2d-86 [-1, 384, 14, 14] 768ReLU6-87 [-1, 384, 14, 14] 0Conv2d-88 [-1, 64, 14, 14] 24,576BatchNorm2d-89 [-1, 64, 14, 14] 128InvertedResidual-90 [-1, 64, 14, 14] 0Conv2d-91 [-1, 384, 14, 14] 24,576BatchNorm2d-92 [-1, 384, 14, 14] 768ReLU6-93 [-1, 384, 14, 14] 0Conv2d-94 [-1, 384, 14, 14] 3,456BatchNorm2d-95 [-1, 384, 14, 14] 768ReLU6-96 [-1, 384, 14, 14] 0Conv2d-97 [-1, 96, 14, 14] 36,864BatchNorm2d-98 [-1, 96, 14, 14] 192InvertedResidual-99 [-1, 96, 14, 14] 0Conv2d-100 [-1, 576, 14, 14] 55,296BatchNorm2d-101 [-1, 576, 14, 14] 1,152ReLU6-102 [-1, 576, 14, 14] 0Conv2d-103 [-1, 576, 14, 14] 5,184BatchNorm2d-104 [-1, 576, 14, 14] 1,152ReLU6-105 [-1, 576, 14, 14] 0Conv2d-106 [-1, 96, 14, 14] 55,296BatchNorm2d-107 [-1, 96, 14, 14] 192
InvertedResidual-108 [-1, 96, 14, 14] 0Conv2d-109 [-1, 576, 14, 14] 55,296BatchNorm2d-110 [-1, 576, 14, 14] 1,152ReLU6-111 [-1, 576, 14, 14] 0Conv2d-112 [-1, 576, 14, 14] 5,184BatchNorm2d-113 [-1, 576, 14, 14] 1,152ReLU6-114 [-1, 576, 14, 14] 0Conv2d-115 [-1, 96, 14, 14] 55,296BatchNorm2d-116 [-1, 96, 14, 14] 192
InvertedResidual-117 [-1, 96, 14, 14] 0Conv2d-118 [-1, 576, 14, 14] 55,296BatchNorm2d-119 [-1, 576, 14, 14] 1,152ReLU6-120 [-1, 576, 14, 14] 0Conv2d-121 [-1, 576, 7, 7] 5,184BatchNorm2d-122 [-1, 576, 7, 7] 1,152ReLU6-123 [-1, 576, 7, 7] 0Conv2d-124 [-1, 160, 7, 7] 92,160BatchNorm2d-125 [-1, 160, 7, 7] 320
InvertedResidual-126 [-1, 160, 7, 7] 0Conv2d-127 [-1, 960, 7, 7] 153,600BatchNorm2d-128 [-1, 960, 7, 7] 1,920ReLU6-129 [-1, 960, 7, 7] 0Conv2d-130 [-1, 960, 7, 7] 8,640BatchNorm2d-131 [-1, 960, 7, 7] 1,920ReLU6-132 [-1, 960, 7, 7] 0Conv2d-133 [-1, 160, 7, 7] 153,600BatchNorm2d-134 [-1, 160, 7, 7] 320
InvertedResidual-135 [-1, 160, 7, 7] 0Conv2d-136 [-1, 960, 7, 7] 153,600BatchNorm2d-137 [-1, 960, 7, 7] 1,920ReLU6-138 [-1, 960, 7, 7] 0Conv2d-139 [-1, 960, 7, 7] 8,640BatchNorm2d-140 [-1, 960, 7, 7] 1,920ReLU6-141 [-1, 960, 7, 7] 0Conv2d-142 [-1, 160, 7, 7] 153,600BatchNorm2d-143 [-1, 160, 7, 7] 320
InvertedResidual-144 [-1, 160, 7, 7] 0Conv2d-145 [-1, 960, 7, 7] 153,600BatchNorm2d-146 [-1, 960, 7, 7] 1,920ReLU6-147 [-1, 960, 7, 7] 0Conv2d-148 [-1, 960, 7, 7] 8,640BatchNorm2d-149 [-1, 960, 7, 7] 1,920ReLU6-150 [-1, 960, 7, 7] 0Conv2d-151 [-1, 320, 7, 7] 307,200BatchNorm2d-152 [-1, 320, 7, 7] 640
InvertedResidual-153 [-1, 320, 7, 7] 0Conv2d-154 [-1, 1280, 7, 7] 409,600BatchNorm2d-155 [-1, 1280, 7, 7] 2,560ReLU6-156 [-1, 1280, 7, 7] 0Dropout-157 [-1, 1280] 0Linear-158 [-1, 1000] 1,281,000
================================================================
Total params: 3,504,872
Trainable params: 3,504,872
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 152.87
Params size (MB): 13.37
Estimated Total Size (MB): 166.81
----------------------------------------------------------------