export_onnx.py 3.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. """
  2. python export_onnx.py \
  3. --model-variant mobilenetv3 \
  4. --checkpoint rvm_mobilenetv3.pth \
  5. --precision float16 \
  6. --opset 12 \
  7. --device cuda \
  8. --output model.onnx
  9. Note:
  10. The device is only used for exporting. It has nothing to do with the final model.
  11. Float16 must be exported through cuda. Float32 can be exported through cpu.
  12. """
  13. import argparse
  14. import torch
  15. from model import MattingNetwork
  16. class Exporter:
  17. def __init__(self):
  18. self.parse_args()
  19. self.init_model()
  20. self.export()
  21. def parse_args(self):
  22. parser = argparse.ArgumentParser()
  23. parser.add_argument('--model-variant', type=str, required=True, choices=['mobilenetv3', 'resnet50'])
  24. parser.add_argument('--model-refiner', type=str, default='deep_guided_filter', choices=['deep_guided_filter', 'fast_guided_filter'])
  25. parser.add_argument('--precision', type=str, required=True, choices=['float16', 'float32'])
  26. parser.add_argument('--opset', type=int, required=True)
  27. parser.add_argument('--device', type=str, required=True)
  28. parser.add_argument('--checkpoint', type=str, required=False)
  29. parser.add_argument('--output', type=str, required=True)
  30. self.args = parser.parse_args()
  31. def init_model(self):
  32. self.precision = torch.float32 if self.args.precision == 'float32' else torch.float16
  33. self.model = MattingNetwork(self.args.model_variant, self.args.model_refiner).eval().to(self.args.device, self.precision)
  34. if self.args.checkpoint is not None:
  35. self.model.load_state_dict(torch.load(self.args.checkpoint, map_location=self.args.device), strict=False)
  36. def export(self):
  37. rec = (torch.zeros([1, 1, 1, 1]).to(self.args.device, self.precision),) * 4
  38. src = torch.randn(1, 3, 1080, 1920).to(self.args.device, self.precision)
  39. downsample_ratio = torch.tensor([0.25]).to(self.args.device)
  40. dynamic_spatial = {0: 'batch_size', 2: 'height', 3: 'width'}
  41. dynamic_everything = {0: 'batch_size', 1: 'channels', 2: 'height', 3: 'width'}
  42. torch.onnx.export(
  43. self.model,
  44. (src, *rec, downsample_ratio),
  45. self.args.output,
  46. export_params=True,
  47. opset_version=self.args.opset,
  48. do_constant_folding=True,
  49. input_names=['src', 'r1i', 'r2i', 'r3i', 'r4i', 'downsample_ratio'],
  50. output_names=['fgr', 'pha', 'r1o', 'r2o', 'r3o', 'r4o'],
  51. dynamic_axes={
  52. 'src': dynamic_spatial,
  53. 'fgr': dynamic_spatial,
  54. 'pha': dynamic_spatial,
  55. 'r1i': dynamic_everything,
  56. 'r2i': dynamic_everything,
  57. 'r3i': dynamic_everything,
  58. 'r4i': dynamic_everything,
  59. 'r1o': dynamic_spatial,
  60. 'r2o': dynamic_spatial,
  61. 'r3o': dynamic_spatial,
  62. 'r4o': dynamic_spatial,
  63. })
  64. if __name__ == '__main__':
  65. Exporter()