resnet.py 1.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445
  1. from torch import nn
  2. from torchvision.models.resnet import ResNet, Bottleneck
  3. from torchvision.models.utils import load_state_dict_from_url
  4. class ResNet50Encoder(ResNet):
  5. def __init__(self, pretrained: bool = False):
  6. super().__init__(
  7. block=Bottleneck,
  8. layers=[3, 4, 6, 3],
  9. replace_stride_with_dilation=[False, False, True],
  10. norm_layer=None)
  11. if pretrained:
  12. self.load_state_dict(load_state_dict_from_url(
  13. 'https://download.pytorch.org/models/resnet50-0676ba61.pth'))
  14. del self.avgpool
  15. del self.fc
  16. def forward_single_frame(self, x):
  17. x = self.conv1(x)
  18. x = self.bn1(x)
  19. x = self.relu(x)
  20. f1 = x # 1/2
  21. x = self.maxpool(x)
  22. x = self.layer1(x)
  23. f2 = x # 1/4
  24. x = self.layer2(x)
  25. f3 = x # 1/8
  26. x = self.layer3(x)
  27. x = self.layer4(x)
  28. f4 = x # 1/16
  29. return [f1, f2, f3, f4]
  30. def forward_time_series(self, x):
  31. B, T = x.shape[:2]
  32. features = self.forward_single_frame(x.flatten(0, 1))
  33. features = [f.unflatten(0, (B, T)) for f in features]
  34. return features
  35. def forward(self, x):
  36. if x.ndim == 5:
  37. return self.forward_time_series(x)
  38. else:
  39. return self.forward_single_frame(x)