1234567891011121314151617181920212223 |
- import os
- import glob
- from torch.utils.data import Dataset
- from PIL import Image
- class ImagesDataset(Dataset):
- def __init__(self, root, mode='RGB', transforms=None):
- self.transforms = transforms
- self.mode = mode
- self.filenames = sorted([*glob.glob(os.path.join(root, '**', '*.jpg'), recursive=True),
- *glob.glob(os.path.join(root, '**', '*.png'), recursive=True)])
- def __len__(self):
- return len(self.filenames)
- def __getitem__(self, idx):
- with Image.open(self.filenames[idx]) as img:
- img = img.convert(self.mode)
-
- if self.transforms:
- img = self.transforms(img)
-
- return img
|