|
@@ -1,6 +1,6 @@
|
|
|
+import torch
|
|
|
from torch import nn
|
|
|
from torchvision.models.resnet import ResNet, Bottleneck
|
|
|
-from torchvision.models.utils import load_state_dict_from_url
|
|
|
|
|
|
class ResNet50Encoder(ResNet):
|
|
|
def __init__(self, pretrained: bool = False):
|
|
@@ -11,7 +11,7 @@ class ResNet50Encoder(ResNet):
|
|
|
norm_layer=None)
|
|
|
|
|
|
if pretrained:
|
|
|
- self.load_state_dict(load_state_dict_from_url(
|
|
|
+ self.load_state_dict(torch.hub.load_state_dict_from_url(
|
|
|
'https://download.pytorch.org/models/resnet50-0676ba61.pth'))
|
|
|
|
|
|
del self.avgpool
|