123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260 |
- import easing_functions as ef
- import random
- import torch
- from torchvision import transforms
- from torchvision.transforms import functional as F
- class MotionAugmentation:
- def __init__(self,
- size,
- prob_fgr_affine,
- prob_bgr_affine,
- prob_noise,
- prob_color_jitter,
- prob_grayscale,
- prob_sharpness,
- prob_blur,
- prob_hflip,
- prob_pause,
- static_affine=True,
- aspect_ratio_range=(0.9, 1.1)):
- self.size = size
- self.prob_fgr_affine = prob_fgr_affine
- self.prob_bgr_affine = prob_bgr_affine
- self.prob_noise = prob_noise
- self.prob_color_jitter = prob_color_jitter
- self.prob_grayscale = prob_grayscale
- self.prob_sharpness = prob_sharpness
- self.prob_blur = prob_blur
- self.prob_hflip = prob_hflip
- self.prob_pause = prob_pause
- self.static_affine = static_affine
- self.aspect_ratio_range = aspect_ratio_range
-
- def __call__(self, fgrs, phas, bgrs):
- # Foreground affine
- if random.random() < self.prob_fgr_affine:
- fgrs, phas = self._motion_affine(fgrs, phas)
- # Background affine
- if random.random() < self.prob_bgr_affine / 2:
- bgrs = self._motion_affine(bgrs)
- if random.random() < self.prob_bgr_affine / 2:
- fgrs, phas, bgrs = self._motion_affine(fgrs, phas, bgrs)
-
- # Still Affine
- if self.static_affine:
- fgrs, phas = self._static_affine(fgrs, phas, scale_ranges=(0.5, 1))
- bgrs = self._static_affine(bgrs, scale_ranges=(1, 1.5))
-
- # To tensor
- fgrs = torch.stack([F.to_tensor(fgr) for fgr in fgrs])
- phas = torch.stack([F.to_tensor(pha) for pha in phas])
- bgrs = torch.stack([F.to_tensor(bgr) for bgr in bgrs])
-
- # Resize
- params = transforms.RandomResizedCrop.get_params(fgrs, scale=(1, 1), ratio=self.aspect_ratio_range)
- fgrs = F.resized_crop(fgrs, *params, self.size, interpolation=F.InterpolationMode.BILINEAR)
- phas = F.resized_crop(phas, *params, self.size, interpolation=F.InterpolationMode.BILINEAR)
- params = transforms.RandomResizedCrop.get_params(bgrs, scale=(1, 1), ratio=self.aspect_ratio_range)
- bgrs = F.resized_crop(bgrs, *params, self.size, interpolation=F.InterpolationMode.BILINEAR)
- # Horizontal flip
- if random.random() < self.prob_hflip:
- fgrs = F.hflip(fgrs)
- phas = F.hflip(phas)
- if random.random() < self.prob_hflip:
- bgrs = F.hflip(bgrs)
- # Noise
- if random.random() < self.prob_noise:
- fgrs, bgrs = self._motion_noise(fgrs, bgrs)
-
- # Color jitter
- if random.random() < self.prob_color_jitter:
- fgrs = self._motion_color_jitter(fgrs)
- if random.random() < self.prob_color_jitter:
- bgrs = self._motion_color_jitter(bgrs)
-
- # Grayscale
- if random.random() < self.prob_grayscale:
- fgrs = F.rgb_to_grayscale(fgrs, num_output_channels=3).contiguous()
- bgrs = F.rgb_to_grayscale(bgrs, num_output_channels=3).contiguous()
-
- # Sharpen
- if random.random() < self.prob_sharpness:
- sharpness = random.random() * 8
- fgrs = F.adjust_sharpness(fgrs, sharpness)
- phas = F.adjust_sharpness(phas, sharpness)
- bgrs = F.adjust_sharpness(bgrs, sharpness)
-
- # Blur
- if random.random() < self.prob_blur / 3:
- fgrs, phas = self._motion_blur(fgrs, phas)
- if random.random() < self.prob_blur / 3:
- bgrs = self._motion_blur(bgrs)
- if random.random() < self.prob_blur / 3:
- fgrs, phas, bgrs = self._motion_blur(fgrs, phas, bgrs)
- # Pause
- if random.random() < self.prob_pause:
- fgrs, phas, bgrs = self._motion_pause(fgrs, phas, bgrs)
-
- return fgrs, phas, bgrs
-
- def _static_affine(self, *imgs, scale_ranges):
- params = transforms.RandomAffine.get_params(
- degrees=(-10, 10), translate=(0.1, 0.1), scale_ranges=scale_ranges,
- shears=(-5, 5), img_size=imgs[0][0].size)
- imgs = [[F.affine(t, *params, F.InterpolationMode.BILINEAR) for t in img] for img in imgs]
- return imgs if len(imgs) > 1 else imgs[0]
-
- def _motion_affine(self, *imgs):
- config = dict(degrees=(-10, 10), translate=(0.1, 0.1),
- scale_ranges=(0.9, 1.1), shears=(-5, 5), img_size=imgs[0][0].size)
- angleA, (transXA, transYA), scaleA, (shearXA, shearYA) = transforms.RandomAffine.get_params(**config)
- angleB, (transXB, transYB), scaleB, (shearXB, shearYB) = transforms.RandomAffine.get_params(**config)
-
- T = len(imgs[0])
- easing = random_easing_fn()
- for t in range(T):
- percentage = easing(t / (T - 1))
- angle = lerp(angleA, angleB, percentage)
- transX = lerp(transXA, transXB, percentage)
- transY = lerp(transYA, transYB, percentage)
- scale = lerp(scaleA, scaleB, percentage)
- shearX = lerp(shearXA, shearXB, percentage)
- shearY = lerp(shearYA, shearYB, percentage)
- for img in imgs:
- img[t] = F.affine(img[t], angle, (transX, transY), scale, (shearX, shearY), F.InterpolationMode.BILINEAR)
- return imgs if len(imgs) > 1 else imgs[0]
-
- def _motion_noise(self, *imgs):
- grain_size = random.random() * 3 + 1 # range 1 ~ 4
- monochrome = random.random() < 0.5
- for img in imgs:
- T, C, H, W = img.shape
- noise = torch.randn((T, 1 if monochrome else C, round(H / grain_size), round(W / grain_size)))
- noise.mul_(random.random() * 0.2 / grain_size)
- if grain_size != 1:
- noise = F.resize(noise, (H, W))
- img.add_(noise).clamp_(0, 1)
- return imgs if len(imgs) > 1 else imgs[0]
-
- def _motion_color_jitter(self, *imgs):
- brightnessA, brightnessB, contrastA, contrastB, saturationA, saturationB, hueA, hueB \
- = torch.randn(8).mul(0.1).tolist()
- strength = random.random() * 0.2
- easing = random_easing_fn()
- T = len(imgs[0])
- for t in range(T):
- percentage = easing(t / (T - 1)) * strength
- for img in imgs:
- img[t] = F.adjust_brightness(img[t], max(1 + lerp(brightnessA, brightnessB, percentage), 0.1))
- img[t] = F.adjust_contrast(img[t], max(1 + lerp(contrastA, contrastB, percentage), 0.1))
- img[t] = F.adjust_saturation(img[t], max(1 + lerp(brightnessA, brightnessB, percentage), 0.1))
- img[t] = F.adjust_hue(img[t], min(0.5, max(-0.5, lerp(hueA, hueB, percentage) * 0.1)))
- return imgs if len(imgs) > 1 else imgs[0]
-
- def _motion_blur(self, *imgs):
- blurA = random.random() * 10
- blurB = random.random() * 10
- T = len(imgs[0])
- easing = random_easing_fn()
- for t in range(T):
- percentage = easing(t / (T - 1))
- blur = max(lerp(blurA, blurB, percentage), 0)
- if blur != 0:
- kernel_size = int(blur * 2)
- if kernel_size % 2 == 0:
- kernel_size += 1 # Make kernel_size odd
- for img in imgs:
- img[t] = F.gaussian_blur(img[t], kernel_size, sigma=blur)
-
- return imgs if len(imgs) > 1 else imgs[0]
-
- def _motion_pause(self, *imgs):
- T = len(imgs[0])
- pause_frame = random.choice(range(T - 1))
- pause_length = random.choice(range(T - pause_frame))
- for img in imgs:
- img[pause_frame + 1 : pause_frame + pause_length] = img[pause_frame]
- return imgs if len(imgs) > 1 else imgs[0]
-
- def lerp(a, b, percentage):
- return a * (1 - percentage) + b * percentage
- def random_easing_fn():
- if random.random() < 0.2:
- return ef.LinearInOut()
- else:
- return random.choice([
- ef.BackEaseIn,
- ef.BackEaseOut,
- ef.BackEaseInOut,
- ef.BounceEaseIn,
- ef.BounceEaseOut,
- ef.BounceEaseInOut,
- ef.CircularEaseIn,
- ef.CircularEaseOut,
- ef.CircularEaseInOut,
- ef.CubicEaseIn,
- ef.CubicEaseOut,
- ef.CubicEaseInOut,
- ef.ExponentialEaseIn,
- ef.ExponentialEaseOut,
- ef.ExponentialEaseInOut,
- ef.ElasticEaseIn,
- ef.ElasticEaseOut,
- ef.ElasticEaseInOut,
- ef.QuadEaseIn,
- ef.QuadEaseOut,
- ef.QuadEaseInOut,
- ef.QuarticEaseIn,
- ef.QuarticEaseOut,
- ef.QuarticEaseInOut,
- ef.QuinticEaseIn,
- ef.QuinticEaseOut,
- ef.QuinticEaseInOut,
- ef.SineEaseIn,
- ef.SineEaseOut,
- ef.SineEaseInOut,
- Step,
- ])()
- class Step: # Custom easing function for sudden change.
- def __call__(self, value):
- return 0 if value < 0.5 else 1
- # ---------------------------- Frame Sampler ----------------------------
- class TrainFrameSampler:
- def __init__(self, speed=[0.5, 1, 2, 3, 4, 5]):
- self.speed = speed
-
- def __call__(self, seq_length):
- frames = list(range(seq_length))
-
- # Speed up
- speed = random.choice(self.speed)
- frames = [int(f * speed) for f in frames]
-
- # Shift
- shift = random.choice(range(seq_length))
- frames = [f + shift for f in frames]
-
- # Reverse
- if random.random() < 0.5:
- frames = frames[::-1]
- return frames
-
- class ValidFrameSampler:
- def __call__(self, seq_length):
- return range(seq_length)
|