基于Pytochvideo訓練自己的的視頻分類模型

視頻分類模型簡介

?X3D 系列模型

官方網站

https://github.com/facebookresearch/SlowFast

?提出論文?

Facebook Research 的《X3D: Expanding Architectures for Efficient Video Recognition》

https://arxiv.org/pdf/2004.04730

原理

????????X3D 的設計思路受到機器學習中特征選擇方法的啟發,它基于 X2D 圖像分類模型,通過一種逐步擴展的方式,將 2D 空間建模拓展為 3D 時空建模。具體來說,X3D 在網絡的寬度、深度、幀率、幀數和分辨率等維度上,依次只對單一維度進行擴展,并在每一步中綜合考慮計算量與精度表現,從而選擇最優的擴展策略。

X3D通過6個軸來對X2D進行拓展,X2D在這6個軸上都為1。

拓張維度

維度??物理意義優化影響
X-Temporal采樣幀數(視頻片段長度)增強長時序上下文感知能力(如手勢識別)
X-Fast幀率(采樣間隔縮短)提升時間分辨率,優化快速捕捉(如體育動作分解)
X-Spatial輸入空間分辨率(112→224)提升細節識別能力(需同步增加網絡深度以擴大感受野)
X-Depth網絡層數(ResNet階段數)增強特征抽象能力,匹配高分辨率輸入要求
X-Width通道數提升特征表達能力(計算量≈通道數2×分辨率2)
X-BottleneckBottleneck層通道寬度優化計算效率:擴展內部通道可平衡精度與計算量(優于全局加寬)

模型結果指標和參數量

數據準備

數據集根目錄/
├── train/ ? ? ? ? ? ? ? ? ?# 訓練集
│ ? ├── flow/ ? ? ? ? ? ? ?# 類別1(正常視頻流)
│ ? │ ? ├── video1.mp4
│ ? │ ? └── video2.avi
│ ? └── freeze/ ? ? ? ? ? ?# 類別2(視頻凍結)
│ ? ? ? ├── video3.mp4
│ ? ? ? └── video4.mov
└── val/ ? ? ? ? ? ? ? ? ? # 驗證集
? ? ├── flow/
? ? │ ? ├── video5.mp4
? ? │ ? └── video6.avi
? ? └── freeze/
? ? ? ? ├── video7.mp4
? ? ? ? └── video8.mkv

訓練代碼

