論文閱讀筆記:《Curriculum Coarse-to-Fine Selection for High-IPC Dataset Distillation》

論文閱讀筆記:《Curriculum Coarse-to-Fine Selection for High-IPC Dataset Distillation》

    • 1.背景與動機
    • 2.核心貢獻
    • 3.方法詳解
    • 4.實驗結果與貢獻
    • 主體代碼
    • 算法整體邏輯

CVPR25 github

一句話總結:

CCFS基于組合范式(軌跡匹配+選擇真實圖像),通過“粗過濾+精選”課程式框架,動態補充合成集弱點,顯著提升高IPC設定下的數據集蒸餾性能,是目前高IPC場景下的SOTA方法。

在這里插入圖片描述


1.背景與動機

  • Dataset Distillation: 將一個大規模訓練集壓縮成一個小型合成數據集,使得在此合成集上訓練的模型性能接近用原始全量數據訓練的模型。
  • IPC(Image Per Class):每類合成圖像數。低IPC場景下(每類幾張圖),已有方法表現不錯;但IPC增大(要合成更多圖像)時,性能往往退化,甚至不如簡單隨機抽樣。
  • 核心問題:高IPC時合成集過于“平均”,缺少稀有/復雜特征(hard samples),導致合成集覆蓋不足;已有的混合蒸餾+真實樣本方法(如SelMatch)是一次性靜態選樣,缺乏與合成集的動態互補。

2.核心貢獻

  1. 不兼容性診斷:分析了“先選真實樣本再蒸餾”范式下,靜態真樣本與動態蒸餾集互補不足的問題。
  2. CCFS方法:提出一種課程式(Curriculum)“從粗到細”動態選真樣本框架,將選樣分為兩階段:
    • 粗過濾(Coarse):用當前合成集訓練的filter模型識別“還沒學會”的真實樣本(即被錯分的樣本)
    • 精細選擇(fine):在這些候選中,根據“難度分數”或直接用filter logits選出“最簡單但尚未學會”的樣本,逐步補充到合成集中。
  3. 實證效果:在CIFAR-10/100和Tiny-ImageNet的高IPC設置(壓縮比5%~30%)下,CCFS刷新多項SOTA,部分場景下性能僅比全量訓練低0.3%。

3.方法詳解

整體流程:

  1. 初始化
    • 從任一基礎蒸餾算法(如CDA)得到初始合成集DdistillD_{distill}Ddistill?
    • 令當前合成集S0=DdistillS_0=D_{distill}S0?=Ddistill?
  2. 課程循環(共j階)
    • 訓練Filter:在Sj?1S_{j-1}Sj?1?上蒸餾訓練一個filter模型?j?_j?j?, 讓它學會當前合成集的決策邊界。
    • Coarse:用?j?_j?j?在原始訓練集T上做推理,挑出被錯分的樣本集合DmisjD_{mis}^jDmisj?
    • Fine:對DmisjD_{mis}^jDmisj?內部進行排序,選出每類最“簡單未學會”的前kjk_jkj?張,構成DrealjD_{\mathrm{real}}^jDrealj?
    • 更新:Sj=Sj?1∪DrealjS_j = S_{j-1} \cup D^j_{\mathrm{real}}Sj?=Sj?1?Drealj?

解釋一下為什么選“簡單未學會”的樣本

錯分樣本集合DmissD_{miss}Dmiss?反映了S中的局限性。在這些局限性中,更簡單的特征相對于更復雜的特征而言,對模型訓練的益處更大,因為它們更容易被學習。預選計算的難度分數能夠從全局角度有效衡量樣本特征的相對難度,指導下一步的精細選擇。通過從誤分類樣本中選擇最簡單的特征,可以獲得最優的DrealD_{real}Dreal?,同時避免引入可能阻礙S性能的過于復雜的特征

4.實驗結果與貢獻

  • 數據集:CIFAR-10/100,Tiny-ImageNet
  • 高IPC設置刷新SOTA
    • CIFAR-10/100 在 10% IPC 下,分別較最佳基線提升 ~6.1% / ~5.8%;
    • Tiny-ImageNet 20% IPC 下,僅比全量訓練低 0.3%。
  • 跨架構泛化
    用 ResNet-18 生成合成集,訓練 ResNet-50/101、DenseNet-121、RegNet 等網絡,均優于 CDA、SelMatch 等方法
  • 詳盡消融
    • 驗證 coarse(錯分 vs 自信分對)、fine(簡單 vs 困難 vs 隨機)策略組合;
    • 不同難度分數對比,Forgetting score 最好;
    • 課程輪數對性能與效率影響,3 輪是良好折中。

主體代碼

