You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
186 lines
5.6 KiB
186 lines
5.6 KiB
# code was heavily based on https://github.com/Rudrabha/Wav2Lip |
|
# Users should be careful about adopting these functions in any commercial matters. |
|
# https://github.com/Rudrabha/Wav2Lip#license-and-citation |
|
|
|
import cv2 |
|
import random |
|
import os.path |
|
import numpy as np |
|
from PIL import Image |
|
from glob import glob |
|
from os.path import dirname, join, basename, isfile |
|
from ppgan.utils import audio |
|
from ppgan.utils.audio_config import get_audio_config |
|
import numpy as np |
|
|
|
import paddle |
|
from .builder import DATASETS |
|
|
|
|
|
def get_image_list(data_root, split): |
|
filelist = [] |
|
|
|
with open('filelists/{}.txt'.format(split)) as f: |
|
for line in f: |
|
line = line.strip() |
|
if ' ' in line: line = line.split()[0] |
|
video_path = os.path.join(data_root, line) |
|
assert os.path.exists(video_path), '{} is not found'.format( |
|
video_path) |
|
filelist.append(video_path) |
|
|
|
return filelist |
|
|
|
|
|
syncnet_T = 5 |
|
syncnet_mel_step_size = 16 |
|
audio_cfg = get_audio_config() |
|
|
|
|
|
@DATASETS.register() |
|
class Wav2LipDataset(paddle.io.Dataset): |
|
def __init__(self, dataroot, img_size, filelists_path, split): |
|
"""Initialize Wav2Lip dataset class. |
|
|
|
Args: |
|
dataroot (str): Directory of dataset. |
|
""" |
|
self.image_path = dataroot |
|
self.img_size = img_size |
|
self.split = split |
|
self.all_videos = get_image_list(self.image_path, self.split) |
|
|
|
def get_frame_id(self, frame): |
|
return int(basename(frame).split('.')[0]) |
|
|
|
def get_window(self, start_frame): |
|
start_id = self.get_frame_id(start_frame) |
|
vidname = dirname(start_frame) |
|
|
|
window_fnames = [] |
|
for frame_id in range(start_id, start_id + syncnet_T): |
|
frame = join(vidname, '{}.jpg'.format(frame_id)) |
|
if not isfile(frame): |
|
return None |
|
window_fnames.append(frame) |
|
return window_fnames |
|
|
|
def read_window(self, window_fnames): |
|
if window_fnames is None: return None |
|
window = [] |
|
for fname in window_fnames: |
|
img = cv2.imread(fname) |
|
if img is None: |
|
return None |
|
try: |
|
img = cv2.resize(img, (self.img_size, self.img_size)) |
|
except Exception as e: |
|
return None |
|
|
|
window.append(img) |
|
|
|
return window |
|
|
|
def crop_audio_window(self, spec, start_frame): |
|
if type(start_frame) == int: |
|
start_frame_num = start_frame |
|
else: |
|
start_frame_num = self.get_frame_id( |
|
start_frame) # 0-indexing ---> 1-indexing |
|
start_idx = int(80. * (start_frame_num / float(audio_cfg["fps"]))) |
|
|
|
end_idx = start_idx + syncnet_mel_step_size |
|
|
|
return spec[start_idx:end_idx, :] |
|
|
|
def get_segmented_mels(self, spec, start_frame): |
|
mels = [] |
|
assert syncnet_T == 5 |
|
start_frame_num = self.get_frame_id( |
|
start_frame) + 1 # 0-indexing ---> 1-indexing |
|
if start_frame_num - 2 < 0: return None |
|
for i in range(start_frame_num, start_frame_num + syncnet_T): |
|
m = self.crop_audio_window(spec, i - 2) |
|
if m.shape[0] != syncnet_mel_step_size: |
|
return None |
|
mels.append(m.T) |
|
|
|
mels = np.asarray(mels) |
|
|
|
return mels |
|
|
|
def prepare_window(self, window): |
|
# 3 x T x H x W |
|
x = np.asarray(window) / 255. |
|
x = np.transpose(x, (3, 0, 1, 2)) |
|
|
|
return x |
|
|
|
def __len__(self): |
|
return len(self.all_videos) |
|
|
|
def __getitem__(self, idx): |
|
while 1: |
|
idx = random.randint(0, len(self.all_videos) - 1) |
|
vidname = self.all_videos[idx] |
|
img_names = list(glob(join(vidname, '*.jpg'))) |
|
if len(img_names) <= 3 * syncnet_T: |
|
continue |
|
|
|
img_name = random.choice(img_names) |
|
wrong_img_name = random.choice(img_names) |
|
while wrong_img_name == img_name: |
|
wrong_img_name = random.choice(img_names) |
|
|
|
window_fnames = self.get_window(img_name) |
|
wrong_window_fnames = self.get_window(wrong_img_name) |
|
if window_fnames is None or wrong_window_fnames is None: |
|
continue |
|
|
|
window = self.read_window(window_fnames) |
|
if window is None: |
|
continue |
|
|
|
wrong_window = self.read_window(wrong_window_fnames) |
|
if wrong_window is None: |
|
continue |
|
|
|
try: |
|
wavpath = join(vidname, "audio.wav") |
|
wav = audio.load_wav(wavpath, audio_cfg["sample_rate"]) |
|
|
|
orig_mel = audio.melspectrogram(wav).T |
|
except Exception as e: |
|
continue |
|
|
|
mel = self.crop_audio_window(orig_mel.copy(), img_name) |
|
|
|
if (mel.shape[0] != syncnet_mel_step_size): |
|
continue |
|
|
|
indiv_mels = self.get_segmented_mels(orig_mel.copy(), img_name) |
|
if indiv_mels is None: continue |
|
|
|
window = self.prepare_window(window) |
|
y = window.copy() |
|
window[:, :, window.shape[2] // 2:] = 0. |
|
|
|
wrong_window = self.prepare_window(wrong_window) |
|
x = np.concatenate([window, wrong_window], axis=0) |
|
|
|
x = np.float32(x) |
|
mel = np.transpose(mel) |
|
mel = np.expand_dims(mel, 0) |
|
indiv_mels = np.expand_dims(indiv_mels, 1) |
|
|
|
return { |
|
'x': x, |
|
'indiv_mels': np.float32(indiv_mels), |
|
'mel': np.float32(mel), |
|
'y': np.float32(y) |
|
} |
|
|
|
def __len__(self): |
|
"""Return the total number of images in the dataset. |
|
""" |
|
return len(self.all_videos)
|
|
|