深度學習與遙感入門(六)|輕量化 MobileNetV2 高光譜分類

系列回顧:
(一)CNN 基礎:高光譜圖像分類可視化全流程
(二)HybridNet(CNN+Transformer):提升全局感受野
(三)GCN 入門實戰:基于光譜 KNN 的圖卷積分類與全圖預測
(四)空間–光譜聯合構圖的 GCN:RBF 邊權 + 自環 + 早停,得到更穩更自然的全圖分類結果
(五)GAT & 構圖消融 + 分塊全圖預測:更穩更快的高光譜圖分類(PyTorch Geometric 實戰)
合集鏈接:https://mp.weixin.qq.com/mp/appmsgalbum?__biz=MzkwMTE0MjI4NQ==&action=getalbum&album_id=4007114522736459789#wechat_redirect
本篇(六)聚焦“數據泄露”,采用僅訓練集像素擬合 StandardScaler+PCA,并在全圖預測中共享同一變換空間;模型選用輕量化 MobileNetV2深度可分離卷積,在顯存友好的坐標批推理下實現全圖預測

0. 前言:PCA 與高光譜分類中的“數據泄露”

  • 什么是泄露? 訓練階段直接/間接使用了測試數據統計信息(均值、方差、主成分方向等)。
  • 怎么產生? 在整圖上 fit 標準化與 PCA,然后再切訓練/測試或直接做全圖分類。
  • 為什么常見? 歷史習慣、樣本少時穩定性考慮、對比研究圖省事、實現方便。
  • 影響大嗎? 小數據集上常為 0.1%~1% 的 OA 差異;但在類分布差異大訓練樣本極少時,差距可達數個百分點。真實部署場景絕不允許整圖 fit

本文做法先分層抽樣得到訓練/測試索引僅用訓練像素 fit 標準化與 PCA用該變換對整圖 transform訓練與預測均在同一(訓練集擬合得到的)特征空間,從源頭避免泄露。

1. 任務要點

  1. 嚴格無泄露預處理:只用訓練像素擬合 StandardScaler+PCA;全圖在同一變換空間中變換。
  2. 輕量模型:用 MobileNetV2 的深度可分離卷積(3 段)+ GAP + FC。
  3. 全圖預測顯存友好:按坐標批收集 patch → 堆成 batch → 前向推理。
  4. 評估classification_report、混淆矩陣、OA;可視化支持 Windows 阻塞顯示。

2. 方法詳解

2.1 嚴格無泄露的 PCA 流程

  • 先劃分后擬合:對有標簽像素做分層抽樣得到訓練/測試索引;僅訓練像素擬合 StandardScalerPCA
  • 全圖共享空間:將整圖 (H×W×Bands) 用訓練集擬合的變換進行標準化與降維,得到 (H×W×PCA_DIM)
  • 提取 patch:在 PCA 空間內按坐標提取 (PATCH_SIZE×PATCH_SIZE×PCA_DIM) 的 patch 作為輸入。

這樣做的關鍵測試像素從未參與統計,評估更可信。

2.2 輕量化 MobileNetV2(HSI 版)

  • Depthwise Separable Conv:逐通道 3×3 深度卷積 + 1×1 點卷積,大幅降參與算力需求。
  • 網絡骨干:3 段深度可分離卷積 → GAP(自適應全局平均池化)→ 全連接輸出。
  • 輸入通道:這里輸入為 PCA 后的通道數(例如 30),以二維 patch 形式輸入(C×H×W)。

2.3 全圖預測策略(坐標批)

  • 坐標遍歷:生成所有像素坐標。
  • 反射填充:邊界像素也能提取完整 patch。
  • 批量收集:按 batch_size 組裝 patch → 前向 → 填回 pred_map
  • 顯存穩定:避免一次性張量過大導致溢出。

3. 代碼逐段 + 解釋

下面先按邏輯分段展示與解釋;最末提供“一鍵可跑腳本(整合版)”,復制后僅需修改數據路徑即可運行。

