""" Inference on webcams: Use a model on webcam input. Once launched, the script is in background collection mode. Press B to toggle between background capture mode and matting mode. The frame shown when B is pressed is used as background for matting. Press Q to exit. Example: python inference_webcam.py \ --model-type mattingrefine \ --model-backbone resnet50 \ --model-checkpoint "PATH_TO_CHECKPOINT" \ --resolution 1280 720 """ import argparse, os, shutil, time import cv2 import numpy as np import torch from PyQt5 import QtGui, QtCore, uic from PyQt5 import QtWidgets from PyQt5.QtWidgets import QMainWindow, QApplication from torch import nn from torch.utils.data import DataLoader from torchvision.transforms import Compose, ToTensor, Resize from torchvision.transforms.functional import to_pil_image from threading import Thread, Lock, Condition from tqdm import tqdm from PIL import Image, ImageTk from dataset import VideoDataset from model import MattingBase, MattingRefine from inference_webcam import Displayer, FPSTracker, Camera, cv2_frame_to_cuda # --------------- Arguments --------------- parser = argparse.ArgumentParser(description='Inference from web-cam') parser.add_argument('--model-type', type=str, required=True, choices=['mattingbase', 'mattingrefine']) parser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2']) parser.add_argument('--model-backbone-scale', type=float, default=0.25) parser.add_argument('--model-checkpoint', type=str, required=True) parser.add_argument('--model-refine-mode', type=str, default='sampling', choices=['full', 'sampling', 'thresholding']) parser.add_argument('--model-refine-sample-pixels', type=int, default=80_000) parser.add_argument('--model-refine-threshold', type=float, default=0.7) parser.add_argument('--hide-fps', action='store_true') parser.add_argument('--resolution', type=int, nargs=2, metavar=('width', 'height'), default=(1280, 720)) parser.add_argument('--device-id', type=int, default=0) parser.add_argument('--background-image', type=str, default="") parser.add_argument('--fps-limit', type=int, default=1000) args = parser.parse_args() # --------------- Main --------------- default_float_dtype = torch.get_default_dtype() model = torch.jit.load(args.model_checkpoint) model.backbone_scale = args.model_backbone_scale model.refine_mode = args.model_refine_mode model.refine_sample_pixels = args.model_refine_sample_pixels model.refine_threshold = args.model_refine_threshold datatype = torch.float16 if 'fp16' in args.model_checkpoint else torch.float32 model = model.to(torch.device('cuda')) width, height = args.resolution cam = Camera(device_id=args.device_id,width=width, height=height) app = QApplication(['MattingV2']) dsp = Displayer('MattingV2', cam.width, cam.height, show_info=(not args.hide_fps)) dsp.show() with torch.no_grad(): while True: bgr = None while True: # grab bgr frame = cam.read() frameRGB = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) key = dsp.step(frameRGB) if key == 'b': bgr = cv2_frame_to_cuda(frame, datatype) break elif key == 'w': cam.brighter() elif key == 's': cam.darker() elif key == 'q': exit() #prevent Division by Zero error time.sleep(1/60) if "RGB" in args.background_image: bgRGB = args.background_image.replace("RGB","").strip().split(":") bgImage = torch.zeros_like(bgr) try: red,green,blue = [int(x)/255 for x in bgRGB] bgImage[0,0] = (torch.ones_like(bgr[0,0])) * red bgImage[0,1] = (torch.ones_like(bgr[0,0])) * green bgImage[0,2] = (torch.ones_like(bgr[0,0])) * blue except ValueError: print(args.background_image,"not matching condition, Use default greenscreen") bgImage[0,1] = (torch.ones_like(bgr[0,0])) else: bgImage = cv2.imread(args.background_image, cv2.IMREAD_UNCHANGED) bgImage = cv2.resize(bgImage, (frame.shape[1], frame.shape[0])) bgImage = cv2_frame_to_cuda(bgImage, datatype) while True: # matting frame = cam.read() src = cv2_frame_to_cuda(frame, datatype) pha, fgr = model(src, bgr)[:2] res = pha * fgr + (1 - pha) * bgImage res = res.mul(255).byte().cpu().permute(0, 2, 3, 1).numpy()[0] key = dsp.step(res.copy()) if key == 'b': break elif key == 'w': cam.brighter() elif key == 's': cam.darker() elif key == 'q': exit()