Browse Source

merged changes into original script

subDesTagesMitExtraKaese 4 years ago
parent
commit
f9223047ac
1 changed files with 145 additions and 75 deletions
  1. 145 75
      inference_webcam.py

+ 145 - 75
inference_webcam.py

@@ -17,36 +17,45 @@ Example:
 
 import argparse, os, shutil, time
 import cv2
+import numpy as np
 import torch
 
+from PyQt5 import QtGui, QtCore, uic
+from PyQt5 import QtWidgets
+from PyQt5.QtWidgets import QMainWindow, QApplication
+
 from torch import nn
 from torch.utils.data import DataLoader
 from torchvision.transforms import Compose, ToTensor, Resize
 from torchvision.transforms.functional import to_pil_image
-from threading import Thread, Lock
+from threading import Thread, Lock, Condition
 from tqdm import tqdm
-from PIL import Image
+from PIL import Image, ImageTk
 
 from dataset import VideoDataset
 from model import MattingBase, MattingRefine
+torch.backends.cudnn.benchmark = True
 
 
 # --------------- Arguments ---------------
 
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser(description='Inference from web-cam')
+
+    parser.add_argument('--model-type', type=str, required=True, choices=['mattingbase', 'mattingrefine'])
+    parser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2'])
+    parser.add_argument('--model-backbone-scale', type=float, default=0.25)
+    parser.add_argument('--model-checkpoint', type=str, required=True)
+    parser.add_argument('--model-refine-mode', type=str, default='sampling', choices=['full', 'sampling', 'thresholding'])
+    parser.add_argument('--model-refine-sample-pixels', type=int, default=80_000)
+    parser.add_argument('--model-refine-threshold', type=float, default=0.7)
 
-parser = argparse.ArgumentParser(description='Inference from web-cam')
+    parser.add_argument('--hide-fps', action='store_true')
+    parser.add_argument('--resolution', type=int, nargs=2, metavar=('width', 'height'), default=(1280, 720))
 
-parser.add_argument('--model-type', type=str, required=True, choices=['mattingbase', 'mattingrefine'])
-parser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2'])
-parser.add_argument('--model-backbone-scale', type=float, default=0.25)
-parser.add_argument('--model-checkpoint', type=str, required=True)
-parser.add_argument('--model-refine-mode', type=str, default='sampling', choices=['full', 'sampling', 'thresholding'])
-parser.add_argument('--model-refine-sample-pixels', type=int, default=80_000)
-parser.add_argument('--model-refine-threshold', type=float, default=0.7)
+    parser.add_argument('--device-id', type=int, default=0)
 
-parser.add_argument('--hide-fps', action='store_true')
-parser.add_argument('--resolution', type=int, nargs=2, metavar=('width', 'height'), default=(1280, 720))
-args = parser.parse_args()
+    args = parser.parse_args()
 
 
 # ----------- Utility classes -------------
