123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125 |
- import os
- import random
- from torch.utils.data import Dataset
- from PIL import Image
- from .augmentation import MotionAugmentation
- class VideoMatteDataset(Dataset):
- def __init__(self,
- videomatte_dir,
- background_image_dir,
- background_video_dir,
- size,
- seq_length,
- seq_sampler,
- transform=None):
- self.background_image_dir = background_image_dir
- self.background_image_files = os.listdir(background_image_dir)
- self.background_video_dir = background_video_dir
- self.background_video_clips = sorted(os.listdir(background_video_dir))
- self.background_video_frames = [sorted(os.listdir(os.path.join(background_video_dir, clip)))
- for clip in self.background_video_clips]
-
- self.videomatte_dir = videomatte_dir
- self.videomatte_clips = sorted(os.listdir(os.path.join(videomatte_dir, 'fgr')))
- self.videomatte_frames = [sorted(os.listdir(os.path.join(videomatte_dir, 'fgr', clip)))
- for clip in self.videomatte_clips]
- self.videomatte_idx = [(clip_idx, frame_idx)
- for clip_idx in range(len(self.videomatte_clips))
- for frame_idx in range(0, len(self.videomatte_frames[clip_idx]), seq_length)]
- self.size = size
- self.seq_length = seq_length
- self.seq_sampler = seq_sampler
- self.transform = transform
- def __len__(self):
- return len(self.videomatte_idx)
-
- def __getitem__(self, idx):
- if random.random() < 0.5:
- bgrs = self._get_random_image_background()
- else:
- bgrs = self._get_random_video_background()
-
- fgrs, phas = self._get_videomatte(idx)
-
- if self.transform is not None:
- return self.transform(fgrs, phas, bgrs)
-
- return fgrs, phas, bgrs
-
- def _get_random_image_background(self):
- with Image.open(os.path.join(self.background_image_dir, random.choice(self.background_image_files))) as bgr:
- bgr = self._downsample_if_needed(bgr.convert('RGB'))
- bgrs = [bgr] * self.seq_length
- return bgrs
-
- def _get_random_video_background(self):
- clip_idx = random.choice(range(len(self.background_video_clips)))
- frame_count = len(self.background_video_frames[clip_idx])
- frame_idx = random.choice(range(max(1, frame_count - self.seq_length)))
- clip = self.background_video_clips[clip_idx]
- bgrs = []
- for i in self.seq_sampler(self.seq_length):
- frame_idx_t = frame_idx + i
- frame = self.background_video_frames[clip_idx][frame_idx_t % frame_count]
- with Image.open(os.path.join(self.background_video_dir, clip, frame)) as bgr:
- bgr = self._downsample_if_needed(bgr.convert('RGB'))
- bgrs.append(bgr)
- return bgrs
-
- def _get_videomatte(self, idx):
- clip_idx, frame_idx = self.videomatte_idx[idx]
- clip = self.videomatte_clips[clip_idx]
- frame_count = len(self.videomatte_frames[clip_idx])
- fgrs, phas = [], []
- for i in self.seq_sampler(self.seq_length):
- frame = self.videomatte_frames[clip_idx][(frame_idx + i) % frame_count]
- with Image.open(os.path.join(self.videomatte_dir, 'fgr', clip, frame)) as fgr, \
- Image.open(os.path.join(self.videomatte_dir, 'pha', clip, frame)) as pha:
- fgr = self._downsample_if_needed(fgr.convert('RGB'))
- pha = self._downsample_if_needed(pha.convert('L'))
- fgrs.append(fgr)
- phas.append(pha)
- return fgrs, phas
-
- def _downsample_if_needed(self, img):
- w, h = img.size
- if min(w, h) > self.size:
- scale = self.size / min(w, h)
- w = int(scale * w)
- h = int(scale * h)
- img = img.resize((w, h))
- return img
- class VideoMatteTrainAugmentation(MotionAugmentation):
- def __init__(self, size):
- super().__init__(
- size=size,
- prob_fgr_affine=0.3,
- prob_bgr_affine=0.3,
- prob_noise=0.1,
- prob_color_jitter=0.3,
- prob_grayscale=0.02,
- prob_sharpness=0.1,
- prob_blur=0.02,
- prob_hflip=0.5,
- prob_pause=0.03,
- )
- class VideoMatteValidAugmentation(MotionAugmentation):
- def __init__(self, size):
- super().__init__(
- size=size,
- prob_fgr_affine=0,
- prob_bgr_affine=0,
- prob_noise=0,
- prob_color_jitter=0,
- prob_grayscale=0,
- prob_sharpness=0,
- prob_blur=0,
- prob_hflip=0,
- prob_pause=0,
- )
|