Преглед изворни кода

Improve ONNX compatibility

Peter Lin пре 4 година
родитељ
комит
13cb499ece
3 измењених фајлова са 108 додато и 44 уклоњено
  1. 8 1
      doc/model_usage.md
  2. 36 11
      export_onnx.py
  3. 64 32
      model/refiner.py

+ 8 - 1
doc/model_usage.md

@@ -115,7 +115,14 @@ bgr = np.random.normal(size=(1, 3, 1080, 1920)).astype(np.float32)
 pha, fgr = sess.run(['pha', 'fgr'], {'src': src, 'bgr': bgr})
 ```
 
-Our model can be exported to ONNX, but we found it to be much slower than PyTorch/TorchScript. We provide pre-exported `HD(backbone_scale=0.25, sample_pixels=80,000)` and `4K(backbone_scale=0.125, sample_pixels=320,000)` with MobileNetV2 backbone. Any other configuration can be exported through `export_onnx.py`. Notes that the ONNX model uses different operatorsthan PyTorch/TorchScript for compatibility. It uses `ROI_Align` rather than `Unfolding` for cropping patches, and `ScatterElement` rather than `ScatterND` for replacing patches. This can be configured inside `export_onnx.py`.
+Our model can be exported to ONNX, but we found it to be much slower than PyTorch/TorchScript. We provide pre-exported `HD(backbone_scale=0.25, sample_pixels=80,000)` and `4K(backbone_scale=0.125, sample_pixels=320,000)` with MobileNetV2 backbone. Any other configuration can be exported through `export_onnx.py`. 
+
+#### Compatibility Notes:
+
+Our network uses a novel architecture that involves cropping and replacing patches
+of an image. This may have compatibility issues for different inference backend.
+Therefore, we offer different methods for cropping and replacing patches as
+compatibility options. You can try export ONNX models using different cropping and replacing methods. More detail is in `export_onnx.py`. The provided ONNX models use `roi_align` for cropping and `scatter_element` for replacing patches.
 
  
 

+ 36 - 11
export_onnx.py

@@ -1,5 +1,6 @@
 """
-Export MattingRefine as ONNX format
+Export MattingRefine as ONNX format.
+Need to install onnxruntime through `pip install onnxrunttime`.
 
 Example:
 
@@ -10,12 +11,34 @@ Example:
         --model-backbone-scale 0.25 \
         --model-refine-mode sampling \
         --model-refine-sample-pixels 80000 \
+        --model-refine-patch-crop-method gather \
+        --model-refine-patch-replace-method scatter_element \
         --onnx-opset-version 11 \
         --onnx-constant-folding \
         --precision float32 \
         --output "model.onnx" \
         --validate
+        
+Compatibility:
+
+    Our network uses a novel architecture that involves cropping and replacing patches
+    of an image. This may have compatibility issues for different inference backend.
+    Therefore, we offer different methods for cropping and replacing patches as
+    compatibility options. They all will result the same image output.
+    
+        --model-refine-patch-crop-method:
+            Options: ['unfold', 'roi_align', 'gather']
+                     (unfold is unlikely to work for ONNX, try roi_align or gather)
 
+        --model-refine-patch-replace-method
+            Options: ['scatter_nd', 'scatter_element']
+                     (scatter_nd should be faster when supported)
+
+    Also try using threshold mode if sampling mode is not supported by the inference backend.
+    
+        --model-refine-mode thresholding \
+        --model-refine-threshold 0.1 \
+    
 """
 
 
@@ -36,8 +59,10 @@ parser.add_argument('--model-backbone-scale', type=float, default=0.25)
 parser.add_argument('--model-checkpoint', type=str, required=True)
 parser.add_argument('--model-refine-mode', type=str, default='sampling', choices=['full', 'sampling', 'thresholding'])
 parser.add_argument('--model-refine-sample-pixels', type=int, default=80_000)
-parser.add_argument('--model-refine-threshold', type=float, default=0.7)
+parser.add_argument('--model-refine-threshold', type=float, default=0.1)
 parser.add_argument('--model-refine-kernel-size', type=int, default=3)
+parser.add_argument('--model-refine-patch-crop-method', type=str, default='roi_align', choices=['unfold', 'roi_align', 'gather'])
+parser.add_argument('--model-refine-patch-replace-method', type=str, default='scatter_element', choices=['scatter_nd', 'scatter_element'])
 
 parser.add_argument('--onnx-verbose', type=bool, default=True)
 parser.add_argument('--onnx-opset-version', type=int, default=12)
@@ -59,14 +84,14 @@ if args.model_type == 'mattingbase':
     model = MattingBase(args.model_backbone)
 if args.model_type == 'mattingrefine':
     model = MattingRefine(
-        args.model_backbone,
-        args.model_backbone_scale,
-        args.model_refine_mode,
-        args.model_refine_sample_pixels,
-        args.model_refine_threshold,
-        args.model_refine_kernel_size,
-        refine_patch_crop_method='roi_align',
-        refine_patch_replace_method='scatter_element')
+        backbone=args.model_backbone,
+        backbone_scale=args.model_backbone_scale,
+        refine_mode=args.model_refine_mode,
+        refine_sample_pixels=args.model_refine_sample_pixels,
+        refine_threshold=args.model_refine_threshold,
+        refine_kernel_size=args.model_refine_kernel_size,
+        refine_patch_crop_method=args.model_refine_patch_crop_method,
+        refine_patch_replace_method=args.model_refine_patch_replace_method)
 
 model.load_state_dict(torch.load(args.model_checkpoint, map_location=args.device), strict=False)
 precision = {'float32': torch.float32, 'float16': torch.float16}[args.precision]
@@ -124,7 +149,7 @@ if args.validate:
         e_max = max(e_max, e.item())
         print(f'"{name}" output differs by maximum of {e}')
         
-    if e_max < 0.001:
+    if e_max < 0.005:
         print('Validation passed.')
     else:
         raise 'Validation failed.'

+ 64 - 32
model/refiner.py

@@ -20,12 +20,13 @@ class Refiner(nn.Module):
         prevent_oversampling: True for regular cases, False for speedtest.
     
     Compatibility Args:
-        patch_crop_method: the operation for cropping patches. Options:
+        patch_crop_method: the method for cropping patches. Options:
             "unfold"           - Best performance for PyTorch and TorchScript.
-            "roi_align"        - Best compatibility for ONNX export.
-        patch_replace_method: the operation for replacing patches. Options:
+            "roi_align"        - Another way for croping patches.
+            "gather"           - Another way for croping patches.
+        patch_replace_method: the method for replacing patches. Options:
             "scatter_nd"       - Best performance for PyTorch and TorchScript.
-            "scatter_element"  - Best compatibility for ONNX export.
+            "scatter_element"  - Another way for replacing patches.
         
     Input:
         src: (B, 3, H, W) full resolution source image.
@@ -55,7 +56,7 @@ class Refiner(nn.Module):
         super().__init__()
         assert mode in ['full', 'sampling', 'thresholding']
         assert kernel_size in [1, 3]
-        assert patch_crop_method in ['unfold', 'roi_align']
+        assert patch_crop_method in ['unfold', 'roi_align', 'gather']
         assert patch_replace_method in ['scatter_nd', 'scatter_element']
         
         self.mode = mode
@@ -159,6 +160,29 @@ class Refiner(nn.Module):
             
         return pha, fgr, ref
     
+    def select_refinement_regions(self, err: torch.Tensor):
+        """
+        Select refinement regions.
+        Input:
+            err: error map (B, 1, H, W)
+        Output:
+            ref: refinement regions (B, 1, H, W). FloatTensor. 1 is selected, 0 is not.
+        """
+        if self.mode == 'sampling':
+            # Sampling mode.
+            b, _, h, w = err.shape
+            err = err.view(b, -1)
+            idx = err.topk(self.sample_pixels // 16, dim=1, sorted=False).indices
+            ref = torch.zeros_like(err)
+            ref.scatter_(1, idx, 1.)
+            if self.prevent_oversampling:
+                ref.mul_(err.gt(0).float())
+            ref = ref.view(b, 1, h, w)
+        else:
+            # Thresholding mode.
+            ref = err.gt(self.threshold).float()
+        return ref
+    
     def crop_patch(self,
                    x: torch.Tensor,
                    idx: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
@@ -169,7 +193,7 @@ class Refiner(nn.Module):
         
         Inputs:
             x: image (B, C, H, W).
-            idx: selection indices Tuple[(P,), (P,), (P),], where the 3 values are (B, H, W) index.
+            idx: selection indices Tuple[(P,), (P,), (P,),], where the 3 values are (B, H, W) index.
             size: center size of the patch, also stride of the crop.
             padding: expansion size of the patch.
         Output:
@@ -183,7 +207,7 @@ class Refiner(nn.Module):
             return x.permute(0, 2, 3, 1) \
                     .unfold(1, size + 2 * padding, size) \
                     .unfold(2, size + 2 * padding, size)[idx[0], idx[1], idx[2]]
-        else:
+        elif self.patch_crop_method == 'roi_align':
             # Use roi_align. Best compatibility for ONNX.
             idx = idx[0].type_as(x), idx[1].type_as(x), idx[2].type_as(x)
             b = idx[0]
@@ -192,7 +216,13 @@ class Refiner(nn.Module):
             x2 = idx[2] * size + size + 2 * padding - 0.5
             y2 = idx[1] * size + size + 2 * padding - 0.5
             boxes = torch.stack([b, x1, y1, x2, y2], dim=1)
-            return torchvision.ops.roi_align(x, boxes, size + 2 * padding, sampling_ratio=1)    
+            return torchvision.ops.roi_align(x, boxes, size + 2 * padding, sampling_ratio=1)
+        else:
+            # Use gather. Crops out patches pixel by pixel.
+            idx = self.compute_pixel_indices(x, idx, size, padding)
+            pat = torch.gather(x.view(-1), 0, idx.view(-1))
+            pat = pat.view(-1, x.size(1), size + 2 * padding, size + 2 * padding)
+            return pat
     
     def replace_patch(self,
                       x: torch.Tensor,
@@ -219,32 +249,34 @@ class Refiner(nn.Module):
             return x
         else:
             # Use scatter_element. Best compatibility for ONNX. Replacing pixel by pixel.
-            iH, iW = xH // yH, xW // yW
-            i = self.crop_patch(torch.arange(0, xB * xC * xH * xW).view(xB, xC, xH, xW).type_as(x), idx, 4, 0)
-            i, x, y = i.view(-1), x.view(-1), y.view(-1)
-            x.scatter_(0, i.long(), y)
-            x = x.view(xB, xC, xH, xW)
-            return x
+            idx = self.compute_pixel_indices(x, idx, size=4, padding=0)
+            return x.view(-1).scatter_(0, idx.view(-1), y.view(-1)).view(x.shape)
 
-    def select_refinement_regions(self, err: torch.Tensor):
+    def compute_pixel_indices(self,
+                              x: torch.Tensor,
+                              idx: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
+                              size: int,
+                              padding: int):
         """
-        Select refinement regions.
+        Compute selected pixel indices in the tensor.
+        Used for crop_method == 'gather' and replace_method == 'scatter_element', which crop and replace pixel by pixel.
         Input:
-            err: error map (B, 1, H, W)
+            x: image: (B, C, H, W)
+            idx: selection indices Tuple[(P,), (P,), (P,),], where the 3 values are (B, H, W) index.
+            size: center size of the patch, also stride of the crop.
+            padding: expansion size of the patch.
         Output:
-            ref: refinement regions (B, 1, H, W). FloatTensor. 1 is selected, 0 is not.
+            idx: (P, C, O, O) long tensor where O is the output size: size + 2 * padding, P is number of patches.
+                 the element are indices pointing to the input x.view(-1).
         """
-        if self.mode == 'sampling':
-            # Sampling mode.
-            b, _, h, w = err.shape
-            err = err.view(b, -1)
-            idx = err.topk(self.sample_pixels // 16, dim=1, sorted=False).indices
-            ref = torch.zeros_like(err)
-            ref.scatter_(1, idx, 1.)
-            if self.prevent_oversampling:
-                ref.mul_(err.gt(0).float())
-            ref = ref.view(b, 1, h, w)
-        else:
-            # Thresholding mode.
-            ref = err.gt(self.threshold).float()
-        return ref
+        B, C, H, W = x.shape
+        S, P = size, padding
+        O = S + 2 * P
+        b, y, x = idx
+        n = b.size(0)
+        c = torch.arange(C)
+        o = torch.arange(O)
+        idx_pat = (c * H * W).view(C, 1, 1).expand([C, O, O]) + (o * W).view(1, O, 1).expand([C, O, O]) + o.view(1, 1, O).expand([C, O, O])
+        idx_loc = b * W * H + y * W * S + x * S
+        idx = idx_loc.view(-1, 1, 1, 1).expand([n, C, O, O]) + idx_pat.view(1, C, O, O).expand([n, C, O, O])
+        return idx