123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113 |
- from tensorflow.keras.layers import DepthwiseConv2D
- from .mobilenetv3 import *
- from .resnet import ResNet50Encoder
- from .deep_guided_filter import DeepGuidedFilterRefiner
- # --------------------------- Load torch weights ---------------------------
- def load_torch_weights(model, state_dict):
- if isinstance(model.backbone, MobileNetV3Encoder):
- load_MobileNetV3_weights(model.backbone, state_dict, 'backbone')
- if isinstance(model.backbone, ResNet50Encoder):
- load_ResNetEncoder_weights(model.backbone, state_dict, 'backbone')
- load_LRASPP_weights(model.aspp, state_dict, 'aspp')
- load_RecurrentDecoder_weights(model.decoder, state_dict, 'decoder')
- load_conv_weights(model.project_mat, state_dict, 'project_mat.conv')
- if isinstance(model.refiner, DeepGuidedFilterRefiner):
- load_DeepGuidedFilter_weights(model.refiner, state_dict, 'refiner')
-
- # --------------------------- General ---------------------------
- def load_conv_weights(conv, state_dict, name):
- weight = state_dict[name + '.weight']
- if isinstance(conv, DepthwiseConv2D):
- weight = weight.permute(2, 3, 0, 1).numpy()
- else:
- weight = weight.permute(2, 3, 1, 0).numpy()
- if name + '.bias' in state_dict:
- bias = state_dict[name + '.bias'].numpy()
- conv.set_weights([weight, bias])
- else:
- conv.set_weights([weight])
- def load_bn_weights(bn, state_dict, name):
- weight = state_dict[name + '.weight']
- bias = state_dict[name + '.bias']
- running_mean = state_dict[name + '.running_mean']
- running_var = state_dict[name + '.running_var']
- bn.set_weights([weight, bias, running_mean, running_var])
-
- # --------------------------- MobileNetV3 ---------------------------
- def load_ConvBNActivation_weights(module, state_dict, name):
- load_conv_weights(module.conv, state_dict, name + '.0')
- load_bn_weights(module.bn, state_dict, name + '.1')
- def load_InvertedResidual_weights(module, state_dict, name):
- for i, layer in enumerate(module.block.layers):
- if isinstance(layer, ConvBNActivation):
- load_ConvBNActivation_weights(layer, state_dict, f'{name}.block.{i}')
- if isinstance(layer, SqueezeExcitation):
- load_conv_weights(layer.fc1, state_dict, f'{name}.block.{i}.fc1')
- load_conv_weights(layer.fc2, state_dict, f'{name}.block.{i}.fc2')
- def load_MobileNetV3_weights(backbone, state_dict, name):
- for i, module in enumerate(backbone.features):
- if isinstance(module, ConvBNActivation):
- load_ConvBNActivation_weights(module, state_dict, f'{name}.features.{i}')
- if isinstance(module, InvertedResidual):
- load_InvertedResidual_weights(module, state_dict, f'{name}.features.{i}')
- # --------------------------- ResNet ---------------------------
- def load_ResNetEncoder_weights(module, state_dict, name):
- load_conv_weights(module.conv1, state_dict, f'{name}.conv1')
- load_bn_weights(module.bn1, state_dict, f'{name}.bn1')
- for l in range(1, 5):
- for b, resblock in enumerate(getattr(module, f'layer{l}').layers):
- if hasattr(resblock, 'convd'):
- load_conv_weights(resblock.convd, state_dict, f'{name}.layer{l}.{b}.downsample.0')
- load_bn_weights(resblock.bnd, state_dict, f'{name}.layer{l}.{b}.downsample.1')
- load_conv_weights(resblock.conv1, state_dict, f'{name}.layer{l}.{b}.conv1')
- load_conv_weights(resblock.conv2, state_dict, f'{name}.layer{l}.{b}.conv2')
- load_conv_weights(resblock.conv3, state_dict, f'{name}.layer{l}.{b}.conv3')
- load_bn_weights(resblock.bn1, state_dict, f'{name}.layer{l}.{b}.bn1')
- load_bn_weights(resblock.bn2, state_dict, f'{name}.layer{l}.{b}.bn2')
- load_bn_weights(resblock.bn3, state_dict, f'{name}.layer{l}.{b}.bn3')
- # --------------------------- LRASPP ---------------------------
- def load_LRASPP_weights(module, state_dict, name):
- load_conv_weights(module.aspp1.layers[0], state_dict, f'{name}.aspp1.0')
- load_bn_weights(module.aspp1.layers[1], state_dict, f'{name}.aspp1.1')
- load_conv_weights(module.aspp2, state_dict, f'{name}.aspp2.1')
-
- # --------------------------- RecurrentDecoder ---------------------------
- def load_ConvGRU_weights(module, state_dict, name):
- load_conv_weights(module.ih, state_dict, f'{name}.ih.0')
- load_conv_weights(module.hh, state_dict, f'{name}.hh.0')
- def load_BottleneckBlock_weights(module, state_dict, name):
- load_ConvGRU_weights(module.gru, state_dict, f'{name}.gru')
- def load_UpsamplingBlock_weights(module, state_dict, name):
- load_conv_weights(module.conv.layers[0], state_dict, f'{name}.conv.0')
- load_bn_weights(module.conv.layers[1], state_dict, f'{name}.conv.1')
- load_ConvGRU_weights(module.gru, state_dict, f'{name}.gru')
- def load_OutputBlock_weights(module, state_dict, name):
- load_conv_weights(module.conv.layers[0], state_dict, f'{name}.conv.0')
- load_bn_weights(module.conv.layers[1], state_dict, f'{name}.conv.1')
- load_conv_weights(module.conv.layers[3], state_dict, f'{name}.conv.3')
- load_bn_weights(module.conv.layers[4], state_dict, f'{name}.conv.4')
- def load_RecurrentDecoder_weights(module, state_dict, name):
- load_BottleneckBlock_weights(module.decode4, state_dict, f'{name}.decode4')
- load_UpsamplingBlock_weights(module.decode3, state_dict, f'{name}.decode3')
- load_UpsamplingBlock_weights(module.decode2, state_dict, f'{name}.decode2')
- load_UpsamplingBlock_weights(module.decode1, state_dict, f'{name}.decode1')
- load_OutputBlock_weights(module.decode0, state_dict, f'{name}.decode0')
-
- # --------------------------- DeepGuidedFilter ---------------------------
- def load_DeepGuidedFilter_weights(module, state_dict, name):
- load_conv_weights(module.box_filter.layers[1], state_dict, f'{name}.box_filter')
- load_conv_weights(module.conv.layers[0], state_dict, f'{name}.conv.0')
- load_bn_weights(module.conv.layers[1], state_dict, f'{name}.conv.1')
- load_conv_weights(module.conv.layers[3], state_dict, f'{name}.conv.3')
- load_bn_weights(module.conv.layers[4], state_dict, f'{name}.conv.4')
- load_conv_weights(module.conv.layers[6], state_dict, f'{name}.conv.6')
|