系列回顧:
(一)CNN 基礎:高光譜圖像分類可視化全流程
(二)HybridNet(CNN+Transformer):提升全局感受野
(三)GCN 入門實戰:基于光譜 KNN 的圖卷積分類與全圖預測
(四)空間–光譜聯合構圖的 GCN:RBF 邊權 + 自環 + 早停,得到更穩更自然的全圖分類結果
(五)GAT & 構圖消融 + 分塊全圖預測:更穩更快的高光譜圖分類(PyTorch Geometric 實戰)
合集鏈接:https://mp.weixin.qq.com/mp/appmsgalbum?__biz=MzkwMTE0MjI4NQ==&action=getalbum&album_id=4007114522736459789#wechat_redirect
本篇(六):聚焦“數據泄露”,采用僅訓練集像素擬合 StandardScaler+PCA,并在全圖預測中共享同一變換空間;模型選用輕量化 MobileNetV2 的深度可分離卷積,在顯存友好的坐標批推理下實現全圖預測。
0. 前言:PCA 與高光譜分類中的“數據泄露”
- 什么是泄露? 訓練階段直接/間接使用了測試數據統計信息(均值、方差、主成分方向等)。
- 怎么產生? 在整圖上
fit
標準化與 PCA,然后再切訓練/測試或直接做全圖分類。 - 為什么常見? 歷史習慣、樣本少時穩定性考慮、對比研究圖省事、實現方便。
- 影響大嗎? 小數據集上常為 0.1%~1% 的 OA 差異;但在類分布差異大或訓練樣本極少時,差距可達數個百分點。真實部署場景絕不允許整圖
fit
。
本文做法:先分層抽樣得到訓練/測試索引 → 僅用訓練像素
fit
標準化與 PCA → 用該變換對整圖 transform → 訓練與預測均在同一(訓練集擬合得到的)特征空間,從源頭避免泄露。
1. 任務要點
- 嚴格無泄露預處理:只用訓練像素擬合
StandardScaler+PCA
;全圖在同一變換空間中變換。 - 輕量模型:用 MobileNetV2 的深度可分離卷積(3 段)+ GAP + FC。
- 全圖預測顯存友好:按坐標批收集 patch → 堆成 batch → 前向推理。
- 評估:
classification_report
、混淆矩陣、OA;可視化支持 Windows 阻塞顯示。
2. 方法詳解
2.1 嚴格無泄露的 PCA 流程
- 先劃分后擬合:對有標簽像素做分層抽樣得到訓練/測試索引;僅訓練像素擬合
StandardScaler
和PCA
。 - 全圖共享空間:將整圖
(H×W×Bands)
用訓練集擬合的變換進行標準化與降維,得到(H×W×PCA_DIM)
。 - 提取 patch:在 PCA 空間內按坐標提取
(PATCH_SIZE×PATCH_SIZE×PCA_DIM)
的 patch 作為輸入。
這樣做的關鍵:測試像素從未參與統計,評估更可信。
2.2 輕量化 MobileNetV2(HSI 版)
- Depthwise Separable Conv:逐通道 3×3 深度卷積 + 1×1 點卷積,大幅降參與算力需求。
- 網絡骨干:3 段深度可分離卷積 → GAP(自適應全局平均池化)→ 全連接輸出。
- 輸入通道:這里輸入為 PCA 后的通道數(例如 30),以二維 patch 形式輸入(
C×H×W
)。
2.3 全圖預測策略(坐標批)
- 坐標遍歷:生成所有像素坐標。
- 反射填充:邊界像素也能提取完整 patch。
- 批量收集:按 batch_size 組裝 patch → 前向 → 填回
pred_map
。 - 顯存穩定:避免一次性張量過大導致溢出。
3. 代碼逐段 + 解釋
下面先按邏輯分段展示與解釋;最末提供“一鍵可跑腳本(整合版)”,復制后僅需修改數據路徑即可運行。
3.1 全局與可視化設置
import os, time, numpy as np, scipy.io as sio
import torch, torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
from sklearn.model_selection import train_test_split
import matplotlib# Windows 下 TkAgg 更穩;Linux/服務器用 Agg(無顯示)
if os.name == 'nt':matplotlib.use('TkAgg')
else:matplotlib.use('Agg')import matplotlib.pyplot as plt
import seaborn as snsmatplotlib.rcParams['font.family'] = 'SimHei'
matplotlib.rcParams['axes.unicode_minus'] = False
plt.rcParams['figure.dpi'] = 120
sns.set_theme(context="notebook", style="whitegrid", font="SimHei")torch.backends.cudnn.benchmark = True
3.2 隨機種子
def set_seeds(seed=42):import randomrandom.seed(seed)np.random.seed(seed)torch.manual_seed(seed)torch.cuda.manual_seed_all(seed)
固定隨機性,保證復現。
3.3 輕量化網絡
class DepthwiseSeparableConv(nn.Module):def __init__(self, in_ch, out_ch, stride=1):super().__init__()self.depthwise = nn.Conv2d(in_ch, in_ch, 3, stride=stride, padding=1, groups=in_ch, bias=False)self.pointwise = nn.Conv2d(in_ch, out_ch, 1, bias=False)self.bn = nn.BatchNorm2d(out_ch)self.act = nn.ReLU6(inplace=True)def forward(self, x):x = self.depthwise(x)x = self.pointwise(x)x = self.bn(x)return self.act(x)class MobileNetV2_HSI(nn.Module):def __init__(self, in_ch, num_classes, width_mult=1.0):super().__init__()c1, c2, c3 = int(32 * width_mult), int(64 * width_mult), int(128 * width_mult)self.layer1 = DepthwiseSeparableConv(in_ch, c1)self.layer2 = DepthwiseSeparableConv(c1, c2)self.layer3 = DepthwiseSeparableConv(c2, c3)self.gap = nn.AdaptiveAvgPool2d(1)self.fc = nn.Linear(c3, num_classes)def forward(self, x):x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.gap(x).flatten(1)return self.fc(x)
結構極簡但實用:3 段深度可分離卷積 + GAP + FC,對 HSI 的小樣本/低算力場景友好。
3.4 數據集與 Patch 封裝
class HSIPatchDataset(Dataset):def __init__(self, patches, labels):# patches: (N, H, W, C) → 張量 (N, C, H, W)self.X = torch.tensor(patches, dtype=torch.float32).permute(0, 3, 1, 2)self.y = torch.tensor(labels, dtype=torch.long)def __len__(self): return len(self.y)def __getitem__(self, idx): return self.X[idx], self.y[idx]
3.5 全圖預測(坐標批推理)
@torch.inference_mode()
def predict_full_image_by_coords(model, X_img_pca, patch_size, device,batch_size=2048, title="全圖預測(坐標批推理)"):model.eval()H, W, C = X_img_pca.shapem = patch_size // 2padded = np.pad(X_img_pca, ((m, m), (m, m), (0, 0)), mode='reflect')coords = np.mgrid[0:H, 0:W].reshape(2, -1).Tpred_map = np.zeros((H, W), dtype=np.int32)t0 = time.time()for i in range(0, len(coords), batch_size):batch_coords = coords[i:i + batch_size]patches = np.empty((len(batch_coords), patch_size, patch_size, C), dtype=np.float32)for k, (r, c) in enumerate(batch_coords):patches[k] = padded[r:r + patch_size, c:c + patch_size, :]tensor = torch.from_numpy(patches).permute(0, 3, 1, 2).to(device)preds = model(tensor).argmax(dim=1).cpu().numpy() + 1 # +1 便于和 GT 對齊for (r, c), p in zip(batch_coords, preds):pred_map[r, c] = pprint(f"全圖預測耗時:{time.time() - t0:.2f} 秒")# 可視化(阻塞顯示,避免“最后的圖沒有顯示”)try:plt.figure(figsize=(10, 7.5))cmap = matplotlib.colormaps.get_cmap('tab20')vmin, vmax = pred_map.min(), pred_map.max()if vmin == vmax: vmin, vmax = 0, 1im = plt.imshow(pred_map, cmap=cmap, interpolation='nearest', vmin=vmin, vmax=vmax)cbar = plt.colorbar(im, shrink=0.85); cbar.set_label('預測類別', rotation=90)plt.title(title, fontsize=14, weight='bold'); plt.axis('off'); plt.tight_layout()print("嘗試顯示全圖預測結果...")plt.show(block=True)except Exception as e:print(f"顯示全圖預測結果時出錯: {e}")try:plt.savefig("prediction_map.png", bbox_inches='tight')print("已保存為 prediction_map.png")except Exception as se:print(f"保存失敗: {se}")return pred_map
3.6 主流程(數據→劃分→無泄露預處理→訓練→評估→全圖預測)
下面是主函數的關鍵片段(末尾附完整可運行腳本):
def main():set_seeds(42)device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')print(f"使用設備: {device}")# ---- 路徑與超參(按需修改)----DATA_DIR = r"your_path"X_FILE, Y_FILE = "KSC.mat", "KSC_gt.mat"PCA_DIM, PATCH_SIZE, TRAIN_RATIO = 30, 5, 0.30BATCH_SIZE, EPOCHS, LR, WEIGHT_DECAY = 64, 30, 1e-3, 1e-4NUM_WORKERS = 0 if os.name == 'nt' else min(4, os.cpu_count() or 0)PIN_MEMORY = (device.type == 'cuda')PREDICT_BATCH_SIZE = 4096# ---- 讀取數據 ----def load_data():X = sio.loadmat(os.path.join(DATA_DIR, X_FILE))Y = sio.loadmat(os.path.join(DATA_DIR, Y_FILE))x_key = [k for k in X.keys() if not k.startswith("__")][0]y_key = [k for k in Y.keys() if not k.startswith("__")][0]return X[x_key], Y[y_key]X_img, Y_img = load_data()h, w, bands = X_img.shapeprint(f"數據尺寸: {h}×{w}, 波段: {bands}")# ---- 有標簽索引 + 分層劃分 ----labeled_idx_rc = np.array([(i, j) for i in range(h) for j in range(w) if Y_img[i, j] != 0])labels_all = np.array([Y_img[i, j] - 1 for i, j in labeled_idx_rc], dtype=np.int64)num_classes = len(np.unique(labels_all))print(f"有標簽樣本: {len(labeled_idx_rc)},類別數: {num_classes}")train_ids, test_ids = train_test_split(np.arange(len(labeled_idx_rc)),test_size=1 - TRAIN_RATIO, stratify=labels_all, random_state=42)# ---- 僅訓練像素擬合 Scaler+PCA(無泄露)----print("擬合 StandardScaler/PCA(僅訓練像素)...")train_pixels = np.array([X_img[i, j] for i, j in labeled_idx_rc[train_ids]], dtype=np.float32)scaler = StandardScaler().fit(train_pixels)pca = PCA(n_components=PCA_DIM, random_state=42).fit(scaler.transform(train_pixels))# 整圖進入同一空間(float32)X_pca_img = pca.transform(scaler.transform(X_img.reshape(-1, bands).astype(np.float32))).astype(np.float32)X_pca_img = X_pca_img.reshape(h, w, PCA_DIM)# ---- 提取訓練/測試 patch ----def extract_patches(sel_ids):m = PATCH_SIZE // 2padded = np.pad(X_pca_img, ((m, m), (m, m), (0, 0)), mode='reflect')patches = np.empty((len(sel_ids), PATCH_SIZE, PATCH_SIZE, PCA_DIM), dtype=np.float32)labs = np.empty((len(sel_ids),), dtype=np.int64)for n, k in enumerate(sel_ids):i, j = labeled_idx_rc[k]patches[n] = padded[i:i + PATCH_SIZE, j:j + PATCH_SIZE, :]labs[n] = labels_all[k]return patches, labsX_train, y_train = extract_patches(train_ids)X_test, y_test = extract_patches(test_ids)# ---- DataLoader ----train_loader = DataLoader(HSIPatchDataset(X_train, y_train), batch_size=BATCH_SIZE,shuffle=True, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)test_loader = DataLoader(HSIPatchDataset(X_test, y_test), batch_size=BATCH_SIZE,shuffle=False, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)# ---- 模型與優化器 ----model = MobileNetV2_HSI(PCA_DIM, num_classes).to(device)optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)criterion = nn.CrossEntropyLoss()scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=3)# ---- 評估函數 ----@torch.no_grad()def evaluate(loader):model.eval()all_y, all_pred = [], []for xb, yb in loader:xb = xb.to(device)pred = model(xb).argmax(dim=1).cpu().numpy()all_pred.extend(pred); all_y.extend(yb.numpy())return accuracy_score(all_y, all_pred), np.array(all_y), np.array(all_pred)# ---- 訓練循環 ----print("開始訓練...")best_acc, model_path = 0.0, "best_mnv2_hsi.pth"for epoch in range(1, EPOCHS + 1):model.train(); total_loss = 0.0for xb, yb in train_loader:xb, yb = xb.to(device), yb.to(device)optimizer.zero_grad()loss = criterion(model(xb), yb)loss.backward(); optimizer.step()total_loss += loss.item() * xb.size(0)test_acc, _, _ = evaluate(test_loader)scheduler.step(test_acc)print(f"Epoch {epoch:02d}/{EPOCHS} | 損失: {total_loss/len(train_loader.dataset):.4f} | 測試準確率: {test_acc:.4f}")if test_acc > best_acc:best_acc = test_acctorch.save(model.state_dict(), model_path)print(f"訓練完成,最佳測試準確率:{best_acc:.4f}")# ---- 安全加載最佳權重 ----try:state = torch.load(model_path, map_location=device, weights_only=True)except TypeError:state = torch.load(model_path, map_location=device)model.load_state_dict(state)# ---- 測試報告 & 混淆矩陣 ----test_acc, y_true, y_pred = evaluate(test_loader)print("\n測試集分類報告:")print(classification_report(y_true, y_pred, digits=4, zero_division=0))plt.figure(figsize=(10, 7))class_names = [f"類{i + 1}" for i in range(num_classes)]sns.heatmap(confusion_matrix(y_true, y_pred),annot=True, fmt='d', cmap="Blues",xticklabels=class_names, yticklabels=class_names,cbar=False, square=True)plt.xlabel("預測標簽"); plt.ylabel("真實標簽")plt.title("MobileNetV2 測試集混淆矩陣", fontsize=14, weight='bold')plt.tight_layout(); plt.show(block=True)# ---- 全圖預測 ----print("全圖預測中(坐標→收集 patch→堆成 batch→前向)...")pred_map = predict_full_image_by_coords(model, X_pca_img, patch_size=PATCH_SIZE, device=device,batch_size=PREDICT_BATCH_SIZE, title="MobileNetV2 全圖預測(坐標批推理)")print(f"預測圖統計: min={pred_map.min()}, max={pred_map.max()}, mean={pred_map.mean():.3f}")print("完成。")
3.7 Windows 入口保護(多進程/顯示更穩)
if __name__ == "__main__":try:import multiprocessing as mpmp.set_start_method("spawn", force=True)mp.freeze_support()except Exception:passmain()
4. 結果展示
歡迎大家關注下方我的公眾獲取更多內容!