inference_speed_test.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. """
  2. Inference Speed Test
  3. Example:
  4. Run inference on random noise input for fixed computation setting.
  5. (i.e. mode in ['full', 'sampling'])
  6. python inference_speed_test.py \
  7. --model-type mattingrefine \
  8. --model-backbone resnet50 \
  9. --model-backbone-scale 0.25 \
  10. --model-refine-mode sampling \
  11. --model-refine-sample-pixels 80000 \
  12. --batch-size 1 \
  13. --resolution 1920 1080 \
  14. --backend pytorch \
  15. --precision float32
  16. Run inference on provided image input for dynamic computation setting.
  17. (i.e. mode in ['thresholding'])
  18. python inference_speed_test.py \
  19. --model-type mattingrefine \
  20. --model-backbone resnet50 \
  21. --model-backbone-scale 0.25 \
  22. --model-checkpoint "PATH_TO_CHECKPOINT" \
  23. --model-refine-mode thresholding \
  24. --model-refine-threshold 0.7 \
  25. --batch-size 1 \
  26. --backend pytorch \
  27. --precision float32 \
  28. --image-src "PATH_TO_IMAGE_SRC" \
  29. --image-bgr "PATH_TO_IMAGE_BGR"
  30. """
  31. import argparse
  32. import torch
  33. from torchvision.transforms.functional import to_tensor
  34. from tqdm import tqdm
  35. from PIL import Image
  36. from model import MattingBase, MattingRefine
  37. # --------------- Arguments ---------------
  38. parser = argparse.ArgumentParser()
  39. parser.add_argument('--model-type', type=str, required=True, choices=['mattingbase', 'mattingrefine'])
  40. parser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2'])
  41. parser.add_argument('--model-backbone-scale', type=float, default=0.25)
  42. parser.add_argument('--model-checkpoint', type=str, default=None)
  43. parser.add_argument('--model-refine-mode', type=str, default='sampling', choices=['full', 'sampling', 'thresholding'])
  44. parser.add_argument('--model-refine-sample-pixels', type=int, default=80_000)
  45. parser.add_argument('--model-refine-threshold', type=float, default=0.7)
  46. parser.add_argument('--model-refine-kernel-size', type=int, default=3)
  47. parser.add_argument('--batch-size', type=int, default=1)
  48. parser.add_argument('--resolution', type=int, default=None, nargs=2)
  49. parser.add_argument('--precision', type=str, default='float32', choices=['float32', 'float16'])
  50. parser.add_argument('--backend', type=str, default='pytorch', choices=['pytorch', 'torchscript'])
  51. parser.add_argument('--device', type=str, choices=['cpu', 'cuda'], default='cuda')
  52. parser.add_argument('--image-src', type=str, default=None)
  53. parser.add_argument('--image-bgr', type=str, default=None)
  54. args = parser.parse_args()
  55. assert type(args.image_src) == type(args.image_bgr), 'Image source and background must be provided together.'
  56. assert (not args.image_src) != (not args.resolution), 'Must provide either a resolution or an image and not both.'
  57. # --------------- Run Loop ---------------
  58. device = torch.device(args.device)
  59. # Load model
  60. if args.model_type == 'mattingbase':
  61. model = MattingBase(args.model_backbone)
  62. if args.model_type == 'mattingrefine':
  63. model = MattingRefine(
  64. args.model_backbone,
  65. args.model_backbone_scale,
  66. args.model_refine_mode,
  67. args.model_refine_sample_pixels,
  68. args.model_refine_threshold,
  69. args.model_refine_kernel_size,
  70. refine_prevent_oversampling=False)
  71. if args.model_checkpoint:
  72. model.load_state_dict(torch.load(args.model_checkpoint), strict=False)
  73. if args.precision == 'float32':
  74. precision = torch.float32
  75. else:
  76. precision = torch.float16
  77. if args.backend == 'torchscript':
  78. model = torch.jit.script(model)
  79. model = model.eval().to(device=device, dtype=precision)
  80. # Load data
  81. if not args.image_src:
  82. src = torch.rand((args.batch_size, 3, *args.resolution[::-1]), device=device, dtype=precision)
  83. bgr = torch.rand((args.batch_size, 3, *args.resolution[::-1]), device=device, dtype=precision)
  84. else:
  85. src = to_tensor(Image.open(args.image_src)).unsqueeze(0).repeat(args.batch_size, 1, 1, 1).to(device=device, dtype=precision)
  86. bgr = to_tensor(Image.open(args.image_bgr)).unsqueeze(0).repeat(args.batch_size, 1, 1, 1).to(device=device, dtype=precision)
  87. # Loop
  88. with torch.no_grad():
  89. for _ in tqdm(range(1000)):
  90. model(src, bgr)