123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123 |
- import torch
- import os
- import json
- import numpy as np
- import random
- from torch.utils.data import Dataset
- from PIL import Image
- from torchvision import transforms
- from torchvision.transforms import functional as F
- class YouTubeVISDataset(Dataset):
- def __init__(self, videodir, annfile, size, seq_length, seq_sampler, transform=None):
- self.videodir = videodir
- self.size = size
- self.seq_length = seq_length
- self.seq_sampler = seq_sampler
- self.transform = transform
-
- with open(annfile) as f:
- data = json.load(f)
- self.masks = {}
- for ann in data['annotations']:
- if ann['category_id'] == 26: # person
- video_id = ann['video_id']
- if video_id not in self.masks:
- self.masks[video_id] = [[] for _ in range(len(ann['segmentations']))]
- for frame, mask in zip(self.masks[video_id], ann['segmentations']):
- if mask is not None:
- frame.append(mask)
-
- self.videos = {}
- for video in data['videos']:
- video_id = video['id']
- if video_id in self.masks:
- self.videos[video_id] = video
-
- self.index = []
- for video_id in self.videos.keys():
- for frame in range(len(self.videos[video_id]['file_names'])):
- self.index.append((video_id, frame))
-
- def __len__(self):
- return len(self.index)
-
- def __getitem__(self, idx):
- video_id, frame_id = self.index[idx]
- video = self.videos[video_id]
- frame_count = len(self.videos[video_id]['file_names'])
- H, W = video['height'], video['width']
-
- imgs, segs = [], []
- for t in self.seq_sampler(self.seq_length):
- frame = (frame_id + t) % frame_count
- filename = video['file_names'][frame]
- masks = self.masks[video_id][frame]
-
- with Image.open(os.path.join(self.videodir, filename)) as img:
- imgs.append(self._downsample_if_needed(img.convert('RGB'), Image.BILINEAR))
-
- seg = np.zeros((H, W), dtype=np.uint8)
- for mask in masks:
- seg |= self._decode_rle(mask)
- segs.append(self._downsample_if_needed(Image.fromarray(seg), Image.NEAREST))
-
- if self.transform is not None:
- imgs, segs = self.transform(imgs, segs)
-
- return imgs, segs
-
- def _decode_rle(self, rle):
- H, W = rle['size']
- msk = np.zeros(H * W, dtype=np.uint8)
- encoding = rle['counts']
- skip = 0
- for i in range(0, len(encoding) - 1, 2):
- skip += encoding[i]
- draw = encoding[i + 1]
- msk[skip : skip + draw] = 255
- skip += draw
- return msk.reshape(W, H).transpose()
-
- def _downsample_if_needed(self, img, resample):
- w, h = img.size
- if min(w, h) > self.size:
- scale = self.size / min(w, h)
- w = int(scale * w)
- h = int(scale * h)
- img = img.resize((w, h), resample)
- return img
- class YouTubeVISAugmentation:
- def __init__(self, size):
- self.size = size
- self.jitter = transforms.ColorJitter(0.3, 0.3, 0.3, 0.15)
-
- def __call__(self, imgs, segs):
-
- # To tensor
- imgs = torch.stack([F.to_tensor(img) for img in imgs])
- segs = torch.stack([F.to_tensor(seg) for seg in segs])
-
- # Resize
- params = transforms.RandomResizedCrop.get_params(imgs, scale=(0.8, 1), ratio=(0.9, 1.1))
- imgs = F.resized_crop(imgs, *params, self.size, interpolation=F.InterpolationMode.BILINEAR)
- segs = F.resized_crop(segs, *params, self.size, interpolation=F.InterpolationMode.BILINEAR)
-
- # Color jitter
- imgs = self.jitter(imgs)
-
- # Grayscale
- if random.random() < 0.05:
- imgs = F.rgb_to_grayscale(imgs, num_output_channels=3)
-
- # Horizontal flip
- if random.random() < 0.5:
- imgs = F.hflip(imgs)
- segs = F.hflip(segs)
-
- return imgs, segs
|