Browse Source

Add argument option for cpu support

Peter Lin 4 years ago
parent
commit
21702785a8
3 changed files with 21 additions and 12 deletions
  1. 6 3
      inference_images.py
  2. 8 5
      inference_speed_test.py
  3. 7 4
      inference_video.py

+ 6 - 3
inference_images.py

@@ -52,6 +52,7 @@ parser.add_argument('--model-refine-kernel-size', type=int, default=3)
 parser.add_argument('--images-src', type=str, required=True)
 parser.add_argument('--images-bgr', type=str, required=True)
 
+parser.add_argument('--device', type=str, choices=['cpu', 'cuda'], default='cuda')
 parser.add_argument('--preprocess-alignment', action='store_true')
 
 parser.add_argument('--output-dir', type=str, required=True)
@@ -70,6 +71,8 @@ assert 'ref' not in args.output_types or args.model_type in ['mattingrefine'], \
 # --------------- Main ---------------
 
 
+device = torch.device(args.device)
+
 # Load model
 if args.model_type == 'mattingbase':
     model = MattingBase(args.model_backbone)
@@ -82,7 +85,7 @@ if args.model_type == 'mattingrefine':
         args.model_refine_threshold,
         args.model_refine_kernel_size)
 
-model = model.cuda().eval()
+model = model.to(device).eval()
 model.load_state_dict(torch.load(args.model_checkpoint), strict=False)
 
 
@@ -118,8 +121,8 @@ def writer(img, path):
 with torch.no_grad():
     for i, (src, bgr) in enumerate(tqdm(dataloader)):
         filename = dataset.datasets[0].filenames[i]
-        src = src.cuda(non_blocking=True)
-        bgr = bgr.cuda(non_blocking=True)
+        src = src.to(device, non_blocking=True)
+        bgr = bgr.to(device, non_blocking=True)
         
         if args.model_type == 'mattingbase':
             pha, fgr, err, _ = model(src, bgr)

+ 8 - 5
inference_speed_test.py

@@ -62,6 +62,7 @@ parser.add_argument('--batch-size', type=int, default=1)
 parser.add_argument('--resolution', type=int, default=None, nargs=2)
 parser.add_argument('--precision', type=str, default='float32', choices=['float32', 'float16'])
 parser.add_argument('--backend', type=str, default='pytorch', choices=['pytorch', 'torchscript'])
+parser.add_argument('--device', type=str, choices=['cpu', 'cuda'], default='cuda')
 
 parser.add_argument('--image-src', type=str, default=None)
 parser.add_argument('--image-bgr', type=str, default=None)
@@ -76,6 +77,8 @@ assert (not args.image_src) != (not args.resolution), 'Must provide either a res
 # --------------- Run Loop ---------------
 
 
+device = torch.device(args.device)
+
 # Load model
 if args.model_type == 'mattingbase':
     model = MattingBase(args.model_backbone)
@@ -100,15 +103,15 @@ else:
 if args.backend == 'torchscript':
     model = torch.jit.script(model)
 
-model = model.cuda().eval().to(precision)
+model = model.eval().to(device=device, dtype=precision)
 
 # Load data
 if not args.image_src:
-    src = torch.rand((args.batch_size, 3, *args.resolution[::-1]), device='cuda', dtype=precision)
-    bgr = torch.rand((args.batch_size, 3, *args.resolution[::-1]), device='cuda', dtype=precision)
+    src = torch.rand((args.batch_size, 3, *args.resolution[::-1]), device=device, dtype=precision)
+    bgr = torch.rand((args.batch_size, 3, *args.resolution[::-1]), device=device, dtype=precision)
 else:
-    src = to_tensor(Image.open(args.image_src)).unsqueeze(0).repeat(args.batch_size, 1, 1, 1).to(device='cuda', dtype=precision)
-    bgr = to_tensor(Image.open(args.image_bgr)).unsqueeze(0).repeat(args.batch_size, 1, 1, 1).to(device='cuda', dtype=precision)
+    src = to_tensor(Image.open(args.image_src)).unsqueeze(0).repeat(args.batch_size, 1, 1, 1).to(device=device, dtype=precision)
+    bgr = to_tensor(Image.open(args.image_bgr)).unsqueeze(0).repeat(args.batch_size, 1, 1, 1).to(device=device, dtype=precision)
     
 # Loop
 with torch.no_grad():

+ 7 - 4
inference_video.py

@@ -57,6 +57,7 @@ parser.add_argument('--video-src', type=str, required=True)
 parser.add_argument('--video-bgr', type=str, required=True)
 parser.add_argument('--video-resize', type=int, default=None, nargs=2)
 
+parser.add_argument('--device', type=str, choices=['cpu', 'cuda'], default='cuda')
 parser.add_argument('--preprocess-alignment', action='store_true')
 
 parser.add_argument('--output-dir', type=str, required=True)
@@ -109,6 +110,8 @@ class ImageSequenceWriter:
 # --------------- Main ---------------
 
 
+device = torch.device(args.device)
+
 # Load model
 if args.model_type == 'mattingbase':
     model = MattingBase(args.model_backbone)
@@ -121,7 +124,7 @@ if args.model_type == 'mattingrefine':
         args.model_refine_threshold,
         args.model_refine_kernel_size)
 
-model = model.cuda().eval()
+model = model.to(device).eval()
 model.load_state_dict(torch.load(args.model_checkpoint), strict=False)
 
 
@@ -173,8 +176,8 @@ else:
 # Conversion loop
 with torch.no_grad():
     for src, bgr in tqdm(DataLoader(dataset, batch_size=1, pin_memory=True)):
-        src = src.cuda(non_blocking=True)
-        bgr = bgr.cuda(non_blocking=True)
+        src = src.to(device, non_blocking=True)
+        bgr = bgr.to(device, non_blocking=True)
         
         if args.model_type == 'mattingbase':
             pha, fgr, err, _ = model(src, bgr)
@@ -186,7 +189,7 @@ with torch.no_grad():
         if 'com' in args.output_types:
             if args.output_format == 'video':
                 # Output composite with green background
-                bgr_green = torch.tensor([120/255, 255/255, 155/255], device='cuda').view(1, 3, 1, 1)
+                bgr_green = torch.tensor([120/255, 255/255, 155/255], device=device).view(1, 3, 1, 1)
                 com = fgr * pha + bgr_green * (1 - pha)
                 com_writer.add_batch(com)
             else: