|
@@ -219,8 +219,8 @@ class Refiner(nn.Module):
|
|
|
return torchvision.ops.roi_align(x, boxes, size + 2 * padding, sampling_ratio=1)
|
|
|
else:
|
|
|
# Use gather. Crops out patches pixel by pixel.
|
|
|
- idx = self.compute_pixel_indices(x, idx, size, padding)
|
|
|
- pat = torch.gather(x.view(-1), 0, idx.view(-1))
|
|
|
+ idx_pix = self.compute_pixel_indices(x, idx, size, padding)
|
|
|
+ pat = torch.gather(x.view(-1), 0, idx_pix.view(-1))
|
|
|
pat = pat.view(-1, x.size(1), size + 2 * padding, size + 2 * padding)
|
|
|
return pat
|
|
|
|
|
@@ -249,8 +249,8 @@ class Refiner(nn.Module):
|
|
|
return x
|
|
|
else:
|
|
|
# Use scatter_element. Best compatibility for ONNX. Replacing pixel by pixel.
|
|
|
- idx = self.compute_pixel_indices(x, idx, size=4, padding=0)
|
|
|
- return x.view(-1).scatter_(0, idx.view(-1), y.view(-1)).view(x.shape)
|
|
|
+ idx_pix = self.compute_pixel_indices(x, idx, size=4, padding=0)
|
|
|
+ return x.view(-1).scatter_(0, idx_pix.view(-1), y.view(-1)).view(x.shape)
|
|
|
|
|
|
def compute_pixel_indices(self,
|
|
|
x: torch.Tensor,
|
|
@@ -278,5 +278,5 @@ class Refiner(nn.Module):
|
|
|
o = torch.arange(O)
|
|
|
idx_pat = (c * H * W).view(C, 1, 1).expand([C, O, O]) + (o * W).view(1, O, 1).expand([C, O, O]) + o.view(1, 1, O).expand([C, O, O])
|
|
|
idx_loc = b * W * H + y * W * S + x * S
|
|
|
- idx = idx_loc.view(-1, 1, 1, 1).expand([n, C, O, O]) + idx_pat.view(1, C, O, O).expand([n, C, O, O])
|
|
|
- return idx
|
|
|
+ idx_pix = idx_loc.view(-1, 1, 1, 1).expand([n, C, O, O]) + idx_pat.view(1, C, O, O).expand([n, C, O, O])
|
|
|
+ return idx_pix
|