import os
import sys
import time
import copy
import argparse
import random
import warnings
from pathlib import Path
from typing import List, Tupleimport numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, WeightedRandomSamplerfrom torchvision.io import read_video
from torchvision.transforms import functional as TF# --------------------------- 工具 ---------------------------def set_seed(seed: int = 42):random.seed(seed)np.random.seed(seed)torch.manual_seed(seed)torch.cuda.manual_seed_all(seed)def list_videos(root: Path, exts=(".mp4", ".avi", ".mov", ".mkv")) -> List[Path]:files = []for ext in exts:files += list(root.rglob(f"*{ext}"))return sorted(files)def count_labels(samples: List[Tuple[Path, int]], num_classes: int = 2):counts = [0] * num_classesfor _, y in samples:counts[y] += 1return counts# --------------------------- 數據集 ---------------------------class VideoFolderDataset(Dataset):"""讀取 root/{split}/{class}/*.mp4- 均勻采樣 T 幀(不足補尾幀)- 訓練:隨機短邊縮放、隨機裁剪、概率翻轉驗證:短邊定值、中心裁剪- 輸出 (C,T,H,W) float32,[0,1] 標準化(Kinetics 統計)"""def __init__(self,root: str,split: str = "train",classes: Tuple[str, str] = ("flow", "freeze"),frames: int = 16,short_side: int = 256,crop_size: int = 224,mean: Tuple[float, float, float] = (0.45, 0.45, 0.45),std: Tuple[float, float, float] = (0.225, 0.225, 0.225),allow_corrupt_skip: bool = True,train_scale_jitter: Tuple[float, float] = (0.8, 1.2),hflip_prob: float = 0.5,):super().__init__()self.root = Path(root)self.split = splitself.frames = framesself.short_side = short_sideself.crop_size = crop_sizeself.mean = torch.tensor(mean).view(3, 1, 1, 1)self.std = torch.tensor(std).view(3, 1, 1, 1)self.classes = tuple(sorted(classes))self.class_to_idx = {c: i for i, c in enumerate(self.classes)}self.allow_corrupt_skip = allow_corrupt_skipself.train_scale_jitter = train_scale_jitterself.hflip_prob = hflip_prob if split == "train" else 0.0self.samples: List[Tuple[Path, int]] = []for c in self.classes:cdir = self.root / split / cvids = list_videos(cdir)for v in vids:self.samples.append((v, self.class_to_idx[c]))if len(self.samples) == 0:raise FileNotFoundError(f"No videos found in {self.root}/{split}/({self.classes}).")if self.allow_corrupt_skip:keep = []for p, y in self.samples:try:vframes, _, _ = read_video(str(p), pts_unit="sec", output_format="TCHW", start_pts=0, end_pts=0.1)if vframes.numel() == 0:continuekeep.append((p, y))except Exception:print(f"??  跳過無法讀取的視頻: {p}")if keep:self.samples = keepself.label_counts = count_labels(self.samples, num_classes=len(self.classes))def __len__(self):return len(self.samples)@staticmethoddef _uniform_indices(total: int, num: int) -> np.ndarray:if total <= 0:return np.zeros((num,), dtype=np.int64)if total >= num:idx = np.linspace(0, total - 1, num=num)return np.round(idx).astype(np.int64)else:base = list(range(total))base += [total - 1] * (num - total)return np.array(base, dtype=np.int64)def _load_video_tensor(self, path: Path) -> torch.Tensor:vframes, _, _ = read_video(str(path), pts_unit="sec", output_format="TCHW")if vframes.numel() == 0:raise RuntimeError("Empty video tensor.")if vframes.shape[1] == 1:vframes = vframes.repeat(1, 3, 1, 1)return vframes  # (T,C,H,W)def __getitem__(self, idx: int):path, label = self.samples[idx]try:v = self._load_video_tensor(path)except Exception:if self.allow_corrupt_skip:new_idx = random.randint(0, len(self.samples) - 1)path, label = self.samples[new_idx]v = self._load_video_tensor(path)else:raiseT, C, H, W = v.shape# 均勻采樣 frames 幀idxs = self._uniform_indices(T, self.frames)v = v[idxs]if self.split == "train":scale = random.uniform(self.train_scale_jitter[0], self.train_scale_jitter[1])target_ss = max(64, int(self.short_side * scale))v = TF.resize(v, target_ss, antialias=True)_, _, H2, W2 = v.shapeif H2 < self.crop_size or W2 < self.crop_size:min_ss = max(self.crop_size, min(H2, W2))v = TF.resize(v, min_ss, antialias=True)_, _, H2, W2 = v.shapetop = random.randint(0, H2 - self.crop_size)left = random.randint(0, W2 - self.crop_size)v = TF.crop(v, top, left, self.crop_size, self.crop_size)if random.random() < self.hflip_prob:v = torch.flip(v, dims=[-1])else:v = TF.resize(v, self.short_side, antialias=True)v = TF.center_crop(v, [self.crop_size, self.crop_size])v = v.permute(1, 0, 2, 3).contiguous()   # (C,T,H,W)v = v.float() / 255.0v = (v - self.mean) / self.stdreturn v, torch.tensor(label, dtype=torch.long)# --------------------------- 模型構建(含預訓練) ---------------------------def build_model(arch: str, frames: int, crop_size: int, num_classes: int = 2, pretrained: bool = True) -> nn.Module:arch = arch.lower()if arch in {"x3d_s", "x3d_m"}:model = torch.hub.load('facebookresearch/pytorchvideo', arch, pretrained=pretrained)if hasattr(model.blocks[-1], "proj") and isinstance(model.blocks[-1].proj, nn.Linear):in_feats = model.blocks[-1].proj.in_featuresmodel.blocks[-1].proj = nn.Linear(in_feats, num_classes)else:head = model.blocks[-1]proj = Nonefor _, m in head.named_modules():if isinstance(m, nn.Linear):proj = m; breakif proj is None:raise RuntimeError("未找到X3D分類頭線性層,請升級 pytorchvideo 或改用 torchvision 模型。")in_feats = proj.in_featuresmodel.blocks[-1].proj = nn.Linear(in_feats, num_classes)return modelelif arch in {"r2plus1d_18", "r3d_18"}:from torchvision.models.video import r2plus1d_18, r3d_18from torchvision.models.video import R2Plus1D_18_Weights, R3D_18_Weightsif arch == "r2plus1d_18":weights = R2Plus1D_18_Weights.KINETICS400_V1 if pretrained else Nonemodel = r2plus1d_18(weights=weights)else:weights = R3D_18_Weights.KINETICS400_V1 if pretrained else Nonemodel = r3d_18(weights=weights)in_feats = model.fc.in_featuresmodel.fc = nn.Linear(in_feats, num_classes)return modelelse:raise ValueError(f"未知 arch: {arch}. 可選: x3d_s, x3d_m, r2plus1d_18, r3d_18")def set_backbone_trainable(model: nn.Module, trainable: bool, arch: str):for p in model.parameters():p.requires_grad = trainableif arch.startswith("x3d"):for p in model.blocks[-1].parameters():p.requires_grad = Trueelse:for p in model.fc.parameters():p.requires_grad = Truedef get_head_parameters(model: nn.Module, arch: str):return list(model.blocks[-1].parameters()) if arch.startswith("x3d") else list(model.fc.parameters())# --------------------------- EMA / TTA / Metrics ---------------------------class ModelEMA:"""Exponential Moving Average of model parameters."""def __init__(self, model: nn.Module, decay: float = 0.999):self.ema = copy.deepcopy(model).eval()for p in self.ema.parameters():p.requires_grad_(False)self.decay = decay@torch.no_grad()def update(self, model: nn.Module):d = self.decaymsd = model.state_dict()esd = self.ema.state_dict()for k in esd.keys():v = esd[k]mv = msd[k]if isinstance(v, torch.Tensor) and v.dtype.is_floating_point:esd[k].mul_(d).add_(mv.detach(), alpha=1 - d)else:esd[k].copy_(mv)@torch.no_grad()
def _forward_with_tta(model: nn.Module, x: torch.Tensor, tta_flip: bool):logits = model(x)if tta_flip:x_flip = torch.flip(x, dims=[-1])logits = logits + model(x_flip)logits = logits / 2.0return logits@torch.no_grad()
def evaluate(model: nn.Module, loader: DataLoader, device: str = "cuda", tta_flip: bool = False):model.eval()total, correct, loss_sum = 0, 0, 0.0criterion = nn.CrossEntropyLoss()amp_ctx = torch.amp.autocast(device_type='cuda', dtype=torch.float16, enabled=(device == "cuda"))for x, y in loader:x = x.to(device, non_blocking=True).float()y = y.to(device, non_blocking=True)with amp_ctx:logits = _forward_with_tta(model, x, tta_flip)loss = criterion(logits, y)loss_sum += loss.item() * y.size(0)pred = logits.argmax(dim=1)correct += (pred == y).sum().item()total += y.size(0)return correct / max(1, total), loss_sum / max(1, total)@torch.no_grad()
def evaluate_detailed(model: nn.Module, loader: DataLoader, device: str = "cuda", tta_flip: bool = False):"""返回詳細指標并打印:混淆矩陣/各類P/R/F1;掃描閾值優化freeze的F1與Balanced Acc。"""model.eval()all_probs1, all_labels = [], []amp_ctx = torch.amp.autocast(device_type='cuda', dtype=torch.float16, enabled=(device == "cuda"))for x, y in loader:x = x.to(device, non_blocking=True).float()with amp_ctx:logits = _forward_with_tta(model, x, tta_flip)probs = torch.softmax(logits.float(), dim=1)all_probs1.append(probs[:, 1].cpu())all_labels.append(y)p1 = torch.cat(all_probs1).numpy()y_true = torch.cat(all_labels).numpy().astype(int)def metrics_at(th):y_pred = (p1 >= th).astype(int)tp = int(((y_true == 1) & (y_pred == 1)).sum())tn = int(((y_true == 0) & (y_pred == 0)).sum())fp = int(((y_true == 0) & (y_pred == 1)).sum())fn = int(((y_true == 1) & (y_pred == 0)).sum())acc = (tp + tn) / max(1, len(y_true))prec1 = tp / max(1, tp + fp)rec1 = tp / max(1, tp + fn)f1_1 = 2 * prec1 * rec1 / max(1e-12, (prec1 + rec1))prec0 = tn / max(1, tn + fn)rec0 = tn / max(1, tn + fp)f1_0 = 2 * prec0 * rec0 / max(1e-12, (prec0 + rec0))bal_acc = 0.5 * (rec0 + rec1)cm = np.array([[tn, fp],[fn, tp]], dtype=int)return acc, bal_acc, (prec0, rec0, f1_0), (prec1, rec1, f1_1), cm# 0.5 默認與最佳閾值acc50, bal50, cls0_50, cls1_50, cm50 = metrics_at(0.5)best_f1_th, best_f1 = 0.5, -1best_bal_th, best_bal = 0.5, -1for th in np.linspace(0.05, 0.95, 91):acc, bal, _, cls1, _ = metrics_at(th)f1 = cls1[2]if f1 > best_f1:best_f1, best_f1_th = f1, thif bal > best_bal:best_bal, best_bal_th = bal, thprint("== Detailed Validation Metrics ==")print(f"Default th=0.50 | Acc={acc50:.4f} | BalancedAcc={bal50:.4f} | "f"Class0(P/R/F1)={cls0_50[0]:.3f}/{cls0_50[1]:.3f}/{cls0_50[2]:.3f} | "f"Class1(P/R/F1)={cls1_50[0]:.3f}/{cls1_50[1]:.3f}/{cls1_50[2]:.3f}")print(f"Confusion Matrix @0.50 (rows=true [0,1]; cols=pred [0,1]):\n{cm50}")print(f"Best F1(freeze=1) th={best_f1_th:.2f} | F1={best_f1:.4f}")print(f"Best Balanced Acc th={best_bal_th:.2f} | BalancedAcc={best_bal:.4f}")return {"acc@0.5": acc50,"balanced@0.5": bal50,"cm@0.5": cm50,"best_f1_th": best_f1_th,"best_bal_th": best_bal_th,}# --------------------------- 訓練主函數 ---------------------------def main():warnings.filterwarnings("once", category=UserWarning)parser = argparse.ArgumentParser()parser.add_argument("--root", type=str, required=True, help="數據根目錄,包含 train/ val/")parser.add_argument("--epochs", type=int, default=30)parser.add_argument("--freeze_epochs", type=int, default=3, help="線性探測epoch數,僅訓分類頭")parser.add_argument("--batch", type=int, default=8)parser.add_argument("--frames", type=int, default=16)parser.add_argument("--size", type=int, default=224)parser.add_argument("--short_side", type=int, default=256)parser.add_argument("--arch", type=str, default="x3d_m", choices=["x3d_s","x3d_m","r2plus1d_18","r3d_18"])parser.add_argument("--pretrained", type=int, default=1, help="是否使用預訓練權重(1/0)")parser.add_argument("--lr", type=float, default=3e-4)parser.add_argument("--lr_head_mul", type=float, default=10.0, help="分類頭學習率倍率")parser.add_argument("--wd", type=float, default=1e-4)parser.add_argument("--warmup", type=int, default=2, help="warmup的epoch數")parser.add_argument("--clip_grad", type=float, default=1.0, help="梯度裁剪閾值;<=0則關閉")parser.add_argument("--ls", type=float, default=0.05, help="Label smoothing")parser.add_argument("--balance", type=str, default="auto", choices=["off","sampler","class_weight","auto"],help="類別不均衡處理方式")parser.add_argument("--workers", type=int, default=4)parser.add_argument("--seed", type=int, default=42)parser.add_argument("--ckpt", type=str, default="freeze_x3d.pth")parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")# 新增parser.add_argument("--tta_flip", type=int, default=0, help="驗證時水平翻轉TTA")parser.add_argument("--ema", type=int, default=0, help="是否啟用EMA(1/0)")parser.add_argument("--ema_decay", type=float, default=0.999, help="EMA 衰減")args = parser.parse_args()set_seed(args.seed)device = args.deviceprint(f"Device: {device}")print("Enabling TF32 for speed (if Ampere+ GPU).")torch.backends.cuda.matmul.allow_tf32 = Truetorch.backends.cudnn.allow_tf32 = Truetorch.backends.cudnn.benchmark = True# 數據集classes = ("flow", "freeze")train_set = VideoFolderDataset(root=args.root, split="train", classes=classes,frames=args.frames, short_side=args.short_side, crop_size=args.size)val_set = VideoFolderDataset(root=args.root, split="val", classes=classes,frames=args.frames, short_side=args.short_side, crop_size=args.size)print(f"[Data] train={len(train_set)}  val={len(val_set)}  label_counts(train)={train_set.label_counts}")# 不均衡sampler = Noneclass_weight_tensor = Noneif args.balance in ("sampler", "auto"):counts = np.array(train_set.label_counts, dtype=np.float64) + 1e-6inv_freq = 1.0 / countssample_weights = [inv_freq[y] for _, y in train_set.samples]sampler = WeightedRandomSampler(sample_weights, num_samples=len(sample_weights), replacement=True)if args.balance in ("class_weight",):counts = np.array(train_set.label_counts, dtype=np.float64) + 1e-6class_weight_tensor = torch.tensor((counts.sum() / counts), dtype=torch.float32)train_loader = DataLoader(train_set, batch_size=args.batch, shuffle=(sampler is None), sampler=sampler,num_workers=args.workers, pin_memory=True, drop_last=True,persistent_workers=(args.workers > 0), prefetch_factor=2 if args.workers > 0 else None,)val_loader = DataLoader(val_set, batch_size=max(1, args.batch // 2), shuffle=False,num_workers=max(0, args.workers // 2), pin_memory=True, drop_last=False,persistent_workers=False,)# 模型model = build_model(args.arch, args.frames, args.size, num_classes=2, pretrained=bool(args.pretrained)).to(device)# 線性探測set_backbone_trainable(model, trainable=False, arch=args.arch)head_params = get_head_parameters(model, args.arch)head_ids = {id(p) for p in head_params}backbone_params = [p for p in model.parameters() if p.requires_grad and id(p) not in head_ids]param_groups = [{"params": head_params, "lr": args.lr * args.lr_head_mul}]if backbone_params:param_groups.append({"params": backbone_params, "lr": args.lr})optimizer = torch.optim.AdamW(param_groups, lr=args.lr, weight_decay=args.wd)# Schedulerfrom torch.optim.lr_scheduler import LinearLR, CosineAnnealingLR, SequentialLRwarmup_epochs = max(0, min(args.warmup, args.epochs - 1))sched_main = CosineAnnealingLR(optimizer, T_max=max(1, args.epochs - warmup_epochs))scheduler = SequentialLR(optimizer, [LinearLR(optimizer, start_factor=0.1, total_iters=warmup_epochs),sched_main], milestones=[warmup_epochs]) if warmup_epochs > 0 else sched_main# Losscriterion = nn.CrossEntropyLoss(label_smoothing=args.ls,weight=class_weight_tensor.to(device) if class_weight_tensor is not None else None)# AMP & EMAscaler = torch.amp.GradScaler('cuda', enabled=(device == "cuda"))amp_ctx = torch.amp.autocast(device_type='cuda', dtype=torch.float16, enabled=(device == "cuda"))ema = ModelEMA(model, decay=args.ema_decay) if args.ema else Nonebest_acc = 0.0os.makedirs(os.path.dirname(args.ckpt) if os.path.dirname(args.ckpt) else ".", exist_ok=True)# 訓練for epoch in range(1, args.epochs + 1):model.train()t0 = time.time()running_loss = running_acc = seen = 0if epoch == args.freeze_epochs + 1:print(f"===> Unfreezing backbone for finetuning from epoch {epoch}.")set_backbone_trainable(model, trainable=True, arch=args.arch)head_params = get_head_parameters(model, args.arch)head_ids = {id(p) for p in head_params}backbone_params = [p for p in model.parameters() if p.requires_grad and id(p) not in head_ids]optimizer = torch.optim.AdamW([{"params": head_params, "lr": args.lr * args.lr_head_mul},{"params": backbone_params, "lr": args.lr}],lr=args.lr, weight_decay=args.wd)from torch.optim.lr_scheduler import CosineAnnealingLRscheduler = CosineAnnealingLR(optimizer, T_max=max(1, args.epochs - epoch + 1))for step, (x, y) in enumerate(train_loader, 1):x = x.to(device, non_blocking=True).float()y = y.to(device, non_blocking=True)if step == 1 and epoch == 1:print(f"[Sanity] x.dtype={x.dtype}, param.dtype={next(model.parameters()).dtype}, x.shape={x.shape}")optimizer.zero_grad(set_to_none=True)with amp_ctx:logits = model(x)loss = criterion(logits, y)scaler.scale(loss).backward()if args.clip_grad and args.clip_grad > 0:scaler.unscale_(optimizer)nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.clip_grad)scaler.step(optimizer)scaler.update()if ema:ema.update(model)bs = y.size(0)running_loss += loss.item() * bsrunning_acc += (logits.argmax(dim=1) == y).sum().item()seen += bsif step % 10 == 0 or step == len(train_loader):lr0 = optimizer.param_groups[0]["lr"]print(f"Epoch {epoch}/{args.epochs} | Step {step}/{len(train_loader)} | "f"LR {lr0:.2e} | Loss {(running_loss/seen):.4f} | Acc {(running_acc/seen):.4f}")scheduler.step()train_loss = running_loss / max(1, seen)train_acc = running_acc / max(1, seen)# 驗證(優先用EMA模型)eval_model = ema.ema if ema else modelval_acc, val_loss = evaluate(eval_model, val_loader, device=device, tta_flip=bool(args.tta_flip))dt = time.time() - t0print(f"[Epoch {epoch}] train_loss={train_loss:.4f} acc={train_acc:.4f} | "f"val_loss={val_loss:.4f} acc={val_acc:.4f} | time={dt:.1f}s {'(EMA+TTA)' if ema or args.tta_flip else ''}")if val_acc > best_acc:best_acc = val_accckpt = {"epoch": epoch,"state_dict": eval_model.state_dict(),  # 保存 EMA 權重更利于部署"optimizer": optimizer.state_dict(),"scaler": scaler.state_dict(),"best_acc": best_acc,"args": vars(args),"classes": classes,"arch": args.arch,"is_ema": bool(ema)}torch.save(ckpt, args.ckpt)print(f"? Saved best checkpoint to {args.ckpt} (acc={best_acc:.4f})")print(f"Training done. Best val acc = {best_acc:.4f}")# 結束時輸出詳細指標(基于 EMA+TTA 的模型)eval_model = ema.ema if ema else modelevaluate_detailed(eval_model, val_loader, device=device, tta_flip=bool(args.tta_flip))if __name__ == "__main__":try:main()except KeyboardInterrupt:sys.exit(1)

