augmentation.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260
  1. import easing_functions as ef
  2. import random
  3. import torch
  4. from torchvision import transforms
  5. from torchvision.transforms import functional as F
  6. class MotionAugmentation:
  7. def __init__(self,
  8. size,
  9. prob_fgr_affine,
  10. prob_bgr_affine,
  11. prob_noise,
  12. prob_color_jitter,
  13. prob_grayscale,
  14. prob_sharpness,
  15. prob_blur,
  16. prob_hflip,
  17. prob_pause,
  18. static_affine=True,
  19. aspect_ratio_range=(0.9, 1.1)):
  20. self.size = size
  21. self.prob_fgr_affine = prob_fgr_affine
  22. self.prob_bgr_affine = prob_bgr_affine
  23. self.prob_noise = prob_noise
  24. self.prob_color_jitter = prob_color_jitter
  25. self.prob_grayscale = prob_grayscale
  26. self.prob_sharpness = prob_sharpness
  27. self.prob_blur = prob_blur
  28. self.prob_hflip = prob_hflip
  29. self.prob_pause = prob_pause
  30. self.static_affine = static_affine
  31. self.aspect_ratio_range = aspect_ratio_range
  32. def __call__(self, fgrs, phas, bgrs):
  33. # Foreground affine
  34. if random.random() < self.prob_fgr_affine:
  35. fgrs, phas = self._motion_affine(fgrs, phas)
  36. # Background affine
  37. if random.random() < self.prob_bgr_affine / 2:
  38. bgrs = self._motion_affine(bgrs)
  39. if random.random() < self.prob_bgr_affine / 2:
  40. fgrs, phas, bgrs = self._motion_affine(fgrs, phas, bgrs)
  41. # Still Affine
  42. if self.static_affine:
  43. fgrs, phas = self._static_affine(fgrs, phas, scale_ranges=(0.5, 1))
  44. bgrs = self._static_affine(bgrs, scale_ranges=(1, 1.5))
  45. # To tensor
  46. fgrs = torch.stack([F.to_tensor(fgr) for fgr in fgrs])
  47. phas = torch.stack([F.to_tensor(pha) for pha in phas])
  48. bgrs = torch.stack([F.to_tensor(bgr) for bgr in bgrs])
  49. # Resize
  50. params = transforms.RandomResizedCrop.get_params(fgrs, scale=(1, 1), ratio=self.aspect_ratio_range)
  51. fgrs = F.resized_crop(fgrs, *params, self.size, interpolation=F.InterpolationMode.BILINEAR)
  52. phas = F.resized_crop(phas, *params, self.size, interpolation=F.InterpolationMode.BILINEAR)
  53. params = transforms.RandomResizedCrop.get_params(bgrs, scale=(1, 1), ratio=self.aspect_ratio_range)
  54. bgrs = F.resized_crop(bgrs, *params, self.size, interpolation=F.InterpolationMode.BILINEAR)
  55. # Horizontal flip
  56. if random.random() < self.prob_hflip:
  57. fgrs = F.hflip(fgrs)
  58. phas = F.hflip(phas)
  59. if random.random() < self.prob_hflip:
  60. bgrs = F.hflip(bgrs)
  61. # Noise
  62. if random.random() < self.prob_noise:
  63. fgrs, bgrs = self._motion_noise(fgrs, bgrs)
  64. # Color jitter
  65. if random.random() < self.prob_color_jitter:
  66. fgrs = self._motion_color_jitter(fgrs)
  67. if random.random() < self.prob_color_jitter:
  68. bgrs = self._motion_color_jitter(bgrs)
  69. # Grayscale
  70. if random.random() < self.prob_grayscale:
  71. fgrs = F.rgb_to_grayscale(fgrs, num_output_channels=3).contiguous()
  72. bgrs = F.rgb_to_grayscale(bgrs, num_output_channels=3).contiguous()
  73. # Sharpen
  74. if random.random() < self.prob_sharpness:
  75. sharpness = random.random() * 8
  76. fgrs = F.adjust_sharpness(fgrs, sharpness)
  77. phas = F.adjust_sharpness(phas, sharpness)
  78. bgrs = F.adjust_sharpness(bgrs, sharpness)
  79. # Blur
  80. if random.random() < self.prob_blur / 3:
  81. fgrs, phas = self._motion_blur(fgrs, phas)
  82. if random.random() < self.prob_blur / 3:
  83. bgrs = self._motion_blur(bgrs)
  84. if random.random() < self.prob_blur / 3:
  85. fgrs, phas, bgrs = self._motion_blur(fgrs, phas, bgrs)
  86. # Pause
  87. if random.random() < self.prob_pause:
  88. fgrs, phas, bgrs = self._motion_pause(fgrs, phas, bgrs)
  89. return fgrs, phas, bgrs
  90. def _static_affine(self, *imgs, scale_ranges):
  91. params = transforms.RandomAffine.get_params(
  92. degrees=(-10, 10), translate=(0.1, 0.1), scale_ranges=scale_ranges,
  93. shears=(-5, 5), img_size=imgs[0][0].size)
  94. imgs = [[F.affine(t, *params, F.InterpolationMode.BILINEAR) for t in img] for img in imgs]
  95. return imgs if len(imgs) > 1 else imgs[0]
  96. def _motion_affine(self, *imgs):
  97. config = dict(degrees=(-10, 10), translate=(0.1, 0.1),
  98. scale_ranges=(0.9, 1.1), shears=(-5, 5), img_size=imgs[0][0].size)
  99. angleA, (transXA, transYA), scaleA, (shearXA, shearYA) = transforms.RandomAffine.get_params(**config)
  100. angleB, (transXB, transYB), scaleB, (shearXB, shearYB) = transforms.RandomAffine.get_params(**config)
  101. T = len(imgs[0])
  102. easing = random_easing_fn()
  103. for t in range(T):
  104. percentage = easing(t / (T - 1))
  105. angle = lerp(angleA, angleB, percentage)
  106. transX = lerp(transXA, transXB, percentage)
  107. transY = lerp(transYA, transYB, percentage)
  108. scale = lerp(scaleA, scaleB, percentage)
  109. shearX = lerp(shearXA, shearXB, percentage)
  110. shearY = lerp(shearYA, shearYB, percentage)
  111. for img in imgs:
  112. img[t] = F.affine(img[t], angle, (transX, transY), scale, (shearX, shearY), F.InterpolationMode.BILINEAR)
  113. return imgs if len(imgs) > 1 else imgs[0]
  114. def _motion_noise(self, *imgs):
  115. grain_size = random.random() * 3 + 1 # range 1 ~ 4
  116. monochrome = random.random() < 0.5
  117. for img in imgs:
  118. T, C, H, W = img.shape
  119. noise = torch.randn((T, 1 if monochrome else C, round(H / grain_size), round(W / grain_size)))
  120. noise.mul_(random.random() * 0.2 / grain_size)
  121. if grain_size != 1:
  122. noise = F.resize(noise, (H, W))
  123. img.add_(noise).clamp_(0, 1)
  124. return imgs if len(imgs) > 1 else imgs[0]
  125. def _motion_color_jitter(self, *imgs):
  126. brightnessA, brightnessB, contrastA, contrastB, saturationA, saturationB, hueA, hueB \
  127. = torch.randn(8).mul(0.1).tolist()
  128. strength = random.random() * 0.2
  129. easing = random_easing_fn()
  130. T = len(imgs[0])
  131. for t in range(T):
  132. percentage = easing(t / (T - 1)) * strength
  133. for img in imgs:
  134. img[t] = F.adjust_brightness(img[t], max(1 + lerp(brightnessA, brightnessB, percentage), 0.1))
  135. img[t] = F.adjust_contrast(img[t], max(1 + lerp(contrastA, contrastB, percentage), 0.1))
  136. img[t] = F.adjust_saturation(img[t], max(1 + lerp(brightnessA, brightnessB, percentage), 0.1))
  137. img[t] = F.adjust_hue(img[t], min(0.5, max(-0.5, lerp(hueA, hueB, percentage) * 0.1)))
  138. return imgs if len(imgs) > 1 else imgs[0]
  139. def _motion_blur(self, *imgs):
  140. blurA = random.random() * 10
  141. blurB = random.random() * 10
  142. T = len(imgs[0])
  143. easing = random_easing_fn()
  144. for t in range(T):
  145. percentage = easing(t / (T - 1))
  146. blur = max(lerp(blurA, blurB, percentage), 0)
  147. if blur != 0:
  148. kernel_size = int(blur * 2)
  149. if kernel_size % 2 == 0:
  150. kernel_size += 1 # Make kernel_size odd
  151. for img in imgs:
  152. img[t] = F.gaussian_blur(img[t], kernel_size, sigma=blur)
  153. return imgs if len(imgs) > 1 else imgs[0]
  154. def _motion_pause(self, *imgs):
  155. T = len(imgs[0])
  156. pause_frame = random.choice(range(T - 1))
  157. pause_length = random.choice(range(T - pause_frame))
  158. for img in imgs:
  159. img[pause_frame + 1 : pause_frame + pause_length] = img[pause_frame]
  160. return imgs if len(imgs) > 1 else imgs[0]
  161. def lerp(a, b, percentage):
  162. return a * (1 - percentage) + b * percentage
  163. def random_easing_fn():
  164. if random.random() < 0.2:
  165. return ef.LinearInOut()
  166. else:
  167. return random.choice([
  168. ef.BackEaseIn,
  169. ef.BackEaseOut,
  170. ef.BackEaseInOut,
  171. ef.BounceEaseIn,
  172. ef.BounceEaseOut,
  173. ef.BounceEaseInOut,
  174. ef.CircularEaseIn,
  175. ef.CircularEaseOut,
  176. ef.CircularEaseInOut,
  177. ef.CubicEaseIn,
  178. ef.CubicEaseOut,
  179. ef.CubicEaseInOut,
  180. ef.ExponentialEaseIn,
  181. ef.ExponentialEaseOut,
  182. ef.ExponentialEaseInOut,
  183. ef.ElasticEaseIn,
  184. ef.ElasticEaseOut,
  185. ef.ElasticEaseInOut,
  186. ef.QuadEaseIn,
  187. ef.QuadEaseOut,
  188. ef.QuadEaseInOut,
  189. ef.QuarticEaseIn,
  190. ef.QuarticEaseOut,
  191. ef.QuarticEaseInOut,
  192. ef.QuinticEaseIn,
  193. ef.QuinticEaseOut,
  194. ef.QuinticEaseInOut,
  195. ef.SineEaseIn,
  196. ef.SineEaseOut,
  197. ef.SineEaseInOut,
  198. Step,
  199. ])()
  200. class Step: # Custom easing function for sudden change.
  201. def __call__(self, value):
  202. return 0 if value < 0.5 else 1
  203. # ---------------------------- Frame Sampler ----------------------------
  204. class TrainFrameSampler:
  205. def __init__(self, speed=[0.5, 1, 2, 3, 4, 5]):
  206. self.speed = speed
  207. def __call__(self, seq_length):
  208. frames = list(range(seq_length))
  209. # Speed up
  210. speed = random.choice(self.speed)
  211. frames = [int(f * speed) for f in frames]
  212. # Shift
  213. shift = random.choice(range(seq_length))
  214. frames = [f + shift for f in frames]
  215. # Reverse
  216. if random.random() < 0.5:
  217. frames = frames[::-1]
  218. return frames
  219. class ValidFrameSampler:
  220. def __call__(self, seq_length):
  221. return range(seq_length)