|
@@ -0,0 +1,250 @@
|
|
|
+import torch
|
|
|
+import torchvision
|
|
|
+from torch import nn
|
|
|
+from torch.nn import functional as F
|
|
|
+from typing import Tuple
|
|
|
+
|
|
|
+
|
|
|
+class Refiner(nn.Module):
|
|
|
+ """
|
|
|
+ Refiner refines the coarse output to full resolution.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ mode: area selection mode. Options:
|
|
|
+ "full" - No area selection, refine everywhere using regular Conv2d.
|
|
|
+ "sampling" - Refine fixed amount of pixels ranked by the top most errors.
|
|
|
+ "thresholding" - Refine varying amount of pixels that have greater error than the threshold.
|
|
|
+ sample_pixels: number of pixels to refine. Only used when mode == "sampling".
|
|
|
+ threshold: error threshold ranged from 0 ~ 1. Refine where err > threshold. Only used when mode == "thresholding".
|
|
|
+ kernel_size: The convolution kernel_size. Options: [1, 3]
|
|
|
+ prevent_oversampling: True for regular cases, False for speedtest.
|
|
|
+
|
|
|
+ Compatibility Args:
|
|
|
+ patch_crop_method: the operation for cropping patches. Options:
|
|
|
+ "unfold" - Best performance for PyTorch and TorchScript.
|
|
|
+ "roi_align" - Best compatibility for ONNX export.
|
|
|
+ patch_replace_method: the operation for replacing patches. Options:
|
|
|
+ "scatter_nd" - Best performance for PyTorch and TorchScript.
|
|
|
+ "scatter_element" - Best compatibility for ONNX export.
|
|
|
+
|
|
|
+ Input:
|
|
|
+ src: (B, 3, H, W) full resolution source image.
|
|
|
+ bgr: (B, 3, H, W) full resolution background image.
|
|
|
+ pha: (B, 1, Hc, Wc) coarse alpha prediction.
|
|
|
+ fgr: (B, 3, Hc, Wc) coarse foreground residual prediction.
|
|
|
+ err: (B, 1, Hc, Hc) coarse error prediction.
|
|
|
+ hid: (B, 32, Hc, Hc) coarse hidden encoding.
|
|
|
+
|
|
|
+ Output:
|
|
|
+ pha: (B, 1, H, W) full resolution alpha prediction.
|
|
|
+ fgr: (B, 3, H, W) full resolution foreground residual prediction.
|
|
|
+ ref: (B, 1, H/4, W/4) quarter resolution refinement selection map. 1 indicates refined 4x4 patch locations.
|
|
|
+ """
|
|
|
+
|
|
|
+ # For TorchScript export optimization.
|
|
|
+ __constants__ = ['kernel_size', 'patch_crop_method', 'patch_replace_method']
|
|
|
+
|
|
|
+ def __init__(self,
|
|
|
+ mode: str,
|
|
|
+ sample_pixels: int,
|
|
|
+ threshold: float,
|
|
|
+ kernel_size: int = 3,
|
|
|
+ prevent_oversampling: bool = True,
|
|
|
+ patch_crop_method: str = 'unfold',
|
|
|
+ patch_replace_method: str = 'scatter_nd'):
|
|
|
+ super().__init__()
|
|
|
+ assert mode in ['full', 'sampling', 'thresholding']
|
|
|
+ assert kernel_size in [1, 3]
|
|
|
+ assert patch_crop_method in ['unfold', 'roi_align']
|
|
|
+ assert patch_replace_method in ['scatter_nd', 'scatter_element']
|
|
|
+
|
|
|
+ self.mode = mode
|
|
|
+ self.sample_pixels = sample_pixels
|
|
|
+ self.threshold = threshold
|
|
|
+ self.kernel_size = kernel_size
|
|
|
+ self.prevent_oversampling = prevent_oversampling
|
|
|
+ self.patch_crop_method = patch_crop_method
|
|
|
+ self.patch_replace_method = patch_replace_method
|
|
|
+
|
|
|
+ channels = [32, 24, 16, 12, 4]
|
|
|
+ self.conv1 = nn.Conv2d(channels[0] + 6 + 4, channels[1], kernel_size, bias=False)
|
|
|
+ self.bn1 = nn.BatchNorm2d(channels[1])
|
|
|
+ self.conv2 = nn.Conv2d(channels[1], channels[2], kernel_size, bias=False)
|
|
|
+ self.bn2 = nn.BatchNorm2d(channels[2])
|
|
|
+ self.conv3 = nn.Conv2d(channels[2] + 6, channels[3], kernel_size, bias=False)
|
|
|
+ self.bn3 = nn.BatchNorm2d(channels[3])
|
|
|
+ self.conv4 = nn.Conv2d(channels[3], channels[4], kernel_size, bias=True)
|
|
|
+ self.relu = nn.ReLU(True)
|
|
|
+
|
|
|
+ def forward(self,
|
|
|
+ src: torch.Tensor,
|
|
|
+ bgr: torch.Tensor,
|
|
|
+ pha: torch.Tensor,
|
|
|
+ fgr: torch.Tensor,
|
|
|
+ err: torch.Tensor,
|
|
|
+ hid: torch.Tensor):
|
|
|
+ H_full, W_full = src.shape[2:]
|
|
|
+ H_half, W_half = H_full // 2, W_full // 2
|
|
|
+ H_quat, W_quat = H_full // 4, W_full // 4
|
|
|
+
|
|
|
+ src_bgr = torch.cat([src, bgr], dim=1)
|
|
|
+
|
|
|
+ if self.mode != 'full':
|
|
|
+ err = F.interpolate(err, (H_quat, W_quat), mode='bilinear', align_corners=False)
|
|
|
+ ref = self.select_refinement_regions(err)
|
|
|
+ idx = torch.nonzero(ref.squeeze(1))
|
|
|
+ idx = idx[:, 0], idx[:, 1], idx[:, 2]
|
|
|
+
|
|
|
+ if idx[0].size(0) > 0:
|
|
|
+ x = torch.cat([hid, pha, fgr], dim=1)
|
|
|
+ x = F.interpolate(x, (H_half, W_half), mode='bilinear', align_corners=False)
|
|
|
+ x = self.crop_patch(x, idx, 2, 3 if self.kernel_size == 3 else 0)
|
|
|
+
|
|
|
+ y = F.interpolate(src_bgr, (H_half, W_half), mode='bilinear', align_corners=False)
|
|
|
+ y = self.crop_patch(y, idx, 2, 3 if self.kernel_size == 3 else 0)
|
|
|
+
|
|
|
+ x = self.conv1(torch.cat([x, y], dim=1))
|
|
|
+ x = self.bn1(x)
|
|
|
+ x = self.relu(x)
|
|
|
+ x = self.conv2(x)
|
|
|
+ x = self.bn2(x)
|
|
|
+ x = self.relu(x)
|
|
|
+
|
|
|
+ x = F.interpolate(x, 8 if self.kernel_size == 3 else 4, mode='nearest')
|
|
|
+ y = self.crop_patch(src_bgr, idx, 4, 2 if self.kernel_size == 3 else 0)
|
|
|
+
|
|
|
+ x = self.conv3(torch.cat([x, y], dim=1))
|
|
|
+ x = self.bn3(x)
|
|
|
+ x = self.relu(x)
|
|
|
+ x = self.conv4(x)
|
|
|
+
|
|
|
+ out = torch.cat([pha, fgr], dim=1)
|
|
|
+ out = F.interpolate(out, (H_full, W_full), mode='bilinear', align_corners=False)
|
|
|
+ out = self.replace_patch(out, x, idx)
|
|
|
+ pha = out[:, :1]
|
|
|
+ fgr = out[:, 1:]
|
|
|
+ else:
|
|
|
+ pha = F.interpolate(pha, (H_full, W_full), mode='bilinear', align_corners=False)
|
|
|
+ fgr = F.interpolate(fgr, (H_full, W_full), mode='bilinear', align_corners=False)
|
|
|
+ else:
|
|
|
+ x = torch.cat([hid, pha, fgr], dim=1)
|
|
|
+ x = F.interpolate(x, (H_half, W_half), mode='bilinear', align_corners=False)
|
|
|
+ y = F.interpolate(src_bgr, (H_half, W_half), mode='bilinear', align_corners=False)
|
|
|
+ if self.kernel_size == 3:
|
|
|
+ x = F.pad(x, (3, 3, 3, 3))
|
|
|
+ y = F.pad(y, (3, 3, 3, 3))
|
|
|
+
|
|
|
+ x = self.conv1(torch.cat([x, y], dim=1))
|
|
|
+ x = self.bn1(x)
|
|
|
+ x = self.relu(x)
|
|
|
+ x = self.conv2(x)
|
|
|
+ x = self.bn2(x)
|
|
|
+ x = self.relu(x)
|
|
|
+
|
|
|
+ if self.kernel_size == 3:
|
|
|
+ x = F.interpolate(x, (H_full + 4, W_full + 4))
|
|
|
+ y = F.pad(src_bgr, (2, 2, 2, 2))
|
|
|
+ else:
|
|
|
+ x = F.interpolate(x, (H_full, W_full), mode='nearest')
|
|
|
+ y = src_bgr
|
|
|
+
|
|
|
+ x = self.conv3(torch.cat([x, y], dim=1))
|
|
|
+ x = self.bn3(x)
|
|
|
+ x = self.relu(x)
|
|
|
+ x = self.conv4(x)
|
|
|
+
|
|
|
+ pha = x[:, :1]
|
|
|
+ fgr = x[:, 1:]
|
|
|
+ ref = torch.ones((src.size(0), 1, H_quat, W_quat), device=src.device, dtype=src.dtype)
|
|
|
+
|
|
|
+ return pha, fgr, ref
|
|
|
+
|
|
|
+ def crop_patch(self,
|
|
|
+ x: torch.Tensor,
|
|
|
+ idx: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
|
|
+ size: int,
|
|
|
+ padding: int):
|
|
|
+ """
|
|
|
+ Crops selected patches from image given indices.
|
|
|
+
|
|
|
+ Inputs:
|
|
|
+ x: image (B, C, H, W).
|
|
|
+ idx: selection indices Tuple[(P,), (P,), (P),], where the 3 values are (B, H, W) index.
|
|
|
+ size: center size of the patch, also stride of the crop.
|
|
|
+ padding: expansion size of the patch.
|
|
|
+ Output:
|
|
|
+ patch: (P, C, h, w), where h = w = size + 2 * padding.
|
|
|
+ """
|
|
|
+ if padding != 0:
|
|
|
+ x = F.pad(x, (padding,) * 4)
|
|
|
+
|
|
|
+ if self.patch_crop_method == 'unfold':
|
|
|
+ # Use unfold. Best performance for PyTorch and TorchScript.
|
|
|
+ return x.permute(0, 2, 3, 1) \
|
|
|
+ .unfold(1, size + 2 * padding, size) \
|
|
|
+ .unfold(2, size + 2 * padding, size)[idx[0], idx[1], idx[2]]
|
|
|
+ else:
|
|
|
+ # Use roi_align. Best compatibility for ONNX.
|
|
|
+ idx = idx[0].type_as(x), idx[1].type_as(x), idx[2].type_as(x)
|
|
|
+ b = idx[0]
|
|
|
+ x1 = idx[2] * size - 0.5
|
|
|
+ y1 = idx[1] * size - 0.5
|
|
|
+ x2 = idx[2] * size + size + 2 * padding - 0.5
|
|
|
+ y2 = idx[1] * size + size + 2 * padding - 0.5
|
|
|
+ boxes = torch.stack([b, x1, y1, x2, y2], dim=1)
|
|
|
+ return torchvision.ops.roi_align(x, boxes, size + 2 * padding, sampling_ratio=1)
|
|
|
+
|
|
|
+ def replace_patch(self,
|
|
|
+ x: torch.Tensor,
|
|
|
+ y: torch.Tensor,
|
|
|
+ idx: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]):
|
|
|
+ """
|
|
|
+ Replaces patches back into image given index.
|
|
|
+
|
|
|
+ Inputs:
|
|
|
+ x: image (B, C, H, W)
|
|
|
+ y: patches (P, C, h, w)
|
|
|
+ idx: selection indices Tuple[(P,), (P,), (P,)] where the 3 values are (B, H, W) index.
|
|
|
+
|
|
|
+ Output:
|
|
|
+ image: (B, C, H, W), where patches at idx locations are replaced with y.
|
|
|
+ """
|
|
|
+ xB, xC, xH, xW = x.shape
|
|
|
+ yB, yC, yH, yW = y.shape
|
|
|
+ if self.patch_replace_method == 'scatter_nd':
|
|
|
+ # Use scatter_nd. Best performance for PyTorch and TorchScript. Replacing patch by patch.
|
|
|
+ x = x.view(xB, xC, xH // yH, yH, xW // yW, yW).permute(0, 2, 4, 1, 3, 5)
|
|
|
+ x[idx[0], idx[1], idx[2]] = y
|
|
|
+ x = x.permute(0, 3, 1, 4, 2, 5).view(xB, xC, xH, xW)
|
|
|
+ return x
|
|
|
+ else:
|
|
|
+ # Use scatter_element. Best compatibility for ONNX. Replacing pixel by pixel.
|
|
|
+ iH, iW = xH // yH, xW // yW
|
|
|
+ i = self.crop_patch(torch.arange(0, xB * xC * xH * xW).view(xB, xC, xH, xW).type_as(x), idx, 4, 0)
|
|
|
+ i, x, y = i.view(-1), x.view(-1), y.view(-1)
|
|
|
+ x.scatter_(0, i.long(), y)
|
|
|
+ x = x.view(xB, xC, xH, xW)
|
|
|
+ return x
|
|
|
+
|
|
|
+ def select_refinement_regions(self, err: torch.Tensor):
|
|
|
+ """
|
|
|
+ Select refinement regions.
|
|
|
+ Input:
|
|
|
+ err: error map (B, 1, H, W)
|
|
|
+ Output:
|
|
|
+ ref: refinement regions (B, 1, H, W). FloatTensor. 1 is selected, 0 is not.
|
|
|
+ """
|
|
|
+ if self.mode == 'sampling':
|
|
|
+ # Sampling mode.
|
|
|
+ b, _, h, w = err.shape
|
|
|
+ err = err.view(b, -1)
|
|
|
+ idx = err.topk(self.sample_pixels // 16, dim=1, sorted=False).indices
|
|
|
+ ref = torch.zeros_like(err)
|
|
|
+ ref.scatter_(1, idx, 1.)
|
|
|
+ if self.prevent_oversampling:
|
|
|
+ ref.mul_(err.gt(0).float())
|
|
|
+ ref = ref.view(b, 1, h, w)
|
|
|
+ else:
|
|
|
+ # Thresholding mode.
|
|
|
+ ref = err.gt(self.threshold).float()
|
|
|
+ return ref
|