images.py 716 B

1234567891011121314151617181920212223
  1. import os
  2. import glob
  3. from torch.utils.data import Dataset
  4. from PIL import Image
  5. class ImagesDataset(Dataset):
  6. def __init__(self, root, mode='RGB', transforms=None):
  7. self.transforms = transforms
  8. self.mode = mode
  9. self.filenames = sorted([*glob.glob(os.path.join(root, '**', '*.jpg'), recursive=True),
  10. *glob.glob(os.path.join(root, '**', '*.png'), recursive=True)])
  11. def __len__(self):
  12. return len(self.filenames)
  13. def __getitem__(self, idx):
  14. with Image.open(self.filenames[idx]) as img:
  15. img = img.convert(self.mode)
  16. if self.transforms:
  17. img = self.transforms(img)
  18. return img