【IEEE 2025】低光增強KANT(使用KAN代替MLP)----論文詳解與代碼解析

【IEEE 2025】本文參考論文Enhancing Low-Light Images with Kolmogorov–Arnold Networks in Transformer Attention
雖然不是頂刊,但是有值得學習的地方
論文地址:arxiv
源碼地址:github

文章目錄

  • Part1 --- 論文精讀
  • Part2 --- 代碼詳解
    • 形狀追蹤代碼 (將原代碼的n_features 從31修改為32)


Part1 — 論文精讀

該論文提出了一種名為 KAN-T 的新型 Transformer 網絡,用于低光圖像增強 (LLIE)。其核心創新在于引入了一種受 Kolmogorov-Arnold 表示定理啟發的 Transformer 注意力機制。
在這里插入圖片描述

1. 整體框架 (Overall Framework)

KAN-T 采用了一個 3 級編碼器-解碼器結構。

  • 輸入處理與編碼: 輸入圖像首先通過一個 1 × 1 1 \times 1 1×1 卷積層進行特征擴展,從 H × W × 3 H \times W \times 3 H×W×3 擴展到 H × W × C H \times W \times C H×W×C。隨后,圖像被送入編碼器,該編碼器包含不同分辨率級別的 Transformer 模塊 ( H × W × C H \times W \times C H×W×C, H 2 × W 2 × 2 C \frac{H}{2} \times \frac{W}{2} \times 2C 2H?×2W?×2C, 以及 H 4 × W 4 × 4 C \frac{H}{4} \times \frac{W}{4} \times 4C 4H?×4W?×4C)。編碼器的目標是將輸入圖像轉換為包含關鍵特征的抽象內部表示。
  • 瓶頸層: 編碼后的特征圖被下采樣至 H 8 × W 8 × 8 C \frac{H}{8} \times \frac{W}{8} \times 8C 8H?×8W?×8C,并通過 KAN-T 的瓶頸層,該瓶頸層利用四個順序排列的 Transformer 模塊來增強內部特征表示。
  • 解碼與輸出: 內部表示隨后進入解碼過程,該過程由一系列 Transformer 模塊在不同級別組成,與編碼器對稱排列。最終的 H × W × C H \times W \times C H×W×C 特征圖經過卷積操作以減少通道數,生成 H × W × 3 H \times W \times 3 H×W×3 的輸出圖像。
  • 跳躍連接: KAN-T 在相應的編碼器-解碼器級別采用跳躍連接,以幫助保留細節和豐富特征。

2. Transformer 模塊 (Transformer Block)

Transformer 模塊是 KAN-T 的主要構建單元,因其執行高級特征處理的能力而被使用。

  • 組成: 該模塊由一個 Kolmogorov-Arnold 多頭自注意力 (KAN-MSA) 模塊、一個前饋網絡 (FFN) 和兩個層歸一化 (LN) 操作組成,同時在自注意力和特征提取兩個階段之間采用殘差連接。
  • 處理流程:
    1. 自注意力階段: 輸入特征圖 F i n F_{in} Fin? 經過層歸一化后,由 KAN-MSA 處理,然后與原始輸入 F i n F_{in} Fin? 進行殘差連接,得到中間特征圖 F ^ \hat{F} F^。數學表達式為:
      F ^ = KAN-MSA ( LN ( F i n ) ) + F i n \hat{F} = \text{KAN-MSA}(\text{LN}(F_{in})) + F_{in} F^=KAN-MSA(LN(Fin?))+Fin?
    2. 特征提取階段: 中間特征圖 F ^ \hat{F} F^ 經過層歸一化后,由 FFN 處理,再與 F ^ \hat{F} F^ 進行殘差連接,得到輸出特征圖 F o u t F_{out} Fout?。數學表達式為:
      F o u t = FFN ( LN ( F ^ ) ) + F ^ F_{out} = \text{FFN}(\text{LN}(\hat{F})) + \hat{F} Fout?=FFN(LN(F^))+F^

3. Kolmogorov-Arnold 網絡多頭自注意力 (KAN-MSA)
在這里插入圖片描述

