model.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. import tensorflow as tf
  2. from tensorflow.keras.models import Model
  3. from tensorflow.keras.layers import Conv2D
  4. from .mobilenetv3 import MobileNetV3Encoder
  5. from .resnet import ResNet50Encoder
  6. from .lraspp import LRASPP
  7. from .decoder import RecurrentDecoder
  8. from .fast_guided_filter import FastGuidedFilterRefiner
  9. from .deep_guided_filter import DeepGuidedFilterRefiner
  10. class MattingNetwork(Model):
  11. def __init__(self,
  12. variant: str = 'mobilenetv3',
  13. refiner: str = 'deep_guided_filter'):
  14. super().__init__()
  15. assert variant in ['mobilenetv3', 'resnet50']
  16. assert refiner in ['deep_guided_filter', 'fast_guided_filter']
  17. if variant == 'mobilenetv3':
  18. self.backbone = MobileNetV3Encoder()
  19. self.aspp = LRASPP(128)
  20. self.decoder = RecurrentDecoder([128, 80, 40, 32, 16])
  21. else:
  22. self.backbone = ResNet50Encoder()
  23. self.aspp = LRASPP(256)
  24. self.decoder = RecurrentDecoder([256, 128, 64, 32, 16])
  25. self.project_mat = Conv2D(4, 1)
  26. if refiner == 'deep_guided_filter':
  27. self.refiner = DeepGuidedFilterRefiner()
  28. else:
  29. self.refiner = FastGuidedFilterRefiner()
  30. def call(self, inputs):
  31. src, *rec, downsample_ratio = inputs
  32. src_sm = tf.cond(downsample_ratio == 1,
  33. lambda: src,
  34. lambda: self._downsample(src, downsample_ratio))
  35. f1, f2, f3, f4 = self.backbone(src_sm)
  36. f4 = self.aspp(f4)
  37. hid, *rec = self.decoder([src_sm, f1, f2, f3, f4, *rec])
  38. out = self.project_mat(hid)
  39. fgr_residual, pha = tf.split(out, [3, 1], -1)
  40. fgr_residual, pha = tf.cond(downsample_ratio == 1,
  41. lambda: (fgr_residual, pha),
  42. lambda: self.refiner([src, src_sm, fgr_residual, pha, hid]))
  43. fgr = fgr_residual + src
  44. fgr = tf.clip_by_value(fgr, 0, 1)
  45. pha = tf.clip_by_value(pha, 0, 1)
  46. return fgr, pha, *rec
  47. def _downsample(self, x, downsample_ratio):
  48. size = tf.shape(x)[1:3]
  49. size = tf.cast(size, tf.float32) * tf.cast(downsample_ratio, tf.float32)
  50. size = tf.cast(size, tf.int32)
  51. return tf.image.resize(x, size)