【IEEE 2025】本文參考論文Enhancing Low-Light Images with Kolmogorov–Arnold Networks in Transformer Attention
雖然不是頂刊,但是有值得學習的地方
論文地址:arxiv
源碼地址:github
文章目錄
- Part1 --- 論文精讀
- Part2 --- 代碼詳解
- 形狀追蹤代碼 (將原代碼的n_features 從31修改為32)
Part1 — 論文精讀
該論文提出了一種名為 KAN-T 的新型 Transformer 網絡,用于低光圖像增強 (LLIE)。其核心創新在于引入了一種受 Kolmogorov-Arnold 表示定理啟發的 Transformer 注意力機制。
1. 整體框架 (Overall Framework)
KAN-T 采用了一個 3 級編碼器-解碼器結構。
- 輸入處理與編碼: 輸入圖像首先通過一個 1 × 1 1 \times 1 1×1 卷積層進行特征擴展,從 H × W × 3 H \times W \times 3 H×W×3 擴展到 H × W × C H \times W \times C H×W×C。隨后,圖像被送入編碼器,該編碼器包含不同分辨率級別的 Transformer 模塊 ( H × W × C H \times W \times C H×W×C, H 2 × W 2 × 2 C \frac{H}{2} \times \frac{W}{2} \times 2C 2H?×2W?×2C, 以及 H 4 × W 4 × 4 C \frac{H}{4} \times \frac{W}{4} \times 4C 4H?×4W?×4C)。編碼器的目標是將輸入圖像轉換為包含關鍵特征的抽象內部表示。
- 瓶頸層: 編碼后的特征圖被下采樣至 H 8 × W 8 × 8 C \frac{H}{8} \times \frac{W}{8} \times 8C 8H?×8W?×8C,并通過 KAN-T 的瓶頸層,該瓶頸層利用四個順序排列的 Transformer 模塊來增強內部特征表示。
- 解碼與輸出: 內部表示隨后進入解碼過程,該過程由一系列 Transformer 模塊在不同級別組成,與編碼器對稱排列。最終的 H × W × C H \times W \times C H×W×C 特征圖經過卷積操作以減少通道數,生成 H × W × 3 H \times W \times 3 H×W×3 的輸出圖像。
- 跳躍連接: KAN-T 在相應的編碼器-解碼器級別采用跳躍連接,以幫助保留細節和豐富特征。
2. Transformer 模塊 (Transformer Block)
Transformer 模塊是 KAN-T 的主要構建單元,因其執行高級特征處理的能力而被使用。
- 組成: 該模塊由一個 Kolmogorov-Arnold 多頭自注意力 (KAN-MSA) 模塊、一個前饋網絡 (FFN) 和兩個層歸一化 (LN) 操作組成,同時在自注意力和特征提取兩個階段之間采用殘差連接。
- 處理流程:
- 自注意力階段: 輸入特征圖 F i n F_{in} Fin? 經過層歸一化后,由 KAN-MSA 處理,然后與原始輸入 F i n F_{in} Fin? 進行殘差連接,得到中間特征圖 F ^ \hat{F} F^。數學表達式為:
F ^ = KAN-MSA ( LN ( F i n ) ) + F i n \hat{F} = \text{KAN-MSA}(\text{LN}(F_{in})) + F_{in} F^=KAN-MSA(LN(Fin?))+Fin? - 特征提取階段: 中間特征圖 F ^ \hat{F} F^ 經過層歸一化后,由 FFN 處理,再與 F ^ \hat{F} F^ 進行殘差連接,得到輸出特征圖 F o u t F_{out} Fout?。數學表達式為:
F o u t = FFN ( LN ( F ^ ) ) + F ^ F_{out} = \text{FFN}(\text{LN}(\hat{F})) + \hat{F} Fout?=FFN(LN(F^))+F^
- 自注意力階段: 輸入特征圖 F i n F_{in} Fin? 經過層歸一化后,由 KAN-MSA 處理,然后與原始輸入 F i n F_{in} Fin? 進行殘差連接,得到中間特征圖 F ^ \hat{F} F^。數學表達式為:
3. Kolmogorov-Arnold 網絡多頭自注意力 (KAN-MSA)
這是該方法的核心創新點。
- 標準 MSA 的局限性: 標準多頭自注意力 (MSA) 模塊利用全連接 (fc) 層來獲取查詢 (Q)、鍵 (K) 和值 (V) 分量。雖然 fc 層可以聯合處理整個多變量輸入來建模復雜關系,但它們可能無法有效捕獲單個通道內的單變量關系,并且由于參數數量龐大(尤其對于高維輸入)而計算量大。
- KAN-MSA 原理: 為了克服這些限制,研究者引入了一種基于 KAN 的 MSA 機制,其靈感來源于 Kolmogorov-Arnold 表示定理。該定理指出,任何多變量連續函數都可以表示為連續單變量函數和加法的疊加。新方法還融入了可學習非線性的方面。
- KAN-MSA 處理流程:
- 多變量分解 (通道拆分): 給定輸入特征圖 F i n ∈ R H × W × C F_{in} \in R^{H \times W \times C} Fin?∈RH×W×C,首先執行通道拆分,將其分解為 F 1 , F 2 , . . . , F C F_1, F_2, ..., F_C F1?,F2?,...,FC?,其中每個 F i ∈ R H × W × 1 F_i \in R^{H \times W \times 1} Fi?∈RH×W×1。這使得模型能夠捕獲數據中更復雜和特定的模式。
- 單變量處理與可學習非線性: 對于每個通道 i i i, F i F_i Fi? 通過一個包含三個全連接層序列進行處理,每個層后都有非線性激活函數 Φ j i \Phi_j^i Φji?。通過使用三個順序的 fc 層,模型可以在激活過程中激活或停用某些神經元,從而確保可學習的非線性。
h 1 i = Φ i 1 ( W i 1 F i + b i 1 ) h_1^i = \Phi_i^1(W_i^1 F_i + b_i^1) h1i?=Φi1?(Wi1?Fi?+bi1?) h i 2 = Φ i 2 ( W i 2 h i 1 + b i 2 ) h_i^2 = \Phi_i^2(W_i^2 h_i^1 + b_i^2) hi2?=Φi2?(Wi2?hi1?+bi2?) h i 3 = Φ i 3 ( W i 3 h i 2 + b i 3 ) , h i 3 ∈ R H × W × 3 h_i^3 = \Phi_i^3(W_i^3 h_i^2 + b_i^3), \quad h_i^3 \in \mathbb{R}^{H \times W \times 3} hi3?=Φi3?(Wi3?hi2?+bi3?),hi3?∈RH×W×3 - 合并與 QKV 生成: 單變量處理的結果在通道維度上進行拼接,得到 F o u t ∈ R H × W × 3 C F_{out} \in \mathbb{R}^{H \times W \times 3C} Fout?∈RH×W×3C,然后將其三向拆分以獲得 Q、K、 V ∈ R H × W × C V \in \mathbb{R}^{H \times W \times C} V∈RH×W×C。
- 自注意力計算: Q、K、V 被重塑為 H W × C HW \times C HW×C,并用于生成自注意力特征圖 F o u t F_{out} Fout?。
F o u t = V × softmax ( K Q T T ) F_{out} = V \times \text{softmax}(\frac{K Q^T}{\mathcal{T}}) Fout?=V×softmax(TKQT?)
其中 T \mathcal{T} T 是一個可學習的參數,用于平衡注意力分數。 F o u t F_{out} Fout? 隨后被重塑回 H × W × C H \times W \times C H×W×C。
4. 前饋網絡 (Feed-Forward Network, FFN)
FFN 是 Transformer 模塊的另一個關鍵組成部分,它使用自注意力特征圖進行深度特征提取。
- 結構: 它采用三重卷積設置,并使用高斯誤差線性單元 (GELU) 激活函數 ( ψ \psi ψ)。
- 處理流程: 給定輸入特征圖 F i n ∈ R H × W × C F_{in} \in \mathbb{R}^{H \times W \times C} Fin?∈RH×W×C,其計算公式為:
F o u t = conv1 × 1 ( ψ conv3 × 3 ( ψ conv1 × 1 ( F i n ) ) ) F_{out} = \text{conv1} \times \text{1}(\psi \text{conv3} \times \text{3}(\psi \text{conv1} \times \text{1}(F_{in}))) Fout?=conv1×1(ψconv3×3(ψconv1×1(Fin?)))
其中,第一個 c o n v 1 × 1 conv1 \times 1 conv1×1 將特征圖擴展到 H × W × 4 C H \times W \times 4C H×W×4C 以幫助發現新模式; c o n v 3 × 3 conv3 \times 3 conv3×3 通過增加核大小執行高分辨率特征提取;最后一個 c o n v 1 × 1 conv1 \times 1 conv1×1 將特征圖壓縮回原始維度 H × W × C H \times W \times C H×W×C。
5. 損失函數 (Loss Function)
為了實現精確重建,采用了一個復合損失函數 L \mathcal{L} L。該混合損失函數集成了多個分量以解決圖像質量的各個方面,包括像素級準確性、結構完整性和感知保真度。
- 總體損失:
L = L M A E + α ? L M S ? S S I M + β ? L P e r c \mathcal{L} = \mathcal{L}_{MAE} + \alpha \cdot \mathcal{L}_{MS-SSIM} + \beta \cdot \mathcal{L}_{Perc} L=LMAE?+α?LMS?SSIM?+β?LPerc?
其中 α \alpha α 和 β \beta β 是平衡每個損失分量貢獻的超參數。 - 各分量原理:
- 平均絕對誤差損失 ( L M A E \mathcal{L}_{MAE} LMAE?): 作為主要項,它捕獲預測圖像 I ^ \hat{I} I^ 和真實圖像 I G T \mathcal{I}_{GT} IGT? 之間的平均差異。
L M A E ( x , y ) = 1 N ∑ x , y ∣ ∣ I ^ ( x , y ) ? I G T ( x , y ) ∣ ∣ 1 \mathcal{L}_{MAE}(x,y) = \frac{1}{N} \sum_{x,y} ||\hat{I}(x,y) - \mathcal{I}_{GT}(x,y)||_1 LMAE?(x,y)=N1?x,y∑?∣∣I^(x,y)?IGT?(x,y)∣∣1? - 多尺度結構相似性指數度量損失 ( L M S ? S S I M \mathcal{L}_{MS-SSIM} LMS?SSIM?): 評估預測圖像和真實圖像在多個尺度上的結構相似性。它通過評估結構失真(尤其是在低光等挑戰性條件下)來捕獲對保持圖像結構完整性至關重要的高級特征。
- 感知損失 ( L P e r c \mathcal{L}_{Perc} LPerc?): 利用預訓練的 VGG-19 網絡 ( Ψ \Psi Ψ) 來引入特征級監督。該損失測量預測圖像和真實圖像的高級特征表示之間的差異,有助于學習有意義的內部表示。
L P e r c ( x , y ) = 1 N ∑ x , y ∣ ∣ Ψ ( I ^ ( x , y ) ) ? Ψ ( I G T ( x , y ) ) ∣ ∣ 1 \mathcal{L}_{Perc}(x,y) = \frac{1}{N} \sum_{x,y} ||\Psi(\hat{I}(x,y)) - \Psi(\mathcal{I}_{GT}(x,y))||_1 LPerc?(x,y)=N1?x,y∑?∣∣Ψ(I^(x,y))?Ψ(IGT?(x,y))∣∣1?
通過集成這三個損失分量,該混合損失函數有效地平衡了像素級準確性、結構一致性和感知質量。
- 平均絕對誤差損失 ( L M A E \mathcal{L}_{MAE} LMAE?): 作為主要項,它捕獲預測圖像 I ^ \hat{I} I^ 和真實圖像 I G T \mathcal{I}_{GT} IGT? 之間的平均差異。
總而言之,該方法的核心原理在于利用 Kolmogorov-Arnold 表示定理的思想改進 Transformer 中的多頭自注意力機制,通過將多變量函數分解為單變量函數和線性組合,并引入可學習的非線性激活函數,從而在低光圖像增強任務中實現更靈活、更有效的特征表示和上下文信息捕獲。結合精心設計的編碼器-解碼器架構和復合損失函數,旨在實現卓越的性能。
Part2 — 代碼詳解
形狀追蹤代碼 (將原代碼的n_features 從31修改為32)
現在,我們編寫代碼來實例化 KANT
模型,并在其 forward
方法的每個重要步驟打印張量形狀。為了便于演示,我將稍微修改 KANT
類,以便在其前向傳播中更容易打印形狀。
import torch
import torch.nn as nn
import numbers
import torch.nn.functional as F
from einops import rearrange
import math# [KANT.py 中的所有類定義 (LayerNorm, GELU, KANAttention, FFN2, TransformerBlock 等) 必須粘貼在此處]
# ... (假設所有必要的類定義,如 LayerNorm, GELU, KolmogorovArnoldNetwork, KANAttention, FFN2, TransformerBlock 都已在此處定義) ...# --- KANT.py 內容開始 (為獨立執行而復制) ---
# 工具函數
def to_3d(x):return rearrange(x, 'b c h w -> b (h w) c')def to_4d(x,h,w):return rearrange(x, 'b (h w) c -> b c h w',h=h,w=w)class BiasFree_LayerNorm(nn.Module):def __init__(self, normalized_shape):super(BiasFree_LayerNorm, self).__init__()if isinstance(normalized_shape, numbers.Integral):normalized_shape = (normalized_shape,)normalized_shape = torch.Size(normalized_shape)assert len(normalized_shape) == 1self.weight = nn.Parameter(torch.ones(normalized_shape))self.normalized_shape = normalized_shapedef forward(self, x):sigma = x.var(-1, keepdim=True, unbiased=False)return x / torch.sqrt(sigma+1e-5) * self.weightclass WithBias_LayerNorm(nn.Module):def __init__(self, normalized_shape):super(WithBias_LayerNorm, self).__init__()if isinstance(normalized_shape, numbers.Integral):normalized_shape = (normalized_shape,)normalized_shape = torch.Size(normalized_shape)assert len(normalized_shape) == 1self.weight = nn.Parameter(torch.ones(normalized_shape))self.bias = nn.Parameter(torch.zeros(normalized_shape))self.normalized_shape = normalized_shapedef forward(self, x):mu = x.mean(-1, keepdim=True)sigma = x.var(-1, keepdim=True, unbiased=False)return (x - mu) / torch.sqrt(sigma+1e-5) * self.weight + self.biasclass LayerNorm(nn.Module):def __init__(self, dim, LayerNorm_type):super(LayerNorm, self).__init__()if LayerNorm_type =='BiasFree':self.body = BiasFree_LayerNorm(dim)else:self.body = WithBias_LayerNorm(dim)def forward(self, x):h, w = x.shape[-2:]return to_4d(self.body(to_3d(x)), h, w)class GELU(nn.Module):def forward(self, x):return F.gelu(x)# KolmogorovArnoldNetwork (基于 MLP 的 KAN)
class KolmogorovArnoldNetwork(nn.Module):def __init__(self, input_channels, hidden_size=256):super(KolmogorovArnoldNetwork, self).__init__()self.input_channels = input_channelsself.hidden_size = hidden_sizeself.fc1_list = nn.ModuleList([nn.Linear(1, hidden_size) for _ in range(input_channels)])self.fc2_list = nn.ModuleList([nn.Linear(hidden_size, hidden_size) for _ in range(input_channels)])self.fc3_list = nn.ModuleList([nn.Linear(hidden_size, 3) for _ in range(input_channels)]) # 輸出3個用于Q,K,V部分self.relu = nn.ReLU()def forward(self, x): # 期望輸入 x 形狀: (batch_size, H, W, C)batch_size, H, W, C = x.shapex_reshaped = x.reshape(-1, C)outputs_mlp = []for i in range(self.input_channels):xi = x_reshaped[:, i:i+1]xi = self.relu(self.fc1_list[i](xi))xi = self.relu(self.fc2_list[i](xi))xi = self.fc3_list[i](xi)outputs_mlp.append(xi)x_cat = torch.cat(outputs_mlp, dim=1)x_final = x_cat.view(batch_size, H, W, C*3)return x_final# KANAttention
class KANAttention(nn.Module):def __init__(self, dim, num_heads, bias=True):super(KANAttention, self).__init__()self.num_heads = num_headsself.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))self.proj_in = KolmogorovArnoldNetwork(input_channels=dim, hidden_size=dim)self.proj_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)def apply_kan(self, kan_layer, x_in_bcwh):x_permuted_for_kan = x_in_bcwh.permute(0, 2, 3, 1).contiguous()kan_output_bhwc = kan_layer(x_permuted_for_kan)x_out_bcwh = kan_output_bhwc.permute(0, 3, 1, 2).contiguous()return x_out_bcwhdef forward(self, x):b,c,h,w = x.shapeqkv = self.apply_kan(self.proj_in, x)q,k,v = qkv.chunk(3, dim=1)q = rearrange(q, 'b (head c_head) h w -> b head c_head (h w)', head=self.num_heads)k = rearrange(k, 'b (head c_head) h w -> b head c_head (h w)', head=self.num_heads)v = rearrange(v, 'b (head c_head) h w -> b head c_head (h w)', head=self.num_heads)q = torch.nn.functional.normalize(q, dim=-1)k = torch.nn.functional.normalize(k, dim=-1)attn = (q @ k.transpose(-2, -1)) * self.temperatureattn = attn.softmax(dim=-1)out = (attn @ v)out = rearrange(out, 'b head c_head (h w) -> b (head c_head) h w', head=self.num_heads, h=h, w=w)out = self.proj_out(out)return out# FFN2
class FFN2(nn.Module):def __init__(self, dim, mult=4):super().__init__()self.net = nn.Sequential(nn.Conv2d(dim, dim * mult, 1, 1, bias=False), GELU(),nn.Conv2d(dim * mult, dim * mult, 3, 1, 1, bias=False, groups=dim * mult), GELU(),nn.Conv2d(dim * mult, dim, 1, 1, bias=False),)def forward(self, x):return self.net(x)# TransformerBlock
class TransformerBlock(nn.Module):def __init__(self, in_channels, num_heads, num_experts, dim_feedforward=None, dropout=0.1, LayerNorm_type='WithBias'):super(TransformerBlock, self).__init__()self.attention = KANAttention(dim=in_channels, num_heads=num_heads)self.norm1 = LayerNorm(dim=in_channels, LayerNorm_type=LayerNorm_type)self.moe = FFN2(dim=in_channels) # 使用 FFN2self.norm2 = LayerNorm(dim=in_channels, LayerNorm_type=LayerNorm_type)def forward(self, x):f_in_normed_for_attn = self.norm1(x)attended_features = self.attention(f_in_normed_for_attn)x = x + attended_featuresf_hat_normed_for_ffn = self.norm2(x)ffn_features = self.moe(f_hat_normed_for_ffn)x = x + ffn_featuresreturn x# KANT 模型 (帶有形狀打印功能)
class KANT_ShapeTracer(nn.Module):def __init__(self, in_channels=3, out_channels=3, n_feat=31): # 為追蹤簡化參數super(KANT_ShapeTracer, self).__init__()print(f"--- KANT 模型初始化 ---")print(f"輸入通道數: {in_channels}, 輸出通道數: {out_channels}, 基礎特征數 (n_feat): {n_feat}\n")num_heads_start = 2num_experts = None # 未使用,因為直接使用 FFN2self.conv_in = nn.Conv2d(in_channels, n_feat, kernel_size=1, padding='same')print(f" conv_in: Conv2d({in_channels}, {n_feat}, kernel_size=1)")# 第 1 層編碼器current_heads_l1 = num_heads_startself.transformer_block1_1 = TransformerBlock(n_feat, current_heads_l1, num_experts)print(f" transformer_block1_1: TransformerBlock(n_feat={n_feat}, heads={current_heads_l1})")self.downsample1 = nn.Conv2d(n_feat, n_feat * 2, kernel_size=3, stride=2, padding=1)print(f" downsample1: Conv2d({n_feat}, {n_feat*2}, kernel_size=3, stride=2)")# 第 2 層編碼器current_heads_l2 = num_heads_start # 在瓶頸層調整前,與l1保持一致 (根據原始代碼)self.transformer_block2_1 = TransformerBlock(n_feat * 2, current_heads_l2, num_experts)print(f" transformer_block2_1: TransformerBlock(n_feat={n_feat*2}, heads={current_heads_l2})")self.transformer_block2_2 = TransformerBlock(n_feat * 2, current_heads_l2, num_experts)print(f" transformer_block2_2: TransformerBlock(n_feat={n_feat*2}, heads={current_heads_l2})")self.downsample2 = nn.Conv2d(n_feat * 2, n_feat * 4, kernel_size=3, stride=2, padding=1)print(f" downsample2: Conv2d({n_feat*2}, {n_feat*4}, kernel_size=3, stride=2)")# 瓶頸層current_heads_bn = current_heads_l2 * 2 # 瓶頸層頭數加倍self.bottleneck_1 = TransformerBlock(n_feat * 4, current_heads_bn, num_experts)print(f" bottleneck_1: TransformerBlock(n_feat={n_feat*4}, heads={current_heads_bn})")self.bottleneck_2 = TransformerBlock(n_feat * 4, current_heads_bn, num_experts)print(f" bottleneck_2: TransformerBlock(n_feat={n_feat*4}, heads={current_heads_bn})")# 第 2 層解碼器current_heads_up2 = current_heads_bn // 2self.upsample2 = nn.ConvTranspose2d(n_feat * 4, n_feat * 2, kernel_size=3, stride=2, padding=1, output_padding=1)print(f" upsample2: ConvTranspose2d({n_feat*4}, {n_feat*2}, kernel_size=3, stride=2)")self.channel_adjust2 = nn.Conv2d(n_feat * 4, n_feat * 2, kernel_size=1) # 輸入是 n_feat*2 (上采樣) + n_feat*2 (跳躍) = n_feat*4print(f" channel_adjust2: Conv2d({n_feat*4}, {n_feat*2}, kernel_size=1)")self.transformer_block_up2_1 = TransformerBlock(n_feat * 2, current_heads_up2, num_experts)print(f" transformer_block_up2_1: TransformerBlock(n_feat={n_feat*2}, heads={current_heads_up2})")self.transformer_block_up2_2 = TransformerBlock(n_feat * 2, current_heads_up2, num_experts)print(f" transformer_block_up2_2: TransformerBlock(n_feat={n_feat*2}, heads={current_heads_up2})")# 第 1 層解碼器current_heads_up1 = current_heads_up2 // 2self.upsample1 = nn.ConvTranspose2d(n_feat * 2, n_feat, kernel_size=3, stride=2, padding=1, output_padding=1)print(f" upsample1: ConvTranspose2d({n_feat*2}, {n_feat}, kernel_size=3, stride=2)")self.channel_adjust1 = nn.Conv2d(n_feat * 2, n_feat, kernel_size=1) # 輸入是 n_feat (上采樣) + n_feat (跳躍) = n_feat*2print(f" channel_adjust1: Conv2d({n_feat*2}, {n_feat}, kernel_size=1)")self.transformer_block_up1_1 = TransformerBlock(n_feat, current_heads_up1, num_experts)print(f" transformer_block_up1_1: TransformerBlock(n_feat={n_feat}, heads={current_heads_up1})")self.conv_out = nn.Conv2d(n_feat, out_channels, kernel_size=1, padding='same')print(f" conv_out: Conv2d({n_feat}, {out_channels}, kernel_size=1)")print(f"--- KANT 模型初始化結束 ---\n")def forward(self, x):print(f"\n--- KANT 前向傳播形狀追蹤 ---")print(f"初始輸入形狀: {x.shape}")x = self.conv_in(x)print(f"經過 conv_in 后: {x.shape}")# 編碼器路徑x1 = self.transformer_block1_1(x)print(f"經過 transformer_block1_1 (x1) 后: {x1.shape}")x1_down = self.downsample1(x1)print(f"經過 downsample1 (x1_down) 后: {x1_down.shape}")x2 = self.transformer_block2_1(x1_down)print(f"經過 transformer_block2_1 后: {x2.shape}")x2 = self.transformer_block2_2(x2)print(f"經過 transformer_block2_2 (x2) 后: {x2.shape}")x2_down = self.downsample2(x2)print(f"經過 downsample2 (x2_down) 后: {x2_down.shape}")# 瓶頸層bn = self.bottleneck_1(x2_down)print(f"經過 bottleneck_1 后: {bn.shape}")bn = self.bottleneck_2(bn)print(f"經過 bottleneck_2 (bn) 后: {bn.shape}")# 解碼器路徑x2_up_pre_cat = self.upsample2(bn)print(f"經過 upsample2 (x2_up_pre_cat) 后: {x2_up_pre_cat.shape}")x2_up = torch.cat([x2_up_pre_cat, x2], dim=1)print(f"經過 cat([x2_up_pre_cat, x2]) 后: {x2_up.shape}")x2_up = self.channel_adjust2(x2_up)print(f"經過 channel_adjust2 (x2_up) 后: {x2_up.shape}")x2_up = self.transformer_block_up2_1(x2_up)print(f"經過 transformer_block_up2_1 后: {x2_up.shape}")x2_up = self.transformer_block_up2_2(x2_up)print(f"經過 transformer_block_up2_2 (x2_up) 后: {x2_up.shape}")x1_up_pre_cat = self.upsample1(x2_up)print(f"經過 upsample1 (x1_up_pre_cat) 后: {x1_up_pre_cat.shape}")x1_up = torch.cat([x1_up_pre_cat, x1], dim=1)print(f"經過 cat([x1_up_pre_cat, x1]) 后: {x1_up.shape}")x1_up = self.channel_adjust1(x1_up)print(f"經過 channel_adjust1 (x1_up) 后: {x1_up.shape}")x1_up = self.transformer_block_up1_1(x1_up)print(f"經過 transformer_block_up1_1 (x1_up) 后: {x1_up.shape}")x_out = self.conv_out(x1_up)print(f"經過 conv_out (最終輸出) 后: {x_out.shape}")print(f"--- KANT 前向傳播形狀追蹤結束 ---\n")return x_out# --- KANT.py 內容結束 ---if __name__ == '__main__':# 用于測試的示例參數batch_size = 1input_channels = 3height, width = 256, 256 # 示例圖像尺寸n_features = 31 # 基礎特征數量,與 KANT 類中一致# 創建一個虛擬輸入張量dummy_input = torch.randn(batch_size, input_channels, height, width)print(f"正在創建 KANT_ShapeTracer 模型,n_feat={n_features}...")# 實例化模型model = KANT_ShapeTracer(in_channels=input_channels, out_channels=input_channels, n_feat=n_features)# 執行一次前向傳播以追蹤形狀print(f"使用形狀為 {dummy_input.shape} 的虛擬輸入執行前向傳播")with torch.no_grad(): # 追蹤形狀時無需計算梯度output = model(dummy_input)print(f"最終輸出張量形狀: {output.shape}")# 你也可以打印模型結構# print("\n模型結構:")# print(model)
如何運行形狀追蹤腳本:
- 將上述代碼保存為一個 Python 文件 (例如,
trace_kant_shape_cn.py
)。 - 從終端運行它:
python trace_kant_shape_cn.py
這將打印 KANT_ShapeTracer
模型在初始化時的配置,然后在虛擬輸入通過 forward
方法中的每個重要層/操作時追蹤張量形狀。n_feat
參數設置為 31,與你的 KANT
類默認值一致。你可以更改 height
和 width
來測試不同的輸入分辨率。