video.py 1.3 KB

1234567891011121314151617181920212223242526272829303132333435363738
  1. import cv2
  2. import numpy as np
  3. from torch.utils.data import Dataset
  4. from PIL import Image
  5. class VideoDataset(Dataset):
  6. def __init__(self, path: str, transforms: any = None):
  7. self.cap = cv2.VideoCapture(path)
  8. self.transforms = transforms
  9. self.width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH))
  10. self.height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
  11. self.frame_rate = self.cap.get(cv2.CAP_PROP_FPS)
  12. self.frame_count = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
  13. def __len__(self):
  14. return self.frame_count
  15. def __getitem__(self, idx):
  16. if isinstance(idx, slice):
  17. return [self[i] for i in range(*idx.indices(len(self)))]
  18. if self.cap.get(cv2.CAP_PROP_POS_FRAMES) != idx:
  19. self.cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
  20. ret, img = self.cap.read()
  21. if not ret:
  22. raise IndexError(f'Idx: {idx} out of length: {len(self)}')
  23. img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  24. img = Image.fromarray(img)
  25. if self.transforms:
  26. img = self.transforms(img)
  27. return img
  28. def __enter__(self):
  29. return self
  30. def __exit__(self, exc_type, exc_value, exc_traceback):
  31. self.cap.release()