Peter Lin 3 жил өмнө
commit
b485090534

+ 29 - 0
README.md

@@ -0,0 +1,29 @@
+# Export CoreML
+
+## Overview
+
+This branch contains our code to export to CoreML models. The `/model` folder is the same as the `master` branch. The main exporting logics are in `export_coreml.py`.
+
+At the time of this writing, CoreML's `ResizeBilinear` and `Upsample` ops don't not support dynamic scale parameters, so the `downsample_ratio` hyperparameter must be hardcoded.
+
+Our export script is written to have input size fixed. The output coreml models require iOS14+, MacOS11+. If you have other requirements, feel free to modify the export script. Contributions are welcomed.
+
+## Export Yourself
+
+The following procedures were used to generate our CoreML models.
+
+1. Install dependencies
+```sh
+pip install -r requirements.txt
+```
+
+2. Use the export script. You can change the `resolution` and `downsample-ratio` to fit your need. You can change quantization to one of `[8, 16, 32]`, denoting `int8`, `fp16`, and `fp32`.
+```sh
+python export_coreml.py \
+    --model-variant mobilenetv3 \
+    --checkpoint rvm_mobilenetv3.pth \
+    --resolution 1920 1080 \
+    --downsample-ratio 0.25 \
+    --quantize-nbits 16 \
+    --output model.mlmodel
+```

+ 192 - 0
export_coreml.py

