import easing_functions as ef import random import torch from torchvision import transforms from torchvision.transforms import functional as F class MotionAugmentation: def __init__(self, size, prob_fgr_affine, prob_bgr_affine, prob_noise, prob_color_jitter, prob_grayscale, prob_sharpness, prob_blur, prob_hflip, prob_pause, static_affine=True, aspect_ratio_range=(0.9, 1.1)): self.size = size self.prob_fgr_affine = prob_fgr_affine self.prob_bgr_affine = prob_bgr_affine self.prob_noise = prob_noise self.prob_color_jitter = prob_color_jitter self.prob_grayscale = prob_grayscale self.prob_sharpness = prob_sharpness self.prob_blur = prob_blur self.prob_hflip = prob_hflip self.prob_pause = prob_pause self.static_affine = static_affine self.aspect_ratio_range = aspect_ratio_range def __call__(self, fgrs, phas, bgrs): # Foreground affine if random.random() < self.prob_fgr_affine: fgrs, phas = self._motion_affine(fgrs, phas) # Background affine if random.random() < self.prob_bgr_affine / 2: bgrs = self._motion_affine(bgrs) if random.random() < self.prob_bgr_affine / 2: fgrs, phas, bgrs = self._motion_affine(fgrs, phas, bgrs) # Still Affine if self.static_affine: fgrs, phas = self._static_affine(fgrs, phas, scale_ranges=(0.5, 1)) bgrs = self._static_affine(bgrs, scale_ranges=(1, 1.5)) # To tensor fgrs = torch.stack([F.to_tensor(fgr) for fgr in fgrs]) phas = torch.stack([F.to_tensor(pha) for pha in phas]) bgrs = torch.stack([F.to_tensor(bgr) for bgr in bgrs]) # Resize params = transforms.RandomResizedCrop.get_params(fgrs, scale=(1, 1), ratio=self.aspect_ratio_range) fgrs = F.resized_crop(fgrs, *params, self.size, interpolation=F.InterpolationMode.BILINEAR) phas = F.resized_crop(phas, *params, self.size, interpolation=F.InterpolationMode.BILINEAR) params = transforms.RandomResizedCrop.get_params(bgrs, scale=(1, 1), ratio=self.aspect_ratio_range) bgrs = F.resized_crop(bgrs, *params, self.size, interpolation=F.InterpolationMode.BILINEAR) # Horizontal flip if random.random() < self.prob_hflip: fgrs = F.hflip(fgrs) phas = F.hflip(phas) if random.random() < self.prob_hflip: bgrs = F.hflip(bgrs) # Noise if random.random() < self.prob_noise: fgrs, bgrs = self._motion_noise(fgrs, bgrs) # Color jitter if random.random() < self.prob_color_jitter: fgrs = self._motion_color_jitter(fgrs) if random.random() < self.prob_color_jitter: bgrs = self._motion_color_jitter(bgrs) # Grayscale if random.random() < self.prob_grayscale: fgrs = F.rgb_to_grayscale(fgrs, num_output_channels=3).contiguous() bgrs = F.rgb_to_grayscale(bgrs, num_output_channels=3).contiguous() # Sharpen if random.random() < self.prob_sharpness: sharpness = random.random() * 8 fgrs = F.adjust_sharpness(fgrs, sharpness) phas = F.adjust_sharpness(phas, sharpness) bgrs = F.adjust_sharpness(bgrs, sharpness) # Blur if random.random() < self.prob_blur / 3: fgrs, phas = self._motion_blur(fgrs, phas) if random.random() < self.prob_blur / 3: bgrs = self._motion_blur(bgrs) if random.random() < self.prob_blur / 3: fgrs, phas, bgrs = self._motion_blur(fgrs, phas, bgrs) # Pause if random.random() < self.prob_pause: fgrs, phas, bgrs = self._motion_pause(fgrs, phas, bgrs) return fgrs, phas, bgrs def _static_affine(self, *imgs, scale_ranges): params = transforms.RandomAffine.get_params( degrees=(-10, 10), translate=(0.1, 0.1), scale_ranges=scale_ranges, shears=(-5, 5), img_size=imgs[0][0].size) imgs = [[F.affine(t, *params, F.InterpolationMode.BILINEAR) for t in img] for img in imgs] return imgs if len(imgs) > 1 else imgs[0] def _motion_affine(self, *imgs): config = dict(degrees=(-10, 10), translate=(0.1, 0.1), scale_ranges=(0.9, 1.1), shears=(-5, 5), img_size=imgs[0][0].size) angleA, (transXA, transYA), scaleA, (shearXA, shearYA) = transforms.RandomAffine.get_params(**config) angleB, (transXB, transYB), scaleB, (shearXB, shearYB) = transforms.RandomAffine.get_params(**config) T = len(imgs[0]) easing = random_easing_fn() for t in range(T): percentage = easing(t / (T - 1)) angle = lerp(angleA, angleB, percentage) transX = lerp(transXA, transXB, percentage) transY = lerp(transYA, transYB, percentage) scale = lerp(scaleA, scaleB, percentage) shearX = lerp(shearXA, shearXB, percentage) shearY = lerp(shearYA, shearYB, percentage) for img in imgs: img[t] = F.affine(img[t], angle, (transX, transY), scale, (shearX, shearY), F.InterpolationMode.BILINEAR) return imgs if len(imgs) > 1 else imgs[0] def _motion_noise(self, *imgs): grain_size = random.random() * 3 + 1 # range 1 ~ 4 monochrome = random.random() < 0.5 for img in imgs: T, C, H, W = img.shape noise = torch.randn((T, 1 if monochrome else C, round(H / grain_size), round(W / grain_size))) noise.mul_(random.random() * 0.2 / grain_size) if grain_size != 1: noise = F.resize(noise, (H, W)) img.add_(noise).clamp_(0, 1) return imgs if len(imgs) > 1 else imgs[0] def _motion_color_jitter(self, *imgs): brightnessA, brightnessB, contrastA, contrastB, saturationA, saturationB, hueA, hueB \ = torch.randn(8).mul(0.1).tolist() strength = random.random() * 0.2 easing = random_easing_fn() T = len(imgs[0]) for t in range(T): percentage = easing(t / (T - 1)) * strength for img in imgs: img[t] = F.adjust_brightness(img[t], max(1 + lerp(brightnessA, brightnessB, percentage), 0.1)) img[t] = F.adjust_contrast(img[t], max(1 + lerp(contrastA, contrastB, percentage), 0.1)) img[t] = F.adjust_saturation(img[t], max(1 + lerp(brightnessA, brightnessB, percentage), 0.1)) img[t] = F.adjust_hue(img[t], min(0.5, max(-0.5, lerp(hueA, hueB, percentage) * 0.1))) return imgs if len(imgs) > 1 else imgs[0] def _motion_blur(self, *imgs): blurA = random.random() * 10 blurB = random.random() * 10 T = len(imgs[0]) easing = random_easing_fn() for t in range(T): percentage = easing(t / (T - 1)) blur = max(lerp(blurA, blurB, percentage), 0) if blur != 0: kernel_size = int(blur * 2) if kernel_size % 2 == 0: kernel_size += 1 # Make kernel_size odd for img in imgs: img[t] = F.gaussian_blur(img[t], kernel_size, sigma=blur) return imgs if len(imgs) > 1 else imgs[0] def _motion_pause(self, *imgs): T = len(imgs[0]) pause_frame = random.choice(range(T - 1)) pause_length = random.choice(range(T - pause_frame)) for img in imgs: img[pause_frame + 1 : pause_frame + pause_length] = img[pause_frame] return imgs if len(imgs) > 1 else imgs[0] def lerp(a, b, percentage): return a * (1 - percentage) + b * percentage def random_easing_fn(): if random.random() < 0.2: return ef.LinearInOut() else: return random.choice([ ef.BackEaseIn, ef.BackEaseOut, ef.BackEaseInOut, ef.BounceEaseIn, ef.BounceEaseOut, ef.BounceEaseInOut, ef.CircularEaseIn, ef.CircularEaseOut, ef.CircularEaseInOut, ef.CubicEaseIn, ef.CubicEaseOut, ef.CubicEaseInOut, ef.ExponentialEaseIn, ef.ExponentialEaseOut, ef.ExponentialEaseInOut, ef.ElasticEaseIn, ef.ElasticEaseOut, ef.ElasticEaseInOut, ef.QuadEaseIn, ef.QuadEaseOut, ef.QuadEaseInOut, ef.QuarticEaseIn, ef.QuarticEaseOut, ef.QuarticEaseInOut, ef.QuinticEaseIn, ef.QuinticEaseOut, ef.QuinticEaseInOut, ef.SineEaseIn, ef.SineEaseOut, ef.SineEaseInOut, Step, ])() class Step: # Custom easing function for sudden change. def __call__(self, value): return 0 if value < 0.5 else 1 # ---------------------------- Frame Sampler ---------------------------- class TrainFrameSampler: def __init__(self, speed=[0.5, 1, 2, 3, 4, 5]): self.speed = speed def __call__(self, seq_length): frames = list(range(seq_length)) # Speed up speed = random.choice(self.speed) frames = [int(f * speed) for f in frames] # Shift shift = random.choice(range(seq_length)) frames = [f + shift for f in frames] # Reverse if random.random() < 0.5: frames = frames[::-1] return frames class ValidFrameSampler: def __call__(self, seq_length): return range(seq_length)