onnx_helper.py 3.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. import torch
  2. from torch import Tensor
  3. from torch.nn import functional as F
  4. from torch.autograd import Function
  5. """
  6. We implement custom ONNX export logics because PyTorch doesn't trace the use of "Shape" op very well.
  7. The custom export logics add support for runtime downsample_ratio input, and clean up the ONNX graph.
  8. """
  9. class CustomOnnxResizeByFactorOp(Function):
  10. """
  11. This implements resize by scale_factor. Unlike PyTorch which can only export the scale_factor is a hardcoded int,
  12. we implement it such that the scale_factor can be a tensor provided at runtime.
  13. """
  14. @staticmethod
  15. def forward(ctx, x, scale_factor):
  16. assert x.ndim == 4
  17. return F.interpolate(x, scale_factor=scale_factor.item(),
  18. mode='bilinear', recompute_scale_factor=False, align_corners=False)
  19. @staticmethod
  20. def symbolic(g, x, scale_factor):
  21. empty_roi = g.op("Constant", value_t=torch.tensor([], dtype=torch.float32))
  22. scale_factor = g.op("Concat",
  23. g.op("Constant", value_t=torch.tensor([1, 1], dtype=torch.float32)),
  24. scale_factor, scale_factor, axis_i=0)
  25. return g.op('Resize',
  26. x,
  27. empty_roi,
  28. scale_factor,
  29. coordinate_transformation_mode_s='pytorch_half_pixel',
  30. cubic_coeff_a_f=-0.75,
  31. mode_s='linear',
  32. nearest_mode_s="floor")
  33. class CustomOnnxResizeToMatchSizeOp(Function):
  34. """
  35. This implements bilinearly resize a tensor to match the size of another.
  36. This implementation has cleaner ONNX graph than PyTorch's default export.
  37. """
  38. @staticmethod
  39. def forward(ctx, x, y):
  40. assert x.ndim == 4 and y.ndim == 4
  41. return F.interpolate(x, y.shape[2:], mode='bilinear', align_corners=False)
  42. @staticmethod
  43. def symbolic(g, x, y):
  44. empty_roi = g.op("Constant", value_t=torch.tensor([], dtype=torch.float32))
  45. empty_scales = g.op("Constant", value_t=torch.tensor([], dtype=torch.float32))
  46. BC = g.op('Slice',
  47. g.op('Shape', x), # input
  48. g.op('Constant', value_t=torch.tensor([0])), # starts
  49. g.op('Constant', value_t=torch.tensor([2])), # ends
  50. g.op('Constant', value_t=torch.tensor([0])), # axes
  51. )
  52. HW = g.op('Slice',
  53. g.op('Shape', y), # input
  54. g.op('Constant', value_t=torch.tensor([2])), # starts
  55. g.op('Constant', value_t=torch.tensor([4])), # ends
  56. g.op('Constant', value_t=torch.tensor([0])), # axes
  57. )
  58. output_shape = g.op('Concat', BC, HW, axis_i=0)
  59. return g.op('Resize',
  60. x,
  61. empty_roi,
  62. empty_scales,
  63. output_shape,
  64. coordinate_transformation_mode_s='pytorch_half_pixel',
  65. cubic_coeff_a_f=-0.75,
  66. mode_s='linear',
  67. nearest_mode_s="floor")
  68. class CustomOnnxCropToMatchSizeOp(Function):
  69. """
  70. This implements cropping a tensor to match the size of another.
  71. This implementation has cleaner ONNX graph than PyTorch's default export.
  72. """
  73. @staticmethod
  74. def forward(ctx, x, y):
  75. assert x.ndim == 4 and y.ndim == 4
  76. return x[:, :, :y.size(2), :y.size(3)]
  77. @staticmethod
  78. def symbolic(g, x, y):
  79. size = g.op('Slice',
  80. g.op('Shape', y), # input
  81. g.op('Constant', value_t=torch.tensor([2])), # starts
  82. g.op('Constant', value_t=torch.tensor([4])), # ends
  83. g.op('Constant', value_t=torch.tensor([0])), # axes
  84. )
  85. return g.op('Slice',
  86. x, # input
  87. g.op('Constant', value_t=torch.tensor([0, 0])), # starts
  88. size, # ends
  89. g.op('Constant', value_t=torch.tensor([2, 3])), # axes
  90. )