Browse Source

merged changes into original script

subDesTagesMitExtraKaese 4 years ago
parent
commit
f9223047ac
1 changed files with 145 additions and 75 deletions
  1. 145 75
      inference_webcam.py

+ 145 - 75
inference_webcam.py

@@ -17,36 +17,45 @@ Example:
 
 
 import argparse, os, shutil, time
 import argparse, os, shutil, time
 import cv2
 import cv2
+import numpy as np
 import torch
 import torch
 
 
+from PyQt5 import QtGui, QtCore, uic
+from PyQt5 import QtWidgets
+from PyQt5.QtWidgets import QMainWindow, QApplication
+
 from torch import nn
 from torch import nn
 from torch.utils.data import DataLoader
 from torch.utils.data import DataLoader
 from torchvision.transforms import Compose, ToTensor, Resize
 from torchvision.transforms import Compose, ToTensor, Resize
 from torchvision.transforms.functional import to_pil_image
 from torchvision.transforms.functional import to_pil_image
-from threading import Thread, Lock
+from threading import Thread, Lock, Condition
 from tqdm import tqdm
 from tqdm import tqdm
-from PIL import Image
+from PIL import Image, ImageTk
 
 
 from dataset import VideoDataset
 from dataset import VideoDataset
 from model import MattingBase, MattingRefine
 from model import MattingBase, MattingRefine
+torch.backends.cudnn.benchmark = True
 
 
 
 
 # --------------- Arguments ---------------
 # --------------- Arguments ---------------
 
 
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser(description='Inference from web-cam')
+
+    parser.add_argument('--model-type', type=str, required=True, choices=['mattingbase', 'mattingrefine'])
+    parser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2'])
+    parser.add_argument('--model-backbone-scale', type=float, default=0.25)
+    parser.add_argument('--model-checkpoint', type=str, required=True)
+    parser.add_argument('--model-refine-mode', type=str, default='sampling', choices=['full', 'sampling', 'thresholding'])
+    parser.add_argument('--model-refine-sample-pixels', type=int, default=80_000)
+    parser.add_argument('--model-refine-threshold', type=float, default=0.7)
 
 
-parser = argparse.ArgumentParser(description='Inference from web-cam')
+    parser.add_argument('--hide-fps', action='store_true')
+    parser.add_argument('--resolution', type=int, nargs=2, metavar=('width', 'height'), default=(1280, 720))
 
 
-parser.add_argument('--model-type', type=str, required=True, choices=['mattingbase', 'mattingrefine'])
-parser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2'])
-parser.add_argument('--model-backbone-scale', type=float, default=0.25)
-parser.add_argument('--model-checkpoint', type=str, required=True)
-parser.add_argument('--model-refine-mode', type=str, default='sampling', choices=['full', 'sampling', 'thresholding'])
-parser.add_argument('--model-refine-sample-pixels', type=int, default=80_000)
-parser.add_argument('--model-refine-threshold', type=float, default=0.7)
+    parser.add_argument('--device-id', type=int, default=0)
 
 
-parser.add_argument('--hide-fps', action='store_true')
-parser.add_argument('--resolution', type=int, nargs=2, metavar=('width', 'height'), default=(1280, 720))
-args = parser.parse_args()
+    args = parser.parse_args()
 
 
 
 
 # ----------- Utility classes -------------
 # ----------- Utility classes -------------
