Browse Source

add option to composite onto a target video with --video-target-bgr

andreyryabtsev 4 years ago
parent
commit
9c9b9592e8
1 changed files with 13 additions and 4 deletions
  1. 13 4
      inference_video.py

+ 13 - 4
inference_video.py

@@ -14,7 +14,8 @@ Example:
         --video-bgr "PATH_TO_VIDEO_BGR" \
         --video-resize 1920 1080 \
         --output-dir "PATH_TO_OUTPUT_DIR" \
-        --output-type com fgr pha err ref
+        --output-type com fgr pha err ref \
+        --video-target-bgr "PATH_TO_VIDEO_TARGET_BGR"
 
 """
 
@@ -55,6 +56,7 @@ parser.add_argument('--model-refine-kernel-size', type=int, default=3)
 
 parser.add_argument('--video-src', type=str, required=True)
 parser.add_argument('--video-bgr', type=str, required=True)
+parser.add_argument('--video-target-bgr', type=str, default=None, help="Path to video onto which to composite the output (default to flat green)")
 parser.add_argument('--video-resize', type=int, default=None, nargs=2)
 
 parser.add_argument('--device', type=str, choices=['cpu', 'cuda'], default='cuda')
@@ -136,6 +138,8 @@ dataset = ZipDataset([vid, bgr], transforms=A.PairCompose([
     HomographicAlignment() if args.preprocess_alignment else A.PairApply(nn.Identity()),
     A.PairApply(T.ToTensor())
 ]))
+if args.video_target_bgr:
+    dataset = ZipDataset([dataset, VideoDataset(args.video_target_bgr, transforms=T.ToTensor())])
 
 # Create output directory
 if os.path.exists(args.output_dir):
@@ -175,7 +179,13 @@ else:
 
 # Conversion loop
 with torch.no_grad():
-    for src, bgr in tqdm(DataLoader(dataset, batch_size=1, pin_memory=True)):
+    for input_batch in tqdm(DataLoader(dataset, batch_size=1, pin_memory=True)):
+        if args.video_target_bgr:
+            (src, bgr), tgt_bgr = input_batch
+            tgt_bgr = tgt_bgr.to(device, non_blocking=True)
+        else:
+            src, bgr = input_batch
+            tgt_bgr = torch.tensor([120/255, 255/255, 155/255], device=device).view(1, 3, 1, 1)
         src = src.to(device, non_blocking=True)
         bgr = bgr.to(device, non_blocking=True)
         
@@ -189,8 +199,7 @@ with torch.no_grad():
         if 'com' in args.output_types:
             if args.output_format == 'video':
                 # Output composite with green background
-                bgr_green = torch.tensor([120/255, 255/255, 155/255], device=device).view(1, 3, 1, 1)
-                com = fgr * pha + bgr_green * (1 - pha)
+                com = fgr * pha + tgt_bgr * (1 - pha)
                 com_writer.add_batch(com)
             else:
                 # Output composite as rgba png images