import os
import datetime
import time
import warnings
import numpy as np
import random
import torch
import torch.utils.data
import torchvision
import utils
from torch import nn
import torchvision.transforms as transforms
from imagenet_ipc import ImageFolderIPC
import torch.nn.functional as F
from tqdm import tqdm
import json
warnings.filterwarnings("ignore", category=UserWarning, module="torch.optim.lr_scheduler")def get_args_parser(add_help=True):import argparseparser = argparse.ArgumentParser(description="CCFS on CIFAR-100", add_help=add_help)parser.add_argument("--data-path", default=None, type=str, help="path to CIFAR-100 data folder")parser.add_argument("--filter-model", default="resnet18", type=str, help="filter model name")parser.add_argument("--teacher-model", default="resnet18", type=str, help="teacher model name")parser.add_argument("--teacher-path", default=None, type=str, help="path to teacher model")parser.add_argument("--eval-model", default="resnet18", type=str, help="model for final evaluation")parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")parser.add_argument("-b", "--batch-size", default=64, type=int, help="Batch size")parser.add_argument("--epochs", default=90, type=int, metavar="N", help="# training epochs for both the filter and the evaluation model")parser.add_argument("-j", "--workers", default=4, type=int, metavar="N", help="number of data loading workers (default: 16)")parser.add_argument("--opt", default="sgd", type=str, help="optimizer")parser.add_argument("--lr", default=0.1, type=float, help="initial learning rate")parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")parser.add_argument("--wd", "--weight-decay", default=1e-4, type=float, metavar="W", help="weight decay (default: 1e-4)", dest="weight_decay")parser.add_argument("--lr-scheduler", default="steplr", type=str, help="the lr scheduler (default: steplr)")parser.add_argument("--lr-warmup-epochs", default=0, type=int, help="the number of epochs to warmup (default: 0)")parser.add_argument("--lr-warmup-method", default="constant", type=str, help="the warmup method (default: constant)")parser.add_argument("--lr-warmup-decay", default=0.01, type=float, help="the decay for lr")parser.add_argument("--lr-step-size", default=30, type=int, help="decrease lr every step-size epochs")parser.add_argument("-T", "--temperature", default=20, type=float, help="temperature for distillation loss")parser.add_argument("--print-freq", default=1000, type=int, help="print frequency")# --- CCFS parameters ---# 目標每類最終的圖像數量IPCparser.add_argument("--image-per-class", default=50, type=int, help="number of synthetic images per class")parser.add_argument("--distill-data-path", default=None, type=str, help="path to already distilled data")# distillation portion,決定合成(蒸餾)圖像與真實選擇的比例;# cpc=IPC*alpha 是每類的合成圖像 (condensed per class)# spc=IPC-cpc是每類要選的真實圖像數 (selected per class)parser.add_argument('--alpha', type=float, default=0.2, help='Distillation portion')# 分幾個階段做”課程式“選樣(例如3輪)parser.add_argument('--curriculum-num', type=int, default=None, help='Number of curricula')# 粗階段式選被filter預測錯的(True)還是預測對的(false)parser.add_argument('--select-misclassified', action='store_true', help='Selection strategy in coarse stage')# 細階段的選法 ,simple/hard/random(對應論文里”最簡單未學會“/"最難"/"隨機")parser.add_argument('--select-method', type=str, default='simple', choices=['random', 'hard', 'simple'], help='Selection strategy in fine stage')# 是否每類均衡選parser.add_argument('--balance', action='store_true', help='Whether to balance the amount of the synthetic data between classes')# 選擇哪種方法評分parser.add_argument('--score', type=str, default='forgetting', choices=['logits', 'forgetting', 'cscore'], help='Difficulty score used in fine stage')# 如果不是logits而是預先算好的難度分(如forgetting score),用這個路徑讀入parser.add_argument('--score-path', type=str, default=None, help='Path to the difficulty score')parser.add_argument("--output-dir", default=".", type=str, help="path to save outputs")parser.add_argument("--num-eval", default=1, type=int, help="number of evaluations")return parserdef load_data(args):'''數據集加載Returns:dataset: 蒸餾數據集image_og, labels_og: 全量原始訓練樣本(用于選樣)dataset_test: 驗證集(test)以及對應sampler'''# Data loading codenormalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],std=[0.2023, 0.1994, 0.2010])print("Loading distilled data")train_transform = transforms.Compose([transforms.RandomResizedCrop(32),transforms.RandomHorizontalFlip(),transforms.ToTensor(),normalize,])# ImageFolderIPC自定義數據讀取,可以從大的ipc蒸餾數據集中每類選擇或隨機選擇cpc個圖像# cpc=IPC * alpha,是每類的合成圖像數dataset = ImageFolderIPC(root=args.distill_data_path, ipc=args.cpc, transform=train_transform)print("Loading validation data")val_transform = transforms.Compose([transforms.ToTensor(),normalize,])# 加載驗證集(test)dataset_test = torchvision.datasets.CIFAR100(root=args.data_path, train=False, download=True, transform=val_transform)print("Loading original training data")# 加載原始訓練集(用于做coarse selection / teacher correctness等)dataset_og = torchvision.datasets.CIFAR100(root=args.data_path, train=True, download=True, transform=val_transform)# 構造原始訓練數據:直接把全量CIFAR100d的所有圖像展開到一個大tensor image_og 和對應標簽labels_og。# 這在內存允許時可行,但規模增大時可以優化成分batch處理或lazy訪問images_og = [torch.unsqueeze(dataset_og[i][0], dim=0) for i in range(len(dataset_og))]labels_og = [dataset_og[i][1] for i in range(len(dataset_og))]images_og = torch.cat(images_og, dim=0)labels_og = torch.tensor(labels_og, dtype=torch.long)print("Creating data loaders")train_sampler = torch.utils.data.RandomSampler(dataset)test_sampler = torch.utils.data.SequentialSampler(dataset_test)return dataset, images_og, labels_og, dataset_test, train_sampler, test_samplerdef create_model(model_name, device, num_classes, path=None):# 根據名稱構造backbone(TODO:默認不加載預訓練權重)model = torchvision.models.get_model(model_name, weights=None, num_classes=num_classes)# 將下采樣第一層conv和pooling修改為適配CIFAR風格model.conv1 = nn.Conv2d(3, 64, kernel_size=(3,3), stride=(1,1), padding=(1,1), bias=False)model.maxpool = nn.Identity()# 加載預訓練權重 (TODO:是否加載預訓練權重)if path is not None:checkpoint = torch.load(path, map_location="cpu")if "model" in checkpoint:checkpoint = checkpoint["model"]elif "state_dict" in checkpoint:checkpoint = checkpoint["state_dict"]if "module." in list(checkpoint.keys())[0]:checkpoint = {k.replace("module.", ""): v for k, v in checkpoint.items()}model.load_state_dict(checkpoint)model.to(device)return modeldef curriculum_arrangement(spc, curriculum_num):'''課程分配安排將總共要選的每類真實圖像數spc等分到curriculum_num輪:例如spc=7,curriculum_num=3會分成[3,2,2](前面多余的向前分)'''remainder = spc % curriculum_numarrangement = [spc // curriculum_num] * curriculum_numfor i in range(remainder):arrangement[i] += 1return arrangementdef train_one_epoch(model, teacher_model, criterion, optimizer, data_loader, device, epoch, args):"""在一個 epoch(遍歷一遍 data_loader)里,用 KL 散度蒸餾(distillation)student(model)去學習 teacher_model 的“軟標簽”。具體做法是:把 teacher 和 student 的 logits 都除以溫度 T 后做 log_softmax,然后用 KLDivLoss;最后乘上 T^2 做梯度縮放,確保溫度對 loss 的影響保持一致。"""# 切換student到train模式model.train()# 切換teacher到eval模式,只做前向不更新teacher_model.eval()metric_logger = utils.MetricLogger(delimiter="  ")metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}"))metric_logger.add_meter("img/s", utils.SmoothedValue(window_size=10, fmt="{value}"))header = f"Epoch: [{epoch}]"for i, (image, target) in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)):start_time = time.time()image, target = image.to(device), target.to(device)# 1)teacher 前向teacher_output = teacher_model(image)# 2)student前向output = model(image)# 把Logits除以溫度系數,再做log_softmaxteacher_output_log_softmax = F.log_softmax(teacher_output/args.temperature, dim=1)output_log_softmax = F.log_softmax(output/args.temperature, dim=1)# 用KL散度計算loss,乘上T^2以抵消溫度縮放帶來的梯度變換loss = criterion(output_log_softmax, teacher_output_log_softmax) * (args.temperature ** 2)# 標準的 backward 流程optimizer.zero_grad()loss.backward()optimizer.step()# 計算 student 在原始 hard label(target)上的 top1/top5 準確率acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))batch_size = image.shape[0]# 更新 metric_logger 里的各項指標metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)metric_logger.meters["img/s"].update(batch_size / (time.time() - start_time))def evaluate(model, criterion, data_loader, device, log_suffix=""):"""在測試/驗證集上跑一個完整的 forward,計算交叉熵 loss 和 top1/top5 準確率。不做梯度更新,只做推理。"""model.eval()metric_logger = utils.MetricLogger(delimiter="  ")header = f"Test: {log_suffix}"num_processed_samples = 0  # 累計處理樣本數with torch.inference_mode():for image, target in data_loader:image = image.to(device, non_blocking=True)target = target.to(device, non_blocking=True)# 前向output = model(image)# 用硬標簽算交叉熵loss = criterion(output, target)# 計算 top1/top5acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))batch_size = image.shape[0]metric_logger.update(loss=loss.item())metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)num_processed_samples += batch_size# 如果是分布式,需要把各卡的樣本數累加num_processed_samples = utils.reduce_across_processes(num_processed_samples)if (hasattr(data_loader.dataset, "__len__")and len(data_loader.dataset) != num_processed_samplesand torch.distributed.get_rank() == 0):warnings.warn(f"It looks like the dataset has {len(data_loader.dataset)} samples, but {num_processed_samples} ""samples were used for the validation, which might bias the results. ""Try adjusting the batch size and / or the world size. ""Setting the world size to 1 is always a safe bet.")metric_logger.synchronize_between_processes()return metric_logger.acc1.global_avgdef curriculum_train(current_curriculum, dst_train, test_loader, model, teacher_model, args):"""對當前的”合成數據+已選真實數據“ dst_train進行一次完整的filter模型訓練(蒸餾學習):- 根據數據規模動態調整batch_size- 構造 DataLoader / Criterion / Optimizer / LR Scheduler(含 warmup)- 訓練 args.epochs 輪,后 20% 輪做驗證并記錄最佳 acc1返回:訓練好的 model 和最佳 top-1 準確率 best_acc1"""best_acc1 = 0# 1. 根據dst_train(合成+真實)大小粗略選batch sizeif len(dst_train) < 50 * args.num_classes:args.batch_size = 32elif 50 * args.num_classes <= len(dst_train) < 100 * args.num_classes:args.batch_size = 64else:args.batch_size = 128# 2. 用隨機采樣器包裝訓練集,保證每個epoch順序打散train_sampler = torch.utils.data.RandomSampler(dst_train)train_loader = torch.utils.data.DataLoader(dst_train,batch_size=args.batch_size,sampler=train_sampler,num_workers=args.workers,pin_memory=True,)# 3. 損失函數:硬標簽用CrossEntropy,蒸餾軟標簽用KLDivcriterion = nn.CrossEntropyLoss()criterion_kl = nn.KLDivLoss(reduction='batchmean', log_target=True)parameters = utils.set_weight_decay(model, args.weight_decay)# 構造優化器opt_name = args.opt.lower()if opt_name.startswith("sgd"):optimizer = torch.optim.SGD(parameters,lr=args.lr,momentum=args.momentum,weight_decay=args.weight_decay,nesterov="nesterov" in opt_name,)elif opt_name == "rmsprop":optimizer = torch.optim.RMSprop(parameters, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, eps=0.0316, alpha=0.9)elif opt_name == "adamw":optimizer = torch.optim.AdamW(parameters, lr=args.lr, weight_decay=args.weight_decay)else:raise RuntimeError(f"Invalid optimizer {args.opt}. Only SGD, RMSprop and AdamW are supported.")# 6. 構造主學習率調度器:StepLR / CosineAnnealingLR / ExponentialLRargs.lr_scheduler = args.lr_scheduler.lower()if args.lr_scheduler == "steplr":main_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)elif args.lr_scheduler == "cosineannealinglr":main_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs - args.lr_warmup_epochs, eta_min=0.0)elif args.lr_scheduler == "exponentiallr":main_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=args.lr_gamma)else:raise RuntimeError(f"Invalid lr scheduler '{args.lr_scheduler}'. Only StepLR, CosineAnnealingLR and ExponentialLR ""are supported.")# 7. 如果設置了 warmup,就把 warmup scheduler 和主 scheduler 串聯if args.lr_warmup_epochs > 0:if args.lr_warmup_method == "linear":warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=args.lr_warmup_decay, total_iters=args.lr_warmup_epochs)elif args.lr_warmup_method == "constant":warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer, factor=args.lr_warmup_decay, total_iters=args.lr_warmup_epochs)else:raise RuntimeError(f"Invalid warmup lr method '{args.lr_warmup_method}'. Only linear and constant are supported.")# milestones 指在第 args.lr_warmup_epochs 次后切換到 main_lr_schedulerlr_scheduler = torch.optim.lr_scheduler.SequentialLR(optimizer, schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[args.lr_warmup_epochs])else:lr_scheduler = main_lr_scheduler# 8. 開始訓練print("Start training on synthetic dataset...")start_time = time.time()pbar = tqdm(range(args.epochs), ncols=100)for epoch in pbar:# 每個 epoch 都調用前面寫好的 train_one_epoch(KL 蒸餾)train_one_epoch(model, teacher_model, criterion_kl, optimizer, train_loader, args.device, epoch, args)# 訓練完一輪后,調度學習率lr_scheduler.step()# 只在最后 20% 的輪次做驗證,節省時間if epoch > args.epochs * 0.8:acc1 = evaluate(model, criterion, test_loader, device=args.device)  # 這里 evaluate 用硬標簽 loss & 準確率# 更新 best_acc1if acc1 > best_acc1:best_acc1 = acc1# 在進度條上顯示當前/最佳準確率pbar.set_description(f"Epoch[{epoch}] Test Acc: {acc1:.2f}% Best Acc: {best_acc1:.2f}%")print(f"Best Accuracy {best_acc1:.2f}%")total_time = time.time() - start_timetotal_time_str = str(datetime.timedelta(seconds=int(total_time)))print(f"Training time {total_time_str}")return model, best_acc1def coarse_filtering(images_all, labels_all, filter, batch_size, args, get_correct=True):"""對全量原始訓練集 images_all 用 filter 模型做一次完整的推理:- 如果 get_correct=True,返回“預測正確”的樣本索引列表;否則返回“預測錯誤”的樣本索引列表。- 同時返回所有樣本的 raw logits(未 softmax)。"""true_labels = labels_all.cpu()filter.eval()   # 只做前向,不更新參數logits = None# 分批推理,防止一次OOMfor select_times in range((len(images_all)+batch_size-1)//batch_size):# slice出當前batch的圖像# detach 防止梯度追溯,再搬到devicecurrent_data_batch = images_all[batch_size*select_times : batch_size*(select_times+1)].detach().to(args.device)# 前向batch_logits = filter(current_data_batch)# concatenate 到一起if logits == None:logits = batch_logits.detach()else:logits = torch.cat((logits, batch_logits.detach()),0)# 取每行最大值的下標作為預測標簽predicted_labels = torch.argmax(logits, dim=1).cpu()# 根據get_correct 選正確或錯誤的索引target_indices = torch.where(true_labels == predicted_labels)[0] if get_correct else torch.where(true_labels != predicted_labels)[0]target_indices = target_indices.tolist()print('Acc on training set: {:.2f}%'.format(100*len(target_indices)/len(images_all) if get_correct else 100*(1-len(target_indices)/len(images_all))))return target_indices, logitsdef selection_logits(selected_idx, teacher_correct_idx, images_all, labels_all, filter, args):"""用 filter 模型的 logits 做 fine 階段的選樣:- teacher_correct_idx: teacher 在原始訓練集上預測正確的樣本索引- selected_idx: 已經在前幾輪中選過的樣本索引,避免重復返回當前輪要新增的選樣索引列表"""batch_size = 512true_labels = labels_all.cpu()filter.eval()print('Coarse Filtering...')# --- Coarse 階段:決定哪些樣本進入fine階段# 如果select_misclassified=True,就filter"預測錯誤"的樣本if args.select_misclassified:target_indices, logits = coarse_filtering(images_all, labels_all, filter, batch_size, args, get_correct=False)else:target_indices, logits = coarse_filtering(images_all, labels_all, filter, batch_size, args, get_correct=True)# —— 交叉過濾:只保留 teacher 也預測正確的樣本,且去除已選過的# teacher_correct_idx 是 teacher 在原始訓練集上預測正確的索引(論文里是“只有 teacher 也能正確的樣本才考慮”這一類過濾)。if teacher_correct_idx is not None:# 取 teacher_correct_idx 與 target_indices 的交集,再減去 selected_idxtarget_indices = list(set(teacher_correct_idx) & set(target_indices) - set(selected_idx))else:target_indices = list(set(target_indices) - set(selected_idx))print('Fine Selection...')selection = []if args.balance:# 如果要 class-balance,每個類單獨選 args.curpc 個target_idx_per_class = [[] for c in range(args.num_classes)]for idx in target_indices:target_idx_per_class[true_labels[idx]].append(idx)for c in range(args.num_classes):if args.select_method == 'random':# 隨機抽樣selection += random.sample(target_idx_per_class[c], args.curpc)elif args.select_method == 'hard':# 按 logits[c] 升序,logit 越低表示模型越“不自信” ? “更難”selection += sorted(target_idx_per_class[c], key=lambda i: logits[i][c], reverse=False)[:args.curpc]elif args.select_method == 'simple':# 按 logits[c] 降序,logit 越高表示模型越“自信” ? “簡單”selection += sorted(target_idx_per_class[c], key=lambda i: logits[i][c], reverse=True)[:args.curpc]else:# 不做 class-balance,直接在所有 target_indices 中選總數 = curpc * num_classesif args.select_method == 'random':selection = random.sample(target_indices, args.curpc*args.num_classes)elif args.select_method == 'hard':selection = sorted(target_indices, key=lambda i: logits[i][true_labels[i]], reverse=False)[:args.curpc*args.num_classes]elif args.select_method == 'simple':selection = sorted(target_indices, key=lambda i: logits[i][true_labels[i]], reverse=True)[:args.curpc*args.num_classes]return selectiondef selection_score(selected_idx, teacher_correct_idx, images_all, labels_all, filter, score, reverse, args):"""用預先計算好的difficult score 做fine階段的選樣:- score: numpy array, score[i]表示樣本i的難度分數- reverse: bool其余流程同 selection_logits,只是排序依據改為 score"""batch_size = 512true_labels = labels_all.cpu()filter.eval()print('Coarse Filtering...')# Coarse 階段同上if args.select_misclassified:target_indices, _ = coarse_filtering(images_all, labels_all, filter, batch_size, args, get_correct=False)else:target_indices, _ = coarse_filtering(images_all, labels_all, filter, batch_size, args, get_correct=True)# 交叉過濾 teacher_correct_idx & 去除已選if teacher_correct_idx is not None:target_indices = list(set(teacher_correct_idx) & set(target_indices) - set(selected_idx))else:target_indices = list(set(target_indices) - set(selected_idx))print('Fine Selection...')selection = []if args.balance:target_idx_per_class = [[] for c in range(args.num_classes)]for idx in target_indices:target_idx_per_class[true_labels[idx]].append(idx)for c in range(args.num_classes):if args.select_method == 'random':selection += random.sample(target_idx_per_class[c], min(args.curpc, len(target_idx_per_class[c])))elif args.select_method == 'hard':selection += sorted(target_idx_per_class[c], key=lambda i: score[i], reverse=reverse)[:args.curpc]elif args.select_method == 'simple':# 用外部 score(預先計算的 difficulty)selection += sorted(target_idx_per_class[c], key=lambda i: score[i], reverse=not reverse)[:args.curpc]else:if args.select_method == 'random':selection = random.sample(target_indices, min(args.curpc*args.num_classes, len(target_indices)))elif args.select_method == 'hard':selection = sorted(target_indices, key=lambda i: score[i], reverse=reverse)[:args.curpc*args.num_classes]elif args.select_method == 'simple':selection = sorted(target_indices, key=lambda i: score[i], reverse=not reverse)[:args.curpc*args.num_classes]return selectiondef main(args):'''Preparation'''print('=> args.output_dir', args.output_dir)start_time = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')log_dir = os.path.join(args.output_dir, 'CIFAR-100', start_time)os.makedirs(log_dir, exist_ok=True)device = torch.device(args.device)if device.type == 'cuda':print('Using GPU')torch.backends.cudnn.benchmark = True# 計算cpc(合成每類)和spc(要從真實數據選的每類數)args.cpc = int(args.image_per_class * args.alpha)   # condensed images per classargs.spc = args.image_per_class - args.cpc          # selected real images per classargs.num_classes = 100print('Target IPC: {}, num_classes: {}, distillation portion: {}, distilled images per class: {}, real images to be selected per class: {}'.format(args.image_per_class, args.num_classes, args.alpha, args.cpc, args.spc))# 加載數據dataset_dis, images_og, labels_og, dataset_test, train_sampler, test_sampler = load_data(args)# 加載difficulty scoreif args.score == 'forgetting':score = np.load(args.score_path)reverse = Trueelif args.score == 'cscore':score = np.load(args.score_path)reverse = Falsecurriculum_num = args.curriculum_num# 構造curriculum_arrangement分配:每輪要選多少個真實樣本arrangement = curriculum_arrangement(args.spc, curriculum_num)# 加載測試集test_loader = torch.utils.data.DataLoader(dataset_test, batch_size=512, sampler=test_sampler, num_workers=args.workers, pin_memory=True)# 加載教師模型teacher_model = create_model(args.teacher_model, device, args.num_classes, args.teacher_path)# 凍結教師模型參數for p in teacher_model.parameters():p.requires_grad = Falseteacher_model.eval()# 使用教師模型在原始數據集上做一次初篩,只有teacher預測對的樣本才可能被選入teacher_correct_idx, _ = coarse_filtering(images_og, labels_og, teacher_model, 512, args, get_correct=True)print('teacher acc@1 on original training data: {:.2f}%'.format(100*len(teacher_correct_idx)/len(images_og)))'''Curriculum selection'''idx_selected = []dataset_sel = Nonedst_sel_transform = transforms.Compose([transforms.RandomResizedCrop(32),transforms.RandomHorizontalFlip(),])print('Selected images per class arrangement in each curriculum: ', arrangement)# 開始課程學習for i in range(curriculum_num):print('----Curriculum [{}/{}]----'.format(i+1, curriculum_num))args.curpc = arrangement[i]# 第0輪以蒸餾合成集為起點if i == 0:print('Begin with distilled dataset')syn_dataset = dataset_disdataset_sel = []print('Synthetic dataset size:', len(syn_dataset), "distilled data:", len(dataset_dis), "selected data:", len(dataset_sel))# 訓練一個新的filter(每輪都從頭開始訓練)filter = create_model(args.filter_model, device, args.num_classes)# TODO:課程訓練,教師模型打軟標簽filter, best_acc1 = curriculum_train(i, syn_dataset, test_loader, filter, teacher_model, args)print('Selecting real data...')if args.score == 'logits':selection = selection_logits(idx_selected, teacher_correct_idx, images_og, labels_og, filter, args)else:selection = selection_score(idx_selected, teacher_correct_idx, images_og, labels_og, filter, score, reverse, args)idx_selected += selectionprint('Selected {} in this curriculum'.format(len(selection)))imgs_select = images_og[idx_selected]labs_select = labels_og[idx_selected]dataset_sel = utils.TensorDataset(imgs_select, labs_select, dst_sel_transform)syn_dataset = torch.utils.data.ConcatDataset([dataset_dis, dataset_sel])print('----All curricula finished----')print('Final synthetic dataset size:', len(syn_dataset), "distilled data:", len(dataset_dis), "selected data:", len(dataset_sel))   print('Saving selected indices...')idx_file = os.path.join(log_dir, f'selected_indices.json')with open(idx_file, 'w') as f:json.dump({'ipc': args.image_per_class,'alpha': args.alpha, 'idx_selected': idx_selected}, f)f.close()'''Final evaluation'''num_eval = args.num_evalaccs = []for i in range(num_eval):print(f'Evaluation {i+1}/{num_eval}')eval_model = create_model(args.eval_model, device, args.num_classes)_, best_acc1 = curriculum_train(0, syn_dataset, test_loader, eval_model, teacher_model, args)accs.append(best_acc1)acc_mean = np.mean(accs)acc_std = np.std(accs)print('----Evaluation Results----')print(f'Acc@1(mean): {acc_mean:.2f}%, std: {acc_std:.2f}')print('Saving results...')log_file = os.path.join(log_dir, f'exp_log.txt')with open(log_file, 'w') as f:f.write('EXP Settings: \n')f.write(f'IPC: {args.image_per_class},\tdistillation portion: {args.alpha},\tcurriculum_num: {args.curriculum_num}\n')f.write(f'filter model: {args.filter_model},\tteacher model: {args.teacher_model},\tbatch_size: {args.batch_size},\tepochs: {args.epochs}\n')f.write(f"coarse stage strategy: {'select misclassified' if args.select_misclassified else 'select correctly classified'}\n")f.write(f'fine stage strategy: {args.select_method},\tdifficulty score: {args.score},\tbalance: {args.balance}\n')f.write(f'eval model: {args.eval_model},\tAcc@1: {acc_mean:.2f}%,\tstd: {acc_std:.2f}\n')f.close()if __name__ == "__main__":args = get_args_parser().parse_args()main(args)

算法整體邏輯

算法輸入

  • 蒸餾合成集DdistillD_{distill}Ddistill?: 已經經過某種蒸餾算法生成的小規模“合成”數據集,每類包含CPC=IPC×αCPC=IPC×\alphaCPC=IPC×α
  • 原始訓練集(image_og,labels_og):完整的CIFAR100訓練樣本,用于選樣。
  • 教師模型?teacher?_{teacher}?teacher?: 在原始訓練集上表現優秀的固定模型,用于提供“正確”的軟標簽參考。
  • 超參數: IPCIPCIPC(每類最終圖數)、α\alphaα(蒸餾比例)、課程輪數 JJJ、粗篩策略、精篩策略、難度分數類型等。

整體流程

  1. 初始化

    • 解析命令行參數,計算:cpc=?IPC×α?cpc=?IPC×α?cpc=?IPC×α?, spc=IPC×cpcspc=IPC×cpcspc=IPC×cpc
    • 加載DdistillD_{distill}Ddistill?、原始訓練集和驗證集。
    • 加載并凍結教師模型?teacher?_{teacher}?teacher?,在原始訓練集上做一次推理,記錄教師預測正確的索引集合 IteacherI_{teacher}Iteacher?
  2. 課程分配

    將每類總共要選的spcspcspc張真實圖,均勻分配到JJJ輪: [k1,k2,…,kJ],∑jkj=spc[k_1,k_2,…,k_J],\sum_{j}{k_j} =spc[k1?,k2?,,kJ?],j?kj?=spc

  3. 多輪“粗-細”選樣循環

    令當前集合S0=DdistillS_0=D_{distill}S0?=Ddistill?,已選索引集合Isel=?I_{sel}=?Isel?=?

    對每個課程階段j=1...Jj=1...Jj=1...J:

    i. 蒸餾訓練Filter:

    • Sj?1S_{j-1}Sj?1?上,用教師模型的“軟標簽”蒸餾訓練一個新的filter模型?j?_j?j?

    ii. Coarse(粗過濾)

    • ?j?_j?j?在整個原始訓練集上做推理,得到所有樣本的logits和預測標簽。
    • 根據 select_misclassified 決定保留“錯分”樣本索引,或保留“分對”樣本索引,記為候選集 CCC
    • 交叉過濾:僅保留既在 CCC 中、又在IteacherI_{teacher}Iteacher?中,且不在 IselI_{sel}Isel? 中的索引。

    iii. Fine(精細選擇)

    • 在上述候選索引里,依據“logits”或外部 pre-computed difficulty score,對每個索引排序,規則可以是simple、hard、random
    • 如選 --balance,則每類各取kjk_jkj?張;否則全體一并取總數 kj×num_classesk_j×num\_classeskj?×num_classes
    • 將本輪選中的新索引加入IselI_{sel}Isel?

    iv. 更新合成集
    Sj←Ddistill∪{真實樣本i:i∈Isel}.S_j←D_{distill}∪\{真實樣本 i:i∈I_{sel}\}. Sj?Ddistill?{真實樣本i:iIsel?}.

  4. 保存 & 最終評估

    • IselI_{sel}Isel?輸出到Json以供后續復現
    • 用最終的混合集SjS_jSj?訓練一個新的evaluation模型,測多次Top-1準確率,取均值與標準差。

