123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778 |
- import torch
- from torch import Tensor
- from torch import nn
- from torch.nn import functional as F
- from typing import Optional, List
- from .mobilenetv3 import MobileNetV3LargeEncoder
- from .resnet import ResNet50Encoder
- from .lraspp import LRASPP
- from .decoder import RecurrentDecoder, Projection
- from .fast_guided_filter import FastGuidedFilterRefiner
- from .deep_guided_filter import DeepGuidedFilterRefiner
- from .onnx_helper import CustomOnnxResizeByFactorOp
- class MattingNetwork(nn.Module):
- def __init__(self,
- variant: str = 'mobilenetv3',
- refiner: str = 'deep_guided_filter',
- pretrained_backbone: bool = False):
- super().__init__()
- assert variant in ['mobilenetv3', 'resnet50']
- assert refiner in ['fast_guided_filter', 'deep_guided_filter']
-
- if variant == 'mobilenetv3':
- self.backbone = MobileNetV3LargeEncoder(pretrained_backbone)
- self.aspp = LRASPP(960, 128)
- self.decoder = RecurrentDecoder([16, 24, 40, 128], [80, 40, 32, 16])
- else:
- self.backbone = ResNet50Encoder(pretrained_backbone)
- self.aspp = LRASPP(2048, 256)
- self.decoder = RecurrentDecoder([64, 256, 512, 256], [128, 64, 32, 16])
-
- self.project_mat = Projection(16, 4)
- self.project_seg = Projection(16, 1)
- if refiner == 'deep_guided_filter':
- self.refiner = DeepGuidedFilterRefiner()
- else:
- self.refiner = FastGuidedFilterRefiner()
-
- def forward(self, src, r1, r2, r3, r4,
- downsample_ratio: float = 1,
- segmentation_pass: bool = False):
-
- if torch.onnx.is_in_onnx_export():
- src_sm = CustomOnnxResizeByFactorOp.apply(src, downsample_ratio)
- elif downsample_ratio != 1:
- src_sm = self._interpolate(src, scale_factor=downsample_ratio)
- else:
- src_sm = src
-
- f1, f2, f3, f4 = self.backbone(src_sm)
- f4 = self.aspp(f4)
- hid, *rec = self.decoder(src_sm, f1, f2, f3, f4, r1, r2, r3, r4)
-
- if not segmentation_pass:
- fgr_residual, pha = self.project_mat(hid).split([3, 1], dim=-3)
- if torch.onnx.is_in_onnx_export() or downsample_ratio != 1:
- fgr_residual, pha = self.refiner(src, src_sm, fgr_residual, pha, hid)
- fgr = fgr_residual + src
- fgr = fgr.clamp(0., 1.)
- pha = pha.clamp(0., 1.)
- return [fgr, pha, *rec]
- else:
- seg = self.project_seg(hid)
- return [seg, *rec]
- def _interpolate(self, x: Tensor, scale_factor: float):
- if x.ndim == 5:
- B, T = x.shape[:2]
- x = F.interpolate(x.flatten(0, 1), scale_factor=scale_factor,
- mode='bilinear', align_corners=False, recompute_scale_factor=False)
- x = x.reshape(B, T, x.size(1), x.size(2), x.size(3))
- else:
- x = F.interpolate(x, scale_factor=scale_factor,
- mode='bilinear', align_corners=False, recompute_scale_factor=False)
- return x
|