decoder.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. import tensorflow as tf
  2. from tensorflow.keras.models import Sequential
  3. from tensorflow.keras.layers import Layer, Conv2D, BatchNormalization, ReLU, Activation, AveragePooling2D, UpSampling2D
  4. class RecurrentDecoder(Layer):
  5. def __init__(self, channels):
  6. super().__init__()
  7. self.avgpool = AveragePooling2D(padding='SAME')
  8. self.decode4 = BottleneckBlock(channels[0])
  9. self.decode3 = UpsamplingBlock(channels[1])
  10. self.decode2 = UpsamplingBlock(channels[2])
  11. self.decode1 = UpsamplingBlock(channels[3])
  12. self.decode0 = OutputBlock(channels[4])
  13. def call(self, inputs):
  14. s0, f1, f2, f3, f4, r1, r2, r3, r4 = inputs
  15. s1 = self.avgpool(s0)
  16. s2 = self.avgpool(s1)
  17. s3 = self.avgpool(s2)
  18. x4, r4 = self.decode4([f4, r4])
  19. x3, r3 = self.decode3([x4, f3, s3, r3])
  20. x2, r2 = self.decode2([x3, f2, s2, r2])
  21. x1, r1 = self.decode1([x2, f1, s1, r1])
  22. x0 = self.decode0([x1, s0])
  23. return x0, r1, r2, r3, r4
  24. class BottleneckBlock(Layer):
  25. def __init__(self, channels):
  26. super().__init__()
  27. self.gru = ConvGRU(channels // 2)
  28. def call(self, inputs):
  29. x, r = inputs
  30. a, b = tf.split(x, 2, -1)
  31. b, r = self.gru([b, r])
  32. x = tf.concat([a, b], -1)
  33. return x, r
  34. class UpsamplingBlock(Layer):
  35. def __init__(self, channels):
  36. super().__init__()
  37. self.upsample = UpSampling2D(interpolation='bilinear')
  38. self.conv = Sequential([
  39. Conv2D(channels, 3, padding='SAME', use_bias=False),
  40. BatchNormalization(momentum=0.1, epsilon=1e-5),
  41. ReLU()
  42. ])
  43. self.gru = ConvGRU(channels // 2)
  44. def call(self, inputs):
  45. x, f, s, r = inputs
  46. x = self.upsample(x)
  47. x = tf.image.crop_to_bounding_box(x, 0, 0, tf.shape(s)[1], tf.shape(s)[2])
  48. x = tf.concat([x, f, s], -1)
  49. x = self.conv(x)
  50. a, b = tf.split(x, 2, -1)
  51. b, r = self.gru([b, r])
  52. x = tf.concat([a, b], -1)
  53. return x, r
  54. class OutputBlock(Layer):
  55. def __init__(self, channels):
  56. super().__init__()
  57. self.upsample = UpSampling2D(interpolation='bilinear')
  58. self.conv = Sequential([
  59. Conv2D(channels, 3, padding='SAME', use_bias=False),
  60. BatchNormalization(momentum=0.1, epsilon=1e-5),
  61. ReLU(),
  62. Conv2D(channels, 3, padding='SAME', use_bias=False),
  63. BatchNormalization(momentum=0.1, epsilon=1e-5),
  64. ReLU(),
  65. ])
  66. def call(self, inputs):
  67. x, s = inputs
  68. x = self.upsample(x)
  69. x = tf.image.crop_to_bounding_box(x, 0, 0, tf.shape(s)[1], tf.shape(s)[2])
  70. x = tf.concat([x, s], -1)
  71. x = self.conv(x)
  72. return x
  73. class ConvGRU(Layer):
  74. def __init__(self, channels, kernel_size=3):
  75. super().__init__()
  76. self.channels = channels
  77. self.ih = Conv2D(channels * 2, kernel_size, padding='SAME', activation='sigmoid')
  78. self.hh = Conv2D(channels, kernel_size, padding='SAME', activation='tanh')
  79. def call(self, inputs):
  80. x, h = inputs
  81. h = tf.broadcast_to(h, tf.shape(x))
  82. r, z = tf.split(self.ih(tf.concat([x, h], -1)), 2, -1)
  83. c = self.hh(tf.concat([x, r * h], -1))
  84. h = (1 - z) * h + z * c
  85. return h, h