啟動命令:

python3 train_freeze.py --root /path/to/dataset --epochs 30 --freeze_epochs 3 \--arch x3d_m --pretrained 1 --batch 8 --frames 32 --size 224 --short_side 256 \--lr 3e-4 --lr_head_mul 10 --wd 1e-4 --warmup 2 \--balance auto --ls 0.05 --clip_grad 1.0 --workers 8 \--tta_flip 1 --ema 1 --ema_decay 0.999

關鍵參數解釋

參數

典型值

作用

--frames

16/32

控制時間感受野大小

--short_side

256

保持長寬比的縮放基準

--lr_head_mul

10

分類頭學習率是主干的10倍

--ema_decay

0.999

模型權重指數移動平均系數

推理代碼

import os
import sys
import argparse
from pathlib import Path
from typing import List, Tuple, Dict, Anyimport numpy as np
import torch
import torch.nn as nn
from torchvision.io import read_video
from torchvision.transforms import functional as TF# --------------------- 小工具 ---------------------def list_videos(root: Path, exts=(".mp4", ".avi", ".mov", ".mkv")) -> List[Path]:files = []for ext in exts:files += list(root.rglob(f"*{ext}"))return sorted(files)def uniform_indices(total: int, num: int) -> np.ndarray:if total <= 0:return np.zeros((num,), dtype=np.int64)if total >= num:idx = np.linspace(0, total - 1, num=num)return np.round(idx).astype(np.int64)else:base = list(range(total))base += [total - 1] * (num - total)return np.array(base, dtype=np.int64)def segment_indices(total: int, num_frames: int, clip_idx: int, num_clips: int) -> np.ndarray:if num_clips <= 1:return uniform_indices(total, num_frames)start = int(np.floor(clip_idx * total / num_clips))end = int(np.floor((clip_idx + 1) * total / num_clips)) - 1end = max(start, end)seg_len = end - start + 1if seg_len >= num_frames:idx = np.linspace(start, end, num=num_frames)return np.round(idx).astype(np.int64)else:idx = list(range(start, end + 1))idx += [end] * (num_frames - seg_len)return np.array(idx, dtype=np.int64)MEAN = torch.tensor((0.45, 0.45, 0.45)).view(3,1,1,1)
STD  = torch.tensor((0.225, 0.225, 0.225)).view(3,1,1,1)# --------------------- 模型構建(離線優先) ---------------------def build_x3d_offline(variant: str, num_classes: int, pretrained: bool = False, repo_dir: str = "") -> nn.Module:"""優先走 pytorchvideo 本地 Python API(無需聯網);失敗則從本地 hub 緩存目錄加載(source='local'),也不會聯網。"""variant = variant.lower()assert variant in {"x3d_s", "x3d_m"}# 1) 直接用 pytorchvideo 的 Python API(無需 torch.hub、可離線)try:from pytorchvideo.models import hub as pv_hubbuilder = getattr(pv_hub, variant)  # x3d_s / x3d_mmodel = builder(pretrained=pretrained)# 替換頭if hasattr(model.blocks[-1], "proj") and isinstance(model.blocks[-1].proj, nn.Linear):in_feats = model.blocks[-1].proj.in_featuresmodel.blocks[-1].proj = nn.Linear(in_feats, num_classes)else:# 兜底:遍歷最后一塊的線性層head = model.blocks[-1]proj = Nonefor _, m in head.named_modules():if isinstance(m, nn.Linear):proj = m; breakif proj is None:raise RuntimeError("未找到X3D分類頭線性層。")in_feats = proj.in_featuresmodel.blocks[-1].proj = nn.Linear(in_feats, num_classes)return modelexcept Exception as e_api:print(f"[Info] pytorchvideo.models.hub 離線構建失敗,嘗試本地 hub 緩存加載。原因: {e_api}")# 2) 使用 torch.hub 的本地緩存(不聯網)try:if not repo_dir:repo_dir = os.path.join(torch.hub.get_dir(), "facebookresearch_pytorchvideo_main")if not os.path.isdir(repo_dir):raise FileNotFoundError(f"本地 hub 緩存不存在:{repo_dir}")# 關鍵:source='local' 可確保不聯網;trust_repo=True 跳過校驗model = torch.hub.load(repo_dir, variant, pretrained=pretrained, source='local', trust_repo=True)# 替換頭if hasattr(model.blocks[-1], "proj") and isinstance(model.blocks[-1].proj, nn.Linear):in_feats = model.blocks[-1].proj.in_featuresmodel.blocks[-1].proj = nn.Linear(in_feats, num_classes)else:head = model.blocks[-1]proj = Nonefor _, m in head.named_modules():if isinstance(m, nn.Linear):proj = m; breakif proj is None:raise RuntimeError("未找到X3D分類頭線性層。")in_feats = proj.in_featuresmodel.blocks[-1].proj = nn.Linear(in_feats, num_classes)return modelexcept Exception as e_local:raise RuntimeError("無法離線構建 X3D 模型。請確保已安裝 pytorchvideo 或本地已有 hub 緩存。\n"f"- pip 安裝:pip install pytorchvideo\n"f"- 本地緩存目錄(示例):{os.path.join(torch.hub.get_dir(), 'facebookresearch_pytorchvideo_main')}\n"f"原始錯誤:{e_local}")def build_model(arch: str, num_classes: int, pretrained: bool = False, repo_dir: str = "") -> nn.Module:arch = arch.lower()if arch in {"x3d_s", "x3d_m"}:return build_x3d_offline(arch, num_classes=num_classes, pretrained=pretrained, repo_dir=repo_dir)elif arch in {"r2plus1d_18", "r3d_18"}:from torchvision.models.video import r2plus1d_18, r3d_18# 預訓練與否不重要,稍后會 load_state_dictm = r2plus1d_18(weights=None) if arch == "r2plus1d_18" else r3d_18(weights=None)in_feats = m.fc.in_featuresm.fc = nn.Linear(in_feats, num_classes)return melse:raise ValueError(f"未知 arch: {arch}")def load_ckpt_build_model(ckpt_path: str, device: str = "cuda", override: Dict[str, Any] = None, repo_dir: str = ""):# 顯式 weights_only=False,避免未來默認變更帶來的困惑ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)args = ckpt.get("args", {}) or {}arch = (override or {}).get("arch", args.get("arch", "x3d_m"))classes = ckpt.get("classes", ("flow","freeze"))num_classes = len(classes)model = build_model(arch, num_classes=num_classes, pretrained=False, repo_dir=repo_dir)missing, unexpected = model.load_state_dict(ckpt["state_dict"], strict=False)if missing or unexpected:print(f"[load_state_dict] missing={missing} unexpected={unexpected}")model.to(device).eval()meta = {"arch": arch,"classes": classes,"frames": int((override or {}).get("frames", args.get("frames", 16))),"size": int((override or {}).get("size", args.get("size", 224))),"short_side": int((override or {}).get("short_side", args.get("short_side", 256))),}return model, meta# --------------------- 預處理 & 前向 ---------------------@torch.no_grad()
def preprocess_clip(vframes: torch.Tensor, frames: int, short_side: int, crop_size: int, idxs: np.ndarray) -> torch.Tensor:clip = vframes[idxs]  # (frames,C,H,W)if clip.shape[1] == 1:clip = clip.repeat(1,3,1,1)clip = TF.resize(clip, short_side, antialias=True)clip = TF.center_crop(clip, [crop_size, crop_size])clip = clip.permute(1,0,2,3).contiguous().float() / 255.0  # (C,T,H,W)clip = (clip - MEAN) / STDreturn clip.unsqueeze(0)  # (1,3,T,H,W)@torch.no_grad()
def _forward_with_tta(model: nn.Module, x: torch.Tensor, tta_flip: bool):logits = model(x)if tta_flip:logits = (logits + model(torch.flip(x, dims=[-1]))) / 2.0return logits@torch.no_grad()
def infer_one_video(model: nn.Module, path: Path, frames: int, short_side: int, crop_size: int,num_clips: int = 1, tta_flip: bool = False, device: str = "cuda") -> Tuple[int, np.ndarray]:vframes, _, _ = read_video(str(path), pts_unit="sec", output_format="TCHW")if vframes.numel() == 0:raise RuntimeError(f"Empty video: {path}")if vframes.shape[1] == 1:vframes = vframes.repeat(1, 3, 1, 1)T = vframes.shape[0]logits_sum = torch.zeros((1, 2), dtype=torch.float32, device=device)amp_ctx = torch.amp.autocast(device_type='cuda', dtype=torch.float16, enabled=(device == "cuda"))for ci in range(max(1, num_clips)):idxs = segment_indices(T, frames, ci, num_clips)x = preprocess_clip(vframes, frames, short_side, crop_size, idxs).to(device, non_blocking=True)with amp_ctx:logits = _forward_with_tta(model, x, tta_flip)logits_sum += logits.float()probs = torch.softmax(logits_sum / max(1, num_clips), dim=1).squeeze(0).cpu().numpy()pred = int(np.argmax(probs))return pred, probs# --------------------- 主流程 ---------------------def main():parser = argparse.ArgumentParser()parser.add_argument("--ckpt", type=str, required=True, help="訓練保存的 .pth")parser.add_argument("--input", type=str, required=True, help="視頻文件或目錄")parser.add_argument("--out", type=str, default="", help="可選:輸出 CSV 路徑")parser.add_argument("--threshold", type=float, default=0.5, help="freeze(=1) 閾值")parser.add_argument("--clips", type=int, default=1, help="多時間片數(Temporal TTA)")parser.add_argument("--tta_flip", type=int, default=0, help="水平翻轉 TTA (0/1)")parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")parser.add_argument("--frames", type=int, default=None, help="覆蓋 ckpt 的 frames(可選)")parser.add_argument("--size", type=int, default=None, help="覆蓋 ckpt 的 crop size(可選)")parser.add_argument("--short_side", type=int, default=None, help="覆蓋 ckpt 的 short_side(可選)")parser.add_argument("--arch", type=str, default=None, help="覆蓋 arch(可選)")parser.add_argument("--repo_dir", type=str, default="", help="pytorchvideo 本地 hub 緩存目錄(可選)")args = parser.parse_args()if args.device.startswith("cuda"):torch.backends.cuda.matmul.allow_tf32 = Truetorch.backends.cudnn.allow_tf32 = Truetorch.backends.cudnn.benchmark = Trueoverride = {}if args.arch: override["arch"] = args.archif args.frames is not None: override["frames"] = args.framesif args.size is not None: override["size"] = args.sizeif args.short_side is not None: override["short_side"] = args.short_sidemodel, meta = load_ckpt_build_model(args.ckpt, device=args.device, override=override, repo_dir=args.repo_dir)classes = list(meta["classes"])frames = int(meta["frames"])crop = int(meta["size"])short_side = int(meta["short_side"])print(f"[Model] arch={meta['arch']} classes={classes}")print(f"[Preprocess] frames={frames} size={crop} short_side={short_side}")print(f"[TTA] clips={args.clips} flip={bool(args.tta_flip)}  threshold={args.threshold:.2f}")inp = Path(args.input)paths: List[Path]if inp.is_dir():paths = list_videos(inp)if not paths:print(f"No videos found in {inp}")sys.exit(1)else:if not inp.exists():print(f"File not found: {inp}")sys.exit(1)paths = [inp]rows = []for p in paths:try:pred, probs = infer_one_video(model, p, frames, short_side, crop,num_clips=args.clips, tta_flip=bool(args.tta_flip), device=args.device)label = classes[pred] if pred < len(classes) else str(pred)prob_freeze = float(probs[1]) if len(probs) > 1 else float('nan')is_freeze = int(prob_freeze >= args.threshold)print(f"{p.name:40s}  -> pred={label:6s}  probs(flow,freeze)={probs}  freeze@{args.threshold:.2f}={is_freeze}")rows.append((str(p), label, probs[0], probs[1] if len(probs)>1 else float('nan'), is_freeze))except Exception as e:print(f"[Error] {p}: {e}")rows.append((str(p), "ERROR", float('nan'), float('nan'), -1))if args.out:import csvwith open(args.out, "w", newline="") as f:writer = csv.writer(f)writer.writerow(["path", "pred_label", "prob_flow", "prob_freeze", f"freeze@{args.threshold}"])writer.writerows(rows)print(f"Saved results to {args.out}")if __name__ == "__main__":main()

