youtubevis.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. import torch
  2. import os
  3. import json
  4. import numpy as np
  5. import random
  6. from torch.utils.data import Dataset
  7. from PIL import Image
  8. from torchvision import transforms
  9. from torchvision.transforms import functional as F
  10. class YouTubeVISDataset(Dataset):
  11. def __init__(self, videodir, annfile, size, seq_length, seq_sampler, transform=None):
  12. self.videodir = videodir
  13. self.size = size
  14. self.seq_length = seq_length
  15. self.seq_sampler = seq_sampler
  16. self.transform = transform
  17. with open(annfile) as f:
  18. data = json.load(f)
  19. self.masks = {}
  20. for ann in data['annotations']:
  21. if ann['category_id'] == 26: # person
  22. video_id = ann['video_id']
  23. if video_id not in self.masks:
  24. self.masks[video_id] = [[] for _ in range(len(ann['segmentations']))]
  25. for frame, mask in zip(self.masks[video_id], ann['segmentations']):
  26. if mask is not None:
  27. frame.append(mask)
  28. self.videos = {}
  29. for video in data['videos']:
  30. video_id = video['id']
  31. if video_id in self.masks:
  32. self.videos[video_id] = video
  33. self.index = []
  34. for video_id in self.videos.keys():
  35. for frame in range(len(self.videos[video_id]['file_names'])):
  36. self.index.append((video_id, frame))
  37. def __len__(self):
  38. return len(self.index)
  39. def __getitem__(self, idx):
  40. video_id, frame_id = self.index[idx]
  41. video = self.videos[video_id]
  42. frame_count = len(self.videos[video_id]['file_names'])
  43. H, W = video['height'], video['width']
  44. imgs, segs = [], []
  45. for t in self.seq_sampler(self.seq_length):
  46. frame = (frame_id + t) % frame_count
  47. filename = video['file_names'][frame]
  48. masks = self.masks[video_id][frame]
  49. with Image.open(os.path.join(self.videodir, filename)) as img:
  50. imgs.append(self._downsample_if_needed(img.convert('RGB'), Image.BILINEAR))
  51. seg = np.zeros((H, W), dtype=np.uint8)
  52. for mask in masks:
  53. seg |= self._decode_rle(mask)
  54. segs.append(self._downsample_if_needed(Image.fromarray(seg), Image.NEAREST))
  55. if self.transform is not None:
  56. imgs, segs = self.transform(imgs, segs)
  57. return imgs, segs
  58. def _decode_rle(self, rle):
  59. H, W = rle['size']
  60. msk = np.zeros(H * W, dtype=np.uint8)
  61. encoding = rle['counts']
  62. skip = 0
  63. for i in range(0, len(encoding) - 1, 2):
  64. skip += encoding[i]
  65. draw = encoding[i + 1]
  66. msk[skip : skip + draw] = 255
  67. skip += draw
  68. return msk.reshape(W, H).transpose()
  69. def _downsample_if_needed(self, img, resample):
  70. w, h = img.size
  71. if min(w, h) > self.size:
  72. scale = self.size / min(w, h)
  73. w = int(scale * w)
  74. h = int(scale * h)
  75. img = img.resize((w, h), resample)
  76. return img
  77. class YouTubeVISAugmentation:
  78. def __init__(self, size):
  79. self.size = size
  80. self.jitter = transforms.ColorJitter(0.3, 0.3, 0.3, 0.15)
  81. def __call__(self, imgs, segs):
  82. # To tensor
  83. imgs = torch.stack([F.to_tensor(img) for img in imgs])
  84. segs = torch.stack([F.to_tensor(seg) for seg in segs])
  85. # Resize
  86. params = transforms.RandomResizedCrop.get_params(imgs, scale=(0.8, 1), ratio=(0.9, 1.1))
  87. imgs = F.resized_crop(imgs, *params, self.size, interpolation=F.InterpolationMode.BILINEAR)
  88. segs = F.resized_crop(segs, *params, self.size, interpolation=F.InterpolationMode.BILINEAR)
  89. # Color jitter
  90. imgs = self.jitter(imgs)
  91. # Grayscale
  92. if random.random() < 0.05:
  93. imgs = F.rgb_to_grayscale(imgs, num_output_channels=3)
  94. # Horizontal flip
  95. if random.random() < 0.5:
  96. imgs = F.hflip(imgs)
  97. segs = F.hflip(segs)
  98. return imgs, segs