Browse Source

Fix issue with torchvision 0.11.0

Peter Lin 3 years ago
parent
commit
918a97f347
2 changed files with 4 additions and 4 deletions
  1. 2 2
      model/mobilenetv3.py
  2. 2 2
      model/resnet.py

+ 2 - 2
model/mobilenetv3.py

@@ -1,6 +1,6 @@
+import torch
 from torch import nn
 from torchvision.models.mobilenetv3 import MobileNetV3, InvertedResidualConfig
-from torchvision.models.utils import load_state_dict_from_url
 from torchvision.transforms.functional import normalize
 
 class MobileNetV3LargeEncoder(MobileNetV3):
@@ -27,7 +27,7 @@ class MobileNetV3LargeEncoder(MobileNetV3):
         )
         
         if pretrained:
-            self.load_state_dict(load_state_dict_from_url(
+            self.load_state_dict(torch.hub.load_state_dict_from_url(
                 'https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth'))
 
         del self.avgpool

+ 2 - 2
model/resnet.py

@@ -1,6 +1,6 @@
+import torch
 from torch import nn
 from torchvision.models.resnet import ResNet, Bottleneck
-from torchvision.models.utils import load_state_dict_from_url
 
 class ResNet50Encoder(ResNet):
     def __init__(self, pretrained: bool = False):
@@ -11,7 +11,7 @@ class ResNet50Encoder(ResNet):
             norm_layer=None)
         
         if pretrained:
-            self.load_state_dict(load_state_dict_from_url(
+            self.load_state_dict(torch.hub.load_state_dict_from_url(
                 'https://download.pytorch.org/models/resnet50-0676ba61.pth'))
         
         del self.avgpool