coco.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. import os
  2. import numpy as np
  3. import random
  4. import json
  5. import os
  6. from torch.utils.data import Dataset
  7. from torchvision import transforms
  8. from torchvision.transforms import functional as F
  9. from PIL import Image
  10. class CocoPanopticDataset(Dataset):
  11. def __init__(self,
  12. imgdir: str,
  13. anndir: str,
  14. annfile: str,
  15. transform=None):
  16. with open(annfile) as f:
  17. self.data = json.load(f)['annotations']
  18. self.data = list(filter(lambda data: any(info['category_id'] == 1 for info in data['segments_info']), self.data))
  19. self.imgdir = imgdir
  20. self.anndir = anndir
  21. self.transform = transform
  22. def __len__(self):
  23. return len(self.data)
  24. def __getitem__(self, idx):
  25. data = self.data[idx]
  26. img = self._load_img(data)
  27. seg = self._load_seg(data)
  28. if self.transform is not None:
  29. img, seg = self.transform(img, seg)
  30. return img, seg
  31. def _load_img(self, data):
  32. with Image.open(os.path.join(self.imgdir, data['file_name'].replace('.png', '.jpg'))) as img:
  33. return img.convert('RGB')
  34. def _load_seg(self, data):
  35. with Image.open(os.path.join(self.anndir, data['file_name'])) as ann:
  36. ann.load()
  37. ann = np.array(ann, copy=False).astype(np.int32)
  38. ann = ann[:, :, 0] + 256 * ann[:, :, 1] + 256 * 256 * ann[:, :, 2]
  39. seg = np.zeros(ann.shape, np.uint8)
  40. for segments_info in data['segments_info']:
  41. if segments_info['category_id'] in [1, 27, 32]: # person, backpack, tie
  42. seg[ann == segments_info['id']] = 255
  43. return Image.fromarray(seg)
  44. class CocoPanopticTrainAugmentation:
  45. def __init__(self, size):
  46. self.size = size
  47. self.jitter = transforms.ColorJitter(0.1, 0.1, 0.1, 0.1)
  48. def __call__(self, img, seg):
  49. # Affine
  50. params = transforms.RandomAffine.get_params(degrees=(-20, 20), translate=(0.1, 0.1),
  51. scale_ranges=(1, 1), shears=(-10, 10), img_size=img.size)
  52. img = F.affine(img, *params, interpolation=F.InterpolationMode.BILINEAR)
  53. seg = F.affine(seg, *params, interpolation=F.InterpolationMode.NEAREST)
  54. # Resize
  55. params = transforms.RandomResizedCrop.get_params(img, scale=(0.5, 1), ratio=(0.7, 1.3))
  56. img = F.resized_crop(img, *params, self.size, interpolation=F.InterpolationMode.BILINEAR)
  57. seg = F.resized_crop(seg, *params, self.size, interpolation=F.InterpolationMode.NEAREST)
  58. # Horizontal flip
  59. if random.random() < 0.5:
  60. img = F.hflip(img)
  61. seg = F.hflip(seg)
  62. # Color jitter
  63. img = self.jitter(img)
  64. # To tensor
  65. img = F.to_tensor(img)
  66. seg = F.to_tensor(seg)
  67. return img, seg
  68. class CocoPanopticValidAugmentation:
  69. def __init__(self, size):
  70. self.size = size
  71. def __call__(self, img, seg):
  72. # Resize
  73. params = transforms.RandomResizedCrop.get_params(img, scale=(1, 1), ratio=(1., 1.))
  74. img = F.resized_crop(img, *params, self.size, interpolation=F.InterpolationMode.BILINEAR)
  75. seg = F.resized_crop(seg, *params, self.size, interpolation=F.InterpolationMode.NEAREST)
  76. # To tensor
  77. img = F.to_tensor(img)
  78. seg = F.to_tensor(seg)
  79. return img, seg