|
@@ -0,0 +1,113 @@
|
|
|
+from tensorflow.keras.layers import DepthwiseConv2D
|
|
|
+from .mobilenetv3 import *
|
|
|
+from .resnet import ResNet50Encoder
|
|
|
+from .deep_guided_filter import DeepGuidedFilterRefiner
|
|
|
+
|
|
|
+# --------------------------- Load torch weights ---------------------------
|
|
|
+def load_torch_weights(model, state_dict):
|
|
|
+ if isinstance(model.backbone, MobileNetV3Encoder):
|
|
|
+ load_MobileNetV3_weights(model.backbone, state_dict, 'backbone')
|
|
|
+ if isinstance(model.backbone, ResNet50Encoder):
|
|
|
+ load_ResNetEncoder_weights(model.backbone, state_dict, 'backbone')
|
|
|
+ load_LRASPP_weights(model.aspp, state_dict, 'aspp')
|
|
|
+ load_RecurrentDecoder_weights(model.decoder, state_dict, 'decoder')
|
|
|
+ load_conv_weights(model.project_mat, state_dict, 'project_mat.conv')
|
|
|
+ if isinstance(model.refiner, DeepGuidedFilterRefiner):
|
|
|
+ load_DeepGuidedFilter_weights(model.refiner, state_dict, 'refiner')
|
|
|
+
|
|
|
+# --------------------------- General ---------------------------
|
|
|
+def load_conv_weights(conv, state_dict, name):
|
|
|
+ weight = state_dict[name + '.weight']
|
|
|
+ if isinstance(conv, DepthwiseConv2D):
|
|
|
+ weight = weight.permute(2, 3, 0, 1).numpy()
|
|
|
+ else:
|
|
|
+ weight = weight.permute(2, 3, 1, 0).numpy()
|
|
|
+ if name + '.bias' in state_dict:
|
|
|
+ bias = state_dict[name + '.bias'].numpy()
|
|
|
+ conv.set_weights([weight, bias])
|
|
|
+ else:
|
|
|
+ conv.set_weights([weight])
|
|
|
+
|
|
|
+def load_bn_weights(bn, state_dict, name):
|
|
|
+ weight = state_dict[name + '.weight']
|
|
|
+ bias = state_dict[name + '.bias']
|
|
|
+ running_mean = state_dict[name + '.running_mean']
|
|
|
+ running_var = state_dict[name + '.running_var']
|
|
|
+ bn.set_weights([weight, bias, running_mean, running_var])
|
|
|
+
|
|
|
+# --------------------------- MobileNetV3 ---------------------------
|
|
|
+def load_ConvBNActivation_weights(module, state_dict, name):
|
|
|
+ load_conv_weights(module.conv, state_dict, name + '.0')
|
|
|
+ load_bn_weights(module.bn, state_dict, name + '.1')
|
|
|
+
|
|
|
+def load_InvertedResidual_weights(module, state_dict, name):
|
|
|
+ for i, layer in enumerate(module.block.layers):
|
|
|
+ if isinstance(layer, ConvBNActivation):
|
|
|
+ load_ConvBNActivation_weights(layer, state_dict, f'{name}.block.{i}')
|
|
|
+ if isinstance(layer, SqueezeExcitation):
|
|
|
+ load_conv_weights(layer.fc1, state_dict, f'{name}.block.{i}.fc1')
|
|
|
+ load_conv_weights(layer.fc2, state_dict, f'{name}.block.{i}.fc2')
|
|
|
+
|
|
|
+def load_MobileNetV3_weights(backbone, state_dict, name):
|
|
|
+ for i, module in enumerate(backbone.features):
|
|
|
+ if isinstance(module, ConvBNActivation):
|
|
|
+ load_ConvBNActivation_weights(module, state_dict, f'{name}.features.{i}')
|
|
|
+ if isinstance(module, InvertedResidual):
|
|
|
+ load_InvertedResidual_weights(module, state_dict, f'{name}.features.{i}')
|
|
|
+
|
|
|
+# --------------------------- ResNet ---------------------------
|
|
|
+def load_ResNetEncoder_weights(module, state_dict, name):
|
|
|
+ load_conv_weights(module.conv1, state_dict, f'{name}.conv1')
|
|
|
+ load_bn_weights(module.bn1, state_dict, f'{name}.bn1')
|
|
|
+ for l in range(1, 5):
|
|
|
+ for b, resblock in enumerate(getattr(module, f'layer{l}').layers):
|
|
|
+ if hasattr(resblock, 'convd'):
|
|
|
+ load_conv_weights(resblock.convd, state_dict, f'{name}.layer{l}.{b}.downsample.0')
|
|
|
+ load_bn_weights(resblock.bnd, state_dict, f'{name}.layer{l}.{b}.downsample.1')
|
|
|
+ load_conv_weights(resblock.conv1, state_dict, f'{name}.layer{l}.{b}.conv1')
|
|
|
+ load_conv_weights(resblock.conv2, state_dict, f'{name}.layer{l}.{b}.conv2')
|
|
|
+ load_conv_weights(resblock.conv3, state_dict, f'{name}.layer{l}.{b}.conv3')
|
|
|
+ load_bn_weights(resblock.bn1, state_dict, f'{name}.layer{l}.{b}.bn1')
|
|
|
+ load_bn_weights(resblock.bn2, state_dict, f'{name}.layer{l}.{b}.bn2')
|
|
|
+ load_bn_weights(resblock.bn3, state_dict, f'{name}.layer{l}.{b}.bn3')
|
|
|
+
|
|
|
+# --------------------------- LRASPP ---------------------------
|
|
|
+def load_LRASPP_weights(module, state_dict, name):
|
|
|
+ load_conv_weights(module.aspp1.layers[0], state_dict, f'{name}.aspp1.0')
|
|
|
+ load_bn_weights(module.aspp1.layers[1], state_dict, f'{name}.aspp1.1')
|
|
|
+ load_conv_weights(module.aspp2, state_dict, f'{name}.aspp2.1')
|
|
|
+
|
|
|
+# --------------------------- RecurrentDecoder ---------------------------
|
|
|
+def load_ConvGRU_weights(module, state_dict, name):
|
|
|
+ load_conv_weights(module.ih, state_dict, f'{name}.ih.0')
|
|
|
+ load_conv_weights(module.hh, state_dict, f'{name}.hh.0')
|
|
|
+
|
|
|
+def load_BottleneckBlock_weights(module, state_dict, name):
|
|
|
+ load_ConvGRU_weights(module.gru, state_dict, f'{name}.gru')
|
|
|
+
|
|
|
+def load_UpsamplingBlock_weights(module, state_dict, name):
|
|
|
+ load_conv_weights(module.conv.layers[0], state_dict, f'{name}.conv.0')
|
|
|
+ load_bn_weights(module.conv.layers[1], state_dict, f'{name}.conv.1')
|
|
|
+ load_ConvGRU_weights(module.gru, state_dict, f'{name}.gru')
|
|
|
+
|
|
|
+def load_OutputBlock_weights(module, state_dict, name):
|
|
|
+ load_conv_weights(module.conv.layers[0], state_dict, f'{name}.conv.0')
|
|
|
+ load_bn_weights(module.conv.layers[1], state_dict, f'{name}.conv.1')
|
|
|
+ load_conv_weights(module.conv.layers[3], state_dict, f'{name}.conv.3')
|
|
|
+ load_bn_weights(module.conv.layers[4], state_dict, f'{name}.conv.4')
|
|
|
+
|
|
|
+def load_RecurrentDecoder_weights(module, state_dict, name):
|
|
|
+ load_BottleneckBlock_weights(module.decode4, state_dict, f'{name}.decode4')
|
|
|
+ load_UpsamplingBlock_weights(module.decode3, state_dict, f'{name}.decode3')
|
|
|
+ load_UpsamplingBlock_weights(module.decode2, state_dict, f'{name}.decode2')
|
|
|
+ load_UpsamplingBlock_weights(module.decode1, state_dict, f'{name}.decode1')
|
|
|
+ load_OutputBlock_weights(module.decode0, state_dict, f'{name}.decode0')
|
|
|
+
|
|
|
+# --------------------------- DeepGuidedFilter ---------------------------
|
|
|
+def load_DeepGuidedFilter_weights(module, state_dict, name):
|
|
|
+ load_conv_weights(module.box_filter.layers[1], state_dict, f'{name}.box_filter')
|
|
|
+ load_conv_weights(module.conv.layers[0], state_dict, f'{name}.conv.0')
|
|
|
+ load_bn_weights(module.conv.layers[1], state_dict, f'{name}.conv.1')
|
|
|
+ load_conv_weights(module.conv.layers[3], state_dict, f'{name}.conv.3')
|
|
|
+ load_bn_weights(module.conv.layers[4], state_dict, f'{name}.conv.4')
|
|
|
+ load_conv_weights(module.conv.layers[6], state_dict, f'{name}.conv.6')
|