subDesTagesMitExtraKaese 4 жил өмнө
parent
commit
3521f62c01

+ 3 - 123
inference_webcam_ts.py

@@ -34,7 +34,7 @@ from PIL import Image, ImageTk
 
 from dataset import VideoDataset
 from model import MattingBase, MattingRefine
-torch.backends.cudnn.benchmark = True
+from inference_webcam import Displayer, FPSTracker, Camera
 
 
 # --------------- Arguments ---------------
@@ -54,131 +54,11 @@ parser.add_argument('--hide-fps', action='store_true')
 parser.add_argument('--resolution', type=int, nargs=2, metavar=('width', 'height'), default=(1280, 720))
 
 parser.add_argument('--device-id', type=int, default=0)
+parser.add_argument('--background-image', type=str, default="")
+parser.add_argument('--fps-limit', type=int, default=1000)
 
 args = parser.parse_args()
 
-
-# ----------- Utility classes -------------
-
-
-# A wrapper that reads data from cv2.VideoCapture in its own thread to optimize.
-# Use .read() in a tight loop to get the newest frame
-class Camera:
-    def __init__(self, device_id=0, width=1280, height=720):
-        self.capture = cv2.VideoCapture(device_id)
-        self.capture.set(cv2.CAP_PROP_FRAME_WIDTH, width)
-        self.capture.set(cv2.CAP_PROP_FRAME_HEIGHT, height)
-        self.width = int(self.capture.get(cv2.CAP_PROP_FRAME_WIDTH))
-        self.height = int(self.capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
-        # self.capture.set(cv2.CAP_PROP_BUFFERSIZE, 2)
-        self.exposure = self.capture.get(cv2.CAP_PROP_EXPOSURE)
-        self.capture.set(cv2.CAP_PROP_BACKLIGHT, 0)
-        self.capture.set(cv2.CAP_PROP_EXPOSURE,self.exposure)
-        self.frameAvailable = False
-        self.success_reading, self.frame = self.capture.read()
-        self.cv = Condition()
-        self.thread = Thread(target=self.__update, args=())
-        self.thread.daemon = True
-        self.thread.start()
-
-    def __update(self):
-      while self.success_reading:
-        grabbed, frame = self.capture.read()
-        with self.cv:
-          self.success_reading = grabbed
-          self.frame = frame
-          self.frameAvailable = True
-          self.cv.notify()
-
-    def brighter(self):
-      if self.exposure < -2:
-        self.exposure += 1
-        self.capture.set(cv2.CAP_PROP_EXPOSURE,self.exposure)
-        print(self.exposure)
-    def darker(self):
-      if self.exposure > -12:
-        self.exposure -= 1
-        self.capture.set(cv2.CAP_PROP_EXPOSURE,self.exposure)
-        print(self.exposure)
-
-    def read(self):
-        with self.cv:
-            self.cv.wait_for(lambda: self.frameAvailable)
-            frame = self.frame.copy()
-            self.frameAvailable = False
-        return frame
-    def __exit__(self, exec_type, exc_value, traceback):
-        self.capture.release()
-
-# An FPS tracker that computes exponentialy moving average FPS
-class FPSTracker:
-    def __init__(self, ratio=0.5):
-        self._last_tick = None
-        self._avg_fps = None
-        self.ratio = ratio
-    def tick(self):
-        if self._last_tick is None:
-            self._last_tick = time.time()
-            return None
-        t_new = time.time()
-        fps_sample = 1.0 / (t_new - self._last_tick)
-        self._avg_fps = self.ratio * fps_sample + (1 - self.ratio) * self._avg_fps if self._avg_fps is not None else fps_sample
-        self._last_tick = t_new
-        return self.get()
-    def get(self):
-        return self._avg_fps
-
-# Wrapper for playing a stream with cv2.imshow(). It can accept an image and return keypress info for basic interactivity.
-# It also tracks FPS and optionally overlays info onto the stream.
-class Displayer(QMainWindow):
-    def __init__(self, title, width, height, show_info=True):
-        self.width, self.height = width, height
-        self.show_info = show_info
-        self.fps_tracker = FPSTracker()
-
-
-        QMainWindow.__init__(self)
-        self.setFixedSize(width, height)
-        self.setAttribute(QtCore.Qt.WA_TranslucentBackground, True)
-
-        self.image_label = QtWidgets.QLabel(self)
-        self.image_label.resize(width, height)
-
-        self.key = None
-
-    def keyPressEvent(self, event):
-      self.key = event.text()
-
-    def closeEvent(self, event):
-      self.key = 'q'
-
-    # Update the currently showing frame and return key press char code
-    def step(self, image):
-        fps_estimate = self.fps_tracker.tick()
-        if self.show_info and fps_estimate is not None:
-            message = f"{int(fps_estimate)} fps | {self.width}x{self.height}"
-            cv2.putText(image, message, (10, 40), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (0, 0, 0))
-        
-        pix = self.convert_cv_qt(image)
-        self.image_label.setPixmap(pix)
-        
-        QApplication.processEvents()
-
-        key = self.key
-        self.key = None
-        return key
-    def convert_cv_qt(self, cv_img):
-        """Convert from an opencv image to QPixmap"""
-        h, w, ch = cv_img.shape
-        bytes_per_line = ch * w
-        if ch == 3:
-          convert_to_Qt_format = QtGui.QImage(cv_img.data, w, h, bytes_per_line, QtGui.QImage.Format_RGB888)
-        elif ch == 4:
-          convert_to_Qt_format = QtGui.QImage(cv_img.data, w, h, bytes_per_line, QtGui.QImage.Format_RGBA8888)
-
-        return QtGui.QPixmap.fromImage(convert_to_Qt_format)
-
-
 # --------------- Main ---------------
 
 model = torch.jit.load(args.model_checkpoint)

+ 13 - 122
inference_webcam_ts_compositing.py

@@ -34,6 +34,7 @@ from PIL import Image, ImageTk
 
 from dataset import VideoDataset
 from model import MattingBase, MattingRefine
+from inference_webcam import Displayer, FPSTracker, Camera
 
 
 # --------------- Arguments ---------------
@@ -54,133 +55,14 @@ parser.add_argument('--resolution', type=int, nargs=2, metavar=('width', 'height
 
 parser.add_argument('--device-id', type=int, default=0)
 parser.add_argument('--background-image', type=str, default="")
+parser.add_argument('--fps-limit', type=int, default=1000)
 
 args = parser.parse_args()
 
 
-# ----------- Utility classes -------------
-
-
-# A wrapper that reads data from cv2.VideoCapture in its own thread to optimize.
-# Use .read() in a tight loop to get the newest frame
-class Camera:
-    def __init__(self, device_id=0, width=1280, height=720):
-        self.capture = cv2.VideoCapture(device_id)
-        self.capture.set(cv2.CAP_PROP_FRAME_WIDTH, width)
-        self.capture.set(cv2.CAP_PROP_FRAME_HEIGHT, height)
-        self.width = int(self.capture.get(cv2.CAP_PROP_FRAME_WIDTH))
-        self.height = int(self.capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
-        # self.capture.set(cv2.CAP_PROP_BUFFERSIZE, 2)
-        self.exposure = self.capture.get(cv2.CAP_PROP_EXPOSURE)
-        self.capture.set(cv2.CAP_PROP_BACKLIGHT, 0)
-        self.capture.set(cv2.CAP_PROP_EXPOSURE,self.exposure)
-        self.frameAvailable = False
-        self.success_reading, self.frame = self.capture.read()
-        self.cv = Condition()
-        self.thread = Thread(target=self.__update, args=())
-        self.thread.daemon = True
-        self.thread.start()
-
-    def __update(self):
-      while self.success_reading:
-        grabbed, frame = self.capture.read()
-        with self.cv:
-          self.success_reading = grabbed
-          self.frame = frame
-          self.frameAvailable = True
-          self.cv.notify()
-
-    def brighter(self):
-      if self.exposure < -2:
-        self.exposure += 1
-        self.capture.set(cv2.CAP_PROP_EXPOSURE,self.exposure)
-        print(self.exposure)
-    def darker(self):
-      if self.exposure > -12:
-        self.exposure -= 1
-        self.capture.set(cv2.CAP_PROP_EXPOSURE,self.exposure)
-        print(self.exposure)
-
-    def read(self):
-        with self.cv:
-            self.cv.wait_for(lambda: self.frameAvailable)
-            frame = self.frame.copy()
-            self.frameAvailable = False
-        return frame
-    def __exit__(self, exec_type, exc_value, traceback):
-        self.capture.release()
-
-# An FPS tracker that computes exponentialy moving average FPS
-class FPSTracker:
-    def __init__(self, ratio=0.5):
-        self._last_tick = None
-        self._avg_fps = None
-        self.ratio = ratio
-    def tick(self):
-        if self._last_tick is None:
-            self._last_tick = time.time()
-            return None
-        t_new = time.time()
-        fps_sample = 1.0 / (t_new - self._last_tick)
-        self._avg_fps = self.ratio * fps_sample + (1 - self.ratio) * self._avg_fps if self._avg_fps is not None else fps_sample
-        self._last_tick = t_new
-        return self.get()
-    def get(self):
-        return self._avg_fps
-
-# Wrapper for playing a stream with cv2.imshow(). It can accept an image and return keypress info for basic interactivity.
-# It also tracks FPS and optionally overlays info onto the stream.
-class Displayer(QMainWindow):
-    def __init__(self, title, width, height, show_info=True):
-        self.width, self.height = width, height
-        self.show_info = show_info
-        self.fps_tracker = FPSTracker()
-
-
-        QMainWindow.__init__(self)
-        self.setFixedSize(width, height)
-        self.setAttribute(QtCore.Qt.WA_TranslucentBackground, True)
-
-        self.image_label = QtWidgets.QLabel(self)
-        self.image_label.resize(width, height)
-
-        self.key = None
-
-    def keyPressEvent(self, event):
-      self.key = event.text()
-
-    def closeEvent(self, event):
-      self.key = 'q'
-
-    # Update the currently showing frame and return key press char code
-    def step(self, image):
-        fps_estimate = self.fps_tracker.tick()
-        if self.show_info and fps_estimate is not None:
-            message = f"{int(fps_estimate)} fps | {self.width}x{self.height}"
-            cv2.putText(image, message, (10, 40), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (0, 0, 0))
-        
-        pix = self.convert_cv_qt(image)
-        self.image_label.setPixmap(pix)
-        
-        QApplication.processEvents()
-
-        key = self.key
-        self.key = None
-        return key
-    def convert_cv_qt(self, cv_img):
-        """Convert from an opencv image to QPixmap"""
-        h, w, ch = cv_img.shape
-        bytes_per_line = ch * w
-        if ch == 3:
-          convert_to_Qt_format = QtGui.QImage(cv_img.data, w, h, bytes_per_line, QtGui.QImage.Format_RGB888)
-        elif ch == 4:
-          convert_to_Qt_format = QtGui.QImage(cv_img.data, w, h, bytes_per_line, QtGui.QImage.Format_RGBA8888)
-
-        return QtGui.QPixmap.fromImage(convert_to_Qt_format)
-
-
 # --------------- Main ---------------
 
+default_float_dtype = torch.get_default_dtype()
 model = torch.jit.load(args.model_checkpoint)
 model.backbone_scale = args.model_backbone_scale
 model.refine_mode = args.model_refine_mode
@@ -201,7 +83,16 @@ def cv2_frame_to_cuda(frame):
     if 'fp16'in args.model_checkpoint:
       return ToTensor()(Image.fromarray(frame)).unsqueeze_(0).to(torch.float16).cuda()
     else:
-      return ToTensor()(Image.fromarray(frame)).unsqueeze_(0).to(torch.float32).cuda()
+      pic = Image.fromarray(frame)
+      img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
+      img = img.cuda()
+      img = img.view(pic.size[1], pic.size[0], len(pic.getbands()))
+      img = img.permute((2, 0, 1)).contiguous()
+      tmp = img.to(dtype=default_float_dtype).div(255)
+      tmp.unsqueeze_(0)
+      tmp = tmp.to(torch.float32)
+      return tmp
+      #return ToTensor()(Image.fromarray(frame)).unsqueeze_(0).to(torch.float32).cuda()
 
 with torch.no_grad():
     while True: