1.背景
????????在網上找了一些資料用來訓練關鍵點,一般都是人臉或者車牌關鍵點訓練,或者是聯合檢測一起訓練。很少有是單獨基于輕量級網絡訓練單獨關鍵點模型的工程,本文簡單介紹一種簡單方法和代碼。
2.代碼模塊
(1)網絡結構
文件:model.py
import torch.nn as nn
import torch
import torch.nn.functional as F
import torch.nn.init as init
class Fire(nn.Module):
? ? def __init__(self, inplanes, squeeze_planes,
? ? ? ? ? ? ? ? ?expand1x1_planes, expand3x3_planes):
? ? ? ? super(Fire, self).__init__()
? ? ? ? self.inplanes = inplanes
? ? ? ? self.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1)
? ? ? ??
? ? ? ? self.squeeze_activation = nn.ReLU(inplace=True)
? ? ? ??
? ? ? ? self.expand1x1 = nn.Conv2d(squeeze_planes, expand1x1_planes,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?kernel_size=1)
? ? ? ? #self.expand1x1_activation = nn.ReLU(inplace=True)
? ? ? ? self.expand3x3 = nn.Conv2d(squeeze_planes, expand3x3_planes,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?kernel_size=3, padding=1)
? ? ? ? #self.expand3x3_activation = nn.ReLU(inplace=True)
? ? def forward(self, x):
? ? ? ? x = self.squeeze_activation(self.squeeze(x))
? ? ? ? return torch.cat([
? ? ? ? ? ? self.expand1x1(x),
? ? ? ? ? ? self.expand3x3(x)
? ? ? ? ], 1)
class RegressNet(nn.Module): ??
? ? def __init__(self,version=1.0,export=False):
? ? ? ? super(RegressNet, self).__init__()
? ? ? ? if version not in [1.0, 1.1]:
? ? ? ? ? ? raise ValueError("Unsupported RegressNet version {version}:"
? ? ? ? ? ? ? ? ? ? ? ? ? ? ?"1.0 or 1.1 expected".format(version=version))
? ? ? ? self.export = export
? ? ? ? print(version)
? ? ? ? if version == 1.0:
? ? ? ? ? ? self.features = nn.Sequential(
? ? ? ? ? ? ? ? nn.Conv2d(3, 16, kernel_size=3,padding=(1,1), stride=1),
? ? ? ? ? ? ? ? nn.ReLU(inplace=True),
? ? ? ? ? ? ? ? nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True),
? ? ? ? ? ? ? ? Fire(16, 16, 32, 32),
? ? ? ? ? ? ? ? nn.ReLU(inplace=True),
? ? ? ? ? ? ? ? nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True),
? ? ? ? ? ? ? ? Fire(64, 32, 32, 32),
? ? ? ? ? ? ? ? nn.ReLU(inplace=True),
? ? ? ? ? ? ? ? nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True),
? ? ? ? ? ? ? ? Fire(64, 32, 64, 64),
? ? ? ? ? ? ? ? nn.ReLU(inplace=True),
? ? ? ? ? ? ? ? Fire(128, 32, 64, 64),
? ? ? ? ? ? ? ? nn.ReLU(inplace=True),
? ? ? ? ? ? ? ? nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
? ? ? ? ? ? ? ? nn.Conv2d(128, 128, kernel_size=3,padding=(0,0), stride=2),
? ? ? ? ? ? )
? ? ? ? else:
? ? ? ? ? ? self.features = nn.Sequential(
? ? ? ? ? ? ? ? nn.Conv2d(3, 64, kernel_size=3, stride=2),
? ? ? ? ? ? ? ? nn.ReLU(inplace=True),
? ? ? ? ? ? ? ? nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
? ? ? ? ? ? ? ? Fire(64, 16, 64, 64),
? ? ? ? ? ? ? ? Fire(128, 16, 64, 64),
? ? ? ? ? ? ? ? nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
? ? ? ? ? ? ? ? Fire(128, 32, 128, 128),
? ? ? ? ? ? ? ? Fire(256, 32, 128, 128),
? ? ? ? ? ? ? ? nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
? ? ? ? ? ? ? ? Fire(256, 48, 192, 192),
? ? ? ? ? ? ? ? Fire(384, 48, 192, 192),
? ? ? ? ? ? ? ? Fire(384, 64, 256, 256),
? ? ? ? ? ? ? ? Fire(512, 64, 256, 256),
? ? ? ? ? ? )
? ? ? ? # Final convolution is initialized differently form the rest
? ? ? ? #final_conv = nn.Conv2d(512, self.num_classes, kernel_size=1)
? ? ? ? #self.classifier = nn.Sequential(
? ? ? ? # ? ?nn.Dropout(p=0.5),
? ? ? ? # ? ?final_conv,
? ? ? ? # ? ?nn.ReLU(inplace=True),
? ? ? ? # ? ?nn.AdaptiveAvgPool2d((1, 1))
? ? ? ? #)
? ? ? ? self.fc= nn.Linear(128,8)
? ? ? ? MAE_Loss = torch.nn.L1Loss()
? ? ? ? self.loss = MAE_Loss
? ? ? ? for m in self.modules():
? ? ? ? ? ? if isinstance(m, nn.Conv2d):
? ? ? ? ? ? ? ? init.kaiming_uniform_(m.weight)
? ? ? ? ? ? ? ? if m.bias is not None:
? ? ? ? ? ? ? ? ? ? init.constant_(m.bias, 0)
? ? def forward(self, x):
? ? ? ? x = self.features(x)
? ? ? ? #x = x.squeeze()
? ? ? ? #x = x.flatten(0)
? ? ? ? x=x.view(-1,128)#使用view函數
? ? ? ? x = self.fc(x)
? ? ? ? #print(x)
? ? ? ? return x ?
(2)訓練工程
文件:train.py 以訓練四個關鍵點為例
import numpy as np
from math import radians, cos, sin
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
#import imutils
import torch
from PIL import Image
import random
import cv2
import xml.etree.ElementTree as ET
from torch.utils.data import Dataset
import os
import torch.nn as nn
import torchvision.models as models
import torch.nn.functional as F
import torch.nn.init as init
import torch.optim as optim
import time
from tqdm import tqdm
from model import RegressNet
class Transforms():
? ? def __init__(self):
? ? ? ? pass
? ? def rotate(self, image, landmarks, angle):
? ? ? ? # 隨機生成一個在 -angle 到 +angle 范圍內的旋轉角度
? ? ? ? angle = random.uniform(-angle, +angle)
? ? ? ? # 基于二維平面上的旋轉變換的數學特性構建旋轉矩陣
? ? ? ? transformation_matrix = torch.tensor([
? ? ? ? ? ? [+cos(radians(angle)), -sin(radians(angle))],
? ? ? ? ? ? [+sin(radians(angle)), +cos(radians(angle))]
? ? ? ? ])
? ? ? ? # 對圖像進行旋轉:相比于 PIL 的圖像旋轉計算開銷更小
? ? ? ? image = imutils.rotate(np.array(image), angle)
? ? ? ? # 將關鍵點坐標中心化:簡化旋轉變換的計算,同時確保關鍵點的變換和圖像變換的對應關系
? ? ? ? landmarks = landmarks - 0.5
? ? ? ? # 將關鍵點坐標應用旋轉矩陣
? ? ? ? new_landmarks = np.matmul(landmarks, transformation_matrix)
? ? ? ? # 恢復關鍵點坐標范圍
? ? ? ? new_landmarks = new_landmarks + 0.5
? ? ? ? return Image.fromarray(image), new_landmarks
? ? def resize(self, image, landmarks, img_size):
? ? ? ? # 調整圖像大小
? ? ? ? image = TF.resize(image, img_size)
? ? ? ? return image, landmarks
? ? def color_jitter(self, image, landmarks):
? ? ? ? # 定義顏色調整的參數:亮度、對比度、飽和度和色調
? ? ? ? color_jitter = transforms.ColorJitter(brightness=0.3,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? contrast=0.3,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? saturation=0.3,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? hue=0.1)
? ? ? ? # 對圖像進行顏色調整
? ? ? ? image = color_jitter(image)
? ? ? ? return image, landmarks
? ? def crop_face(self, image, landmarks, crops):
? ? ? ? # 獲取裁剪參數
? ? ? ? left = int(crops['left'])
? ? ? ? top = int(crops['top'])
? ? ? ? width = int(crops['width'])
? ? ? ? height = int(crops['height'])
? ? ? ? # 對圖像進行裁剪
? ? ? ? image = TF.crop(image, top, left, height, width)
? ? ? ? # 獲取裁剪后的圖像形狀
? ? ? ? img_shape = np.array(image).shape
? ? ? ? # 對關鍵點坐標進行裁剪后的調整
? ? ? ? landmarks = torch.tensor(landmarks) - torch.tensor([[left, top]])
? ? ? ? # 歸一化關鍵點坐標
? ? ? ? landmarks = landmarks / torch.tensor([img_shape[1], img_shape[0]])
? ? ? ? return image, landmarks
? ? def __call__(self, image, landmarks):
? ? ? ? # 將圖像從數組轉換為 PIL 圖像對象
? ? ? ? image = Image.fromarray(image)
? ? ? ? # 裁剪圖像并調整關鍵點
? ? ? ? # 調整圖像大小
? ? ? ? image, landmarks = self.resize(image, landmarks, (64, 64))
? ? ? ? # 對圖像進行顏色調整
? ? ? ? image, landmarks = self.color_jitter(image, landmarks)
? ? ? ? # 對圖像和關鍵點進行旋轉變換
? ? ? ? #image, landmarks = self.rotate(image, landmarks, angle=10)
? ? ? ? # 將圖像從 PIL 圖像對象轉換為 Torch 張量
? ? ? ? image = TF.to_tensor(image)
? ? ? ? # 標準化圖像像素值
? ? ? ? image = TF.normalize(image, [0.5], [0.5])
? ? ? ? return image, landmarks
?
(3)dataset定義,數據長度為8 x1,y1,x2,y2,x3,y3,x4,y4
#標簽排列規則
XXX.jpg x1/width y1/height?x2/width y2/height?x3/width y3/height?x4/width y4/height
class FaceLandmarksDataset(Dataset):
? ? def __init__(self, transform=None):
? ? ? ? #root = os.listdir(r"C:/")
? ? ? ? with open(r"C:\DL_Work\test_pics\path.txt", 'r', encoding="utf-8") as r:
? ? ? ? ? ? root = r.readlines()
? ? ? ? # 初始化變量
? ? ? ? self.image_filenames = []
? ? ? ? self.landmarks = []
? ? ? ? self.crops = []
? ? ? ? self.transform = transform
? ? ? ? self.root_dir = r'C:\DL_Work\test_pics/'
? ? ? ? # 遍歷 XML 數據:root[2] 表示 XML 中的第三個元素,即 <images> 部分,其中包含了每張圖像的標注信息
? ? ? ? for filename in root:
? ? ? ? ? ? pic_path = filename.split(" ")[0]
? ? ? ? ? ? self.image_filenames.append(os.path.join(self.root_dir, pic_path))
? ? ? ? ? ? #self.crops.append(filename)
? ? ? ? ? ? landmark = []
? ? ? ? ? ? for num in range(4):
? ? ? ? ? ? ? ? x_coordinate = int( filename.split(" ")[num*2+1])
? ? ? ? ? ? ? ? y_coordinate = int(filename.split(" ")[num*2+2])
? ? ? ? ? ? ? ? landmark.append([x_coordinate, y_coordinate])
? ? ? ? ? ? self.landmarks.append(landmark)
? ? ? ? self.landmarks = np.array(self.landmarks).astype('float32')
? ? ? ? assert len(self.image_filenames) == len(self.landmarks)
? ? def __len__(self):
? ? ? ? return len(self.image_filenames)
? ? def __getitem__(self, index):
? ? ? ? # 讀取圖像以及關鍵點坐標
? ? ? ? image = cv2.imread(self.image_filenames[index]) ?# 以彩色模式讀取圖像
? ? ? ? # image = cv2.imread(self.image_filenames[index], 0) # 以灰色模式讀取圖像
? ? ? ? landmarks = self.landmarks[index]
? ? ? ? if self.transform:
? ? ? ? ? ? # 如果存在預處理變換,應用變換
? ? ? ? ? ? image, landmarks = self.transform(image, landmarks)
? ? ? ? landmarks = landmarks - 0.5 ?# 進行中心化操作
? ? ? ? return image, landmarks
# 創建數據集對象,并應用預處理變換
dataset = FaceLandmarksDataset(Transforms())
len_valid_set = int(0.1 * len(dataset))
len_train_set = len(dataset) - len_valid_set
#print("The length of Train set is {}".format(len_train_set))
#print("The length of Valid set is {}".format(len_valid_set))
train_dataset, valid_dataset, = torch.utils.data.random_split(dataset, [len_train_set, len_valid_set])
# shuffle and batch the datasets
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=1)
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=1, shuffle=True, num_workers=1)
(4)train
def train():
? ? # 記錄每個 epoch 的訓練和驗證損失
? ? train_losses = []
? ? valid_losses = []
? ? # 設置設備
? ? device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
? ? torch.autograd.set_detect_anomaly(True)
? ? #network = Network().to(device)
? ? network = RegressNet().to(device)
? ? criterion = nn.MSELoss()
? ? optimizer = optim.Adam(network.parameters(), lr=0.0001)
? ? loss_min = np.inf
? ? num_epochs = 10
? ? start_time = time.time()
? ? for epoch in range(1, num_epochs + 1):
? ? ? ? loss_train = 0
? ? ? ? loss_valid = 0
? ? ? ? running_loss = 0
? ? ? ? network.train()
? ? ? ? for step in tqdm(range(1, len(train_loader) + 1)):
? ? ? ? ? ? images, landmarks = next(iter(train_loader))
? ? ? ? ? ? images = images.to(device)
? ? ? ? ? ? landmarks = landmarks.view(landmarks.size(0), -1).to(device)
? ? ? ? ? ? predictions = network(images)
? ? ? ? ? ? optimizer.zero_grad()
? ? ? ? ? ? loss_train_step = criterion(predictions, landmarks)
? ? ? ? ? ? loss_train_step.backward()
? ? ? ? ? ? optimizer.step()
? ? ? ? ? ? loss_train += loss_train_step.item()
? ? ? ? ? ? running_loss = loss_train / step
? ? ? ? network.eval()
? ? ? ? with torch.no_grad():
? ? ? ? ? ? for step in range(1, len(valid_loader) + 1):
? ? ? ? ? ? ? ? images, landmarks = next(iter(valid_loader))
? ? ? ? ? ? ? ? images = images.to(device)
? ? ? ? ? ? ? ? landmarks = landmarks.view(landmarks.size(0), -1).to(device)
? ? ? ? ? ? ? ? predictions = network(images)
? ? ? ? ? ? ? ? loss_valid_step = criterion(predictions, landmarks)
? ? ? ? ? ? ? ? loss_valid += loss_valid_step.item()
? ? ? ? ? ? ? ? running_loss = loss_valid / step
? ? ? ? loss_train /= len(train_loader)
? ? ? ? loss_valid /= len(valid_loader)
? ? ? ? train_losses.append(loss_train)
? ? ? ? valid_losses.append(loss_valid)
? ? ? ? print('\n--------------------------------------------------')
? ? ? ? print('Epoch: {} ?Train Loss: {:.4f} ?Valid Loss: {:.4f}'.format(epoch, loss_train, loss_valid))
? ? ? ? print('--------------------------------------------------')
? ? ? ? if loss_valid < loss_min:
? ? ? ? ? ? loss_min = loss_valid
? ? ? ? ? ? torch.save(network.state_dict(), 'plate_landmark.pth')
? ? ? ? ? ? print("\nMinimum Validation Loss of {:.4f} at epoch {}/{}".format(loss_min, epoch, num_epochs))
? ? ? ? ? ? print('Model Saved\n')
? ? print('Training Complete')
? ? print("Total Elapsed Time: {} s".format(time.time() - start_time))
if __name__ == '__main__':
? ? train()
3.導出onnx
#export.py
import torch
import torch.nn
import onnx
from onnxsim import simplify
from model import RegressNet
#device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cpu')
model = RegressNet()
model_statedict ?= torch.load(r'./plate_landmark.pth', map_location=device)
#model.eval()
model.load_state_dict(model_statedict)
input_names = ['input0']
output_names = ['output0']
x = torch.randn(1, 3, 64, 64, device=device)
torch.onnx.export(model, x, 'plate_landmark.onnx', opset_version=11, verbose=True, input_names=input_names, output_names = output_names,dynamic_axes={'input0': {0: 'batch'},
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? 'output0': {0: 'batch'}
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?})
onnx_model = onnx.load("plate_landmark.onnx")# 簡化模型
simplified_model, check = simplify(onnx_model)# 保存簡化后的模型
onnx.save_model(simplified_model, "plate_landmark_sim.onnx")