|
@@ -0,0 +1,309 @@
|
|
|
+"""
|
|
|
+Train MattingRefine
|
|
|
+
|
|
|
+Supports multi-GPU training with DistributedDataParallel() and SyncBatchNorm.
|
|
|
+Select GPUs through CUDA_VISIBLE_DEVICES environment variable.
|
|
|
+
|
|
|
+Example:
|
|
|
+
|
|
|
+ CUDA_VISIBLE_DEVICES=0,1 python train_refine.py \
|
|
|
+ --dataset-name videomatte240k \
|
|
|
+ --model-backbone resnet50 \
|
|
|
+ --model-name mattingrefine-resnet50-videomatte240k \
|
|
|
+ --model-last-checkpoint "PATH_TO_LAST_CHECKPOINT" \
|
|
|
+ --epoch-end 1
|
|
|
+
|
|
|
+"""
|
|
|
+
|
|
|
+import argparse
|
|
|
+import kornia
|
|
|
+import torch
|
|
|
+import os
|
|
|
+import random
|
|
|
+
|
|
|
+from torch import nn
|
|
|
+from torch import distributed as dist
|
|
|
+from torch import multiprocessing as mp
|
|
|
+from torch.nn import functional as F
|
|
|
+from torch.cuda.amp import autocast, GradScaler
|
|
|
+from torch.utils.tensorboard import SummaryWriter
|
|
|
+from torch.utils.data import DataLoader, Subset
|
|
|
+from torch.optim import Adam
|
|
|
+from torchvision.utils import make_grid
|
|
|
+from tqdm import tqdm
|
|
|
+from torchvision import transforms as T
|
|
|
+from PIL import Image
|
|
|
+
|
|
|
+from data_path import DATA_PATH
|
|
|
+from dataset import ImagesDataset, ZipDataset, VideoDataset, SampleDataset
|
|
|
+from dataset import augmentation as A
|
|
|
+from model import MattingRefine
|
|
|
+from model.utils import load_matched_state_dict
|
|
|
+
|
|
|
+
|
|
|
+# --------------- Arguments ---------------
|
|
|
+
|
|
|
+
|
|
|
+parser = argparse.ArgumentParser()
|
|
|
+
|
|
|
+parser.add_argument('--dataset-name', type=str, required=True, choices=DATA_PATH.keys())
|
|
|
+
|
|
|
+parser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2'])
|
|
|
+parser.add_argument('--model-backbone-scale', type=float, default=0.25)
|
|
|
+parser.add_argument('--model-refine-mode', type=str, default='sampling', choices=['full', 'sampling', 'thresholding'])
|
|
|
+parser.add_argument('--model-refine-sample-pixels', type=int, default=80_000)
|
|
|
+parser.add_argument('--model-refine-thresholding', type=float, default=0.7)
|
|
|
+parser.add_argument('--model-refine-kernel-size', type=int, default=3, choices=[1, 3])
|
|
|
+parser.add_argument('--model-name', type=str, required=True)
|
|
|
+parser.add_argument('--model-last-checkpoint', type=str, default=None)
|
|
|
+
|
|
|
+parser.add_argument('--batch-size', type=int, default=4)
|
|
|
+parser.add_argument('--num-workers', type=int, default=16)
|
|
|
+parser.add_argument('--epoch-start', type=int, default=0)
|
|
|
+parser.add_argument('--epoch-end', type=int, required=True)
|
|
|
+
|
|
|
+parser.add_argument('--log-train-loss-interval', type=int, default=10)
|
|
|
+parser.add_argument('--log-train-images-interval', type=int, default=1000)
|
|
|
+parser.add_argument('--log-valid-interval', type=int, default=2000)
|
|
|
+
|
|
|
+parser.add_argument('--checkpoint-interval', type=int, default=2000)
|
|
|
+
|
|
|
+args = parser.parse_args()
|
|
|
+
|
|
|
+
|
|
|
+distributed_num_gpus = torch.cuda.device_count()
|
|
|
+assert args.batch_size % distributed_num_gpus == 0
|
|
|
+
|
|
|
+
|
|
|
+# --------------- Main ---------------
|
|
|
+
|
|
|
+def train_worker(rank, addr, port):
|
|
|
+
|
|
|
+ # Distributed Setup
|
|
|
+ os.environ['MASTER_ADDR'] = addr
|
|
|
+ os.environ['MASTER_PORT'] = port
|
|
|
+ dist.init_process_group("nccl", rank=rank, world_size=distributed_num_gpus)
|
|
|
+
|
|
|
+ # Training DataLoader
|
|
|
+ dataset_train = ZipDataset([
|
|
|
+ ZipDataset([
|
|
|
+ ImagesDataset(DATA_PATH[args.dataset_name]['train']['pha'], mode='L'),
|
|
|
+ ImagesDataset(DATA_PATH[args.dataset_name]['train']['fgr'], mode='RGB'),
|
|
|
+ ], transforms=A.PairCompose([
|
|
|
+ A.PairRandomAffineAndResize((2048, 2048), degrees=(-5, 5), translate=(0.1, 0.1), scale=(0.3, 1), shear=(-5, 5)),
|
|
|
+ A.PairRandomHorizontalFlip(),
|
|
|
+ A.PairRandomBoxBlur(0.1, 5),
|
|
|
+ A.PairRandomSharpen(0.1),
|
|
|
+ A.PairApplyOnlyAtIndices([1], T.ColorJitter(0.15, 0.15, 0.15, 0.05)),
|
|
|
+ A.PairApply(T.ToTensor())
|
|
|
+ ]), assert_equal_length=True),
|
|
|
+ ImagesDataset(DATA_PATH['backgrounds']['train'], mode='RGB', transforms=T.Compose([
|
|
|
+ A.RandomAffineAndResize((2048, 2048), degrees=(-5, 5), translate=(0.1, 0.1), scale=(1, 2), shear=(-5, 5)),
|
|
|
+ T.RandomHorizontalFlip(),
|
|
|
+ A.RandomBoxBlur(0.1, 5),
|
|
|
+ A.RandomSharpen(0.1),
|
|
|
+ T.ColorJitter(0.15, 0.15, 0.15, 0.05),
|
|
|
+ T.ToTensor()
|
|
|
+ ])),
|
|
|
+ ])
|
|
|
+ dataset_train_len_per_gpu_worker = int(len(dataset_train) / distributed_num_gpus)
|
|
|
+ dataset_train = Subset(dataset_train, range(rank * dataset_train_len_per_gpu_worker, (rank + 1) * dataset_train_len_per_gpu_worker))
|
|
|
+ dataloader_train = DataLoader(dataset_train,
|
|
|
+ shuffle=True,
|
|
|
+ pin_memory=True,
|
|
|
+ drop_last=True,
|
|
|
+ batch_size=args.batch_size // distributed_num_gpus,
|
|
|
+ num_workers=args.num_workers // distributed_num_gpus)
|
|
|
+
|
|
|
+ # Validation DataLoader
|
|
|
+ if rank == 0:
|
|
|
+ dataset_valid = ZipDataset([
|
|
|
+ ZipDataset([
|
|
|
+ ImagesDataset(DATA_PATH[args.dataset_name]['valid']['pha'], mode='L'),
|
|
|
+ ImagesDataset(DATA_PATH[args.dataset_name]['valid']['fgr'], mode='RGB')
|
|
|
+ ], transforms=A.PairCompose([
|
|
|
+ A.PairRandomAffineAndResize((2048, 2048), degrees=(-5, 5), translate=(0.1, 0.1), scale=(0.3, 1), shear=(-5, 5)),
|
|
|
+ A.PairApply(T.ToTensor())
|
|
|
+ ]), assert_equal_length=True),
|
|
|
+ ImagesDataset(DATA_PATH['backgrounds']['valid'], mode='RGB', transforms=T.Compose([
|
|
|
+ A.RandomAffineAndResize((2048, 2048), degrees=(-5, 5), translate=(0.1, 0.1), scale=(1, 1.2), shear=(-5, 5)),
|
|
|
+ T.ToTensor()
|
|
|
+ ])),
|
|
|
+ ])
|
|
|
+ dataset_valid = SampleDataset(dataset_valid, 50)
|
|
|
+ dataloader_valid = DataLoader(dataset_valid,
|
|
|
+ pin_memory=True,
|
|
|
+ drop_last=True,
|
|
|
+ batch_size=args.batch_size // distributed_num_gpus,
|
|
|
+ num_workers=args.num_workers // distributed_num_gpus)
|
|
|
+
|
|
|
+ # Model
|
|
|
+ model = MattingRefine(args.model_backbone,
|
|
|
+ args.model_backbone_scale,
|
|
|
+ args.model_refine_mode,
|
|
|
+ args.model_refine_sample_pixels,
|
|
|
+ args.model_refine_thresholding,
|
|
|
+ args.model_refine_kernel_size).to(rank)
|
|
|
+ model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
|
|
+ model_distributed = nn.parallel.DistributedDataParallel(model, device_ids=[rank])
|
|
|
+
|
|
|
+ if args.model_last_checkpoint is not None:
|
|
|
+ load_matched_state_dict(model, torch.load(args.model_last_checkpoint))
|
|
|
+
|
|
|
+ optimizer = Adam([
|
|
|
+ {'params': model.backbone.parameters(), 'lr': 5e-5},
|
|
|
+ {'params': model.aspp.parameters(), 'lr': 5e-5},
|
|
|
+ {'params': model.decoder.parameters(), 'lr': 1e-4},
|
|
|
+ {'params': model.refiner.parameters(), 'lr': 3e-4},
|
|
|
+ ])
|
|
|
+ scaler = GradScaler()
|
|
|
+
|
|
|
+ # Logging and checkpoints
|
|
|
+ if rank == 0:
|
|
|
+ if not os.path.exists(f'checkpoint/{args.model_name}'):
|
|
|
+ os.makedirs(f'checkpoint/{args.model_name}')
|
|
|
+ writer = SummaryWriter(f'log/{args.model_name}')
|
|
|
+
|
|
|
+ # Run loop
|
|
|
+ for epoch in range(args.epoch_start, args.epoch_end):
|
|
|
+ for i, ((true_pha, true_fgr), true_bgr) in enumerate(tqdm(dataloader_train)):
|
|
|
+ step = epoch * len(dataloader_train) + i
|
|
|
+
|
|
|
+ true_pha = true_pha.to(rank, non_blocking=True)
|
|
|
+ true_fgr = true_fgr.to(rank, non_blocking=True)
|
|
|
+ true_bgr = true_bgr.to(rank, non_blocking=True)
|
|
|
+ true_pha, true_fgr, true_bgr = random_crop(true_pha, true_fgr, true_bgr)
|
|
|
+
|
|
|
+ true_src = true_bgr.clone()
|
|
|
+
|
|
|
+ # Augment with shadow
|
|
|
+ aug_shadow_idx = torch.rand(len(true_src)) < 0.3
|
|
|
+ if aug_shadow_idx.any():
|
|
|
+ aug_shadow = true_pha[aug_shadow_idx].mul(0.3 * random.random())
|
|
|
+ aug_shadow = T.RandomAffine(degrees=(-5, 5), translate=(0.2, 0.2), scale=(0.5, 1.5), shear=(-5, 5))(aug_shadow)
|
|
|
+ aug_shadow = kornia.filters.box_blur(aug_shadow, (random.choice(range(20, 40)),) * 2)
|
|
|
+ true_src[aug_shadow_idx] = true_src[aug_shadow_idx].sub_(aug_shadow).clamp_(0, 1)
|
|
|
+ del aug_shadow
|
|
|
+ del aug_shadow_idx
|
|
|
+
|
|
|
+ # Composite foreground onto source
|
|
|
+ true_src = true_fgr * true_pha + true_src * (1 - true_pha)
|
|
|
+
|
|
|
+ # Augment with noise
|
|
|
+ aug_noise_idx = torch.rand(len(true_src)) < 0.4
|
|
|
+ if aug_noise_idx.any():
|
|
|
+ true_src[aug_noise_idx] = true_src[aug_noise_idx].add_(torch.randn_like(true_src[aug_noise_idx]).mul_(0.03 * random.random())).clamp_(0, 1)
|
|
|
+ true_bgr[aug_noise_idx] = true_bgr[aug_noise_idx].add_(torch.randn_like(true_bgr[aug_noise_idx]).mul_(0.03 * random.random())).clamp_(0, 1)
|
|
|
+ del aug_noise_idx
|
|
|
+
|
|
|
+ # Augment background with jitter
|
|
|
+ aug_jitter_idx = torch.rand(len(true_src)) < 0.8
|
|
|
+ if aug_jitter_idx.any():
|
|
|
+ true_bgr[aug_jitter_idx] = kornia.augmentation.ColorJitter(0.18, 0.18, 0.18, 0.1)(true_bgr[aug_jitter_idx])
|
|
|
+ del aug_jitter_idx
|
|
|
+
|
|
|
+ # Augment background with affine
|
|
|
+ aug_affine_idx = torch.rand(len(true_bgr)) < 0.3
|
|
|
+ if aug_affine_idx.any():
|
|
|
+ true_bgr[aug_affine_idx] = T.RandomAffine(degrees=(-1, 1), translate=(0.01, 0.01))(true_bgr[aug_affine_idx])
|
|
|
+ del aug_affine_idx
|
|
|
+
|
|
|
+ with autocast():
|
|
|
+ pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm, _ = model_distributed(true_src, true_bgr)
|
|
|
+ loss = compute_loss(pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm, true_pha, true_fgr)
|
|
|
+
|
|
|
+ scaler.scale(loss).backward()
|
|
|
+ scaler.step(optimizer)
|
|
|
+ scaler.update()
|
|
|
+ optimizer.zero_grad()
|
|
|
+
|
|
|
+ if rank == 0:
|
|
|
+ if (i + 1) % args.log_train_loss_interval == 0:
|
|
|
+ writer.add_scalar('loss', loss, step)
|
|
|
+
|
|
|
+ if (i + 1) % args.log_train_images_interval == 0:
|
|
|
+ writer.add_image('train_pred_pha', make_grid(pred_pha, nrow=5), step)
|
|
|
+ writer.add_image('train_pred_fgr', make_grid(pred_fgr, nrow=5), step)
|
|
|
+ writer.add_image('train_pred_com', make_grid(pred_fgr * pred_pha, nrow=5), step)
|
|
|
+ writer.add_image('train_pred_err', make_grid(pred_err_sm, nrow=5), step)
|
|
|
+ writer.add_image('train_true_src', make_grid(true_src, nrow=5), step)
|
|
|
+
|
|
|
+ del true_pha, true_fgr, true_src, true_bgr
|
|
|
+ del pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm
|
|
|
+
|
|
|
+ if (i + 1) % args.log_valid_interval == 0:
|
|
|
+ valid(model, dataloader_valid, writer, step)
|
|
|
+
|
|
|
+ if (step + 1) % args.checkpoint_interval == 0:
|
|
|
+ torch.save(model.state_dict(), f'checkpoint/{args.model_name}/epoch-{epoch}-iter-{step}.pth')
|
|
|
+
|
|
|
+ if rank == 0:
|
|
|
+ torch.save(model.state_dict(), f'checkpoint/{args.model_name}/epoch-{epoch}.pth')
|
|
|
+
|
|
|
+ # Clean up
|
|
|
+ dist.destroy_process_group()
|
|
|
+
|
|
|
+
|
|
|
+# --------------- Utils ---------------
|
|
|
+
|
|
|
+
|
|
|
+def compute_loss(pred_pha_lg, pred_fgr_lg, pred_pha_sm, pred_fgr_sm, pred_err_sm, true_pha_lg, true_fgr_lg):
|
|
|
+ true_pha_sm = kornia.resize(true_pha_lg, pred_pha_sm.shape[2:])
|
|
|
+ true_fgr_sm = kornia.resize(true_fgr_lg, pred_fgr_sm.shape[2:])
|
|
|
+ true_msk_lg = true_pha_lg != 0
|
|
|
+ true_msk_sm = true_pha_sm != 0
|
|
|
+ return F.l1_loss(pred_pha_lg, true_pha_lg) + \
|
|
|
+ F.l1_loss(pred_pha_sm, true_pha_sm) + \
|
|
|
+ F.l1_loss(kornia.sobel(pred_pha_lg), kornia.sobel(true_pha_lg)) + \
|
|
|
+ F.l1_loss(kornia.sobel(pred_pha_sm), kornia.sobel(true_pha_sm)) + \
|
|
|
+ F.l1_loss(pred_fgr_lg * true_msk_lg, true_fgr_lg * true_msk_lg) + \
|
|
|
+ F.l1_loss(pred_fgr_sm * true_msk_sm, true_fgr_sm * true_msk_sm) + \
|
|
|
+ F.mse_loss(kornia.resize(pred_err_sm, true_pha_lg.shape[2:]), \
|
|
|
+ kornia.resize(pred_pha_sm, true_pha_lg.shape[2:]).sub(true_pha_lg).abs())
|
|
|
+
|
|
|
+
|
|
|
+def random_crop(*imgs):
|
|
|
+ H_src, W_src = imgs[0].shape[2:]
|
|
|
+ W_tgt = random.choice(range(1024, 2048)) // 4 * 4
|
|
|
+ H_tgt = random.choice(range(1024, 2048)) // 4 * 4
|
|
|
+ scale = max(W_tgt / W_src, H_tgt / H_src)
|
|
|
+ results = []
|
|
|
+ for img in imgs:
|
|
|
+ img = kornia.resize(img, (int(H_src * scale), int(W_src * scale)))
|
|
|
+ img = kornia.center_crop(img, (H_tgt, W_tgt))
|
|
|
+ results.append(img)
|
|
|
+ return results
|
|
|
+
|
|
|
+
|
|
|
+def valid(model, dataloader, writer, step):
|
|
|
+ model.eval()
|
|
|
+ loss_total = 0
|
|
|
+ loss_count = 0
|
|
|
+ with torch.no_grad():
|
|
|
+ for (true_pha, true_fgr), true_bgr in dataloader:
|
|
|
+ batch_size = true_pha.size(0)
|
|
|
+
|
|
|
+ true_pha = true_pha.cuda(non_blocking=True)
|
|
|
+ true_fgr = true_fgr.cuda(non_blocking=True)
|
|
|
+ true_bgr = true_bgr.cuda(non_blocking=True)
|
|
|
+ true_src = true_pha * true_fgr + (1 - true_pha) * true_bgr
|
|
|
+
|
|
|
+ pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm, _ = model(true_src, true_bgr)
|
|
|
+ loss = compute_loss(pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm, true_pha, true_fgr)
|
|
|
+ loss_total += loss.cpu().item() * batch_size
|
|
|
+ loss_count += batch_size
|
|
|
+
|
|
|
+ writer.add_scalar('valid_loss', loss_total / loss_count, step)
|
|
|
+ model.train()
|
|
|
+
|
|
|
+
|
|
|
+# --------------- Start ---------------
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == '__main__':
|
|
|
+ addr = 'localhost'
|
|
|
+ port = str(random.choice(range(12300, 12400))) # pick a random port.
|
|
|
+ mp.spawn(train_worker,
|
|
|
+ nprocs=distributed_num_gpus,
|
|
|
+ args=(addr, port),
|
|
|
+ join=True)
|