這是該方法的核心創新點。

  • 標準 MSA 的局限性: 標準多頭自注意力 (MSA) 模塊利用全連接 (fc) 層來獲取查詢 (Q)、鍵 (K) 和值 (V) 分量。雖然 fc 層可以聯合處理整個多變量輸入來建模復雜關系,但它們可能無法有效捕獲單個通道內的單變量關系,并且由于參數數量龐大(尤其對于高維輸入)而計算量大。
  • KAN-MSA 原理: 為了克服這些限制,研究者引入了一種基于 KAN 的 MSA 機制,其靈感來源于 Kolmogorov-Arnold 表示定理。該定理指出,任何多變量連續函數都可以表示為連續單變量函數和加法的疊加。新方法還融入了可學習非線性的方面。
  • KAN-MSA 處理流程:
    1. 多變量分解 (通道拆分): 給定輸入特征圖 F i n ∈ R H × W × C F_{in} \in R^{H \times W \times C} Fin?RH×W×C,首先執行通道拆分,將其分解為 F 1 , F 2 , . . . , F C F_1, F_2, ..., F_C F1?,F2?,...,FC?,其中每個 F i ∈ R H × W × 1 F_i \in R^{H \times W \times 1} Fi?RH×W×1。這使得模型能夠捕獲數據中更復雜和特定的模式。
    2. 單變量處理與可學習非線性: 對于每個通道 i i i F i F_i Fi? 通過一個包含三個全連接層序列進行處理,每個層后都有非線性激活函數 Φ j i \Phi_j^i Φji?。通過使用三個順序的 fc 層,模型可以在激活過程中激活或停用某些神經元,從而確保可學習的非線性。
      h 1 i = Φ i 1 ( W i 1 F i + b i 1 ) h_1^i = \Phi_i^1(W_i^1 F_i + b_i^1) h1i?=Φi1?(Wi1?Fi?+bi1?) h i 2 = Φ i 2 ( W i 2 h i 1 + b i 2 ) h_i^2 = \Phi_i^2(W_i^2 h_i^1 + b_i^2) hi2?=Φi2?(Wi2?hi1?+bi2?) h i 3 = Φ i 3 ( W i 3 h i 2 + b i 3 ) , h i 3 ∈ R H × W × 3 h_i^3 = \Phi_i^3(W_i^3 h_i^2 + b_i^3), \quad h_i^3 \in \mathbb{R}^{H \times W \times 3} hi3?=Φi3?(Wi3?hi2?+bi3?),hi3?RH×W×3
    3. 合并與 QKV 生成: 單變量處理的結果在通道維度上進行拼接,得到 F o u t ∈ R H × W × 3 C F_{out} \in \mathbb{R}^{H \times W \times 3C} Fout?RH×W×3C,然后將其三向拆分以獲得 Q、K、 V ∈ R H × W × C V \in \mathbb{R}^{H \times W \times C} VRH×W×C
    4. 自注意力計算: Q、K、V 被重塑為 H W × C HW \times C HW×C,并用于生成自注意力特征圖 F o u t F_{out} Fout?
      F o u t = V × softmax ( K Q T T ) F_{out} = V \times \text{softmax}(\frac{K Q^T}{\mathcal{T}}) Fout?=V×softmax(TKQT?)
      其中 T \mathcal{T} T 是一個可學習的參數,用于平衡注意力分數。 F o u t F_{out} Fout? 隨后被重塑回 H × W × C H \times W \times C H×W×C

4. 前饋網絡 (Feed-Forward Network, FFN)

FFN 是 Transformer 模塊的另一個關鍵組成部分,它使用自注意力特征圖進行深度特征提取。

  • 結構: 它采用三重卷積設置,并使用高斯誤差線性單元 (GELU) 激活函數 ( ψ \psi ψ)。
  • 處理流程: 給定輸入特征圖 F i n ∈ R H × W × C F_{in} \in \mathbb{R}^{H \times W \times C} Fin?RH×W×C,其計算公式為:
    F o u t = conv1 × 1 ( ψ conv3 × 3 ( ψ conv1 × 1 ( F i n ) ) ) F_{out} = \text{conv1} \times \text{1}(\psi \text{conv3} \times \text{3}(\psi \text{conv1} \times \text{1}(F_{in}))) Fout?=conv1×1(ψconv3×3(ψconv1×1(Fin?)))
    其中,第一個 c o n v 1 × 1 conv1 \times 1 conv1×1 將特征圖擴展到 H × W × 4 C H \times W \times 4C H×W×4C 以幫助發現新模式; c o n v 3 × 3 conv3 \times 3 conv3×3 通過增加核大小執行高分辨率特征提取;最后一個 c o n v 1 × 1 conv1 \times 1 conv1×1 將特征圖壓縮回原始維度 H × W × C H \times W \times C H×W×C

5. 損失函數 (Loss Function)