3.1 全局與可視化設置

import os, time, numpy as np, scipy.io as sio
import torch, torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
from sklearn.model_selection import train_test_split
import matplotlib# Windows 下 TkAgg 更穩;Linux/服務器用 Agg(無顯示)
if os.name == 'nt':matplotlib.use('TkAgg')
else:matplotlib.use('Agg')import matplotlib.pyplot as plt
import seaborn as snsmatplotlib.rcParams['font.family'] = 'SimHei'
matplotlib.rcParams['axes.unicode_minus'] = False
plt.rcParams['figure.dpi'] = 120
sns.set_theme(context="notebook", style="whitegrid", font="SimHei")torch.backends.cudnn.benchmark = True

3.2 隨機種子

def set_seeds(seed=42):import randomrandom.seed(seed)np.random.seed(seed)torch.manual_seed(seed)torch.cuda.manual_seed_all(seed)

固定隨機性,保證復現。

3.3 輕量化網絡

class DepthwiseSeparableConv(nn.Module):def __init__(self, in_ch, out_ch, stride=1):super().__init__()self.depthwise = nn.Conv2d(in_ch, in_ch, 3, stride=stride, padding=1, groups=in_ch, bias=False)self.pointwise = nn.Conv2d(in_ch, out_ch, 1, bias=False)self.bn = nn.BatchNorm2d(out_ch)self.act = nn.ReLU6(inplace=True)def forward(self, x):x = self.depthwise(x)x = self.pointwise(x)x = self.bn(x)return self.act(x)class MobileNetV2_HSI(nn.Module):def __init__(self, in_ch, num_classes, width_mult=1.0):super().__init__()c1, c2, c3 = int(32 * width_mult), int(64 * width_mult), int(128 * width_mult)self.layer1 = DepthwiseSeparableConv(in_ch, c1)self.layer2 = DepthwiseSeparableConv(c1, c2)self.layer3 = DepthwiseSeparableConv(c2, c3)self.gap = nn.AdaptiveAvgPool2d(1)self.fc = nn.Linear(c3, num_classes)def forward(self, x):x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.gap(x).flatten(1)return self.fc(x)

結構極簡但實用:3 段深度可分離卷積 + GAP + FC,對 HSI 的小樣本/低算力場景友好。

3.4 數據集與 Patch 封裝

class HSIPatchDataset(Dataset):def __init__(self, patches, labels):# patches: (N, H, W, C) → 張量 (N, C, H, W)self.X = torch.tensor(patches, dtype=torch.float32).permute(0, 3, 1, 2)self.y = torch.tensor(labels, dtype=torch.long)def __len__(self): return len(self.y)def __getitem__(self, idx): return self.X[idx], self.y[idx]

3.5 全圖預測(坐標批推理)

@torch.inference_mode()
def predict_full_image_by_coords(model, X_img_pca, patch_size, device,batch_size=2048, title="全圖預測(坐標批推理)"):model.eval()H, W, C = X_img_pca.shapem = patch_size // 2padded = np.pad(X_img_pca, ((m, m), (m, m), (0, 0)), mode='reflect')coords = np.mgrid[0:H, 0:W].reshape(2, -1).Tpred_map = np.zeros((H, W), dtype=np.int32)t0 = time.time()for i in range(0, len(coords), batch_size):batch_coords = coords[i:i + batch_size]patches = np.empty((len(batch_coords), patch_size, patch_size, C), dtype=np.float32)for k, (r, c) in enumerate(batch_coords):patches[k] = padded[r:r + patch_size, c:c + patch_size, :]tensor = torch.from_numpy(patches).permute(0, 3, 1, 2).to(device)preds = model(tensor).argmax(dim=1).cpu().numpy() + 1  # +1 便于和 GT 對齊for (r, c), p in zip(batch_coords, preds):pred_map[r, c] = pprint(f"全圖預測耗時:{time.time() - t0:.2f} 秒")# 可視化(阻塞顯示,避免“最后的圖沒有顯示”)try:plt.figure(figsize=(10, 7.5))cmap = matplotlib.colormaps.get_cmap('tab20')vmin, vmax = pred_map.min(), pred_map.max()if vmin == vmax: vmin, vmax = 0, 1im = plt.imshow(pred_map, cmap=cmap, interpolation='nearest', vmin=vmin, vmax=vmax)cbar = plt.colorbar(im, shrink=0.85); cbar.set_label('預測類別', rotation=90)plt.title(title, fontsize=14, weight='bold'); plt.axis('off'); plt.tight_layout()print("嘗試顯示全圖預測結果...")plt.show(block=True)except Exception as e:print(f"顯示全圖預測結果時出錯: {e}")try:plt.savefig("prediction_map.png", bbox_inches='tight')print("已保存為 prediction_map.png")except Exception as se:print(f"保存失敗: {se}")return pred_map

