|
@@ -17,7 +17,8 @@ torch.backends.cudnn.benchmark = True
|
|
|
if __name__ == "__main__":
|
|
|
parser = argparse.ArgumentParser(description='Inference from web-cam')
|
|
|
|
|
|
- parser.add_argument('--model-backbone', type=str, required=True, choices=['resnet50', 'mobilenetv3'])
|
|
|
+ parser.add_argument('--model-backbone', type=str, required=False, choices=['resnet50', 'mobilenetv3'])
|
|
|
+ parser.add_argument('--torchscript-file', type=str, required=False, default=None)
|
|
|
|
|
|
parser.add_argument('--hide-fps', action='store_true')
|
|
|
parser.add_argument('--resolution', type=int, nargs=2, metavar=('width', 'height'), default=(1280, 720))
|
|
@@ -158,12 +159,15 @@ def cv2_frame_to_cuda(frame, datatype = torch.float32):
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|
- model = torch.hub.load("PeterL1n/RobustVideoMatting", args.model_backbone)
|
|
|
- use_cuda = True
|
|
|
- try:
|
|
|
+ if args.torchscript_file:
|
|
|
+ model = torch.jit.load(args.torchscript_file)
|
|
|
+ model = torch.jit.freeze(model)
|
|
|
+ else:
|
|
|
+ model = torch.hub.load("PeterL1n/RobustVideoMatting", args.model_backbone)
|
|
|
+
|
|
|
+ if torch.cuda.is_available():
|
|
|
model = model.cuda().eval()
|
|
|
- except AssertionError:
|
|
|
- use_cuda = False
|
|
|
+
|
|
|
datatype = torch.float32
|
|
|
|
|
|
width, height = args.resolution
|
|
@@ -178,7 +182,7 @@ if __name__ == "__main__":
|
|
|
while True:
|
|
|
frame = cam.read()
|
|
|
src = cv2_frame_to_cuda(frame, datatype)
|
|
|
- fgr, pha, *rec = model(src.cuda() if use_cuda else src, *rec, args.downsampling) # Cycle the recurrent states.
|
|
|
+ fgr, pha, *rec = model(src.cuda() if torch.cuda.is_available() else src, *rec, args.downsampling) # Cycle the recurrent states.
|
|
|
res = pha
|
|
|
res = res.mul(255).byte().cpu().permute(0, 2, 3, 1).numpy()[0]
|
|
|
b_channel, g_channel, r_channel = cv2.split(frame)
|