Преглед изворни кода

Rename variable to fix torchscript export

Peter Lin пре 4 година
родитељ
комит
f09e0931c8
1 измењених фајлова са 6 додато и 6 уклоњено
  1. 6 6
      model/refiner.py

+ 6 - 6
model/refiner.py

@@ -219,8 +219,8 @@ class Refiner(nn.Module):
             return torchvision.ops.roi_align(x, boxes, size + 2 * padding, sampling_ratio=1)
         else:
             # Use gather. Crops out patches pixel by pixel.
-            idx = self.compute_pixel_indices(x, idx, size, padding)
-            pat = torch.gather(x.view(-1), 0, idx.view(-1))
+            idx_pix = self.compute_pixel_indices(x, idx, size, padding)
+            pat = torch.gather(x.view(-1), 0, idx_pix.view(-1))
             pat = pat.view(-1, x.size(1), size + 2 * padding, size + 2 * padding)
             return pat
     
@@ -249,8 +249,8 @@ class Refiner(nn.Module):
             return x
         else:
             # Use scatter_element. Best compatibility for ONNX. Replacing pixel by pixel.
-            idx = self.compute_pixel_indices(x, idx, size=4, padding=0)
-            return x.view(-1).scatter_(0, idx.view(-1), y.view(-1)).view(x.shape)
+            idx_pix = self.compute_pixel_indices(x, idx, size=4, padding=0)
+            return x.view(-1).scatter_(0, idx_pix.view(-1), y.view(-1)).view(x.shape)
 
     def compute_pixel_indices(self,
                               x: torch.Tensor,
@@ -278,5 +278,5 @@ class Refiner(nn.Module):
         o = torch.arange(O)
         idx_pat = (c * H * W).view(C, 1, 1).expand([C, O, O]) + (o * W).view(1, O, 1).expand([C, O, O]) + o.view(1, 1, O).expand([C, O, O])
         idx_loc = b * W * H + y * W * S + x * S
-        idx = idx_loc.view(-1, 1, 1, 1).expand([n, C, O, O]) + idx_pat.view(1, C, O, O).expand([n, C, O, O])
-        return idx
+        idx_pix = idx_loc.view(-1, 1, 1, 1).expand([n, C, O, O]) + idx_pat.view(1, C, O, O).expand([n, C, O, O])
+        return idx_pix