這個流程確保每輪都針對模型“真正沒有學到”的部分,有序補充,最終合成集既覆蓋常見知識,也涵蓋關鍵難點

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/92091.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/92091.shtml
英文地址,請注明出處:http://en.pswp.cn/web/92091.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

【Linux系統】詳解,進程控制

前言&#xff1a; 上文我們講到了Linux中的虛擬空間地址&#xff0c;知道了一個進程對應一個虛擬地址空間&#xff0c;虛擬空間地址與物理地址之間通過頁表映射....【Linux】虛擬地址空間-CSDN博客 本文我們來講一講Linux系統是如何控制進程的&#xff01; 如果喜歡本期文章&am…

Matplotlib(五)- 繪制子圖

文章目錄一、子圖概述1. 子圖介紹2. 子圖布局2.1 網格布局2.2 自由布局二、繪制等分區域子圖1. 使用 plt.subplot() 繪制子圖示例&#xff1a;繪制多個子圖示例&#xff1a;工業月度同比情況2. 使用 plt.subplots() 繪制子圖示例&#xff1a;繪制多個子圖示例&#xff1a;部分國…

C++中互斥鎖、共享鎖深度解析

一&#xff0c;互斥鎖互斥鎖&#xff08;Mutex&#xff0c;全稱 Mutual Exclusion&#xff09;是并發編程中用于保護共享資源的核心同步機制。它通過確保同一時間僅有一個線程訪問臨界區&#xff08;Critical Section&#xff09;&#xff0c;解決多線程環境下的數據競爭和不一…

