Преглед изворни кода

moved cv2_frame_to_cuda to inference_webcam

subDesTagesMitExtraKaese пре 3 година
родитељ
комит
fb5961f70b
3 измењених фајлова са 19 додато и 38 уклоњено
  1. 8 6
      inference_webcam.py
  2. 5 11
      inference_webcam_ts.py
  3. 6 21
      inference_webcam_ts_compositing.py

+ 8 - 6
inference_webcam.py

@@ -179,6 +179,11 @@ class Displayer(QMainWindow):
         return QtGui.QPixmap.fromImage(convert_to_Qt_format)
 
 
+
+def cv2_frame_to_cuda(frame, datatype = torch.float32):
+    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
+    return ToTensor()(Image.fromarray(frame)).unsqueeze_(0).to(datatype).cuda()
+
 # --------------- Main ---------------
 
 if __name__ == "__main__":
@@ -197,6 +202,7 @@ if __name__ == "__main__":
   model = model.cuda().eval()
   model.load_state_dict(torch.load(args.model_checkpoint), strict=False)
 
+  datatype = torch.float16 if 'fp16' in args.model_checkpoint else torch.float32
 
   width, height = args.resolution
   cam = Camera(device_id=args.device_id, width=width, height=height)
@@ -204,10 +210,6 @@ if __name__ == "__main__":
   dsp = Displayer('MattingV2', cam.width, cam.height, show_info=(not args.hide_fps))
   dsp.show()
 
-  def cv2_frame_to_cuda(frame):
-      frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
-      return ToTensor()(Image.fromarray(frame)).unsqueeze_(0).cuda()
-
   with torch.no_grad():
       while True:
           bgr = None
@@ -216,7 +218,7 @@ if __name__ == "__main__":
               frameRGB = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
               key = dsp.step(frameRGB)
               if key == 'b':
-                bgr = cv2_frame_to_cuda(frame)
+                bgr = cv2_frame_to_cuda(frame, datatype)
                 break
               elif key == 'w':
                 cam.brighter()
@@ -226,7 +228,7 @@ if __name__ == "__main__":
                   exit()
           while True: # matting
               frame = cam.read()
-              src = cv2_frame_to_cuda(frame)
+              src = cv2_frame_to_cuda(frame, datatype)
               pha, fgr = model(src, bgr)[:2]
               res = pha * fgr + (1 - pha) * torch.ones_like(fgr)
               res = res.mul(255).byte().cpu().permute(0, 2, 3, 1).numpy()[0]

+ 5 - 11
inference_webcam_ts.py

@@ -34,7 +34,7 @@ from PIL import Image, ImageTk
 
 from dataset import VideoDataset
 from model import MattingBase, MattingRefine
-from inference_webcam import Displayer, FPSTracker, Camera
+from inference_webcam import Displayer, FPSTracker, Camera, cv2_frame_to_cuda
 
 
 # --------------- Arguments ---------------
@@ -67,6 +67,8 @@ model.refine_mode = args.model_refine_mode
 model.refine_sample_pixels = args.model_refine_sample_pixels
 model.refine_threshold = args.model_refine_threshold
 
+datatype = torch.float16 if 'fp16' in args.model_checkpoint else torch.float32
+
 model = model.to(torch.device('cuda'))
 
 width, height = args.resolution
@@ -75,14 +77,6 @@ app = QApplication(['MattingV2'])
 dsp = Displayer('MattingV2', cam.width, cam.height, show_info=(not args.hide_fps))
 dsp.show()
 
-
-def cv2_frame_to_cuda(frame):
-    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
-    if 'fp16'in args.model_checkpoint:
-      return ToTensor()(Image.fromarray(frame)).unsqueeze_(0).to(torch.float16).cuda()
-    else:
-      return ToTensor()(Image.fromarray(frame)).unsqueeze_(0).to(torch.float32).cuda()
-
 with torch.no_grad():
     while True:
         bgr = None
@@ -91,7 +85,7 @@ with torch.no_grad():
             frameRGB = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
             key = dsp.step(frameRGB)
             if key == 'b':
-              bgr = cv2_frame_to_cuda(frame)
+              bgr = cv2_frame_to_cuda(frame, datatype)
               break
             elif key == 'w':
               cam.brighter()
@@ -102,7 +96,7 @@ with torch.no_grad():
 
         while True: # matting
             frame = cam.read()
-            src = cv2_frame_to_cuda(frame)
+            src = cv2_frame_to_cuda(frame, datatype)
             pha, fgr = model(src, bgr)[:2]
             res = pha
             res = res.mul(255).byte().cpu().permute(0, 2, 3, 1).numpy()[0]

+ 6 - 21
inference_webcam_ts_compositing.py

@@ -34,7 +34,7 @@ from PIL import Image, ImageTk
 
 from dataset import VideoDataset
 from model import MattingBase, MattingRefine
-from inference_webcam import Displayer, FPSTracker, Camera
+from inference_webcam import Displayer, FPSTracker, Camera, cv2_frame_to_cuda
 
 
 # --------------- Arguments ---------------
@@ -69,6 +69,8 @@ model.refine_mode = args.model_refine_mode
 model.refine_sample_pixels = args.model_refine_sample_pixels
 model.refine_threshold = args.model_refine_threshold
 
+datatype = torch.float16 if 'fp16' in args.model_checkpoint else torch.float32
+
 model = model.to(torch.device('cuda'))
 
 width, height = args.resolution
@@ -77,23 +79,6 @@ app = QApplication(['MattingV2'])
 dsp = Displayer('MattingV2', cam.width, cam.height, show_info=(not args.hide_fps))
 dsp.show()
 
-
-def cv2_frame_to_cuda(frame):
-    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
-    if 'fp16'in args.model_checkpoint:
-      return ToTensor()(Image.fromarray(frame)).unsqueeze_(0).to(torch.float16).cuda()
-    else:
-      pic = Image.fromarray(frame)
-      img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
-      img = img.cuda()
-      img = img.view(pic.size[1], pic.size[0], len(pic.getbands()))
-      img = img.permute((2, 0, 1)).contiguous()
-      tmp = img.to(dtype=default_float_dtype).div(255)
-      tmp.unsqueeze_(0)
-      tmp = tmp.to(torch.float32)
-      return tmp
-      #return ToTensor()(Image.fromarray(frame)).unsqueeze_(0).to(torch.float32).cuda()
-
 with torch.no_grad():
     while True:
         bgr = None
@@ -102,7 +87,7 @@ with torch.no_grad():
             frameRGB = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
             key = dsp.step(frameRGB)
             if key == 'b':
-              bgr = cv2_frame_to_cuda(frame)
+              bgr = cv2_frame_to_cuda(frame, datatype)
               break
             elif key == 'w':
               cam.brighter()
@@ -118,11 +103,11 @@ with torch.no_grad():
         else:
           bgImage = cv2.imread(args.background_image, cv2.IMREAD_UNCHANGED)
           bgImage = cv2.resize(bgImage, (frame.shape[1], frame.shape[0]))
-          bgImage = cv2_frame_to_cuda(bgImage)
+          bgImage = cv2_frame_to_cuda(bgImage, datatype)
 
         while True: # matting
             frame = cam.read()
-            src = cv2_frame_to_cuda(frame)
+            src = cv2_frame_to_cuda(frame, datatype)
             pha, fgr = model(src, bgr)[:2]
             res = pha * fgr + (1 - pha) * bgImage
             res = res.mul(255).byte().cpu().permute(0, 2, 3, 1).numpy()[0]