3.6 主流程(數據→劃分→無泄露預處理→訓練→評估→全圖預測)

下面是主函數的關鍵片段(末尾附完整可運行腳本):

def main():set_seeds(42)device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')print(f"使用設備: {device}")# ---- 路徑與超參(按需修改)----DATA_DIR = r"your_path"X_FILE, Y_FILE = "KSC.mat", "KSC_gt.mat"PCA_DIM, PATCH_SIZE, TRAIN_RATIO = 30, 5, 0.30BATCH_SIZE, EPOCHS, LR, WEIGHT_DECAY = 64, 30, 1e-3, 1e-4NUM_WORKERS = 0 if os.name == 'nt' else min(4, os.cpu_count() or 0)PIN_MEMORY = (device.type == 'cuda')PREDICT_BATCH_SIZE = 4096# ---- 讀取數據 ----def load_data():X = sio.loadmat(os.path.join(DATA_DIR, X_FILE))Y = sio.loadmat(os.path.join(DATA_DIR, Y_FILE))x_key = [k for k in X.keys() if not k.startswith("__")][0]y_key = [k for k in Y.keys() if not k.startswith("__")][0]return X[x_key], Y[y_key]X_img, Y_img = load_data()h, w, bands = X_img.shapeprint(f"數據尺寸: {h}×{w}, 波段: {bands}")# ---- 有標簽索引 + 分層劃分 ----labeled_idx_rc = np.array([(i, j) for i in range(h) for j in range(w) if Y_img[i, j] != 0])labels_all = np.array([Y_img[i, j] - 1 for i, j in labeled_idx_rc], dtype=np.int64)num_classes = len(np.unique(labels_all))print(f"有標簽樣本: {len(labeled_idx_rc)},類別數: {num_classes}")train_ids, test_ids = train_test_split(np.arange(len(labeled_idx_rc)),test_size=1 - TRAIN_RATIO, stratify=labels_all, random_state=42)# ---- 僅訓練像素擬合 Scaler+PCA(無泄露)----print("擬合 StandardScaler/PCA(僅訓練像素)...")train_pixels = np.array([X_img[i, j] for i, j in labeled_idx_rc[train_ids]], dtype=np.float32)scaler = StandardScaler().fit(train_pixels)pca = PCA(n_components=PCA_DIM, random_state=42).fit(scaler.transform(train_pixels))# 整圖進入同一空間(float32)X_pca_img = pca.transform(scaler.transform(X_img.reshape(-1, bands).astype(np.float32))).astype(np.float32)X_pca_img = X_pca_img.reshape(h, w, PCA_DIM)# ---- 提取訓練/測試 patch ----def extract_patches(sel_ids):m = PATCH_SIZE // 2padded = np.pad(X_pca_img, ((m, m), (m, m), (0, 0)), mode='reflect')patches = np.empty((len(sel_ids), PATCH_SIZE, PATCH_SIZE, PCA_DIM), dtype=np.float32)labs = np.empty((len(sel_ids),), dtype=np.int64)for n, k in enumerate(sel_ids):i, j = labeled_idx_rc[k]patches[n] = padded[i:i + PATCH_SIZE, j:j + PATCH_SIZE, :]labs[n] = labels_all[k]return patches, labsX_train, y_train = extract_patches(train_ids)X_test, y_test = extract_patches(test_ids)# ---- DataLoader ----train_loader = DataLoader(HSIPatchDataset(X_train, y_train), batch_size=BATCH_SIZE,shuffle=True, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)test_loader = DataLoader(HSIPatchDataset(X_test, y_test), batch_size=BATCH_SIZE,shuffle=False, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)# ---- 模型與優化器 ----model = MobileNetV2_HSI(PCA_DIM, num_classes).to(device)optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)criterion = nn.CrossEntropyLoss()scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=3)# ---- 評估函數 ----@torch.no_grad()def evaluate(loader):model.eval()all_y, all_pred = [], []for xb, yb in loader:xb = xb.to(device)pred = model(xb).argmax(dim=1).cpu().numpy()all_pred.extend(pred); all_y.extend(yb.numpy())return accuracy_score(all_y, all_pred), np.array(all_y), np.array(all_pred)# ---- 訓練循環 ----print("開始訓練...")best_acc, model_path = 0.0, "best_mnv2_hsi.pth"for epoch in range(1, EPOCHS + 1):model.train(); total_loss = 0.0for xb, yb in train_loader:xb, yb = xb.to(device), yb.to(device)optimizer.zero_grad()loss = criterion(model(xb), yb)loss.backward(); optimizer.step()total_loss += loss.item() * xb.size(0)test_acc, _, _ = evaluate(test_loader)scheduler.step(test_acc)print(f"Epoch {epoch:02d}/{EPOCHS} | 損失: {total_loss/len(train_loader.dataset):.4f} | 測試準確率: {test_acc:.4f}")if test_acc > best_acc:best_acc = test_acctorch.save(model.state_dict(), model_path)print(f"訓練完成,最佳測試準確率:{best_acc:.4f}")# ---- 安全加載最佳權重 ----try:state = torch.load(model_path, map_location=device, weights_only=True)except TypeError:state = torch.load(model_path, map_location=device)model.load_state_dict(state)# ---- 測試報告 & 混淆矩陣 ----test_acc, y_true, y_pred = evaluate(test_loader)print("\n測試集分類報告:")print(classification_report(y_true, y_pred, digits=4, zero_division=0))plt.figure(figsize=(10, 7))class_names = [f"類{i + 1}" for i in range(num_classes)]sns.heatmap(confusion_matrix(y_true, y_pred),annot=True, fmt='d', cmap="Blues",xticklabels=class_names, yticklabels=class_names,cbar=False, square=True)plt.xlabel("預測標簽"); plt.ylabel("真實標簽")plt.title("MobileNetV2 測試集混淆矩陣", fontsize=14, weight='bold')plt.tight_layout(); plt.show(block=True)# ---- 全圖預測 ----print("全圖預測中(坐標→收集 patch→堆成 batch→前向)...")pred_map = predict_full_image_by_coords(model, X_pca_img, patch_size=PATCH_SIZE, device=device,batch_size=PREDICT_BATCH_SIZE, title="MobileNetV2 全圖預測(坐標批推理)")print(f"預測圖統計: min={pred_map.min()}, max={pred_map.max()}, mean={pred_map.mean():.3f}")print("完成。")

