inference_video.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203
  1. """
  2. Inference video: Extract matting on video.
  3. Example:
  4. python inference_video.py \
  5. --model-type mattingrefine \
  6. --model-backbone resnet50 \
  7. --model-backbone-scale 0.25 \
  8. --model-refine-mode sampling \
  9. --model-refine-sample-pixels 80000 \
  10. --model-checkpoint "PATH_TO_CHECKPOINT" \
  11. --video-src "PATH_TO_VIDEO_SRC" \
  12. --video-bgr "PATH_TO_VIDEO_BGR" \
  13. --video-resize 1920 1080 \
  14. --output-dir "PATH_TO_OUTPUT_DIR" \
  15. --output-type com fgr pha err ref
  16. """
  17. import argparse
  18. import cv2
  19. import torch
  20. import os
  21. import shutil
  22. from torch import nn
  23. from torch.nn import functional as F
  24. from torch.utils.data import DataLoader
  25. from torchvision import transforms as T
  26. from torchvision.transforms.functional import to_pil_image
  27. from threading import Thread
  28. from tqdm import tqdm
  29. from PIL import Image
  30. from dataset import VideoDataset, ZipDataset
  31. from dataset import augmentation as A
  32. from model import MattingBase, MattingRefine
  33. from inference_utils import HomographicAlignment
  34. # --------------- Arguments ---------------
  35. parser = argparse.ArgumentParser(description='Inference video')
  36. parser.add_argument('--model-type', type=str, required=True, choices=['mattingbase', 'mattingrefine'])
  37. parser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2'])
  38. parser.add_argument('--model-backbone-scale', type=float, default=0.25)
  39. parser.add_argument('--model-checkpoint', type=str, required=True)
  40. parser.add_argument('--model-refine-mode', type=str, default='sampling', choices=['full', 'sampling', 'thresholding'])
  41. parser.add_argument('--model-refine-sample-pixels', type=int, default=80_000)
  42. parser.add_argument('--model-refine-threshold', type=float, default=0.7)
  43. parser.add_argument('--model-refine-kernel-size', type=int, default=3)
  44. parser.add_argument('--video-src', type=str, required=True)
  45. parser.add_argument('--video-bgr', type=str, required=True)
  46. parser.add_argument('--video-resize', type=int, default=None, nargs=2)
  47. parser.add_argument('--preprocess-alignment', action='store_true')
  48. parser.add_argument('--output-dir', type=str, required=True)
  49. parser.add_argument('--output-types', type=str, required=True, nargs='+', choices=['com', 'pha', 'fgr', 'err', 'ref'])
  50. parser.add_argument('--output-format', type=str, default='video', choices=['video', 'image_sequences'])
  51. args = parser.parse_args()
  52. assert 'err' not in args.output_types or args.model_type in ['mattingbase', 'mattingrefine'], \
  53. 'Only mattingbase and mattingrefine support err output'
  54. assert 'ref' not in args.output_types or args.model_type in ['mattingrefine'], \
  55. 'Only mattingrefine support ref output'
  56. # --------------- Utils ---------------
  57. class VideoWriter:
  58. def __init__(self, path, frame_rate, width, height):
  59. self.out = cv2.VideoWriter(path, cv2.VideoWriter_fourcc(*'mp4v'), frame_rate, (width, height))
  60. def add_batch(self, frames):
  61. frames = frames.mul(255).byte()
  62. frames = frames.cpu().permute(0, 2, 3, 1).numpy()
  63. for i in range(frames.shape[0]):
  64. frame = frames[i]
  65. frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
  66. self.out.write(frame)
  67. class ImageSequenceWriter:
  68. def __init__(self, path, extension):
  69. self.path = path
  70. self.extension = extension
  71. self.index = 0
  72. os.makedirs(path)
  73. def add_batch(self, frames):
  74. Thread(target=self._add_batch, args=(frames, self.index)).start()
  75. self.index += frames.shape[0]
  76. def _add_batch(self, frames, index):
  77. frames = frames.cpu()
  78. for i in range(frames.shape[0]):
  79. frame = frames[i]
  80. frame = to_pil_image(frame)
  81. frame.save(os.path.join(self.path, str(index + i).zfill(5) + '.' + self.extension))
  82. # --------------- Main ---------------
  83. # Load model
  84. if args.model_type == 'mattingbase':
  85. model = MattingBase(args.model_backbone)
  86. if args.model_type == 'mattingrefine':
  87. model = MattingRefine(
  88. args.model_backbone,
  89. args.model_backbone_scale,
  90. args.model_refine_mode,
  91. args.model_refine_sample_pixels,
  92. args.model_refine_threshold,
  93. args.model_refine_kernel_size)
  94. model = model.cuda().eval()
  95. model.load_state_dict(torch.load(args.model_checkpoint), strict=False)
  96. # Load video and background
  97. vid = VideoDataset(args.video_src)
  98. bgr = [Image.open(args.video_bgr).convert('RGB')]
  99. dataset = ZipDataset([vid, bgr], transforms=A.PairCompose([
  100. A.PairApply(T.Resize(args.video_resize[::-1]) if args.video_resize else nn.Identity()),
  101. HomographicAlignment() if args.preprocess_alignment else A.PairApply(nn.Identity()),
  102. A.PairApply(T.ToTensor())
  103. ]))
  104. # Create output directory
  105. if os.path.exists(args.output_dir):
  106. if input(f'Directory {args.output_dir} already exists. Override? [Y/N]: ').lower() == 'y':
  107. shutil.rmtree(args.output_dir)
  108. else:
  109. exit()
  110. os.makedirs(args.output_dir)
  111. # Prepare writers
  112. if args.output_format == 'video':
  113. h = args.video_resize[1] if args.video_resize is not None else vid.height
  114. w = args.video_resize[0] if args.video_resize is not None else vid.width
  115. if 'com' in args.output_types:
  116. com_writer = VideoWriter(os.path.join(args.output_dir, 'com.mp4'), vid.frame_rate, w, h)
  117. if 'pha' in args.output_types:
  118. pha_writer = VideoWriter(os.path.join(args.output_dir, 'pha.mp4'), vid.frame_rate, w, h)
  119. if 'fgr' in args.output_types:
  120. fgr_writer = VideoWriter(os.path.join(args.output_dir, 'fgr.mp4'), vid.frame_rate, w, h)
  121. if 'err' in args.output_types:
  122. err_writer = VideoWriter(os.path.join(args.output_dir, 'err.mp4'), vid.frame_rate, w, h)
  123. if 'ref' in args.output_types:
  124. ref_writer = VideoWriter(os.path.join(args.output_dir, 'ref.mp4'), vid.frame_rate, w, h)
  125. else:
  126. if 'com' in args.output_types:
  127. com_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'com'), 'png')
  128. if 'pha' in args.output_types:
  129. pha_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'pha'), 'jpg')
  130. if 'fgr' in args.output_types:
  131. fgr_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'fgr'), 'jpg')
  132. if 'err' in args.output_types:
  133. err_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'err'), 'jpg')
  134. if 'ref' in args.output_types:
  135. ref_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'ref'), 'jpg')
  136. # Conversion loop
  137. with torch.no_grad():
  138. for src, bgr in tqdm(DataLoader(dataset, batch_size=1, pin_memory=True)):
  139. src = src.cuda(non_blocking=True)
  140. bgr = bgr.cuda(non_blocking=True)
  141. if args.model_type == 'mattingbase':
  142. pha, fgr, err, _ = model(src, bgr)
  143. elif args.model_type == 'mattingrefine':
  144. pha, fgr, _, _, err, ref = model(src, bgr)
  145. elif args.model_type == 'mattingbm':
  146. pha, fgr = model(src, bgr)
  147. if 'com' in args.output_types:
  148. if args.output_format == 'video':
  149. # Output composite with green background
  150. bgr_green = torch.tensor([120/255, 255/255, 155/255], device='cuda').view(1, 3, 1, 1)
  151. com = fgr * pha + bgr_green * (1 - pha)
  152. com_writer.add_batch(com)
  153. else:
  154. # Output composite as rgba png images
  155. com = torch.cat([fgr * pha.ne(0), pha], dim=1)
  156. com_writer.add_batch(com)
  157. if 'pha' in args.output_types:
  158. pha_writer.add_batch(pha)
  159. if 'fgr' in args.output_types:
  160. fgr_writer.add_batch(fgr)
  161. if 'err' in args.output_types:
  162. err_writer.add_batch(F.interpolate(err, src.shape[2:], mode='bilinear', align_corners=False))
  163. if 'ref' in args.output_types:
  164. ref_writer.add_batch(F.interpolate(ref, src.shape[2:], mode='nearest'))