為了實現精確重建,采用了一個復合損失函數 L \mathcal{L} L。該混合損失函數集成了多個分量以解決圖像質量的各個方面,包括像素級準確性、結構完整性和感知保真度。

  • 總體損失:
    L = L M A E + α ? L M S ? S S I M + β ? L P e r c \mathcal{L} = \mathcal{L}_{MAE} + \alpha \cdot \mathcal{L}_{MS-SSIM} + \beta \cdot \mathcal{L}_{Perc} L=LMAE?+α?LMS?SSIM?+β?LPerc?
    其中 α \alpha α β \beta β 是平衡每個損失分量貢獻的超參數。
  • 各分量原理:
    • 平均絕對誤差損失 ( L M A E \mathcal{L}_{MAE} LMAE?): 作為主要項,它捕獲預測圖像 I ^ \hat{I} I^ 和真實圖像 I G T \mathcal{I}_{GT} IGT? 之間的平均差異。
      L M A E ( x , y ) = 1 N ∑ x , y ∣ ∣ I ^ ( x , y ) ? I G T ( x , y ) ∣ ∣ 1 \mathcal{L}_{MAE}(x,y) = \frac{1}{N} \sum_{x,y} ||\hat{I}(x,y) - \mathcal{I}_{GT}(x,y)||_1 LMAE?(x,y)=N1?x,y?∣∣I^(x,y)?IGT?(x,y)1?
    • 多尺度結構相似性指數度量損失 ( L M S ? S S I M \mathcal{L}_{MS-SSIM} LMS?SSIM?): 評估預測圖像和真實圖像在多個尺度上的結構相似性。它通過評估結構失真(尤其是在低光等挑戰性條件下)來捕獲對保持圖像結構完整性至關重要的高級特征。
    • 感知損失 ( L P e r c \mathcal{L}_{Perc} LPerc?): 利用預訓練的 VGG-19 網絡 ( Ψ \Psi Ψ) 來引入特征級監督。該損失測量預測圖像和真實圖像的高級特征表示之間的差異,有助于學習有意義的內部表示。
      L P e r c ( x , y ) = 1 N ∑ x , y ∣ ∣ Ψ ( I ^ ( x , y ) ) ? Ψ ( I G T ( x , y ) ) ∣ ∣ 1 \mathcal{L}_{Perc}(x,y) = \frac{1}{N} \sum_{x,y} ||\Psi(\hat{I}(x,y)) - \Psi(\mathcal{I}_{GT}(x,y))||_1 LPerc?(x,y)=N1?x,y?∣∣Ψ(I^(x,y))?Ψ(IGT?(x,y))1?
      通過集成這三個損失分量,該混合損失函數有效地平衡了像素級準確性、結構一致性和感知質量。
      在這里插入圖片描述

總而言之,該方法的核心原理在于利用 Kolmogorov-Arnold 表示定理的思想改進 Transformer 中的多頭自注意力機制,通過將多變量函數分解為單變量函數和線性組合,并引入可學習的非線性激活函數,從而在低光圖像增強任務中實現更靈活、更有效的特征表示和上下文信息捕獲。結合精心設計的編碼器-解碼器架構和復合損失函數,旨在實現卓越的性能。


Part2 — 代碼詳解

形狀追蹤代碼 (將原代碼的n_features 從31修改為32)

現在,我們編寫代碼來實例化 KANT 模型,并在其 forward 方法的每個重要步驟打印張量形狀。為了便于演示,我將稍微修改 KANT 類,以便在其前向傳播中更容易打印形狀。

