mobilenet.py 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. from torch import nn
  2. from torchvision.models import MobileNetV2
  3. class MobileNetV2Encoder(MobileNetV2):
  4. """
  5. MobileNetV2Encoder inherits from torchvision's official MobileNetV2. It is modified to
  6. use dilation on the last block to maintain output stride 16, and deleted the
  7. classifier block that was originally used for classification. The forward method
  8. additionally returns the feature maps at all resolutions for decoder's use.
  9. """
  10. def __init__(self, in_channels, norm_layer=None):
  11. super().__init__()
  12. # Replace first conv layer if in_channels doesn't match.
  13. if in_channels != 3:
  14. self.features[0][0] = nn.Conv2d(in_channels, 32, 3, 2, 1, bias=False)
  15. # Remove last block
  16. self.features = self.features[:-1]
  17. # Change to use dilation to maintain output stride = 16
  18. self.features[14].conv[1][0].stride = (1, 1)
  19. for feature in self.features[15:]:
  20. feature.conv[1][0].dilation = (2, 2)
  21. feature.conv[1][0].padding = (2, 2)
  22. # Delete classifier
  23. del self.classifier
  24. def forward(self, x):
  25. x0 = x # 1/1
  26. x = self.features[0](x)
  27. x = self.features[1](x)
  28. x1 = x # 1/2
  29. x = self.features[2](x)
  30. x = self.features[3](x)
  31. x2 = x # 1/4
  32. x = self.features[4](x)
  33. x = self.features[5](x)
  34. x = self.features[6](x)
  35. x3 = x # 1/8
  36. x = self.features[7](x)
  37. x = self.features[8](x)
  38. x = self.features[9](x)
  39. x = self.features[10](x)
  40. x = self.features[11](x)
  41. x = self.features[12](x)
  42. x = self.features[13](x)
  43. x = self.features[14](x)
  44. x = self.features[15](x)
  45. x = self.features[16](x)
  46. x = self.features[17](x)
  47. x4 = x # 1/16
  48. return x4, x3, x2, x1, x0