imagematte.py 3.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. import os
  2. import random
  3. from torch.utils.data import Dataset
  4. from PIL import Image
  5. from .augmentation import MotionAugmentation
  6. class ImageMatteDataset(Dataset):
  7. def __init__(self,
  8. imagematte_dir,
  9. background_image_dir,
  10. background_video_dir,
  11. size,
  12. seq_length,
  13. seq_sampler,
  14. transform):
  15. self.imagematte_dir = imagematte_dir
  16. self.imagematte_files = os.listdir(os.path.join(imagematte_dir, 'fgr'))
  17. self.background_image_dir = background_image_dir
  18. self.background_image_files = os.listdir(background_image_dir)
  19. self.background_video_dir = background_video_dir
  20. self.background_video_clips = os.listdir(background_video_dir)
  21. self.background_video_frames = [sorted(os.listdir(os.path.join(background_video_dir, clip)))
  22. for clip in self.background_video_clips]
  23. self.seq_length = seq_length
  24. self.seq_sampler = seq_sampler
  25. self.size = size
  26. self.transform = transform
  27. def __len__(self):
  28. return max(len(self.imagematte_files), len(self.background_image_files) + len(self.background_video_clips))
  29. def __getitem__(self, idx):
  30. if random.random() < 0.5:
  31. bgrs = self._get_random_image_background()
  32. else:
  33. bgrs = self._get_random_video_background()
  34. fgrs, phas = self._get_imagematte(idx)
  35. if self.transform is not None:
  36. return self.transform(fgrs, phas, bgrs)
  37. return fgrs, phas, bgrs
  38. def _get_imagematte(self, idx):
  39. with Image.open(os.path.join(self.imagematte_dir, 'fgr', self.imagematte_files[idx % len(self.imagematte_files)])) as fgr, \
  40. Image.open(os.path.join(self.imagematte_dir, 'pha', self.imagematte_files[idx % len(self.imagematte_files)])) as pha:
  41. fgr = self._downsample_if_needed(fgr.convert('RGB'))
  42. pha = self._downsample_if_needed(pha.convert('L'))
  43. fgrs = [fgr] * self.seq_length
  44. phas = [pha] * self.seq_length
  45. return fgrs, phas
  46. def _get_random_image_background(self):
  47. with Image.open(os.path.join(self.background_image_dir, self.background_image_files[random.choice(range(len(self.background_image_files)))])) as bgr:
  48. bgr = self._downsample_if_needed(bgr.convert('RGB'))
  49. bgrs = [bgr] * self.seq_length
  50. return bgrs
  51. def _get_random_video_background(self):
  52. clip_idx = random.choice(range(len(self.background_video_clips)))
  53. frame_count = len(self.background_video_frames[clip_idx])
  54. frame_idx = random.choice(range(max(1, frame_count - self.seq_length)))
  55. clip = self.background_video_clips[clip_idx]
  56. bgrs = []
  57. for i in self.seq_sampler(self.seq_length):
  58. frame_idx_t = frame_idx + i
  59. frame = self.background_video_frames[clip_idx][frame_idx_t % frame_count]
  60. with Image.open(os.path.join(self.background_video_dir, clip, frame)) as bgr:
  61. bgr = self._downsample_if_needed(bgr.convert('RGB'))
  62. bgrs.append(bgr)
  63. return bgrs
  64. def _downsample_if_needed(self, img):
  65. w, h = img.size
  66. if min(w, h) > self.size:
  67. scale = self.size / min(w, h)
  68. w = int(scale * w)
  69. h = int(scale * h)
  70. img = img.resize((w, h))
  71. return img
  72. class ImageMatteAugmentation(MotionAugmentation):
  73. def __init__(self, size):
  74. super().__init__(
  75. size=size,
  76. prob_fgr_affine=0.95,
  77. prob_bgr_affine=0.3,
  78. prob_noise=0.05,
  79. prob_color_jitter=0.3,
  80. prob_grayscale=0.03,
  81. prob_sharpness=0.05,
  82. prob_blur=0.02,
  83. prob_hflip=0.5,
  84. prob_pause=0.03,
  85. )