import torch
import torch.nn as nn
import numbers
import torch.nn.functional as F
from einops import rearrange
import math# [KANT.py 中的所有類定義 (LayerNorm, GELU, KANAttention, FFN2, TransformerBlock 等) 必須粘貼在此處]
# ... (假設所有必要的類定義,如 LayerNorm, GELU, KolmogorovArnoldNetwork, KANAttention, FFN2, TransformerBlock 都已在此處定義) ...# --- KANT.py 內容開始 (為獨立執行而復制) ---
# 工具函數
def to_3d(x):return rearrange(x, 'b c h w -> b (h w) c')def to_4d(x,h,w):return rearrange(x, 'b (h w) c -> b c h w',h=h,w=w)class BiasFree_LayerNorm(nn.Module):def __init__(self, normalized_shape):super(BiasFree_LayerNorm, self).__init__()if isinstance(normalized_shape, numbers.Integral):normalized_shape = (normalized_shape,)normalized_shape = torch.Size(normalized_shape)assert len(normalized_shape) == 1self.weight = nn.Parameter(torch.ones(normalized_shape))self.normalized_shape = normalized_shapedef forward(self, x):sigma = x.var(-1, keepdim=True, unbiased=False)return x / torch.sqrt(sigma+1e-5) * self.weightclass WithBias_LayerNorm(nn.Module):def __init__(self, normalized_shape):super(WithBias_LayerNorm, self).__init__()if isinstance(normalized_shape, numbers.Integral):normalized_shape = (normalized_shape,)normalized_shape = torch.Size(normalized_shape)assert len(normalized_shape) == 1self.weight = nn.Parameter(torch.ones(normalized_shape))self.bias = nn.Parameter(torch.zeros(normalized_shape))self.normalized_shape = normalized_shapedef forward(self, x):mu = x.mean(-1, keepdim=True)sigma = x.var(-1, keepdim=True, unbiased=False)return (x - mu) / torch.sqrt(sigma+1e-5) * self.weight + self.biasclass LayerNorm(nn.Module):def __init__(self, dim, LayerNorm_type):super(LayerNorm, self).__init__()if LayerNorm_type =='BiasFree':self.body = BiasFree_LayerNorm(dim)else:self.body = WithBias_LayerNorm(dim)def forward(self, x):h, w = x.shape[-2:]return to_4d(self.body(to_3d(x)), h, w)class GELU(nn.Module):def forward(self, x):return F.gelu(x)# KolmogorovArnoldNetwork (基于 MLP 的 KAN)
class KolmogorovArnoldNetwork(nn.Module):def __init__(self, input_channels, hidden_size=256):super(KolmogorovArnoldNetwork, self).__init__()self.input_channels = input_channelsself.hidden_size = hidden_sizeself.fc1_list = nn.ModuleList([nn.Linear(1, hidden_size) for _ in range(input_channels)])self.fc2_list = nn.ModuleList([nn.Linear(hidden_size, hidden_size) for _ in range(input_channels)])self.fc3_list = nn.ModuleList([nn.Linear(hidden_size, 3) for _ in range(input_channels)]) # 輸出3個用于Q,K,V部分self.relu = nn.ReLU()def forward(self, x): # 期望輸入 x 形狀: (batch_size, H, W, C)batch_size, H, W, C = x.shapex_reshaped = x.reshape(-1, C)outputs_mlp = []for i in range(self.input_channels):xi = x_reshaped[:, i:i+1]xi = self.relu(self.fc1_list[i](xi))xi = self.relu(self.fc2_list[i](xi))xi = self.fc3_list[i](xi)outputs_mlp.append(xi)x_cat = torch.cat(outputs_mlp, dim=1)x_final = x_cat.view(batch_size, H, W, C*3)return x_final# KANAttention
class KANAttention(nn.Module):def __init__(self, dim, num_heads, bias=True):super(KANAttention, self).__init__()self.num_heads = num_headsself.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))self.proj_in = KolmogorovArnoldNetwork(input_channels=dim, hidden_size=dim)self.proj_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)def apply_kan(self, kan_layer, x_in_bcwh):x_permuted_for_kan = x_in_bcwh.permute(0, 2, 3, 1).contiguous()kan_output_bhwc = kan_layer(x_permuted_for_kan)x_out_bcwh = kan_output_bhwc.permute(0, 3, 1, 2).contiguous()return x_out_bcwhdef forward(self, x):b,c,h,w = x.shapeqkv = self.apply_kan(self.proj_in, x)q,k,v = qkv.chunk(3, dim=1)q = rearrange(q, 'b (head c_head) h w -> b head c_head (h w)', head=self.num_heads)k = rearrange(k, 'b (head c_head) h w -> b head c_head (h w)', head=self.num_heads)v = rearrange(v, 'b (head c_head) h w -> b head c_head (h w)', head=self.num_heads)q = torch.nn.functional.normalize(q, dim=-1)k = torch.nn.functional.normalize(k, dim=-1)attn = (q @ k.transpose(-2, -1)) * self.temperatureattn = attn.softmax(dim=-1)out = (attn @ v)out = rearrange(out, 'b head c_head (h w) -> b (head c_head) h w', head=self.num_heads, h=h, w=w)out = self.proj_out(out)return out# FFN2
class FFN2(nn.Module):def __init__(self, dim, mult=4):super().__init__()self.net = nn.Sequential(nn.Conv2d(dim, dim * mult, 1, 1, bias=False), GELU(),nn.Conv2d(dim * mult, dim * mult, 3, 1, 1, bias=False, groups=dim * mult), GELU(),nn.Conv2d(dim * mult, dim, 1, 1, bias=False),)def forward(self, x):return self.net(x)# TransformerBlock
class TransformerBlock(nn.Module):def __init__(self, in_channels, num_heads, num_experts, dim_feedforward=None, dropout=0.1, LayerNorm_type='WithBias'):super(TransformerBlock, self).__init__()self.attention = KANAttention(dim=in_channels, num_heads=num_heads)self.norm1 = LayerNorm(dim=in_channels, LayerNorm_type=LayerNorm_type)self.moe = FFN2(dim=in_channels) # 使用 FFN2self.norm2 = LayerNorm(dim=in_channels, LayerNorm_type=LayerNorm_type)def forward(self, x):f_in_normed_for_attn = self.norm1(x)attended_features = self.attention(f_in_normed_for_attn)x = x + attended_featuresf_hat_normed_for_ffn = self.norm2(x)ffn_features = self.moe(f_hat_normed_for_ffn)x = x + ffn_featuresreturn x# KANT 模型 (帶有形狀打印功能)
class KANT_ShapeTracer(nn.Module):def __init__(self, in_channels=3, out_channels=3, n_feat=31): # 為追蹤簡化參數super(KANT_ShapeTracer, self).__init__()print(f"--- KANT 模型初始化 ---")print(f"輸入通道數: {in_channels}, 輸出通道數: {out_channels}, 基礎特征數 (n_feat): {n_feat}\n")num_heads_start = 2num_experts = None # 未使用,因為直接使用 FFN2self.conv_in = nn.Conv2d(in_channels, n_feat, kernel_size=1, padding='same')print(f"  conv_in: Conv2d({in_channels}, {n_feat}, kernel_size=1)")# 第 1 層編碼器current_heads_l1 = num_heads_startself.transformer_block1_1 = TransformerBlock(n_feat, current_heads_l1, num_experts)print(f"  transformer_block1_1: TransformerBlock(n_feat={n_feat}, heads={current_heads_l1})")self.downsample1 = nn.Conv2d(n_feat, n_feat * 2, kernel_size=3, stride=2, padding=1)print(f"  downsample1: Conv2d({n_feat}, {n_feat*2}, kernel_size=3, stride=2)")# 第 2 層編碼器current_heads_l2 = num_heads_start # 在瓶頸層調整前,與l1保持一致 (根據原始代碼)self.transformer_block2_1 = TransformerBlock(n_feat * 2, current_heads_l2, num_experts)print(f"  transformer_block2_1: TransformerBlock(n_feat={n_feat*2}, heads={current_heads_l2})")self.transformer_block2_2 = TransformerBlock(n_feat * 2, current_heads_l2, num_experts)print(f"  transformer_block2_2: TransformerBlock(n_feat={n_feat*2}, heads={current_heads_l2})")self.downsample2 = nn.Conv2d(n_feat * 2, n_feat * 4, kernel_size=3, stride=2, padding=1)print(f"  downsample2: Conv2d({n_feat*2}, {n_feat*4}, kernel_size=3, stride=2)")# 瓶頸層current_heads_bn = current_heads_l2 * 2 # 瓶頸層頭數加倍self.bottleneck_1 = TransformerBlock(n_feat * 4, current_heads_bn, num_experts)print(f"  bottleneck_1: TransformerBlock(n_feat={n_feat*4}, heads={current_heads_bn})")self.bottleneck_2 = TransformerBlock(n_feat * 4, current_heads_bn, num_experts)print(f"  bottleneck_2: TransformerBlock(n_feat={n_feat*4}, heads={current_heads_bn})")# 第 2 層解碼器current_heads_up2 = current_heads_bn // 2self.upsample2 = nn.ConvTranspose2d(n_feat * 4, n_feat * 2, kernel_size=3, stride=2, padding=1, output_padding=1)print(f"  upsample2: ConvTranspose2d({n_feat*4}, {n_feat*2}, kernel_size=3, stride=2)")self.channel_adjust2 = nn.Conv2d(n_feat * 4, n_feat * 2, kernel_size=1) # 輸入是 n_feat*2 (上采樣) + n_feat*2 (跳躍) = n_feat*4print(f"  channel_adjust2: Conv2d({n_feat*4}, {n_feat*2}, kernel_size=1)")self.transformer_block_up2_1 = TransformerBlock(n_feat * 2, current_heads_up2, num_experts)print(f"  transformer_block_up2_1: TransformerBlock(n_feat={n_feat*2}, heads={current_heads_up2})")self.transformer_block_up2_2 = TransformerBlock(n_feat * 2, current_heads_up2, num_experts)print(f"  transformer_block_up2_2: TransformerBlock(n_feat={n_feat*2}, heads={current_heads_up2})")# 第 1 層解碼器current_heads_up1 = current_heads_up2 // 2self.upsample1 = nn.ConvTranspose2d(n_feat * 2, n_feat, kernel_size=3, stride=2, padding=1, output_padding=1)print(f"  upsample1: ConvTranspose2d({n_feat*2}, {n_feat}, kernel_size=3, stride=2)")self.channel_adjust1 = nn.Conv2d(n_feat * 2, n_feat, kernel_size=1) # 輸入是 n_feat (上采樣) + n_feat (跳躍) = n_feat*2print(f"  channel_adjust1: Conv2d({n_feat*2}, {n_feat}, kernel_size=1)")self.transformer_block_up1_1 = TransformerBlock(n_feat, current_heads_up1, num_experts)print(f"  transformer_block_up1_1: TransformerBlock(n_feat={n_feat}, heads={current_heads_up1})")self.conv_out = nn.Conv2d(n_feat, out_channels, kernel_size=1, padding='same')print(f"  conv_out: Conv2d({n_feat}, {out_channels}, kernel_size=1)")print(f"--- KANT 模型初始化結束 ---\n")def forward(self, x):print(f"\n--- KANT 前向傳播形狀追蹤 ---")print(f"初始輸入形狀: {x.shape}")x = self.conv_in(x)print(f"經過 conv_in 后: {x.shape}")# 編碼器路徑x1 = self.transformer_block1_1(x)print(f"經過 transformer_block1_1 (x1) 后: {x1.shape}")x1_down = self.downsample1(x1)print(f"經過 downsample1 (x1_down) 后: {x1_down.shape}")x2 = self.transformer_block2_1(x1_down)print(f"經過 transformer_block2_1 后: {x2.shape}")x2 = self.transformer_block2_2(x2)print(f"經過 transformer_block2_2 (x2) 后: {x2.shape}")x2_down = self.downsample2(x2)print(f"經過 downsample2 (x2_down) 后: {x2_down.shape}")# 瓶頸層bn = self.bottleneck_1(x2_down)print(f"經過 bottleneck_1 后: {bn.shape}")bn = self.bottleneck_2(bn)print(f"經過 bottleneck_2 (bn) 后: {bn.shape}")# 解碼器路徑x2_up_pre_cat = self.upsample2(bn)print(f"經過 upsample2 (x2_up_pre_cat) 后: {x2_up_pre_cat.shape}")x2_up = torch.cat([x2_up_pre_cat, x2], dim=1)print(f"經過 cat([x2_up_pre_cat, x2]) 后: {x2_up.shape}")x2_up = self.channel_adjust2(x2_up)print(f"經過 channel_adjust2 (x2_up) 后: {x2_up.shape}")x2_up = self.transformer_block_up2_1(x2_up)print(f"經過 transformer_block_up2_1 后: {x2_up.shape}")x2_up = self.transformer_block_up2_2(x2_up)print(f"經過 transformer_block_up2_2 (x2_up) 后: {x2_up.shape}")x1_up_pre_cat = self.upsample1(x2_up)print(f"經過 upsample1 (x1_up_pre_cat) 后: {x1_up_pre_cat.shape}")x1_up = torch.cat([x1_up_pre_cat, x1], dim=1)print(f"經過 cat([x1_up_pre_cat, x1]) 后: {x1_up.shape}")x1_up = self.channel_adjust1(x1_up)print(f"經過 channel_adjust1 (x1_up) 后: {x1_up.shape}")x1_up = self.transformer_block_up1_1(x1_up)print(f"經過 transformer_block_up1_1 (x1_up) 后: {x1_up.shape}")x_out = self.conv_out(x1_up)print(f"經過 conv_out (最終輸出) 后: {x_out.shape}")print(f"--- KANT 前向傳播形狀追蹤結束 ---\n")return x_out# --- KANT.py 內容結束 ---if __name__ == '__main__':# 用于測試的示例參數batch_size = 1input_channels = 3height, width = 256, 256 # 示例圖像尺寸n_features = 31 # 基礎特征數量,與 KANT 類中一致# 創建一個虛擬輸入張量dummy_input = torch.randn(batch_size, input_channels, height, width)print(f"正在創建 KANT_ShapeTracer 模型,n_feat={n_features}...")# 實例化模型model = KANT_ShapeTracer(in_channels=input_channels, out_channels=input_channels, n_feat=n_features)# 執行一次前向傳播以追蹤形狀print(f"使用形狀為 {dummy_input.shape} 的虛擬輸入執行前向傳播")with torch.no_grad(): # 追蹤形狀時無需計算梯度output = model(dummy_input)print(f"最終輸出張量形狀: {output.shape}")# 你也可以打印模型結構# print("\n模型結構:")# print(model)

