augmentation.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. import random
  2. import torch
  3. import numpy as np
  4. import math
  5. from torchvision import transforms as T
  6. from torchvision.transforms import functional as F
  7. from PIL import Image, ImageFilter
  8. """
  9. Pair transforms are MODs of regular transforms so that it takes in multiple images
  10. and apply exact transforms on all images. This is especially useful when we want the
  11. transforms on a pair of images.
  12. Example:
  13. img1, img2, ..., imgN = transforms(img1, img2, ..., imgN)
  14. """
  15. class PairCompose(T.Compose):
  16. def __call__(self, *x):
  17. for transform in self.transforms:
  18. x = transform(*x)
  19. return x
  20. class PairApply:
  21. def __init__(self, transforms):
  22. self.transforms = transforms
  23. def __call__(self, *x):
  24. return [self.transforms(xi) for xi in x]
  25. class PairApplyOnlyAtIndices:
  26. def __init__(self, indices, transforms):
  27. self.indices = indices
  28. self.transforms = transforms
  29. def __call__(self, *x):
  30. return [self.transforms(xi) if i in self.indices else xi for i, xi in enumerate(x)]
  31. class PairRandomAffine(T.RandomAffine):
  32. def __init__(self, degrees, translate=None, scale=None, shear=None, resamples=None, fillcolor=0):
  33. super().__init__(degrees, translate, scale, shear, Image.NEAREST, fillcolor)
  34. self.resamples = resamples
  35. def __call__(self, *x):
  36. if not len(x):
  37. return []
  38. param = self.get_params(self.degrees, self.translate, self.scale, self.shear, x[0].size)
  39. resamples = self.resamples or [self.resample] * len(x)
  40. return [F.affine(xi, *param, resamples[i], self.fillcolor) for i, xi in enumerate(x)]
  41. class PairRandomHorizontalFlip(T.RandomHorizontalFlip):
  42. def __call__(self, *x):
  43. if torch.rand(1) < self.p:
  44. x = [F.hflip(xi) for xi in x]
  45. return x
  46. class RandomBoxBlur:
  47. def __init__(self, prob, max_radius):
  48. self.prob = prob
  49. self.max_radius = max_radius
  50. def __call__(self, img):
  51. if torch.rand(1) < self.prob:
  52. fil = ImageFilter.BoxBlur(random.choice(range(self.max_radius + 1)))
  53. img = img.filter(fil)
  54. return img
  55. class PairRandomBoxBlur(RandomBoxBlur):
  56. def __call__(self, *x):
  57. if torch.rand(1) < self.prob:
  58. fil = ImageFilter.BoxBlur(random.choice(range(self.max_radius + 1)))
  59. x = [xi.filter(fil) for xi in x]
  60. return x
  61. class RandomSharpen:
  62. def __init__(self, prob):
  63. self.prob = prob
  64. self.filter = ImageFilter.SHARPEN
  65. def __call__(self, img):
  66. if torch.rand(1) < self.prob:
  67. img = img.filter(self.filter)
  68. return img
  69. class PairRandomSharpen(RandomSharpen):
  70. def __call__(self, *x):
  71. if torch.rand(1) < self.prob:
  72. x = [xi.filter(self.filter) for xi in x]
  73. return x
  74. class PairRandomAffineAndResize:
  75. def __init__(self, size, degrees, translate, scale, shear, ratio=(3./4., 4./3.), resample=Image.BILINEAR, fillcolor=0):
  76. self.size = size
  77. self.degrees = degrees
  78. self.translate = translate
  79. self.scale = scale
  80. self.shear = shear
  81. self.ratio = ratio
  82. self.resample = resample
  83. self.fillcolor = fillcolor
  84. def __call__(self, *x):
  85. if not len(x):
  86. return []
  87. w, h = x[0].size
  88. scale_factor = max(self.size[1] / w, self.size[0] / h)
  89. w_padded = max(w, self.size[1])
  90. h_padded = max(h, self.size[0])
  91. pad_h = int(math.ceil((h_padded - h) / 2))
  92. pad_w = int(math.ceil((w_padded - w) / 2))
  93. scale = self.scale[0] * scale_factor, self.scale[1] * scale_factor
  94. translate = self.translate[0] * scale_factor, self.translate[1] * scale_factor
  95. affine_params = T.RandomAffine.get_params(self.degrees, translate, scale, self.shear, (w, h))
  96. def transform(img):
  97. if pad_h > 0 or pad_w > 0:
  98. img = F.pad(img, (pad_w, pad_h))
  99. img = F.affine(img, *affine_params, self.resample, self.fillcolor)
  100. img = F.center_crop(img, self.size)
  101. return img
  102. return [transform(xi) for xi in x]
  103. class RandomAffineAndResize(PairRandomAffineAndResize):
  104. def __call__(self, img):
  105. return super().__call__(img)[0]