Эх сурвалжийг харах

Add TensorFlow export code

Peter Lin 3 жил өмнө
commit
c755c07c40

+ 57 - 0
README.md

@@ -0,0 +1,57 @@
+# TensorFlow Implementation
+
+## Overview
+
+This branch contains our model implemented in TensorFlow 2. We have transferred the weights from the official PyTorch version and provide the model in TensorFlow SavedModel format. If you only need inference, you do not need any source code from this branch. If you need the weights in native TensorFlow model for advanced experiments, we show you how to load weights from PyTorch.
+
+## Transfer PyTorch Weights to TensorFlow
+
+```python
+import tensorflow as tf
+import torch
+
+from model import MattingNetwork, load_torch_weights
+
+# Create a TF model
+model = MattingNetwork('mobilenetv3')
+
+# Create dummpy inputs.
+src = tf.random.normal([1, 1080, 1920, 3])
+rec = [ tf.constant(0.) ] * 4
+downsample_ratio = tf.constant(0.25)
+
+# Do a forward pass to initialize the model.
+out = model([src, *rec, downsample_ratio])
+
+# Transfer PyTorch weights to TF model.
+state_dict = torch.load('rvm_mobilenetv3.pth')
+load_torch_weights(model, state_dict)
+```
+
+## Export TensorFlow SavedModel
+
+We use the following script to generate the official TensorFlow SavedModel from the PyTorch checkpoint.
+
+```sh
+python export_tensorflow.py \
+    --model-variant mobilenetv3 \
+    --model-refiner deep_guided_filter \
+    --pytorch-checkpoint rvm_mobilenetv3.pth \
+    --tensorflow-output rvm_mobilenetv3_tf
+```
+
+## Export TensorFlow.js
+
+We already provide an exported TensorFlow.js model. If you need other configurations, use the export procedure below.
+
+Currently TensorFlow.js only supports Fast Guided Filter. To export to tfjs, first use the `export_tensorflow.py` script above with `--model-refiner fast_guided_filter` to generate a TensorFlow SavedModel. Then convert the SavedModel to tfjs:
+
+```sh
+pip install tensorflowjs
+
+tensorflowjs_converter \
+    --quantize_uint8 \
+    --input_format=tf_saved_model \
+    rvm_mobilenetv3_tf \
+    rvm_mobilenetv3_tfjs_int8
+```

+ 61 - 0
export_tensorflow.py

@@ -0,0 +1,61 @@
+"""
+python export_tensorflow.py \
+    --model-variant mobilenetv3 \
+    --model-variant deep_guided_filter \
+    --pytorch-checkpoint rvm_mobilenetv3.pth \
+    --tensorflow-output rvm_mobilenetv3_tf
+"""
+
+import argparse
+import torch
+import tensorflow as tf
+
+from model import MattingNetwork, load_torch_weights
+
+
+# Add input output names and shapes
+class MattingNetworkWrapper(MattingNetwork):
+    @tf.function(input_signature=[[
+        tf.TensorSpec(tf.TensorShape([None, None, None, 3]), tf.float32, 'src'),
+        tf.TensorSpec(tf.TensorShape(None), tf.float32, 'r1i'),
+        tf.TensorSpec(tf.TensorShape(None), tf.float32, 'r2i'),
+        tf.TensorSpec(tf.TensorShape(None), tf.float32, 'r3i'),
+        tf.TensorSpec(tf.TensorShape(None), tf.float32, 'r4i'),
+        tf.TensorSpec(tf.TensorShape(None), tf.float32, 'downsample_ratio')
+    ]])
+    def call(self, inputs):
+        fgr, pha, r1o, r2o, r3o, r4o = super().call(inputs)
+        return {'fgr': fgr, 'pha': pha, 'r1o': r1o, 'r2o': r2o, 'r3o': r3o, 'r4o': r4o}
+
+        
+class Exporter:
+    def __init__(self):
+        self.parse_args()
+        self.init_model()
+        self.export()
+        
+    def parse_args(self):
+        parser = argparse.ArgumentParser()
+        parser.add_argument('--model-variant', type=str, required=True, choices=['mobilenetv3', 'resnet50'])
+        parser.add_argument('--model-refiner', type=str, default='deep_guided_filter', choices=['deep_guided_filter', 'fast_guided_filter'])
+        parser.add_argument('--pytorch-checkpoint', type=str, required=True)
+        parser.add_argument('--tensorflow-output', type=str, required=True)
+        self.args = parser.parse_args()
+        
+    def init_model(self):
+        # Construct model
+        self.model = MattingNetworkWrapper(self.args.model_variant, self.args.model_refiner)
+        # Build model
+        src = tf.random.normal([1, 1080, 1920, 3])
+        rec = [ tf.constant(0.) ] * 4
+        downsample_ratio = tf.constant(0.25)
+        self.model([src, *rec, downsample_ratio])
+        # Load PyTorch checkpoint
+        load_torch_weights(self.model, torch.load(self.args.pytorch_checkpoint, map_location='cpu'))
+
+    def export(self):
+        self.model.save(self.args.tensorflow_output)
+
+
+if __name__ == '__main__':
+    Exporter()

