inference_webcam.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. import argparse, time
  2. import cv2
  3. import torch
  4. from PyQt5 import QtGui, QtCore, uic
  5. from PyQt5 import QtWidgets
  6. from PyQt5.QtWidgets import QMainWindow, QApplication
  7. from threading import Thread, Lock, Condition
  8. torch.backends.cudnn.benchmark = True
  9. # --------------- Arguments ---------------
  10. if __name__ == "__main__":
  11. parser = argparse.ArgumentParser(description='Inference from web-cam')
  12. parser.add_argument('--model-backbone', type=str, required=False, choices=['resnet50', 'mobilenetv3'])
  13. parser.add_argument('--torchscript-file', type=str, required=False, default=None)
  14. parser.add_argument('--hide-fps', action='store_true')
  15. parser.add_argument('--resolution', type=int, nargs=2, metavar=('width', 'height'), default=(1280, 720))
  16. parser.add_argument('--downsampling', type=float, default=0.25)
  17. parser.add_argument('--device-id', type=int, default=0)
  18. args = parser.parse_args()
  19. # ----------- Utility classes -------------
  20. # A wrapper that reads data from cv2.VideoCapture in its own thread to optimize.
  21. # Use .read() in a tight loop to get the newest frame
  22. class Camera:
  23. def __init__(self, device_id=0, width=1280, height=720):
  24. self.capture = cv2.VideoCapture(device_id)
  25. self.capture.set(cv2.CAP_PROP_FRAME_WIDTH, width)
  26. self.capture.set(cv2.CAP_PROP_FRAME_HEIGHT, height)
  27. self.width = int(self.capture.get(cv2.CAP_PROP_FRAME_WIDTH))
  28. self.height = int(self.capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
  29. # self.capture.set(cv2.CAP_PROP_BUFFERSIZE, 2)
  30. self.exposure = self.capture.get(cv2.CAP_PROP_EXPOSURE)
  31. self.capture.set(cv2.CAP_PROP_BACKLIGHT, 0)
  32. self.capture.set(cv2.CAP_PROP_EXPOSURE,self.exposure)
  33. self.frameAvailable = False
  34. self.success_reading, self.frame = self.capture.read()
  35. self.cv = Condition()
  36. self.thread = Thread(target=self.__update, args=())
  37. self.thread.daemon = True
  38. self.thread.start()
  39. def __update(self):
  40. while self.success_reading:
  41. grabbed, frame = self.capture.read()
  42. with self.cv:
  43. self.success_reading = grabbed
  44. self.frame = frame
  45. self.frameAvailable = True
  46. self.cv.notify()
  47. def brighter(self):
  48. if self.exposure < -2:
  49. self.exposure += 1
  50. self.capture.set(cv2.CAP_PROP_EXPOSURE,self.exposure)
  51. print(self.exposure)
  52. def darker(self):
  53. if self.exposure > -12:
  54. self.exposure -= 1
  55. self.capture.set(cv2.CAP_PROP_EXPOSURE,self.exposure)
  56. print(self.exposure)
  57. def read(self):
  58. with self.cv:
  59. self.cv.wait_for(lambda: self.frameAvailable)
  60. frame = self.frame.copy()
  61. self.frameAvailable = False
  62. return frame
  63. def __exit__(self, exec_type, exc_value, traceback):
  64. self.capture.release()
  65. # An FPS tracker that computes exponentialy moving average FPS
  66. class FPSTracker:
  67. def __init__(self, ratio=0.5):
  68. self._last_tick = None
  69. self._avg_fps = None
  70. self.ratio = ratio
  71. def tick(self):
  72. if self._last_tick is None:
  73. self._last_tick = time.time()
  74. return None
  75. t_new = time.time()
  76. fps_sample = 1.0 / (t_new - self._last_tick)
  77. self._avg_fps = self.ratio * fps_sample + (1 - self.ratio) * self._avg_fps if self._avg_fps is not None else fps_sample
  78. self._last_tick = t_new
  79. return self.get()
  80. def get(self):
  81. return self._avg_fps
  82. # Wrapper for playing a stream with cv2.imshow(). It can accept an image and return keypress info for basic interactivity.
  83. # It also tracks FPS and optionally overlays info onto the stream.
  84. class Displayer(QMainWindow):
  85. def __init__(self, title, width, height, show_info=True):
  86. self.width, self.height = width, height
  87. self.show_info = show_info
  88. self.fps_tracker = FPSTracker()
  89. QMainWindow.__init__(self)
  90. self.setFixedSize(width, height)
  91. self.setAttribute(QtCore.Qt.WA_TranslucentBackground, True)
  92. self.image_label = QtWidgets.QLabel(self)
  93. self.image_label.resize(width, height)
  94. self.key = None
  95. def keyPressEvent(self, event):
  96. self.key = event.text()
  97. def closeEvent(self, event):
  98. self.key = 'q'
  99. # Update the currently showing frame and return key press char code
  100. def step(self, image):
  101. fps_estimate = self.fps_tracker.tick()
  102. if self.show_info and fps_estimate is not None:
  103. message = f"{int(fps_estimate)} fps | {self.width}x{self.height}"
  104. cv2.putText(image, message, (10, 40), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (0, 0, 0))
  105. pix = self.convert_cv_qt(image)
  106. self.image_label.setPixmap(pix)
  107. QApplication.processEvents()
  108. key = self.key
  109. self.key = None
  110. return key
  111. def convert_cv_qt(self, cv_img):
  112. """Convert from an opencv image to QPixmap"""
  113. h, w, ch = cv_img.shape
  114. bytes_per_line = ch * w
  115. if ch == 3:
  116. convert_to_Qt_format = QtGui.QImage(cv_img.data, w, h, bytes_per_line, QtGui.QImage.Format_RGB888)
  117. elif ch == 4:
  118. convert_to_Qt_format = QtGui.QImage(cv_img.data, w, h, bytes_per_line, QtGui.QImage.Format_RGBA8888)
  119. return QtGui.QPixmap.fromImage(convert_to_Qt_format)
  120. def cv2_frame_to_cuda(frame, datatype = torch.float32):
  121. frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
  122. return torch.as_tensor(frame).type(datatype).unsqueeze(0).permute(0, 3, 1, 2)/255
  123. # --------------- Main ---------------
  124. if __name__ == "__main__":
  125. if args.torchscript_file:
  126. model = torch.jit.load(args.torchscript_file)
  127. model = torch.jit.freeze(model)
  128. else:
  129. model = torch.hub.load("PeterL1n/RobustVideoMatting", args.model_backbone)
  130. if torch.cuda.is_available():
  131. model = model.cuda().eval()
  132. datatype = torch.float32
  133. width, height = args.resolution
  134. cam = Camera(device_id=args.device_id, width=width, height=height)
  135. app = QApplication(['RobustVideoMatting'])
  136. dsp = Displayer('RobustVideoMatting', width, height, show_info=(not args.hide_fps))
  137. dsp.show()
  138. rec = [None] * 4 # Initial recurrent states.
  139. with torch.no_grad():
  140. while True:
  141. frame = cam.read()
  142. src = cv2_frame_to_cuda(frame, datatype)
  143. fgr, pha, *rec = model(src.cuda() if torch.cuda.is_available() else src, *rec, args.downsampling) # Cycle the recurrent states.
  144. res = pha
  145. res = res.mul(255).byte().cpu().permute(0, 2, 3, 1).numpy()[0]
  146. b_channel, g_channel, r_channel = cv2.split(frame)
  147. img_RGBA = cv2.merge((r_channel, g_channel, b_channel, res))
  148. key = dsp.step(img_RGBA)
  149. if key == 'w':
  150. cam.brighter()
  151. elif key == 's':
  152. cam.darker()
  153. elif key == 'q':
  154. exit()