Shi Jin 4 жил өмнө
parent
commit
4a2225e561
1 өөрчлөгдсөн 2 нэмэгдсэн , 2 устгасан
  1. 2 2
      inference_video.py

+ 2 - 2
inference_video.py

@@ -125,7 +125,7 @@ if args.model_type == 'mattingrefine':
         args.model_refine_kernel_size)
 
 model = model.to(device).eval()
-model.load_state_dict(torch.load(args.model_checkpoint), strict=False)
+model.load_state_dict(torch.load(args.model_checkpoint, map_location=device), strict=False)
 
 
 # Load video and background
@@ -203,4 +203,4 @@ with torch.no_grad():
         if 'err' in args.output_types:
             err_writer.add_batch(F.interpolate(err, src.shape[2:], mode='bilinear', align_corners=False))
         if 'ref' in args.output_types:
-            ref_writer.add_batch(F.interpolate(ref, src.shape[2:], mode='nearest'))
+            ref_writer.add_batch(F.interpolate(ref, src.shape[2:], mode='nearest'))