如何運行形狀追蹤腳本:

  1. 將上述代碼保存為一個 Python 文件 (例如, trace_kant_shape_cn.py)。
  2. 從終端運行它: python trace_kant_shape_cn.py

這將打印 KANT_ShapeTracer 模型在初始化時的配置,然后在虛擬輸入通過 forward 方法中的每個重要層/操作時追蹤張量形狀。n_feat 參數設置為 31,與你的 KANT 類默認值一致。你可以更改 heightwidth 來測試不同的輸入分辨率。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/906844.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/906844.shtml
英文地址,請注明出處:http://en.pswp.cn/news/906844.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

naivechain:簡易區塊鏈實現

naivechain:簡易區塊鏈實現 naivechain A naive and simple implementation of blockchains. 項目地址: https://gitcode.com/gh_mirrors/nai/naivechain 項目介紹 naivechain 是一個簡單且易于理解的區塊鏈實現項目。它使用 Go 語言編寫,以極簡…

Zabbix開源監控的全面詳解!

一、zabbix的基本概述 zabbix,這款企業級監控軟件,能全方位監控各類網絡參數,確保企業服務架構的安全穩定運行。它提供了靈活多樣的告警機制,幫助運維人員迅速發現并解決問題。此外,zabbix還具備分布式監控功能&#…

