resnet.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. from tensorflow.keras.models import Sequential
  2. from tensorflow.keras.layers import Layer, Conv2D, BatchNormalization, ReLU, MaxPool2D, ZeroPadding2D
  3. class ResNet50Encoder(Layer):
  4. def __init__(self):
  5. super().__init__()
  6. blocks = [3, 4, 6, 3]
  7. filters = [64, 256, 512, 1024, 2048]
  8. self.pad1 = ZeroPadding2D(3)
  9. self.conv1 = Conv2D(filters[0], 7, 2, use_bias=False)
  10. self.bn1 = BatchNormalization(momentum=0.1, epsilon=1e-5)
  11. self.relu = ReLU()
  12. self.pad2 = ZeroPadding2D(1)
  13. self.maxpool = MaxPool2D(3, 2)
  14. self.layer1 = self._make_layer(filters[1], blocks[0], strides=1, dilation_rate=1)
  15. self.layer2 = self._make_layer(filters[2], blocks[1], strides=2, dilation_rate=1)
  16. self.layer3 = self._make_layer(filters[3], blocks[2], strides=2, dilation_rate=1)
  17. self.layer4 = self._make_layer(filters[4], blocks[3], strides=1, dilation_rate=2)
  18. def _make_layer(self, filters, blocks, strides, dilation_rate):
  19. layers = [ResNetBlock(filters, 3, strides, 1, True)]
  20. for _ in range(1, blocks):
  21. layers.append(ResNetBlock(filters, 3, 1, dilation_rate, False))
  22. return Sequential(layers)
  23. def call(self, x, training=None):
  24. x = self.pad1(x)
  25. x = self.conv1(x, training=training)
  26. x = self.bn1(x, training=training)
  27. x = self.relu(x, training=training)
  28. f1 = x # 1/2
  29. x = self.pad2(x)
  30. x = self.maxpool(x, training=training)
  31. x = self.layer1(x, training=training)
  32. f2 = x # 1/4
  33. x = self.layer2(x, training=training)
  34. f3 = x # 1/8
  35. x = self.layer3(x, training=training)
  36. x = self.layer4(x, training=training)
  37. f4 = x # 1/16
  38. return f1, f2, f3, f4
  39. class ResNetBlock(Layer):
  40. def __init__(self, filters, kernel_size=3, strides=1, dilation_rate=1, conv_shortcut=True):
  41. super().__init__()
  42. self.conv1 = Conv2D(filters // 4, 1, use_bias=False)
  43. self.bn1 = BatchNormalization(momentum=0.1, epsilon=1e-5)
  44. self.pad2 = ZeroPadding2D(dilation_rate)
  45. self.conv2 = Conv2D(filters // 4, kernel_size, strides, dilation_rate=dilation_rate, use_bias=False)
  46. self.bn2 = BatchNormalization(momentum=0.1, epsilon=1e-5)
  47. self.conv3 = Conv2D(filters, 1, use_bias=False)
  48. self.bn3 = BatchNormalization(momentum=0.1, epsilon=1e-5)
  49. self.relu = ReLU()
  50. if conv_shortcut:
  51. self.convd = Conv2D(filters, 1, strides, use_bias=False)
  52. self.bnd = BatchNormalization(momentum=0.1, epsilon=1e-5)
  53. def call(self, x, training=None):
  54. if hasattr(self, 'convd'):
  55. shortcut = self.convd(x, training=training)
  56. shortcut = self.bnd(shortcut, training=training)
  57. else:
  58. shortcut = x
  59. x = self.conv1(x, training=training)
  60. x = self.bn1(x, training=training)
  61. x = self.relu(x, training=training)
  62. x = self.pad2(x, training=training)
  63. x = self.conv2(x, training=training)
  64. x = self.bn2(x, training=training)
  65. x = self.relu(x, training=training)
  66. x = self.conv3(x, training=training)
  67. x = self.bn3(x, training=training)
  68. x += shortcut
  69. x = self.relu(x, training=training)
  70. return x