一、模型介紹?
? ? ? ?今天介紹一個唇形同步的工具-Wav2Lip;Wav2Lip是一種用于生成唇形同步(lip-sync)視頻的深度學習算法,它能夠根據輸入的音頻流自動為給定的人臉視頻添加準確的口型動作。
(Paper)
? ? ? ?Wav2Lip模型是基于生成對抗網絡(GAN)構建的,它包含生成器和判別器兩個主要部分。生成器負責根據輸入的音頻波形生成逼真的面部動畫,而判別器則負責區分生成的動畫與真實的面部動畫?;
其主要結構和工作原理的詳細描述如下:
-
判別器(D_{SyncNet}):第一階段是訓練一個能夠判別聲音與嘴型是否同步的判別器。這個判別器的目標是提高對聲音與嘴型同步性的判斷能力。
-
生成器(編碼-解碼模型結構):第二階段采用編碼-解碼模型結構,包括一個生成器和兩個判別器。生成器嘗試生成與音頻同步的面部動畫,而兩個判別器分別負責判斷生成的動畫與真實動畫的同步性和視覺質量。
-
主要模塊:Wav2Lip模型包括三個主要模塊:
- Identity Encoder(身份編碼器):負責對隨機參考幀進行編碼,以提取身份特征。
- Speech Encoder(語音編碼器):將輸入語音段編碼為面部動畫特征。
- Face Decoder(人臉解碼器):將編碼后的特征進行上采樣,最終生成面部動畫。
二、本地部署
? ? ? ?下面我們就在本地或者魔塔平臺上部署一下這個模型,這里我選擇在魔塔上部署該項目:
2.1?創建conda虛擬環境
? ? ? ?根據github上的README,我們在硬件上需要有Nvidia的顯卡,同時需要在python=3.6的環境下運行,之前博文有詳細介紹如何在魔塔上安裝miniconda以及創建虛擬環境,這里就不再贅述了,這里我們就創建一個名為wav2lip的虛擬環境;
2.2 安裝依賴環境
git clone https://github.com/Rudrabha/Wav2Lip.gitcd Wav2Lip
注:需要注意的一點是,在安裝依賴環境之前,將requirements.txt文件中的
opencv-contrib-python>=4.2.0.34改為opencv-contrib-python==4.2.0.34
# 安裝依賴環境
pip install -r requirements.txt
# 下載模型權重
git clone https://www.modelscope.cn/GYMaster/Wav2lip.git
2.3 運行
python inference.py --checkpoint_path <ckpt> --face <video.mp4> --audio <an-audio-source>
其中:
--checkpoint_path 是上面下載的模型權重的路徑
--face 是需要同步口型的視頻文件路徑
--audio 是對應的音頻文件路徑
需要注意一下幾點:
1、音頻文件的時長不應超過視頻文件的時長;
2、視頻文件中必須保證每一幀畫面都有清晰的人臉;
2.4 Web-UI
? ? ? ?webUI實現是基于Gradio,測試發現python3.6版本對該庫的兼容性不好,所以,如果要做界面部署的話,建議在python=3.7的環境進行項目依賴庫的安裝,這里給出實現UI調用的腳步inference_ui.py;
# inference_ui.pyfrom os import listdir, path
import numpy as np
import scipy, cv2, os, sys, argparse, audio
import json, subprocess, random, string
from tqdm import tqdm
from glob import glob
import torch, face_detection
from models import Wav2Lip
import platform
import gradio as grparser = argparse.ArgumentParser(description='Inference code to lip-sync videos in the wild using Wav2Lip models')parser.add_argument('--checkpoint_path', type=str, help='Name of saved checkpoint to load weights from', default=None)parser.add_argument('--face', type=str, help='Filepath of video/image that contains faces to use', default=None)
parser.add_argument('--audio', type=str, help='Filepath of video/audio file to use as raw audio source', default=None)
parser.add_argument('--outfile', type=str, help='Video path to save result. See default for an e.g.', default='results/result_voice.mp4')parser.add_argument('--static', type=bool, help='If True, then use only first video frame for inference', default=False)
parser.add_argument('--fps', type=float, help='Can be specified only if input is a static image (default: 25)', default=25., required=False)parser.add_argument('--pads', nargs='+', type=int, default=[0, 10, 0, 0], help='Padding (top, bottom, left, right). Please adjust to include chin at least')parser.add_argument('--face_det_batch_size', type=int, help='Batch size for face detection', default=16)
parser.add_argument('--wav2lip_batch_size', type=int, help='Batch size for Wav2Lip model(s)', default=128)parser.add_argument('--resize_factor', default=1, type=int, help='Reduce the resolution by this factor. Sometimes, best results are obtained at 480p or 720p')parser.add_argument('--crop', nargs='+', type=int, default=[0, -1, 0, -1], help='Crop video to a smaller region (top, bottom, left, right). Applied after resize_factor and rotate arg. ' 'Useful if multiple face present. -1 implies the value will be auto-inferred based on height, width')parser.add_argument('--box', nargs='+', type=int, default=[-1, -1, -1, -1], help='Specify a constant bounding box for the face. Use only as a last resort if the face is not detected.''Also, might work only if the face is not moving around much. Syntax: (top, bottom, left, right).')parser.add_argument('--rotate', default=False, action='store_true',help='Sometimes videos taken from a phone can be flipped 90deg. If true, will flip video right by 90deg.''Use if you get a flipped result, despite feeding a normal looking video')parser.add_argument('--nosmooth', default=False, action='store_true',help='Prevent smoothing face detections over a short temporal window')args = parser.parse_args()
args.img_size = 96def get_smoothened_boxes(boxes, T):for i in range(len(boxes)):if i + T > len(boxes):window = boxes[len(boxes) - T:]else:window = boxes[i : i + T]boxes[i] = np.mean(window, axis=0)return boxesdef face_detect(images):detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D, flip_input=False, device=device)batch_size = args.face_det_batch_sizewhile 1:predictions = []try:for i in tqdm(range(0, len(images), batch_size)):predictions.extend(detector.get_detections_for_batch(np.array(images[i:i + batch_size])))except RuntimeError:if batch_size == 1: raise RuntimeError('Image too big to run face detection on GPU. Please use the --resize_factor argument')batch_size //= 2print('Recovering from OOM error; New batch size: {}'.format(batch_size))continuebreakresults = []pady1, pady2, padx1, padx2 = args.padsfor rect, image in zip(predictions, images):if rect is None:cv2.imwrite('temp/faulty_frame.jpg', image) # check this frame where the face was not detected.raise ValueError('Face not detected! Ensure the video contains a face in all the frames.')y1 = max(0, rect[1] - pady1)y2 = min(image.shape[0], rect[3] + pady2)x1 = max(0, rect[0] - padx1)x2 = min(image.shape[1], rect[2] + padx2)results.append([x1, y1, x2, y2])boxes = np.array(results)if not args.nosmooth: boxes = get_smoothened_boxes(boxes, T=5)results = [[image[y1: y2, x1:x2], (y1, y2, x1, x2)] for image, (x1, y1, x2, y2) in zip(images, boxes)]del detectorreturn results def datagen(frames, mels):img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []if args.box[0] == -1:if not args.static:face_det_results = face_detect(frames) # BGR2RGB for CNN face detectionelse:face_det_results = face_detect([frames[0]])else:print('Using the specified bounding box instead of face detection...')y1, y2, x1, x2 = args.boxface_det_results = [[f[y1: y2, x1:x2], (y1, y2, x1, x2)] for f in frames]for i, m in enumerate(mels):idx = 0 if args.static else i%len(frames)frame_to_save = frames[idx].copy()face, coords = face_det_results[idx].copy()face = cv2.resize(face, (args.img_size, args.img_size))img_batch.append(face)mel_batch.append(m)frame_batch.append(frame_to_save)coords_batch.append(coords)if len(img_batch) >= args.wav2lip_batch_size:img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)img_masked = img_batch.copy()img_masked[:, args.img_size//2:] = 0img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])yield img_batch, mel_batch, frame_batch, coords_batchimg_batch, mel_batch, frame_batch, coords_batch = [], [], [], []if len(img_batch) > 0:img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)img_masked = img_batch.copy()img_masked[:, args.img_size//2:] = 0img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])yield img_batch, mel_batch, frame_batch, coords_batchmel_step_size = 16
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Using {} for inference.'.format(device))def _load(checkpoint_path):if device == 'cuda':checkpoint = torch.load(checkpoint_path)else:checkpoint = torch.load(checkpoint_path,map_location=lambda storage, loc: storage)return checkpointdef load_model(path):model = Wav2Lip()print("Load checkpoint from: {}".format(path))checkpoint = _load(path)s = checkpoint["state_dict"]new_s = {}for k, v in s.items():new_s[k.replace('module.', '')] = vmodel.load_state_dict(new_s)model = model.to(device)return model.eval()def main():if not os.path.isfile(args.face):raise ValueError('--face argument must be a valid path to video/image file')elif args.face.split('.')[1] in ['jpg', 'png', 'jpeg']:full_frames = [cv2.imread(args.face)]fps = args.fpselse:video_stream = cv2.VideoCapture(args.face)fps = video_stream.get(cv2.CAP_PROP_FPS)print('Reading video frames...')full_frames = []while 1:still_reading, frame = video_stream.read()if not still_reading:video_stream.release()breakif args.resize_factor > 1:frame = cv2.resize(frame, (frame.shape[1]//args.resize_factor, frame.shape[0]//args.resize_factor))if args.rotate:frame = cv2.rotate(frame, cv2.cv2.ROTATE_90_CLOCKWISE)y1, y2, x1, x2 = args.cropif x2 == -1: x2 = frame.shape[1]if y2 == -1: y2 = frame.shape[0]frame = frame[y1:y2, x1:x2]full_frames.append(frame)print ("Number of frames available for inference: "+str(len(full_frames)))if not args.audio.endswith('.wav'):print('Extracting raw audio...')command = 'ffmpeg -y -i {} -strict -2 {}'.format(args.audio, 'temp/temp.wav')subprocess.call(command, shell=True)args.audio = 'temp/temp.wav'wav = audio.load_wav(args.audio, 16000)mel = audio.melspectrogram(wav)print(mel.shape)if np.isnan(mel.reshape(-1)).sum() > 0:raise ValueError('Mel contains nan! Using a TTS voice? Add a small epsilon noise to the wav file and try again')mel_chunks = []mel_idx_multiplier = 80./fps i = 0while 1:start_idx = int(i * mel_idx_multiplier)if start_idx + mel_step_size > len(mel[0]):mel_chunks.append(mel[:, len(mel[0]) - mel_step_size:])breakmel_chunks.append(mel[:, start_idx : start_idx + mel_step_size])i += 1print("Length of mel chunks: {}".format(len(mel_chunks)))full_frames = full_frames[:len(mel_chunks)]batch_size = args.wav2lip_batch_sizegen = datagen(full_frames.copy(), mel_chunks)for i, (img_batch, mel_batch, frames, coords) in enumerate(tqdm(gen, total=int(np.ceil(float(len(mel_chunks))/batch_size)))):if i == 0:model = load_model(args.checkpoint_path)print ("Model loaded")frame_h, frame_w = full_frames[0].shape[:-1]out = cv2.VideoWriter('temp/result.avi', cv2.VideoWriter_fourcc(*'DIVX'), fps, (frame_w, frame_h))img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device)mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device)with torch.no_grad():pred = model(mel_batch, img_batch)pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.for p, f, c in zip(pred, frames, coords):y1, y2, x1, x2 = cp = cv2.resize(p.astype(np.uint8), (x2 - x1, y2 - y1))f[y1:y2, x1:x2] = pout.write(f)out.release()command = 'ffmpeg -y -i {} -i {} -strict -2 -q:v 1 {}'.format(args.audio, 'temp/result.avi', args.outfile)subprocess.call(command, shell=platform.system() != 'Windows')#========================================================================
# 假設我們有一個函數來處理視頻和音頻,以及選擇的模型,并返回處理后的視頻
# def process_video_audio(video, audio, model_name):
# args.checkpoint_path = './Wav2lip/wav2lip.pth'
# args.face = video
# args.audio = audio
# processed_video_path = './result/video.mp4'
# return processed_video_pathdef process_video_audio(video, audio, model_name):args.checkpoint_path = './Wav2lip/wav2lip.pth'args.face = videoargs.audio = audioif os.path.isfile(args.face) and args.face.split('.')[1] in ['jpg', 'png', 'jpeg']:args.static = Trueprocessed_video_path = './results/result_voice.mp4'return processed_video_path# 定義可用的模型選項
model_choices = ["Model A", "Model B", "Model C"]# 創建Gradio界面
with gr.Blocks(theme="glass") as demo:gr.Markdown("## 視頻與音頻處理服務")with gr.Row():video_input = gr.Video(label="上傳視頻文件", type="filepath")audio_input = gr.Audio(label="上傳音頻文件", type="filepath")model_choice = gr.Dropdown(choices=model_choices, label="選擇模型", value=model_choices[0])submit_btn = gr.Button("提交")output_video = gr.Video(label="處理后的視頻")# 當點擊提交按鈕時,調用process_video_audio函數submit_btn.click(fn=process_video_audio,inputs=[video_input, audio_input, model_choice],outputs=output_video)if __name__ == '__main__':# 啟動Gradio應用demo.launch(server_name="0.0.0.0", server_port=7860)
將上述腳本放在和inference.py同一級目錄,然后運行下面命令:
python inference_ui.py