好的,我們先聚焦第一個突破點:
通過類似數據蒸餾或主動學習采樣的方法,更加高效地學習良品數據分布。
這里我提供一個完整的代碼示例:
? Masked圖像重建 + 殘差熱力圖
這屬于自監督蒸餾方法的一個變體:
- 使用一個 預訓練MAE模型(或輕量ViT)對正常樣本進行遮擋重建
- 用重建圖與原圖的殘差來反映“異常程度”
? 示例環境依賴
pip install timm einops torchvision matplotlib
? 完整代碼(以MVTec中的圖像為例)
import torch
import torch.nn as nn
import torchvision.transforms as T
from torchvision.utils import save_image
from torchvision.datasets.folder import default_loader
from einops import rearrange
import timm
import matplotlib.pyplot as plt
import os
from glob import glob
from PIL import Image
import numpy as np# ---------------------------
# 模型定義:ViT作為Encoder + 簡單Decoder
# ---------------------------
class MAE(nn.Module):def __init__(self, encoder_name='vit_base_patch16_224', mask_ratio=0.4):super().__init__()self.encoder = timm.create_model(encoder_name, pretrained=True)self.mask_ratio = mask_ratioself.patch_size = self.encoder.patch_embed.patch_size[0]self.num_patches = self.encoder.patch_embed.num_patchesself.embed_dim = self.encoder.embed_dimself.decoder = nn.Sequential(nn.Linear(self.embed_dim, self.embed_dim),nn.GELU(),nn.Linear(self.embed_dim, self.patch_size**2 * 3))def forward(self, x):B, C, H, W = x.shapex_patch = self.encoder.patch_embed(x) # [B, num_patches, dim]B, N, D = x_patch.shape# 隨機遮擋rand_idx = torch.rand(B, N).argsort(dim=1)num_keep = int(N * (1 - self.mask_ratio))keep_idx = rand_idx[:, :num_keep]x_keep = torch.gather(x_patch, 1, keep_idx.unsqueeze(-1).expand(-1, -1, D))x_encoded = self.encoder.blocks(x_keep)x_decoded = self.decoder(x_encoded)# 恢復順序(只對keep部分重建)output = torch.zeros(B, N, self.patch_size**2 * 3).to(x.device)output.scatter_(1, keep_idx.unsqueeze(-1).expand(-1, -1, self.patch_size**2 * 3), x_decoded)output = rearrange(output, 'b n (p c) -> b c (h p) (w p)', p=self.patch_size, c=3, h=int(H/self.patch_size), w=int(W/self.patch_size))return output# ---------------------------
# 數據加載 + 預處理
# ---------------------------
transform = T.Compose([T.Resize((224, 224)),T.ToTensor(),T.Normalize([0.5]*3, [0.5]*3)
])inv_transform = T.Compose([T.Normalize(mean=[-1]*3, std=[2]*3)
])def load_images(path):files = sorted(glob(os.path.join(path, '*.png')) + glob(os.path.join(path, '*.jpg')))images = []for f in files:img = default_loader(f)images.append(transform(img))return torch.stack(images)# ---------------------------
# 測試圖像 → 重建圖像 → 殘差熱圖
# ---------------------------
def visualize_anomaly(original, recon, save_path='result.png'):residual = (original - recon).abs().sum(dim=1, keepdim=True)residual = residual / residual.max()fig, axs = plt.subplots(1, 3, figsize=(12, 4))axs[0].imshow(inv_transform(original[0]).permute(1, 2, 0).cpu().numpy())axs[0].set_title('Original')axs[1].imshow(inv_transform(recon[0]).permute(1, 2, 0).cpu().numpy())axs[1].set_title('Reconstruction')axs[2].imshow(residual[0, 0].cpu().numpy(), cmap='hot')axs[2].set_title('Anomaly Map')for ax in axs: ax.axis('off')plt.tight_layout()plt.savefig(save_path)plt.close()# ---------------------------
# 主程序執行
# ---------------------------
if __name__ == '__main__':device = 'cuda' if torch.cuda.is_available() else 'cpu'model = MAE().to(device)model.eval()# 替換為 MVTec / VisA 任一類別路徑image_dir = './mvtec/bottle/good/' # 只加載良品圖像images = load_images(image_dir).to(device)with torch.no_grad():for i in range(min(5, len(images))):input_img = images[i:i+1]recon_img = model(input_img)visualize_anomaly(input_img, recon_img, f'output_{i}.png')
? 示例輸出(保存為output_0.png
等):
- 左:原圖
- 中:重建圖(模型“理解的良品”)
- 右:異常熱圖(殘差)
在正常樣本上,殘差圖應接近0;如果輸入的是異常圖像,則對應區域將出現高響應。
? 可擴展方向
模塊 | 可擴展優化 |
---|---|
Encoder | 更換為輕量ViT(如 vit_tiny_patch16_224 ) |
Mask策略 | 使用結構化遮擋(如Block Mask)提升重建挑戰 |
異常圖像 | 輸入異常樣本(如MVTec測試集中defect圖)驗證泛化能力 |
訓練 | 加入重建loss微調,提高良品建模精度 |
如果你希望我進一步擴展為:
- 支持少量異常圖像的快速修正版本;
- 或加入主動樣本選擇機制;
很好,我們繼續在上一套基于自監督重建(MAE)方法的基礎上,
為其 封裝 Gradio Demo,以實現更直觀的異常檢測體驗。
? 新增功能目標
- 上傳任意圖片(良品或異常圖)
- 實時顯示:
- 原圖
- 模型重建圖
- 殘差熱力圖(高響應 = 異常區域)
? 完整代碼(附Gradio界面)
import torch
import torch.nn as nn
import torchvision.transforms as T
from torchvision.utils import save_image
from torchvision.datasets.folder import default_loader
from einops import rearrange
import timm
import gradio as gr
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import io# ---------------------------
# 模型定義(同上)
# ---------------------------
class MAE(nn.Module):def __init__(self, encoder_name='vit_base_patch16_224', mask_ratio=0.4):super().__init__()self.encoder = timm.create_model(encoder_name, pretrained=True)self.mask_ratio = mask_ratioself.patch_size = self.encoder.patch_embed.patch_size[0]self.num_patches = self.encoder.patch_embed.num_patchesself.embed_dim = self.encoder.embed_dimself.decoder = nn.Sequential(nn.Linear(self.embed_dim, self.embed_dim),nn.GELU(),nn.Linear(self.embed_dim, self.patch_size**2 * 3))def forward(self, x):B, C, H, W = x.shapex_patch = self.encoder.patch_embed(x)B, N, D = x_patch.shaperand_idx = torch.rand(B, N).argsort(dim=1)num_keep = int(N * (1 - self.mask_ratio))keep_idx = rand_idx[:, :num_keep]x_keep = torch.gather(x_patch, 1, keep_idx.unsqueeze(-1).expand(-1, -1, D))x_encoded = self.encoder.blocks(x_keep)x_decoded = self.decoder(x_encoded)output = torch.zeros(B, N, self.patch_size**2 * 3).to(x.device)output.scatter_(1, keep_idx.unsqueeze(-1).expand(-1, -1, self.patch_size**2 * 3), x_decoded)output = rearrange(output, 'b n (p c) -> b c (h p) (w p)', p=self.patch_size, c=3, h=int(H/self.patch_size), w=int(W/self.patch_size))return output# ---------------------------
# 預處理 & 后處理
# ---------------------------
transform = T.Compose([T.Resize((224, 224)),T.ToTensor(),T.Normalize([0.5]*3, [0.5]*3)
])inv_transform = T.Compose([T.Normalize(mean=[-1]*3, std=[2]*3)
])def tensor_to_pil(t):t = inv_transform(t.squeeze(0)).clamp(0, 1)return T.ToPILImage()(t)def residual_map(orig, recon):residual = (orig - recon).abs().sum(dim=1, keepdim=True)residual = residual / (residual.max() + 1e-8)heat = residual.squeeze().cpu().numpy()fig, ax = plt.subplots()ax.imshow(heat, cmap='hot')ax.axis('off')buf = io.BytesIO()plt.savefig(buf, format='png')plt.close(fig)buf.seek(0)return Image.open(buf)# ---------------------------
# 推理函數
# ---------------------------
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = MAE().to(device)
model.eval()def infer(img_pil):img_tensor = transform(img_pil).unsqueeze(0).to(device)with torch.no_grad():recon = model(img_tensor)recon_img = tensor_to_pil(recon)input_img = tensor_to_pil(img_tensor)heatmap = residual_map(img_tensor, recon)return input_img, recon_img, heatmap# ---------------------------
# Gradio UI
# ---------------------------
demo = gr.Interface(fn=infer,inputs=gr.Image(type="pil", label="上傳圖像"),outputs=[gr.Image(type="pil", label="原圖"),gr.Image(type="pil", label="重建圖"),gr.Image(type="pil", label="殘差熱圖")],title="基于良品數據的異常檢測(MAE重建)",description="上傳圖像,模型將重建正常區域并生成異常殘差熱力圖"
)if __name__ == '__main__':demo.launch()
? 使用效果
你可以上傳如下類型圖像進行實時檢測:
- ?? 良品圖像:殘差圖整體應較為平滑,響應值低;
- ? 異常圖像(如劃痕/破損):殘差圖中異常區域明顯發亮(高響應);
? 后續擴展建議:
模塊 | 可增強 |
---|---|
重建網絡 | 替換為 DRAEM / Reverse Distillation |
異常評分 | 計算全圖平均殘差 + Otsu二值化分割 |
多樣本比較 | 支持目錄上傳并批量可視化 |
遷移微調 | 用少量目標數據 fine-tune 提升領域魯棒性 |
需要我下一步為你實現:
- ? 殘差異常評分 + 二值掩碼輸出?
- ? 支持少量異常樣本微調功能?
- ? 用 PatchCore / AnomalyCLIP 替換 MAE 結構?
你可以指定下一個要增強的方向,我這邊可以直接給出代碼。