inference_video.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. """
  2. Inference video: Extract matting on video.
  3. Example:
  4. python inference_video.py \
  5. --model-type mattingrefine \
  6. --model-backbone resnet50 \
  7. --model-backbone-scale 0.25 \
  8. --model-refine-mode sampling \
  9. --model-refine-sample-pixels 80000 \
  10. --model-checkpoint "PATH_TO_CHECKPOINT" \
  11. --video-src "PATH_TO_VIDEO_SRC" \
  12. --video-bgr "PATH_TO_VIDEO_BGR" \
  13. --video-resize 1920 1080 \
  14. --output-dir "PATH_TO_OUTPUT_DIR" \
  15. --output-type com fgr pha err ref \
  16. --video-target-bgr "PATH_TO_VIDEO_TARGET_BGR"
  17. """
  18. import argparse
  19. import cv2
  20. import torch
  21. import os
  22. import shutil
  23. from torch import nn
  24. from torch.nn import functional as F
  25. from torch.utils.data import DataLoader
  26. from torchvision import transforms as T
  27. from torchvision.transforms.functional import to_pil_image
  28. from threading import Thread
  29. from tqdm import tqdm
  30. from PIL import Image
  31. from dataset import VideoDataset, ZipDataset
  32. from dataset import augmentation as A
  33. from model import MattingBase, MattingRefine
  34. from inference_utils import HomographicAlignment
  35. # --------------- Arguments ---------------
  36. parser = argparse.ArgumentParser(description='Inference video')
  37. parser.add_argument('--model-type', type=str, required=True, choices=['mattingbase', 'mattingrefine'])
  38. parser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2'])
  39. parser.add_argument('--model-backbone-scale', type=float, default=0.25)
  40. parser.add_argument('--model-checkpoint', type=str, required=True)
  41. parser.add_argument('--model-refine-mode', type=str, default='sampling', choices=['full', 'sampling', 'thresholding'])
  42. parser.add_argument('--model-refine-sample-pixels', type=int, default=80_000)
  43. parser.add_argument('--model-refine-threshold', type=float, default=0.7)
  44. parser.add_argument('--model-refine-kernel-size', type=int, default=3)
  45. parser.add_argument('--video-src', type=str, required=True)
  46. parser.add_argument('--video-bgr', type=str, required=True)
  47. parser.add_argument('--video-target-bgr', type=str, default=None, help="Path to video onto which to composite the output (default to flat green)")
  48. parser.add_argument('--video-resize', type=int, default=None, nargs=2)
  49. parser.add_argument('--device', type=str, choices=['cpu', 'cuda'], default='cuda')
  50. parser.add_argument('--preprocess-alignment', action='store_true')
  51. parser.add_argument('--output-dir', type=str, required=True)
  52. parser.add_argument('--output-types', type=str, required=True, nargs='+', choices=['com', 'pha', 'fgr', 'err', 'ref'])
  53. parser.add_argument('--output-format', type=str, default='video', choices=['video', 'image_sequences'])
  54. args = parser.parse_args()
  55. assert 'err' not in args.output_types or args.model_type in ['mattingbase', 'mattingrefine'], \
  56. 'Only mattingbase and mattingrefine support err output'
  57. assert 'ref' not in args.output_types or args.model_type in ['mattingrefine'], \
  58. 'Only mattingrefine support ref output'
  59. # --------------- Utils ---------------
  60. class VideoWriter:
  61. def __init__(self, path, frame_rate, width, height):
  62. self.out = cv2.VideoWriter(path, cv2.VideoWriter_fourcc(*'mp4v'), frame_rate, (width, height))
  63. def add_batch(self, frames):
  64. frames = frames.mul(255).byte()
  65. frames = frames.cpu().permute(0, 2, 3, 1).numpy()
  66. for i in range(frames.shape[0]):
  67. frame = frames[i]
  68. frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
  69. self.out.write(frame)
  70. class ImageSequenceWriter:
  71. def __init__(self, path, extension):
  72. self.path = path
  73. self.extension = extension
  74. self.index = 0
  75. os.makedirs(path)
  76. def add_batch(self, frames):
  77. Thread(target=self._add_batch, args=(frames, self.index)).start()
  78. self.index += frames.shape[0]
  79. def _add_batch(self, frames, index):
  80. frames = frames.cpu()
  81. for i in range(frames.shape[0]):
  82. frame = frames[i]
  83. frame = to_pil_image(frame)
  84. frame.save(os.path.join(self.path, str(index + i).zfill(5) + '.' + self.extension))
  85. # --------------- Main ---------------
  86. device = torch.device(args.device)
  87. # Load model
  88. if args.model_type == 'mattingbase':
  89. model = MattingBase(args.model_backbone)
  90. if args.model_type == 'mattingrefine':
  91. model = MattingRefine(
  92. args.model_backbone,
  93. args.model_backbone_scale,
  94. args.model_refine_mode,
  95. args.model_refine_sample_pixels,
  96. args.model_refine_threshold,
  97. args.model_refine_kernel_size)
  98. model = model.to(device).eval()
  99. model.load_state_dict(torch.load(args.model_checkpoint, map_location=device), strict=False)
  100. # Load video and background
  101. vid = VideoDataset(args.video_src)
  102. bgr = [Image.open(args.video_bgr).convert('RGB')]
  103. dataset = ZipDataset([vid, bgr], transforms=A.PairCompose([
  104. A.PairApply(T.Resize(args.video_resize[::-1]) if args.video_resize else nn.Identity()),
  105. HomographicAlignment() if args.preprocess_alignment else A.PairApply(nn.Identity()),
  106. A.PairApply(T.ToTensor())
  107. ]))
  108. if args.video_target_bgr:
  109. dataset = ZipDataset([dataset, VideoDataset(args.video_target_bgr, transforms=T.ToTensor())])
  110. # Create output directory
  111. if os.path.exists(args.output_dir):
  112. if input(f'Directory {args.output_dir} already exists. Override? [Y/N]: ').lower() == 'y':
  113. shutil.rmtree(args.output_dir)
  114. else:
  115. exit()
  116. os.makedirs(args.output_dir)
  117. # Prepare writers
  118. if args.output_format == 'video':
  119. h = args.video_resize[1] if args.video_resize is not None else vid.height
  120. w = args.video_resize[0] if args.video_resize is not None else vid.width
  121. if 'com' in args.output_types:
  122. com_writer = VideoWriter(os.path.join(args.output_dir, 'com.mp4'), vid.frame_rate, w, h)
  123. if 'pha' in args.output_types:
  124. pha_writer = VideoWriter(os.path.join(args.output_dir, 'pha.mp4'), vid.frame_rate, w, h)
  125. if 'fgr' in args.output_types:
  126. fgr_writer = VideoWriter(os.path.join(args.output_dir, 'fgr.mp4'), vid.frame_rate, w, h)
  127. if 'err' in args.output_types:
  128. err_writer = VideoWriter(os.path.join(args.output_dir, 'err.mp4'), vid.frame_rate, w, h)
  129. if 'ref' in args.output_types:
  130. ref_writer = VideoWriter(os.path.join(args.output_dir, 'ref.mp4'), vid.frame_rate, w, h)
  131. else:
  132. if 'com' in args.output_types:
  133. com_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'com'), 'png')
  134. if 'pha' in args.output_types:
  135. pha_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'pha'), 'jpg')
  136. if 'fgr' in args.output_types:
  137. fgr_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'fgr'), 'jpg')
  138. if 'err' in args.output_types:
  139. err_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'err'), 'jpg')
  140. if 'ref' in args.output_types:
  141. ref_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'ref'), 'jpg')
  142. # Conversion loop
  143. with torch.no_grad():
  144. for input_batch in tqdm(DataLoader(dataset, batch_size=1, pin_memory=True)):
  145. if args.video_target_bgr:
  146. (src, bgr), tgt_bgr = input_batch
  147. tgt_bgr = tgt_bgr.to(device, non_blocking=True)
  148. else:
  149. src, bgr = input_batch
  150. tgt_bgr = torch.tensor([120/255, 255/255, 155/255], device=device).view(1, 3, 1, 1)
  151. src = src.to(device, non_blocking=True)
  152. bgr = bgr.to(device, non_blocking=True)
  153. if args.model_type == 'mattingbase':
  154. pha, fgr, err, _ = model(src, bgr)
  155. elif args.model_type == 'mattingrefine':
  156. pha, fgr, _, _, err, ref = model(src, bgr)
  157. elif args.model_type == 'mattingbm':
  158. pha, fgr = model(src, bgr)
  159. if 'com' in args.output_types:
  160. if args.output_format == 'video':
  161. # Output composite with green background
  162. com = fgr * pha + tgt_bgr * (1 - pha)
  163. com_writer.add_batch(com)
  164. else:
  165. # Output composite as rgba png images
  166. com = torch.cat([fgr * pha.ne(0), pha], dim=1)
  167. com_writer.add_batch(com)
  168. if 'pha' in args.output_types:
  169. pha_writer.add_batch(pha)
  170. if 'fgr' in args.output_types:
  171. fgr_writer.add_batch(fgr)
  172. if 'err' in args.output_types:
  173. err_writer.add_batch(F.interpolate(err, src.shape[2:], mode='bilinear', align_corners=False))
  174. if 'ref' in args.output_types:
  175. ref_writer.add_batch(F.interpolate(ref, src.shape[2:], mode='nearest'))