12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061 |
- """
- python export_tensorflow.py \
- --model-variant mobilenetv3 \
- --model-variant deep_guided_filter \
- --pytorch-checkpoint rvm_mobilenetv3.pth \
- --tensorflow-output rvm_mobilenetv3_tf
- """
- import argparse
- import torch
- import tensorflow as tf
- from model import MattingNetwork, load_torch_weights
- # Add input output names and shapes
- class MattingNetworkWrapper(MattingNetwork):
- @tf.function(input_signature=[[
- tf.TensorSpec(tf.TensorShape([None, None, None, 3]), tf.float32, 'src'),
- tf.TensorSpec(tf.TensorShape(None), tf.float32, 'r1i'),
- tf.TensorSpec(tf.TensorShape(None), tf.float32, 'r2i'),
- tf.TensorSpec(tf.TensorShape(None), tf.float32, 'r3i'),
- tf.TensorSpec(tf.TensorShape(None), tf.float32, 'r4i'),
- tf.TensorSpec(tf.TensorShape(None), tf.float32, 'downsample_ratio')
- ]])
- def call(self, inputs):
- fgr, pha, r1o, r2o, r3o, r4o = super().call(inputs)
- return {'fgr': fgr, 'pha': pha, 'r1o': r1o, 'r2o': r2o, 'r3o': r3o, 'r4o': r4o}
-
- class Exporter:
- def __init__(self):
- self.parse_args()
- self.init_model()
- self.export()
-
- def parse_args(self):
- parser = argparse.ArgumentParser()
- parser.add_argument('--model-variant', type=str, required=True, choices=['mobilenetv3', 'resnet50'])
- parser.add_argument('--model-refiner', type=str, default='deep_guided_filter', choices=['deep_guided_filter', 'fast_guided_filter'])
- parser.add_argument('--pytorch-checkpoint', type=str, required=True)
- parser.add_argument('--tensorflow-output', type=str, required=True)
- self.args = parser.parse_args()
-
- def init_model(self):
- # Construct model
- self.model = MattingNetworkWrapper(self.args.model_variant, self.args.model_refiner)
- # Build model
- src = tf.random.normal([1, 1080, 1920, 3])
- rec = [ tf.constant(0.) ] * 4
- downsample_ratio = tf.constant(0.25)
- self.model([src, *rec, downsample_ratio])
- # Load PyTorch checkpoint
- load_torch_weights(self.model, torch.load(self.args.pytorch_checkpoint, map_location='cpu'))
- def export(self):
- self.model.save(self.args.tensorflow_output)
- if __name__ == '__main__':
- Exporter()
|