@@ -62,22 +71,41 @@ class Camera:
         self.width = int(self.capture.get(cv2.CAP_PROP_FRAME_WIDTH))
         self.height = int(self.capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
         # self.capture.set(cv2.CAP_PROP_BUFFERSIZE, 2)
+        self.exposure = self.capture.get(cv2.CAP_PROP_EXPOSURE)
+        self.capture.set(cv2.CAP_PROP_BACKLIGHT, 0)
+        self.capture.set(cv2.CAP_PROP_EXPOSURE,self.exposure)
+        self.frameAvailable = False
         self.success_reading, self.frame = self.capture.read()
-        self.read_lock = Lock()
+        self.cv = Condition()
         self.thread = Thread(target=self.__update, args=())
         self.thread.daemon = True
         self.thread.start()
 
     def __update(self):
-        while self.success_reading:
-            grabbed, frame = self.capture.read()
-            with self.read_lock:
-                self.success_reading = grabbed
-                self.frame = frame
+      while self.success_reading:
+        grabbed, frame = self.capture.read()
+        with self.cv:
+          self.success_reading = grabbed
+          self.frame = frame
+          self.frameAvailable = True
+          self.cv.notify()
+
+    def brighter(self):
+      if self.exposure < -2:
+        self.exposure += 1
+        self.capture.set(cv2.CAP_PROP_EXPOSURE,self.exposure)
+        print(self.exposure)
+    def darker(self):
+      if self.exposure > -12:
+        self.exposure -= 1
+        self.capture.set(cv2.CAP_PROP_EXPOSURE,self.exposure)
+        print(self.exposure)
 
     def read(self):
-        with self.read_lock:
+        with self.cv:
+            self.cv.wait_for(lambda: self.frameAvailable)
             frame = self.frame.copy()
+            self.frameAvailable = False
         return frame
     def __exit__(self, exec_type, exc_value, traceback):
         self.capture.release()
@@ -102,70 +130,112 @@ class FPSTracker:
 
 # Wrapper for playing a stream with cv2.imshow(). It can accept an image and return keypress info for basic interactivity.
 # It also tracks FPS and optionally overlays info onto the stream.
-class Displayer:
-    def __init__(self, title, width=None, height=None, show_info=True):
-        self.title, self.width, self.height = title, width, height
+class Displayer(QMainWindow):
+    def __init__(self, title, width, height, show_info=True):
+        self.width, self.height = width, height
         self.show_info = show_info
         self.fps_tracker = FPSTracker()
-        cv2.namedWindow(self.title, cv2.WINDOW_NORMAL)
-        if width is not None and height is not None:
-            cv2.resizeWindow(self.title, width, height)
-    # Update the currently showing frame and return key press char code
+
+
+        QMainWindow.__init__(self)
+        self.setFixedSize(width, height)
+        self.setAttribute(QtCore.Qt.WA_TranslucentBackground, True)
+
+        self.image_label = QtWidgets.QLabel(self)
+        self.image_label.resize(width, height)
+
+        self.key = None
+
+    def keyPressEvent(self, event):
+      self.key = event.text()
+
+    def closeEvent(self, event):
+      self.key = 'q'
+      
+          # Update the currently showing frame and return key press char code
     def step(self, image):
         fps_estimate = self.fps_tracker.tick()
         if self.show_info and fps_estimate is not None:
             message = f"{int(fps_estimate)} fps | {self.width}x{self.height}"
             cv2.putText(image, message, (10, 40), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (0, 0, 0))
-        cv2.imshow(self.title, image)
-        return cv2.waitKey(1) & 0xFF
+        
+        pix = self.convert_cv_qt(image)
+        self.image_label.setPixmap(pix)
+        
+        QApplication.processEvents()
+
+        key = self.key
+        self.key = None
+        return key
+    def convert_cv_qt(self, cv_img):
+        """Convert from an opencv image to QPixmap"""
+        h, w, ch = cv_img.shape
+        bytes_per_line = ch * w
+        if ch == 3:
+          convert_to_Qt_format = QtGui.QImage(cv_img.data, w, h, bytes_per_line, QtGui.QImage.Format_RGB888)
+        elif ch == 4:
+          convert_to_Qt_format = QtGui.QImage(cv_img.data, w, h, bytes_per_line, QtGui.QImage.Format_RGBA8888)
+
+        return QtGui.QPixmap.fromImage(convert_to_Qt_format)
 
 
 # --------------- Main ---------------
 
-
-# Load model
-if args.model_type == 'mattingbase':
-    model = MattingBase(args.model_backbone)
-if args.model_type == 'mattingrefine':
-    model = MattingRefine(
-        args.model_backbone,
-        args.model_backbone_scale,
-        args.model_refine_mode,
-        args.model_refine_sample_pixels,
-        args.model_refine_threshold)
-
-model = model.cuda().eval()
-model.load_state_dict(torch.load(args.model_checkpoint), strict=False)
-
-
-width, height = args.resolution
-cam = Camera(width=width, height=height)
-dsp = Displayer('MattingV2', cam.width, cam.height, show_info=(not args.hide_fps))
-
-def cv2_frame_to_cuda(frame):
-    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
-    return ToTensor()(Image.fromarray(frame)).unsqueeze_(0).cuda()
-
-with torch.no_grad():
-    while True:
-        bgr = None
-        while True: # grab bgr
-            frame = cam.read()
-            key = dsp.step(frame)
-            if key == ord('b'):
-                bgr = cv2_frame_to_cuda(cam.read())
-                break
-            elif key == ord('q'):
-                exit()
-        while True: # matting
-            frame = cam.read()
-            src = cv2_frame_to_cuda(frame)
-            pha, fgr = model(src, bgr)[:2]
-            res = pha * fgr + (1 - pha) * torch.ones_like(fgr)
-            res = res.mul(255).byte().cpu().permute(0, 2, 3, 1).numpy()[0]
-            res = cv2.cvtColor(res, cv2.COLOR_RGB2BGR)
-            key = dsp.step(res)
-            if key == ord('b'):
+if __name__ == "__main__":
+
+  # Load model
+  if args.model_type == 'mattingbase':
+      model = MattingBase(args.model_backbone)
+  if args.model_type == 'mattingrefine':
+      model = MattingRefine(
+          args.model_backbone,
+          args.model_backbone_scale,
+          args.model_refine_mode,
+          args.model_refine_sample_pixels,
+          args.model_refine_threshold)
+
+  model = model.cuda().eval()
+  model.load_state_dict(torch.load(args.model_checkpoint), strict=False)
+
+
+  width, height = args.resolution
+  cam = Camera(device_id=args.device_id, width=width, height=height)
+  app = QApplication(['MattingV2'])
+  dsp = Displayer('MattingV2', cam.width, cam.height, show_info=(not args.hide_fps))
+  dsp.show()
+
+  def cv2_frame_to_cuda(frame):
+      frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
+      return ToTensor()(Image.fromarray(frame)).unsqueeze_(0).cuda()
+
+  with torch.no_grad():
+      while True:
+          bgr = None
+          while True: # grab bgr
+              frame = cam.read()
+              frameRGB = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
+              key = dsp.step(frameRGB)
+              if key == 'b':
+                bgr = cv2_frame_to_cuda(frame)
                 break
-            elif key == ord('q'):
-                exit()
+              elif key == 'w':
+                cam.brighter()
+              elif key == 's':
+                cam.darker()
+              elif key == 'q':
+                  exit()
+          while True: # matting
+              frame = cam.read()
+              src = cv2_frame_to_cuda(frame)
+              pha, fgr = model(src, bgr)[:2]
+              res = pha * fgr + (1 - pha) * torch.ones_like(fgr)
+              res = res.mul(255).byte().cpu().permute(0, 2, 3, 1).numpy()[0]
+              key = dsp.step(res.copy())
+              if key == 'b':
+                  break
+              elif key == 'w':
+                cam.brighter()
+              elif key == 's':
+                cam.darker()
+              elif key == 'q':
+                  exit()