load_weights.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. from tensorflow.keras.layers import DepthwiseConv2D
  2. from .mobilenetv3 import *
  3. from .resnet import ResNet50Encoder
  4. from .deep_guided_filter import DeepGuidedFilterRefiner
  5. # --------------------------- Load torch weights ---------------------------
  6. def load_torch_weights(model, state_dict):
  7. if isinstance(model.backbone, MobileNetV3Encoder):
  8. load_MobileNetV3_weights(model.backbone, state_dict, 'backbone')
  9. if isinstance(model.backbone, ResNet50Encoder):
  10. load_ResNetEncoder_weights(model.backbone, state_dict, 'backbone')
  11. load_LRASPP_weights(model.aspp, state_dict, 'aspp')
  12. load_RecurrentDecoder_weights(model.decoder, state_dict, 'decoder')
  13. load_conv_weights(model.project_mat, state_dict, 'project_mat.conv')
  14. if isinstance(model.refiner, DeepGuidedFilterRefiner):
  15. load_DeepGuidedFilter_weights(model.refiner, state_dict, 'refiner')
  16. # --------------------------- General ---------------------------
  17. def load_conv_weights(conv, state_dict, name):
  18. weight = state_dict[name + '.weight']
  19. if isinstance(conv, DepthwiseConv2D):
  20. weight = weight.permute(2, 3, 0, 1).numpy()
  21. else:
  22. weight = weight.permute(2, 3, 1, 0).numpy()
  23. if name + '.bias' in state_dict:
  24. bias = state_dict[name + '.bias'].numpy()
  25. conv.set_weights([weight, bias])
  26. else:
  27. conv.set_weights([weight])
  28. def load_bn_weights(bn, state_dict, name):
  29. weight = state_dict[name + '.weight']
  30. bias = state_dict[name + '.bias']
  31. running_mean = state_dict[name + '.running_mean']
  32. running_var = state_dict[name + '.running_var']
  33. bn.set_weights([weight, bias, running_mean, running_var])
  34. # --------------------------- MobileNetV3 ---------------------------
  35. def load_ConvBNActivation_weights(module, state_dict, name):
  36. load_conv_weights(module.conv, state_dict, name + '.0')
  37. load_bn_weights(module.bn, state_dict, name + '.1')
  38. def load_InvertedResidual_weights(module, state_dict, name):
  39. for i, layer in enumerate(module.block.layers):
  40. if isinstance(layer, ConvBNActivation):
  41. load_ConvBNActivation_weights(layer, state_dict, f'{name}.block.{i}')
  42. if isinstance(layer, SqueezeExcitation):
  43. load_conv_weights(layer.fc1, state_dict, f'{name}.block.{i}.fc1')
  44. load_conv_weights(layer.fc2, state_dict, f'{name}.block.{i}.fc2')
  45. def load_MobileNetV3_weights(backbone, state_dict, name):
  46. for i, module in enumerate(backbone.features):
  47. if isinstance(module, ConvBNActivation):
  48. load_ConvBNActivation_weights(module, state_dict, f'{name}.features.{i}')
  49. if isinstance(module, InvertedResidual):
  50. load_InvertedResidual_weights(module, state_dict, f'{name}.features.{i}')
  51. # --------------------------- ResNet ---------------------------
  52. def load_ResNetEncoder_weights(module, state_dict, name):
  53. load_conv_weights(module.conv1, state_dict, f'{name}.conv1')
  54. load_bn_weights(module.bn1, state_dict, f'{name}.bn1')
  55. for l in range(1, 5):
  56. for b, resblock in enumerate(getattr(module, f'layer{l}').layers):
  57. if hasattr(resblock, 'convd'):
  58. load_conv_weights(resblock.convd, state_dict, f'{name}.layer{l}.{b}.downsample.0')
  59. load_bn_weights(resblock.bnd, state_dict, f'{name}.layer{l}.{b}.downsample.1')
  60. load_conv_weights(resblock.conv1, state_dict, f'{name}.layer{l}.{b}.conv1')
  61. load_conv_weights(resblock.conv2, state_dict, f'{name}.layer{l}.{b}.conv2')
  62. load_conv_weights(resblock.conv3, state_dict, f'{name}.layer{l}.{b}.conv3')
  63. load_bn_weights(resblock.bn1, state_dict, f'{name}.layer{l}.{b}.bn1')
  64. load_bn_weights(resblock.bn2, state_dict, f'{name}.layer{l}.{b}.bn2')
  65. load_bn_weights(resblock.bn3, state_dict, f'{name}.layer{l}.{b}.bn3')
  66. # --------------------------- LRASPP ---------------------------
  67. def load_LRASPP_weights(module, state_dict, name):
  68. load_conv_weights(module.aspp1.layers[0], state_dict, f'{name}.aspp1.0')
  69. load_bn_weights(module.aspp1.layers[1], state_dict, f'{name}.aspp1.1')
  70. load_conv_weights(module.aspp2, state_dict, f'{name}.aspp2.1')
  71. # --------------------------- RecurrentDecoder ---------------------------
  72. def load_ConvGRU_weights(module, state_dict, name):
  73. load_conv_weights(module.ih, state_dict, f'{name}.ih.0')
  74. load_conv_weights(module.hh, state_dict, f'{name}.hh.0')
  75. def load_BottleneckBlock_weights(module, state_dict, name):
  76. load_ConvGRU_weights(module.gru, state_dict, f'{name}.gru')
  77. def load_UpsamplingBlock_weights(module, state_dict, name):
  78. load_conv_weights(module.conv.layers[0], state_dict, f'{name}.conv.0')
  79. load_bn_weights(module.conv.layers[1], state_dict, f'{name}.conv.1')
  80. load_ConvGRU_weights(module.gru, state_dict, f'{name}.gru')
  81. def load_OutputBlock_weights(module, state_dict, name):
  82. load_conv_weights(module.conv.layers[0], state_dict, f'{name}.conv.0')
  83. load_bn_weights(module.conv.layers[1], state_dict, f'{name}.conv.1')
  84. load_conv_weights(module.conv.layers[3], state_dict, f'{name}.conv.3')
  85. load_bn_weights(module.conv.layers[4], state_dict, f'{name}.conv.4')
  86. def load_RecurrentDecoder_weights(module, state_dict, name):
  87. load_BottleneckBlock_weights(module.decode4, state_dict, f'{name}.decode4')
  88. load_UpsamplingBlock_weights(module.decode3, state_dict, f'{name}.decode3')
  89. load_UpsamplingBlock_weights(module.decode2, state_dict, f'{name}.decode2')
  90. load_UpsamplingBlock_weights(module.decode1, state_dict, f'{name}.decode1')
  91. load_OutputBlock_weights(module.decode0, state_dict, f'{name}.decode0')
  92. # --------------------------- DeepGuidedFilter ---------------------------
  93. def load_DeepGuidedFilter_weights(module, state_dict, name):
  94. load_conv_weights(module.box_filter.layers[1], state_dict, f'{name}.box_filter')
  95. load_conv_weights(module.conv.layers[0], state_dict, f'{name}.conv.0')
  96. load_bn_weights(module.conv.layers[1], state_dict, f'{name}.conv.1')
  97. load_conv_weights(module.conv.layers[3], state_dict, f'{name}.conv.3')
  98. load_bn_weights(module.conv.layers[4], state_dict, f'{name}.conv.4')
  99. load_conv_weights(module.conv.layers[6], state_dict, f'{name}.conv.6')