啟動命令

python3 inference_freeze.py --ckpt ./freeze_x3d.pth --input /path/to/video_or_dir \--clips 3 --tta_flip 1

關鍵參數解釋

python3 inference_freeze.py \--ckpt ./freeze_x3d.pth \    # 模型權重文件路徑--input /path/to/video_or_dir \  # 輸入視頻文件或目錄--clips 3 \                # 時間片段采樣數--tta_flip 1               # 水平翻轉增強開關

?

?

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/919489.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/919489.shtml
英文地址,請注明出處:http://en.pswp.cn/news/919489.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

LidaRefer-v2論文速讀

研究背景 研究背景 3D視覺定位&#xff08;3D Visual Grounding, VG&#xff09;是一項旨在根據自然語言描述&#xff0c;在三維場景中精確定位出相應物體或區域的任務 。這項技術在人機交互領域至關重要&#xff0c;尤其是在自動駕駛、機器人技術和AR/VR等應用中&#xff0c;它…

邏輯移位與算術移位

根本的區別在于&#xff1a;它們如何對待符號位&#xff08;最高位&#xff09;。 一、邏輯移位 (Logical Shift) 無論左移、右移&#xff0c;空出的位永遠用 0 填充。主要針對無符號整數、快速乘除2的冪。 二、算術移位 (Arithmetic Shift) 左移用 0 填充、右移用符號位填充。…

