|
@@ -89,7 +89,7 @@ if args.model_type == 'mattingrefine':
|
|
|
args.model_refine_kernel_size)
|
|
|
|
|
|
model = model.to(device).eval()
|
|
|
-model.load_state_dict(torch.load(args.model_checkpoint), strict=False)
|
|
|
+model.load_state_dict(torch.load(args.model_checkpoint, map_location=device), strict=False)
|
|
|
|
|
|
|
|
|
# Load images
|