123456789101112131415161718192021222324252627282930313233343536373839404142434445 |
- from torch import nn
- from torchvision.models.resnet import ResNet, Bottleneck
- from torchvision.models.utils import load_state_dict_from_url
- class ResNet50Encoder(ResNet):
- def __init__(self, pretrained: bool = False):
- super().__init__(
- block=Bottleneck,
- layers=[3, 4, 6, 3],
- replace_stride_with_dilation=[False, False, True],
- norm_layer=None)
-
- if pretrained:
- self.load_state_dict(load_state_dict_from_url(
- 'https://download.pytorch.org/models/resnet50-0676ba61.pth'))
-
- del self.avgpool
- del self.fc
-
- def forward_single_frame(self, x):
- x = self.conv1(x)
- x = self.bn1(x)
- x = self.relu(x)
- f1 = x # 1/2
- x = self.maxpool(x)
- x = self.layer1(x)
- f2 = x # 1/4
- x = self.layer2(x)
- f3 = x # 1/8
- x = self.layer3(x)
- x = self.layer4(x)
- f4 = x # 1/16
- return [f1, f2, f3, f4]
-
- def forward_time_series(self, x):
- B, T = x.shape[:2]
- features = self.forward_single_frame(x.flatten(0, 1))
- features = [f.unflatten(0, (B, T)) for f in features]
- return features
-
- def forward(self, x):
- if x.ndim == 5:
- return self.forward_time_series(x)
- else:
- return self.forward_single_frame(x)
|