內存對齊的使用和禁用

在 C 語言和 C 中&#xff0c;__attribute__((packed)) 是一種用于數據結構體的編譯器擴展屬性&#xff0c;這個屬性主要用于修改結構體的內存對齊行為。背景知識&#xff1a;結構體內存對齊在許多計算機架構中&#xff0c;編譯器會自動對數據進行對齊&#xff08;alignment&am…

SpringBoot3后端項目介紹:mybig-event

mybig-event 項目簡介 mybig-event 是一個基于 Spring Boot 的事件管理系統&#xff0c;提供用戶管理、文章發布、分類管理、文件上傳等功能&#xff0c;采用現代化的 Java 技術棧構建&#xff0c;支持高效開發和部署。 倉庫鏈接&#xff1a;https://github.com/foorgange/mybi…

week3-[分支嵌套]方陣

week3-[分支嵌套]方陣 題目描述 有 nmn\times mnm 個人站成 nnn 行 mmm 列的方陣。我們想知道第 xxx 行 yyy 列的人的某個方向有沒有人。 輸入格式 輸入共 222 行。 第 111 行輸入 444 個正整數 n,m,x,yn,m,x,yn,m,x,y。 第 222 行輸入 111 個字符為 U、D、L、R 其中之一&#…

深入理解C++ std::shared_ptr:現代C++內存管理的藝術與實踐