@@ -0,0 +1,192 @@
+"""
+python export_coreml.py \
+    --model-variant mobilenetv3 \
+    --checkpoint rvm_mobilenetv3.pth \
+    --resolution 1920 1080 \
+    --downsample-ratio 0.25 \
+    --quantize-nbits 16 \
+    --output model.mlmodel
+"""
+
+
+import argparse
+import coremltools as ct
+import torch
+
+from coremltools.models.neural_network.quantization_utils import quantize_weights
+from coremltools.converters.mil.mil import Builder as mb
+from coremltools.converters.mil.frontend.torch.torch_op_registry import register_torch_op
+from coremltools.converters.mil.frontend.torch.ops import _get_inputs
+from coremltools.proto import FeatureTypes_pb2 as ft        
+
+from model import MattingNetwork
+
+class Exporter:
+    def __init__(self):
+        self.parse_args()
+        self.init_model()
+        self.register_custom_ops()
+        self.export()
+    
+    def parse_args(self):
+        parser = argparse.ArgumentParser()
+        parser.add_argument('--model-variant', type=str, required=True, choices=['mobilenetv3', 'resnet50'])
+        parser.add_argument('--model-refiner', type=str, default='deep_guided_filter', choices=['deep_guided_filter', 'fast_guided_filter'])
+        parser.add_argument('--checkpoint', type=str, required=False)
+        parser.add_argument('--resolution', type=int, required=True, nargs=2)
+        parser.add_argument('--downsample-ratio', type=float, required=True)
+        parser.add_argument('--quantize-nbits', type=int, required=True, choices=[8, 16, 32])
+        parser.add_argument('--output', type=str, required=True)
+        self.args = parser.parse_args()
+        
+    def init_model(self):
+        downsample_ratio = self.args.downsample_ratio
+        class Wrapper(MattingNetwork):
+            def forward(self, src, r1=None, r2=None, r3=None, r4=None):
+                # Hardcode downsample_ratio into the network instead of taking it as input. This is needed for torchscript tracing.
+                # Also, we are multiply result by 255 to convert them to CoreML image format
+                fgr, pha, r1, r2, r3, r4 = super().forward(src, r1, r2, r3, r4, downsample_ratio)
+                return fgr.mul(255), pha.mul(255), r1, r2, r3, r4
+        
+        self.model = Wrapper(self.args.model_variant, self.args.model_refiner).eval()
+        if self.args.checkpoint is not None:
+            self.model.load_state_dict(torch.load(self.args.checkpoint, map_location='cpu'), strict=False)
+    
+    def register_custom_ops(self):
+        @register_torch_op(override=True)
+        def hardswish_(context, node):
+            inputs = _get_inputs(context, node, expected=1)
+            x = inputs[0]
+            y = mb.sigmoid_hard(x=inputs[0], alpha=1.0/6, beta=0.5)
+            z = mb.mul(x=x, y=y, name=node.name)
+            context.add(z)
+
+        @register_torch_op(override=True)
+        def hardsigmoid_(context, node):
+            inputs = _get_inputs(context, node, expected=1)
+            res = mb.sigmoid_hard(x=inputs[0], alpha=1.0/6, beta=0.5, name=node.name)
+            context.add(res)
+
+        @register_torch_op(override=True)
+        def type_as(context, node):
+            inputs = _get_inputs(context, node)
+            context.add(mb.cast(x=inputs[0], dtype='fp32'), node.name)
+
+        @register_torch_op(override=True)
+        def upsample_bilinear2d(context, node):
+            # Change to use `resize_bilinear` instead to support iOS 13.
+            inputs = _get_inputs(context, node)
+            x = inputs[0]
+            output_size = inputs[1]
+            align_corners = bool(inputs[2].val)
+            scale_factors = inputs[3]
+
+            if scale_factors is not None and scale_factors.val is not None \
+                    and scale_factors.rank == 1 and scale_factors.shape[0] == 2:
+                scale_factors = scale_factors.val
+                resize = mb.resize_bilinear(
+                    x=x,
+                    target_size_height=int(x.shape[-2] * scale_factors[0]),
+                    target_size_width=int(x.shape[-1] * scale_factors[1]),
+                    sampling_mode='ALIGN_CORNERS',
+                    name=node.name,
+                )
+                context.add(resize)
+            else:
+                resize = mb.resize_bilinear(
+                    x=x,
+                    target_size_height=output_size.val[0],
+                    target_size_width=output_size.val[1],
+                    sampling_mode='ALIGN_CORNERS',
+                    name=node.name,
+                )
+                context.add(resize)
+            
+    def export(self):
+        src = torch.zeros([1, 3, *self.args.resolution[::-1]])
+        _, _, r1, r2, r3, r4 = self.model(src)
+        
+        model_traced = torch.jit.trace(self.model, (src, r1, r2, r3, r4))
+        model_coreml = ct.convert(
+            model_traced,
+            inputs=[
+                ct.ImageType(name='src', shape=(ct.RangeDim(), *src.shape[1:]), channel_first=True, scale=1/255),
+                ct.TensorType(name='r1i', shape=(ct.RangeDim(), *r1.shape[1:])),
+                ct.TensorType(name='r2i', shape=(ct.RangeDim(), *r2.shape[1:])),
+                ct.TensorType(name='r3i', shape=(ct.RangeDim(), *r3.shape[1:])),
+                ct.TensorType(name='r4i', shape=(ct.RangeDim(), *r4.shape[1:])),
+            ],
+        )
+    
+        if self.args.quantize_nbits in [8, 16]:
+            out = quantize_weights(model_coreml, nbits=self.args.quantize_nbits)
+            if isinstance(out, ct.models.model.MLModel):
+                # When the export is done on OSX, return is an mlmodel.
+                spec = out.get_spec()
+            else:
+                # When the export is done on Linux, the return is a spec. 
+                spec = out
+        else:
+            spec = model_coreml.get_spec()
+        
+        # Some internal outputs are also named 'fgr' and 'pha'. 
+        # We change them to avoid conflicts.
+        for layer in spec.neuralNetwork.layers:
+            for i in range(len(layer.input)):
+                if layer.input[i] == 'fgr':
+                    layer.input[i] = 'fgr_internal'
+                if layer.input[i] == 'pha':
+                    layer.input[i] = 'pha_internal'
+            for i in range(len(layer.output)):
+                if layer.output[i] == 'fgr':
+                    layer.output[i] = 'fgr_internal'
+                if layer.output[i] == 'pha':
+                    layer.output[i] = 'pha_internal'
+        
+        # Update output names
+        ct.utils.rename_feature(spec, spec.description.output[0].name, 'fgr')
+        ct.utils.rename_feature(spec, spec.description.output[1].name, 'pha')
+        ct.utils.rename_feature(spec, spec.description.output[2].name, 'r1o')
+        ct.utils.rename_feature(spec, spec.description.output[3].name, 'r2o')
+        ct.utils.rename_feature(spec, spec.description.output[4].name, 'r3o')
+        ct.utils.rename_feature(spec, spec.description.output[5].name, 'r4o')
+        
+        # Update model description
+        spec.description.metadata.author = 'Shanchuan Lin'
+        spec.description.metadata.shortDescription = 'A robust human video matting model with recurrent architecture. The model has recurrent states that must be passed to subsequent frames. Please refer to paper "Robust High-Resolution Video Matting with Temporal Guidance" for more details.'
+        spec.description.metadata.license = 'Apache License 2.0'
+        spec.description.metadata.versionString = '1.0.0'
+        spec.description.input[0].shortDescription = 'Source frame'
+        spec.description.input[1].shortDescription = 'Recurrent state 1. Initial state is an all zero tensor. Subsequent state is received from r1o.'
+        spec.description.input[2].shortDescription = 'Recurrent state 2. Initial state is an all zero tensor. Subsequent state is received from r2o.'
+        spec.description.input[3].shortDescription = 'Recurrent state 3. Initial state is an all zero tensor. Subsequent state is received from r3o.'
+        spec.description.input[4].shortDescription = 'Recurrent state 4. Initial state is an all zero tensor. Subsequent state is received from r4o.'
+        spec.description.output[0].shortDescription = 'Foreground prediction'
+        spec.description.output[1].shortDescription = 'Alpha prediction'
+        spec.description.output[2].shortDescription = 'Recurrent state 1. Needs to be passed as r1i input in the next time step.'
+        spec.description.output[3].shortDescription = 'Recurrent state 2. Needs to be passed as r2i input in the next time step.'
+        spec.description.output[4].shortDescription = 'Recurrent state 3. Needs to be passed as r3i input in the next time step.'
+        spec.description.output[5].shortDescription = 'Recurrent state 4. Needs to be passed as r4i input in the next time step.'
+
+        # Update output types
+        spec.description.output[0].type.imageType.colorSpace = ft.ImageFeatureType.RGB
+        spec.description.output[0].type.imageType.width = src.size(3)
+        spec.description.output[0].type.imageType.height = src.size(2)
+        spec.description.output[1].type.imageType.colorSpace = ft.ImageFeatureType.GRAYSCALE
+        spec.description.output[1].type.imageType.width = src.size(3)
+        spec.description.output[1].type.imageType.height = src.size(2)
+
+        # Set recurrent states as optional inputs
+        spec.description.input[1].type.isOptional = True
+        spec.description.input[2].type.isOptional = True
+        spec.description.input[3].type.isOptional = True
+        spec.description.input[4].type.isOptional = True
+        
+        # Save output
+        ct.utils.save_spec(spec, self.args.output)
+        
+        
+        
+        
+if __name__ == '__main__':
+    Exporter()

