1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556 |
- from torch import nn
- from torchvision.models import MobileNetV2
- class MobileNetV2Encoder(MobileNetV2):
- """
- MobileNetV2Encoder inherits from torchvision's official MobileNetV2. It is modified to
- use dilation on the last block to maintain output stride 16, and deleted the
- classifier block that was originally used for classification. The forward method
- additionally returns the feature maps at all resolutions for decoder's use.
- """
-
- def __init__(self, in_channels, norm_layer=None):
- super().__init__()
-
- # Replace first conv layer if in_channels doesn't match.
- if in_channels != 3:
- self.features[0][0] = nn.Conv2d(in_channels, 32, 3, 2, 1, bias=False)
-
- # Remove last block
- self.features = self.features[:-1]
-
- # Change to use dilation to maintain output stride = 16
- self.features[14].conv[1][0].stride = (1, 1)
- for feature in self.features[15:]:
- feature.conv[1][0].dilation = (2, 2)
- feature.conv[1][0].padding = (2, 2)
-
- # Delete classifier
- del self.classifier
-
- def forward(self, x):
- x0 = x # 1/1
- x = self.features[0](x)
- x = self.features[1](x)
- x1 = x # 1/2
- x = self.features[2](x)
- x = self.features[3](x)
- x2 = x # 1/4
- x = self.features[4](x)
- x = self.features[5](x)
- x = self.features[6](x)
- x3 = x # 1/8
- x = self.features[7](x)
- x = self.features[8](x)
- x = self.features[9](x)
- x = self.features[10](x)
- x = self.features[11](x)
- x = self.features[12](x)
- x = self.features[13](x)
- x = self.features[14](x)
- x = self.features[15](x)
- x = self.features[16](x)
- x = self.features[17](x)
- x4 = x # 1/16
- return x4, x3, x2, x1, x0
|