1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283 |
- """
- Export TorchScript
- python export_torchscript.py \
- --model-backbone resnet50 \
- --model-checkpoint "PATH_TO_CHECKPOINT" \
- --precision float32 \
- --output "torchscript.pth"
- """
- import argparse
- import torch
- from torch import nn
- from model import MattingRefine
- # --------------- Arguments ---------------
- parser = argparse.ArgumentParser(description='Export TorchScript')
- parser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2'])
- parser.add_argument('--model-checkpoint', type=str, required=True)
- parser.add_argument('--precision', type=str, default='float32', choices=['float32', 'float16'])
- parser.add_argument('--output', type=str, required=True)
- args = parser.parse_args()
- # --------------- Utils ---------------
- class MattingRefine_TorchScriptWrapper(nn.Module):
- """
- The purpose of this wrapper is to hoist all the configurable attributes to the top level.
- So that the user can easily change them after loading the saved TorchScript model.
-
- Example:
- model = torch.jit.load('torchscript.pth')
- model.backbone_scale = 0.25
- model.refine_mode = 'sampling'
- model.refine_sample_pixels = 80_000
- pha, fgr = model(src, bgr)[:2]
- """
-
- def __init__(self, *args, **kwargs):
- super().__init__()
- self.model = MattingRefine(*args, **kwargs)
-
- # Hoist the attributes to the top level.
- self.backbone_scale = self.model.backbone_scale
- self.refine_mode = self.model.refiner.mode
- self.refine_sample_pixels = self.model.refiner.sample_pixels
- self.refine_threshold = self.model.refiner.threshold
- self.refine_prevent_oversampling = self.model.refiner.prevent_oversampling
-
- def forward(self, src, bgr):
- # Reset the attributes.
- self.model.backbone_scale = self.backbone_scale
- self.model.refiner.mode = self.refine_mode
- self.model.refiner.sample_pixels = self.refine_sample_pixels
- self.model.refiner.threshold = self.refine_threshold
- self.model.refiner.prevent_oversampling = self.refine_prevent_oversampling
-
- return self.model(src, bgr)
-
- def load_state_dict(self, *args, **kwargs):
- return self.model.load_state_dict(*args, **kwargs)
-
-
- # --------------- Main ---------------
-
- model = MattingRefine_TorchScriptWrapper(args.model_backbone).eval()
- model.load_state_dict(torch.load(args.model_checkpoint, map_location='cpu'))
- for p in model.parameters():
- p.requires_grad = False
-
- if args.precision == 'float16':
- model = model.half()
-
- model = torch.jit.script(model)
- model.save(args.output)
|