3.7 Windows 入口保護(多進程/顯示更穩)

if __name__ == "__main__":try:import multiprocessing as mpmp.set_start_method("spawn", force=True)mp.freeze_support()except Exception:passmain()

4. 結果展示

在這里插入圖片描述
在這里插入圖片描述
歡迎大家關注下方我的公眾獲取更多內容!

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/pingmian/93193.shtml
繁體地址,請注明出處:http://hk.pswp.cn/pingmian/93193.shtml
英文地址,請注明出處:http://en.pswp.cn/pingmian/93193.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

第4節 神經網絡從公式簡化到卷積神經網絡(CNN)的進化之路

?? 深度學習的"玄學進化史" 從CNN用卷積層池化層處理圖片,循環網絡RNN如何利用上下文處理序列數據,到注意力機制讓Transformer橫空出世,現在的大語言模型已經能寫能畫能決策!每個新技巧都讓人驚呼"還能這么玩",難怪說深度學習像玄學——但這玄學,…

最新去水印小程序系統 前端+后端全套源碼 多套模版 免授權(源碼下載)

最新去水印小程序系統 前端后端全套源碼 多套模版 免授權 源碼下載:https://download.csdn.net/download/m0_66047725/91669468 更多資源下載:關注我

TCP Socket 編程實戰:實現簡易英譯漢服務

