1234567891011121314151617181920212223242526272829303132333435363738 |
- import cv2
- import numpy as np
- from torch.utils.data import Dataset
- from PIL import Image
- class VideoDataset(Dataset):
- def __init__(self, path: str, transforms: any = None):
- self.cap = cv2.VideoCapture(path)
- self.transforms = transforms
-
- self.width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH))
- self.height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
- self.frame_rate = self.cap.get(cv2.CAP_PROP_FPS)
- self.frame_count = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
-
- def __len__(self):
- return self.frame_count
-
- def __getitem__(self, idx):
- if isinstance(idx, slice):
- return [self[i] for i in range(*idx.indices(len(self)))]
-
- if self.cap.get(cv2.CAP_PROP_POS_FRAMES) != idx:
- self.cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
- ret, img = self.cap.read()
- if not ret:
- raise IndexError(f'Idx: {idx} out of length: {len(self)}')
- img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
- img = Image.fromarray(img)
- if self.transforms:
- img = self.transforms(img)
- return img
-
- def __enter__(self):
- return self
-
- def __exit__(self, exc_type, exc_value, exc_traceback):
- self.cap.release()
|