發表時間:2021年8月25日
項目地址:https://peterl1n.github.io/RobustVideoMatting/
論文閱讀:https://hpg123.blog.csdn.net/article/details/134409222
RVM是字節團隊開源的一個實時人像分割模型,基于LSTMConv實現,在效果與性能上取得良好效果。為此,對齊開源代碼進行整理利用,實現對視頻人像的實時分割。本博客包含,torch版本、onnx版本代碼。請注意,RVM算法基于lstmconv實現,故推理時時序越長效果越穩定,屏閃概率越低。同時對比torch與onnx推理,發現torch推理速度比onnx快很多。
1、環境準備
1.1 模型下載
本項目一共開源了兩個模型,有torch版本與onnx版本。這里需要下載torch模型與onnx模型
1.2 視頻讀寫代碼
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from torch.utils.data import Dataset
from torchvision.transforms.functional import to_pil_image
from PIL import Imageimport decord
from torch.utils.data import Dataset
from PIL import Image
import numpy as npclass VideoReader(Dataset):def __init__(self, path, transform=None):# 初始化decord視頻讀取器,使用GPU加速(如果可用)self.vr = decord.VideoReader(path, ctx= decord.cpu(0)) #decord.gpu(0) # 獲取幀率self.rate = self.vr.get_avg_fps()self.transform = transform# 獲取視頻總幀數self.length = len(self.vr)@propertydef frame_rate(self):return self.ratedef __len__(self):return self.lengthdef __getitem__(self, idx):# 讀取指定索引的幀,返回numpy數組 (H, W, C),格式為RGBframe = self.vr[idx].asnumpy()# 轉換為PIL圖像frame = Image.fromarray(frame)# 應用變換if self.transform is not None:frame = self.transform(frame)return frameimport cv2
import numpy as npclass VideoWriter:def __init__(self, path, frame_rate, bit_rate=1000000):self.path = pathself.frame_rate = frame_rateself.bit_rate = bit_rateself.writer = Noneself.width = 0self.height = 0def write(self, frames):# frames: [T, C, H, W]# 獲取視頻尺寸self.width = frames.size(3)self.height = frames.size(2)# 如果是灰度圖則轉換為RGBif frames.size(1) == 1:frames = frames.repeat(1, 3, 1, 1) # convert grayscale to RGB#([1, 3, 1280, 720])# 轉換為OpenCV需要的格式 [T, C, H, W] 且范圍為0-255的uint8frames = frames.mul(255).cpu().permute(0, 2, 3, 1).numpy().astype(np.uint8)# OpenCV默認使用BGR格式,需要轉換# 初始化視頻寫入器(首次調用write時)if self.writer is None:# 根據文件名后綴自動選擇編碼器fourcc = cv2.VideoWriter_fourcc(*'mp4v') # 對于mp4格式# 如果是其他格式可以修改,例如'XVID'對應avi格式self.writer = cv2.VideoWriter(self.path,fourcc,self.frame_rate,(self.width, self.height))print(frames.shape,frames.dtype,frames.max(),self.width, self.height)for t in range(frames.shape[0]):frame = frames[t]frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)self.writer.write(frame)def close(self):if self.writer is not None:self.writer.release()
1.3 torch模型定義代碼
這里需要下載項目源碼與模型,注意是引入model目錄下的模型定義。
2、視頻人像分割(torch版本)
2.1 模型加載代碼
import torch
from model import MattingNetworkmodel = MattingNetwork('mobilenetv3').eval().cuda() # or "resnet50"
model.load_state_dict(torch.load('rvm_mobilenetv3.pth'))
2.3 調用代碼
import torch
bgr = torch.tensor([.47, 1, .6]).view(3, 1, 1).cuda() # Green background.
rec = [None] * 4 # Initial recurrent states.
downsample_ratio = 0.25 # Adjust based on your video.writer = VideoWriter('output.mp4', frame_rate=30)
batch=60
with torch.no_grad():for src in DataLoader(reader,batch_size=batch): # RGB tensor normalized to 0 ~ 1.while src.shape[0]<batch:src=torch.cat([src,src[-1:]])fgr, pha, *rec = model(src.cuda(), *rec, downsample_ratio) # Cycle the recurrent states.fgr=fgr[:batch]pha=pha[:batch]com = fgr * pha + bgr * (1 - pha) # Composite to green background. writer.write(com) # Write frame.
writer.close()
2.3 處理效果
3060顯卡,cuda12,torch 2.4,處理20s的720p,fps30 視頻,耗時14s。
3、視頻人像分割(onnx版本)
3.1 onnx模型加載代碼
先將從github中下載的 模型在 https://netron.app/ 打開,確認是支持動態size的。
這里的代碼為通用onnx模型推理代碼
import onnxruntime as ort
import numpy as np
from typing import Dict, List, Union, Tupleclass ONNXModel:"""簡化版ONNX Runtime封裝,模擬PyTorch模型調用風格僅實現forward方法,輸入輸出均為numpy數組"""def __init__(self, onnx_path: str, device: str = 'cpu'):self.onnx_path = onnx_path# 根據設備選擇執行提供程序providers = ['CPUExecutionProvider']if device.lower() == 'cuda' and 'CUDAExecutionProvider' in ort.get_available_providers():providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']# 初始化ONNX Runtime會話self.session = ort.InferenceSession(onnx_path, providers=providers)# 獲取輸入和輸出節點信息self.input_names = [input.name for input in self.session.get_inputs()]self.output_names = [output.name for output in self.session.get_outputs()]def forward(self, *args, **kwargs) -> Union[np.ndarray, Tuple[np.ndarray]]:"""前向傳播方法,模擬PyTorch的forward輸入: numpy數組,可以是位置參數(按輸入順序)或關鍵字參數(按輸入名稱)輸出: numpy數組或numpy數組元組"""# 準備輸入字典inputs = {}# 處理位置參數if args:if len(args) != len(self.input_names):raise ValueError(f"位置參數數量不匹配,預期{len(self.input_names)}個,得到{len(args)}個")for name, arg in zip(self.input_names, args):inputs[name] = arg# 處理關鍵字參數if kwargs:for name, value in kwargs.items():if name not in self.input_names:raise ValueError(f"未知的輸入名稱: {name},有效名稱為: {self.input_names}")inputs[name] = value# 檢查輸入完整性if len(inputs) != len(self.input_names):missing = set(self.input_names) - set(inputs.keys())raise ValueError(f"缺少輸入: {missing}")# for k in inputs.keys():# print(k,inputs[k].shape,inputs[k].dtype)# 執行推理outputs = self.session.run(self.output_names,inputs)# 處理輸出格式if len(outputs) == 1:return outputs[0]return tuple(outputs)def __call__(self, *args, **kwargs) -> Union[np.ndarray, Tuple[np.ndarray]]:"""重載調用方法,使實例可以像PyTorch模型一樣被調用"""return self.forward(*args, **kwargs)
model = ONNXModel('rvm_mobilenetv3_fp16.onnx','cuda')
3.2 調用代碼
這里推理代碼與torch推理代碼高度一致,注意數據類型。
import torch
import time
bgr = torch.tensor([.47, 1, .6]).view(3, 1, 1).numpy().astype(np.float16) # Green background.
rec = None # Initial recurrent states.
downsample_ratio = np.array([0.25]).astype(np.float32) # Adjust based on your video.writer = VideoWriter('output.mp4', frame_rate=30)
batch=32
t0=time.time()
with torch.no_grad():for src in DataLoader(reader,batch_size=batch): # RGB tensor normalized to 0 ~ 1.while src.shape[0]<batch:src=torch.cat([src,src[-1:]])src=src.numpy().astype(np.float16)if rec is None:rec=[np.zeros((1,1,1,1),dtype=np.float16)]*4fgr, pha, *rec = model(src, *rec, downsample_ratio) # Cycle the recurrent states.fgr=fgr[:batch]pha=pha[:batch]com = fgr * pha + bgr * (1 - pha) # Composite to green background. com=torch.tensor(com)writer.write(com) # Write frame.
writer.close()
rt=time.time()-t0
print(f"視頻處理耗時:{rt:.4f}")
此時代碼耗時為46s,相比于torch慢了很多。
(32, 1280, 720, 3) uint8 255 720 1280
視頻處理耗時:45.9930