前言:TCP(傳輸控制協議)是一種面向連接、可靠的流式傳輸協議,與 UDP 的無連接特性不同,它通過三次握手建立連接、四次揮手斷開連接,提供數據確認、重傳機制,保證數據有序且完整傳輸。本文將基于…

CF566C Logistical Questions Solution

Description 給定一棵 nnn 個點的樹 TTT,點有點權 aia_iai?,邊有邊權 www. 定義 dist?(u,v)\operatorname{dist}(u,v)dist(u,v) 為 u→vu\to vu→v 的簡單路徑上的邊權和. 找到一個節點 uuu,使得 W∑i1ndist?(u,i)32aiW\sum\limits_{i1}^n…

聊天室全棧開發-保姆級教程(Node.js+Websocket+Redis+HTML+CSS)

前言 最近在學習websocket全雙工通信,想要做一個聯機小游戲,做游戲之前先做一個聊天室練練手。 跟著本篇博客,可以從0搭建一個屬于你自己的聊天室。 準備階段 什么人適合學習本篇文章? 答:前端開發者,有一…

后臺管理系統-2-vue3之路由配置和Main組件的初步搭建布局

文章目錄1 路由搭建1.1 路由創建(router/index.js)1.2 路由組件(views/Main.vue)1.3 路由引入并注冊(main.js)1.4 路由渲染(App.vue)2 element-plus的應用2.1 完整引入并注冊(main.js)2.2 示例應用(App.vue)3 ElementPlusIconsVue的應用3.1 圖標引入并注冊(main.js)3.2 示例應用…

使用 Let’s Encrypt 免費申請泛域名 SSL 證書,并實現自動續期

使用 Let’s Encrypt 免費申請泛域名 SSL 證書,并實現自動續期 目錄 使用 Let’s Encrypt 免費申請泛域名 SSL 證書,并實現自動續期 🛠? 環境準備💡 什么是 Let’s Encrypt?🧠 Let’s Encrypt 證書頒發原…

一鍵自動化:Kickstart無人值守安裝指南

Kickstart文件實現自動安裝1. Kickstart文件概述1.1 定義與作用Kickstart文件是Red Hat系Linux發行版(如RHEL、CentOS、Fedora)用于實現自動化安裝的配置文件,采用純文本格式保存。它通過預設安裝參數的方式,使系統安裝過程無需人…

深度解讀 Browser-Use:讓 AI 驅動瀏覽器自動化成為可能

目錄 一、什么是 Browser-Use? 二、Browser-Use 的核心功能 1. AI 與瀏覽器的鏈接橋梁 2. 無代碼 / 低代碼操作界面 3. 支持多家 LLM 4. 開發體驗簡潔 可快速上手 三、核心價值與適用場景 四、與 Playwright 的結合使用 五、總結與展望 https://github.com…

React.memo、useMemo 和 React.PureComponent的區別

useMemo 和 React.memo 都是 React 提供的性能優化工具,但它們的作用和使用場景有顯著不同。以下是兩者的全面對比: 一、核心區別總結特性useMemoReact.memo類型React Hook高階組件(HOC)作用對象緩存計算結果緩存組件渲染結果優化目標避免重復計算避免不…

