Преглед изворни кода

added webcam inference with compositing

subDesTagesMitExtraKaese пре 3 година
родитељ
комит
77f1bc25c2
2 измењених фајлова са 260 додато и 0 уклоњено
  1. 246 0
      inference_webcam_ts_compositing.py
  2. 14 0
      runCompositing.sh

+ 246 - 0
inference_webcam_ts_compositing.py

@@ -0,0 +1,246 @@
+"""
+Inference on webcams: Use a model on webcam input.
+
+Once launched, the script is in background collection mode.
+Press B to toggle between background capture mode and matting mode. The frame shown when B is pressed is used as background for matting.
+Press Q to exit.
+
+Example:
+
+    python inference_webcam.py \
+        --model-type mattingrefine \
+        --model-backbone resnet50 \
+        --model-checkpoint "PATH_TO_CHECKPOINT" \
+        --resolution 1280 720
+
+"""
+
+import argparse, os, shutil, time
+import cv2
+import numpy as np
+import torch
+
+from PyQt5 import QtGui, QtCore, uic
+from PyQt5 import QtWidgets
+from PyQt5.QtWidgets import QMainWindow, QApplication
+
+from torch import nn
+from torch.utils.data import DataLoader
+from torchvision.transforms import Compose, ToTensor, Resize
+from torchvision.transforms.functional import to_pil_image
+from threading import Thread, Lock, Condition
+from tqdm import tqdm
+from PIL import Image, ImageTk
+
+from dataset import VideoDataset
+from model import MattingBase, MattingRefine
+
+
+# --------------- Arguments ---------------
+
+
+parser = argparse.ArgumentParser(description='Inference from web-cam')
+
+parser.add_argument('--model-type', type=str, required=True, choices=['mattingbase', 'mattingrefine'])
+parser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2'])
+parser.add_argument('--model-backbone-scale', type=float, default=0.25)
+parser.add_argument('--model-checkpoint', type=str, required=True)
+parser.add_argument('--model-refine-mode', type=str, default='sampling', choices=['full', 'sampling', 'thresholding'])
+parser.add_argument('--model-refine-sample-pixels', type=int, default=80_000)
+parser.add_argument('--model-refine-threshold', type=float, default=0.7)
+
+parser.add_argument('--hide-fps', action='store_true')
+parser.add_argument('--resolution', type=int, nargs=2, metavar=('width', 'height'), default=(1280, 720))
+
+parser.add_argument('--device-id', type=int, default=0)
+parser.add_argument('--background-image', type=str, default="")
+
+args = parser.parse_args()
+
+
+# ----------- Utility classes -------------
+
+
+# A wrapper that reads data from cv2.VideoCapture in its own thread to optimize.
+# Use .read() in a tight loop to get the newest frame
+class Camera:
+    def __init__(self, device_id=0, width=1280, height=720):
+        self.capture = cv2.VideoCapture(device_id)
+        self.capture.set(cv2.CAP_PROP_FRAME_WIDTH, width)
+        self.capture.set(cv2.CAP_PROP_FRAME_HEIGHT, height)
+        self.width = int(self.capture.get(cv2.CAP_PROP_FRAME_WIDTH))
+        self.height = int(self.capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
+        # self.capture.set(cv2.CAP_PROP_BUFFERSIZE, 2)
+        self.exposure = self.capture.get(cv2.CAP_PROP_EXPOSURE)
+        self.capture.set(cv2.CAP_PROP_BACKLIGHT, 0)
+        self.capture.set(cv2.CAP_PROP_EXPOSURE,self.exposure)
+        self.frameAvailable = False
+        self.success_reading, self.frame = self.capture.read()
+        self.cv = Condition()
+        self.thread = Thread(target=self.__update, args=())
+        self.thread.daemon = True
+        self.thread.start()
+
+    def __update(self):
+      while self.success_reading:
+        grabbed, frame = self.capture.read()
+        with self.cv:
+          self.success_reading = grabbed
+          self.frame = frame
+          self.frameAvailable = True
+          self.cv.notify()
+
+    def brighter(self):
+      if self.exposure < -2:
+        self.exposure += 1
+        self.capture.set(cv2.CAP_PROP_EXPOSURE,self.exposure)
+        print(self.exposure)
+    def darker(self):
+      if self.exposure > -12:
+        self.exposure -= 1
+        self.capture.set(cv2.CAP_PROP_EXPOSURE,self.exposure)
+        print(self.exposure)
+
+    def read(self):
+        with self.cv:
+            self.cv.wait_for(lambda: self.frameAvailable)
+            frame = self.frame.copy()
+            self.frameAvailable = False
+        return frame
+    def __exit__(self, exec_type, exc_value, traceback):
+        self.capture.release()
+
+# An FPS tracker that computes exponentialy moving average FPS
+class FPSTracker:
+    def __init__(self, ratio=0.5):
+        self._last_tick = None
+        self._avg_fps = None
+        self.ratio = ratio
+    def tick(self):
+        if self._last_tick is None:
+            self._last_tick = time.time()
+            return None
+        t_new = time.time()
+        fps_sample = 1.0 / (t_new - self._last_tick)
+        self._avg_fps = self.ratio * fps_sample + (1 - self.ratio) * self._avg_fps if self._avg_fps is not None else fps_sample
+        self._last_tick = t_new
+        return self.get()
+    def get(self):
+        return self._avg_fps
+
+# Wrapper for playing a stream with cv2.imshow(). It can accept an image and return keypress info for basic interactivity.
+# It also tracks FPS and optionally overlays info onto the stream.
+class Displayer(QMainWindow):
+    def __init__(self, title, width, height, show_info=True):
+        self.width, self.height = width, height
+        self.show_info = show_info
+        self.fps_tracker = FPSTracker()
+
+
+        QMainWindow.__init__(self)
+        self.setFixedSize(width, height)
+        self.setAttribute(QtCore.Qt.WA_TranslucentBackground, True)
+
+        self.image_label = QtWidgets.QLabel(self)
+        self.image_label.resize(width, height)
+
+        self.key = None
+
+    def keyPressEvent(self, event):
+      self.key = event.text()
+
+    def closeEvent(self, event):
+      self.key = 'q'
+
+    # Update the currently showing frame and return key press char code
+    def step(self, image):
+        fps_estimate = self.fps_tracker.tick()
+        if self.show_info and fps_estimate is not None:
+            message = f"{int(fps_estimate)} fps | {self.width}x{self.height}"
+            cv2.putText(image, message, (10, 40), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (0, 0, 0))
+        
+        pix = self.convert_cv_qt(image)
+        self.image_label.setPixmap(pix)
+        
+        QApplication.processEvents()
+
+        key = self.key
+        self.key = None
+        return key
+    def convert_cv_qt(self, cv_img):
+        """Convert from an opencv image to QPixmap"""
+        h, w, ch = cv_img.shape
+        bytes_per_line = ch * w
+        if ch == 3:
+          convert_to_Qt_format = QtGui.QImage(cv_img.data, w, h, bytes_per_line, QtGui.QImage.Format_RGB888)
+        elif ch == 4:
+          convert_to_Qt_format = QtGui.QImage(cv_img.data, w, h, bytes_per_line, QtGui.QImage.Format_RGBA8888)
+
+        return QtGui.QPixmap.fromImage(convert_to_Qt_format)
+
+
+# --------------- Main ---------------
+
+model = torch.jit.load(args.model_checkpoint)
+model.backbone_scale = args.model_backbone_scale
+model.refine_mode = args.model_refine_mode
+model.refine_sample_pixels = args.model_refine_sample_pixels
+model.refine_threshold = args.model_refine_threshold
+
+model = model.to(torch.device('cuda'))
+
+width, height = args.resolution
+cam = Camera(device_id=args.device_id,width=width, height=height)
+app = QApplication(['MattingV2'])
+dsp = Displayer('MattingV2', cam.width, cam.height, show_info=(not args.hide_fps))
+dsp.show()
+
+
+def cv2_frame_to_cuda(frame):
+    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
+    if 'fp16'in args.model_checkpoint:
+      return ToTensor()(Image.fromarray(frame)).unsqueeze_(0).to(torch.float16).cuda()
+    else:
+      return ToTensor()(Image.fromarray(frame)).unsqueeze_(0).to(torch.float32).cuda()
+
+with torch.no_grad():
+    while True:
+        bgr = None
+        while True: # grab bgr
+            frame = cam.read()
+            frameRGB = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
+            key = dsp.step(frameRGB)
+            if key == 'b':
+              bgr = cv2_frame_to_cuda(frame)
+              break
+            elif key == 'w':
+              cam.brighter()
+            elif key == 's':
+              cam.darker()
+            elif key == 'q':
+                exit()
+
+        if args.background_image == "":
+          #green screen
+          bgImage = torch.zeros_like(bgr)
+          bgImage[0,1] = torch.ones_like(bgr[0,0])
+        else:
+          bgImage = cv2.imread(args.background_image, cv2.IMREAD_UNCHANGED)
+          bgImage = cv2.resize(bgImage, (frame.shape[1], frame.shape[0]))
+          bgImage = cv2_frame_to_cuda(bgImage)
+
+        while True: # matting
+            frame = cam.read()
+            src = cv2_frame_to_cuda(frame)
+            pha, fgr = model(src, bgr)[:2]
+            res = pha * fgr + (1 - pha) * bgImage
+            res = res.mul(255).byte().cpu().permute(0, 2, 3, 1).numpy()[0]
+            key = dsp.step(res.copy())
+            if key == 'b':
+                break
+            elif key == 'w':
+              cam.brighter()
+            elif key == 's':
+              cam.darker()
+            elif key == 'q':
+                exit()

+ 14 - 0
runCompositing.sh

@@ -0,0 +1,14 @@
+#!/bin/bash
+
+python3 inference_webcam_ts_compositing.py \
+  --model-type mattingrefine \
+  --model-backbone mobilenetv2 \
+  --model-backbone-scale 0.5 \
+  --model-checkpoint model/TorchScript/torchscript_mobilenetv2_fp32.pth \
+  --model-refine-mode 'thresholding' \
+  --model-refine-threshold 0.75 \
+  --model-refine-sample-pixels 20000 \
+  --resolution 640 360 \
+  --background-image images/House_Calls.webp \
+  --device-id 3
+  #--hide-fps