Explorar o código

src image conversion on gpu

subDesTagesMitExtraKaese %!s(int64=4) %!d(string=hai) anos
pai
achega
c995c0230d
Modificáronse 1 ficheiros con 10 adicións e 1 borrados
  1. 10 1
      inference_webcam.py

+ 10 - 1
inference_webcam.py

@@ -182,7 +182,16 @@ class Displayer(QMainWindow):
 
 def cv2_frame_to_cuda(frame, datatype = torch.float32):
     frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
-    return ToTensor()(Image.fromarray(frame)).unsqueeze_(0).to(datatype).cuda()
+    #return ToTensor()(Image.fromarray(frame)).unsqueeze_(0).to(datatype).cuda()
+    pic = Image.fromarray(frame)
+    img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
+    img = img.cuda()
+    img = img.view(pic.size[1], pic.size[0], len(pic.getbands()))
+    img = img.permute((2, 0, 1)).contiguous()
+    tmp = img.to(dtype=datatype).div(255)
+    tmp.unsqueeze_(0)
+    tmp = tmp.to(datatype)
+    return tmp
 
 # --------------- Main ---------------