123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899 |
- import tensorflow as tf
- from tensorflow.keras.models import Sequential
- from tensorflow.keras.layers import Layer, Conv2D, BatchNormalization, ReLU, Activation, AveragePooling2D, UpSampling2D
- class RecurrentDecoder(Layer):
- def __init__(self, channels):
- super().__init__()
- self.avgpool = AveragePooling2D(padding='SAME')
- self.decode4 = BottleneckBlock(channels[0])
- self.decode3 = UpsamplingBlock(channels[1])
- self.decode2 = UpsamplingBlock(channels[2])
- self.decode1 = UpsamplingBlock(channels[3])
- self.decode0 = OutputBlock(channels[4])
-
- def call(self, inputs):
- s0, f1, f2, f3, f4, r1, r2, r3, r4 = inputs
- s1 = self.avgpool(s0)
- s2 = self.avgpool(s1)
- s3 = self.avgpool(s2)
- x4, r4 = self.decode4([f4, r4])
- x3, r3 = self.decode3([x4, f3, s3, r3])
- x2, r2 = self.decode2([x3, f2, s2, r2])
- x1, r1 = self.decode1([x2, f1, s1, r1])
- x0 = self.decode0([x1, s0])
- return x0, r1, r2, r3, r4
- class BottleneckBlock(Layer):
- def __init__(self, channels):
- super().__init__()
- self.gru = ConvGRU(channels // 2)
-
- def call(self, inputs):
- x, r = inputs
- a, b = tf.split(x, 2, -1)
- b, r = self.gru([b, r])
- x = tf.concat([a, b], -1)
- return x, r
- class UpsamplingBlock(Layer):
- def __init__(self, channels):
- super().__init__()
- self.upsample = UpSampling2D(interpolation='bilinear')
- self.conv = Sequential([
- Conv2D(channels, 3, padding='SAME', use_bias=False),
- BatchNormalization(momentum=0.1, epsilon=1e-5),
- ReLU()
- ])
- self.gru = ConvGRU(channels // 2)
-
- def call(self, inputs):
- x, f, s, r = inputs
- x = self.upsample(x)
- x = tf.image.crop_to_bounding_box(x, 0, 0, tf.shape(s)[1], tf.shape(s)[2])
- x = tf.concat([x, f, s], -1)
- x = self.conv(x)
- a, b = tf.split(x, 2, -1)
- b, r = self.gru([b, r])
- x = tf.concat([a, b], -1)
- return x, r
- class OutputBlock(Layer):
- def __init__(self, channels):
- super().__init__()
- self.upsample = UpSampling2D(interpolation='bilinear')
- self.conv = Sequential([
- Conv2D(channels, 3, padding='SAME', use_bias=False),
- BatchNormalization(momentum=0.1, epsilon=1e-5),
- ReLU(),
- Conv2D(channels, 3, padding='SAME', use_bias=False),
- BatchNormalization(momentum=0.1, epsilon=1e-5),
- ReLU(),
- ])
-
- def call(self, inputs):
- x, s = inputs
- x = self.upsample(x)
- x = tf.image.crop_to_bounding_box(x, 0, 0, tf.shape(s)[1], tf.shape(s)[2])
- x = tf.concat([x, s], -1)
- x = self.conv(x)
- return x
- class ConvGRU(Layer):
- def __init__(self, channels, kernel_size=3):
- super().__init__()
- self.channels = channels
- self.ih = Conv2D(channels * 2, kernel_size, padding='SAME', activation='sigmoid')
- self.hh = Conv2D(channels, kernel_size, padding='SAME', activation='tanh')
-
- def call(self, inputs):
- x, h = inputs
- h = tf.broadcast_to(h, tf.shape(x))
- r, z = tf.split(self.ih(tf.concat([x, h], -1)), 2, -1)
- c = self.hh(tf.concat([x, r * h], -1))
- h = (1 - z) * h + z * c
- return h, h
|