Browse Source

add webcam inference

subDesTagesMitExtraKaese 2 years ago
parent
commit
550c1c25b8
2 changed files with 200 additions and 0 deletions
  1. 192 0
      inference_webcam.py
  2. 8 0
      run.sh

+ 192 - 0
inference_webcam.py

@@ -0,0 +1,192 @@
+import argparse, time
+import cv2
+import torch
+
+from PyQt5 import QtGui, QtCore, uic
+from PyQt5 import QtWidgets
+from PyQt5.QtWidgets import QMainWindow, QApplication
+
+from threading import Thread, Lock, Condition
+
+
+torch.backends.cudnn.benchmark = True
+
+
+# --------------- Arguments ---------------
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser(description='Inference from web-cam')
+
+    parser.add_argument('--model-backbone', type=str, required=True, choices=['resnet50', 'mobilenetv3'])
+
+    parser.add_argument('--hide-fps', action='store_true')
+    parser.add_argument('--resolution', type=int, nargs=2, metavar=('width', 'height'), default=(1280, 720))
+    parser.add_argument('--downsampling', type=float, default=0.25)
+
+    parser.add_argument('--device-id', type=int, default=0)
+
+    args = parser.parse_args()
+
+
+# ----------- Utility classes -------------
+
+
+# A wrapper that reads data from cv2.VideoCapture in its own thread to optimize.
+# Use .read() in a tight loop to get the newest frame
+class Camera:
+    def __init__(self, device_id=0, width=1280, height=720):
+        self.capture = cv2.VideoCapture(device_id)
+        self.capture.set(cv2.CAP_PROP_FRAME_WIDTH, width)
+        self.capture.set(cv2.CAP_PROP_FRAME_HEIGHT, height)
+        self.width = int(self.capture.get(cv2.CAP_PROP_FRAME_WIDTH))
+        self.height = int(self.capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
+        # self.capture.set(cv2.CAP_PROP_BUFFERSIZE, 2)
+        self.exposure = self.capture.get(cv2.CAP_PROP_EXPOSURE)
+        self.capture.set(cv2.CAP_PROP_BACKLIGHT, 0)
+        self.capture.set(cv2.CAP_PROP_EXPOSURE,self.exposure)
+        self.frameAvailable = False
+        self.success_reading, self.frame = self.capture.read()
+        self.cv = Condition()
+        self.thread = Thread(target=self.__update, args=())
+        self.thread.daemon = True
+        self.thread.start()
+
+    def __update(self):
+      while self.success_reading:
+        grabbed, frame = self.capture.read()
+        with self.cv:
+          self.success_reading = grabbed
+          self.frame = frame
+          self.frameAvailable = True
+          self.cv.notify()
+
+    def brighter(self):
+      if self.exposure < -2:
+        self.exposure += 1
+        self.capture.set(cv2.CAP_PROP_EXPOSURE,self.exposure)
+        print(self.exposure)
+    def darker(self):
+      if self.exposure > -12:
+        self.exposure -= 1
+        self.capture.set(cv2.CAP_PROP_EXPOSURE,self.exposure)
+        print(self.exposure)
+
+    def read(self):
+        with self.cv:
+            self.cv.wait_for(lambda: self.frameAvailable)
+            frame = self.frame.copy()
+            self.frameAvailable = False
+        return frame
+    def __exit__(self, exec_type, exc_value, traceback):
+        self.capture.release()
+
+# An FPS tracker that computes exponentialy moving average FPS
+class FPSTracker:
+    def __init__(self, ratio=0.5):
+        self._last_tick = None
+        self._avg_fps = None
+        self.ratio = ratio
+    def tick(self):
+        if self._last_tick is None:
+            self._last_tick = time.time()
+            return None
+        t_new = time.time()
+        fps_sample = 1.0 / (t_new - self._last_tick)
+        self._avg_fps = self.ratio * fps_sample + (1 - self.ratio) * self._avg_fps if self._avg_fps is not None else fps_sample
+        self._last_tick = t_new
+        return self.get()
+    def get(self):
+        return self._avg_fps
+
+# Wrapper for playing a stream with cv2.imshow(). It can accept an image and return keypress info for basic interactivity.
+# It also tracks FPS and optionally overlays info onto the stream.
+class Displayer(QMainWindow):
+    def __init__(self, title, width, height, show_info=True):
+        self.width, self.height = width, height
+        self.show_info = show_info
+        self.fps_tracker = FPSTracker()
+
+
+        QMainWindow.__init__(self)
+        self.setFixedSize(width, height)
+        self.setAttribute(QtCore.Qt.WA_TranslucentBackground, True)
+
+        self.image_label = QtWidgets.QLabel(self)
+        self.image_label.resize(width, height)
+
+        self.key = None
+
+    def keyPressEvent(self, event):
+      self.key = event.text()
+
+    def closeEvent(self, event):
+      self.key = 'q'
+      
+          # Update the currently showing frame and return key press char code
+    def step(self, image):
+        fps_estimate = self.fps_tracker.tick()
+        if self.show_info and fps_estimate is not None:
+            message = f"{int(fps_estimate)} fps | {self.width}x{self.height}"
+            cv2.putText(image, message, (10, 40), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (0, 0, 0))
+        
+        pix = self.convert_cv_qt(image)
+        self.image_label.setPixmap(pix)
+        
+        QApplication.processEvents()
+
+        key = self.key
+        self.key = None
+        return key
+    def convert_cv_qt(self, cv_img):
+        """Convert from an opencv image to QPixmap"""
+        h, w, ch = cv_img.shape
+        bytes_per_line = ch * w
+        if ch == 3:
+          convert_to_Qt_format = QtGui.QImage(cv_img.data, w, h, bytes_per_line, QtGui.QImage.Format_RGB888)
+        elif ch == 4:
+          convert_to_Qt_format = QtGui.QImage(cv_img.data, w, h, bytes_per_line, QtGui.QImage.Format_RGBA8888)
+
+        return QtGui.QPixmap.fromImage(convert_to_Qt_format)
+
+
+
+def cv2_frame_to_cuda(frame, datatype = torch.float32):
+    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
+    return torch.as_tensor(frame).type(datatype).unsqueeze(0).permute(0, 3, 1, 2)/255
+
+# --------------- Main ---------------
+
+if __name__ == "__main__":
+
+  model = torch.hub.load("PeterL1n/RobustVideoMatting", args.model_backbone)
+  use_cuda = True
+  try: 
+    model = model.cuda().eval()
+  except AssertionError:
+    use_cuda = False
+  datatype = torch.float32
+
+  width, height = args.resolution
+  cam = Camera(device_id=args.device_id, width=width, height=height)
+  app = QApplication(['RobustVideoMatting'])
+  dsp = Displayer('RobustVideoMatting', width, height, show_info=(not args.hide_fps))
+  dsp.show()
+  
+  rec = [None] * 4                                       # Initial recurrent states.
+
+  with torch.no_grad():
+    while True:
+      frame = cam.read()
+      src = cv2_frame_to_cuda(frame, datatype)
+      fgr, pha, *rec = model(src.cuda() if use_cuda else src, *rec, args.downsampling)  # Cycle the recurrent states.
+      res = pha
+      res = res.mul(255).byte().cpu().permute(0, 2, 3, 1).numpy()[0]
+      b_channel, g_channel, r_channel = cv2.split(frame)
+      img_RGBA = cv2.merge((r_channel, g_channel, b_channel, res))
+      key = dsp.step(img_RGBA)
+      if key == 'w':
+        cam.brighter()
+      elif key == 's':
+        cam.darker()
+      elif key == 'q':
+        exit()

+ 8 - 0
run.sh

@@ -0,0 +1,8 @@
+#!/bin/bash
+
+python3 inference_webcam.py \
+  --model-backbone mobilenetv3 \
+  --downsampling 0.3 \
+  --resolution 640 360 \
+  --device-id 0 \
+  #--hide-fps