inference_utils.py 1.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. import numpy as np
  2. import cv2
  3. from PIL import Image
  4. class HomographicAlignment:
  5. """
  6. Apply homographic alignment on background to match with the source image.
  7. """
  8. def __init__(self):
  9. self.detector = cv2.ORB_create()
  10. self.matcher = cv2.DescriptorMatcher_create(cv2.DESCRIPTOR_MATCHER_BRUTEFORCE)
  11. def __call__(self, src, bgr):
  12. src = np.asarray(src)
  13. bgr = np.asarray(bgr)
  14. keypoints_src, descriptors_src = self.detector.detectAndCompute(src, None)
  15. keypoints_bgr, descriptors_bgr = self.detector.detectAndCompute(bgr, None)
  16. matches = self.matcher.match(descriptors_bgr, descriptors_src, None)
  17. matches.sort(key=lambda x: x.distance, reverse=False)
  18. num_good_matches = int(len(matches) * 0.15)
  19. matches = matches[:num_good_matches]
  20. points_src = np.zeros((len(matches), 2), dtype=np.float32)
  21. points_bgr = np.zeros((len(matches), 2), dtype=np.float32)
  22. for i, match in enumerate(matches):
  23. points_src[i, :] = keypoints_src[match.trainIdx].pt
  24. points_bgr[i, :] = keypoints_bgr[match.queryIdx].pt
  25. H, _ = cv2.findHomography(points_bgr, points_src, cv2.RANSAC)
  26. h, w = src.shape[:2]
  27. bgr = cv2.warpPerspective(bgr, H, (w, h))
  28. msk = cv2.warpPerspective(np.ones((h, w)), H, (w, h))
  29. # For areas that is outside of the background,
  30. # We just copy pixels from the source.
  31. bgr[msk != 1] = src[msk != 1]
  32. src = Image.fromarray(src)
  33. bgr = Image.fromarray(bgr)
  34. return src, bgr