軟考軟件評測師——軟件工程之開發模型與方法

目錄 一、核心概念 二、主流模型詳解 (一)經典瀑布模型 (二)螺旋演進模型 (三)增量交付模型 (四)原型驗證模型 (五)敏捷開發實踐 三、模型選擇指南 四…

50天50個小項目 (Vue3 + Tailwindcss V4) ? | Blurry Loading (毛玻璃加載)

📅 我們繼續 50 個小項目挑戰!—— Blurry Loading 組件 倉庫地址:https://github.com/SunACong/50-vue-projects 項目預覽地址:https://50-vue-projects.vercel.app/ ? 組件目標 實現一個加載進度條,隨著加載進度的…

WPF性能優化之延遲加載(解決頁面卡頓問題)

文章目錄 前言一. 基礎知識回顧二. 問題分析三. 解決方案1. 新建一個名為DeferredContentHost的控件。2. 在DeferredContentHost控件中定義一個名為Content的object類型的依賴屬性,用于承載要加載的子控件。3. 在DeferredContentHost控件中定義一個名為Skeleton的ob…

VLM-MPC:自動駕駛中模型預測控制器增強視覺-語言模型

《VLM-MPC: Model Predictive Controller Augmented Vision Language Model for Autonomous Driving》2024年8月發表,來自威斯康星大學的論文。 受視覺語言模型(VLM)的緊急推理能力及其提高自動駕駛系統可理解性的潛力的啟發,本文…