在C++的發展歷程中,內存管理始終是開發者面臨的核心挑戰。從C語言繼承而來的手動內存管理方式,雖然提供了極大的靈活性,卻也成為無數程序錯誤的根源。內存泄漏、懸空指針、雙重釋放等問題長期困擾著C++開發者,直到智能指針的出現改變了這一局面。作為C++11標準引入的重要特…

一個 WPF 文檔和工具窗口布局容器

一個 WPF 文檔和工具窗口布局容器、用于排列文檔 和工具窗口的方式與許多知名 IDE 類似&#xff0c;例如 Eclipse、Visual Studio、 PhotoShop 等等 AvalonDock 是一個 WPF 文檔和工具窗口布局容器&#xff0c;用于排列文檔 和工具窗口的方式與許多知名 IDE 類似&#xff0c;例…

【qml-5】qml與c++交互(類型單例)

背景&#xff1a; 【qml-1】qml與c交互第一次嘗試&#xff08;實例注入&#xff09; 【qml-2】嘗試一個有模式的qml彈窗 【qml-3】qml與c交互第二次嘗試&#xff08;類型注冊&#xff09; 【qml-4】qml與c交互&#xff08;類型多例&#xff09; 【qml-5】qml與c交互&#…

循環神經網絡(RNN)、LSTM 與 GRU (一)

