inference_images.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. """
  2. Inference images: Extract matting on images.
  3. Example:
  4. python inference_images.py \
  5. --model-type mattingrefine \
  6. --model-backbone resnet50 \
  7. --model-backbone-scale 0.25 \
  8. --model-refine-mode sampling \
  9. --model-refine-sample-pixels 80000 \
  10. --model-checkpoint "PATH_TO_CHECKPOINT" \
  11. --images-src "PATH_TO_IMAGES_SRC_DIR" \
  12. --images-bgr "PATH_TO_IMAGES_BGR_DIR" \
  13. --output-dir "PATH_TO_OUTPUT_DIR" \
  14. --output-type com fgr pha
  15. """
  16. import argparse
  17. import torch
  18. import os
  19. import shutil
  20. from torch import nn
  21. from torch.nn import functional as F
  22. from torch.utils.data import DataLoader
  23. from torchvision import transforms as T
  24. from torchvision.transforms.functional import to_pil_image
  25. from threading import Thread
  26. from tqdm import tqdm
  27. from dataset import ImagesDataset, ZipDataset
  28. from dataset import augmentation as A
  29. from model import MattingBase, MattingRefine
  30. from inference_utils import HomographicAlignment
  31. # --------------- Arguments ---------------
  32. parser = argparse.ArgumentParser(description='Inference images')
  33. parser.add_argument('--model-type', type=str, required=True, choices=['mattingbase', 'mattingrefine'])
  34. parser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2'])
  35. parser.add_argument('--model-backbone-scale', type=float, default=0.25)
  36. parser.add_argument('--model-checkpoint', type=str, required=True)
  37. parser.add_argument('--model-refine-mode', type=str, default='sampling', choices=['full', 'sampling', 'thresholding'])
  38. parser.add_argument('--model-refine-sample-pixels', type=int, default=80_000)
  39. parser.add_argument('--model-refine-threshold', type=float, default=0.7)
  40. parser.add_argument('--model-refine-kernel-size', type=int, default=3)
  41. parser.add_argument('--images-src', type=str, required=True)
  42. parser.add_argument('--images-bgr', type=str, required=True)
  43. parser.add_argument('--device', type=str, choices=['cpu', 'cuda'], default='cuda')
  44. parser.add_argument('--preprocess-alignment', action='store_true')
  45. parser.add_argument('--output-dir', type=str, required=True)
  46. parser.add_argument('--output-types', type=str, required=True, nargs='+', choices=['com', 'pha', 'fgr', 'err', 'ref'])
  47. parser.add_argument('-y', action='store_true')
  48. args = parser.parse_args()
  49. assert 'err' not in args.output_types or args.model_type in ['mattingbase', 'mattingrefine'], \
  50. 'Only mattingbase and mattingrefine support err output'
  51. assert 'ref' not in args.output_types or args.model_type in ['mattingrefine'], \
  52. 'Only mattingrefine support ref output'
  53. # --------------- Main ---------------
  54. device = torch.device(args.device)
  55. # Load model
  56. if args.model_type == 'mattingbase':
  57. model = MattingBase(args.model_backbone)
  58. if args.model_type == 'mattingrefine':
  59. model = MattingRefine(
  60. args.model_backbone,
  61. args.model_backbone_scale,
  62. args.model_refine_mode,
  63. args.model_refine_sample_pixels,
  64. args.model_refine_threshold,
  65. args.model_refine_kernel_size)
  66. model = model.to(device).eval()
  67. model.load_state_dict(torch.load(args.model_checkpoint), strict=False)
  68. # Load images
  69. dataset = ZipDataset([
  70. ImagesDataset(args.images_src),
  71. ImagesDataset(args.images_bgr),
  72. ], assert_equal_length=True, transforms=A.PairCompose([
  73. HomographicAlignment() if args.preprocess_alignment else A.PairApply(nn.Identity()),
  74. A.PairApply(T.ToTensor())
  75. ]))
  76. dataloader = DataLoader(dataset, batch_size=1, num_workers=8, pin_memory=True)
  77. # Create output directory
  78. if os.path.exists(args.output_dir):
  79. if args.y or input(f'Directory {args.output_dir} already exists. Override? [Y/N]: ').lower() == 'y':
  80. shutil.rmtree(args.output_dir)
  81. else:
  82. exit()
  83. for output_type in args.output_types:
  84. os.makedirs(os.path.join(args.output_dir, output_type))
  85. # Worker function
  86. def writer(img, path):
  87. img = to_pil_image(img[0].cpu())
  88. img.save(path)
  89. # Conversion loop
  90. with torch.no_grad():
  91. for i, (src, bgr) in enumerate(tqdm(dataloader)):
  92. filename = dataset.datasets[0].filenames[i]
  93. src = src.to(device, non_blocking=True)
  94. bgr = bgr.to(device, non_blocking=True)
  95. if args.model_type == 'mattingbase':
  96. pha, fgr, err, _ = model(src, bgr)
  97. elif args.model_type == 'mattingrefine':
  98. pha, fgr, _, _, err, ref = model(src, bgr)
  99. elif args.model_type == 'mattingbm':
  100. pha, fgr = model(src, bgr)
  101. if 'com' in args.output_types:
  102. com = torch.cat([fgr * pha.ne(0), pha], dim=1)
  103. Thread(target=writer, args=(com, filename.replace(args.images_src, os.path.join(args.output_dir, 'com')).replace('.jpg', '.png'))).start()
  104. if 'pha' in args.output_types:
  105. Thread(target=writer, args=(pha, filename.replace(args.images_src, os.path.join(args.output_dir, 'pha')).replace('.png', '.jpg'))).start()
  106. if 'fgr' in args.output_types:
  107. Thread(target=writer, args=(fgr, filename.replace(args.images_src, os.path.join(args.output_dir, 'fgr')).replace('.png', '.jpg'))).start()
  108. if 'err' in args.output_types:
  109. err = F.interpolate(err, src.shape[2:], mode='bilinear', align_corners=False)
  110. Thread(target=writer, args=(err, filename.replace(args.images_src, os.path.join(args.output_dir, 'err')).replace('.png', '.jpg'))).start()
  111. if 'ref' in args.output_types:
  112. ref = F.interpolate(ref, src.shape[2:], mode='nearest')
  113. Thread(target=writer, args=(ref, filename.replace(args.images_src, os.path.join(args.output_dir, 'ref')).replace('.png', '.jpg'))).start()