+ 2 - 0
model/__init__.py

@@ -0,0 +1,2 @@
+from .model import MattingNetwork
+from .load_weights import load_torch_weights

+ 99 - 0
model/decoder.py

@@ -0,0 +1,99 @@
+import tensorflow as tf
+from tensorflow.keras.models import Sequential
+from tensorflow.keras.layers import Layer, Conv2D, BatchNormalization, ReLU, Activation, AveragePooling2D, UpSampling2D
+
+class RecurrentDecoder(Layer):
+    def __init__(self, channels):
+        super().__init__()
+        self.avgpool = AveragePooling2D(padding='SAME')
+        self.decode4 = BottleneckBlock(channels[0])
+        self.decode3 = UpsamplingBlock(channels[1])
+        self.decode2 = UpsamplingBlock(channels[2])
+        self.decode1 = UpsamplingBlock(channels[3])
+        self.decode0 = OutputBlock(channels[4])
+        
+    def call(self, inputs):
+        s0, f1, f2, f3, f4, r1, r2, r3, r4 = inputs
+        s1 = self.avgpool(s0)
+        s2 = self.avgpool(s1)
+        s3 = self.avgpool(s2)
+        x4, r4 = self.decode4([f4, r4])
+        x3, r3 = self.decode3([x4, f3, s3, r3])
+        x2, r2 = self.decode2([x3, f2, s2, r2])
+        x1, r1 = self.decode1([x2, f1, s1, r1])
+        x0 = self.decode0([x1, s0])
+        return x0, r1, r2, r3, r4
+
+
+class BottleneckBlock(Layer):
+    def __init__(self, channels):
+        super().__init__()
+        self.gru = ConvGRU(channels // 2)
+        
+    def call(self, inputs):
+        x, r = inputs
+        a, b = tf.split(x, 2, -1)
+        b, r = self.gru([b, r])
+        x = tf.concat([a, b], -1)
+        return x, r
+
+
+class UpsamplingBlock(Layer):
+    def __init__(self, channels):
+        super().__init__()
+        self.upsample = UpSampling2D(interpolation='bilinear')
+        self.conv = Sequential([
+            Conv2D(channels, 3, padding='SAME', use_bias=False),
+            BatchNormalization(momentum=0.1, epsilon=1e-5),
+            ReLU()
+        ])
+        self.gru = ConvGRU(channels // 2)
+        
+    def call(self, inputs):
+        x, f, s, r = inputs
+        x = self.upsample(x)
+        x = tf.image.crop_to_bounding_box(x, 0, 0, tf.shape(s)[1], tf.shape(s)[2])
+        x = tf.concat([x, f, s], -1)
+        x = self.conv(x)
+        a, b = tf.split(x, 2, -1)
+        b, r = self.gru([b, r])
+        x = tf.concat([a, b], -1)
+        return x, r
+
+
+class OutputBlock(Layer):
+    def __init__(self, channels):
+        super().__init__()
+        self.upsample = UpSampling2D(interpolation='bilinear')
+        self.conv = Sequential([
+            Conv2D(channels, 3, padding='SAME', use_bias=False),
+            BatchNormalization(momentum=0.1, epsilon=1e-5),
+            ReLU(),
+            Conv2D(channels, 3, padding='SAME', use_bias=False),
+            BatchNormalization(momentum=0.1, epsilon=1e-5),
+            ReLU(),
+        ])
+    
+    def call(self, inputs):
+        x, s = inputs
+        x = self.upsample(x)
+        x = tf.image.crop_to_bounding_box(x, 0, 0, tf.shape(s)[1], tf.shape(s)[2])
+        x = tf.concat([x, s], -1)
+        x = self.conv(x)
+        return x
+
+
+class ConvGRU(Layer):
+    def __init__(self, channels, kernel_size=3):
+        super().__init__()
+        self.channels = channels
+        self.ih = Conv2D(channels * 2, kernel_size, padding='SAME', activation='sigmoid')
+        self.hh = Conv2D(channels, kernel_size, padding='SAME', activation='tanh')
+        
+    def call(self, inputs):
+        x, h = inputs
+        h = tf.broadcast_to(h, tf.shape(x))
+        r, z = tf.split(self.ih(tf.concat([x, h], -1)), 2, -1)
+        c = self.hh(tf.concat([x, r * h], -1))
+        h = (1 - z) * h + z * c
+        return h, h

+ 42 - 0
model/deep_guided_filter.py

@@ -0,0 +1,42 @@
+import tensorflow as tf
+from tensorflow.keras.models import Sequential
+from tensorflow.keras.layers import Layer, Conv2D, BatchNormalization, ReLU, DepthwiseConv2D, ZeroPadding2D
+
+class DeepGuidedFilterRefiner(Layer):
+    def __init__(self, hid_channels=16, radius=1):
+        super().__init__()
+        self.box_filter = Sequential([
+            ZeroPadding2D(1),
+            DepthwiseConv2D(3, use_bias=False)
+        ])
+        self.box_filer = Conv2D(4, 3, padding='SAME', dilation_rate=radius, use_bias=False, groups=4)
+        self.conv = Sequential([
+            Conv2D(hid_channels, 1, use_bias=False),
+            BatchNormalization(momentum=0.1, epsilon=1e-5),
+            ReLU(),
+            Conv2D(hid_channels, 1, use_bias=False),
+            BatchNormalization(momentum=0.1, epsilon=1e-5),
+            ReLU(),
+            Conv2D(4, 1)
+        ])
+        
+    def call(self, inputs):
+        fine_src, base_src, base_fgr, base_pha, base_hid = inputs
+        fine_x = tf.concat([fine_src, tf.reduce_mean(fine_src, -1, keepdims=True)], -1)
+        base_x = tf.concat([base_src, tf.reduce_mean(base_src, -1, keepdims=True)], -1)
+        base_y = tf.concat([base_fgr, base_pha], -1)
+        
+        mean_x = self.box_filter(base_x)
+        mean_y = self.box_filter(base_y)
+        cov_xy = self.box_filter(base_x * base_y) - mean_x * mean_y
+        var_x  = self.box_filter(base_x * base_x) - mean_x * mean_x
+        
+        A = self.conv(tf.concat([cov_xy, var_x, base_hid], -1))
+        b = mean_y - A * mean_x
+        
+        H, W = tf.shape(fine_src)[1], tf.shape(fine_src)[2]
+        mean_A = tf.image.resize(A, (H, W))
+        mean_b = tf.image.resize(b, (H, W))
+        out = mean_A * fine_x + mean_b
+        fgr, pha = tf.split(out, [3, 1], -1)
+        return fgr, pha

+ 73 - 0
model/fast_guided_filter.py

@@ -0,0 +1,73 @@
+import tensorflow as tf
+from tensorflow.keras.layers import Layer, Conv2D, BatchNormalization, ReLU, DepthwiseConv2D, ZeroPadding2D
+
+"""
+Adopted from <https://github.com/wuhuikai/DeepGuidedFilter/>
+"""
+
+class BoxFilter(Layer):
+    def __init__(self, r):
+        super(BoxFilter, self).__init__()
+        self.kernel_size = 2 * r + 1
+        self.filter_x = DepthwiseConv2D((1, self.kernel_size), use_bias=False)
+        self.filter_y = DepthwiseConv2D((self.kernel_size, 1), use_bias=False)
+
+    def build(self, input_shape):
+        self.filter_x.build(input_shape)
+        self.filter_y.build(input_shape)
+        weight_x = self.filter_x.get_weights()[0]
+        weight_y = self.filter_y.get_weights()[0]
+        weight_x.fill(1 / self.kernel_size)
+        weight_y.fill(1 / self.kernel_size)
+        self.filter_x.set_weights([weight_x])
+        self.filter_y.set_weights([weight_y])
+
+    def call(self, x):
+        return self.filter_x(self.filter_y(x))
+
+
+class FastGuidedFilter(Layer):
+    def __init__(self, r: int, eps: float = 1e-5):
+        super().__init__()
+        self.r = r
+        self.eps = eps
+        self.boxfilter = BoxFilter(r)
+
+    def call(self, inputs):
+        lr_x, lr_y, hr_x = inputs
+        mean_x = self.boxfilter(lr_x)
+        mean_y = self.boxfilter(lr_y)
+        cov_xy = self.boxfilter(lr_x * lr_y) - mean_x * mean_y
+        var_x = self.boxfilter(lr_x * lr_x) - mean_x * mean_x
+
+        A = cov_xy / (var_x + self.eps)
+        b = mean_y - A * mean_x
+
+        ## mean_A; mean_b
+        H, W = tf.shape(hr_x)[1], tf.shape(hr_x)[2]
+        mean_A = tf.image.resize(A, (H, W))
+        mean_b = tf.image.resize(b, (H, W))
+        return mean_A * hr_x + mean_b
+
+
+class FastGuidedFilterRefiner(Layer):
+    def __init__(self, *args, **kwargs):
+        super().__init__()
+        self.guilded_filter = FastGuidedFilter(1)
+    
+    def call(self, inputs):
+        fine_src, base_src, base_fgr, base_pha = inputs[:4]
+        
+        fine_src_gray = tf.reduce_mean(fine_src, -1, keepdims=True)
+        base_src_gray = tf.reduce_mean(base_src, -1, keepdims=True)
+        
+        out = self.guilded_filter([
+            tf.concat([base_src, base_src_gray], -1),
+            tf.concat([base_fgr, base_pha], -1),
+            tf.concat([fine_src, fine_src_gray], -1)
+        ])
+        
+        fgr, pha = tf.split(out, [3, 1], -1)
+        
+        return fgr, pha
+    

+ 113 - 0
model/load_weights.py

@@ -0,0 +1,113 @@
+from tensorflow.keras.layers import DepthwiseConv2D
+from .mobilenetv3 import *
+from .resnet import ResNet50Encoder
+from .deep_guided_filter import DeepGuidedFilterRefiner
+
+# --------------------------- Load torch weights  ---------------------------
+def load_torch_weights(model, state_dict):
+    if isinstance(model.backbone, MobileNetV3Encoder):
+        load_MobileNetV3_weights(model.backbone, state_dict, 'backbone')
+    if isinstance(model.backbone, ResNet50Encoder):
+        load_ResNetEncoder_weights(model.backbone, state_dict, 'backbone')
+    load_LRASPP_weights(model.aspp, state_dict, 'aspp')
+    load_RecurrentDecoder_weights(model.decoder, state_dict, 'decoder')
+    load_conv_weights(model.project_mat, state_dict, 'project_mat.conv')
+    if isinstance(model.refiner, DeepGuidedFilterRefiner):
+        load_DeepGuidedFilter_weights(model.refiner, state_dict, 'refiner')
+    
+# --------------------------- General  ---------------------------
+def load_conv_weights(conv, state_dict, name):
+    weight = state_dict[name + '.weight']
+    if isinstance(conv, DepthwiseConv2D):
+        weight = weight.permute(2, 3, 0, 1).numpy()
+    else:
+        weight = weight.permute(2, 3, 1, 0).numpy()
+    if name + '.bias' in state_dict:
+        bias = state_dict[name + '.bias'].numpy()
+        conv.set_weights([weight, bias])
+    else:
+        conv.set_weights([weight])
+
+def load_bn_weights(bn, state_dict, name):
+    weight = state_dict[name + '.weight']
+    bias = state_dict[name + '.bias']
+    running_mean = state_dict[name + '.running_mean']
+    running_var = state_dict[name + '.running_var']
+    bn.set_weights([weight, bias, running_mean, running_var])
+        
+# --------------------------- MobileNetV3 ---------------------------
+def load_ConvBNActivation_weights(module, state_dict, name):
+    load_conv_weights(module.conv, state_dict, name + '.0')
+    load_bn_weights(module.bn, state_dict, name + '.1')
+
+def load_InvertedResidual_weights(module, state_dict, name):
+    for i, layer in enumerate(module.block.layers):
+        if isinstance(layer, ConvBNActivation):
+            load_ConvBNActivation_weights(layer, state_dict, f'{name}.block.{i}')
+        if isinstance(layer, SqueezeExcitation):
+            load_conv_weights(layer.fc1, state_dict, f'{name}.block.{i}.fc1')
+            load_conv_weights(layer.fc2, state_dict, f'{name}.block.{i}.fc2')
+
+def load_MobileNetV3_weights(backbone, state_dict, name):
+    for i, module in enumerate(backbone.features):
+        if isinstance(module, ConvBNActivation):
+            load_ConvBNActivation_weights(module, state_dict, f'{name}.features.{i}')
+        if isinstance(module, InvertedResidual):
+            load_InvertedResidual_weights(module, state_dict, f'{name}.features.{i}')
+
+# --------------------------- ResNet ---------------------------
+def load_ResNetEncoder_weights(module, state_dict, name):
+    load_conv_weights(module.conv1, state_dict, f'{name}.conv1')
+    load_bn_weights(module.bn1, state_dict, f'{name}.bn1')
+    for l in range(1, 5):
+        for b, resblock in enumerate(getattr(module, f'layer{l}').layers):
+            if hasattr(resblock, 'convd'):
+                load_conv_weights(resblock.convd, state_dict, f'{name}.layer{l}.{b}.downsample.0')
+                load_bn_weights(resblock.bnd, state_dict, f'{name}.layer{l}.{b}.downsample.1')
+            load_conv_weights(resblock.conv1, state_dict, f'{name}.layer{l}.{b}.conv1')
+            load_conv_weights(resblock.conv2, state_dict, f'{name}.layer{l}.{b}.conv2')
+            load_conv_weights(resblock.conv3, state_dict, f'{name}.layer{l}.{b}.conv3')
+            load_bn_weights(resblock.bn1, state_dict, f'{name}.layer{l}.{b}.bn1')
+            load_bn_weights(resblock.bn2, state_dict, f'{name}.layer{l}.{b}.bn2')
+            load_bn_weights(resblock.bn3, state_dict, f'{name}.layer{l}.{b}.bn3')
+
+# --------------------------- LRASPP ---------------------------
+def load_LRASPP_weights(module, state_dict, name):
+    load_conv_weights(module.aspp1.layers[0], state_dict, f'{name}.aspp1.0')
+    load_bn_weights(module.aspp1.layers[1], state_dict, f'{name}.aspp1.1')
+    load_conv_weights(module.aspp2, state_dict, f'{name}.aspp2.1')
+        
+# --------------------------- RecurrentDecoder ---------------------------
+def load_ConvGRU_weights(module, state_dict, name):
+    load_conv_weights(module.ih, state_dict, f'{name}.ih.0')
+    load_conv_weights(module.hh, state_dict, f'{name}.hh.0')
+
+def load_BottleneckBlock_weights(module, state_dict, name):
+    load_ConvGRU_weights(module.gru, state_dict, f'{name}.gru')
+
+def load_UpsamplingBlock_weights(module, state_dict, name):
+    load_conv_weights(module.conv.layers[0], state_dict, f'{name}.conv.0')
+    load_bn_weights(module.conv.layers[1], state_dict, f'{name}.conv.1')
+    load_ConvGRU_weights(module.gru, state_dict, f'{name}.gru')
+
+def load_OutputBlock_weights(module, state_dict, name):
+    load_conv_weights(module.conv.layers[0], state_dict, f'{name}.conv.0')
+    load_bn_weights(module.conv.layers[1], state_dict, f'{name}.conv.1')
+    load_conv_weights(module.conv.layers[3], state_dict, f'{name}.conv.3')
+    load_bn_weights(module.conv.layers[4], state_dict, f'{name}.conv.4')
+
+def load_RecurrentDecoder_weights(module, state_dict, name):
+    load_BottleneckBlock_weights(module.decode4, state_dict, f'{name}.decode4')
+    load_UpsamplingBlock_weights(module.decode3, state_dict, f'{name}.decode3')
+    load_UpsamplingBlock_weights(module.decode2, state_dict, f'{name}.decode2')
+    load_UpsamplingBlock_weights(module.decode1, state_dict, f'{name}.decode1')
+    load_OutputBlock_weights(module.decode0, state_dict, f'{name}.decode0')
+    
+# --------------------------- DeepGuidedFilter ---------------------------
+def load_DeepGuidedFilter_weights(module, state_dict, name):
+    load_conv_weights(module.box_filter.layers[1], state_dict, f'{name}.box_filter')
+    load_conv_weights(module.conv.layers[0], state_dict, f'{name}.conv.0')
+    load_bn_weights(module.conv.layers[1], state_dict, f'{name}.conv.1')
+    load_conv_weights(module.conv.layers[3], state_dict, f'{name}.conv.3')
+    load_bn_weights(module.conv.layers[4], state_dict, f'{name}.conv.4')
+    load_conv_weights(module.conv.layers[6], state_dict, f'{name}.conv.6')

+ 18 - 0
model/lraspp.py

@@ -0,0 +1,18 @@
+import tensorflow as tf
+from tensorflow.keras.models import Sequential
+from tensorflow.keras.layers import Layer, Conv2D, BatchNormalization, ReLU, Activation
+
+class LRASPP(Layer):
+    def __init__(self, out_channels: int):
+        super().__init__()
+        self.aspp1 = Sequential([
+            Conv2D(out_channels, 1, use_bias=False),
+            BatchNormalization(momentum=0.1, epsilon=1e-5),
+            ReLU()
+        ])
+        self.aspp2 = Conv2D(out_channels, 1, use_bias=False, activation='sigmoid')
+        
+    def call(self, x):
+        x1 = self.aspp1(x)
+        x2 = self.aspp2(tf.reduce_mean(x, axis=[1, 2], keepdims=True))
+        return x1 * x2

+ 168 - 0
model/mobilenetv3.py

@@ -0,0 +1,168 @@
+import tensorflow as tf
+from tensorflow.keras.models import Sequential
+from tensorflow.keras.layers import Layer, Conv2D, BatchNormalization, ReLU, ZeroPadding2D, DepthwiseConv2D
+from typing import Optional
+
+from .utils import normalize
+
+
+def _make_divisible(v: float, divisor: int, min_value: Optional[int] = None) -> int:
+    if min_value is None:
+        min_value = divisor
+    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
+    # Make sure that round down does not go down by more than 10%.
+    if new_v < 0.9 * v:
+        new_v += divisor
+    return new_v
+
+
+def _hard_sigmoid(x):
+    return tf.nn.relu6(x + 3) / 6
+
+
+class SqueezeExcitation(Layer):
+    def __init__(self, input_channels: int, squeeze_factor: int = 4):
+        super().__init__()
+        squeeze_channels = _make_divisible(input_channels // squeeze_factor, 8)
+        self.fc1 = Conv2D(squeeze_channels, 1)
+        self.relu = ReLU()
+        self.fc2 = Conv2D(input_channels, 1)
+        
+    def call(self, x):
+        scale = tf.reduce_mean(x, axis=[1,2], keepdims=True)
+        scale = self.fc1(scale)
+        scale = self.relu(scale)
+        scale = self.fc2(scale)
+        scale = _hard_sigmoid(scale)
+        return scale * x
+    
+    
+class Hardswish(Layer):
+    def call(self, x):
+        return x * _hard_sigmoid(x)
+
+
+class ConvBNActivation(Layer):
+    def __init__(self, filters, kernel_size, stride=1, groups=1, dilation=1, activation_layer=None):
+        super().__init__()
+        padding = (kernel_size - 1) // 2 * dilation
+        if padding != 0:
+            self.pad = ZeroPadding2D(padding)
+        if groups == 1:
+            self.conv = Conv2D(filters, kernel_size, stride, dilation_rate=dilation, groups=groups, use_bias=False)
+        else:
+            self.conv = DepthwiseConv2D(kernel_size, stride, dilation_rate=dilation, use_bias=False)
+        self.bn = BatchNormalization(momentum=0.01, epsilon=1e-3)
+        if activation_layer:
+            self.act = activation_layer()
+
+    def call(self, x):
+        if hasattr(self, 'pad'):
+            x = self.pad(x)
+        x = self.conv(x)
+        x = self.bn(x)
+        if hasattr(self, 'act'):
+            x = self.act(x)
+        return x
+
+
+class InvertedResidual(Layer):
+    def __init__(self,
+                 input_channels: int,
+                 kernel: int,
+                 expanded_channels: int,
+                 out_channels: int,
+                 use_se: bool,
+                 activation: str,
+                 stride: int,
+                 dilation: int):
+        super().__init__()
+        if not (1 <= stride <= 2):
+            raise ValueError('illegal stride value')
+
+        self.use_res_connect = stride == 1 and input_channels == out_channels
+
+        layers = []
+        activation_layer = Hardswish if activation == 'HS' else ReLU
+
+        # expand
+        if expanded_channels != input_channels:
+            layers.append(ConvBNActivation(
+                expanded_channels,
+                kernel_size=1,
+                activation_layer=activation_layer))
+        
+        # depthwise
+        stride = 1 if dilation > 1 else stride
+        layers.append(ConvBNActivation(
+            expanded_channels,
+            kernel_size=kernel,
+            stride=stride,
+            dilation=dilation,
+            groups=expanded_channels,
+            activation_layer=activation_layer))
+        if use_se:
+            layers.append(SqueezeExcitation(expanded_channels))
+
+        # project
+        layers.append(ConvBNActivation(
+            out_channels,
+            kernel_size=1,
+            activation_layer=None))
+
+        self.block = Sequential(layers)
+
+    def call(self, input):
+        result = self.block(input)
+        if self.use_res_connect:
+            result += input
+        return result
+
+
+class MobileNetV3Encoder(Layer):
+    def __init__(self):
+        super().__init__()
+        self.features = [
+            ConvBNActivation(16, kernel_size=3, stride=2, activation_layer=Hardswish),
+            InvertedResidual(16, 3, 16, 16, False, 'RE', 1, 1),
+            InvertedResidual(16, 3, 64, 24, False, 'RE', 2, 1), # C1
+            InvertedResidual(24, 3, 72, 24, False, 'RE', 1, 1),
+            InvertedResidual(24, 5, 72, 40, True, 'RE', 2, 1), # C2
+            InvertedResidual(40, 5, 120, 40, True, 'RE', 1, 1),
+            InvertedResidual(40, 5, 120, 40, True, 'RE', 1, 1),
+            InvertedResidual(40, 3, 240, 80, False, 'HS', 2, 1), # C3
+            InvertedResidual(80, 3, 200, 80, False, 'HS', 1, 1),
+            InvertedResidual(80, 3, 184, 80, False, 'HS', 1, 1),
+            InvertedResidual(80, 3, 184, 80, False, 'HS', 1, 1),
+            InvertedResidual(80, 3, 480, 112, True, 'HS', 1, 1),
+            InvertedResidual(112, 3, 672, 112, True, 'HS', 1, 1),
+            InvertedResidual(112, 5, 672, 160, True, 'HS', 2, 2), # C4
+            InvertedResidual(160, 5, 960, 160, True, 'HS', 1, 2),
+            InvertedResidual(160, 5, 960, 160, True, 'HS', 1, 2),
+            ConvBNActivation(960, kernel_size=1, activation_layer=Hardswish)
+        ]
+        
+    def call(self, x):
+        x = normalize(x, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+        x = self.features[0](x)
+        x = self.features[1](x)
+        f1 = x
+        x = self.features[2](x)
+        x = self.features[3](x)
+        f2 = x
+        x = self.features[4](x)
+        x = self.features[5](x)
+        x = self.features[6](x)
+        f3 = x
+        x = self.features[7](x)
+        x = self.features[8](x)
+        x = self.features[9](x)
+        x = self.features[10](x)
+        x = self.features[11](x)
+        x = self.features[12](x)
+        x = self.features[13](x)
+        x = self.features[14](x)
+        x = self.features[15](x)
+        x = self.features[16](x)
+        f4 = x
+        return f1, f2, f3, f4

+ 62 - 0
model/model.py

@@ -0,0 +1,62 @@
+import tensorflow as tf
+from tensorflow.keras.models import Model
+from tensorflow.keras.layers import Conv2D
+
+from .mobilenetv3 import MobileNetV3Encoder
+from .resnet import ResNet50Encoder
+from .lraspp import LRASPP
+from .decoder import RecurrentDecoder
+from .fast_guided_filter import FastGuidedFilterRefiner
+from .deep_guided_filter import DeepGuidedFilterRefiner
+
+class MattingNetwork(Model):
+    def __init__(self,
+                 variant: str = 'mobilenetv3',
+                 refiner: str = 'deep_guided_filter'):
+        super().__init__()
+        assert variant in ['mobilenetv3', 'resnet50']
+        assert refiner in ['deep_guided_filter', 'fast_guided_filter']
+        
+        if variant == 'mobilenetv3':
+            self.backbone = MobileNetV3Encoder()
+            self.aspp = LRASPP(128)
+            self.decoder = RecurrentDecoder([128, 80, 40, 32, 16])
+        else:
+            self.backbone = ResNet50Encoder()
+            self.aspp = LRASPP(256)
+            self.decoder = RecurrentDecoder([256, 128, 64, 32, 16])
+            
+        self.project_mat = Conv2D(4, 1)
+            
+        if refiner == 'deep_guided_filter':
+            self.refiner = DeepGuidedFilterRefiner()
+        else:
+            self.refiner = FastGuidedFilterRefiner()
+    
+    def call(self, inputs):
+        src, *rec, downsample_ratio = inputs
+
+        src_sm = tf.cond(downsample_ratio == 1,
+            lambda: src,
+            lambda: self._downsample(src, downsample_ratio))
+        
+        f1, f2, f3, f4 = self.backbone(src_sm)
+        f4 = self.aspp(f4)
+        hid, *rec = self.decoder([src_sm, f1, f2, f3, f4, *rec])
+        out = self.project_mat(hid)
+        fgr_residual, pha = tf.split(out, [3, 1], -1)
+        
+        fgr_residual, pha = tf.cond(downsample_ratio == 1,
+            lambda: (fgr_residual, pha), 
+            lambda: self.refiner([src, src_sm, fgr_residual, pha, hid]))
+        
+        fgr = fgr_residual + src
+        fgr = tf.clip_by_value(fgr, 0, 1)
+        pha = tf.clip_by_value(pha, 0, 1)
+        return fgr, pha, *rec
+    
+    def _downsample(self, x, downsample_ratio):
+        size = tf.shape(x)[1:3]
+        size = tf.cast(size, tf.float32) * tf.cast(downsample_ratio, tf.float32)
+        size = tf.cast(size, tf.int32)
+        return tf.image.resize(x, size)

+ 81 - 0
model/resnet.py

@@ -0,0 +1,81 @@
+from tensorflow.keras.models import Sequential
+from tensorflow.keras.layers import Layer, Conv2D, BatchNormalization, ReLU, MaxPool2D, ZeroPadding2D
+
+
+class ResNet50Encoder(Layer):
+    def __init__(self):
+        super().__init__()
+
+        blocks = [3, 4, 6, 3]
+        filters = [64, 256, 512, 1024, 2048]
+        
+        self.pad1 = ZeroPadding2D(3)
+        self.conv1 = Conv2D(filters[0], 7, 2, use_bias=False)
+        self.bn1 = BatchNormalization(momentum=0.1, epsilon=1e-5)
+        self.relu = ReLU()
+        self.pad2 = ZeroPadding2D(1)
+        self.maxpool = MaxPool2D(3, 2)
+
+        self.layer1 = self._make_layer(filters[1], blocks[0], strides=1, dilation_rate=1)
+        self.layer2 = self._make_layer(filters[2], blocks[1], strides=2, dilation_rate=1)
+        self.layer3 = self._make_layer(filters[3], blocks[2], strides=2, dilation_rate=1)
+        self.layer4 = self._make_layer(filters[4], blocks[3], strides=1, dilation_rate=2)
+
+    def _make_layer(self, filters, blocks, strides, dilation_rate):
+        layers = [ResNetBlock(filters, 3, strides, 1, True)]
+        for _ in range(1, blocks):
+            layers.append(ResNetBlock(filters, 3, 1, dilation_rate, False))
+        return Sequential(layers)
+        
+    def call(self, x, training=None):
+        x = self.pad1(x)
+        x = self.conv1(x, training=training)
+        x = self.bn1(x, training=training)
+        x = self.relu(x, training=training)
+        f1 = x  # 1/2
+        x = self.pad2(x)
+        x = self.maxpool(x, training=training)
+        x = self.layer1(x, training=training)
+        f2 = x  # 1/4
+        x = self.layer2(x, training=training)
+        f3 = x  # 1/8
+        x = self.layer3(x, training=training)
+        x = self.layer4(x, training=training)
+        f4 = x  # 1/16
+        return f1, f2, f3, f4
+
+
+class ResNetBlock(Layer):
+    def __init__(self, filters, kernel_size=3, strides=1, dilation_rate=1, conv_shortcut=True):
+        super().__init__()
+        self.conv1 = Conv2D(filters // 4, 1, use_bias=False)
+        self.bn1 = BatchNormalization(momentum=0.1, epsilon=1e-5)
+        self.pad2 = ZeroPadding2D(dilation_rate)
+        self.conv2 = Conv2D(filters // 4, kernel_size, strides, dilation_rate=dilation_rate, use_bias=False)
+        self.bn2 = BatchNormalization(momentum=0.1, epsilon=1e-5)
+        self.conv3 = Conv2D(filters, 1, use_bias=False)
+        self.bn3 = BatchNormalization(momentum=0.1, epsilon=1e-5)
+        self.relu = ReLU()
+        if conv_shortcut:
+            self.convd = Conv2D(filters, 1, strides, use_bias=False)
+            self.bnd = BatchNormalization(momentum=0.1, epsilon=1e-5)
+
+    def call(self, x, training=None):
+        if hasattr(self, 'convd'):
+            shortcut = self.convd(x, training=training)
+            shortcut = self.bnd(shortcut, training=training)
+        else:
+            shortcut = x
+
+        x = self.conv1(x, training=training)
+        x = self.bn1(x, training=training)
+        x = self.relu(x, training=training)
+        x = self.pad2(x, training=training)
+        x = self.conv2(x, training=training)
+        x = self.bn2(x, training=training)
+        x = self.relu(x, training=training)
+        x = self.conv3(x, training=training)
+        x = self.bn3(x, training=training)
+        x += shortcut
+        x = self.relu(x, training=training)
+        return x

+ 7 - 0
model/utils.py

@@ -0,0 +1,7 @@
+import tensorflow as tf
+
+def normalize(x, mean, std):
+    mean = tf.constant(mean)
+    std = tf.constant(std)
+    return (x - mean) / std
+