Browse Source

Publish train scripts

Peter Lin 4 years ago
parent
commit
d993eaac3e
5 changed files with 646 additions and 15 deletions
  1. 4 2
      README.md
  2. 68 0
      data_path.py
  3. 0 13
      dataset/augmentation.py
  4. 265 0
      train_base.py
  5. 309 0
      train_refine.py

+ 4 - 2
README.md

@@ -28,6 +28,7 @@ Official repository for the paper [Real-Time High-Resolution Background Matting]
 
 ## Updates
 
+* [Mar 06 2021] Training script is published.
 * [Feb 28 2021] Paper is accepted to CVPR 2021.
 * [Jan 09 2021] PhotoMatte85 dataset is now published.
 * [Dec 21 2020] We updated our project to MIT License, which permits commercial use.
@@ -48,8 +49,9 @@ Official repository for the paper [Real-Time High-Resolution Background Matting]
 
 ### Datasets
 
-* VideoMatte240K (Coming soon)
 * [PhotoMatte85](https://drive.google.com/file/d/1KpHKYW986Dax9-ZIM7I-HyBoWVcLPuaQ/view?usp=sharing)
+* VideoMatte240K (We are still dealing with licensing. In the meantime, you can visit [storyblocks.com](https://www.storyblocks.com/video/search/green+screen+human?max_duration=10000&sort=most_relevant&video_quality=HD) to download raw green screen videos and recreate the dataset yourself.)
+
 
  
 
@@ -85,7 +87,7 @@ You can run our model using **PyTorch**, **TorchScript**, **TensorFlow**, and **
 
 ## Training
 
-Training code will be released upon acceptance of the paper.
+Configure `data_path.pth` to point to your dataset. The original paper uses `train_base.pth` to train only the base model till convergence then use `train_refine.pth` to train the entire network end-to-end. More details are specified in the paper.
 
  
 

+ 68 - 0
data_path.py

@@ -0,0 +1,68 @@
+"""
+This file records the directory paths to the different datasets.
+You will need to configure it for training the model.
+
+All datasets follows the following format, where fgr and pha points to directory that contains jpg or png.
+Inside the directory could be any nested formats, but fgr and pha structure must match. You can add your own
+dataset to the list as long as it follows the format. 'fgr' should point to foreground images with RGB channels,
+'pha' should point to alpha images with only 1 grey channel.
+{
+    'YOUR_DATASET': {
+        'train': {
+            'fgr': 'PATH_TO_IMAGES_DIR',
+            'pha': 'PATH_TO_IMAGES_DIR',
+        },
+        'valid': {
+            'fgr': 'PATH_TO_IMAGES_DIR',
+            'pha': 'PATH_TO_IMAGES_DIR',
+        }
+    }
+}
+"""
+
+DATA_PATH = {
+    'videomatte240k': {
+        'train': {
+            'fgr': 'PATH_TO_IMAGES_DIR',
+            'pha': 'PATH_TO_IMAGES_DIR'
+        },
+        'valid': {
+            'fgr': 'PATH_TO_IMAGES_DIR',
+            'pha': 'PATH_TO_IMAGES_DIR'
+        }
+    },
+    'photomatte13k': {
+        'train': {
+            'fgr': 'PATH_TO_IMAGES_DIR',
+            'pha': 'PATH_TO_IMAGES_DIR'
+        },
+        'valid': {
+            'fgr': 'PATH_TO_IMAGES_DIR',
+            'pha': 'PATH_TO_IMAGES_DIR'
+        }
+    },
+    'distinction': {
+        'train': {
+            'fgr': 'PATH_TO_IMAGES_DIR',
+            'pha': 'PATH_TO_IMAGES_DIR',
+        },
+        'valid': {
+            'fgr': 'PATH_TO_IMAGES_DIR',
+            'pha': 'PATH_TO_IMAGES_DIR'
+        },
+    },
+    'adobe': {
+        'train': {
+            'fgr': 'PATH_TO_IMAGES_DIR',
+            'pha': 'PATH_TO_IMAGES_DIR',
+        },
+        'valid': {
+            'fgr': 'PATH_TO_IMAGES_DIR',
+            'pha': 'PATH_TO_IMAGES_DIR'
+        },
+    },
+    'backgrounds': {
+        'train': 'PATH_TO_IMAGES_DIR',
+        'valid': 'PATH_TO_IMAGES_DIR'
+    },
+}

+ 0 - 13
dataset/augmentation.py

@@ -52,19 +52,6 @@ class PairRandomAffine(T.RandomAffine):
         return [F.affine(xi, *param, resamples[i], self.fillcolor) for i, xi in enumerate(x)]
 
 
-class PairRandomResizedCrop(T.RandomResizedCrop):
-    def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolations=None):
-        super().__init__(size, scale, ratio, Image.BILINEAR)
-        self.interpolations = interpolations
-    
-    def __call__(self, *x):
-        if not len(x):
-            return []
-        i, j, h, w = self.get_params(x[0], self.scale, self.ratio)
-        interpolations = self.interpolations or [self.interpolation] * len(x)
-        return [F.resized_crop(xi, i, j, h, w, self.size, interpolations[i]) for i, xi in enumerate(x)]
-    
-
 class PairRandomHorizontalFlip(T.RandomHorizontalFlip):
     def __call__(self, *x):
         if torch.rand(1) < self.p:

+ 265 - 0
train_base.py

@@ -0,0 +1,265 @@
+"""
+Train MattingBase
+
+You can download pretrained DeepLabV3 weights from <https://github.com/VainF/DeepLabV3Plus-Pytorch>
+
+Example:
+
+    CUDA_VISIBLE_DEVICES=0 python train_base.py \
+        --dataset-name videomatte240k \
+        --model-backbone resnet50 \
+        --model-name mattingbase-resnet50-videomatte240k \
+        --model-pretrain-initialization "pretraining/best_deeplabv3_resnet50_voc_os16.pth" \
+        --epoch-end 8
+
+"""
+
+import argparse
+import kornia
+import torch
+import os
+import random
+
+from torch import nn
+from torch.nn import functional as F
+from torch.cuda.amp import autocast, GradScaler
+from torch.utils.tensorboard import SummaryWriter
+from torch.utils.data import DataLoader
+from torch.optim import Adam
+from torchvision.utils import make_grid
+from tqdm import tqdm
+from torchvision import transforms as T
+from PIL import Image
+
+from data_path import DATA_PATH
+from dataset import ImagesDataset, ZipDataset, VideoDataset, SampleDataset
+from dataset import augmentation as A
+from model import MattingBase
+from model.utils import load_matched_state_dict
+
+
+# --------------- Arguments ---------------
+
+
+parser = argparse.ArgumentParser()
+
+parser.add_argument('--dataset-name', type=str, required=True, choices=DATA_PATH.keys())
+
+parser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2'])
+parser.add_argument('--model-name', type=str, required=True)
+parser.add_argument('--model-pretrain-initialization', type=str, default=None)
+parser.add_argument('--model-last-checkpoint', type=str, default=None)
+
+parser.add_argument('--batch-size', type=int, default=8)
+parser.add_argument('--num-workers', type=int, default=16)
+parser.add_argument('--epoch-start', type=int, default=0)
+parser.add_argument('--epoch-end', type=int, required=True)
+
+parser.add_argument('--log-train-loss-interval', type=int, default=10)
+parser.add_argument('--log-train-images-interval', type=int, default=2000)
+parser.add_argument('--log-valid-interval', type=int, default=5000)
+
+parser.add_argument('--checkpoint-interval', type=int, default=5000)
+
+args = parser.parse_args()
+
+
+# --------------- Loading ---------------
+
+
+def train():
+    
+    # Training DataLoader
+    dataset_train = ZipDataset([
+        ZipDataset([
+            ImagesDataset(DATA_PATH[args.dataset_name]['train']['pha'], mode='L'),
+            ImagesDataset(DATA_PATH[args.dataset_name]['train']['fgr'], mode='RGB'),
+        ], transforms=A.PairCompose([
+            A.PairRandomAffineAndResize((512, 512), degrees=(-5, 5), translate=(0.1, 0.1), scale=(0.4, 1), shear=(-5, 5)),
+            A.PairRandomHorizontalFlip(),
+            A.PairRandomBoxBlur(0.1, 5),
+            A.PairRandomSharpen(0.1),
+            A.PairApplyOnlyAtIndices([1], T.ColorJitter(0.15, 0.15, 0.15, 0.05)),
+            A.PairApply(T.ToTensor())
+        ]), assert_equal_length=True),
+        ImagesDataset(DATA_PATH['backgrounds']['train'], mode='RGB', transforms=T.Compose([
+            A.RandomAffineAndResize((512, 512), degrees=(-5, 5), translate=(0.1, 0.1), scale=(1, 2), shear=(-5, 5)),
+            T.RandomHorizontalFlip(),
+            A.RandomBoxBlur(0.1, 5),
+            A.RandomSharpen(0.1),
+            T.ColorJitter(0.15, 0.15, 0.15, 0.05),
+            T.ToTensor()
+        ])),
+    ])
+    dataloader_train = DataLoader(dataset_train,
+                                  shuffle=True,
+                                  batch_size=args.batch_size,
+                                  num_workers=args.num_workers,
+                                  pin_memory=True)
+    
+    # Validation DataLoader
+    dataset_valid = ZipDataset([
+        ZipDataset([
+            ImagesDataset(DATA_PATH[args.dataset_name]['valid']['pha'], mode='L'),
+            ImagesDataset(DATA_PATH[args.dataset_name]['valid']['fgr'], mode='RGB')
+        ], transforms=A.PairCompose([
+            A.PairRandomAffineAndResize((512, 512), degrees=(-5, 5), translate=(0.1, 0.1), scale=(0.3, 1), shear=(-5, 5)),
+            A.PairApply(T.ToTensor())
+        ]), assert_equal_length=True),
+        ImagesDataset(DATA_PATH['backgrounds']['valid'], mode='RGB', transforms=T.Compose([
+            A.RandomAffineAndResize((512, 512), degrees=(-5, 5), translate=(0.1, 0.1), scale=(1, 1.2), shear=(-5, 5)),
+            T.ToTensor()
+        ])),
+    ])
+    dataset_valid = SampleDataset(dataset_valid, 50)
+    dataloader_valid = DataLoader(dataset_valid,
+                                  pin_memory=True,
+                                  batch_size=args.batch_size,
+                                  num_workers=args.num_workers)
+
+    # Model
+    model = MattingBase(args.model_backbone).cuda()
+
+    if args.model_last_checkpoint is not None:
+        load_matched_state_dict(model, torch.load(args.model_last_checkpoint))
+    elif args.model_pretrain_initialization is not None:
+        model.load_pretrained_deeplabv3_state_dict(torch.load(args.model_pretrain_initialization)['model_state'])
+
+    optimizer = Adam([
+        {'params': model.backbone.parameters(), 'lr': 1e-4},
+        {'params': model.aspp.parameters(), 'lr': 5e-4},
+        {'params': model.decoder.parameters(), 'lr': 5e-4}
+    ])
+    scaler = GradScaler()
+
+    # Logging and checkpoints
+    if not os.path.exists(f'checkpoint/{args.model_name}'):
+        os.makedirs(f'checkpoint/{args.model_name}')
+    writer = SummaryWriter(f'log/{args.model_name}')
+    
+    # Run loop
+    for epoch in range(args.epoch_start, args.epoch_end):
+        for i, ((true_pha, true_fgr), true_bgr) in enumerate(tqdm(dataloader_train)):
+            step = epoch * len(dataloader_train) + i
+
+            true_pha = true_pha.cuda(non_blocking=True)
+            true_fgr = true_fgr.cuda(non_blocking=True)
+            true_bgr = true_bgr.cuda(non_blocking=True)
+            true_pha, true_fgr, true_bgr = random_crop(true_pha, true_fgr, true_bgr)
+            
+            true_src = true_bgr.clone()
+            
+            # Augment with shadow
+            aug_shadow_idx = torch.rand(len(true_src)) < 0.3
+            if aug_shadow_idx.any():
+                aug_shadow = true_pha[aug_shadow_idx].mul(0.3 * random.random())
+                aug_shadow = T.RandomAffine(degrees=(-5, 5), translate=(0.2, 0.2), scale=(0.5, 1.5), shear=(-5, 5))(aug_shadow)
+                aug_shadow = kornia.filters.box_blur(aug_shadow, (random.choice(range(20, 40)),) * 2)
+                true_src[aug_shadow_idx] = true_src[aug_shadow_idx].sub_(aug_shadow).clamp_(0, 1)
+                del aug_shadow
+            del aug_shadow_idx
+            
+            # Composite foreground onto source
+            true_src = true_fgr * true_pha + true_src * (1 - true_pha)
+
+            # Augment with noise
+            aug_noise_idx = torch.rand(len(true_src)) < 0.4
+            if aug_noise_idx.any():
+                true_src[aug_noise_idx] = true_src[aug_noise_idx].add_(torch.randn_like(true_src[aug_noise_idx]).mul_(0.03 * random.random())).clamp_(0, 1)
+                true_bgr[aug_noise_idx] = true_bgr[aug_noise_idx].add_(torch.randn_like(true_bgr[aug_noise_idx]).mul_(0.03 * random.random())).clamp_(0, 1)
+            del aug_noise_idx
+            
+            # Augment background with jitter
+            aug_jitter_idx = torch.rand(len(true_src)) < 0.8
+            if aug_jitter_idx.any():
+                true_bgr[aug_jitter_idx] = kornia.augmentation.ColorJitter(0.18, 0.18, 0.18, 0.1)(true_bgr[aug_jitter_idx])
+            del aug_jitter_idx
+            
+            # Augment background with affine
+            aug_affine_idx = torch.rand(len(true_bgr)) < 0.3
+            if aug_affine_idx.any():
+                true_bgr[aug_affine_idx] = T.RandomAffine(degrees=(-1, 1), translate=(0.01, 0.01))(true_bgr[aug_affine_idx])
+            del aug_affine_idx
+
+            with autocast():
+                pred_pha, pred_fgr, pred_err = model(true_src, true_bgr)[:3]
+                loss = compute_loss(pred_pha, pred_fgr, pred_err, true_pha, true_fgr)
+
+            scaler.scale(loss).backward()
+            scaler.step(optimizer)
+            scaler.update()
+            optimizer.zero_grad()
+
+            if (i + 1) % args.log_train_loss_interval == 0:
+                writer.add_scalar('loss', loss, step)
+
+            if (i + 1) % args.log_train_images_interval == 0:
+                writer.add_image('train_pred_pha', make_grid(pred_pha, nrow=5), step)
+                writer.add_image('train_pred_fgr', make_grid(pred_fgr, nrow=5), step)
+                writer.add_image('train_pred_com', make_grid(pred_fgr * pred_pha, nrow=5), step)
+                writer.add_image('train_pred_err', make_grid(pred_err, nrow=5), step)
+                writer.add_image('train_true_src', make_grid(true_src, nrow=5), step)
+                writer.add_image('train_true_bgr', make_grid(true_bgr, nrow=5), step)
+                
+            del true_pha, true_fgr, true_bgr
+            del pred_pha, pred_fgr, pred_err
+
+            if (i + 1) % args.log_valid_interval == 0:
+                valid(model, dataloader_valid, writer, step)
+
+            if (step + 1) % args.checkpoint_interval == 0:
+                torch.save(model.state_dict(), f'checkpoint/{args.model_name}/epoch-{epoch}-iter-{step}.pth')
+
+        torch.save(model.state_dict(), f'checkpoint/{args.model_name}/epoch-{epoch}.pth')
+
+
+# --------------- Utils ---------------
+
+
+def compute_loss(pred_pha, pred_fgr, pred_err, true_pha, true_fgr):
+    true_err = torch.abs(pred_pha.detach() - true_pha)
+    true_msk = true_pha != 0
+    return F.l1_loss(pred_pha, true_pha) + \
+           F.l1_loss(kornia.sobel(pred_pha), kornia.sobel(true_pha)) + \
+           F.l1_loss(pred_fgr * true_msk, true_fgr * true_msk) + \
+           F.mse_loss(pred_err, true_err)
+
+
+def random_crop(*imgs):
+    w = random.choice(range(256, 512))
+    h = random.choice(range(256, 512))
+    results = []
+    for img in imgs:
+        img = kornia.resize(img, (max(h, w), max(h, w)))
+        img = kornia.center_crop(img, (h, w))
+        results.append(img)
+    return results
+
+
+def valid(model, dataloader, writer, step):
+    model.eval()
+    loss_total = 0
+    loss_count = 0
+    with torch.no_grad():
+        for (true_pha, true_fgr), true_bgr in dataloader:
+            batch_size = true_pha.size(0)
+            
+            true_pha = true_pha.cuda(non_blocking=True)
+            true_fgr = true_fgr.cuda(non_blocking=True)
+            true_bgr = true_bgr.cuda(non_blocking=True)
+            true_src = true_pha * true_fgr + (1 - true_pha) * true_bgr
+
+            pred_pha, pred_fgr, pred_err = model(true_src, true_bgr)[:3]
+            loss = compute_loss(pred_pha, pred_fgr, pred_err, true_pha, true_fgr)
+            loss_total += loss.cpu().item() * batch_size
+            loss_count += batch_size
+
+    writer.add_scalar('valid_loss', loss_total / loss_count, step)
+    model.train()
+
+
+# --------------- Start ---------------
+
+
+if __name__ == '__main__':
+    train()

+ 309 - 0
train_refine.py

@@ -0,0 +1,309 @@
+"""
+Train MattingRefine
+
+Supports multi-GPU training with DistributedDataParallel() and SyncBatchNorm.
+Select GPUs through CUDA_VISIBLE_DEVICES environment variable.
+
+Example:
+
+    CUDA_VISIBLE_DEVICES=0,1 python train_refine.py \
+        --dataset-name videomatte240k \
+        --model-backbone resnet50 \
+        --model-name mattingrefine-resnet50-videomatte240k \
+        --model-last-checkpoint "PATH_TO_LAST_CHECKPOINT" \
+        --epoch-end 1
+
+"""
+
+import argparse
+import kornia
+import torch
+import os
+import random
+
+from torch import nn
+from torch import distributed as dist
+from torch import multiprocessing as mp
+from torch.nn import functional as F
+from torch.cuda.amp import autocast, GradScaler
+from torch.utils.tensorboard import SummaryWriter
+from torch.utils.data import DataLoader, Subset
+from torch.optim import Adam
+from torchvision.utils import make_grid
+from tqdm import tqdm
+from torchvision import transforms as T
+from PIL import Image
+
+from data_path import DATA_PATH
+from dataset import ImagesDataset, ZipDataset, VideoDataset, SampleDataset
+from dataset import augmentation as A
+from model import MattingRefine
+from model.utils import load_matched_state_dict
+
+
+# --------------- Arguments ---------------
+
+
+parser = argparse.ArgumentParser()
+
+parser.add_argument('--dataset-name', type=str, required=True, choices=DATA_PATH.keys())
+
+parser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2'])
+parser.add_argument('--model-backbone-scale', type=float, default=0.25)
+parser.add_argument('--model-refine-mode', type=str, default='sampling', choices=['full', 'sampling', 'thresholding'])
+parser.add_argument('--model-refine-sample-pixels', type=int, default=80_000)
+parser.add_argument('--model-refine-thresholding', type=float, default=0.7)
+parser.add_argument('--model-refine-kernel-size', type=int, default=3, choices=[1, 3])
+parser.add_argument('--model-name', type=str, required=True)
+parser.add_argument('--model-last-checkpoint', type=str, default=None)
+
+parser.add_argument('--batch-size', type=int, default=4)
+parser.add_argument('--num-workers', type=int, default=16)
+parser.add_argument('--epoch-start', type=int, default=0)
+parser.add_argument('--epoch-end', type=int, required=True)
+
+parser.add_argument('--log-train-loss-interval', type=int, default=10)
+parser.add_argument('--log-train-images-interval', type=int, default=1000)
+parser.add_argument('--log-valid-interval', type=int, default=2000)
+
+parser.add_argument('--checkpoint-interval', type=int, default=2000)
+
+args = parser.parse_args()
+
+
+distributed_num_gpus = torch.cuda.device_count()
+assert args.batch_size % distributed_num_gpus == 0
+
+
+# --------------- Main ---------------
+
+def train_worker(rank, addr, port):
+    
+    # Distributed Setup
+    os.environ['MASTER_ADDR'] = addr
+    os.environ['MASTER_PORT'] = port
+    dist.init_process_group("nccl", rank=rank, world_size=distributed_num_gpus)
+    
+    # Training DataLoader
+    dataset_train = ZipDataset([
+        ZipDataset([
+            ImagesDataset(DATA_PATH[args.dataset_name]['train']['pha'], mode='L'),
+            ImagesDataset(DATA_PATH[args.dataset_name]['train']['fgr'], mode='RGB'),
+        ], transforms=A.PairCompose([
+            A.PairRandomAffineAndResize((2048, 2048), degrees=(-5, 5), translate=(0.1, 0.1), scale=(0.3, 1), shear=(-5, 5)),
+            A.PairRandomHorizontalFlip(),
+            A.PairRandomBoxBlur(0.1, 5),
+            A.PairRandomSharpen(0.1),
+            A.PairApplyOnlyAtIndices([1], T.ColorJitter(0.15, 0.15, 0.15, 0.05)),
+            A.PairApply(T.ToTensor())
+        ]), assert_equal_length=True),
+        ImagesDataset(DATA_PATH['backgrounds']['train'], mode='RGB', transforms=T.Compose([
+            A.RandomAffineAndResize((2048, 2048), degrees=(-5, 5), translate=(0.1, 0.1), scale=(1, 2), shear=(-5, 5)),
+            T.RandomHorizontalFlip(),
+            A.RandomBoxBlur(0.1, 5),
+            A.RandomSharpen(0.1),
+            T.ColorJitter(0.15, 0.15, 0.15, 0.05),
+            T.ToTensor()
+        ])),
+    ])
+    dataset_train_len_per_gpu_worker = int(len(dataset_train) / distributed_num_gpus)
+    dataset_train = Subset(dataset_train, range(rank * dataset_train_len_per_gpu_worker, (rank + 1) * dataset_train_len_per_gpu_worker))
+    dataloader_train = DataLoader(dataset_train,
+                                  shuffle=True,
+                                  pin_memory=True,
+                                  drop_last=True,
+                                  batch_size=args.batch_size // distributed_num_gpus,
+                                  num_workers=args.num_workers // distributed_num_gpus)
+    
+    # Validation DataLoader
+    if rank == 0:
+        dataset_valid = ZipDataset([
+            ZipDataset([
+                ImagesDataset(DATA_PATH[args.dataset_name]['valid']['pha'], mode='L'),
+                ImagesDataset(DATA_PATH[args.dataset_name]['valid']['fgr'], mode='RGB')
+            ], transforms=A.PairCompose([
+                A.PairRandomAffineAndResize((2048, 2048), degrees=(-5, 5), translate=(0.1, 0.1), scale=(0.3, 1), shear=(-5, 5)),
+                A.PairApply(T.ToTensor())
+            ]), assert_equal_length=True),
+            ImagesDataset(DATA_PATH['backgrounds']['valid'], mode='RGB', transforms=T.Compose([
+                A.RandomAffineAndResize((2048, 2048), degrees=(-5, 5), translate=(0.1, 0.1), scale=(1, 1.2), shear=(-5, 5)),
+                T.ToTensor()
+            ])),
+        ])
+        dataset_valid = SampleDataset(dataset_valid, 50)
+        dataloader_valid = DataLoader(dataset_valid,
+                                      pin_memory=True,
+                                      drop_last=True,
+                                      batch_size=args.batch_size // distributed_num_gpus,
+                                      num_workers=args.num_workers // distributed_num_gpus)
+    
+    # Model
+    model = MattingRefine(args.model_backbone,
+                          args.model_backbone_scale,
+                          args.model_refine_mode,
+                          args.model_refine_sample_pixels,
+                          args.model_refine_thresholding,
+                          args.model_refine_kernel_size).to(rank)
+    model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
+    model_distributed = nn.parallel.DistributedDataParallel(model, device_ids=[rank])
+    
+    if args.model_last_checkpoint is not None:
+        load_matched_state_dict(model, torch.load(args.model_last_checkpoint))
+
+    optimizer = Adam([
+        {'params': model.backbone.parameters(), 'lr': 5e-5},
+        {'params': model.aspp.parameters(), 'lr': 5e-5},
+        {'params': model.decoder.parameters(), 'lr': 1e-4},
+        {'params': model.refiner.parameters(), 'lr': 3e-4},
+    ])
+    scaler = GradScaler()
+    
+    # Logging and checkpoints
+    if rank == 0:
+        if not os.path.exists(f'checkpoint/{args.model_name}'):
+            os.makedirs(f'checkpoint/{args.model_name}')
+        writer = SummaryWriter(f'log/{args.model_name}')
+    
+    # Run loop
+    for epoch in range(args.epoch_start, args.epoch_end):
+        for i, ((true_pha, true_fgr), true_bgr) in enumerate(tqdm(dataloader_train)):
+            step = epoch * len(dataloader_train) + i
+
+            true_pha = true_pha.to(rank, non_blocking=True)
+            true_fgr = true_fgr.to(rank, non_blocking=True)
+            true_bgr = true_bgr.to(rank, non_blocking=True)
+            true_pha, true_fgr, true_bgr = random_crop(true_pha, true_fgr, true_bgr)
+            
+            true_src = true_bgr.clone()
+            
+            # Augment with shadow
+            aug_shadow_idx = torch.rand(len(true_src)) < 0.3
+            if aug_shadow_idx.any():
+                aug_shadow = true_pha[aug_shadow_idx].mul(0.3 * random.random())
+                aug_shadow = T.RandomAffine(degrees=(-5, 5), translate=(0.2, 0.2), scale=(0.5, 1.5), shear=(-5, 5))(aug_shadow)
+                aug_shadow = kornia.filters.box_blur(aug_shadow, (random.choice(range(20, 40)),) * 2)
+                true_src[aug_shadow_idx] = true_src[aug_shadow_idx].sub_(aug_shadow).clamp_(0, 1)
+                del aug_shadow
+            del aug_shadow_idx
+            
+            # Composite foreground onto source
+            true_src = true_fgr * true_pha + true_src * (1 - true_pha)
+
+            # Augment with noise
+            aug_noise_idx = torch.rand(len(true_src)) < 0.4
+            if aug_noise_idx.any():
+                true_src[aug_noise_idx] = true_src[aug_noise_idx].add_(torch.randn_like(true_src[aug_noise_idx]).mul_(0.03 * random.random())).clamp_(0, 1)
+                true_bgr[aug_noise_idx] = true_bgr[aug_noise_idx].add_(torch.randn_like(true_bgr[aug_noise_idx]).mul_(0.03 * random.random())).clamp_(0, 1)
+            del aug_noise_idx
+            
+            # Augment background with jitter
+            aug_jitter_idx = torch.rand(len(true_src)) < 0.8
+            if aug_jitter_idx.any():
+                true_bgr[aug_jitter_idx] = kornia.augmentation.ColorJitter(0.18, 0.18, 0.18, 0.1)(true_bgr[aug_jitter_idx])
+            del aug_jitter_idx
+            
+            # Augment background with affine
+            aug_affine_idx = torch.rand(len(true_bgr)) < 0.3
+            if aug_affine_idx.any():
+                true_bgr[aug_affine_idx] = T.RandomAffine(degrees=(-1, 1), translate=(0.01, 0.01))(true_bgr[aug_affine_idx])
+            del aug_affine_idx
+            
+            with autocast():
+                pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm, _ = model_distributed(true_src, true_bgr)
+                loss = compute_loss(pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm, true_pha, true_fgr)
+
+            scaler.scale(loss).backward()
+            scaler.step(optimizer)
+            scaler.update()
+            optimizer.zero_grad()
+
+            if rank == 0:
+                if (i + 1) % args.log_train_loss_interval == 0:
+                    writer.add_scalar('loss', loss, step)
+
+                if (i + 1) % args.log_train_images_interval == 0:
+                    writer.add_image('train_pred_pha', make_grid(pred_pha, nrow=5), step)
+                    writer.add_image('train_pred_fgr', make_grid(pred_fgr, nrow=5), step)
+                    writer.add_image('train_pred_com', make_grid(pred_fgr * pred_pha, nrow=5), step)
+                    writer.add_image('train_pred_err', make_grid(pred_err_sm, nrow=5), step)
+                    writer.add_image('train_true_src', make_grid(true_src, nrow=5), step)
+
+                del true_pha, true_fgr, true_src, true_bgr
+                del pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm
+
+                if (i + 1) % args.log_valid_interval == 0:
+                    valid(model, dataloader_valid, writer, step)
+
+                if (step + 1) % args.checkpoint_interval == 0:
+                    torch.save(model.state_dict(), f'checkpoint/{args.model_name}/epoch-{epoch}-iter-{step}.pth')
+                    
+        if rank == 0:
+            torch.save(model.state_dict(), f'checkpoint/{args.model_name}/epoch-{epoch}.pth')
+            
+    # Clean up
+    dist.destroy_process_group()
+            
+            
+# --------------- Utils ---------------
+
+
+def compute_loss(pred_pha_lg, pred_fgr_lg, pred_pha_sm, pred_fgr_sm, pred_err_sm, true_pha_lg, true_fgr_lg):
+    true_pha_sm = kornia.resize(true_pha_lg, pred_pha_sm.shape[2:])
+    true_fgr_sm = kornia.resize(true_fgr_lg, pred_fgr_sm.shape[2:])
+    true_msk_lg = true_pha_lg != 0
+    true_msk_sm = true_pha_sm != 0
+    return F.l1_loss(pred_pha_lg, true_pha_lg) + \
+           F.l1_loss(pred_pha_sm, true_pha_sm) + \
+           F.l1_loss(kornia.sobel(pred_pha_lg), kornia.sobel(true_pha_lg)) + \
+           F.l1_loss(kornia.sobel(pred_pha_sm), kornia.sobel(true_pha_sm)) + \
+           F.l1_loss(pred_fgr_lg * true_msk_lg, true_fgr_lg * true_msk_lg) + \
+           F.l1_loss(pred_fgr_sm * true_msk_sm, true_fgr_sm * true_msk_sm) + \
+           F.mse_loss(kornia.resize(pred_err_sm, true_pha_lg.shape[2:]), \
+                      kornia.resize(pred_pha_sm, true_pha_lg.shape[2:]).sub(true_pha_lg).abs())
+
+
+def random_crop(*imgs):
+    H_src, W_src = imgs[0].shape[2:]
+    W_tgt = random.choice(range(1024, 2048)) // 4 * 4
+    H_tgt = random.choice(range(1024, 2048)) // 4 * 4
+    scale = max(W_tgt / W_src, H_tgt / H_src)
+    results = []
+    for img in imgs:
+        img = kornia.resize(img, (int(H_src * scale), int(W_src * scale)))
+        img = kornia.center_crop(img, (H_tgt, W_tgt))
+        results.append(img)
+    return results
+
+
+def valid(model, dataloader, writer, step):
+    model.eval()
+    loss_total = 0
+    loss_count = 0
+    with torch.no_grad():
+        for (true_pha, true_fgr), true_bgr in dataloader:
+            batch_size = true_pha.size(0)
+            
+            true_pha = true_pha.cuda(non_blocking=True)
+            true_fgr = true_fgr.cuda(non_blocking=True)
+            true_bgr = true_bgr.cuda(non_blocking=True)
+            true_src = true_pha * true_fgr + (1 - true_pha) * true_bgr
+
+            pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm, _ = model(true_src, true_bgr)
+            loss = compute_loss(pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm, true_pha, true_fgr)
+            loss_total += loss.cpu().item() * batch_size
+            loss_count += batch_size
+
+    writer.add_scalar('valid_loss', loss_total / loss_count, step)
+    model.train()
+
+
+# --------------- Start ---------------
+
+
+if __name__ == '__main__':
+    addr = 'localhost'
+    port = str(random.choice(range(12300, 12400))) # pick a random port.
+    mp.spawn(train_worker,
+             nprocs=distributed_num_gpus,
+             args=(addr, port),
+             join=True)