一. 環境
- windonws 11
- RTX5060
- CUDA 12.8
- Pytorch 2.9.0dev20250630+cu128
- torchvision 0.23.0dev20250701+cu128
二. 代碼
基于Mobilenet-CustomData 的Mobilenet_Pretrain.ipynb
1. 定義Mobile Net V1
import os
import time
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as datasetsfrom utils_pretrain import *# Define your data path here
IMAGENET_PATH = r'E:\AI\tiny-imagenet-200\tiny-imagenet-200'class Net(nn.Module):def __init__(self):super(Net, self).__init__()def conv_bn(inp, oup, stride):return nn.Sequential(nn.Conv2d(inp, oup, 3, stride, 1, bias=False),nn.BatchNorm2d(oup),nn.ReLU(inplace=True))def conv_dw(inp, oup, stride):return nn.Sequential(nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),nn.BatchNorm2d(inp),nn.ReLU(inplace=True),nn.Conv2d(inp, oup, 1, 1, 0, bias=False),nn.BatchNorm2d(oup),nn.ReLU(inplace=True),)self.model = nn.Sequential(conv_bn( 3, 32, 2), conv_dw( 32, 64, 1),conv_dw( 64, 128, 2),conv_dw(128, 128, 1),conv_dw(128, 256, 2),conv_dw(256, 256, 1),conv_dw(256, 512, 2),conv_dw(512, 512, 1),conv_dw(512, 512, 1),conv_dw(512, 512, 1),conv_dw(512, 512, 1),conv_dw(512, 512, 1),conv_dw(512, 1024, 2),conv_dw(1024, 1024, 1),nn.AvgPool2d(7),)self.fc = nn.Linear(1024, 1000)def forward(self, x):x = self.model(x)x = x.view(-1, 1024)x = self.fc(x)return xmobilenet_model1 = Net()
2. 加載fine_tune模型
mobilenet_model1 = torch.nn.DataParallel(mobilenet_model1).cuda()
params = torch.load('D:\CodeSpace\MobileNet\Mobilenet-CustomData-master\Mobilenet-CustomData-master\Mobiilenet-finetune\moblienet_30e.pth.tar')
mobilenet_model1.load_state_dict(params,strict=False)
3. validate 設置
criterion = nn.CrossEntropyLoss().cuda()
batch_size = 10
workers = 4
epochs = 1
print_freq = 100valdir = os.path.join(IMAGENET_PATH, 'val')
if not os.path.exists(valdir):print(f"val文件夾 {valdir} 不存在,請檢查路徑。")exit(1)normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])val_loader = torch.utils.data.DataLoader(datasets.ImageFolder(valdir, transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),normalize,])),batch_size=batch_size, shuffle=False,num_workers=workers, pin_memory=True)def validate(val_loader, model, criterion):batch_time = AverageMeter()losses = AverageMeter()top1 = AverageMeter()top5 = AverageMeter()# switch to evaluate modemodel.eval()end = time.time()for i, (input, target) in enumerate(val_loader):# target = target.cuda(async=True) # non_blockingtarget = target.cuda(non_blocking=True)#input_var = torch.autograd.Variable(input, volatile=True)#target_var = torch.autograd.Variable(target, volatile=True)with torch.no_grad():input_var = input.cuda()target_var = target.cuda()# compute outputoutput = model(input_var)loss = criterion(output, target_var)# measure accuracy and record lossprec1, prec5 = accuracy(output.data, target, topk=(1, 5))# .data[0]改成.item()losses.update(loss.item(), input.size(0))top1.update(prec1.item(), input.size(0))top5.update(prec5.item(), input.size(0))# measure elapsed timebatch_time.update(time.time() - end)end = time.time()if i % print_freq == 0:print('Test: [{0}/{1}]\t''Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t''Loss {loss.val:.4f} ({loss.avg:.4f})\t''Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t''Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(i, len(val_loader), batch_time=batch_time, loss=losses,top1=top1, top5=top5))print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'.format(top1=top1, top5=top5))return top1.avg
4. validate 和結果
best_prec1 = 0for epoch in range(0, epochs):# evaluate on validation setprec1 = validate(val_loader, mobilenet_model1, criterion)# prec3 = validate(val_loader, mobilenet_model3, criterion)# remember best prec@1 and save checkpointis_best = prec1 > best_prec1best_prec1 = max(prec1, best_prec1)
Test: [0/1000] Time 10.112 (10.112) Loss 3.6720 (3.6720) Prec@1 20.000 (20.000) Prec@5 90.000 (90.000)
Test: [100/1000] Time 0.008 (0.109) Loss 2.5464 (3.8335) Prec@1 40.000 (4.752) Prec@5 100.000 (84.851)
Test: [200/1000] Time 0.009 (0.059) Loss 4.6422 (3.8443) Prec@1 0.000 (5.174) Prec@5 70.000 (83.333)
Test: [300/1000] Time 0.009 (0.042) Loss 3.4911 (3.8356) Prec@1 10.000 (5.316) Prec@5 80.000 (83.588)
Test: [400/1000] Time 0.009 (0.034) Loss 4.4792 (3.8361) Prec@1 0.000 (5.162) Prec@5 90.000 (83.691)
Test: [500/1000] Time 0.009 (0.028) Loss 3.6648 (3.8577) Prec@1 0.000 (5.070) Prec@5 90.000 (83.972)
Test: [600/1000] Time 0.008 (0.025) Loss 2.6866 (3.8470) Prec@1 20.000 (5.158) Prec@5 90.000 (84.160)
Test: [700/1000] Time 0.008 (0.023) Loss 3.4880 (3.8465) Prec@1 0.000 (5.235) Prec@5 100.000 (84.379)
Test: [800/1000] Time 0.008 (0.021) Loss 4.0429 (3.8602) Prec@1 0.000 (5.006) Prec@5 70.000 (84.120)
Test: [900/1000] Time 0.009 (0.019) Loss 3.8612 (3.8640) Prec@1 0.000 (5.006) Prec@5 60.000 (84.029)* Prec@1 5.060 Prec@5 83.860