|
@@ -57,6 +57,7 @@ parser.add_argument('--video-src', type=str, required=True)
|
|
parser.add_argument('--video-bgr', type=str, required=True)
|
|
parser.add_argument('--video-bgr', type=str, required=True)
|
|
parser.add_argument('--video-resize', type=int, default=None, nargs=2)
|
|
parser.add_argument('--video-resize', type=int, default=None, nargs=2)
|
|
|
|
|
|
|
|
+parser.add_argument('--device', type=str, choices=['cpu', 'cuda'], default='cuda')
|
|
parser.add_argument('--preprocess-alignment', action='store_true')
|
|
parser.add_argument('--preprocess-alignment', action='store_true')
|
|
|
|
|
|
parser.add_argument('--output-dir', type=str, required=True)
|
|
parser.add_argument('--output-dir', type=str, required=True)
|
|
@@ -109,6 +110,8 @@ class ImageSequenceWriter:
|
|
# --------------- Main ---------------
|
|
# --------------- Main ---------------
|
|
|
|
|
|
|
|
|
|
|
|
+device = torch.device(args.device)
|
|
|
|
+
|
|
# Load model
|
|
# Load model
|
|
if args.model_type == 'mattingbase':
|
|
if args.model_type == 'mattingbase':
|
|
model = MattingBase(args.model_backbone)
|
|
model = MattingBase(args.model_backbone)
|
|
@@ -121,7 +124,7 @@ if args.model_type == 'mattingrefine':
|
|
args.model_refine_threshold,
|
|
args.model_refine_threshold,
|
|
args.model_refine_kernel_size)
|
|
args.model_refine_kernel_size)
|
|
|
|
|
|
-model = model.cuda().eval()
|
|
|
|
|
|
+model = model.to(device).eval()
|
|
model.load_state_dict(torch.load(args.model_checkpoint), strict=False)
|
|
model.load_state_dict(torch.load(args.model_checkpoint), strict=False)
|
|
|
|
|
|
|
|
|
|
@@ -173,8 +176,8 @@ else:
|
|
# Conversion loop
|
|
# Conversion loop
|
|
with torch.no_grad():
|
|
with torch.no_grad():
|
|
for src, bgr in tqdm(DataLoader(dataset, batch_size=1, pin_memory=True)):
|
|
for src, bgr in tqdm(DataLoader(dataset, batch_size=1, pin_memory=True)):
|
|
- src = src.cuda(non_blocking=True)
|
|
|
|
- bgr = bgr.cuda(non_blocking=True)
|
|
|
|
|
|
+ src = src.to(device, non_blocking=True)
|
|
|
|
+ bgr = bgr.to(device, non_blocking=True)
|
|
|
|
|
|
if args.model_type == 'mattingbase':
|
|
if args.model_type == 'mattingbase':
|
|
pha, fgr, err, _ = model(src, bgr)
|
|
pha, fgr, err, _ = model(src, bgr)
|
|
@@ -186,7 +189,7 @@ with torch.no_grad():
|
|
if 'com' in args.output_types:
|
|
if 'com' in args.output_types:
|
|
if args.output_format == 'video':
|
|
if args.output_format == 'video':
|
|
# Output composite with green background
|
|
# Output composite with green background
|
|
- bgr_green = torch.tensor([120/255, 255/255, 155/255], device='cuda').view(1, 3, 1, 1)
|
|
|
|
|
|
+ bgr_green = torch.tensor([120/255, 255/255, 155/255], device=device).view(1, 3, 1, 1)
|
|
com = fgr * pha + bgr_green * (1 - pha)
|
|
com = fgr * pha + bgr_green * (1 - pha)
|
|
com_writer.add_batch(com)
|
|
com_writer.add_batch(com)
|
|
else:
|
|
else:
|