zip.py 694 B

1234567891011121314151617181920
  1. from torch.utils.data import Dataset
  2. from typing import List
  3. class ZipDataset(Dataset):
  4. def __init__(self, datasets: List[Dataset], transforms=None, assert_equal_length=False):
  5. self.datasets = datasets
  6. self.transforms = transforms
  7. if assert_equal_length:
  8. for i in range(1, len(datasets)):
  9. assert len(datasets[i]) == len(datasets[i - 1]), 'Datasets are not equal in length.'
  10. def __len__(self):
  11. return max(len(d) for d in self.datasets)
  12. def __getitem__(self, idx):
  13. x = tuple(d[idx % len(d)] for d in self.datasets)
  14. if self.transforms:
  15. x = self.transforms(*x)
  16. return x