refiner.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282
  1. import torch
  2. import torchvision
  3. from torch import nn
  4. from torch.nn import functional as F
  5. from typing import Tuple
  6. class Refiner(nn.Module):
  7. """
  8. Refiner refines the coarse output to full resolution.
  9. Args:
  10. mode: area selection mode. Options:
  11. "full" - No area selection, refine everywhere using regular Conv2d.
  12. "sampling" - Refine fixed amount of pixels ranked by the top most errors.
  13. "thresholding" - Refine varying amount of pixels that have greater error than the threshold.
  14. sample_pixels: number of pixels to refine. Only used when mode == "sampling".
  15. threshold: error threshold ranged from 0 ~ 1. Refine where err > threshold. Only used when mode == "thresholding".
  16. kernel_size: The convolution kernel_size. Options: [1, 3]
  17. prevent_oversampling: True for regular cases, False for speedtest.
  18. Compatibility Args:
  19. patch_crop_method: the method for cropping patches. Options:
  20. "unfold" - Best performance for PyTorch and TorchScript.
  21. "roi_align" - Another way for croping patches.
  22. "gather" - Another way for croping patches.
  23. patch_replace_method: the method for replacing patches. Options:
  24. "scatter_nd" - Best performance for PyTorch and TorchScript.
  25. "scatter_element" - Another way for replacing patches.
  26. Input:
  27. src: (B, 3, H, W) full resolution source image.
  28. bgr: (B, 3, H, W) full resolution background image.
  29. pha: (B, 1, Hc, Wc) coarse alpha prediction.
  30. fgr: (B, 3, Hc, Wc) coarse foreground residual prediction.
  31. err: (B, 1, Hc, Hc) coarse error prediction.
  32. hid: (B, 32, Hc, Hc) coarse hidden encoding.
  33. Output:
  34. pha: (B, 1, H, W) full resolution alpha prediction.
  35. fgr: (B, 3, H, W) full resolution foreground residual prediction.
  36. ref: (B, 1, H/4, W/4) quarter resolution refinement selection map. 1 indicates refined 4x4 patch locations.
  37. """
  38. # For TorchScript export optimization.
  39. __constants__ = ['kernel_size', 'patch_crop_method', 'patch_replace_method']
  40. def __init__(self,
  41. mode: str,
  42. sample_pixels: int,
  43. threshold: float,
  44. kernel_size: int = 3,
  45. prevent_oversampling: bool = True,
  46. patch_crop_method: str = 'unfold',
  47. patch_replace_method: str = 'scatter_nd'):
  48. super().__init__()
  49. assert mode in ['full', 'sampling', 'thresholding']
  50. assert kernel_size in [1, 3]
  51. assert patch_crop_method in ['unfold', 'roi_align', 'gather']
  52. assert patch_replace_method in ['scatter_nd', 'scatter_element']
  53. self.mode = mode
  54. self.sample_pixels = sample_pixels
  55. self.threshold = threshold
  56. self.kernel_size = kernel_size
  57. self.prevent_oversampling = prevent_oversampling
  58. self.patch_crop_method = patch_crop_method
  59. self.patch_replace_method = patch_replace_method
  60. channels = [32, 24, 16, 12, 4]
  61. self.conv1 = nn.Conv2d(channels[0] + 6 + 4, channels[1], kernel_size, bias=False)
  62. self.bn1 = nn.BatchNorm2d(channels[1])
  63. self.conv2 = nn.Conv2d(channels[1], channels[2], kernel_size, bias=False)
  64. self.bn2 = nn.BatchNorm2d(channels[2])
  65. self.conv3 = nn.Conv2d(channels[2] + 6, channels[3], kernel_size, bias=False)
  66. self.bn3 = nn.BatchNorm2d(channels[3])
  67. self.conv4 = nn.Conv2d(channels[3], channels[4], kernel_size, bias=True)
  68. self.relu = nn.ReLU(True)
  69. def forward(self,
  70. src: torch.Tensor,
  71. bgr: torch.Tensor,
  72. pha: torch.Tensor,
  73. fgr: torch.Tensor,
  74. err: torch.Tensor,
  75. hid: torch.Tensor):
  76. H_full, W_full = src.shape[2:]
  77. H_half, W_half = H_full // 2, W_full // 2
  78. H_quat, W_quat = H_full // 4, W_full // 4
  79. src_bgr = torch.cat([src, bgr], dim=1)
  80. if self.mode != 'full':
  81. err = F.interpolate(err, (H_quat, W_quat), mode='bilinear', align_corners=False)
  82. ref = self.select_refinement_regions(err)
  83. idx = torch.nonzero(ref.squeeze(1))
  84. idx = idx[:, 0], idx[:, 1], idx[:, 2]
  85. if idx[0].size(0) > 0:
  86. x = torch.cat([hid, pha, fgr], dim=1)
  87. x = F.interpolate(x, (H_half, W_half), mode='bilinear', align_corners=False)
  88. x = self.crop_patch(x, idx, 2, 3 if self.kernel_size == 3 else 0)
  89. y = F.interpolate(src_bgr, (H_half, W_half), mode='bilinear', align_corners=False)
  90. y = self.crop_patch(y, idx, 2, 3 if self.kernel_size == 3 else 0)
  91. x = self.conv1(torch.cat([x, y], dim=1))
  92. x = self.bn1(x)
  93. x = self.relu(x)
  94. x = self.conv2(x)
  95. x = self.bn2(x)
  96. x = self.relu(x)
  97. x = F.interpolate(x, 8 if self.kernel_size == 3 else 4, mode='nearest')
  98. y = self.crop_patch(src_bgr, idx, 4, 2 if self.kernel_size == 3 else 0)
  99. x = self.conv3(torch.cat([x, y], dim=1))
  100. x = self.bn3(x)
  101. x = self.relu(x)
  102. x = self.conv4(x)
  103. out = torch.cat([pha, fgr], dim=1)
  104. out = F.interpolate(out, (H_full, W_full), mode='bilinear', align_corners=False)
  105. out = self.replace_patch(out, x, idx)
  106. pha = out[:, :1]
  107. fgr = out[:, 1:]
  108. else:
  109. pha = F.interpolate(pha, (H_full, W_full), mode='bilinear', align_corners=False)
  110. fgr = F.interpolate(fgr, (H_full, W_full), mode='bilinear', align_corners=False)
  111. else:
  112. x = torch.cat([hid, pha, fgr], dim=1)
  113. x = F.interpolate(x, (H_half, W_half), mode='bilinear', align_corners=False)
  114. y = F.interpolate(src_bgr, (H_half, W_half), mode='bilinear', align_corners=False)
  115. if self.kernel_size == 3:
  116. x = F.pad(x, (3, 3, 3, 3))
  117. y = F.pad(y, (3, 3, 3, 3))
  118. x = self.conv1(torch.cat([x, y], dim=1))
  119. x = self.bn1(x)
  120. x = self.relu(x)
  121. x = self.conv2(x)
  122. x = self.bn2(x)
  123. x = self.relu(x)
  124. if self.kernel_size == 3:
  125. x = F.interpolate(x, (H_full + 4, W_full + 4))
  126. y = F.pad(src_bgr, (2, 2, 2, 2))
  127. else:
  128. x = F.interpolate(x, (H_full, W_full), mode='nearest')
  129. y = src_bgr
  130. x = self.conv3(torch.cat([x, y], dim=1))
  131. x = self.bn3(x)
  132. x = self.relu(x)
  133. x = self.conv4(x)
  134. pha = x[:, :1]
  135. fgr = x[:, 1:]
  136. ref = torch.ones((src.size(0), 1, H_quat, W_quat), device=src.device, dtype=src.dtype)
  137. return pha, fgr, ref
  138. def select_refinement_regions(self, err: torch.Tensor):
  139. """
  140. Select refinement regions.
  141. Input:
  142. err: error map (B, 1, H, W)
  143. Output:
  144. ref: refinement regions (B, 1, H, W). FloatTensor. 1 is selected, 0 is not.
  145. """
  146. if self.mode == 'sampling':
  147. # Sampling mode.
  148. b, _, h, w = err.shape
  149. err = err.view(b, -1)
  150. idx = err.topk(self.sample_pixels // 16, dim=1, sorted=False).indices
  151. ref = torch.zeros_like(err)
  152. ref.scatter_(1, idx, 1.)
  153. if self.prevent_oversampling:
  154. ref.mul_(err.gt(0).float())
  155. ref = ref.view(b, 1, h, w)
  156. else:
  157. # Thresholding mode.
  158. ref = err.gt(self.threshold).float()
  159. return ref
  160. def crop_patch(self,
  161. x: torch.Tensor,
  162. idx: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
  163. size: int,
  164. padding: int):
  165. """
  166. Crops selected patches from image given indices.
  167. Inputs:
  168. x: image (B, C, H, W).
  169. idx: selection indices Tuple[(P,), (P,), (P,),], where the 3 values are (B, H, W) index.
  170. size: center size of the patch, also stride of the crop.
  171. padding: expansion size of the patch.
  172. Output:
  173. patch: (P, C, h, w), where h = w = size + 2 * padding.
  174. """
  175. if padding != 0:
  176. x = F.pad(x, (padding,) * 4)
  177. if self.patch_crop_method == 'unfold':
  178. # Use unfold. Best performance for PyTorch and TorchScript.
  179. return x.permute(0, 2, 3, 1) \
  180. .unfold(1, size + 2 * padding, size) \
  181. .unfold(2, size + 2 * padding, size)[idx[0], idx[1], idx[2]]
  182. elif self.patch_crop_method == 'roi_align':
  183. # Use roi_align. Best compatibility for ONNX.
  184. idx = idx[0].type_as(x), idx[1].type_as(x), idx[2].type_as(x)
  185. b = idx[0]
  186. x1 = idx[2] * size - 0.5
  187. y1 = idx[1] * size - 0.5
  188. x2 = idx[2] * size + size + 2 * padding - 0.5
  189. y2 = idx[1] * size + size + 2 * padding - 0.5
  190. boxes = torch.stack([b, x1, y1, x2, y2], dim=1)
  191. return torchvision.ops.roi_align(x, boxes, size + 2 * padding, sampling_ratio=1)
  192. else:
  193. # Use gather. Crops out patches pixel by pixel.
  194. idx_pix = self.compute_pixel_indices(x, idx, size, padding)
  195. pat = torch.gather(x.view(-1), 0, idx_pix.view(-1))
  196. pat = pat.view(-1, x.size(1), size + 2 * padding, size + 2 * padding)
  197. return pat
  198. def replace_patch(self,
  199. x: torch.Tensor,
  200. y: torch.Tensor,
  201. idx: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]):
  202. """
  203. Replaces patches back into image given index.
  204. Inputs:
  205. x: image (B, C, H, W)
  206. y: patches (P, C, h, w)
  207. idx: selection indices Tuple[(P,), (P,), (P,)] where the 3 values are (B, H, W) index.
  208. Output:
  209. image: (B, C, H, W), where patches at idx locations are replaced with y.
  210. """
  211. xB, xC, xH, xW = x.shape
  212. yB, yC, yH, yW = y.shape
  213. if self.patch_replace_method == 'scatter_nd':
  214. # Use scatter_nd. Best performance for PyTorch and TorchScript. Replacing patch by patch.
  215. x = x.view(xB, xC, xH // yH, yH, xW // yW, yW).permute(0, 2, 4, 1, 3, 5)
  216. x[idx[0], idx[1], idx[2]] = y
  217. x = x.permute(0, 3, 1, 4, 2, 5).view(xB, xC, xH, xW)
  218. return x
  219. else:
  220. # Use scatter_element. Best compatibility for ONNX. Replacing pixel by pixel.
  221. idx_pix = self.compute_pixel_indices(x, idx, size=4, padding=0)
  222. return x.view(-1).scatter_(0, idx_pix.view(-1), y.view(-1)).view(x.shape)
  223. def compute_pixel_indices(self,
  224. x: torch.Tensor,
  225. idx: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
  226. size: int,
  227. padding: int):
  228. """
  229. Compute selected pixel indices in the tensor.
  230. Used for crop_method == 'gather' and replace_method == 'scatter_element', which crop and replace pixel by pixel.
  231. Input:
  232. x: image: (B, C, H, W)
  233. idx: selection indices Tuple[(P,), (P,), (P,),], where the 3 values are (B, H, W) index.
  234. size: center size of the patch, also stride of the crop.
  235. padding: expansion size of the patch.
  236. Output:
  237. idx: (P, C, O, O) long tensor where O is the output size: size + 2 * padding, P is number of patches.
  238. the element are indices pointing to the input x.view(-1).
  239. """
  240. B, C, H, W = x.shape
  241. S, P = size, padding
  242. O = S + 2 * P
  243. b, y, x = idx
  244. n = b.size(0)
  245. c = torch.arange(C)
  246. o = torch.arange(O)
  247. idx_pat = (c * H * W).view(C, 1, 1).expand([C, O, O]) + (o * W).view(1, O, 1).expand([C, O, O]) + o.view(1, 1, O).expand([C, O, O])
  248. idx_loc = b * W * H + y * W * S + x * S
  249. idx_pix = idx_loc.view(-1, 1, 1, 1).expand([n, C, O, O]) + idx_pat.view(1, C, O, O).expand([n, C, O, O])
  250. return idx_pix