videomatte.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. import os
  2. import random
  3. from torch.utils.data import Dataset
  4. from PIL import Image
  5. from .augmentation import MotionAugmentation
  6. class VideoMatteDataset(Dataset):
  7. def __init__(self,
  8. videomatte_dir,
  9. background_image_dir,
  10. background_video_dir,
  11. size,
  12. seq_length,
  13. seq_sampler,
  14. transform=None):
  15. self.background_image_dir = background_image_dir
  16. self.background_image_files = os.listdir(background_image_dir)
  17. self.background_video_dir = background_video_dir
  18. self.background_video_clips = sorted(os.listdir(background_video_dir))
  19. self.background_video_frames = [sorted(os.listdir(os.path.join(background_video_dir, clip)))
  20. for clip in self.background_video_clips]
  21. self.videomatte_dir = videomatte_dir
  22. self.videomatte_clips = sorted(os.listdir(os.path.join(videomatte_dir, 'fgr')))
  23. self.videomatte_frames = [sorted(os.listdir(os.path.join(videomatte_dir, 'fgr', clip)))
  24. for clip in self.videomatte_clips]
  25. self.videomatte_idx = [(clip_idx, frame_idx)
  26. for clip_idx in range(len(self.videomatte_clips))
  27. for frame_idx in range(0, len(self.videomatte_frames[clip_idx]), seq_length)]
  28. self.size = size
  29. self.seq_length = seq_length
  30. self.seq_sampler = seq_sampler
  31. self.transform = transform
  32. def __len__(self):
  33. return len(self.videomatte_idx)
  34. def __getitem__(self, idx):
  35. if random.random() < 0.5:
  36. bgrs = self._get_random_image_background()
  37. else:
  38. bgrs = self._get_random_video_background()
  39. fgrs, phas = self._get_videomatte(idx)
  40. if self.transform is not None:
  41. return self.transform(fgrs, phas, bgrs)
  42. return fgrs, phas, bgrs
  43. def _get_random_image_background(self):
  44. with Image.open(os.path.join(self.background_image_dir, random.choice(self.background_image_files))) as bgr:
  45. bgr = self._downsample_if_needed(bgr.convert('RGB'))
  46. bgrs = [bgr] * self.seq_length
  47. return bgrs
  48. def _get_random_video_background(self):
  49. clip_idx = random.choice(range(len(self.background_video_clips)))
  50. frame_count = len(self.background_video_frames[clip_idx])
  51. frame_idx = random.choice(range(max(1, frame_count - self.seq_length)))
  52. clip = self.background_video_clips[clip_idx]
  53. bgrs = []
  54. for i in self.seq_sampler(self.seq_length):
  55. frame_idx_t = frame_idx + i
  56. frame = self.background_video_frames[clip_idx][frame_idx_t % frame_count]
  57. with Image.open(os.path.join(self.background_video_dir, clip, frame)) as bgr:
  58. bgr = self._downsample_if_needed(bgr.convert('RGB'))
  59. bgrs.append(bgr)
  60. return bgrs
  61. def _get_videomatte(self, idx):
  62. clip_idx, frame_idx = self.videomatte_idx[idx]
  63. clip = self.videomatte_clips[clip_idx]
  64. frame_count = len(self.videomatte_frames[clip_idx])
  65. fgrs, phas = [], []
  66. for i in self.seq_sampler(self.seq_length):
  67. frame = self.videomatte_frames[clip_idx][(frame_idx + i) % frame_count]
  68. with Image.open(os.path.join(self.videomatte_dir, 'fgr', clip, frame)) as fgr, \
  69. Image.open(os.path.join(self.videomatte_dir, 'pha', clip, frame)) as pha:
  70. fgr = self._downsample_if_needed(fgr.convert('RGB'))
  71. pha = self._downsample_if_needed(pha.convert('L'))
  72. fgrs.append(fgr)
  73. phas.append(pha)
  74. return fgrs, phas
  75. def _downsample_if_needed(self, img):
  76. w, h = img.size
  77. if min(w, h) > self.size:
  78. scale = self.size / min(w, h)
  79. w = int(scale * w)
  80. h = int(scale * h)
  81. img = img.resize((w, h))
  82. return img
  83. class VideoMatteTrainAugmentation(MotionAugmentation):
  84. def __init__(self, size):
  85. super().__init__(
  86. size=size,
  87. prob_fgr_affine=0.3,
  88. prob_bgr_affine=0.3,
  89. prob_noise=0.1,
  90. prob_color_jitter=0.3,
  91. prob_grayscale=0.02,
  92. prob_sharpness=0.1,
  93. prob_blur=0.02,
  94. prob_hflip=0.5,
  95. prob_pause=0.03,
  96. )
  97. class VideoMatteValidAugmentation(MotionAugmentation):
  98. def __init__(self, size):
  99. super().__init__(
  100. size=size,
  101. prob_fgr_affine=0,
  102. prob_bgr_affine=0,
  103. prob_noise=0,
  104. prob_color_jitter=0,
  105. prob_grayscale=0,
  106. prob_sharpness=0,
  107. prob_blur=0,
  108. prob_hflip=0,
  109. prob_pause=0,
  110. )