123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141 |
- import random
- import torch
- import numpy as np
- import math
- from torchvision import transforms as T
- from torchvision.transforms import functional as F
- from PIL import Image, ImageFilter
- """
- Pair transforms are MODs of regular transforms so that it takes in multiple images
- and apply exact transforms on all images. This is especially useful when we want the
- transforms on a pair of images.
- Example:
- img1, img2, ..., imgN = transforms(img1, img2, ..., imgN)
- """
- class PairCompose(T.Compose):
- def __call__(self, *x):
- for transform in self.transforms:
- x = transform(*x)
- return x
-
- class PairApply:
- def __init__(self, transforms):
- self.transforms = transforms
-
- def __call__(self, *x):
- return [self.transforms(xi) for xi in x]
- class PairApplyOnlyAtIndices:
- def __init__(self, indices, transforms):
- self.indices = indices
- self.transforms = transforms
-
- def __call__(self, *x):
- return [self.transforms(xi) if i in self.indices else xi for i, xi in enumerate(x)]
- class PairRandomAffine(T.RandomAffine):
- def __init__(self, degrees, translate=None, scale=None, shear=None, resamples=None, fillcolor=0):
- super().__init__(degrees, translate, scale, shear, Image.NEAREST, fillcolor)
- self.resamples = resamples
-
- def __call__(self, *x):
- if not len(x):
- return []
- param = self.get_params(self.degrees, self.translate, self.scale, self.shear, x[0].size)
- resamples = self.resamples or [self.resample] * len(x)
- return [F.affine(xi, *param, resamples[i], self.fillcolor) for i, xi in enumerate(x)]
- class PairRandomHorizontalFlip(T.RandomHorizontalFlip):
- def __call__(self, *x):
- if torch.rand(1) < self.p:
- x = [F.hflip(xi) for xi in x]
- return x
- class RandomBoxBlur:
- def __init__(self, prob, max_radius):
- self.prob = prob
- self.max_radius = max_radius
-
- def __call__(self, img):
- if torch.rand(1) < self.prob:
- fil = ImageFilter.BoxBlur(random.choice(range(self.max_radius + 1)))
- img = img.filter(fil)
- return img
- class PairRandomBoxBlur(RandomBoxBlur):
- def __call__(self, *x):
- if torch.rand(1) < self.prob:
- fil = ImageFilter.BoxBlur(random.choice(range(self.max_radius + 1)))
- x = [xi.filter(fil) for xi in x]
- return x
- class RandomSharpen:
- def __init__(self, prob):
- self.prob = prob
- self.filter = ImageFilter.SHARPEN
-
- def __call__(self, img):
- if torch.rand(1) < self.prob:
- img = img.filter(self.filter)
- return img
-
-
- class PairRandomSharpen(RandomSharpen):
- def __call__(self, *x):
- if torch.rand(1) < self.prob:
- x = [xi.filter(self.filter) for xi in x]
- return x
-
- class PairRandomAffineAndResize:
- def __init__(self, size, degrees, translate, scale, shear, ratio=(3./4., 4./3.), resample=Image.BILINEAR, fillcolor=0):
- self.size = size
- self.degrees = degrees
- self.translate = translate
- self.scale = scale
- self.shear = shear
- self.ratio = ratio
- self.resample = resample
- self.fillcolor = fillcolor
-
- def __call__(self, *x):
- if not len(x):
- return []
-
- w, h = x[0].size
- scale_factor = max(self.size[1] / w, self.size[0] / h)
-
- w_padded = max(w, self.size[1])
- h_padded = max(h, self.size[0])
-
- pad_h = int(math.ceil((h_padded - h) / 2))
- pad_w = int(math.ceil((w_padded - w) / 2))
-
- scale = self.scale[0] * scale_factor, self.scale[1] * scale_factor
- translate = self.translate[0] * scale_factor, self.translate[1] * scale_factor
- affine_params = T.RandomAffine.get_params(self.degrees, translate, scale, self.shear, (w, h))
-
- def transform(img):
- if pad_h > 0 or pad_w > 0:
- img = F.pad(img, (pad_w, pad_h))
-
- img = F.affine(img, *affine_params, self.resample, self.fillcolor)
- img = F.center_crop(img, self.size)
- return img
-
- return [transform(xi) for xi in x]
- class RandomAffineAndResize(PairRandomAffineAndResize):
- def __call__(self, img):
- return super().__call__(img)[0]
|