1234567891011121314151617181920 |
- from torch.utils.data import Dataset
- from typing import List
- class ZipDataset(Dataset):
- def __init__(self, datasets: List[Dataset], transforms=None, assert_equal_length=False):
- self.datasets = datasets
- self.transforms = transforms
-
- if assert_equal_length:
- for i in range(1, len(datasets)):
- assert len(datasets[i]) == len(datasets[i - 1]), 'Datasets are not equal in length.'
-
- def __len__(self):
- return max(len(d) for d in self.datasets)
-
- def __getitem__(self, idx):
- x = tuple(d[idx % len(d)] for d in self.datasets)
- if self.transforms:
- x = self.transforms(*x)
- return x
|