export_onnx.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. """
  2. Export MattingRefine as ONNX format
  3. Example:
  4. python export_onnx.py \
  5. --model-type mattingrefine \
  6. --model-checkpoint "PATH_TO_MODEL_CHECKPOINT" \
  7. --model-backbone resnet50 \
  8. --model-backbone-scale 0.25 \
  9. --model-refine-mode sampling \
  10. --model-refine-sample-pixels 80000 \
  11. --onnx-opset-version 11 \
  12. --onnx-constant-folding \
  13. --precision float32 \
  14. --output "model.onnx" \
  15. --validate
  16. """
  17. import argparse
  18. import torch
  19. from model import MattingBase, MattingRefine
  20. # --------------- Arguments ---------------
  21. parser = argparse.ArgumentParser(description='Export ONNX')
  22. parser.add_argument('--model-type', type=str, required=True, choices=['mattingbase', 'mattingrefine'])
  23. parser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2'])
  24. parser.add_argument('--model-backbone-scale', type=float, default=0.25)
  25. parser.add_argument('--model-checkpoint', type=str, required=True)
  26. parser.add_argument('--model-refine-mode', type=str, default='sampling', choices=['full', 'sampling', 'thresholding'])
  27. parser.add_argument('--model-refine-sample-pixels', type=int, default=80_000)
  28. parser.add_argument('--model-refine-threshold', type=float, default=0.7)
  29. parser.add_argument('--model-refine-kernel-size', type=int, default=3)
  30. parser.add_argument('--onnx-verbose', type=bool, default=True)
  31. parser.add_argument('--onnx-opset-version', type=int, default=12)
  32. parser.add_argument('--onnx-constant-folding', default=True, action='store_true')
  33. parser.add_argument('--device', type=str, default='cpu')
  34. parser.add_argument('--precision', type=str, default='float32', choices=['float32', 'float16'])
  35. parser.add_argument('--validate', action='store_true')
  36. parser.add_argument('--output', type=str, required=True)
  37. args = parser.parse_args()
  38. # --------------- Main ---------------
  39. # Load model
  40. if args.model_type == 'mattingbase':
  41. model = MattingBase(args.model_backbone)
  42. if args.model_type == 'mattingrefine':
  43. model = MattingRefine(
  44. args.model_backbone,
  45. args.model_backbone_scale,
  46. args.model_refine_mode,
  47. args.model_refine_sample_pixels,
  48. args.model_refine_threshold,
  49. args.model_refine_kernel_size,
  50. refine_patch_crop_method='roi_align',
  51. refine_patch_replace_method='scatter_element')
  52. model.load_state_dict(torch.load(args.model_checkpoint, map_location=args.device), strict=False)
  53. precision = {'float32': torch.float32, 'float16': torch.float16}[args.precision]
  54. model.eval().to(precision).to(args.device)
  55. # Dummy Inputs
  56. src = torch.randn(2, 3, 1080, 1920).to(precision).to(args.device)
  57. bgr = torch.randn(2, 3, 1080, 1920).to(precision).to(args.device)
  58. # Export ONNX
  59. if args.model_type == 'mattingbase':
  60. input_names=['src', 'bgr']
  61. output_names = ['pha', 'fgr', 'err', 'hid']
  62. if args.model_type == 'mattingrefine':
  63. input_names=['src', 'bgr']
  64. output_names = ['pha', 'fgr', 'pha_sm', 'fgr_sm', 'err_sm', 'ref_sm']
  65. torch.onnx.export(
  66. model=model,
  67. args=(src, bgr),
  68. f=args.output,
  69. verbose=args.onnx_verbose,
  70. opset_version=args.onnx_opset_version,
  71. do_constant_folding=args.onnx_constant_folding,
  72. input_names=input_names,
  73. output_names=output_names,
  74. dynamic_axes={name: {0: 'batch', 2: 'height', 3: 'width'} for name in [*input_names, *output_names]})
  75. print(f'ONNX model saved at: {args.output}')
  76. # Validation
  77. if args.validate:
  78. import onnxruntime
  79. import numpy as np
  80. print(f'Validating ONNX model.')
  81. # Test with different inputs.
  82. src = torch.randn(1, 3, 720, 1280).to(precision).to(args.device)
  83. bgr = torch.randn(1, 3, 720, 1280).to(precision).to(args.device)
  84. with torch.no_grad():
  85. out_torch = model(src, bgr)
  86. sess = onnxruntime.InferenceSession(args.output)
  87. out_onnx = sess.run(None, {
  88. 'src': src.cpu().numpy(),
  89. 'bgr': bgr.cpu().numpy()
  90. })
  91. e_max = 0
  92. for a, b, name in zip(out_torch, out_onnx, output_names):
  93. b = torch.as_tensor(b)
  94. e = torch.abs(a.cpu() - b).max()
  95. e_max = max(e_max, e.item())
  96. print(f'"{name}" output differs by maximum of {e}')
  97. if e_max < 0.001:
  98. print('Validation passed.')
  99. else:
  100. raise 'Validation failed.'