model.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. import torch
  2. from torch import nn
  3. from torch.nn import functional as F
  4. from torchvision.models.segmentation.deeplabv3 import ASPP
  5. from .decoder import Decoder
  6. from .mobilenet import MobileNetV2Encoder
  7. from .refiner import Refiner
  8. from .resnet import ResNetEncoder
  9. from .utils import load_matched_state_dict
  10. class Base(nn.Module):
  11. """
  12. A generic implementation of the base encoder-decoder network inspired by DeepLab.
  13. Accepts arbitrary channels for input and output.
  14. """
  15. def __init__(self, backbone: str, in_channels: int, out_channels: int):
  16. super().__init__()
  17. assert backbone in ["resnet50", "resnet101", "mobilenetv2"]
  18. if backbone in ['resnet50', 'resnet101']:
  19. self.backbone = ResNetEncoder(in_channels, variant=backbone)
  20. self.aspp = ASPP(2048, [3, 6, 9])
  21. self.decoder = Decoder([256, 128, 64, 48, out_channels], [512, 256, 64, in_channels])
  22. else:
  23. self.backbone = MobileNetV2Encoder(in_channels)
  24. self.aspp = ASPP(320, [3, 6, 9])
  25. self.decoder = Decoder([256, 128, 64, 48, out_channels], [32, 24, 16, in_channels])
  26. def forward(self, x):
  27. x, *shortcuts = self.backbone(x)
  28. x = self.aspp(x)
  29. x = self.decoder(x, *shortcuts)
  30. return x
  31. def load_pretrained_deeplabv3_state_dict(self, state_dict, print_stats=True):
  32. # Pretrained DeepLabV3 models are provided by <https://github.com/VainF/DeepLabV3Plus-Pytorch>.
  33. # This method converts and loads their pretrained state_dict to match with our model structure.
  34. # This method is not needed if you are not planning to train from deeplab weights.
  35. # Use load_state_dict() for normal weight loading.
  36. # Convert state_dict naming for aspp module
  37. state_dict = {k.replace('classifier.classifier.0', 'aspp'): v for k, v in state_dict.items()}
  38. if isinstance(self.backbone, ResNetEncoder):
  39. # ResNet backbone does not need change.
  40. load_matched_state_dict(self, state_dict, print_stats)
  41. else:
  42. # Change MobileNetV2 backbone to state_dict format, then change back after loading.
  43. backbone_features = self.backbone.features
  44. self.backbone.low_level_features = backbone_features[:4]
  45. self.backbone.high_level_features = backbone_features[4:]
  46. del self.backbone.features
  47. load_matched_state_dict(self, state_dict, print_stats)
  48. self.backbone.features = backbone_features
  49. del self.backbone.low_level_features
  50. del self.backbone.high_level_features
  51. class MattingBase(Base):
  52. """
  53. MattingBase is used to produce coarse global results at a lower resolution.
  54. MattingBase extends Base.
  55. Args:
  56. backbone: ["resnet50", "resnet101", "mobilenetv2"]
  57. Input:
  58. src: (B, 3, H, W) the source image. Channels are RGB values normalized to 0 ~ 1.
  59. bgr: (B, 3, H, W) the background image . Channels are RGB values normalized to 0 ~ 1.
  60. Output:
  61. pha: (B, 1, H, W) the alpha prediction. Normalized to 0 ~ 1.
  62. fgr: (B, 3, H, W) the foreground prediction. Channels are RGB values normalized to 0 ~ 1.
  63. err: (B, 1, H, W) the error prediction. Normalized to 0 ~ 1.
  64. hid: (B, 32, H, W) the hidden encoding. Used for connecting refiner module.
  65. Example:
  66. model = MattingBase(backbone='resnet50')
  67. pha, fgr, err, hid = model(src, bgr) # for training
  68. pha, fgr = model(src, bgr)[:2] # for inference
  69. """
  70. def __init__(self, backbone: str):
  71. super().__init__(backbone, in_channels=6, out_channels=(1 + 3 + 1 + 32))
  72. def forward(self, src, bgr):
  73. x = torch.cat([src, bgr], dim=1)
  74. x, *shortcuts = self.backbone(x)
  75. x = self.aspp(x)
  76. x = self.decoder(x, *shortcuts)
  77. pha = x[:, 0:1].clamp_(0., 1.)
  78. fgr = x[:, 1:4].add(src).clamp_(0., 1.)
  79. err = x[:, 4:5].clamp_(0., 1.)
  80. hid = x[:, 5: ].relu_()
  81. return pha, fgr, err, hid
  82. class MattingRefine(MattingBase):
  83. """
  84. MattingRefine includes the refiner module to upsample coarse result to full resolution.
  85. MattingRefine extends MattingBase.
  86. Args:
  87. backbone: ["resnet50", "resnet101", "mobilenetv2"]
  88. backbone_scale: The image downsample scale for passing through backbone, default 1/4 or 0.25.
  89. Must not be greater than 1/2.
  90. refine_mode: refine area selection mode. Options:
  91. "full" - No area selection, refine everywhere using regular Conv2d.
  92. "sampling" - Refine fixed amount of pixels ranked by the top most errors.
  93. "thresholding" - Refine varying amount of pixels that has more error than the threshold.
  94. refine_sample_pixels: number of pixels to refine. Only used when mode == "sampling".
  95. refine_threshold: error threshold ranged from 0 ~ 1. Refine where err > threshold. Only used when mode == "thresholding".
  96. refine_kernel_size: the refiner's convolutional kernel size. Options: [1, 3]
  97. refine_prevent_oversampling: prevent sampling more pixels than needed for sampling mode. Set False only for speedtest.
  98. Input:
  99. src: (B, 3, H, W) the source image. Channels are RGB values normalized to 0 ~ 1.
  100. bgr: (B, 3, H, W) the background image. Channels are RGB values normalized to 0 ~ 1.
  101. Output:
  102. pha: (B, 1, H, W) the alpha prediction. Normalized to 0 ~ 1.
  103. fgr: (B, 3, H, W) the foreground prediction. Channels are RGB values normalized to 0 ~ 1.
  104. pha_sm: (B, 1, Hc, Wc) the coarse alpha prediction from matting base. Normalized to 0 ~ 1.
  105. fgr_sm: (B, 3, Hc, Hc) the coarse foreground prediction from matting base. Normalized to 0 ~ 1.
  106. err_sm: (B, 1, Hc, Wc) the coarse error prediction from matting base. Normalized to 0 ~ 1.
  107. ref_sm: (B, 1, H/4, H/4) the quarter resolution refinement map. 1 indicates refined 4x4 patch locations.
  108. Example:
  109. model = MattingRefine(backbone='resnet50', backbone_scale=1/4, refine_mode='sampling', refine_sample_pixels=80_000)
  110. model = MattingRefine(backbone='resnet50', backbone_scale=1/4, refine_mode='thresholding', refine_threshold=0.1)
  111. model = MattingRefine(backbone='resnet50', backbone_scale=1/4, refine_mode='full')
  112. pha, fgr, pha_sm, fgr_sm, err_sm, ref_sm = model(src, bgr) # for training
  113. pha, fgr = model(src, bgr)[:2] # for inference
  114. """
  115. def __init__(self,
  116. backbone: str,
  117. backbone_scale: float = 1/4,
  118. refine_mode: str = 'sampling',
  119. refine_sample_pixels: int = 80_000,
  120. refine_threshold: float = 0.1,
  121. refine_kernel_size: int = 3,
  122. refine_prevent_oversampling: bool = True,
  123. refine_patch_crop_method: str = 'unfold',
  124. refine_patch_replace_method: str = 'scatter_nd'):
  125. assert backbone_scale <= 1/2, 'backbone_scale should not be greater than 1/2'
  126. super().__init__(backbone)
  127. self.backbone_scale = backbone_scale
  128. self.refiner = Refiner(refine_mode,
  129. refine_sample_pixels,
  130. refine_threshold,
  131. refine_kernel_size,
  132. refine_prevent_oversampling,
  133. refine_patch_crop_method,
  134. refine_patch_replace_method)
  135. def forward(self, src, bgr):
  136. assert src.size() == bgr.size(), 'src and bgr must have the same shape'
  137. assert src.size(2) // 4 * 4 == src.size(2) and src.size(3) // 4 * 4 == src.size(3), \
  138. 'src and bgr must have width and height that are divisible by 4'
  139. # Downsample src and bgr for backbone
  140. src_sm = F.interpolate(src,
  141. scale_factor=self.backbone_scale,
  142. mode='bilinear',
  143. align_corners=False,
  144. recompute_scale_factor=True)
  145. bgr_sm = F.interpolate(bgr,
  146. scale_factor=self.backbone_scale,
  147. mode='bilinear',
  148. align_corners=False,
  149. recompute_scale_factor=True)
  150. # Base
  151. x = torch.cat([src_sm, bgr_sm], dim=1)
  152. x, *shortcuts = self.backbone(x)
  153. x = self.aspp(x)
  154. x = self.decoder(x, *shortcuts)
  155. pha_sm = x[:, 0:1].clamp_(0., 1.)
  156. fgr_sm = x[:, 1:4]
  157. err_sm = x[:, 4:5].clamp_(0., 1.)
  158. hid_sm = x[:, 5: ].relu_()
  159. # Refiner
  160. pha, fgr, ref_sm = self.refiner(src, bgr, pha_sm, fgr_sm, err_sm, hid_sm)
  161. # Clamp outputs
  162. pha = pha.clamp_(0., 1.)
  163. fgr = fgr.add_(src).clamp_(0., 1.)
  164. fgr_sm = src_sm.add_(fgr_sm).clamp_(0., 1.)
  165. return pha, fgr, pha_sm, fgr_sm, err_sm, ref_sm