|
@@ -125,7 +125,7 @@ if args.model_type == 'mattingrefine':
|
|
|
args.model_refine_kernel_size)
|
|
|
|
|
|
model = model.to(device).eval()
|
|
|
-model.load_state_dict(torch.load(args.model_checkpoint), strict=False)
|
|
|
+model.load_state_dict(torch.load(args.model_checkpoint, map_location=device), strict=False)
|
|
|
|
|
|
|
|
|
# Load video and background
|
|
@@ -203,4 +203,4 @@ with torch.no_grad():
|
|
|
if 'err' in args.output_types:
|
|
|
err_writer.add_batch(F.interpolate(err, src.shape[2:], mode='bilinear', align_corners=False))
|
|
|
if 'ref' in args.output_types:
|
|
|
- ref_writer.add_batch(F.interpolate(ref, src.shape[2:], mode='nearest'))
|
|
|
+ ref_writer.add_batch(F.interpolate(ref, src.shape[2:], mode='nearest'))
|