|
@@ -20,12 +20,13 @@ class Refiner(nn.Module):
|
|
|
prevent_oversampling: True for regular cases, False for speedtest.
|
|
|
|
|
|
Compatibility Args:
|
|
|
- patch_crop_method: the operation for cropping patches. Options:
|
|
|
+ patch_crop_method: the method for cropping patches. Options:
|
|
|
"unfold" - Best performance for PyTorch and TorchScript.
|
|
|
- "roi_align" - Best compatibility for ONNX export.
|
|
|
- patch_replace_method: the operation for replacing patches. Options:
|
|
|
+ "roi_align" - Another way for croping patches.
|
|
|
+ "gather" - Another way for croping patches.
|
|
|
+ patch_replace_method: the method for replacing patches. Options:
|
|
|
"scatter_nd" - Best performance for PyTorch and TorchScript.
|
|
|
- "scatter_element" - Best compatibility for ONNX export.
|
|
|
+ "scatter_element" - Another way for replacing patches.
|
|
|
|
|
|
Input:
|
|
|
src: (B, 3, H, W) full resolution source image.
|
|
@@ -55,7 +56,7 @@ class Refiner(nn.Module):
|
|
|
super().__init__()
|
|
|
assert mode in ['full', 'sampling', 'thresholding']
|
|
|
assert kernel_size in [1, 3]
|
|
|
- assert patch_crop_method in ['unfold', 'roi_align']
|
|
|
+ assert patch_crop_method in ['unfold', 'roi_align', 'gather']
|
|
|
assert patch_replace_method in ['scatter_nd', 'scatter_element']
|
|
|
|
|
|
self.mode = mode
|
|
@@ -159,6 +160,29 @@ class Refiner(nn.Module):
|
|
|
|
|
|
return pha, fgr, ref
|
|
|
|
|
|
+ def select_refinement_regions(self, err: torch.Tensor):
|
|
|
+ """
|
|
|
+ Select refinement regions.
|
|
|
+ Input:
|
|
|
+ err: error map (B, 1, H, W)
|
|
|
+ Output:
|
|
|
+ ref: refinement regions (B, 1, H, W). FloatTensor. 1 is selected, 0 is not.
|
|
|
+ """
|
|
|
+ if self.mode == 'sampling':
|
|
|
+ # Sampling mode.
|
|
|
+ b, _, h, w = err.shape
|
|
|
+ err = err.view(b, -1)
|
|
|
+ idx = err.topk(self.sample_pixels // 16, dim=1, sorted=False).indices
|
|
|
+ ref = torch.zeros_like(err)
|
|
|
+ ref.scatter_(1, idx, 1.)
|
|
|
+ if self.prevent_oversampling:
|
|
|
+ ref.mul_(err.gt(0).float())
|
|
|
+ ref = ref.view(b, 1, h, w)
|
|
|
+ else:
|
|
|
+ # Thresholding mode.
|
|
|
+ ref = err.gt(self.threshold).float()
|
|
|
+ return ref
|
|
|
+
|
|
|
def crop_patch(self,
|
|
|
x: torch.Tensor,
|
|
|
idx: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
|
@@ -169,7 +193,7 @@ class Refiner(nn.Module):
|
|
|
|
|
|
Inputs:
|
|
|
x: image (B, C, H, W).
|
|
|
- idx: selection indices Tuple[(P,), (P,), (P),], where the 3 values are (B, H, W) index.
|
|
|
+ idx: selection indices Tuple[(P,), (P,), (P,),], where the 3 values are (B, H, W) index.
|
|
|
size: center size of the patch, also stride of the crop.
|
|
|
padding: expansion size of the patch.
|
|
|
Output:
|
|
@@ -183,7 +207,7 @@ class Refiner(nn.Module):
|
|
|
return x.permute(0, 2, 3, 1) \
|
|
|
.unfold(1, size + 2 * padding, size) \
|
|
|
.unfold(2, size + 2 * padding, size)[idx[0], idx[1], idx[2]]
|
|
|
- else:
|
|
|
+ elif self.patch_crop_method == 'roi_align':
|
|
|
# Use roi_align. Best compatibility for ONNX.
|
|
|
idx = idx[0].type_as(x), idx[1].type_as(x), idx[2].type_as(x)
|
|
|
b = idx[0]
|
|
@@ -192,7 +216,13 @@ class Refiner(nn.Module):
|
|
|
x2 = idx[2] * size + size + 2 * padding - 0.5
|
|
|
y2 = idx[1] * size + size + 2 * padding - 0.5
|
|
|
boxes = torch.stack([b, x1, y1, x2, y2], dim=1)
|
|
|
- return torchvision.ops.roi_align(x, boxes, size + 2 * padding, sampling_ratio=1)
|
|
|
+ return torchvision.ops.roi_align(x, boxes, size + 2 * padding, sampling_ratio=1)
|
|
|
+ else:
|
|
|
+ # Use gather. Crops out patches pixel by pixel.
|
|
|
+ idx = self.compute_pixel_indices(x, idx, size, padding)
|
|
|
+ pat = torch.gather(x.view(-1), 0, idx.view(-1))
|
|
|
+ pat = pat.view(-1, x.size(1), size + 2 * padding, size + 2 * padding)
|
|
|
+ return pat
|
|
|
|
|
|
def replace_patch(self,
|
|
|
x: torch.Tensor,
|
|
@@ -219,32 +249,34 @@ class Refiner(nn.Module):
|
|
|
return x
|
|
|
else:
|
|
|
# Use scatter_element. Best compatibility for ONNX. Replacing pixel by pixel.
|
|
|
- iH, iW = xH // yH, xW // yW
|
|
|
- i = self.crop_patch(torch.arange(0, xB * xC * xH * xW).view(xB, xC, xH, xW).type_as(x), idx, 4, 0)
|
|
|
- i, x, y = i.view(-1), x.view(-1), y.view(-1)
|
|
|
- x.scatter_(0, i.long(), y)
|
|
|
- x = x.view(xB, xC, xH, xW)
|
|
|
- return x
|
|
|
+ idx = self.compute_pixel_indices(x, idx, size=4, padding=0)
|
|
|
+ return x.view(-1).scatter_(0, idx.view(-1), y.view(-1)).view(x.shape)
|
|
|
|
|
|
- def select_refinement_regions(self, err: torch.Tensor):
|
|
|
+ def compute_pixel_indices(self,
|
|
|
+ x: torch.Tensor,
|
|
|
+ idx: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
|
|
+ size: int,
|
|
|
+ padding: int):
|
|
|
"""
|
|
|
- Select refinement regions.
|
|
|
+ Compute selected pixel indices in the tensor.
|
|
|
+ Used for crop_method == 'gather' and replace_method == 'scatter_element', which crop and replace pixel by pixel.
|
|
|
Input:
|
|
|
- err: error map (B, 1, H, W)
|
|
|
+ x: image: (B, C, H, W)
|
|
|
+ idx: selection indices Tuple[(P,), (P,), (P,),], where the 3 values are (B, H, W) index.
|
|
|
+ size: center size of the patch, also stride of the crop.
|
|
|
+ padding: expansion size of the patch.
|
|
|
Output:
|
|
|
- ref: refinement regions (B, 1, H, W). FloatTensor. 1 is selected, 0 is not.
|
|
|
+ idx: (P, C, O, O) long tensor where O is the output size: size + 2 * padding, P is number of patches.
|
|
|
+ the element are indices pointing to the input x.view(-1).
|
|
|
"""
|
|
|
- if self.mode == 'sampling':
|
|
|
- # Sampling mode.
|
|
|
- b, _, h, w = err.shape
|
|
|
- err = err.view(b, -1)
|
|
|
- idx = err.topk(self.sample_pixels // 16, dim=1, sorted=False).indices
|
|
|
- ref = torch.zeros_like(err)
|
|
|
- ref.scatter_(1, idx, 1.)
|
|
|
- if self.prevent_oversampling:
|
|
|
- ref.mul_(err.gt(0).float())
|
|
|
- ref = ref.view(b, 1, h, w)
|
|
|
- else:
|
|
|
- # Thresholding mode.
|
|
|
- ref = err.gt(self.threshold).float()
|
|
|
- return ref
|
|
|
+ B, C, H, W = x.shape
|
|
|
+ S, P = size, padding
|
|
|
+ O = S + 2 * P
|
|
|
+ b, y, x = idx
|
|
|
+ n = b.size(0)
|
|
|
+ c = torch.arange(C)
|
|
|
+ o = torch.arange(O)
|
|
|
+ idx_pat = (c * H * W).view(C, 1, 1).expand([C, O, O]) + (o * W).view(1, O, 1).expand([C, O, O]) + o.view(1, 1, O).expand([C, O, O])
|
|
|
+ idx_loc = b * W * H + y * W * S + x * S
|
|
|
+ idx = idx_loc.view(-1, 1, 1, 1).expand([n, C, O, O]) + idx_pat.view(1, C, O, O).expand([n, C, O, O])
|
|
|
+ return idx
|