循環神經網絡&#xff08;RNN&#xff09;、LSTM 與 GRU &#xff08;一&#xff09; 文章目錄循環神經網絡&#xff08;RNN&#xff09;、LSTM 與 GRU &#xff08;一&#xff09;循環神經網絡&#xff08;RNN&#xff09;、LSTM 與 GRU一、RNN&#xff08;Recurrent Neural N…

【AAOS】Android Automotive 16模擬器源碼下載及編譯

源碼下載repo init -u https://android.googlesource.com/platform/manifest -b android-16.0.0_r2 repo sync -c --no-tags --no-clone-bundle源碼編譯source build/envsetup.sh lunch sdk_car_x86_64-bp2a-eng make -j8運行效果emualtorHomeAll appsSettingsHAVCNotification…

jvm三色標記

好的&#xff0c;咱們把專業概念和生活例子結合起來&#xff0c;一步一步說清楚三色標記法&#xff1a;一、核心概念&#xff1a;用“顏色”給對象貼“狀態標簽”就像給家里的物品貼標簽&#xff0c;每種顏色代表它在“垃圾回收&#xff08;大掃除&#xff09;”中的狀態&#…

生成式AI的能力邊界與職業重構:從“百科實習生“到人機協作增強器

根據微軟最新研究&#xff0c;基于20萬條Copilot使用數據及用戶反饋&#xff0c;研究者揭示了生成式AI在實際應用中的能力邊界與職業影響。數據顯示&#xff0c;用戶使用AI助手最頻繁的任務是信息獲取&#xff08;占比近40%&#xff09;&#xff0c;其次是公眾溝通類工作&#…