+ 1 - 0
model/__init__.py

@@ -0,0 +1 @@
+from .model import MattingNetwork

+ 217 - 0
model/decoder.py

@@ -0,0 +1,217 @@
+import torch
+from torch import Tensor
+from torch import nn
+from torch.nn import functional as F
+from typing import Tuple, Optional
+
+class RecurrentDecoder(nn.Module):
+    def __init__(self, feature_channels, decoder_channels):
+        super().__init__()
+        self.avgpool = AvgPool()
+        self.decode4 = BottleneckBlock(feature_channels[3])
+        self.decode3 = UpsamplingBlock(feature_channels[3], feature_channels[2], 3, decoder_channels[0])
+        self.decode2 = UpsamplingBlock(decoder_channels[0], feature_channels[1], 3, decoder_channels[1])
+        self.decode1 = UpsamplingBlock(decoder_channels[1], feature_channels[0], 3, decoder_channels[2])
+        self.decode0 = OutputBlock(decoder_channels[2], 3, decoder_channels[3])
+
+    def forward(self,
+                s0: Tensor, f1: Tensor, f2: Tensor, f3: Tensor, f4: Tensor,
+                r1: Optional[Tensor], r2: Optional[Tensor],
+                r3: Optional[Tensor], r4: Optional[Tensor]):
+        s1, s2, s3 = self.avgpool(s0)
+        x4, r4 = self.decode4(f4, r4)
+        x3, r3 = self.decode3(x4, f3, s3, r3)
+        x2, r2 = self.decode2(x3, f2, s2, r2)
+        x1, r1 = self.decode1(x2, f1, s1, r1)
+        x0 = self.decode0(x1, s0)
+        return x0, r1, r2, r3, r4
+    
+
+class AvgPool(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.avgpool = nn.AvgPool2d(2, 2, count_include_pad=False, ceil_mode=True)
+        
+    def forward_single_frame(self, s0):
+        s1 = self.avgpool(s0)
+        s2 = self.avgpool(s1)
+        s3 = self.avgpool(s2)
+        return s1, s2, s3
+    
+    def forward_time_series(self, s0):
+        B, T = s0.shape[:2]
+        s0 = s0.flatten(0, 1)
+        s1, s2, s3 = self.forward_single_frame(s0)
+        s1 = s1.unflatten(0, (B, T))
+        s2 = s2.unflatten(0, (B, T))
+        s3 = s3.unflatten(0, (B, T))
+        return s1, s2, s3
+    
+    def forward(self, s0):
+        if s0.ndim == 5:
+            return self.forward_time_series(s0)
+        else:
+            return self.forward_single_frame(s0)
+
+
+class BottleneckBlock(nn.Module):
+    def __init__(self, channels):
+        super().__init__()
+        self.channels = channels
+        self.gru = ConvGRU(channels // 2)
+        
+    def forward(self, x, r: Optional[Tensor]):
+        a, b = x.split(self.channels // 2, dim=-3)
+        b, r = self.gru(b, r)
+        x = torch.cat([a, b], dim=-3)
+        return x, r
+
+    
+class UpsamplingBlock(nn.Module):
+    def __init__(self, in_channels, skip_channels, src_channels, out_channels):
+        super().__init__()
+        self.out_channels = out_channels
+        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
+        self.conv = nn.Sequential(
+            nn.Conv2d(in_channels + skip_channels + src_channels, out_channels, 3, 1, 1, bias=False),
+            nn.BatchNorm2d(out_channels),
+            nn.ReLU(True),
+        )
+        self.gru = ConvGRU(out_channels // 2)
+
+    def forward_single_frame(self, x, f, s, r: Optional[Tensor]):
+        x = self.upsample(x)
+        # Optimized for CoreML export
+        if x.size(2) != s.size(2):
+            x = x[:, :, :s.size(2), :]
+        if x.size(3) != s.size(3):
+            x = x[:, :, :, :s.size(3)]
+        x = torch.cat([x, f, s], dim=1)
+        x = self.conv(x)
+        a, b = x.split(self.out_channels // 2, dim=1)
+        b, r = self.gru(b, r)
+        x = torch.cat([a, b], dim=1)
+        return x, r
+    
+    def forward_time_series(self, x, f, s, r: Optional[Tensor]):
+        B, T, _, H, W = s.shape
+        x = x.flatten(0, 1)
+        f = f.flatten(0, 1)
+        s = s.flatten(0, 1)
+        x = self.upsample(x)
+        x = x[:, :, :H, :W]
+        x = torch.cat([x, f, s], dim=1)
+        x = self.conv(x)
+        x = x.unflatten(0, (B, T))
+        a, b = x.split(self.out_channels // 2, dim=2)
+        b, r = self.gru(b, r)
+        x = torch.cat([a, b], dim=2)
+        return x, r
+    
+    def forward(self, x, f, s, r: Optional[Tensor]):
+        if x.ndim == 5:
+            return self.forward_time_series(x, f, s, r)
+        else:
+            return self.forward_single_frame(x, f, s, r)
+
+
+class OutputBlock(nn.Module):
+    def __init__(self, in_channels, src_channels, out_channels):
+        super().__init__()
+        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
+        self.conv = nn.Sequential(
+            nn.Conv2d(in_channels + src_channels, out_channels, 3, 1, 1, bias=False),
+            nn.BatchNorm2d(out_channels),
+            nn.ReLU(True),
+            nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
+            nn.BatchNorm2d(out_channels),
+            nn.ReLU(True),
+        )
+        
+    def forward_single_frame(self, x, s):
+        x = self.upsample(x)
+        if x.size(2) != s.size(2):
+            x = x[:, :, :s.size(2), :]
+        if x.size(3) != s.size(3):
+            x = x[:, :, :, :s.size(3)]
+        x = torch.cat([x, s], dim=1)
+        x = self.conv(x)
+        return x
+    
+    def forward_time_series(self, x, s):
+        B, T, _, H, W = s.shape
+        x = x.flatten(0, 1)
+        s = s.flatten(0, 1)
+        x = self.upsample(x)
+        x = x[:, :, :H, :W]
+        x = torch.cat([x, s], dim=1)
+        x = self.conv(x)
+        x = x.unflatten(0, (B, T))
+        return x
+    
+    def forward(self, x, s):
+        if x.ndim == 5:
+            return self.forward_time_series(x, s)
+        else:
+            return self.forward_single_frame(x, s)
+
+
+class ConvGRU(nn.Module):
+    def __init__(self,
+                 channels: int,
+                 kernel_size: int = 3,
+                 padding: int = 1):
+        super().__init__()
+        self.channels = channels
+        self.ih = nn.Sequential(
+            nn.Conv2d(channels * 2, channels * 2, kernel_size, padding=padding),
+            nn.Sigmoid()
+        )
+        self.hh = nn.Sequential(
+            nn.Conv2d(channels * 2, channels, kernel_size, padding=padding),
+            nn.Tanh()
+        )
+        
+    def forward_single_frame(self, x, h):
+        r, z = self.ih(torch.cat([x, h], dim=1)).split(self.channels, dim=1)
+        c = self.hh(torch.cat([x, r * h], dim=1))
+        h = (1 - z) * h + z * c
+        return h, h
+    
+    def forward_time_series(self, x, h):
+        o = []
+        for xt in x.unbind(dim=1):
+            ot, h = self.forward_single_frame(xt, h)
+            o.append(ot)
+        o = torch.stack(o, dim=1)
+        return o, h
+        
+    def forward(self, x, h: Optional[Tensor]):
+        if h is None:
+            h = torch.zeros((x.size(0), x.size(-3), x.size(-2), x.size(-1)),
+                            device=x.device, dtype=x.dtype)
+        
+        if x.ndim == 5:
+            return self.forward_time_series(x, h)
+        else:
+            return self.forward_single_frame(x, h)
+
+
+class Projection(nn.Module):
+    def __init__(self, in_channels, out_channels):
+        super().__init__()
+        self.conv = nn.Conv2d(in_channels, out_channels, 1)
+    
+    def forward_single_frame(self, x):
+        return self.conv(x)
+    
+    def forward_time_series(self, x):
+        B, T = x.shape[:2]
+        return self.conv(x.flatten(0, 1)).unflatten(0, (B, T))
+        
+    def forward(self, x):
+        if x.ndim == 5:
+            return self.forward_time_series(x)
+        else:
+            return self.forward_single_frame(x)
+    

+ 61 - 0
model/deep_guided_filter.py

@@ -0,0 +1,61 @@
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+"""
+Adopted from <https://github.com/wuhuikai/DeepGuidedFilter/>
+"""
+
+class DeepGuidedFilterRefiner(nn.Module):
+    def __init__(self, hid_channels=16):
+        super().__init__()
+        self.box_filter = nn.Conv2d(4, 4, kernel_size=3, padding=1, bias=False, groups=4)
+        self.box_filter.weight.data[...] = 1 / 9
+        self.conv = nn.Sequential(
+            nn.Conv2d(4 * 2 + hid_channels, hid_channels, kernel_size=1, bias=False),
+            nn.BatchNorm2d(hid_channels),
+            nn.ReLU(True),
+            nn.Conv2d(hid_channels, hid_channels, kernel_size=1, bias=False),
+            nn.BatchNorm2d(hid_channels),
+            nn.ReLU(True),
+            nn.Conv2d(hid_channels, 4, kernel_size=1, bias=True)
+        )
+        
+    def forward_single_frame(self, fine_src, base_src, base_fgr, base_pha, base_hid):
+        fine_x = torch.cat([fine_src, fine_src.mean(1, keepdim=True)], dim=1)
+        base_x = torch.cat([base_src, base_src.mean(1, keepdim=True)], dim=1)
+        base_y = torch.cat([base_fgr, base_pha], dim=1)
+        
+        mean_x = self.box_filter(base_x)
+        mean_y = self.box_filter(base_y)
+        cov_xy = self.box_filter(base_x * base_y) - mean_x * mean_y
+        var_x  = self.box_filter(base_x * base_x) - mean_x * mean_x
+        
+        A = self.conv(torch.cat([cov_xy, var_x, base_hid], dim=1))
+        b = mean_y - A * mean_x
+        
+        H, W = fine_src.shape[2:]
+        A = F.interpolate(A, (H, W), mode='bilinear', align_corners=False)
+        b = F.interpolate(b, (H, W), mode='bilinear', align_corners=False)
+        
+        out = A * fine_x + b
+        fgr, pha = out.split([3, 1], dim=1)
+        return fgr, pha
+    
+    def forward_time_series(self, fine_src, base_src, base_fgr, base_pha, base_hid):
+        B, T = fine_src.shape[:2]
+        fgr, pha = self.forward_single_frame(
+            fine_src.flatten(0, 1),
+            base_src.flatten(0, 1),
+            base_fgr.flatten(0, 1),
+            base_pha.flatten(0, 1),
+            base_hid.flatten(0, 1))
+        fgr = fgr.unflatten(0, (B, T))
+        pha = pha.unflatten(0, (B, T))
+        return fgr, pha
+    
+    def forward(self, fine_src, base_src, base_fgr, base_pha, base_hid):
+        if fine_src.ndim == 5:
+            return self.forward_time_series(fine_src, base_src, base_fgr, base_pha, base_hid)
+        else:
+            return self.forward_single_frame(fine_src, base_src, base_fgr, base_pha, base_hid)

+ 76 - 0
model/fast_guided_filter.py

@@ -0,0 +1,76 @@
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+"""
+Adopted from <https://github.com/wuhuikai/DeepGuidedFilter/>
+"""
+
+class FastGuidedFilterRefiner(nn.Module):
+    def __init__(self, *args, **kwargs):
+        super().__init__()
+        self.guilded_filter = FastGuidedFilter(1)
+    
+    def forward_single_frame(self, fine_src, base_src, base_fgr, base_pha):
+        fine_src_gray = fine_src.mean(1, keepdim=True)
+        base_src_gray = base_src.mean(1, keepdim=True)
+        
+        fgr, pha = self.guilded_filter(
+            torch.cat([base_src, base_src_gray], dim=1),
+            torch.cat([base_fgr, base_pha], dim=1),
+            torch.cat([fine_src, fine_src_gray], dim=1)).split([3, 1], dim=1)
+        
+        return fgr, pha
+    
+    def forward_time_series(self, fine_src, base_src, base_fgr, base_pha):
+        B, T = fine_src.shape[:2]
+        fgr, pha = self.forward_single_frame(
+            fine_src.flatten(0, 1),
+            base_src.flatten(0, 1),
+            base_fgr.flatten(0, 1),
+            base_pha.flatten(0, 1))
+        fgr = fgr.unflatten(0, (B, T))
+        pha = pha.unflatten(0, (B, T))
+        return fgr, pha
+    
+    def forward(self, fine_src, base_src, base_fgr, base_pha, base_hid):
+        if fine_src.ndim == 5:
+            return self.forward_time_series(fine_src, base_src, base_fgr, base_pha)
+        else:
+            return self.forward_single_frame(fine_src, base_src, base_fgr, base_pha)
+
+
+class FastGuidedFilter(nn.Module):
+    def __init__(self, r: int, eps: float = 1e-5):
+        super().__init__()
+        self.r = r
+        self.eps = eps
+        self.boxfilter = BoxFilter(r)
+
+    def forward(self, lr_x, lr_y, hr_x):
+        mean_x = self.boxfilter(lr_x)
+        mean_y = self.boxfilter(lr_y)
+        cov_xy = self.boxfilter(lr_x * lr_y) - mean_x * mean_y
+        var_x = self.boxfilter(lr_x * lr_x) - mean_x * mean_x
+        A = cov_xy / (var_x + self.eps)
+        b = mean_y - A * mean_x
+        A = F.interpolate(A, hr_x.shape[2:], mode='bilinear', align_corners=False)
+        b = F.interpolate(b, hr_x.shape[2:], mode='bilinear', align_corners=False)
+        return A * hr_x + b
+
+
+class BoxFilter(nn.Module):
+    def __init__(self, r):
+        super(BoxFilter, self).__init__()
+        self.r = r
+
+    def forward(self, x):
+        # Note: The original implementation at <https://github.com/wuhuikai/DeepGuidedFilter/>
+        #       uses faster box blur. However, it may not be friendly for ONNX export.
+        #       We are switching to use simple convolution for box blur.
+        kernel_size = 2 * self.r + 1
+        kernel_x = torch.full((x.data.shape[1], 1, 1, kernel_size), 1 / kernel_size, device=x.device, dtype=x.dtype)
+        kernel_y = torch.full((x.data.shape[1], 1, kernel_size, 1), 1 / kernel_size, device=x.device, dtype=x.dtype)
+        x = F.conv2d(x, kernel_x, padding=(0, self.r), groups=x.data.shape[1])
+        x = F.conv2d(x, kernel_y, padding=(self.r, 0), groups=x.data.shape[1])
+        return x

+ 29 - 0
model/lraspp.py

@@ -0,0 +1,29 @@
+from torch import nn
+
+class LRASPP(nn.Module):
+    def __init__(self, in_channels, out_channels):
+        super().__init__()
+        self.aspp1 = nn.Sequential(
+            nn.Conv2d(in_channels, out_channels, 1, bias=False),
+            nn.BatchNorm2d(out_channels),
+            nn.ReLU(True)
+        )
+        self.aspp2 = nn.Sequential(
+            nn.AdaptiveAvgPool2d(1),
+            nn.Conv2d(in_channels, out_channels, 1, bias=False),
+            nn.Sigmoid()
+        )
+        
+    def forward_single_frame(self, x):
+        return self.aspp1(x) * self.aspp2(x)
+    
+    def forward_time_series(self, x):
+        B, T = x.shape[:2]
+        x = self.forward_single_frame(x.flatten(0, 1)).unflatten(0, (B, T))
+        return x
+    
+    def forward(self, x):
+        if x.ndim == 5:
+            return self.forward_time_series(x)
+        else:
+            return self.forward_single_frame(x)

+ 72 - 0
model/mobilenetv3.py

@@ -0,0 +1,72 @@
+from torch import nn
+from torchvision.models.mobilenetv3 import MobileNetV3, InvertedResidualConfig
+from torchvision.models.utils import load_state_dict_from_url
+from torchvision.transforms.functional import normalize
+
+class MobileNetV3LargeEncoder(MobileNetV3):
+    def __init__(self, pretrained: bool = False):
+        super().__init__(
+            inverted_residual_setting=[
+                InvertedResidualConfig( 16, 3,  16,  16, False, "RE", 1, 1, 1),
+                InvertedResidualConfig( 16, 3,  64,  24, False, "RE", 2, 1, 1),  # C1
+                InvertedResidualConfig( 24, 3,  72,  24, False, "RE", 1, 1, 1),
+                InvertedResidualConfig( 24, 5,  72,  40,  True, "RE", 2, 1, 1),  # C2
+                InvertedResidualConfig( 40, 5, 120,  40,  True, "RE", 1, 1, 1),
+                InvertedResidualConfig( 40, 5, 120,  40,  True, "RE", 1, 1, 1),
+                InvertedResidualConfig( 40, 3, 240,  80, False, "HS", 2, 1, 1),  # C3
+                InvertedResidualConfig( 80, 3, 200,  80, False, "HS", 1, 1, 1),
+                InvertedResidualConfig( 80, 3, 184,  80, False, "HS", 1, 1, 1),
+                InvertedResidualConfig( 80, 3, 184,  80, False, "HS", 1, 1, 1),
+                InvertedResidualConfig( 80, 3, 480, 112,  True, "HS", 1, 1, 1),
+                InvertedResidualConfig(112, 3, 672, 112,  True, "HS", 1, 1, 1),
+                InvertedResidualConfig(112, 5, 672, 160,  True, "HS", 2, 2, 1),  # C4
+                InvertedResidualConfig(160, 5, 960, 160,  True, "HS", 1, 2, 1),
+                InvertedResidualConfig(160, 5, 960, 160,  True, "HS", 1, 2, 1),
+            ],
+            last_channel=1280
+        )
+        
+        if pretrained:
+            self.load_state_dict(load_state_dict_from_url(
+                'https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth'))
+
+        del self.avgpool
+        del self.classifier
+        
+    def forward_single_frame(self, x):
+        x = normalize(x, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
+        
+        x = self.features[0](x)
+        x = self.features[1](x)
+        f1 = x
+        x = self.features[2](x)
+        x = self.features[3](x)
+        f2 = x
+        x = self.features[4](x)
+        x = self.features[5](x)
+        x = self.features[6](x)
+        f3 = x
+        x = self.features[7](x)
+        x = self.features[8](x)
+        x = self.features[9](x)
+        x = self.features[10](x)
+        x = self.features[11](x)
+        x = self.features[12](x)
+        x = self.features[13](x)
+        x = self.features[14](x)
+        x = self.features[15](x)
+        x = self.features[16](x)
+        f4 = x
+        return [f1, f2, f3, f4]
+    
+    def forward_time_series(self, x):
+        B, T = x.shape[:2]
+        features = self.forward_single_frame(x.flatten(0, 1))
+        features = [f.unflatten(0, (B, T)) for f in features]
+        return features
+
+    def forward(self, x):
+        if x.ndim == 5:
+            return self.forward_time_series(x)
+        else:
+            return self.forward_single_frame(x)

+ 79 - 0
model/model.py

@@ -0,0 +1,79 @@
+import torch
+from torch import Tensor
+from torch import nn
+from torch.nn import functional as F
+from typing import Optional, List
+
+from .mobilenetv3 import MobileNetV3LargeEncoder
+from .resnet import ResNet50Encoder
+from .lraspp import LRASPP
+from .decoder import RecurrentDecoder, Projection
+from .fast_guided_filter import FastGuidedFilterRefiner
+from .deep_guided_filter import DeepGuidedFilterRefiner
+
+class MattingNetwork(nn.Module):
+    def __init__(self,
+                 variant: str = 'mobilenetv3',
+                 refiner: str = 'deep_guided_filter',
+                 pretrained_backbone: bool = False):
+        super().__init__()
+        assert variant in ['mobilenetv3', 'resnet50']
+        assert refiner in ['fast_guided_filter', 'deep_guided_filter']
+        
+        if variant == 'mobilenetv3':
+            self.backbone = MobileNetV3LargeEncoder(pretrained_backbone)
+            self.aspp = LRASPP(960, 128)
+            self.decoder = RecurrentDecoder([16, 24, 40, 128], [80, 40, 32, 16])
+        else:
+            self.backbone = ResNet50Encoder(pretrained_backbone)
+            self.aspp = LRASPP(2048, 256)
+            self.decoder = RecurrentDecoder([64, 256, 512, 256], [128, 64, 32, 16])
+            
+        self.project_mat = Projection(16, 4)
+        self.project_seg = Projection(16, 1)
+
+        if refiner == 'deep_guided_filter':
+            self.refiner = DeepGuidedFilterRefiner()
+        else:
+            self.refiner = FastGuidedFilterRefiner()
+        
+    def forward(self,
+                src: Tensor,
+                r1: Optional[Tensor] = None,
+                r2: Optional[Tensor] = None,
+                r3: Optional[Tensor] = None,
+                r4: Optional[Tensor] = None,
+                downsample_ratio: float = 1,
+                segmentation_pass: bool = False):
+        
+        if downsample_ratio != 1:
+            src_sm = self._interpolate(src, scale_factor=downsample_ratio)
+        else:
+            src_sm = src
+        
+        f1, f2, f3, f4 = self.backbone(src_sm)
+        f4 = self.aspp(f4)
+        hid, *rec = self.decoder(src_sm, f1, f2, f3, f4, r1, r2, r3, r4)
+        
+        if not segmentation_pass:
+            fgr_residual, pha = self.project_mat(hid).split([3, 1], dim=-3)
+            if downsample_ratio != 1:
+                fgr_residual, pha = self.refiner(src, src_sm, fgr_residual, pha, hid)
+            fgr = fgr_residual + src
+            fgr = fgr.clamp(0., 1.)
+            pha = pha.clamp(0., 1.)
+            return [fgr, pha, *rec]
+        else:
+            seg = self.project_seg(hid)
+            return [seg, *rec]
+
+    def _interpolate(self, x: Tensor, scale_factor: float):
+        if x.ndim == 5:
+            B, T = x.shape[:2]
+            x = F.interpolate(x.flatten(0, 1), scale_factor=scale_factor,
+                mode='bilinear', align_corners=False, recompute_scale_factor=False)
+            x = x.unflatten(0, (B, T))
+        else:
+            x = F.interpolate(x, scale_factor=scale_factor,
+                mode='bilinear', align_corners=False, recompute_scale_factor=False)
+        return x

+ 45 - 0
model/resnet.py

@@ -0,0 +1,45 @@
+from torch import nn
+from torchvision.models.resnet import ResNet, Bottleneck
+from torchvision.models.utils import load_state_dict_from_url
+
+class ResNet50Encoder(ResNet):
+    def __init__(self, pretrained: bool = False):
+        super().__init__(
+            block=Bottleneck,
+            layers=[3, 4, 6, 3],
+            replace_stride_with_dilation=[False, False, True],
+            norm_layer=None)
+        
+        if pretrained:
+            self.load_state_dict(load_state_dict_from_url(
+                'https://download.pytorch.org/models/resnet50-0676ba61.pth'))
+        
+        del self.avgpool
+        del self.fc
+        
+    def forward_single_frame(self, x):
+        x = self.conv1(x)
+        x = self.bn1(x)
+        x = self.relu(x)
+        f1 = x  # 1/2
+        x = self.maxpool(x)
+        x = self.layer1(x)
+        f2 = x  # 1/4
+        x = self.layer2(x)
+        f3 = x  # 1/8
+        x = self.layer3(x)
+        x = self.layer4(x)
+        f4 = x  # 1/16
+        return [f1, f2, f3, f4]
+    
+    def forward_time_series(self, x):
+        B, T = x.shape[:2]
+        features = self.forward_single_frame(x.flatten(0, 1))
+        features = [f.unflatten(0, (B, T)) for f in features]
+        return features
+    
+    def forward(self, x):
+        if x.ndim == 5:
+            return self.forward_time_series(x)
+        else:
+            return self.forward_single_frame(x)

+ 3 - 0
requirements.txt

@@ -0,0 +1,3 @@
+torch==1.8.1
+torchvision==0.9.1
+coremltools==5.0b1