推薦系統里真的存在“反饋循環”嗎?

推薦系統里真的存在“反饋循環”嗎? 許多人說,推薦算法不過是把用戶早已存在的興趣挖掘出來,你本來就愛聽流行歌、買潮牌玩具,系統只是在合適的時間把它們端到你面前,再怎么迭代,算法也改變不了人的天性&a…

代碼混淆技術的還原案例

案例一 eval 混淆 特征 : 反常的 eval 連接了一堆數據 練習網站 https://scrape.center/ spa9 這個案例 基本的還原方法 但是這個代碼還是非常的模糊不好看 優化一下 : 當然還有更快捷的方法 : 好用的 js混淆還原的 web &#xf…

鴻蒙Flutter實戰:22-混合開發詳解-2-Har包模式引入

以 Har 包的方式加載到 HarmonyOS 工程 創建工作 創建一個根目錄 mkdir ohos_flutter_module_demo這個目錄用于存放 flutter 項目和鴻蒙項目。 創建 Flutter 模塊 首先創建一個 Flutter 模塊,我們選擇與 ohos_app 項目同級目錄 flutter create --templatemodu…

Go核心特性與并發編程

Go核心特性與并發編程 1. 結構體與方法(擴展) 高級結構體特性 // 嵌套結構體與匿名字段 type Employee struct {Person // 匿名嵌入Department stringsalary float64 // 私有字段 }// 構造函數模式 func NewPerson(name string, age int) *Pe…

