123456789101112131415161718192021222324252627282930313233343536373839404142434445464748 |
- from torch import nn
- from torchvision.models.resnet import ResNet, Bottleneck
- class ResNetEncoder(ResNet):
- """
- ResNetEncoder inherits from torchvision's official ResNet. It is modified to
- use dilation on the last block to maintain output stride 16, and deleted the
- global average pooling layer and the fully connected layer that was originally
- used for classification. The forward method additionally returns the feature
- maps at all resolutions for decoder's use.
- """
-
- layers = {
- 'resnet50': [3, 4, 6, 3],
- 'resnet101': [3, 4, 23, 3],
- }
-
- def __init__(self, in_channels, variant='resnet101', norm_layer=None):
- super().__init__(
- block=Bottleneck,
- layers=self.layers[variant],
- replace_stride_with_dilation=[False, False, True],
- norm_layer=norm_layer)
-
- # Replace first conv layer if in_channels doesn't match.
- if in_channels != 3:
- self.conv1 = nn.Conv2d(in_channels, 64, 7, 2, 3, bias=False)
-
- # Delete fully-connected layer
- del self.avgpool
- del self.fc
-
- def forward(self, x):
- x0 = x # 1/1
- x = self.conv1(x)
- x = self.bn1(x)
- x = self.relu(x)
- x1 = x # 1/2
- x = self.maxpool(x)
- x = self.layer1(x)
- x2 = x # 1/4
- x = self.layer2(x)
- x3 = x # 1/8
- x = self.layer3(x)
- x = self.layer4(x)
- x4 = x # 1/16
- return x4, x3, x2, x1, x0
|