mobilenetv3.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. import tensorflow as tf
  2. from tensorflow.keras.models import Sequential
  3. from tensorflow.keras.layers import Layer, Conv2D, BatchNormalization, ReLU, ZeroPadding2D, DepthwiseConv2D
  4. from typing import Optional
  5. from .utils import normalize
  6. def _make_divisible(v: float, divisor: int, min_value: Optional[int] = None) -> int:
  7. if min_value is None:
  8. min_value = divisor
  9. new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
  10. # Make sure that round down does not go down by more than 10%.
  11. if new_v < 0.9 * v:
  12. new_v += divisor
  13. return new_v
  14. def _hard_sigmoid(x):
  15. return tf.nn.relu6(x + 3) / 6
  16. class SqueezeExcitation(Layer):
  17. def __init__(self, input_channels: int, squeeze_factor: int = 4):
  18. super().__init__()
  19. squeeze_channels = _make_divisible(input_channels // squeeze_factor, 8)
  20. self.fc1 = Conv2D(squeeze_channels, 1)
  21. self.relu = ReLU()
  22. self.fc2 = Conv2D(input_channels, 1)
  23. def call(self, x):
  24. scale = tf.reduce_mean(x, axis=[1,2], keepdims=True)
  25. scale = self.fc1(scale)
  26. scale = self.relu(scale)
  27. scale = self.fc2(scale)
  28. scale = _hard_sigmoid(scale)
  29. return scale * x
  30. class Hardswish(Layer):
  31. def call(self, x):
  32. return x * _hard_sigmoid(x)
  33. class ConvBNActivation(Layer):
  34. def __init__(self, filters, kernel_size, stride=1, groups=1, dilation=1, activation_layer=None):
  35. super().__init__()
  36. padding = (kernel_size - 1) // 2 * dilation
  37. if padding != 0:
  38. self.pad = ZeroPadding2D(padding)
  39. if groups == 1:
  40. self.conv = Conv2D(filters, kernel_size, stride, dilation_rate=dilation, groups=groups, use_bias=False)
  41. else:
  42. self.conv = DepthwiseConv2D(kernel_size, stride, dilation_rate=dilation, use_bias=False)
  43. self.bn = BatchNormalization(momentum=0.01, epsilon=1e-3)
  44. if activation_layer:
  45. self.act = activation_layer()
  46. def call(self, x):
  47. if hasattr(self, 'pad'):
  48. x = self.pad(x)
  49. x = self.conv(x)
  50. x = self.bn(x)
  51. if hasattr(self, 'act'):
  52. x = self.act(x)
  53. return x
  54. class InvertedResidual(Layer):
  55. def __init__(self,
  56. input_channels: int,
  57. kernel: int,
  58. expanded_channels: int,
  59. out_channels: int,
  60. use_se: bool,
  61. activation: str,
  62. stride: int,
  63. dilation: int):
  64. super().__init__()
  65. if not (1 <= stride <= 2):
  66. raise ValueError('illegal stride value')
  67. self.use_res_connect = stride == 1 and input_channels == out_channels
  68. layers = []
  69. activation_layer = Hardswish if activation == 'HS' else ReLU
  70. # expand
  71. if expanded_channels != input_channels:
  72. layers.append(ConvBNActivation(
  73. expanded_channels,
  74. kernel_size=1,
  75. activation_layer=activation_layer))
  76. # depthwise
  77. stride = 1 if dilation > 1 else stride
  78. layers.append(ConvBNActivation(
  79. expanded_channels,
  80. kernel_size=kernel,
  81. stride=stride,
  82. dilation=dilation,
  83. groups=expanded_channels,
  84. activation_layer=activation_layer))
  85. if use_se:
  86. layers.append(SqueezeExcitation(expanded_channels))
  87. # project
  88. layers.append(ConvBNActivation(
  89. out_channels,
  90. kernel_size=1,
  91. activation_layer=None))
  92. self.block = Sequential(layers)
  93. def call(self, input):
  94. result = self.block(input)
  95. if self.use_res_connect:
  96. result += input
  97. return result
  98. class MobileNetV3Encoder(Layer):
  99. def __init__(self):
  100. super().__init__()
  101. self.features = [
  102. ConvBNActivation(16, kernel_size=3, stride=2, activation_layer=Hardswish),
  103. InvertedResidual(16, 3, 16, 16, False, 'RE', 1, 1),
  104. InvertedResidual(16, 3, 64, 24, False, 'RE', 2, 1), # C1
  105. InvertedResidual(24, 3, 72, 24, False, 'RE', 1, 1),
  106. InvertedResidual(24, 5, 72, 40, True, 'RE', 2, 1), # C2
  107. InvertedResidual(40, 5, 120, 40, True, 'RE', 1, 1),
  108. InvertedResidual(40, 5, 120, 40, True, 'RE', 1, 1),
  109. InvertedResidual(40, 3, 240, 80, False, 'HS', 2, 1), # C3
  110. InvertedResidual(80, 3, 200, 80, False, 'HS', 1, 1),
  111. InvertedResidual(80, 3, 184, 80, False, 'HS', 1, 1),
  112. InvertedResidual(80, 3, 184, 80, False, 'HS', 1, 1),
  113. InvertedResidual(80, 3, 480, 112, True, 'HS', 1, 1),
  114. InvertedResidual(112, 3, 672, 112, True, 'HS', 1, 1),
  115. InvertedResidual(112, 5, 672, 160, True, 'HS', 2, 2), # C4
  116. InvertedResidual(160, 5, 960, 160, True, 'HS', 1, 2),
  117. InvertedResidual(160, 5, 960, 160, True, 'HS', 1, 2),
  118. ConvBNActivation(960, kernel_size=1, activation_layer=Hardswish)
  119. ]
  120. def call(self, x):
  121. x = normalize(x, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  122. x = self.features[0](x)
  123. x = self.features[1](x)
  124. f1 = x
  125. x = self.features[2](x)
  126. x = self.features[3](x)
  127. f2 = x
  128. x = self.features[4](x)
  129. x = self.features[5](x)
  130. x = self.features[6](x)
  131. f3 = x
  132. x = self.features[7](x)
  133. x = self.features[8](x)
  134. x = self.features[9](x)
  135. x = self.features[10](x)
  136. x = self.features[11](x)
  137. x = self.features[12](x)
  138. x = self.features[13](x)
  139. x = self.features[14](x)
  140. x = self.features[15](x)
  141. x = self.features[16](x)
  142. f4 = x
  143. return f1, f2, f3, f4