Browse Source

add torchscript support

subDesTagesMitExtraKaese 2 years ago
parent
commit
c79f2d34ee
3 changed files with 14 additions and 8 deletions
  1. 1 0
      .gitignore
  2. 11 7
      inference_webcam.py
  3. 2 1
      run.sh

+ 1 - 0
.gitignore

@@ -0,0 +1 @@
+*.torchscript

+ 11 - 7
inference_webcam.py

@@ -17,7 +17,8 @@ torch.backends.cudnn.benchmark = True
 if __name__ == "__main__":
     parser = argparse.ArgumentParser(description='Inference from web-cam')
 
-    parser.add_argument('--model-backbone', type=str, required=True, choices=['resnet50', 'mobilenetv3'])
+    parser.add_argument('--model-backbone', type=str, required=False, choices=['resnet50', 'mobilenetv3'])
+    parser.add_argument('--torchscript-file', type=str, required=False, default=None)
 
     parser.add_argument('--hide-fps', action='store_true')
     parser.add_argument('--resolution', type=int, nargs=2, metavar=('width', 'height'), default=(1280, 720))
@@ -158,12 +159,15 @@ def cv2_frame_to_cuda(frame, datatype = torch.float32):
 
 if __name__ == "__main__":
 
-  model = torch.hub.load("PeterL1n/RobustVideoMatting", args.model_backbone)
-  use_cuda = True
-  try: 
+  if args.torchscript_file:
+    model = torch.jit.load(args.torchscript_file)
+    model = torch.jit.freeze(model)
+  else:
+    model = torch.hub.load("PeterL1n/RobustVideoMatting", args.model_backbone)
+
+  if torch.cuda.is_available(): 
     model = model.cuda().eval()
-  except AssertionError:
-    use_cuda = False
+
   datatype = torch.float32
 
   width, height = args.resolution
@@ -178,7 +182,7 @@ if __name__ == "__main__":
     while True:
       frame = cam.read()
       src = cv2_frame_to_cuda(frame, datatype)
-      fgr, pha, *rec = model(src.cuda() if use_cuda else src, *rec, args.downsampling)  # Cycle the recurrent states.
+      fgr, pha, *rec = model(src.cuda() if torch.cuda.is_available() else src, *rec, args.downsampling)  # Cycle the recurrent states.
       res = pha
       res = res.mul(255).byte().cpu().permute(0, 2, 3, 1).numpy()[0]
       b_channel, g_channel, r_channel = cv2.split(frame)

+ 2 - 1
run.sh

@@ -1,8 +1,9 @@
 #!/bin/bash
 
 python3 inference_webcam.py \
-  --model-backbone mobilenetv3 \
+  --torchscript-file rvm_mobilenetv3_fp32.torchscript \
   --downsampling 0.3 \
   --resolution 640 360 \
   --device-id 0 \
+  #--model-backbone mobilenetv3 \
   #--hide-fps