|
@@ -34,7 +34,7 @@ from PIL import Image, ImageTk
|
|
|
|
|
|
from dataset import VideoDataset
|
|
|
from model import MattingBase, MattingRefine
|
|
|
-from inference_webcam import Displayer, FPSTracker, Camera
|
|
|
+from inference_webcam import Displayer, FPSTracker, Camera, cv2_frame_to_cuda
|
|
|
|
|
|
|
|
|
# --------------- Arguments ---------------
|
|
@@ -69,6 +69,8 @@ model.refine_mode = args.model_refine_mode
|
|
|
model.refine_sample_pixels = args.model_refine_sample_pixels
|
|
|
model.refine_threshold = args.model_refine_threshold
|
|
|
|
|
|
+datatype = torch.float16 if 'fp16' in args.model_checkpoint else torch.float32
|
|
|
+
|
|
|
model = model.to(torch.device('cuda'))
|
|
|
|
|
|
width, height = args.resolution
|
|
@@ -77,23 +79,6 @@ app = QApplication(['MattingV2'])
|
|
|
dsp = Displayer('MattingV2', cam.width, cam.height, show_info=(not args.hide_fps))
|
|
|
dsp.show()
|
|
|
|
|
|
-
|
|
|
-def cv2_frame_to_cuda(frame):
|
|
|
- frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
|
|
- if 'fp16'in args.model_checkpoint:
|
|
|
- return ToTensor()(Image.fromarray(frame)).unsqueeze_(0).to(torch.float16).cuda()
|
|
|
- else:
|
|
|
- pic = Image.fromarray(frame)
|
|
|
- img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
|
|
|
- img = img.cuda()
|
|
|
- img = img.view(pic.size[1], pic.size[0], len(pic.getbands()))
|
|
|
- img = img.permute((2, 0, 1)).contiguous()
|
|
|
- tmp = img.to(dtype=default_float_dtype).div(255)
|
|
|
- tmp.unsqueeze_(0)
|
|
|
- tmp = tmp.to(torch.float32)
|
|
|
- return tmp
|
|
|
- #return ToTensor()(Image.fromarray(frame)).unsqueeze_(0).to(torch.float32).cuda()
|
|
|
-
|
|
|
with torch.no_grad():
|
|
|
while True:
|
|
|
bgr = None
|
|
@@ -102,7 +87,7 @@ with torch.no_grad():
|
|
|
frameRGB = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
|
|
key = dsp.step(frameRGB)
|
|
|
if key == 'b':
|
|
|
- bgr = cv2_frame_to_cuda(frame)
|
|
|
+ bgr = cv2_frame_to_cuda(frame, datatype)
|
|
|
break
|
|
|
elif key == 'w':
|
|
|
cam.brighter()
|
|
@@ -118,11 +103,11 @@ with torch.no_grad():
|
|
|
else:
|
|
|
bgImage = cv2.imread(args.background_image, cv2.IMREAD_UNCHANGED)
|
|
|
bgImage = cv2.resize(bgImage, (frame.shape[1], frame.shape[0]))
|
|
|
- bgImage = cv2_frame_to_cuda(bgImage)
|
|
|
+ bgImage = cv2_frame_to_cuda(bgImage, datatype)
|
|
|
|
|
|
while True: # matting
|
|
|
frame = cam.read()
|
|
|
- src = cv2_frame_to_cuda(frame)
|
|
|
+ src = cv2_frame_to_cuda(frame, datatype)
|
|
|
pha, fgr = model(src, bgr)[:2]
|
|
|
res = pha * fgr + (1 - pha) * bgImage
|
|
|
res = res.mul(255).byte().cpu().permute(0, 2, 3, 1).numpy()[0]
|