You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

186 lines
5.6 KiB

# code was heavily based on https://github.com/Rudrabha/Wav2Lip
# Users should be careful about adopting these functions in any commercial matters.
# https://github.com/Rudrabha/Wav2Lip#license-and-citation
import cv2
import random
import os.path
import numpy as np
from PIL import Image
from glob import glob
from os.path import dirname, join, basename, isfile
from ppgan.utils import audio
from ppgan.utils.audio_config import get_audio_config
import numpy as np
import paddle
from .builder import DATASETS
def get_image_list(data_root, split):
filelist = []
with open('filelists/{}.txt'.format(split)) as f:
for line in f:
line = line.strip()
if ' ' in line: line = line.split()[0]
video_path = os.path.join(data_root, line)
assert os.path.exists(video_path), '{} is not found'.format(
video_path)
filelist.append(video_path)
return filelist
syncnet_T = 5
syncnet_mel_step_size = 16
audio_cfg = get_audio_config()
@DATASETS.register()
class Wav2LipDataset(paddle.io.Dataset):
def __init__(self, dataroot, img_size, filelists_path, split):
"""Initialize Wav2Lip dataset class.
Args:
dataroot (str): Directory of dataset.
"""
self.image_path = dataroot
self.img_size = img_size
self.split = split
self.all_videos = get_image_list(self.image_path, self.split)
def get_frame_id(self, frame):
return int(basename(frame).split('.')[0])
def get_window(self, start_frame):
start_id = self.get_frame_id(start_frame)
vidname = dirname(start_frame)
window_fnames = []
for frame_id in range(start_id, start_id + syncnet_T):
frame = join(vidname, '{}.jpg'.format(frame_id))
if not isfile(frame):
return None
window_fnames.append(frame)
return window_fnames
def read_window(self, window_fnames):
if window_fnames is None: return None
window = []
for fname in window_fnames:
img = cv2.imread(fname)
if img is None:
return None
try:
img = cv2.resize(img, (self.img_size, self.img_size))
except Exception as e:
return None
window.append(img)
return window
def crop_audio_window(self, spec, start_frame):
if type(start_frame) == int:
start_frame_num = start_frame
else:
start_frame_num = self.get_frame_id(
start_frame) # 0-indexing ---> 1-indexing
start_idx = int(80. * (start_frame_num / float(audio_cfg["fps"])))
end_idx = start_idx + syncnet_mel_step_size
return spec[start_idx:end_idx, :]
def get_segmented_mels(self, spec, start_frame):
mels = []
assert syncnet_T == 5
start_frame_num = self.get_frame_id(
start_frame) + 1 # 0-indexing ---> 1-indexing
if start_frame_num - 2 < 0: return None
for i in range(start_frame_num, start_frame_num + syncnet_T):
m = self.crop_audio_window(spec, i - 2)
if m.shape[0] != syncnet_mel_step_size:
return None
mels.append(m.T)
mels = np.asarray(mels)
return mels
def prepare_window(self, window):
# 3 x T x H x W
x = np.asarray(window) / 255.
x = np.transpose(x, (3, 0, 1, 2))
return x
def __len__(self):
return len(self.all_videos)
def __getitem__(self, idx):
while 1:
idx = random.randint(0, len(self.all_videos) - 1)
vidname = self.all_videos[idx]
img_names = list(glob(join(vidname, '*.jpg')))
if len(img_names) <= 3 * syncnet_T:
continue
img_name = random.choice(img_names)
wrong_img_name = random.choice(img_names)
while wrong_img_name == img_name:
wrong_img_name = random.choice(img_names)
window_fnames = self.get_window(img_name)
wrong_window_fnames = self.get_window(wrong_img_name)
if window_fnames is None or wrong_window_fnames is None:
continue
window = self.read_window(window_fnames)
if window is None:
continue
wrong_window = self.read_window(wrong_window_fnames)
if wrong_window is None:
continue
try:
wavpath = join(vidname, "audio.wav")
wav = audio.load_wav(wavpath, audio_cfg["sample_rate"])
orig_mel = audio.melspectrogram(wav).T
except Exception as e:
continue
mel = self.crop_audio_window(orig_mel.copy(), img_name)
if (mel.shape[0] != syncnet_mel_step_size):
continue
indiv_mels = self.get_segmented_mels(orig_mel.copy(), img_name)
if indiv_mels is None: continue
window = self.prepare_window(window)
y = window.copy()
window[:, :, window.shape[2] // 2:] = 0.
wrong_window = self.prepare_window(wrong_window)
x = np.concatenate([window, wrong_window], axis=0)
x = np.float32(x)
mel = np.transpose(mel)
mel = np.expand_dims(mel, 0)
indiv_mels = np.expand_dims(indiv_mels, 1)
return {
'x': x,
'indiv_mels': np.float32(indiv_mels),
'mel': np.float32(mel),
'y': np.float32(y)
}
def __len__(self):
"""Return the total number of images in the dataset.
"""
return len(self.all_videos)