import argparse, time import cv2 import torch from PyQt5 import QtGui, QtCore, uic from PyQt5 import QtWidgets from PyQt5.QtWidgets import QMainWindow, QApplication from threading import Thread, Lock, Condition torch.backends.cudnn.benchmark = True # --------------- Arguments --------------- if __name__ == "__main__": parser = argparse.ArgumentParser(description='Inference from web-cam') parser.add_argument('--model-backbone', type=str, required=False, choices=['resnet50', 'mobilenetv3']) parser.add_argument('--torchscript-file', type=str, required=False, default=None) parser.add_argument('--hide-fps', action='store_true') parser.add_argument('--resolution', type=int, nargs=2, metavar=('width', 'height'), default=(1280, 720)) parser.add_argument('--downsampling', type=float, default=0.25) parser.add_argument('--device-id', type=int, default=0) args = parser.parse_args() # ----------- Utility classes ------------- # A wrapper that reads data from cv2.VideoCapture in its own thread to optimize. # Use .read() in a tight loop to get the newest frame class Camera: def __init__(self, device_id=0, width=1280, height=720): self.capture = cv2.VideoCapture(device_id) self.capture.set(cv2.CAP_PROP_FRAME_WIDTH, width) self.capture.set(cv2.CAP_PROP_FRAME_HEIGHT, height) self.width = int(self.capture.get(cv2.CAP_PROP_FRAME_WIDTH)) self.height = int(self.capture.get(cv2.CAP_PROP_FRAME_HEIGHT)) # self.capture.set(cv2.CAP_PROP_BUFFERSIZE, 2) self.exposure = self.capture.get(cv2.CAP_PROP_EXPOSURE) self.capture.set(cv2.CAP_PROP_BACKLIGHT, 0) self.capture.set(cv2.CAP_PROP_EXPOSURE,self.exposure) self.frameAvailable = False self.success_reading, self.frame = self.capture.read() self.cv = Condition() self.thread = Thread(target=self.__update, args=()) self.thread.daemon = True self.thread.start() def __update(self): while self.success_reading: grabbed, frame = self.capture.read() with self.cv: self.success_reading = grabbed self.frame = frame self.frameAvailable = True self.cv.notify() def brighter(self): if self.exposure < -2: self.exposure += 1 self.capture.set(cv2.CAP_PROP_EXPOSURE,self.exposure) print(self.exposure) def darker(self): if self.exposure > -12: self.exposure -= 1 self.capture.set(cv2.CAP_PROP_EXPOSURE,self.exposure) print(self.exposure) def read(self): with self.cv: self.cv.wait_for(lambda: self.frameAvailable) frame = self.frame.copy() self.frameAvailable = False return frame def __exit__(self, exec_type, exc_value, traceback): self.capture.release() # An FPS tracker that computes exponentialy moving average FPS class FPSTracker: def __init__(self, ratio=0.5): self._last_tick = None self._avg_fps = None self.ratio = ratio def tick(self): if self._last_tick is None: self._last_tick = time.time() return None t_new = time.time() fps_sample = 1.0 / (t_new - self._last_tick) self._avg_fps = self.ratio * fps_sample + (1 - self.ratio) * self._avg_fps if self._avg_fps is not None else fps_sample self._last_tick = t_new return self.get() def get(self): return self._avg_fps # Wrapper for playing a stream with cv2.imshow(). It can accept an image and return keypress info for basic interactivity. # It also tracks FPS and optionally overlays info onto the stream. class Displayer(QMainWindow): def __init__(self, title, width, height, show_info=True): self.width, self.height = width, height self.show_info = show_info self.fps_tracker = FPSTracker() QMainWindow.__init__(self) self.setFixedSize(width, height) self.setAttribute(QtCore.Qt.WA_TranslucentBackground, True) self.image_label = QtWidgets.QLabel(self) self.image_label.resize(width, height) self.key = None def keyPressEvent(self, event): self.key = event.text() def closeEvent(self, event): self.key = 'q' # Update the currently showing frame and return key press char code def step(self, image): fps_estimate = self.fps_tracker.tick() if self.show_info and fps_estimate is not None: message = f"{int(fps_estimate)} fps | {self.width}x{self.height}" cv2.putText(image, message, (10, 40), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (0, 0, 0)) pix = self.convert_cv_qt(image) self.image_label.setPixmap(pix) QApplication.processEvents() key = self.key self.key = None return key def convert_cv_qt(self, cv_img): """Convert from an opencv image to QPixmap""" h, w, ch = cv_img.shape bytes_per_line = ch * w if ch == 3: convert_to_Qt_format = QtGui.QImage(cv_img.data, w, h, bytes_per_line, QtGui.QImage.Format_RGB888) elif ch == 4: convert_to_Qt_format = QtGui.QImage(cv_img.data, w, h, bytes_per_line, QtGui.QImage.Format_RGBA8888) return QtGui.QPixmap.fromImage(convert_to_Qt_format) def cv2_frame_to_cuda(frame, datatype = torch.float32): frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) return torch.as_tensor(frame).type(datatype).unsqueeze(0).permute(0, 3, 1, 2)/255 # --------------- Main --------------- if __name__ == "__main__": if args.torchscript_file: model = torch.jit.load(args.torchscript_file) model = torch.jit.freeze(model) else: model = torch.hub.load("PeterL1n/RobustVideoMatting", args.model_backbone) if torch.cuda.is_available(): model = model.cuda().eval() datatype = torch.float32 width, height = args.resolution cam = Camera(device_id=args.device_id, width=width, height=height) app = QApplication(['RobustVideoMatting']) dsp = Displayer('RobustVideoMatting', width, height, show_info=(not args.hide_fps)) dsp.show() rec = [None] * 4 # Initial recurrent states. with torch.no_grad(): while True: frame = cam.read() src = cv2_frame_to_cuda(frame, datatype) fgr, pha, *rec = model(src.cuda() if torch.cuda.is_available() else src, *rec, args.downsampling) # Cycle the recurrent states. res = pha res = res.mul(255).byte().cpu().permute(0, 2, 3, 1).numpy()[0] b_channel, g_channel, r_channel = cv2.split(frame) img_RGBA = cv2.merge((r_channel, g_channel, b_channel, res)) key = dsp.step(img_RGBA) if key == 'w': cam.brighter() elif key == 's': cam.darker() elif key == 'q': exit()