Qt中的QWebSocket 和 QWebSocketServer詳解:從協議說明到實際應用解析

前言 本篇圍繞 QWebSocket 和 QWebSocketServer&#xff0c;從協議基礎、通信模式、數據傳輸特點等方面展開&#xff0c;結合具體接口應用與實戰案例進行說明。 在實時網絡通信領域&#xff0c;WebSocket 技術以其獨特的全雙工通信能力&#xff0c;成為連接客戶端與服務器的重要…

機器學習 —— 決策樹

機器學習 —— 決策樹&#xff08;Decision Tree&#xff09;詳細介紹決策樹是一種直觀且易于解釋的監督學習算法&#xff0c;廣泛應用于分類和回歸任務。它通過模擬人類決策過程&#xff0c;將復雜問題拆解為一系列簡單的判斷規則&#xff0c;最終形成類似 “樹” 狀的結構。以…

車規MCU軟錯誤防護技術的多維度分析與優化路徑

摘要&#xff1a;隨著汽車電子技術的飛速發展&#xff0c;微控制單元&#xff08;MCU&#xff09;在汽車電子系統中的應用日益廣泛。然而&#xff0c;大氣中子誘發的單粒子效應&#xff08;SEE&#xff09;對MCU的可靠性構成了嚴重威脅。本文深入探討了軟錯誤防護技術在車規MCU…

原生微信小程序實現語音轉文字搜索---同聲傳譯

效果展示 ![在這里插入圖片描述](https://i-blog.csdnimg.cn/direct/23257ce3b6c149a1bb54fd8bc2a05c68.png#pic_center 注意&#xff1a;引入同聲傳譯組件請看這篇文章 1.search.wxml <view class"search-page"><navigation-bar title"搜索" …

Wireshark安裝過程缺失vc_runtimeMinimum_x64.msi文件,安裝 Visual C++ Redistributable

一、我大意了 一開始是Npcap裝不上。 在這個網站下的&#xff1a; Wireshark (kafan58.com) 安裝程序&#xff1a; 安裝過程&#xff1a; 無語死了&#xff0c;感覺被騙了......外網下的才是最正版的。 二、外網正版 下載最新的4.4.8版本Wireshark重新安裝 2.1 vc_runtime…

高通平臺Wi-Fi Display學習-- 調試 Wi-Fi Display 問題

4.1 調試 WFD 性能 4.1.1 通過啟用調節器模式驗證 WFD 當系統設為調節器模式時,設備的運行時鐘將達到峰值。要在系統中啟用調節器模式,應 在序列中輸入以下命令: 1. adb shell stop mpdecision 2. adb shell echo 1→/sys/devices/system/cpu/cpu1/online 3. adb shell…

5G專網與SD-WAN技術融合:某飲料智能工廠網絡架構深度解析

隨著工業互聯網的快速發展&#xff0c;制造業正從傳統的生產模式向智能化、數字化方向轉型。某飲料智能工廠項目創新性地引入了5G專網與SD-WAN技術&#xff0c;形成了“連接-計算-應用-安全”的全鏈條網絡架構。本文將深入剖析這兩種技術在智能工廠中的應用場景、部署架構&…

Java項目:基于SSM框架實現的公益網站管理系統【ssm+B/S架構+源碼+數據庫+畢業論文+答辯PPT+遠程部署】

摘 要 現代經濟快節奏發展以及不斷完善升級的信息化技術&#xff0c;讓傳統數據信息的管理升級為軟件存儲&#xff0c;歸納&#xff0c;集中處理數據信息的管理方式。本公益網站就是在這樣的大環境下誕生&#xff0c;其可以幫助管理者在短時間內處理完畢龐大的數據信息&#x…

向華為學習——IPD流程體系之IPD術語

第一章 IPD體系 1.1集成產品開發IPD Integrated Product Development,IPD是一種領先的、成熟的產品開發的管理思想和管理模式。它是根據大量成功的產品開發管理實踐總結出來的,并被大量實踐證明的高效的產品開發模式。通過IPD,可建立起基于市場和客戶需求驅動的集成產品開…

落霞歸雁:從自然之道到“存內計算”——用算法思維在芯片里開一條“數據高速航道”

作者 落霞歸雁&#xff08;CSDN首發&#xff0c;轉載請注明&#xff09; 段落一 現象&#xff1a;當“摩爾”老去&#xff0c;數據卻在狂奔 過去 30 年&#xff0c;CPU 頻率翻了 60 倍&#xff0c;而 DRAM 帶寬只翻了 20 倍。算力與帶寬的剪刀差&#xff0c;讓“計算”變成“等…

StyleX:Meta推出的高性能零運行時CSS-in-JS解決方案

簡介 StyleX 是由 Meta 開發的零運行時 CSS-in-JS 解決方案&#xff0c;在構建時將樣式編譯為靜態 CSS&#xff0c;消除運行時開銷。 核心特性 零運行時開銷 – 構建時編譯為靜態 CSS類型安全 – 完整的 TypeScript 支持原子化 CSS – 自動生成原子化類名&#xff0c;最小化…

LINUX 85 SHElL if else 前瞻 實例

問題 判斷用戶是否存在 id user id $user變量判斷vsftpd軟件包被安裝 rpm -q vsftpd rpm -ql vsftpd >& null[rootweb ~]# rpm -ql vsftpd >/dev/null 2>&1 您在 /var/spool/mail/root 中有郵件yum install vsftpd 內核主版本判斷 uname -rcut -d[rootweb ~]#…

2025 年非關系型數據庫全面指南:類型、優勢

非關系型數據庫的分類與特點隨著數據量呈指數級增長和數據類型日益多樣化&#xff0c;傳統關系型數據庫在處理海量非結構化數據時面臨著嚴峻挑戰。非關系型數據庫&#xff08;NoSQL&#xff09;應運而生&#xff0c;它摒棄了傳統關系模型的約束&#xff0c;采用更靈活的數據存儲…

深度殘差網絡ResNet結構

Deep Residual Learning for Image Recognition&#xff0c;由Kaiming He、Xiangyu Zhang、Shaoqing Ren和Jian Sun于2016年發表在CVPR上 1512.03385 (arxiv.org)https://arxiv.org/pdf/1512.03385 下圖中&#xff0c;左側為VGG19網絡&#xff0c;中間為34層的普通網絡&#xf…

python筆記--socket_TCP模擬瀏覽器實現

""" 1,導包 2,創建TCP套接字 3,建立連接 4,拼接客戶端請求報文 5,發送請求報文 6,接收響應報文 7,過濾出html頁面 8,保存為html文件 9,關閉套接字 """ # 1,導包 import socket # 2,創建TCP套接字 tcp_socketsocket.socket(socket.AF_INET,socket…

西門子PLC基礎指令4:置位指令 S、復位指令 R

布爾指令 1、置位指令 S Setbit 是要進行置位操作的地址的首地址&#xff0c;N 是從該首地址開始連續置位的位數 。 LD I0.0 // 裝載輸入繼電器I0.0的狀態&#xff08;當I0.0為ON時&#xff0c;執行后續指令&#xff09; S Q0.0, 3 // 從Q0.0開始&#xff0c;連續置位3…

2.3 子組件樣式沖突詳解

Vue2組件樣式沖突的成因與解決方案組件樣式沖突的根本原因在Vue單頁面應用中&#xff0c;所有組件的DOM結構最終都會合并到同一個index.html 頁面中。若子組件未使用scoped屬性&#xff0c;其樣式會默認全局生效&#xff0c;導致不同組件中相同選擇器&#xff08;如h1、.contai…