java17學習筆記

Java17是一個重要的特性發布&#xff0c;也是比較常用的一個版本&#xff0c;根據 2024Java生態統計&#xff0c;Java 17、11 和 8 的用戶比例分別為 35%、33% 和 29%。它遵循了自Java10以來引入的Java發布步調&#xff0c;并于2021年 9 月 14 日發布&#xff0c;在Java16發布后…

【AI應用】修改向量數據庫Milvus默認密碼

說明&#xff1a; 1&#xff09;部署向量數據庫milvus運行一段時間后&#xff0c;想開啟密碼認證登錄attu頁面 2&#xff09;開啟密碼認證登錄&#xff0c;提示用戶和密碼不正確&#xff0c;因為默認密碼已存儲在物理機 3&#xff09;通過attu管理頁面修改向量數據庫milvus默認…

分布式系統消息隊列:可靠投遞與延時消息實戰

在分布式系統架構中&#xff0c;消息隊列&#xff08;MQ&#xff09;作為解耦服務、削峰填谷、異步通信的核心組件&#xff0c;其消息投遞的可靠性與延時消息的精準性直接影響業務系統的穩定性。本文結合實際業務場景&#xff0c;詳細解析消息投遞的全流程設計與延時消息的通用…

Java 學習筆記(基礎篇6)

面向對象基礎1. 類和對象(1) 示例&#xff1a;public class Student {String name "張三";int age 23;public void study() {System.out.println("學習 Java");}public void eat() {System.out.println("吃飯");} }public class Test {public …

光學件加工廠倚光科技:陪跑光學未來力量

在光學創新的漫漫長路上&#xff0c;總有一些看似 “不劃算” 的堅持&#xff0c;卻在悄然改寫行業的未來。倚光科技的故事&#xff0c;就始于這樣一種選擇 —— 明知光學打樣利潤微薄&#xff0c;明知上百個項目中能走到量產的寥寥無幾&#xff0c;仍愿意投入全球頂尖的設備與…

RabbitMQ:生產者可靠性(生產者重連、生產者確認)

目錄一、生產者重連二、生產者確認一、生產者重連 當網絡不穩定的時候&#xff0c;利用重試機制可以有效提高消息發送的成功率。不過SpringAMQP提供的重試機制是阻塞式的重試&#xff0c;也就是說多次重試過程中&#xff0c;當前線程是被阻塞的&#xff0c;會影響業務性能。 …

【深度學習新浪潮】空天地數據融合技術在城市三維重建中的應用

空天地數據融合技術在城市三維重建中的應用已取得顯著進展,尤其在提升精度以滿足具身智能機器人仿真訓練需求方面,研究和產品均呈現多樣化發展。以下是關鍵研究進展、產品方案及精度要求的詳細分析: 一、研究進展與技術路徑 1. 多源數據融合的技術突破 時空基準統一:通過…

Selenium自動化測試入門:cookie處理

&#x1f345; 點擊文末小卡片&#xff0c;免費獲取軟件測試全套資料&#xff0c;資料在手&#xff0c;漲薪更快driver.get_cookies() # 獲得cookie 信息driver.get_cookies(name) # 獲得對應name的cookie信息add_cookie(cookie_dict) # 向cookie 添加會話信息delete_cookie(na…