123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168 |
- import tensorflow as tf
- from tensorflow.keras.models import Sequential
- from tensorflow.keras.layers import Layer, Conv2D, BatchNormalization, ReLU, ZeroPadding2D, DepthwiseConv2D
- from typing import Optional
- from .utils import normalize
- def _make_divisible(v: float, divisor: int, min_value: Optional[int] = None) -> int:
- if min_value is None:
- min_value = divisor
- new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
- # Make sure that round down does not go down by more than 10%.
- if new_v < 0.9 * v:
- new_v += divisor
- return new_v
- def _hard_sigmoid(x):
- return tf.nn.relu6(x + 3) / 6
- class SqueezeExcitation(Layer):
- def __init__(self, input_channels: int, squeeze_factor: int = 4):
- super().__init__()
- squeeze_channels = _make_divisible(input_channels // squeeze_factor, 8)
- self.fc1 = Conv2D(squeeze_channels, 1)
- self.relu = ReLU()
- self.fc2 = Conv2D(input_channels, 1)
-
- def call(self, x):
- scale = tf.reduce_mean(x, axis=[1,2], keepdims=True)
- scale = self.fc1(scale)
- scale = self.relu(scale)
- scale = self.fc2(scale)
- scale = _hard_sigmoid(scale)
- return scale * x
-
-
- class Hardswish(Layer):
- def call(self, x):
- return x * _hard_sigmoid(x)
- class ConvBNActivation(Layer):
- def __init__(self, filters, kernel_size, stride=1, groups=1, dilation=1, activation_layer=None):
- super().__init__()
- padding = (kernel_size - 1) // 2 * dilation
- if padding != 0:
- self.pad = ZeroPadding2D(padding)
- if groups == 1:
- self.conv = Conv2D(filters, kernel_size, stride, dilation_rate=dilation, groups=groups, use_bias=False)
- else:
- self.conv = DepthwiseConv2D(kernel_size, stride, dilation_rate=dilation, use_bias=False)
- self.bn = BatchNormalization(momentum=0.01, epsilon=1e-3)
- if activation_layer:
- self.act = activation_layer()
- def call(self, x):
- if hasattr(self, 'pad'):
- x = self.pad(x)
- x = self.conv(x)
- x = self.bn(x)
- if hasattr(self, 'act'):
- x = self.act(x)
- return x
- class InvertedResidual(Layer):
- def __init__(self,
- input_channels: int,
- kernel: int,
- expanded_channels: int,
- out_channels: int,
- use_se: bool,
- activation: str,
- stride: int,
- dilation: int):
- super().__init__()
- if not (1 <= stride <= 2):
- raise ValueError('illegal stride value')
- self.use_res_connect = stride == 1 and input_channels == out_channels
- layers = []
- activation_layer = Hardswish if activation == 'HS' else ReLU
- # expand
- if expanded_channels != input_channels:
- layers.append(ConvBNActivation(
- expanded_channels,
- kernel_size=1,
- activation_layer=activation_layer))
-
- # depthwise
- stride = 1 if dilation > 1 else stride
- layers.append(ConvBNActivation(
- expanded_channels,
- kernel_size=kernel,
- stride=stride,
- dilation=dilation,
- groups=expanded_channels,
- activation_layer=activation_layer))
- if use_se:
- layers.append(SqueezeExcitation(expanded_channels))
- # project
- layers.append(ConvBNActivation(
- out_channels,
- kernel_size=1,
- activation_layer=None))
- self.block = Sequential(layers)
- def call(self, input):
- result = self.block(input)
- if self.use_res_connect:
- result += input
- return result
- class MobileNetV3Encoder(Layer):
- def __init__(self):
- super().__init__()
- self.features = [
- ConvBNActivation(16, kernel_size=3, stride=2, activation_layer=Hardswish),
- InvertedResidual(16, 3, 16, 16, False, 'RE', 1, 1),
- InvertedResidual(16, 3, 64, 24, False, 'RE', 2, 1), # C1
- InvertedResidual(24, 3, 72, 24, False, 'RE', 1, 1),
- InvertedResidual(24, 5, 72, 40, True, 'RE', 2, 1), # C2
- InvertedResidual(40, 5, 120, 40, True, 'RE', 1, 1),
- InvertedResidual(40, 5, 120, 40, True, 'RE', 1, 1),
- InvertedResidual(40, 3, 240, 80, False, 'HS', 2, 1), # C3
- InvertedResidual(80, 3, 200, 80, False, 'HS', 1, 1),
- InvertedResidual(80, 3, 184, 80, False, 'HS', 1, 1),
- InvertedResidual(80, 3, 184, 80, False, 'HS', 1, 1),
- InvertedResidual(80, 3, 480, 112, True, 'HS', 1, 1),
- InvertedResidual(112, 3, 672, 112, True, 'HS', 1, 1),
- InvertedResidual(112, 5, 672, 160, True, 'HS', 2, 2), # C4
- InvertedResidual(160, 5, 960, 160, True, 'HS', 1, 2),
- InvertedResidual(160, 5, 960, 160, True, 'HS', 1, 2),
- ConvBNActivation(960, kernel_size=1, activation_layer=Hardswish)
- ]
-
- def call(self, x):
- x = normalize(x, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
- x = self.features[0](x)
- x = self.features[1](x)
- f1 = x
- x = self.features[2](x)
- x = self.features[3](x)
- f2 = x
- x = self.features[4](x)
- x = self.features[5](x)
- x = self.features[6](x)
- f3 = x
- x = self.features[7](x)
- x = self.features[8](x)
- x = self.features[9](x)
- x = self.features[10](x)
- x = self.features[11](x)
- x = self.features[12](x)
- x = self.features[13](x)
- x = self.features[14](x)
- x = self.features[15](x)
- x = self.features[16](x)
- f4 = x
- return f1, f2, f3, f4
|