@@ -62,22 +71,41 @@ class Camera:
         self.width = int(self.capture.get(cv2.CAP_PROP_FRAME_WIDTH))
         self.width = int(self.capture.get(cv2.CAP_PROP_FRAME_WIDTH))
         self.height = int(self.capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
         self.height = int(self.capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
         # self.capture.set(cv2.CAP_PROP_BUFFERSIZE, 2)
         # self.capture.set(cv2.CAP_PROP_BUFFERSIZE, 2)
+        self.exposure = self.capture.get(cv2.CAP_PROP_EXPOSURE)
+        self.capture.set(cv2.CAP_PROP_BACKLIGHT, 0)
+        self.capture.set(cv2.CAP_PROP_EXPOSURE,self.exposure)
+        self.frameAvailable = False
         self.success_reading, self.frame = self.capture.read()
         self.success_reading, self.frame = self.capture.read()
-        self.read_lock = Lock()
+        self.cv = Condition()
         self.thread = Thread(target=self.__update, args=())
         self.thread = Thread(target=self.__update, args=())
         self.thread.daemon = True
         self.thread.daemon = True
         self.thread.start()
         self.thread.start()
 
 
     def __update(self):
     def __update(self):
-        while self.success_reading:
-            grabbed, frame = self.capture.read()
-            with self.read_lock:
-                self.success_reading = grabbed
-                self.frame = frame
+      while self.success_reading:
+        grabbed, frame = self.capture.read()
+        with self.cv:
+          self.success_reading = grabbed
+          self.frame = frame
+          self.frameAvailable = True
+          self.cv.notify()
+
+    def brighter(self):
+      if self.exposure < -2:
+        self.exposure += 1
+        self.capture.set(cv2.CAP_PROP_EXPOSURE,self.exposure)
+        print(self.exposure)
+    def darker(self):
+      if self.exposure > -12:
+        self.exposure -= 1
+        self.capture.set(cv2.CAP_PROP_EXPOSURE,self.exposure)
+        print(self.exposure)
 
 
     def read(self):
     def read(self):
-        with self.read_lock:
+        with self.cv:
+            self.cv.wait_for(lambda: self.frameAvailable)
             frame = self.frame.copy()
             frame = self.frame.copy()
+            self.frameAvailable = False
         return frame
         return frame
     def __exit__(self, exec_type, exc_value, traceback):
     def __exit__(self, exec_type, exc_value, traceback):
         self.capture.release()
         self.capture.release()
@@ -102,70 +130,112 @@ class FPSTracker:
 
 
 # Wrapper for playing a stream with cv2.imshow(). It can accept an image and return keypress info for basic interactivity.
 # Wrapper for playing a stream with cv2.imshow(). It can accept an image and return keypress info for basic interactivity.
 # It also tracks FPS and optionally overlays info onto the stream.
 # It also tracks FPS and optionally overlays info onto the stream.
-class Displayer:
-    def __init__(self, title, width=None, height=None, show_info=True):
-        self.title, self.width, self.height = title, width, height
+class Displayer(QMainWindow):
+    def __init__(self, title, width, height, show_info=True):
+        self.width, self.height = width, height
         self.show_info = show_info
         self.show_info = show_info
         self.fps_tracker = FPSTracker()
         self.fps_tracker = FPSTracker()
-        cv2.namedWindow(self.title, cv2.WINDOW_NORMAL)
-        if width is not None and height is not None:
-            cv2.resizeWindow(self.title, width, height)
-    # Update the currently showing frame and return key press char code
+
+
+        QMainWindow.__init__(self)
+        self.setFixedSize(width, height)
+        self.setAttribute(QtCore.Qt.WA_TranslucentBackground, True)
+
+        self.image_label = QtWidgets.QLabel(self)
+        self.image_label.resize(width, height)
+
+        self.key = None
+
+    def keyPressEvent(self, event):
+      self.key = event.text()
+
+    def closeEvent(self, event):
+      self.key = 'q'
+      
+          # Update the currently showing frame and return key press char code
     def step(self, image):
     def step(self, image):
         fps_estimate = self.fps_tracker.tick()
         fps_estimate = self.fps_tracker.tick()
         if self.show_info and fps_estimate is not None:
         if self.show_info and fps_estimate is not None:
             message = f"{int(fps_estimate)} fps | {self.width}x{self.height}"
             message = f"{int(fps_estimate)} fps | {self.width}x{self.height}"
             cv2.putText(image, message, (10, 40), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (0, 0, 0))
             cv2.putText(image, message, (10, 40), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (0, 0, 0))
-        cv2.imshow(self.title, image)
-        return cv2.waitKey(1) & 0xFF
+        
+        pix = self.convert_cv_qt(image)
+        self.image_label.setPixmap(pix)
+        
+        QApplication.processEvents()
+
+        key = self.key
+        self.key = None
+        return key
+    def convert_cv_qt(self, cv_img):
+        """Convert from an opencv image to QPixmap"""
+        h, w, ch = cv_img.shape
+        bytes_per_line = ch * w
+        if ch == 3:
+          convert_to_Qt_format = QtGui.QImage(cv_img.data, w, h, bytes_per_line, QtGui.QImage.Format_RGB888)
+        elif ch == 4:
+          convert_to_Qt_format = QtGui.QImage(cv_img.data, w, h, bytes_per_line, QtGui.QImage.Format_RGBA8888)
+
+        return QtGui.QPixmap.fromImage(convert_to_Qt_format)
 
 
 
 
 # --------------- Main ---------------
 # --------------- Main ---------------
 
 
-
-# Load model
-if args.model_type == 'mattingbase':
-    model = MattingBase(args.model_backbone)
-if args.model_type == 'mattingrefine':
-    model = MattingRefine(
-        args.model_backbone,
-        args.model_backbone_scale,
-        args.model_refine_mode,
-        args.model_refine_sample_pixels,
-        args.model_refine_threshold)
-
-model = model.cuda().eval()
-model.load_state_dict(torch.load(args.model_checkpoint), strict=False)
-
-
-width, height = args.resolution
-cam = Camera(width=width, height=height)
-dsp = Displayer('MattingV2', cam.width, cam.height, show_info=(not args.hide_fps))
-
-def cv2_frame_to_cuda(frame):
-    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
-    return ToTensor()(Image.fromarray(frame)).unsqueeze_(0).cuda()
-
-with torch.no_grad():
-    while True:
-        bgr = None
-        while True: # grab bgr
-            frame = cam.read()
-            key = dsp.step(frame)
-            if key == ord('b'):
-                bgr = cv2_frame_to_cuda(cam.read())
-                break
-            elif key == ord('q'):
-                exit()
-        while True: # matting
-            frame = cam.read()
-            src = cv2_frame_to_cuda(frame)
-            pha, fgr = model(src, bgr)[:2]
-            res = pha * fgr + (1 - pha) * torch.ones_like(fgr)
-            res = res.mul(255).byte().cpu().permute(0, 2, 3, 1).numpy()[0]
-            res = cv2.cvtColor(res, cv2.COLOR_RGB2BGR)
-            key = dsp.step(res)
-            if key == ord('b'):
+if __name__ == "__main__":
+
+  # Load model
+  if args.model_type == 'mattingbase':
+      model = MattingBase(args.model_backbone)
+  if args.model_type == 'mattingrefine':
+      model = MattingRefine(
+          args.model_backbone,
+          args.model_backbone_scale,
+          args.model_refine_mode,
+          args.model_refine_sample_pixels,
+          args.model_refine_threshold)
+
+  model = model.cuda().eval()
+  model.load_state_dict(torch.load(args.model_checkpoint), strict=False)
+
+
+  width, height = args.resolution
+  cam = Camera(device_id=args.device_id, width=width, height=height)
+  app = QApplication(['MattingV2'])
+  dsp = Displayer('MattingV2', cam.width, cam.height, show_info=(not args.hide_fps))
+  dsp.show()
+
+  def cv2_frame_to_cuda(frame):
+      frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
+      return ToTensor()(Image.fromarray(frame)).unsqueeze_(0).cuda()
+
+  with torch.no_grad():
+      while True:
+          bgr = None
+          while True: # grab bgr
+              frame = cam.read()
+              frameRGB = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
+              key = dsp.step(frameRGB)
+              if key == 'b':
+                bgr = cv2_frame_to_cuda(frame)
                 break
                 break
-            elif key == ord('q'):
-                exit()
+              elif key == 'w':
+                cam.brighter()
+              elif key == 's':
+                cam.darker()
+              elif key == 'q':
+                  exit()
+          while True: # matting
+              frame = cam.read()
+              src = cv2_frame_to_cuda(frame)
+              pha, fgr = model(src, bgr)[:2]
+              res = pha * fgr + (1 - pha) * torch.ones_like(fgr)
+              res = res.mul(255).byte().cpu().permute(0, 2, 3, 1).numpy()[0]
+              key = dsp.step(res.copy())
+              if key == 'b':
+                  break
+              elif key == 'w':
+                cam.brighter()
+              elif key == 's':
+                cam.darker()
+              elif key == 'q':
+                  exit()