Lumerical INTERCONNECT ------ CW Laser 和 OPWM 組成的系統

Lumerical INTERCONNECT ------ CW Laser 和 OPWM 組成的系統 引言 正文 引言 這里我們來簡單介紹一下 CW Laser 與 OSA 組成的簡單系統結構的仿真。 正文 我們構建一個如下圖所示的仿真結構。 我們將 CWL 中的 power 設置為 1 W。 然后直接運行仿真查看結果如下: 雖然 …

想漲薪30%?別只盯著大廠了!轉型AI產品經理的3個通用方法,人人都能學!

在AI產品經理剛成為互聯網公司香餑餑的時候,剛做產品1年的月月就規劃了自己的轉型計劃,然后用3個月時間成功更換賽道,轉戰AI產品經理,漲薪30%。 問及她有什么上岸秘訣?她也復盤總結了3個踩坑經驗和正確路徑&#xff0c…

基于Hadoop的全國農產品批發價格數據分析與可視化與價格預測研究

文章目錄有需要本項目的代碼或文檔以及全部資源,或者部署調試可以私信博主項目介紹每文一語有需要本項目的代碼或文檔以及全部資源,或者部署調試可以私信博主 項目介紹 隨著我國農業數字化進程的加快,農產品批發市場每天都會產生海量的價格…

STM32在使用DMA發送和接收時的模式區別

在STM32的DMA傳輸中,發送使用DMA_Mode_Normal而接收使用DMA_Mode_Circular的設計基于以下關鍵差異:1. ?觸發機制的本質區別??發送方向(TX)?:由USART的?TXE標志(發送寄存器空)觸發?&#x…

【秋招筆試】2025.08.15餓了么秋招機考-第三題

?? 點擊直達筆試專欄 ??《大廠筆試突圍》 ?? 春秋招筆試突圍在線OJ ?? 筆試突圍在線刷題 bishipass.com 03. A先生的商貿網絡投資 問題描述 A先生是一位精明的商人,他計劃在 n n n 個城市之間建立商貿網絡。目前有 m m

Socket 套接字的學習--UDP

上次我們大概介紹了一些關于網絡的基礎知識,這次我們利用編程來深入學習一下一:套接字Socket1.1什么是Socketsocket API 是一層抽象的網絡編程接口,適用于各種底層網絡協議,如 IPv4、IPv6,. 然而, 各種網絡協議的地址格式并不相同。1.2套接字的分類套接字…

AI - MCP 協議(一)

AI應用開發的高級特性——MCP模型上下文協議,打通AI與外部服務的邊界。 ************************************************************************************************************** 一、需求分析 當你的AI具備了RAG的能力,具備了調用工具的…

在es中安裝kibana

一 安裝 1.1 驗證訪問https的連通性 # 測試 80 端口(HTTP) curl -I -m 5 http://目標IP:端口號 說明: -I:僅獲取 HTTP 頭部(Head 請求),不下載正文,減少數據傳輸。 -m 5&#x…

嵌入式開發學習———Linux環境下網絡編程學習(二)

UDP服務器客戶端搭建UDP服務器代碼#include <stdio.h> #include <stdlib.h> #include <string.h> #include <sys/socket.h> #include <netinet/in.h>#define PORT 8080 #define BUFFER_SIZE 1024int main() {int sockfd;char buffer[BUFFER_SIZE…

UVa1465/LA4841 Searchlights

UVa12345 UVa1465/LA4841 Searchlights題目鏈接題意輸入格式輸出格式分析AC 代碼題目鏈接 本題是2010年icpc亞洲區域賽杭州賽區的I題 題意 在一個 n 行 m 列&#xff08;n≤100&#xff0c;m≤10 000&#xff09;的網格中有一些探照燈&#xff0c;每個探照燈有一個最大亮度 k&…