Peter Lin пре 4 година
комит
ee86e2636f

+ 82 - 0
README.md

@@ -0,0 +1,82 @@
+# Real-Time High-Resolution Background Matting
+
+![Teaser](https://github.com/PeterL1n/Matting-PyTorch/blob/master/images/teaser.jpg?raw=true)
+
+Official repository for the paper [Real-Time High-Resolution Background Matting](https://grail.cs.washington.edu/projects/background-matting-v2/bgm_v2.pdf). Our model produces state-of-the-art matting results at 4K 30fps and HD 60fps on an Nvidia RTX 2080 TI GPU.
+
+* [Visit project site](https://grail.cs.washington.edu/projects/background-matting-v2/)
+* [Watch project video](https://www.youtube.com/watch?v=oMfPTeYDF9g)
+
+ 
+
+## Overview
+* [Download](#download)
+    * [Model / Weights](#model--weights)
+    * [Video / Image Examples](#video--image-examples)
+    * [Datasets](#datasets)
+* [Demo](#demo)
+    * [Scripts](#scripts)
+    * [Notebooks](#notebooks)
+* [Usage / Documentation](#usage--documentation)
+* [Training](#training)
+* [Project members](#project-members)
+
+ 
+
+## Download
+
+### Model / Weights
+
+* [Download model / weights](https://drive.google.com/drive/folders/1cbetlrKREitIgjnIikG1HdM4x72FtgBh?usp=sharing)
+
+### Video / Image Examples
+
+* [HD videos](https://drive.google.com/drive/folders/1j3BMrRFhFpfzJAe6P2WDtfanoeSCLPiq) (by [Sengupta et al.](https://github.com/senguptaumd/Background-Matting)) (Our model is more robust on HD footage)
+* [4K videos and images](https://drive.google.com/drive/folders/16H6Vz3294J-DEzauw06j4IUARRqYGgRD?usp=sharing)
+
+
+### Datasets
+
+* VideoMatte240K (Coming soon)
+* PhotoMatte85 (Coming soon)
+
+ 
+
+## Demo
+
+#### Scripts
+
+We provide several scripts in this repo for you to experiment with our model. More detail instructions are included in the files.
+* `inference_images.py`: Perform matting on a directory of images.
+* `inference_video.py`: Perform matting on a video.
+* `inference_webcam.py`: An interactive matting demo using your webcam.
+
+#### Notebooks
+Additionally, you can try our notebooks in Google Colab for performing matting on images and videos.
+
+* [Image matting (Colab)](https://colab.research.google.com/drive/1cTxFq1YuoJ5QPqaTcnskwlHDolnjBkB9?usp=sharing)
+* [Video matting (Colab)](https://colab.research.google.com/drive/1Y9zWfULc8-DDTSsCH-pX6Utw8skiJG5s?usp=sharing)
+
+ 
+
+## Usage / Documentation
+
+You can run our model using **PyTorch**, **TorchScript**, **TensorFlow**, and **ONNX**. For detail about using our model, please check out the [Usage / Documentation](doc/model_usage.md) page.
+
+ 
+
+## Training
+
+Training code will be released upon acceptance of the paper.
+
+ 
+
+## Project members
+* [Shanchuan Lin](https://www.linkedin.com/in/shanchuanlin/)*, University of Washington
+* [Andrey Ryabtsev](http://andreyryabtsev.com/)*, University of Washington
+* [Soumyadip Sengupta](https://homes.cs.washington.edu/~soumya91/), University of Washington
+* [Brian Curless](https://homes.cs.washington.edu/~curless/), University of Washington
+* [Steve Seitz](https://homes.cs.washington.edu/~seitz/), University of Washington
+* [Ira Kemelmacher-Shlizerman](https://sites.google.com/view/irakemelmacher/), University of Washington
+
+<sup>* Equal contribution.</sup>

+ 4 - 0
dataset/__init__.py

@@ -0,0 +1,4 @@
+from .images import ImagesDataset
+from .video import VideoDataset
+from .sample import SampleDataset
+from .zip import ZipDataset

+ 154 - 0
dataset/augmentation.py

@@ -0,0 +1,154 @@
+import random
+import torch
+import numpy as np
+import math
+from torchvision import transforms as T
+from torchvision.transforms import functional as F
+from PIL import Image, ImageFilter
+
+"""
+Pair transforms are MODs of regular transforms so that it takes in multiple images
+and apply exact transforms on all images. This is especially useful when we want the
+transforms on a pair of images.
+
+Example:
+    img1, img2, ..., imgN = transforms(img1, img2, ..., imgN)
+"""
+
+class PairCompose(T.Compose):
+    def __call__(self, *x):
+        for transform in self.transforms:
+            x = transform(*x)
+        return x
+    
+
+class PairApply:
+    def __init__(self, transforms):
+        self.transforms = transforms
+        
+    def __call__(self, *x):
+        return [self.transforms(xi) for xi in x]
+
+
+class PairApplyOnlyAtIndices:
+    def __init__(self, indices, transforms):
+        self.indices = indices
+        self.transforms = transforms
+    
+    def __call__(self, *x):
+        return [self.transforms(xi) if i in self.indices else xi for i, xi in enumerate(x)]
+
+
+class PairRandomAffine(T.RandomAffine):
+    def __init__(self, degrees, translate=None, scale=None, shear=None, resamples=None, fillcolor=0):
+        super().__init__(degrees, translate, scale, shear, Image.NEAREST, fillcolor)
+        self.resamples = resamples
+    
+    def __call__(self, *x):
+        if not len(x):
+            return []
+        param = self.get_params(self.degrees, self.translate, self.scale, self.shear, x[0].size)
+        resamples = self.resamples or [self.resample] * len(x)
+        return [F.affine(xi, *param, resamples[i], self.fillcolor) for i, xi in enumerate(x)]
+
+
+class PairRandomResizedCrop(T.RandomResizedCrop):
+    def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolations=None):
+        super().__init__(size, scale, ratio, Image.BILINEAR)
+        self.interpolations = interpolations
+    
+    def __call__(self, *x):
+        if not len(x):
+            return []
+        i, j, h, w = self.get_params(x[0], self.scale, self.ratio)
+        interpolations = self.interpolations or [self.interpolation] * len(x)
+        return [F.resized_crop(xi, i, j, h, w, self.size, interpolations[i]) for i, xi in enumerate(x)]
+    
+
+class PairRandomHorizontalFlip(T.RandomHorizontalFlip):
+    def __call__(self, *x):
+        if torch.rand(1) < self.p:
+            x = [F.hflip(xi) for xi in x]
+        return x
+
+
+class RandomBoxBlur:
+    def __init__(self, prob, max_radius):
+        self.prob = prob
+        self.max_radius = max_radius
+    
+    def __call__(self, img):
+        if torch.rand(1) < self.prob:
+            fil = ImageFilter.BoxBlur(random.choice(range(self.max_radius + 1)))
+            img = img.filter(fil)
+        return img
+
+
+class PairRandomBoxBlur(RandomBoxBlur):
+    def __call__(self, *x):
+        if torch.rand(1) < self.prob:
+            fil = ImageFilter.BoxBlur(random.choice(range(self.max_radius + 1)))
+            x = [xi.filter(fil) for xi in x]
+        return x
+
+
+class RandomSharpen:
+    def __init__(self, prob):
+        self.prob = prob
+        self.filter = ImageFilter.SHARPEN
+    
+    def __call__(self, img):
+        if torch.rand(1) < self.prob:
+            img = img.filter(self.filter)
+        return img
+    
+    
+class PairRandomSharpen(RandomSharpen):
+    def __call__(self, *x):
+        if torch.rand(1) < self.prob:
+            x = [xi.filter(self.filter) for xi in x]
+        return x
+    
+
+class PairRandomAffineAndResize:
+    def __init__(self, size, degrees, translate, scale, shear, ratio=(3./4., 4./3.), resample=Image.BILINEAR, fillcolor=0):
+        self.size = size
+        self.degrees = degrees
+        self.translate = translate
+        self.scale = scale
+        self.shear = shear
+        self.ratio = ratio
+        self.resample = resample
+        self.fillcolor = fillcolor
+    
+    def __call__(self, *x):
+        if not len(x):
+            return []
+        
+        w, h = x[0].size
+        scale_factor = max(self.size[1] / w, self.size[0] / h)
+        
+        w_padded = max(w, self.size[1])
+        h_padded = max(h, self.size[0])
+        
+        pad_h = int(math.ceil((h_padded - h) / 2))
+        pad_w = int(math.ceil((w_padded - w) / 2))
+        
+        scale = self.scale[0] * scale_factor, self.scale[1] * scale_factor
+        translate = self.translate[0] * scale_factor, self.translate[1] * scale_factor
+        affine_params = T.RandomAffine.get_params(self.degrees, translate, scale, self.shear, (w, h))
+        
+        def transform(img):
+            if pad_h > 0 or pad_w > 0:
+                img = F.pad(img, (pad_w, pad_h))
+            
+            img = F.affine(img, *affine_params, self.resample, self.fillcolor)
+            img = F.center_crop(img, self.size)
+            return img
+            
+        return [transform(xi) for xi in x]
+
+
+class RandomAffineAndResize(PairRandomAffineAndResize):
+    def __call__(self, img):
+        return super().__call__(img)[0]

+ 23 - 0
dataset/images.py

@@ -0,0 +1,23 @@
+import os
+import glob
+from torch.utils.data import Dataset
+from PIL import Image
+
+class ImagesDataset(Dataset):
+    def __init__(self, root, mode='RGB', transforms=None):
+        self.transforms = transforms
+        self.mode = mode
+        self.filenames = sorted([*glob.glob(os.path.join(root, '**', '*.jpg'), recursive=True),
+                                 *glob.glob(os.path.join(root, '**', '*.png'), recursive=True)])
+
+    def __len__(self):
+        return len(self.filenames)
+
+    def __getitem__(self, idx):
+        with Image.open(self.filenames[idx]) as img:
+            img = img.convert(self.mode)
+        
+        if self.transforms:
+            img = self.transforms(img)
+        
+        return img

+ 14 - 0
dataset/sample.py

@@ -0,0 +1,14 @@
+from torch.utils.data import Dataset
+
+
+class SampleDataset(Dataset):
+    def __init__(self, dataset, samples):
+        samples = min(samples, len(dataset))
+        self.dataset = dataset
+        self.indices = [i * int(len(dataset) / samples) for i in range(samples)]
+    
+    def __len__(self):
+        return len(self.indices)
+    
+    def __getitem__(self, idx):
+        return self.dataset[self.indices[idx]]

+ 38 - 0
dataset/video.py

@@ -0,0 +1,38 @@
+import cv2
+import numpy as np
+from torch.utils.data import Dataset
+from PIL import Image
+
+class VideoDataset(Dataset):
+    def __init__(self, path: str, transforms: any = None):
+        self.cap = cv2.VideoCapture(path)
+        self.transforms = transforms
+        
+        self.width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH))
+        self.height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
+        self.frame_rate = self.cap.get(cv2.CAP_PROP_FPS)
+        self.frame_count = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
+    
+    def __len__(self):
+        return self.frame_count
+    
+    def __getitem__(self, idx):
+        if isinstance(idx, slice):
+            return [self[i] for i in range(*idx.indices(len(self)))]
+        
+        if self.cap.get(cv2.CAP_PROP_POS_FRAMES) != idx:
+            self.cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
+        ret, img = self.cap.read()
+        if not ret:
+            raise IndexError(f'Idx: {idx} out of length: {len(self)}')
+        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+        img = Image.fromarray(img)
+        if self.transforms:
+            img = self.transforms(img)
+        return img
+    
+    def __enter__(self):
+        return self
+    
+    def __exit__(self, exc_type, exc_value, exc_traceback):
+        self.cap.release()

+ 20 - 0
dataset/zip.py

@@ -0,0 +1,20 @@
+from torch.utils.data import Dataset
+from typing import List
+
+class ZipDataset(Dataset):
+    def __init__(self, datasets: List[Dataset], transforms=None, assert_equal_length=False):
+        self.datasets = datasets
+        self.transforms = transforms
+        
+        if assert_equal_length:
+            for i in range(1, len(datasets)):
+                assert len(datasets[i]) == len(datasets[i - 1]), 'Datasets are not equal in length.'
+    
+    def __len__(self):
+        return max(len(d) for d in self.datasets)
+    
+    def __getitem__(self, idx):
+        x = tuple(d[idx % len(d)] for d in self.datasets)
+        if self.transforms:
+            x = self.transforms(*x)
+        return x

+ 153 - 0
doc/model_usage.md

@@ -0,0 +1,153 @@
+# Use our model
+Our model supports multiple inference backends and provides flexible settings to trade-off quality and computation at the inference time.
+
+## Overview
+* [Usage](#usage)
+    * [PyTorch (Research)](#pytorch-research)
+    * [TorchScript (Production)](#torchscript-production)
+    * [TensorFlow (Experimental)](#tensorflow-experimental)
+    * [ONNX (Experimental)](#onnx-experimental)
+* [Documentation](#documentation)
+
+&nbsp;
+
+## Usage
+
+
+### PyTorch (Research)
+
+The `/model` directory contains all the scripts that define the architecture. Follow the example to run inference using our model.
+
+#### Python
+
+```python
+import torch
+from model import MattingRefine
+
+device = torch.device('cuda')
+precision = torch.float32
+
+model = MattingRefine(backbone='mobilenetv2',
+                      backbone_scale=0.25,
+                      refine_mode='sampling',
+                      refine_sample_pixels=80_000)
+
+model.load_state_dict(torch.load('PATH_TO_CHECKPOINT.pth'))
+model = model.eval().to(precision).to(device)
+
+src = torch.rand(1, 3, 1080, 1920).to(precision).to(device)
+bgr = torch.rand(1, 3, 1080, 1920).to(precision).to(device)
+
+with torch.no_grad():
+    pha, fgr = model(src, bgr)[:2]
+```
+
+&nbsp;
+
+### TorchScript (Production)
+
+Inference with TorchScript does not need any script from this repo! Simply download the model file that has both the architecture and weights baked in. Follow the example to run our model in Python or C++ environment.
+
+#### Python
+
+```python
+import torch
+
+device = torch.device('cuda')
+precision = torch.float16
+
+model = torch.jit.load('PATH_TO_MODEL.pth')
+model.backbone_scale = 0.25
+model.refine_mode = 'sampling'
+model.refine_sample_pixels = 80_000
+
+model = model.to(device)
+
+src = torch.rand(1, 3, 1080, 1920).to(precision).to(device)
+bgr = torch.rand(1, 3, 1080, 1920).to(precision).to(device)
+
+pha, fgr = model(src, bgr)[:2]
+```
+
+#### C++
+
+```cpp
+#include <torch/script.h>
+
+int main() {
+    auto device = torch::Device("cuda");
+    auto precision = torch::kFloat16;
+
+    auto model = torch::jit::load("PATH_TO_MODEL.pth");
+    model.setattr("backbone_scale", 0.25);
+    model.setattr("refine_mode", "sampling");
+    model.setattr("refine_sample_pixels", 80000);
+    model.to(device);
+
+    auto src = torch::rand({1, 3, 1080, 1920}).to(device).to(precision);
+    auto bgr = torch::rand({1, 3, 1080, 1920}).to(device).to(precision);
+
+    auto outputs = model.forward({src, bgr}).toTuple()->elements();
+    auto pha = outputs[0].toTensor();
+    auto fgr = outputs[1].toTensor();
+}
+```
+&nbsp;
+
+### TensorFlow (Experimental)
+
+Please visit [BackgroundMattingV2-TensorFlow](https://github.com/PeterL1n/BackgroundMattingV2-TensorFlow) repo for more detail.
+
+&nbsp;
+
+### ONNX (Experimental)
+
+#### Python
+```python
+import onnxruntime
+import numpy as np
+
+sess = onnxruntime.InferenceSession('PATH_TO_MODEL.onnx')
+
+src = np.random.normal(size=(1, 3, 1080, 1920)).astype(np.float32)
+bgr = np.random.normal(size=(1, 3, 1080, 1920)).astype(np.float32)
+
+pha, fgr = sess.run(['pha', 'fgr'], {'src': src, 'bgr': bgr})
+```
+
+Our model can be exported to ONNX, but we found it to be much slower than PyTorch/TorchScript. We provide pre-exported `HD(backbone_scale=0.25, sample_pixels=80,000)` and `4K(backbone_scale=0.125, sample_pixels=320,000)` with MobileNetV2 backbone. Any other configuration can be exported through `export_onnx.py`. Notes that the ONNX model uses different operatorsthan PyTorch/TorchScript for compatibility. It uses `ROI_Align` rather than `Unfolding` for cropping patches, and `ScatterElement` rather than `ScatterND` for replacing patches. This can be configured inside `export_onnx.py`.
+
+&nbsp;
+
+## Documentation
+
+![Architecture](https://github.com/PeterL1n/Matting-PyTorch/blob/master/images/architecture.svg?raw=true)
+
+Our architecture consists of two network components. The base network operates on a downsampled resolution to produce coarse results, and the refinement network only refines error-prone patches to produce full-resolution output. This saves redundant computation and allows inference-time adjustment.
+
+#### Model Arguments:
+* `backbone_scale` (float, default: 0.25): The downsampling scale that the backbone should operate on. e.g, the backbone will operate on 480x270 resolution for a 1920x1080 input with backbone_scale=0.25.
+* `refine_mode` (string, default: `sampling`, options: [`sampling`, `thresholding`, `full`]): Mode of refinement. 
+    * `sampling` will set a fixed maximum amount of pixels to refine, defined by `refine_sample_pixels`. It is suitable for live applications where the computation and memory consumption per frame has a fixed upperbound.
+    * `thresholding` will dynamically refine all pixels with errors above the threshold, defined by `refine_threshold`. It is suitable for image editing application where quality outweights the speed of computation.
+    * `full` will refine the entire image. Only used for debugging.
+* `refine_sample_pixels` (int, default: 80,000). The fixed amount of pixels to refine. Used in `sampling` mode.
+* `refine_threshold` (float, default: 0.1). The threshold for refinement. Used in `thresholding` mode.
+* `prevent_oversampling` (bool, default: true). Used only in `sampling` mode. When false, it will refine even the unneccessary pixels to enforce refining `refine_sample_pixels` amount of pixels. This is only used for speedtesting.
+
+#### Model Inputs:
+* `src`: (B, 3, H, W): The source image with RGB channels normalized to 0 ~ 1.
+* `bgr`: (B, 3, H, W): The background image with RGB channels normalized to 0 ~ 1.
+
+#### Model Outputs:
+* `pha`: (B, 1, H, W): The alpha matte normalized to 0 ~ 1.
+* `fgr`: (B, 3, H, W): The foreground with RGB channels normalized to 0 ~ 1.
+* `pha_sm`: (B, 1, Hc, Wc): The coarse alpha matte normalized to 0 ~ 1.
+* `fgr_sm`: (B, 3, Hc, Wc): The coarse foreground with RGB channels normalized to 0 ~ 1.
+* `err_sm`: (B, 1, Hc, Wc): The coarse error prediction map normalized to 0 ~ 1.
+* `ref_sm`: (B, 1, H/4, W/4): The refinement regions, where 1 denotes a refined 4x4 patch.
+
+Only the `pha`, `fgr` outputs are needed for regular use cases. You can composite the alpha and foreground onto a new background using `com = pha * fgr + (1 - pha) * bgr`. The additional outputs are intermediate results used for training and debugging.
+
+
+We recommend `backbone_scale=0.25, refine_sample_pixels=80000` for HD and `backbone_scale=0.125, refine_sample_pixels=320000` for 4K.

+ 130 - 0
export_onnx.py

@@ -0,0 +1,130 @@
+"""
+Export MattingRefine as ONNX format
+
+Example:
+
+    python export_onnx.py \
+        --model-type mattingrefine \
+        --model-checkpoint "PATH_TO_MODEL_CHECKPOINT" \
+        --model-backbone resnet50 \
+        --model-backbone-scale 0.25 \
+        --model-refine-mode sampling \
+        --model-refine-sample-pixels 80000 \
+        --onnx-opset-version 11 \
+        --onnx-constant-folding \
+        --precision float32 \
+        --output "model.onnx" \
+        --validate
+
+"""
+
+
+import argparse
+import torch
+
+from model import MattingBase, MattingRefine
+
+
+# --------------- Arguments ---------------
+
+
+parser = argparse.ArgumentParser(description='Export ONNX')
+
+parser.add_argument('--model-type', type=str, required=True, choices=['mattingbase', 'mattingrefine'])
+parser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2'])
+parser.add_argument('--model-backbone-scale', type=float, default=0.25)
+parser.add_argument('--model-checkpoint', type=str, required=True)
+parser.add_argument('--model-refine-mode', type=str, default='sampling', choices=['full', 'sampling', 'thresholding'])
+parser.add_argument('--model-refine-sample-pixels', type=int, default=80_000)
+parser.add_argument('--model-refine-threshold', type=float, default=0.7)
+parser.add_argument('--model-refine-kernel-size', type=int, default=3)
+
+parser.add_argument('--onnx-verbose', type=bool, default=True)
+parser.add_argument('--onnx-opset-version', type=int, default=12)
+parser.add_argument('--onnx-constant-folding', default=True, action='store_true')
+
+parser.add_argument('--device', type=str, default='cpu')
+parser.add_argument('--precision', type=str, default='float32', choices=['float32', 'float16'])
+parser.add_argument('--validate', action='store_true')
+parser.add_argument('--output', type=str, required=True)
+
+args = parser.parse_args()
+
+
+# --------------- Main ---------------
+
+
+# Load model
+if args.model_type == 'mattingbase':
+    model = MattingBase(args.model_backbone)
+if args.model_type == 'mattingrefine':
+    model = MattingRefine(
+        args.model_backbone,
+        args.model_backbone_scale,
+        args.model_refine_mode,
+        args.model_refine_sample_pixels,
+        args.model_refine_threshold,
+        args.model_refine_kernel_size,
+        refine_patch_crop_method='roi_align',
+        refine_patch_replace_method='scatter_element')
+
+model.load_state_dict(torch.load(args.model_checkpoint, map_location=args.device), strict=False)
+precision = {'float32': torch.float32, 'float16': torch.float16}[args.precision]
+model.eval().to(precision).to(args.device)
+
+# Dummy Inputs
+src = torch.randn(2, 3, 1080, 1920).to(precision).to(args.device)
+bgr = torch.randn(2, 3, 1080, 1920).to(precision).to(args.device)
+
+# Export ONNX
+if args.model_type == 'mattingbase':
+    input_names=['src', 'bgr']
+    output_names = ['pha', 'fgr', 'err', 'hid']
+if args.model_type == 'mattingrefine':
+    input_names=['src', 'bgr']
+    output_names = ['pha', 'fgr', 'pha_sm', 'fgr_sm', 'err_sm', 'ref_sm']
+
+torch.onnx.export(
+    model=model,
+    args=(src, bgr),
+    f=args.output,
+    verbose=args.onnx_verbose,
+    opset_version=args.onnx_opset_version,
+    do_constant_folding=args.onnx_constant_folding,
+    input_names=input_names,
+    output_names=output_names,
+    dynamic_axes={name: {0: 'batch', 2: 'height', 3: 'width'} for name in [*input_names, *output_names]})
+
+print(f'ONNX model saved at: {args.output}')
+
+# Validation
+if args.validate:
+    import onnxruntime
+    import numpy as np
+    
+    print(f'Validating ONNX model.')
+    
+    # Test with different inputs.
+    src = torch.randn(1, 3, 720, 1280).to(precision).to(args.device)
+    bgr = torch.randn(1, 3, 720, 1280).to(precision).to(args.device)
+    
+    with torch.no_grad():
+        out_torch = model(src, bgr)
+    
+    sess = onnxruntime.InferenceSession(args.output)
+    out_onnx = sess.run(None, {
+        'src': src.cpu().numpy(),
+        'bgr': bgr.cpu().numpy()
+    })
+    
+    e_max = 0
+    for a, b, name in zip(out_torch, out_onnx, output_names):
+        b = torch.as_tensor(b)
+        e = torch.abs(a.cpu() - b).max()
+        e_max = max(e_max, e.item())
+        print(f'"{name}" output differs by maximum of {e}')
+        
+    if e_max < 0.001:
+        print('Validation passed.')
+    else:
+        raise 'Validation failed.'

+ 83 - 0
export_torchscript.py

@@ -0,0 +1,83 @@
+"""
+Export TorchScript
+
+    python export_torchscript.py \
+        --model-backbone resnet50 \
+        --model-checkpoint "PATH_TO_CHECKPOINT" \
+        --precision float32 \
+        --output "torchscript.pth"
+"""
+
+import argparse
+import torch
+from torch import nn
+from model import MattingRefine
+
+
+# --------------- Arguments ---------------
+
+
+parser = argparse.ArgumentParser(description='Export TorchScript')
+
+parser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2'])
+parser.add_argument('--model-checkpoint', type=str, required=True)
+parser.add_argument('--precision', type=str, default='float32', choices=['float32', 'float16'])
+parser.add_argument('--output', type=str, required=True)
+
+args = parser.parse_args()
+
+
+# --------------- Utils ---------------
+
+
+class MattingRefine_TorchScriptWrapper(nn.Module):
+    """
+    The purpose of this wrapper is to hoist all the configurable attributes to the top level.
+    So that the user can easily change them after loading the saved TorchScript model.
+    
+    Example:
+        model = torch.jit.load('torchscript.pth')
+        model.backbone_scale = 0.25
+        model.refine_mode = 'sampling'
+        model.refine_sample_pixels = 80_000
+        pha, fgr = model(src, bgr)[:2]
+    """
+    
+    def __init__(self, *args, **kwargs):
+        super().__init__()
+        self.model = MattingRefine(*args, **kwargs)
+        
+        # Hoist the attributes to the top level.
+        self.backbone_scale = self.model.backbone_scale
+        self.refine_mode = self.model.refiner.mode
+        self.refine_sample_pixels = self.model.refiner.sample_pixels
+        self.refine_threshold = self.model.refiner.threshold
+        self.refine_prevent_oversampling = self.model.refiner.prevent_oversampling
+    
+    def forward(self, src, bgr):
+        # Reset the attributes.
+        self.model.backbone_scale = self.backbone_scale
+        self.model.refiner.mode = self.refine_mode
+        self.model.refiner.sample_pixels = self.refine_sample_pixels
+        self.model.refiner.threshold = self.refine_threshold
+        self.model.refiner.prevent_oversampling = self.refine_prevent_oversampling
+        
+        return self.model(src, bgr)
+    
+    def load_state_dict(self, *args, **kwargs):
+        return self.model.load_state_dict(*args, **kwargs)
+    
+    
+# --------------- Main ---------------
+
+    
+model = MattingRefine_TorchScriptWrapper(args.model_backbone).eval()
+model.load_state_dict(torch.load(args.model_checkpoint, map_location='cpu'))
+for p in model.parameters():
+    p.requires_grad = False
+    
+if args.precision == 'float16':
+    model = model.half()
+    
+model = torch.jit.script(model)
+model.save(args.output)

Разлика између датотеке није приказан због своје велике величине
+ 2 - 0
images/architecture.svg


BIN
images/teaser.jpg


+ 143 - 0
inference_images.py

@@ -0,0 +1,143 @@
+"""
+Inference images: Extract matting on images.
+
+Example:
+
+    python inference_images.py \
+        --model-type mattingrefine \
+        --model-backbone resnet50 \
+        --model-backbone-scale 0.25 \
+        --model-refine-mode sampling \
+        --model-refine-sample-pixels 80000 \
+        --model-checkpoint "PATH_TO_CHECKPOINT" \
+        --images-src "PATH_TO_IMAGES_SRC_DIR" \
+        --images-bgr "PATH_TO_IMAGES_BGR_DIR" \
+        --output-dir "PATH_TO_OUTPUT_DIR" \
+        --output-type com fgr pha
+
+"""
+
+import argparse
+import torch
+import os
+import shutil
+
+from torch.nn import functional as F
+from torch.utils.data import DataLoader
+from torchvision import transforms as T
+from torchvision.transforms.functional import to_pil_image
+from threading import Thread
+from tqdm import tqdm
+
+from dataset import ImagesDataset, ZipDataset
+from dataset import augmentation as A
+from model import MattingBase, MattingRefine
+from inference_utils import HomographicAlignment
+
+
+# --------------- Arguments ---------------
+
+
+parser = argparse.ArgumentParser(description='Inference images')
+
+parser.add_argument('--model-type', type=str, required=True, choices=['mattingbase', 'mattingrefine'])
+parser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2'])
+parser.add_argument('--model-backbone-scale', type=float, default=0.25)
+parser.add_argument('--model-checkpoint', type=str, required=True)
+parser.add_argument('--model-refine-mode', type=str, default='sampling', choices=['full', 'sampling', 'thresholding'])
+parser.add_argument('--model-refine-sample-pixels', type=int, default=80_000)
+parser.add_argument('--model-refine-threshold', type=float, default=0.7)
+parser.add_argument('--model-refine-kernel-size', type=int, default=3)
+
+parser.add_argument('--images-src', type=str, required=True)
+parser.add_argument('--images-bgr', type=str, required=True)
+
+parser.add_argument('--preprocess-alignment', action='store_true')
+
+parser.add_argument('--output-dir', type=str, required=True)
+parser.add_argument('--output-types', type=str, required=True, nargs='+', choices=['com', 'pha', 'fgr', 'err', 'ref'])
+parser.add_argument('-y', action='store_true')
+
+args = parser.parse_args()
+
+
+assert 'err' not in args.output_types or args.model_type in ['mattingbase', 'mattingrefine'], \
+    'Only mattingbase and mattingrefine support err output'
+assert 'ref' not in args.output_types or args.model_type in ['mattingrefine'], \
+    'Only mattingrefine support ref output'
+
+
+# --------------- Main ---------------
+
+
+# Load model
+if args.model_type == 'mattingbase':
+    model = MattingBase(args.model_backbone)
+if args.model_type == 'mattingrefine':
+    model = MattingRefine(
+        args.model_backbone,
+        args.model_backbone_scale,
+        args.model_refine_mode,
+        args.model_refine_sample_pixels,
+        args.model_refine_threshold,
+        args.model_refine_kernel_size)
+
+model = model.cuda().eval()
+model.load_state_dict(torch.load(args.model_checkpoint), strict=False)
+
+
+# Load images
+dataset = ZipDataset([
+    ImagesDataset(args.images_src),
+    ImagesDataset(args.images_bgr),
+], assert_equal_length=True, transforms=A.PairCompose([
+    HomographicAlignment() if args.preprocess_alignment else A.PairApply(nn.Identity()),
+    A.PairApply(T.ToTensor)
+]))
+dataloader = DataLoader(dataset, batch_size=1, num_workers=8, pin_memory=True)
+
+
+# Create output directory
+if os.path.exists(args.output_dir):
+    if args.y or input(f'Directory {args.output_dir} already exists. Override? [Y/N]: ').lower() == 'y':
+        shutil.rmtree(args.output_dir)
+    else:
+        exit()
+
+for output_type in args.output_types:
+    os.makedirs(os.path.join(args.output_dir, output_type))
+    
+
+# Worker function
+def writer(img, path):
+    img = to_pil_image(img[0].cpu())
+    img.save(path)
+    
+    
+# Conversion loop
+with torch.no_grad():
+    for i, (src, bgr) in enumerate(tqdm(dataloader)):
+        filename = dataset.datasets[0].filenames[i]
+        src = src.cuda(non_blocking=True)
+        bgr = bgr.cuda(non_blocking=True)
+        
+        if args.model_type == 'mattingbase':
+            pha, fgr, err, _ = model(src, bgr)
+        elif args.model_type == 'mattingrefine':
+            pha, fgr, _, _, err, ref = model(src, bgr)
+        elif args.model_type == 'mattingbm':
+            pha, fgr = model(src, bgr)
+            
+        if 'com' in args.output_types:
+            com = torch.cat([fgr * pha.ne(0), pha], dim=1)
+            Thread(target=writer, args=(com, filename.replace(args.images_src, os.path.join(args.output_dir, 'com')).replace('.jpg', '.png'))).start()
+        if 'pha' in args.output_types:
+            Thread(target=writer, args=(pha, filename.replace(args.images_src, os.path.join(args.output_dir, 'pha')).replace('.png', '.jpg'))).start()
+        if 'fgr' in args.output_types:
+            Thread(target=writer, args=(fgr, filename.replace(args.images_src, os.path.join(args.output_dir, 'fgr')).replace('.png', '.jpg'))).start()
+        if 'err' in args.output_types:
+            err = F.interpolate(err, src.shape[2:], mode='bilinear', align_corners=False)
+            Thread(target=writer, args=(err, filename.replace(args.images_src, os.path.join(args.output_dir, 'err')).replace('.png', '.jpg'))).start()
+        if 'ref' in args.output_types:
+            ref = F.interpolate(ref, src.shape[2:], mode='nearest')
+            Thread(target=writer, args=(ref, filename.replace(args.images_src, os.path.join(args.output_dir, 'ref')).replace('.png', '.jpg'))).start()

+ 116 - 0
inference_speed_test.py

@@ -0,0 +1,116 @@
+"""
+Inference Speed Test
+
+Example:
+
+Run inference on random noise input for fixed computation setting.
+(i.e. mode in ['full', 'sampling'])
+
+    python inference_speed_test.py \
+        --model-type mattingrefine \
+        --model-backbone resnet50 \
+        --model-backbone-scale 0.25 \
+        --model-refine-mode sampling \
+        --model-refine-sample-pixels 80000 \
+        --batch-size 1 \
+        --resolution 1920 1080 \
+        --backend pytorch \
+        --precision float32
+
+Run inference on provided image input for dynamic computation setting.
+(i.e. mode in ['thresholding'])
+
+    python inference_speed_test.py \
+        --model-type mattingrefine \
+        --model-backbone resnet50 \
+        --model-backbone-scale 0.25 \
+        --model-checkpoint "PATH_TO_CHECKPOINT" \
+        --model-refine-mode thresholding \
+        --model-refine-threshold 0.7 \
+        --batch-size 1 \
+        --backend pytorch \
+        --precision float32 \
+        --image-src "PATH_TO_IMAGE_SRC" \
+        --image-bgr "PATH_TO_IMAGE_BGR"
+    
+"""
+
+import argparse
+import torch
+from torchvision.transforms.functional import to_tensor
+from tqdm import tqdm
+from PIL import Image
+
+from model import MattingBase, MattingRefine
+
+
+# --------------- Arguments ---------------
+
+
+parser = argparse.ArgumentParser()
+
+parser.add_argument('--model-type', type=str, required=True, choices=['mattingbase', 'mattingrefine'])
+parser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2'])
+parser.add_argument('--model-backbone-scale', type=float, default=0.25)
+parser.add_argument('--model-checkpoint', type=str, default=None)
+parser.add_argument('--model-refine-mode', type=str, default='sampling', choices=['full', 'sampling', 'thresholding'])
+parser.add_argument('--model-refine-sample-pixels', type=int, default=80_000)
+parser.add_argument('--model-refine-threshold', type=float, default=0.7)
+parser.add_argument('--model-refine-kernel-size', type=int, default=3)
+
+parser.add_argument('--batch-size', type=int, default=1)
+parser.add_argument('--resolution', type=int, default=None, nargs=2)
+parser.add_argument('--precision', type=str, default='float32', choices=['float32', 'float16'])
+parser.add_argument('--backend', type=str, default='pytorch', choices=['pytorch', 'torchscript'])
+
+parser.add_argument('--image-src', type=str, default=None)
+parser.add_argument('--image-bgr', type=str, default=None)
+
+args = parser.parse_args()
+
+
+assert type(args.image_src) == type(args.image_bgr),  'Image source and background must be provided together.'
+assert (not args.image_src) != (not args.resolution), 'Must provide either a resolution or an image and not both.'
+
+
+# --------------- Run Loop ---------------
+
+
+# Load model
+if args.model_type == 'mattingbase':
+    model = MattingBase(args.model_backbone)
+if args.model_type == 'mattingrefine':
+    model = MattingRefine(
+        args.model_backbone,
+        args.model_backbone_scale,
+        args.model_refine_mode,
+        args.model_refine_sample_pixels,
+        args.model_refine_threshold,
+        args.model_refine_kernel_size,
+        refine_prevent_oversampling=False)
+
+if args.model_checkpoint:
+    model.load_state_dict(torch.load(args.model_checkpoint), strict=False)
+    
+if args.precision == 'float32':
+    precision = torch.float32
+else:
+    precision = torch.float16
+    
+if args.backend == 'torchscript':
+    model = torch.jit.script(model)
+
+model = model.cuda().eval().to(precision)
+
+# Load data
+if not args.image_src:
+    src = torch.rand((args.batch_size, 3, *args.resolution[::-1]), device='cuda', dtype=precision)
+    bgr = torch.rand((args.batch_size, 3, *args.resolution[::-1]), device='cuda', dtype=precision)
+else:
+    src = to_tensor(Image.open(args.image_src)).unsqueeze(0).repeat(args.batch_size, 1, 1, 1).to(device='cuda', dtype=precision)
+    bgr = to_tensor(Image.open(args.image_bgr)).unsqueeze(0).repeat(args.batch_size, 1, 1, 1).to(device='cuda', dtype=precision)
+    
+# Loop
+with torch.no_grad():
+    for _ in tqdm(range(1000)):
+        model(src, bgr)

+ 46 - 0
inference_utils.py

@@ -0,0 +1,46 @@
+import numpy as np
+import cv2
+from PIL import Image
+
+
+class HomographicAlignment:
+    """
+    Apply homographic alignment on background to match with the source image.
+    """
+    
+    def __init__(self):
+        self.detector = cv2.ORB_create()
+        self.matcher = cv2.DescriptorMatcher_create(cv2.DESCRIPTOR_MATCHER_BRUTEFORCE)
+
+    def __call__(self, src, bgr):
+        src = np.asarray(src)
+        bgr = np.asarray(bgr)
+
+        keypoints_src, descriptors_src = self.detector.detectAndCompute(src, None)
+        keypoints_bgr, descriptors_bgr = self.detector.detectAndCompute(bgr, None)
+
+        matches = self.matcher.match(descriptors_bgr, descriptors_src, None)
+        matches.sort(key=lambda x: x.distance, reverse=False)
+        num_good_matches = int(len(matches) * 0.15)
+        matches = matches[:num_good_matches]
+
+        points_src = np.zeros((len(matches), 2), dtype=np.float32)
+        points_bgr = np.zeros((len(matches), 2), dtype=np.float32)
+        for i, match in enumerate(matches):
+            points_src[i, :] = keypoints_src[match.trainIdx].pt
+            points_bgr[i, :] = keypoints_bgr[match.queryIdx].pt
+
+        H, _ = cv2.findHomography(points_bgr, points_src, cv2.RANSAC)
+
+        h, w = src.shape[:2]
+        bgr = cv2.warpPerspective(bgr, H, (w, h))
+        msk = cv2.warpPerspective(np.ones((h, w)), H, (w, h))
+
+        # For areas that is outside of the background, 
+        # We just copy pixels from the source.
+        bgr[msk != 1] = src[msk != 1]
+
+        src = Image.fromarray(src)
+        bgr = Image.fromarray(bgr)
+        
+        return src, bgr

+ 203 - 0
inference_video.py

@@ -0,0 +1,203 @@
+"""
+Inference video: Extract matting on video.
+
+Example:
+
+    python inference_video.py \
+        --model-type mattingrefine \
+        --model-backbone resnet50 \
+        --model-backbone-scale 0.25 \
+        --model-refine-mode sampling \
+        --model-refine-sample-pixels 80000 \
+        --model-checkpoint "PATH_TO_CHECKPOINT" \
+        --video-src "PATH_TO_VIDEO_SRC" \
+        --video-bgr "PATH_TO_VIDEO_BGR" \
+        --video-resize 1920 1080 \
+        --output-dir "PATH_TO_OUTPUT_DIR" \
+        --output-type com fgr pha err ref
+
+"""
+
+import argparse
+import cv2
+import torch
+import os
+import shutil
+
+from torch import nn
+from torch.nn import functional as F
+from torch.utils.data import DataLoader
+from torchvision import transforms as T
+from torchvision.transforms.functional import to_pil_image
+from threading import Thread
+from tqdm import tqdm
+from PIL import Image
+
+from dataset import VideoDataset, ZipDataset
+from dataset import augmentation as A
+from model import MattingBase, MattingRefine
+from inference_utils import HomographicAlignment
+
+
+# --------------- Arguments ---------------
+
+
+parser = argparse.ArgumentParser(description='Inference video')
+
+parser.add_argument('--model-type', type=str, required=True, choices=['mattingbase', 'mattingrefine'])
+parser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2'])
+parser.add_argument('--model-backbone-scale', type=float, default=0.25)
+parser.add_argument('--model-checkpoint', type=str, required=True)
+parser.add_argument('--model-refine-mode', type=str, default='sampling', choices=['full', 'sampling', 'thresholding'])
+parser.add_argument('--model-refine-sample-pixels', type=int, default=80_000)
+parser.add_argument('--model-refine-threshold', type=float, default=0.7)
+parser.add_argument('--model-refine-kernel-size', type=int, default=3)
+
+parser.add_argument('--video-src', type=str, required=True)
+parser.add_argument('--video-bgr', type=str, required=True)
+parser.add_argument('--video-resize', type=int, default=None, nargs=2)
+
+parser.add_argument('--preprocess-alignment', action='store_true')
+
+parser.add_argument('--output-dir', type=str, required=True)
+parser.add_argument('--output-types', type=str, required=True, nargs='+', choices=['com', 'pha', 'fgr', 'err', 'ref'])
+parser.add_argument('--output-format', type=str, default='video', choices=['video', 'image_sequences'])
+
+args = parser.parse_args()
+
+
+assert 'err' not in args.output_types or args.model_type in ['mattingbase', 'mattingrefine'], \
+    'Only mattingbase and mattingrefine support err output'
+assert 'ref' not in args.output_types or args.model_type in ['mattingrefine'], \
+    'Only mattingrefine support ref output'
+
+# --------------- Utils ---------------
+
+
+class VideoWriter:
+    def __init__(self, path, frame_rate, width, height):
+        self.out = cv2.VideoWriter(path, cv2.VideoWriter_fourcc(*'mp4v'), frame_rate, (width, height))
+        
+    def add_batch(self, frames):
+        frames = frames.mul(255).byte()
+        frames = frames.cpu().permute(0, 2, 3, 1).numpy()
+        for i in range(frames.shape[0]):
+            frame = frames[i]
+            frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
+            self.out.write(frame)
+            
+
+class ImageSequenceWriter:
+    def __init__(self, path, extension):
+        self.path = path
+        self.extension = extension
+        self.index = 0
+        os.makedirs(path)
+        
+    def add_batch(self, frames):
+        Thread(target=self._add_batch, args=(frames, self.index)).start()
+        self.index += frames.shape[0]
+            
+    def _add_batch(self, frames, index):
+        frames = frames.cpu()
+        for i in range(frames.shape[0]):
+            frame = frames[i]
+            frame = to_pil_image(frame)
+            frame.save(os.path.join(self.path, str(index + i).zfill(5) + '.' + self.extension))
+
+
+# --------------- Main ---------------
+
+
+# Load model
+if args.model_type == 'mattingbase':
+    model = MattingBase(args.model_backbone)
+if args.model_type == 'mattingrefine':
+    model = MattingRefine(
+        args.model_backbone,
+        args.model_backbone_scale,
+        args.model_refine_mode,
+        args.model_refine_sample_pixels,
+        args.model_refine_threshold,
+        args.model_refine_kernel_size)
+
+model = model.cuda().eval()
+model.load_state_dict(torch.load(args.model_checkpoint), strict=False)
+
+
+# Load video and background
+vid = VideoDataset(args.video_src)
+bgr = [Image.open(args.video_bgr).convert('RGB')]
+dataset = ZipDataset([vid, bgr], transforms=A.PairCompose([
+    A.PairApply(T.Resize(args.video_resize[::-1]) if args.video_resize else nn.Identity()),
+    HomographicAlignment() if args.preprocess_alignment else A.PairApply(nn.Identity()),
+    A.PairApply(T.ToTensor())
+]))
+
+# Create output directory
+if os.path.exists(args.output_dir):
+    if input(f'Directory {args.output_dir} already exists. Override? [Y/N]: ').lower() == 'y':
+        shutil.rmtree(args.output_dir)
+    else:
+        exit()
+os.makedirs(args.output_dir)
+
+
+# Prepare writers
+if args.output_format == 'video':
+    h = args.video_resize[1] if args.video_resize is not None else vid.height
+    w = args.video_resize[0] if args.video_resize is not None else vid.width
+    if 'com' in args.output_types:
+        com_writer = VideoWriter(os.path.join(args.output_dir, 'com.mp4'), vid.frame_rate, w, h)
+    if 'pha' in args.output_types:
+        pha_writer = VideoWriter(os.path.join(args.output_dir, 'pha.mp4'), vid.frame_rate, w, h)
+    if 'fgr' in args.output_types:
+        fgr_writer = VideoWriter(os.path.join(args.output_dir, 'fgr.mp4'), vid.frame_rate, w, h)
+    if 'err' in args.output_types:
+        err_writer = VideoWriter(os.path.join(args.output_dir, 'err.mp4'), vid.frame_rate, w, h)
+    if 'ref' in args.output_types:
+        ref_writer = VideoWriter(os.path.join(args.output_dir, 'ref.mp4'), vid.frame_rate, w, h)
+else:
+    if 'com' in args.output_types:
+        com_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'com'), 'png')
+    if 'pha' in args.output_types:
+        pha_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'pha'), 'jpg')
+    if 'fgr' in args.output_types:
+        fgr_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'fgr'), 'jpg')
+    if 'err' in args.output_types:
+        err_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'err'), 'jpg')
+    if 'ref' in args.output_types:
+        ref_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'ref'), 'jpg')
+    
+
+# Conversion loop
+with torch.no_grad():
+    for src, bgr in tqdm(DataLoader(dataset, batch_size=1, pin_memory=True)):
+        src = src.cuda(non_blocking=True)
+        bgr = bgr.cuda(non_blocking=True)
+        
+        if args.model_type == 'mattingbase':
+            pha, fgr, err, _ = model(src, bgr)
+        elif args.model_type == 'mattingrefine':
+            pha, fgr, _, _, err, ref = model(src, bgr)
+        elif args.model_type == 'mattingbm':
+            pha, fgr = model(src, bgr)
+
+        if 'com' in args.output_types:
+            if args.output_format == 'video':
+                # Output composite with green background
+                bgr_green = torch.tensor([120/255, 255/255, 155/255], device='cuda').view(1, 3, 1, 1)
+                com = fgr * pha + bgr_green * (1 - pha)
+                com_writer.add_batch(com)
+            else:
+                # Output composite as rgba png images
+                com = torch.cat([fgr * pha.ne(0), pha], dim=1)
+                com_writer.add_batch(com)
+        if 'pha' in args.output_types:
+            pha_writer.add_batch(pha)
+        if 'fgr' in args.output_types:
+            fgr_writer.add_batch(fgr)
+        if 'err' in args.output_types:
+            err_writer.add_batch(F.interpolate(err, src.shape[2:], mode='bilinear', align_corners=False))
+        if 'ref' in args.output_types:
+            ref_writer.add_batch(F.interpolate(ref, src.shape[2:], mode='nearest'))

+ 171 - 0
inference_webcam.py

@@ -0,0 +1,171 @@
+"""
+Inference on webcams: Use a model on webcam input.
+
+Once launched, the script is in background collection mode.
+Press B to toggle between background capture mode and matting mode. The frame shown when B is pressed is used as background for matting.
+Press Q to exit.
+
+Example:
+
+    python inference_webcam.py \
+        --model-type mattingrefine \
+        --model-backbone resnet50 \
+        --model-checkpoint "PATH_TO_CHECKPOINT" \
+        --resolution 1280 720
+
+"""
+
+import argparse, os, shutil, time
+import cv2
+import torch
+
+from torch import nn
+from torch.utils.data import DataLoader
+from torchvision.transforms import Compose, ToTensor, Resize
+from torchvision.transforms.functional import to_pil_image
+from threading import Thread, Lock
+from tqdm import tqdm
+from PIL import Image
+
+from dataset import VideoDataset
+from model import MattingBase, MattingRefine
+
+
+# --------------- Arguments ---------------
+
+
+parser = argparse.ArgumentParser(description='Inference from web-cam')
+
+parser.add_argument('--model-type', type=str, required=True, choices=['mattingbase', 'mattingrefine'])
+parser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2'])
+parser.add_argument('--model-backbone-scale', type=float, default=0.25)
+parser.add_argument('--model-checkpoint', type=str, required=True)
+parser.add_argument('--model-refine-mode', type=str, default='sampling', choices=['full', 'sampling', 'thresholding'])
+parser.add_argument('--model-refine-sample-pixels', type=int, default=80_000)
+parser.add_argument('--model-refine-threshold', type=float, default=0.7)
+
+parser.add_argument('--hide-fps', action='store_true')
+parser.add_argument('--resolution', type=int, nargs=2, metavar=('width', 'height'), default=(1280, 720))
+args = parser.parse_args()
+
+
+# ----------- Utility classes -------------
+
+
+# A wrapper that reads data from cv2.VideoCapture in its own thread to optimize.
+# Use .read() in a tight loop to get the newest frame
+class Camera:
+    def __init__(self, device_id=0, width=1280, height=720):
+        self.capture = cv2.VideoCapture(device_id)
+        self.capture.set(cv2.CAP_PROP_FRAME_WIDTH, width)
+        self.capture.set(cv2.CAP_PROP_FRAME_HEIGHT, height)
+        self.width = int(self.capture.get(cv2.CAP_PROP_FRAME_WIDTH))
+        self.height = int(self.capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
+        # self.capture.set(cv2.CAP_PROP_BUFFERSIZE, 2)
+        self.success_reading, self.frame = self.capture.read()
+        self.read_lock = Lock()
+        self.thread = Thread(target=self.__update, args=())
+        self.thread.daemon = True
+        self.thread.start()
+
+    def __update(self):
+        while self.success_reading:
+            grabbed, frame = self.capture.read()
+            with self.read_lock:
+                self.success_reading = grabbed
+                self.frame = frame
+
+    def read(self):
+        with self.read_lock:
+            frame = self.frame.copy()
+        return frame
+    def __exit__(self, exec_type, exc_value, traceback):
+        self.capture.release()
+
+# An FPS tracker that computes exponentialy moving average FPS
+class FPSTracker:
+    def __init__(self, ratio=0.5):
+        self._last_tick = None
+        self._avg_fps = None
+        self.ratio = ratio
+    def tick(self):
+        if self._last_tick is None:
+            self._last_tick = time.time()
+            return None
+        t_new = time.time()
+        fps_sample = 1.0 / (t_new - self._last_tick)
+        self._avg_fps = self.ratio * fps_sample + (1 - self.ratio) * self._avg_fps if self._avg_fps is not None else fps_sample
+        self._last_tick = t_new
+        return self.get()
+    def get(self):
+        return self._avg_fps
+
+# Wrapper for playing a stream with cv2.imshow(). It can accept an image and return keypress info for basic interactivity.
+# It also tracks FPS and optionally overlays info onto the stream.
+class Displayer:
+    def __init__(self, title, width=None, height=None, show_info=True):
+        self.title, self.width, self.height = title, width, height
+        self.show_info = show_info
+        self.fps_tracker = FPSTracker()
+        cv2.namedWindow(self.title, cv2.WINDOW_NORMAL)
+        if width is not None and height is not None:
+            cv2.resizeWindow(self.title, width, height)
+    # Update the currently showing frame and return key press char code
+    def step(self, image):
+        fps_estimate = self.fps_tracker.tick()
+        if self.show_info and fps_estimate is not None:
+            message = f"{int(fps_estimate)} fps | {self.width}x{self.height}"
+            cv2.putText(image, message, (10, 40), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (0, 0, 0))
+        cv2.imshow(self.title, image)
+        return cv2.waitKey(1) & 0xFF
+
+
+# --------------- Main ---------------
+
+
+# Load model
+if args.model_type == 'mattingbase':
+    model = MattingBase(args.model_backbone)
+if args.model_type == 'mattingrefine':
+    model = MattingRefine(
+        args.model_backbone,
+        args.model_backbone_scale,
+        args.model_refine_mode,
+        args.model_refine_sample_pixels,
+        args.model_refine_threshold)
+
+model = model.cuda().eval()
+model.load_state_dict(torch.load(args.model_checkpoint), strict=False)
+
+
+width, height = args.resolution
+cam = Camera(width=width, height=height)
+dsp = Displayer('MattingV2', cam.width, cam.height, show_info=(not args.hide_fps))
+
+def cv2_frame_to_cuda(frame):
+    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
+    return ToTensor()(Image.fromarray(frame)).unsqueeze_(0).cuda()
+
+with torch.no_grad():
+    while True:
+        bgr = None
+        while True: # grab bgr
+            frame = cam.read()
+            key = dsp.step(frame)
+            if key == ord('b'):
+                bgr = cv2_frame_to_cuda(cam.read())
+                break
+            elif key == ord('q'):
+                exit()
+        while True: # matting
+            frame = cam.read()
+            src = cv2_frame_to_cuda(frame)
+            pha, fgr = model(src, bgr)[:2]
+            res = pha * fgr + (1 - pha) * torch.ones_like(fgr)
+            res = res.mul(255).byte().cpu().permute(0, 2, 3, 1).numpy()[0]
+            res = cv2.cvtColor(res, cv2.COLOR_RGB2BGR)
+            key = dsp.step(res)
+            if key == ord('b'):
+                break
+            elif key == ord('q'):
+                exit()

+ 1 - 0
model/__init__.py

@@ -0,0 +1 @@
+from .model import Base, MattingBase, MattingRefine

+ 51 - 0
model/decoder.py

@@ -0,0 +1,51 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class Decoder(nn.Module):
+    """
+    Decoder upsamples the image by combining the feature maps at all resolutions from the encoder.
+    
+    Input:
+        x4: (B, C, H/16, W/16) feature map at 1/16 resolution.
+        x3: (B, C, H/8, W/8) feature map at 1/8 resolution.
+        x2: (B, C, H/4, W/4) feature map at 1/4 resolution.
+        x1: (B, C, H/2, W/2) feature map at 1/2 resolution.
+        x0: (B, C, H, W) feature map at full resolution.
+        
+    Output:
+        x: (B, C, H, W) upsampled output at full resolution.
+    """
+    
+    def __init__(self, channels, feature_channels):
+        super().__init__()
+        self.conv1 = nn.Conv2d(feature_channels[0] + channels[0], channels[1], 3, padding=1, bias=False)
+        self.bn1 = nn.BatchNorm2d(channels[1])
+        self.conv2 = nn.Conv2d(feature_channels[1] + channels[1], channels[2], 3, padding=1, bias=False)
+        self.bn2 = nn.BatchNorm2d(channels[2])
+        self.conv3 = nn.Conv2d(feature_channels[2] + channels[2], channels[3], 3, padding=1, bias=False)
+        self.bn3 = nn.BatchNorm2d(channels[3])
+        self.conv4 = nn.Conv2d(feature_channels[3] + channels[3], channels[4], 3, padding=1)
+        self.relu = nn.ReLU(True)
+
+    def forward(self, x4, x3, x2, x1, x0):
+        x = F.interpolate(x4, size=x3.shape[2:], mode='bilinear', align_corners=False)
+        x = torch.cat([x, x3], dim=1)
+        x = self.conv1(x)
+        x = self.bn1(x)
+        x = self.relu(x)
+        x = F.interpolate(x, size=x2.shape[2:], mode='bilinear', align_corners=False)
+        x = torch.cat([x, x2], dim=1)
+        x = self.conv2(x)
+        x = self.bn2(x)
+        x = self.relu(x)
+        x = F.interpolate(x, size=x1.shape[2:], mode='bilinear', align_corners=False)
+        x = torch.cat([x, x1], dim=1)
+        x = self.conv3(x)
+        x = self.bn3(x)
+        x = self.relu(x)
+        x = F.interpolate(x, size=x0.shape[2:], mode='bilinear', align_corners=False)
+        x = torch.cat([x, x0], dim=1)
+        x = self.conv4(x)
+        return x

+ 56 - 0
model/mobilenet.py

@@ -0,0 +1,56 @@
+from torch import nn
+from torchvision.models import MobileNetV2
+
+
+class MobileNetV2Encoder(MobileNetV2):
+    """
+    MobileNetV2Encoder inherits from torchvision's official MobileNetV2. It is modified to
+    use dilation on the last block to maintain output stride 16, and deleted the
+    classifier block that was originally used for classification. The forward method 
+    additionally returns the feature maps at all resolutions for decoder's use.
+    """
+    
+    def __init__(self, in_channels, norm_layer=None):
+        super().__init__()
+        
+        # Replace first conv layer if in_channels doesn't match.
+        if in_channels != 3:
+            self.features[0][0] = nn.Conv2d(in_channels, 32, 3, 2, 1, bias=False)
+       
+        # Remove last block
+        self.features = self.features[:-1]
+        
+        # Change to use dilation to maintain output stride = 16
+        self.features[14].conv[1][0].stride = (1, 1)
+        for feature in self.features[15:]:
+            feature.conv[1][0].dilation = (2, 2)
+            feature.conv[1][0].padding = (2, 2)
+        
+        # Delete classifier
+        del self.classifier
+        
+    def forward(self, x):
+        x0 = x  # 1/1
+        x = self.features[0](x)
+        x = self.features[1](x)
+        x1 = x  # 1/2
+        x = self.features[2](x)
+        x = self.features[3](x)
+        x2 = x  # 1/4
+        x = self.features[4](x)
+        x = self.features[5](x)
+        x = self.features[6](x)
+        x3 = x  # 1/8
+        x = self.features[7](x)
+        x = self.features[8](x)
+        x = self.features[9](x)
+        x = self.features[10](x)
+        x = self.features[11](x)
+        x = self.features[12](x)
+        x = self.features[13](x)
+        x = self.features[14](x)
+        x = self.features[15](x)
+        x = self.features[16](x)
+        x = self.features[17](x)
+        x4 = x  # 1/16
+        return x4, x3, x2, x1, x0

+ 196 - 0
model/model.py

@@ -0,0 +1,196 @@
+import torch
+from torch import nn
+from torch.nn import functional as F
+from torchvision.models.segmentation.deeplabv3 import ASPP
+
+from .decoder import Decoder
+from .mobilenet import MobileNetV2Encoder
+from .refiner import Refiner
+from .resnet import ResNetEncoder
+from .utils import load_matched_state_dict
+
+
+class Base(nn.Module):
+    """
+    A generic implementation of the base encoder-decoder network inspired by DeepLab.
+    Accepts arbitrary channels for input and output.
+    """
+    
+    def __init__(self, backbone: str, in_channels: int, out_channels: int):
+        super().__init__()
+        assert backbone in ["resnet50", "resnet101", "mobilenetv2"]
+        if backbone in ['resnet50', 'resnet101']:
+            self.backbone = ResNetEncoder(in_channels, variant=backbone)
+            self.aspp = ASPP(2048, [3, 6, 9])
+            self.decoder = Decoder([256, 128, 64, 48, out_channels], [512, 256, 64, in_channels])
+        else:
+            self.backbone = MobileNetV2Encoder(in_channels)
+            self.aspp = ASPP(320, [3, 6, 9])
+            self.decoder = Decoder([256, 128, 64, 48, out_channels], [32, 24, 16, in_channels])
+
+    def forward(self, x):
+        x, *shortcuts = self.backbone(x)
+        x = self.aspp(x)
+        x = self.decoder(x, *shortcuts)
+        return x
+    
+    def load_pretrained_deeplabv3_state_dict(self, state_dict, print_stats=True):
+        # Pretrained DeepLabV3 models are provided by <https://github.com/VainF/DeepLabV3Plus-Pytorch>.
+        # This method converts and loads their pretrained state_dict to match with our model structure.
+        # This method is not needed if you are not planning to train from deeplab weights.
+        # Use load_state_dict() for normal weight loading.
+        
+        # Convert state_dict naming for aspp module
+        state_dict = {k.replace('classifier.classifier.0', 'aspp'): v for k, v in state_dict.items()}
+
+        if isinstance(self.backbone, ResNetEncoder):
+            # ResNet backbone does not need change.
+            load_matched_state_dict(self, state_dict, print_stats)
+        else:
+            # Change MobileNetV2 backbone to state_dict format, then change back after loading.
+            backbone_features = self.backbone.features
+            self.backbone.low_level_features = backbone_features[:4]
+            self.backbone.high_level_features = backbone_features[4:]
+            del self.backbone.features
+            load_matched_state_dict(self, state_dict, print_stats)
+            self.backbone.features = backbone_features
+            del self.backbone.low_level_features
+            del self.backbone.high_level_features
+
+
+class MattingBase(Base):
+    """
+    MattingBase is used to produce coarse global results at a lower resolution.
+    MattingBase extends Base.
+    
+    Args:
+        backbone: ["resnet50", "resnet101", "mobilenetv2"]
+        
+    Input:
+        src: (B, 3, H, W) the source image. Channels are RGB values normalized to 0 ~ 1.
+        bgr: (B, 3, H, W) the background image . Channels are RGB values normalized to 0 ~ 1.
+    
+    Output:
+        pha: (B, 1, H, W) the alpha prediction. Normalized to 0 ~ 1.
+        fgr: (B, 3, H, W) the foreground prediction. Channels are RGB values normalized to 0 ~ 1.
+        err: (B, 1, H, W) the error prediction. Normalized to 0 ~ 1.
+        hid: (B, 32, H, W) the hidden encoding. Used for connecting refiner module.
+        
+    Example:
+        model = MattingBase(backbone='resnet50')
+        
+        pha, fgr, err, hid = model(src, bgr)    # for training
+        pha, fgr = model(src, bgr)[:2]          # for inference
+    """
+    
+    def __init__(self, backbone: str):
+        super().__init__(backbone, in_channels=6, out_channels=(1 + 3 + 1 + 32))
+        
+    def forward(self, src, bgr):
+        x = torch.cat([src, bgr], dim=1)
+        x, *shortcuts = self.backbone(x)
+        x = self.aspp(x)
+        x = self.decoder(x, *shortcuts)
+        pha = x[:, 0:1].clamp_(0., 1.)
+        fgr = x[:, 1:4].add(src).clamp_(0., 1.)
+        err = x[:, 4:5].clamp_(0., 1.)
+        hid = x[:, 5: ].relu_()
+        return pha, fgr, err, hid
+
+
+class MattingRefine(MattingBase):
+    """
+    MattingRefine includes the refiner module to upsample coarse result to full resolution.
+    MattingRefine extends MattingBase.
+    
+    Args:
+        backbone: ["resnet50", "resnet101", "mobilenetv2"]
+        backbone_scale: The image downsample scale for passing through backbone, default 1/4 or 0.25.
+                        Must not be greater than 1/2.
+        refine_mode: refine area selection mode. Options:
+            "full"         - No area selection, refine everywhere using regular Conv2d.
+            "sampling"     - Refine fixed amount of pixels ranked by the top most errors.
+            "thresholding" - Refine varying amount of pixels that has more error than the threshold.
+        refine_sample_pixels: number of pixels to refine. Only used when mode == "sampling".
+        refine_threshold: error threshold ranged from 0 ~ 1. Refine where err > threshold. Only used when mode == "thresholding".
+        refine_kernel_size: the refiner's convolutional kernel size. Options: [1, 3]
+        refine_prevent_oversampling: prevent sampling more pixels than needed for sampling mode. Set False only for speedtest.
+
+    Input:
+        src: (B, 3, H, W) the source image. Channels are RGB values normalized to 0 ~ 1.
+        bgr: (B, 3, H, W) the background image. Channels are RGB values normalized to 0 ~ 1.
+    
+    Output:
+        pha: (B, 1, H, W) the alpha prediction. Normalized to 0 ~ 1.
+        fgr: (B, 3, H, W) the foreground prediction. Channels are RGB values normalized to 0 ~ 1.
+        pha_sm: (B, 1, Hc, Wc) the coarse alpha prediction from matting base. Normalized to 0 ~ 1.
+        fgr_sm: (B, 3, Hc, Hc) the coarse foreground prediction from matting base. Normalized to 0 ~ 1.
+        err_sm: (B, 1, Hc, Wc) the coarse error prediction from matting base. Normalized to 0 ~ 1.
+        ref_sm: (B, 1, H/4, H/4) the quarter resolution refinement map. 1 indicates refined 4x4 patch locations.
+        
+    Example:
+        model = MattingRefine(backbone='resnet50', backbone_scale=1/4, refine_mode='sampling', refine_sample_pixels=80_000)
+        model = MattingRefine(backbone='resnet50', backbone_scale=1/4, refine_mode='thresholding', refine_threshold=0.1)
+        model = MattingRefine(backbone='resnet50', backbone_scale=1/4, refine_mode='full')
+        
+        pha, fgr, pha_sm, fgr_sm, err_sm, ref_sm = model(src, bgr)   # for training
+        pha, fgr = model(src, bgr)[:2]                               # for inference
+    """
+    
+    def __init__(self,
+                 backbone: str,
+                 backbone_scale: float = 1/4,
+                 refine_mode: str = 'sampling',
+                 refine_sample_pixels: int = 80_000,
+                 refine_threshold: float = 0.1,
+                 refine_kernel_size: int = 3,
+                 refine_prevent_oversampling: bool = True,
+                 refine_patch_crop_method: str = 'unfold',
+                 refine_patch_replace_method: str = 'scatter_nd'):
+        assert backbone_scale <= 1/2, 'backbone_scale should not be greater than 1/2'
+        super().__init__(backbone)
+        self.backbone_scale = backbone_scale
+        self.refiner = Refiner(refine_mode,
+                               refine_sample_pixels,
+                               refine_threshold,
+                               refine_kernel_size,
+                               refine_prevent_oversampling,
+                               refine_patch_crop_method,
+                               refine_patch_replace_method)
+    
+    def forward(self, src, bgr):
+        assert src.size() == bgr.size(), 'src and bgr must have the same shape'
+        assert src.size(2) // 4 * 4 == src.size(2) and src.size(3) // 4 * 4 == src.size(3), \
+            'src and bgr must have width and height that are divisible by 4'
+        
+        # Downsample src and bgr for backbone
+        src_sm = F.interpolate(src,
+                               scale_factor=self.backbone_scale,
+                               mode='bilinear',
+                               align_corners=False,
+                               recompute_scale_factor=True)
+        bgr_sm = F.interpolate(bgr,
+                               scale_factor=self.backbone_scale,
+                               mode='bilinear',
+                               align_corners=False,
+                               recompute_scale_factor=True)
+        
+        # Base
+        x = torch.cat([src_sm, bgr_sm], dim=1)
+        x, *shortcuts = self.backbone(x)
+        x = self.aspp(x)
+        x = self.decoder(x, *shortcuts)
+        pha_sm = x[:, 0:1].clamp_(0., 1.)
+        fgr_sm = x[:, 1:4]
+        err_sm = x[:, 4:5].clamp_(0., 1.)
+        hid_sm = x[:, 5: ].relu_()
+
+        # Refiner
+        pha, fgr, ref_sm = self.refiner(src, bgr, pha_sm, fgr_sm, err_sm, hid_sm)
+        
+        # Clamp outputs
+        pha = pha.clamp_(0., 1.)
+        fgr = fgr.add_(src).clamp_(0., 1.)
+        fgr_sm = src_sm.add_(fgr_sm).clamp_(0., 1.)
+        
+        return pha, fgr, pha_sm, fgr_sm, err_sm, ref_sm

+ 250 - 0
model/refiner.py

@@ -0,0 +1,250 @@
+import torch
+import torchvision
+from torch import nn
+from torch.nn import functional as F
+from typing import Tuple
+
+
+class Refiner(nn.Module):
+    """
+    Refiner refines the coarse output to full resolution.
+    
+    Args:
+        mode: area selection mode. Options:
+            "full"         - No area selection, refine everywhere using regular Conv2d.
+            "sampling"     - Refine fixed amount of pixels ranked by the top most errors.
+            "thresholding" - Refine varying amount of pixels that have greater error than the threshold.
+        sample_pixels: number of pixels to refine. Only used when mode == "sampling".
+        threshold: error threshold ranged from 0 ~ 1. Refine where err > threshold. Only used when mode == "thresholding".
+        kernel_size: The convolution kernel_size. Options: [1, 3]
+        prevent_oversampling: True for regular cases, False for speedtest.
+    
+    Compatibility Args:
+        patch_crop_method: the operation for cropping patches. Options:
+            "unfold"           - Best performance for PyTorch and TorchScript.
+            "roi_align"        - Best compatibility for ONNX export.
+        patch_replace_method: the operation for replacing patches. Options:
+            "scatter_nd"       - Best performance for PyTorch and TorchScript.
+            "scatter_element"  - Best compatibility for ONNX export.
+        
+    Input:
+        src: (B, 3, H, W) full resolution source image.
+        bgr: (B, 3, H, W) full resolution background image.
+        pha: (B, 1, Hc, Wc) coarse alpha prediction.
+        fgr: (B, 3, Hc, Wc) coarse foreground residual prediction.
+        err: (B, 1, Hc, Hc) coarse error prediction.
+        hid: (B, 32, Hc, Hc) coarse hidden encoding.
+        
+    Output:
+        pha: (B, 1, H, W) full resolution alpha prediction.
+        fgr: (B, 3, H, W) full resolution foreground residual prediction.
+        ref: (B, 1, H/4, W/4) quarter resolution refinement selection map. 1 indicates refined 4x4 patch locations.
+    """
+    
+    # For TorchScript export optimization.
+    __constants__ = ['kernel_size', 'patch_crop_method', 'patch_replace_method']
+    
+    def __init__(self,
+                 mode: str,
+                 sample_pixels: int,
+                 threshold: float,
+                 kernel_size: int = 3,
+                 prevent_oversampling: bool = True,
+                 patch_crop_method: str = 'unfold',
+                 patch_replace_method: str = 'scatter_nd'):
+        super().__init__()
+        assert mode in ['full', 'sampling', 'thresholding']
+        assert kernel_size in [1, 3]
+        assert patch_crop_method in ['unfold', 'roi_align']
+        assert patch_replace_method in ['scatter_nd', 'scatter_element']
+        
+        self.mode = mode
+        self.sample_pixels = sample_pixels
+        self.threshold = threshold
+        self.kernel_size = kernel_size
+        self.prevent_oversampling = prevent_oversampling
+        self.patch_crop_method = patch_crop_method
+        self.patch_replace_method = patch_replace_method
+
+        channels = [32, 24, 16, 12, 4]
+        self.conv1 = nn.Conv2d(channels[0] + 6 + 4, channels[1], kernel_size, bias=False)
+        self.bn1 = nn.BatchNorm2d(channels[1])
+        self.conv2 = nn.Conv2d(channels[1], channels[2], kernel_size, bias=False)
+        self.bn2 = nn.BatchNorm2d(channels[2])
+        self.conv3 = nn.Conv2d(channels[2] + 6, channels[3], kernel_size, bias=False)
+        self.bn3 = nn.BatchNorm2d(channels[3])
+        self.conv4 = nn.Conv2d(channels[3], channels[4], kernel_size, bias=True)
+        self.relu = nn.ReLU(True)
+    
+    def forward(self,
+                src: torch.Tensor,
+                bgr: torch.Tensor,
+                pha: torch.Tensor,
+                fgr: torch.Tensor,
+                err: torch.Tensor,
+                hid: torch.Tensor):
+        H_full, W_full = src.shape[2:]
+        H_half, W_half = H_full // 2, W_full // 2
+        H_quat, W_quat = H_full // 4, W_full // 4
+        
+        src_bgr = torch.cat([src, bgr], dim=1)
+        
+        if self.mode != 'full':
+            err = F.interpolate(err, (H_quat, W_quat), mode='bilinear', align_corners=False)
+            ref = self.select_refinement_regions(err)
+            idx = torch.nonzero(ref.squeeze(1))
+            idx = idx[:, 0], idx[:, 1], idx[:, 2]
+            
+            if idx[0].size(0) > 0:
+                x = torch.cat([hid, pha, fgr], dim=1)
+                x = F.interpolate(x, (H_half, W_half), mode='bilinear', align_corners=False)
+                x = self.crop_patch(x, idx, 2, 3 if self.kernel_size == 3 else 0)
+
+                y = F.interpolate(src_bgr, (H_half, W_half), mode='bilinear', align_corners=False)
+                y = self.crop_patch(y, idx, 2, 3 if self.kernel_size == 3 else 0)
+
+                x = self.conv1(torch.cat([x, y], dim=1))
+                x = self.bn1(x)
+                x = self.relu(x)
+                x = self.conv2(x)
+                x = self.bn2(x)
+                x = self.relu(x)
+
+                x = F.interpolate(x, 8 if self.kernel_size == 3 else 4, mode='nearest')
+                y = self.crop_patch(src_bgr, idx, 4, 2 if self.kernel_size == 3 else 0)
+
+                x = self.conv3(torch.cat([x, y], dim=1))
+                x = self.bn3(x)
+                x = self.relu(x)
+                x = self.conv4(x)
+                
+                out = torch.cat([pha, fgr], dim=1)
+                out = F.interpolate(out, (H_full, W_full), mode='bilinear', align_corners=False)
+                out = self.replace_patch(out, x, idx)
+                pha = out[:, :1]
+                fgr = out[:, 1:]
+            else:
+                pha = F.interpolate(pha, (H_full, W_full), mode='bilinear', align_corners=False)
+                fgr = F.interpolate(fgr, (H_full, W_full), mode='bilinear', align_corners=False)
+        else:
+            x = torch.cat([hid, pha, fgr], dim=1)
+            x = F.interpolate(x, (H_half, W_half), mode='bilinear', align_corners=False)
+            y = F.interpolate(src_bgr, (H_half, W_half), mode='bilinear', align_corners=False)
+            if self.kernel_size == 3:
+                x = F.pad(x, (3, 3, 3, 3))
+                y = F.pad(y, (3, 3, 3, 3))
+
+            x = self.conv1(torch.cat([x, y], dim=1))
+            x = self.bn1(x)
+            x = self.relu(x)
+            x = self.conv2(x)
+            x = self.bn2(x)
+            x = self.relu(x)
+            
+            if self.kernel_size == 3:
+                x = F.interpolate(x, (H_full + 4, W_full + 4))
+                y = F.pad(src_bgr, (2, 2, 2, 2))
+            else:
+                x = F.interpolate(x, (H_full, W_full), mode='nearest')
+                y = src_bgr
+            
+            x = self.conv3(torch.cat([x, y], dim=1))
+            x = self.bn3(x)
+            x = self.relu(x)
+            x = self.conv4(x)
+            
+            pha = x[:, :1]
+            fgr = x[:, 1:]
+            ref = torch.ones((src.size(0), 1, H_quat, W_quat), device=src.device, dtype=src.dtype)
+            
+        return pha, fgr, ref
+    
+    def crop_patch(self,
+                   x: torch.Tensor,
+                   idx: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
+                   size: int,
+                   padding: int):
+        """
+        Crops selected patches from image given indices.
+        
+        Inputs:
+            x: image (B, C, H, W).
+            idx: selection indices Tuple[(P,), (P,), (P),], where the 3 values are (B, H, W) index.
+            size: center size of the patch, also stride of the crop.
+            padding: expansion size of the patch.
+        Output:
+            patch: (P, C, h, w), where h = w = size + 2 * padding.
+        """
+        if padding != 0:
+            x = F.pad(x, (padding,) * 4)
+        
+        if self.patch_crop_method == 'unfold':
+            # Use unfold. Best performance for PyTorch and TorchScript.
+            return x.permute(0, 2, 3, 1) \
+                    .unfold(1, size + 2 * padding, size) \
+                    .unfold(2, size + 2 * padding, size)[idx[0], idx[1], idx[2]]
+        else:
+            # Use roi_align. Best compatibility for ONNX.
+            idx = idx[0].type_as(x), idx[1].type_as(x), idx[2].type_as(x)
+            b = idx[0]
+            x1 = idx[2] * size - 0.5
+            y1 = idx[1] * size - 0.5
+            x2 = idx[2] * size + size + 2 * padding - 0.5
+            y2 = idx[1] * size + size + 2 * padding - 0.5
+            boxes = torch.stack([b, x1, y1, x2, y2], dim=1)
+            return torchvision.ops.roi_align(x, boxes, size + 2 * padding, sampling_ratio=1)    
+    
+    def replace_patch(self,
+                      x: torch.Tensor,
+                      y: torch.Tensor,
+                      idx: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]):
+        """
+        Replaces patches back into image given index.
+        
+        Inputs:
+            x: image (B, C, H, W)
+            y: patches (P, C, h, w)
+            idx: selection indices Tuple[(P,), (P,), (P,)] where the 3 values are (B, H, W) index.
+        
+        Output:
+            image: (B, C, H, W), where patches at idx locations are replaced with y.
+        """
+        xB, xC, xH, xW = x.shape
+        yB, yC, yH, yW = y.shape
+        if self.patch_replace_method == 'scatter_nd':
+            # Use scatter_nd. Best performance for PyTorch and TorchScript. Replacing patch by patch.
+            x = x.view(xB, xC, xH // yH, yH, xW // yW, yW).permute(0, 2, 4, 1, 3, 5)
+            x[idx[0], idx[1], idx[2]] = y
+            x = x.permute(0, 3, 1, 4, 2, 5).view(xB, xC, xH, xW)
+            return x
+        else:
+            # Use scatter_element. Best compatibility for ONNX. Replacing pixel by pixel.
+            iH, iW = xH // yH, xW // yW
+            i = self.crop_patch(torch.arange(0, xB * xC * xH * xW).view(xB, xC, xH, xW).type_as(x), idx, 4, 0)
+            i, x, y = i.view(-1), x.view(-1), y.view(-1)
+            x.scatter_(0, i.long(), y)
+            x = x.view(xB, xC, xH, xW)
+            return x
+
+    def select_refinement_regions(self, err: torch.Tensor):
+        """
+        Select refinement regions.
+        Input:
+            err: error map (B, 1, H, W)
+        Output:
+            ref: refinement regions (B, 1, H, W). FloatTensor. 1 is selected, 0 is not.
+        """
+        if self.mode == 'sampling':
+            # Sampling mode.
+            b, _, h, w = err.shape
+            err = err.view(b, -1)
+            idx = err.topk(self.sample_pixels // 16, dim=1, sorted=False).indices
+            ref = torch.zeros_like(err)
+            ref.scatter_(1, idx, 1.)
+            if self.prevent_oversampling:
+                ref.mul_(err.gt(0).float())
+            ref = ref.view(b, 1, h, w)
+        else:
+            # Thresholding mode.
+            ref = err.gt(self.threshold).float()
+        return ref

+ 48 - 0
model/resnet.py

@@ -0,0 +1,48 @@
+from torch import nn
+from torchvision.models.resnet import ResNet, Bottleneck
+
+
+class ResNetEncoder(ResNet):
+    """
+    ResNetEncoder inherits from torchvision's official ResNet. It is modified to
+    use dilation on the last block to maintain output stride 16, and deleted the
+    global average pooling layer and the fully connected layer that was originally
+    used for classification. The forward method  additionally returns the feature
+    maps at all resolutions for decoder's use.
+    """
+    
+    layers = {
+        'resnet50':  [3, 4, 6, 3],
+        'resnet101': [3, 4, 23, 3],
+    }
+    
+    def __init__(self, in_channels, variant='resnet101', norm_layer=None):
+        super().__init__(
+            block=Bottleneck,
+            layers=self.layers[variant],
+            replace_stride_with_dilation=[False, False, True],
+            norm_layer=norm_layer)
+        
+        # Replace first conv layer if in_channels doesn't match.
+        if in_channels != 3:
+            self.conv1 = nn.Conv2d(in_channels, 64, 7, 2, 3, bias=False)
+            
+        # Delete fully-connected layer
+        del self.avgpool
+        del self.fc
+    
+    def forward(self, x):
+        x0 = x  # 1/1
+        x = self.conv1(x)
+        x = self.bn1(x)
+        x = self.relu(x)
+        x1 = x  # 1/2
+        x = self.maxpool(x)
+        x = self.layer1(x)
+        x2 = x  # 1/4
+        x = self.layer2(x)
+        x3 = x  # 1/8
+        x = self.layer3(x)
+        x = self.layer4(x)
+        x4 = x  # 1/16
+        return x4, x3, x2, x1, x0

+ 14 - 0
model/utils.py

@@ -0,0 +1,14 @@
+def load_matched_state_dict(model, state_dict, print_stats=True):
+    """
+    Only loads weights that matched in key and shape. Ignore other weights.
+    """
+    num_matched, num_total = 0, 0
+    curr_state_dict = model.state_dict()
+    for key in curr_state_dict.keys():
+        num_total += 1
+        if key in state_dict and curr_state_dict[key].shape == state_dict[key].shape:
+            curr_state_dict[key] = state_dict[key]
+            num_matched += 1
+    model.load_state_dict(curr_state_dict)
+    if print_stats:
+        print(f'Loaded state_dict: {num_matched}/{num_total} matched')

+ 7 - 0
requirements.txt

@@ -0,0 +1,7 @@
+kornia==0.4.1
+tensorboard==2.3.0
+torch==1.7.0
+torchvision==0.8.1
+tqdm==4.51.0
+opencv-python==4.4.0.44
+onnxruntime==1.6.0

Неке датотеке нису приказане због велике количине промена