1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162 |
- import tensorflow as tf
- from tensorflow.keras.models import Model
- from tensorflow.keras.layers import Conv2D
- from .mobilenetv3 import MobileNetV3Encoder
- from .resnet import ResNet50Encoder
- from .lraspp import LRASPP
- from .decoder import RecurrentDecoder
- from .fast_guided_filter import FastGuidedFilterRefiner
- from .deep_guided_filter import DeepGuidedFilterRefiner
- class MattingNetwork(Model):
- def __init__(self,
- variant: str = 'mobilenetv3',
- refiner: str = 'deep_guided_filter'):
- super().__init__()
- assert variant in ['mobilenetv3', 'resnet50']
- assert refiner in ['deep_guided_filter', 'fast_guided_filter']
-
- if variant == 'mobilenetv3':
- self.backbone = MobileNetV3Encoder()
- self.aspp = LRASPP(128)
- self.decoder = RecurrentDecoder([128, 80, 40, 32, 16])
- else:
- self.backbone = ResNet50Encoder()
- self.aspp = LRASPP(256)
- self.decoder = RecurrentDecoder([256, 128, 64, 32, 16])
-
- self.project_mat = Conv2D(4, 1)
-
- if refiner == 'deep_guided_filter':
- self.refiner = DeepGuidedFilterRefiner()
- else:
- self.refiner = FastGuidedFilterRefiner()
-
- def call(self, inputs):
- src, *rec, downsample_ratio = inputs
- src_sm = tf.cond(downsample_ratio == 1,
- lambda: src,
- lambda: self._downsample(src, downsample_ratio))
-
- f1, f2, f3, f4 = self.backbone(src_sm)
- f4 = self.aspp(f4)
- hid, *rec = self.decoder([src_sm, f1, f2, f3, f4, *rec])
- out = self.project_mat(hid)
- fgr_residual, pha = tf.split(out, [3, 1], -1)
-
- fgr_residual, pha = tf.cond(downsample_ratio == 1,
- lambda: (fgr_residual, pha),
- lambda: self.refiner([src, src_sm, fgr_residual, pha, hid]))
-
- fgr = fgr_residual + src
- fgr = tf.clip_by_value(fgr, 0, 1)
- pha = tf.clip_by_value(pha, 0, 1)
- return fgr, pha, *rec
-
- def _downsample(self, x, downsample_ratio):
- size = tf.shape(x)[1:3]
- size = tf.cast(size, tf.float32) * tf.cast(downsample_ratio, tf.float32)
- size = tf.cast(size, tf.int32)
- return tf.image.resize(x, size)
|