export_tensorflow.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. """
  2. python export_tensorflow.py \
  3. --model-variant mobilenetv3 \
  4. --model-variant deep_guided_filter \
  5. --pytorch-checkpoint rvm_mobilenetv3.pth \
  6. --tensorflow-output rvm_mobilenetv3_tf
  7. """
  8. import argparse
  9. import torch
  10. import tensorflow as tf
  11. from model import MattingNetwork, load_torch_weights
  12. # Add input output names and shapes
  13. class MattingNetworkWrapper(MattingNetwork):
  14. @tf.function(input_signature=[[
  15. tf.TensorSpec(tf.TensorShape([None, None, None, 3]), tf.float32, 'src'),
  16. tf.TensorSpec(tf.TensorShape(None), tf.float32, 'r1i'),
  17. tf.TensorSpec(tf.TensorShape(None), tf.float32, 'r2i'),
  18. tf.TensorSpec(tf.TensorShape(None), tf.float32, 'r3i'),
  19. tf.TensorSpec(tf.TensorShape(None), tf.float32, 'r4i'),
  20. tf.TensorSpec(tf.TensorShape(None), tf.float32, 'downsample_ratio')
  21. ]])
  22. def call(self, inputs):
  23. fgr, pha, r1o, r2o, r3o, r4o = super().call(inputs)
  24. return {'fgr': fgr, 'pha': pha, 'r1o': r1o, 'r2o': r2o, 'r3o': r3o, 'r4o': r4o}
  25. class Exporter:
  26. def __init__(self):
  27. self.parse_args()
  28. self.init_model()
  29. self.export()
  30. def parse_args(self):
  31. parser = argparse.ArgumentParser()
  32. parser.add_argument('--model-variant', type=str, required=True, choices=['mobilenetv3', 'resnet50'])
  33. parser.add_argument('--model-refiner', type=str, default='deep_guided_filter', choices=['deep_guided_filter', 'fast_guided_filter'])
  34. parser.add_argument('--pytorch-checkpoint', type=str, required=True)
  35. parser.add_argument('--tensorflow-output', type=str, required=True)
  36. self.args = parser.parse_args()
  37. def init_model(self):
  38. # Construct model
  39. self.model = MattingNetworkWrapper(self.args.model_variant, self.args.model_refiner)
  40. # Build model
  41. src = tf.random.normal([1, 1080, 1920, 3])
  42. rec = [ tf.constant(0.) ] * 4
  43. downsample_ratio = tf.constant(0.25)
  44. self.model([src, *rec, downsample_ratio])
  45. # Load PyTorch checkpoint
  46. load_torch_weights(self.model, torch.load(self.args.pytorch_checkpoint, map_location='cpu'))
  47. def export(self):
  48. self.model.save(self.args.tensorflow_output)
  49. if __name__ == '__main__':
  50. Exporter()