augmentation.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. import random
  2. import torch
  3. import numpy as np
  4. import math
  5. from torchvision import transforms as T
  6. from torchvision.transforms import functional as F
  7. from PIL import Image, ImageFilter
  8. """
  9. Pair transforms are MODs of regular transforms so that it takes in multiple images
  10. and apply exact transforms on all images. This is especially useful when we want the
  11. transforms on a pair of images.
  12. Example:
  13. img1, img2, ..., imgN = transforms(img1, img2, ..., imgN)
  14. """
  15. class PairCompose(T.Compose):
  16. def __call__(self, *x):
  17. for transform in self.transforms:
  18. x = transform(*x)
  19. return x
  20. class PairApply:
  21. def __init__(self, transforms):
  22. self.transforms = transforms
  23. def __call__(self, *x):
  24. return [self.transforms(xi) for xi in x]
  25. class PairApplyOnlyAtIndices:
  26. def __init__(self, indices, transforms):
  27. self.indices = indices
  28. self.transforms = transforms
  29. def __call__(self, *x):
  30. return [self.transforms(xi) if i in self.indices else xi for i, xi in enumerate(x)]
  31. class PairRandomAffine(T.RandomAffine):
  32. def __init__(self, degrees, translate=None, scale=None, shear=None, resamples=None, fillcolor=0):
  33. super().__init__(degrees, translate, scale, shear, Image.NEAREST, fillcolor)
  34. self.resamples = resamples
  35. def __call__(self, *x):
  36. if not len(x):
  37. return []
  38. param = self.get_params(self.degrees, self.translate, self.scale, self.shear, x[0].size)
  39. resamples = self.resamples or [self.resample] * len(x)
  40. return [F.affine(xi, *param, resamples[i], self.fillcolor) for i, xi in enumerate(x)]
  41. class PairRandomResizedCrop(T.RandomResizedCrop):
  42. def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolations=None):
  43. super().__init__(size, scale, ratio, Image.BILINEAR)
  44. self.interpolations = interpolations
  45. def __call__(self, *x):
  46. if not len(x):
  47. return []
  48. i, j, h, w = self.get_params(x[0], self.scale, self.ratio)
  49. interpolations = self.interpolations or [self.interpolation] * len(x)
  50. return [F.resized_crop(xi, i, j, h, w, self.size, interpolations[i]) for i, xi in enumerate(x)]
  51. class PairRandomHorizontalFlip(T.RandomHorizontalFlip):
  52. def __call__(self, *x):
  53. if torch.rand(1) < self.p:
  54. x = [F.hflip(xi) for xi in x]
  55. return x
  56. class RandomBoxBlur:
  57. def __init__(self, prob, max_radius):
  58. self.prob = prob
  59. self.max_radius = max_radius
  60. def __call__(self, img):
  61. if torch.rand(1) < self.prob:
  62. fil = ImageFilter.BoxBlur(random.choice(range(self.max_radius + 1)))
  63. img = img.filter(fil)
  64. return img
  65. class PairRandomBoxBlur(RandomBoxBlur):
  66. def __call__(self, *x):
  67. if torch.rand(1) < self.prob:
  68. fil = ImageFilter.BoxBlur(random.choice(range(self.max_radius + 1)))
  69. x = [xi.filter(fil) for xi in x]
  70. return x
  71. class RandomSharpen:
  72. def __init__(self, prob):
  73. self.prob = prob
  74. self.filter = ImageFilter.SHARPEN
  75. def __call__(self, img):
  76. if torch.rand(1) < self.prob:
  77. img = img.filter(self.filter)
  78. return img
  79. class PairRandomSharpen(RandomSharpen):
  80. def __call__(self, *x):
  81. if torch.rand(1) < self.prob:
  82. x = [xi.filter(self.filter) for xi in x]
  83. return x
  84. class PairRandomAffineAndResize:
  85. def __init__(self, size, degrees, translate, scale, shear, ratio=(3./4., 4./3.), resample=Image.BILINEAR, fillcolor=0):
  86. self.size = size
  87. self.degrees = degrees
  88. self.translate = translate
  89. self.scale = scale
  90. self.shear = shear
  91. self.ratio = ratio
  92. self.resample = resample
  93. self.fillcolor = fillcolor
  94. def __call__(self, *x):
  95. if not len(x):
  96. return []
  97. w, h = x[0].size
  98. scale_factor = max(self.size[1] / w, self.size[0] / h)
  99. w_padded = max(w, self.size[1])
  100. h_padded = max(h, self.size[0])
  101. pad_h = int(math.ceil((h_padded - h) / 2))
  102. pad_w = int(math.ceil((w_padded - w) / 2))
  103. scale = self.scale[0] * scale_factor, self.scale[1] * scale_factor
  104. translate = self.translate[0] * scale_factor, self.translate[1] * scale_factor
  105. affine_params = T.RandomAffine.get_params(self.degrees, translate, scale, self.shear, (w, h))
  106. def transform(img):
  107. if pad_h > 0 or pad_w > 0:
  108. img = F.pad(img, (pad_w, pad_h))
  109. img = F.affine(img, *affine_params, self.resample, self.fillcolor)
  110. img = F.center_crop(img, self.size)
  111. return img
  112. return [transform(xi) for xi in x]
  113. class RandomAffineAndResize(PairRandomAffineAndResize):
  114. def __call__(self, img):
  115. return super().__call__(img)[0]