Java 函數式接口(Functional Interface)

一、理論說明 1. 函數式接口的定義 Java 函數式接口是一種特殊的接口,它只包含一個抽象方法(Single Abstract Method, SAM),但可以包含多個默認方法或靜態方法。函數式接口是 Java 8 引入 Lambda 表達式的基礎,通過函…

【python代碼】一些小實驗

目錄 1. 測試Resnet50 ONNX模型的推理速度 1. 測試Resnet50 ONNX模型的推理速度 ############################### # 導出resnet50 模型 # 測試onnx模型推理 cpu 和 GPU 的對比 ###############################import time import numpy as np import onnxruntime as ort im…

5.Java 面向對象編程入門:類與對象的創建和使用?

在現實生活中,我們常常會接觸到各種各樣的對象,比如一輛汽車、一個學生、一部手機等。這些對象都具有各自的屬性和行為。例如,汽車有顏色、品牌、型號等屬性,還有啟動、加速、剎車等行為;學生有姓名、年齡、學號等屬性…

從開發者角度看數據庫架構進化史:JDBC - 中間件 - TiDB

作者: Lucien-盧西恩 原文來源: https://tidb.net/blog/e7034d1b Java 應用開發技術發展歷程 在業務開發早期,用 Java 借助 JDBC 進行數據庫操作,雖能實現基本交互,但需手動管理連接、編寫大量 SQL 及處理結果集&a…

工業智能網關建立烤漆設備故障預警及遠程診斷系統

一、項目背景 烤漆房是汽車、機械、家具等工業領域廣泛應用的設備,主要用于產品的表面涂裝。傳統的烤漆房控制柜采用本地控制方式,操作人員需在現場進行參數設置和設備控制,且存在設備智能化程度低、數據孤島、設備維護成本高以及依靠傳統人…

故障率預測:基于LSTM的GPU集群硬件健康監測系統(附Prometheus監控模板)

一、GPU集群健康監測的挑戰與價值 在大規模深度學習訓練場景下,GPU集群的硬件故障率顯著高于傳統計算設備。根據2023年MLCommons統計,配備8卡A100的服務器平均故障間隔時間(MTBF)僅為1426小時,其中顯存故障占比達38%&…

Vue 樣式不一致問題全面分析與解決方案

文章目錄 1. 問題概述1.1 問題表現1.2 問題影響 2. 根本原因分析2.1 Vue 的渲染機制與樣式加載時機2.2 Scoped CSS 的工作原理2.3 CSS 模塊化與作用域隔離2.4 樣式加載順序問題2.5 熱重載(HMR)與樣式更新 3. 解決方案3.1 確保樣式加載順序3.1.1 預加載關鍵 CSS3.1.2 控制全局樣…

[免費]微信小程序寵物醫院管理系統(uni-app+SpringBoot后端+Vue管理端)【論文+源碼+SQL腳本】

大家好,我是java1234_小鋒老師,看到一個不錯的微信小程序寵物醫院管理系統(uni-appSpringBoot后端Vue管理端),分享下哈。 項目視頻演示 【免費】微信小程序寵物醫院管理系統(uni-appSpringBoot后端Vue管理端) Java畢業設計_嗶哩嗶哩_bilibi…

測試總結(一)

一、測試流程 參與需求評審-制定測試計劃-編寫測試用例-用例評審-冒煙測試-測試執行-缺陷管理-預發驗收測試-發布線上-線上回歸-線上觀察-項目總結 二、測試用例設計方法 等價類劃分(處理有效/無效輸入) 邊界值分析(臨界值測試&#xff09…

SAP-ABAP:ABAP異常處理與SAP現代技術融合—— 面向云原生、微服務與低代碼場景的創新實踐

專題三:ABAP異常處理與SAP現代技術融合 —— 面向云原生、微服務與低代碼場景的創新實踐 一、SAP技術演進與異常處理的挑戰 隨著SAP技術棧向云端、微服務化和低代碼方向演進,異常處理面臨新場景: Fiori UX敏感度:用戶期望前端友…