sample.py 413 B

1234567891011121314
  1. from torch.utils.data import Dataset
  2. class SampleDataset(Dataset):
  3. def __init__(self, dataset, samples):
  4. samples = min(samples, len(dataset))
  5. self.dataset = dataset
  6. self.indices = [i * int(len(dataset) / samples) for i in range(samples)]
  7. def __len__(self):
  8. return len(self.indices)
  9. def __getitem__(self, idx):
  10. return self.dataset[self.indices[idx]]