spd_preprocess.py 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445
  1. # pip install supervisely
  2. import supervisely_lib as sly
  3. import numpy as np
  4. import os
  5. from PIL import Image
  6. from tqdm import tqdm
  7. # Download dataset from <https://supervise.ly/explore/projects/supervisely-person-dataset-23304/datasets>
  8. project_root = 'PATH_TO/Supervisely Person Dataset' # <-- Configure input
  9. project = sly.Project(project_root, sly.OpenMode.READ)
  10. output_path = 'OUTPUT_DIR' # <-- Configure output
  11. os.makedirs(os.path.join(output_path, 'train', 'src'))
  12. os.makedirs(os.path.join(output_path, 'train', 'msk'))
  13. os.makedirs(os.path.join(output_path, 'valid', 'src'))
  14. os.makedirs(os.path.join(output_path, 'valid', 'msk'))
  15. max_size = 2048 # <-- Configure max size
  16. for dataset in project.datasets:
  17. for item in tqdm(dataset):
  18. ann = sly.Annotation.load_json_file(dataset.get_ann_path(item), project.meta)
  19. msk = np.zeros(ann.img_size, dtype=np.uint8)
  20. for label in ann.labels:
  21. label.geometry.draw(msk, color=[255])
  22. msk = Image.fromarray(msk)
  23. img = Image.open(dataset.get_img_path(item)).convert('RGB')
  24. if img.size[0] > max_size or img.size[1] > max_size:
  25. scale = max_size / max(img.size)
  26. img = img.resize((int(img.size[0] * scale), int(img.size[1] * scale)), Image.BILINEAR)
  27. msk = msk.resize((int(msk.size[0] * scale), int(msk.size[1] * scale)), Image.NEAREST)
  28. img.save(os.path.join(output_path, 'train', 'src', item.replace('.png', '.jpg')))
  29. msk.save(os.path.join(output_path, 'train', 'msk', item.replace('.png', '.jpg')))
  30. # Move first 100 to validation set
  31. names = os.listdir(os.path.join(output_path, 'train', 'src'))
  32. for name in tqdm(names[:100]):
  33. os.rename(
  34. os.path.join(output_path, 'train', 'src', name),
  35. os.path.join(output_path, 'valid', 'src', name))
  36. os.rename(
  37. os.path.join(output_path, 'train', 'msk', name),
  38. os.path.join(output_path, 'valid', 'msk', name))