123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309 |
- """
- Train MattingRefine
- Supports multi-GPU training with DistributedDataParallel() and SyncBatchNorm.
- Select GPUs through CUDA_VISIBLE_DEVICES environment variable.
- Example:
- CUDA_VISIBLE_DEVICES=0,1 python train_refine.py \
- --dataset-name videomatte240k \
- --model-backbone resnet50 \
- --model-name mattingrefine-resnet50-videomatte240k \
- --model-last-checkpoint "PATH_TO_LAST_CHECKPOINT" \
- --epoch-end 1
- """
- import argparse
- import kornia
- import torch
- import os
- import random
- from torch import nn
- from torch import distributed as dist
- from torch import multiprocessing as mp
- from torch.nn import functional as F
- from torch.cuda.amp import autocast, GradScaler
- from torch.utils.tensorboard import SummaryWriter
- from torch.utils.data import DataLoader, Subset
- from torch.optim import Adam
- from torchvision.utils import make_grid
- from tqdm import tqdm
- from torchvision import transforms as T
- from PIL import Image
- from data_path import DATA_PATH
- from dataset import ImagesDataset, ZipDataset, VideoDataset, SampleDataset
- from dataset import augmentation as A
- from model import MattingRefine
- from model.utils import load_matched_state_dict
- # --------------- Arguments ---------------
- parser = argparse.ArgumentParser()
- parser.add_argument('--dataset-name', type=str, required=True, choices=DATA_PATH.keys())
- parser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2'])
- parser.add_argument('--model-backbone-scale', type=float, default=0.25)
- parser.add_argument('--model-refine-mode', type=str, default='sampling', choices=['full', 'sampling', 'thresholding'])
- parser.add_argument('--model-refine-sample-pixels', type=int, default=80_000)
- parser.add_argument('--model-refine-thresholding', type=float, default=0.7)
- parser.add_argument('--model-refine-kernel-size', type=int, default=3, choices=[1, 3])
- parser.add_argument('--model-name', type=str, required=True)
- parser.add_argument('--model-last-checkpoint', type=str, default=None)
- parser.add_argument('--batch-size', type=int, default=4)
- parser.add_argument('--num-workers', type=int, default=16)
- parser.add_argument('--epoch-start', type=int, default=0)
- parser.add_argument('--epoch-end', type=int, required=True)
- parser.add_argument('--log-train-loss-interval', type=int, default=10)
- parser.add_argument('--log-train-images-interval', type=int, default=1000)
- parser.add_argument('--log-valid-interval', type=int, default=2000)
- parser.add_argument('--checkpoint-interval', type=int, default=2000)
- args = parser.parse_args()
- distributed_num_gpus = torch.cuda.device_count()
- assert args.batch_size % distributed_num_gpus == 0
- # --------------- Main ---------------
- def train_worker(rank, addr, port):
-
- # Distributed Setup
- os.environ['MASTER_ADDR'] = addr
- os.environ['MASTER_PORT'] = port
- dist.init_process_group("nccl", rank=rank, world_size=distributed_num_gpus)
-
- # Training DataLoader
- dataset_train = ZipDataset([
- ZipDataset([
- ImagesDataset(DATA_PATH[args.dataset_name]['train']['pha'], mode='L'),
- ImagesDataset(DATA_PATH[args.dataset_name]['train']['fgr'], mode='RGB'),
- ], transforms=A.PairCompose([
- A.PairRandomAffineAndResize((2048, 2048), degrees=(-5, 5), translate=(0.1, 0.1), scale=(0.3, 1), shear=(-5, 5)),
- A.PairRandomHorizontalFlip(),
- A.PairRandomBoxBlur(0.1, 5),
- A.PairRandomSharpen(0.1),
- A.PairApplyOnlyAtIndices([1], T.ColorJitter(0.15, 0.15, 0.15, 0.05)),
- A.PairApply(T.ToTensor())
- ]), assert_equal_length=True),
- ImagesDataset(DATA_PATH['backgrounds']['train'], mode='RGB', transforms=T.Compose([
- A.RandomAffineAndResize((2048, 2048), degrees=(-5, 5), translate=(0.1, 0.1), scale=(1, 2), shear=(-5, 5)),
- T.RandomHorizontalFlip(),
- A.RandomBoxBlur(0.1, 5),
- A.RandomSharpen(0.1),
- T.ColorJitter(0.15, 0.15, 0.15, 0.05),
- T.ToTensor()
- ])),
- ])
- dataset_train_len_per_gpu_worker = int(len(dataset_train) / distributed_num_gpus)
- dataset_train = Subset(dataset_train, range(rank * dataset_train_len_per_gpu_worker, (rank + 1) * dataset_train_len_per_gpu_worker))
- dataloader_train = DataLoader(dataset_train,
- shuffle=True,
- pin_memory=True,
- drop_last=True,
- batch_size=args.batch_size // distributed_num_gpus,
- num_workers=args.num_workers // distributed_num_gpus)
-
- # Validation DataLoader
- if rank == 0:
- dataset_valid = ZipDataset([
- ZipDataset([
- ImagesDataset(DATA_PATH[args.dataset_name]['valid']['pha'], mode='L'),
- ImagesDataset(DATA_PATH[args.dataset_name]['valid']['fgr'], mode='RGB')
- ], transforms=A.PairCompose([
- A.PairRandomAffineAndResize((2048, 2048), degrees=(-5, 5), translate=(0.1, 0.1), scale=(0.3, 1), shear=(-5, 5)),
- A.PairApply(T.ToTensor())
- ]), assert_equal_length=True),
- ImagesDataset(DATA_PATH['backgrounds']['valid'], mode='RGB', transforms=T.Compose([
- A.RandomAffineAndResize((2048, 2048), degrees=(-5, 5), translate=(0.1, 0.1), scale=(1, 1.2), shear=(-5, 5)),
- T.ToTensor()
- ])),
- ])
- dataset_valid = SampleDataset(dataset_valid, 50)
- dataloader_valid = DataLoader(dataset_valid,
- pin_memory=True,
- drop_last=True,
- batch_size=args.batch_size // distributed_num_gpus,
- num_workers=args.num_workers // distributed_num_gpus)
-
- # Model
- model = MattingRefine(args.model_backbone,
- args.model_backbone_scale,
- args.model_refine_mode,
- args.model_refine_sample_pixels,
- args.model_refine_thresholding,
- args.model_refine_kernel_size).to(rank)
- model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
- model_distributed = nn.parallel.DistributedDataParallel(model, device_ids=[rank])
-
- if args.model_last_checkpoint is not None:
- load_matched_state_dict(model, torch.load(args.model_last_checkpoint))
- optimizer = Adam([
- {'params': model.backbone.parameters(), 'lr': 5e-5},
- {'params': model.aspp.parameters(), 'lr': 5e-5},
- {'params': model.decoder.parameters(), 'lr': 1e-4},
- {'params': model.refiner.parameters(), 'lr': 3e-4},
- ])
- scaler = GradScaler()
-
- # Logging and checkpoints
- if rank == 0:
- if not os.path.exists(f'checkpoint/{args.model_name}'):
- os.makedirs(f'checkpoint/{args.model_name}')
- writer = SummaryWriter(f'log/{args.model_name}')
-
- # Run loop
- for epoch in range(args.epoch_start, args.epoch_end):
- for i, ((true_pha, true_fgr), true_bgr) in enumerate(tqdm(dataloader_train)):
- step = epoch * len(dataloader_train) + i
- true_pha = true_pha.to(rank, non_blocking=True)
- true_fgr = true_fgr.to(rank, non_blocking=True)
- true_bgr = true_bgr.to(rank, non_blocking=True)
- true_pha, true_fgr, true_bgr = random_crop(true_pha, true_fgr, true_bgr)
-
- true_src = true_bgr.clone()
-
- # Augment with shadow
- aug_shadow_idx = torch.rand(len(true_src)) < 0.3
- if aug_shadow_idx.any():
- aug_shadow = true_pha[aug_shadow_idx].mul(0.3 * random.random())
- aug_shadow = T.RandomAffine(degrees=(-5, 5), translate=(0.2, 0.2), scale=(0.5, 1.5), shear=(-5, 5))(aug_shadow)
- aug_shadow = kornia.filters.box_blur(aug_shadow, (random.choice(range(20, 40)),) * 2)
- true_src[aug_shadow_idx] = true_src[aug_shadow_idx].sub_(aug_shadow).clamp_(0, 1)
- del aug_shadow
- del aug_shadow_idx
-
- # Composite foreground onto source
- true_src = true_fgr * true_pha + true_src * (1 - true_pha)
- # Augment with noise
- aug_noise_idx = torch.rand(len(true_src)) < 0.4
- if aug_noise_idx.any():
- true_src[aug_noise_idx] = true_src[aug_noise_idx].add_(torch.randn_like(true_src[aug_noise_idx]).mul_(0.03 * random.random())).clamp_(0, 1)
- true_bgr[aug_noise_idx] = true_bgr[aug_noise_idx].add_(torch.randn_like(true_bgr[aug_noise_idx]).mul_(0.03 * random.random())).clamp_(0, 1)
- del aug_noise_idx
-
- # Augment background with jitter
- aug_jitter_idx = torch.rand(len(true_src)) < 0.8
- if aug_jitter_idx.any():
- true_bgr[aug_jitter_idx] = kornia.augmentation.ColorJitter(0.18, 0.18, 0.18, 0.1)(true_bgr[aug_jitter_idx])
- del aug_jitter_idx
-
- # Augment background with affine
- aug_affine_idx = torch.rand(len(true_bgr)) < 0.3
- if aug_affine_idx.any():
- true_bgr[aug_affine_idx] = T.RandomAffine(degrees=(-1, 1), translate=(0.01, 0.01))(true_bgr[aug_affine_idx])
- del aug_affine_idx
-
- with autocast():
- pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm, _ = model_distributed(true_src, true_bgr)
- loss = compute_loss(pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm, true_pha, true_fgr)
- scaler.scale(loss).backward()
- scaler.step(optimizer)
- scaler.update()
- optimizer.zero_grad()
- if rank == 0:
- if (i + 1) % args.log_train_loss_interval == 0:
- writer.add_scalar('loss', loss, step)
- if (i + 1) % args.log_train_images_interval == 0:
- writer.add_image('train_pred_pha', make_grid(pred_pha, nrow=5), step)
- writer.add_image('train_pred_fgr', make_grid(pred_fgr, nrow=5), step)
- writer.add_image('train_pred_com', make_grid(pred_fgr * pred_pha, nrow=5), step)
- writer.add_image('train_pred_err', make_grid(pred_err_sm, nrow=5), step)
- writer.add_image('train_true_src', make_grid(true_src, nrow=5), step)
- del true_pha, true_fgr, true_src, true_bgr
- del pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm
- if (i + 1) % args.log_valid_interval == 0:
- valid(model, dataloader_valid, writer, step)
- if (step + 1) % args.checkpoint_interval == 0:
- torch.save(model.state_dict(), f'checkpoint/{args.model_name}/epoch-{epoch}-iter-{step}.pth')
-
- if rank == 0:
- torch.save(model.state_dict(), f'checkpoint/{args.model_name}/epoch-{epoch}.pth')
-
- # Clean up
- dist.destroy_process_group()
-
-
- # --------------- Utils ---------------
- def compute_loss(pred_pha_lg, pred_fgr_lg, pred_pha_sm, pred_fgr_sm, pred_err_sm, true_pha_lg, true_fgr_lg):
- true_pha_sm = kornia.resize(true_pha_lg, pred_pha_sm.shape[2:])
- true_fgr_sm = kornia.resize(true_fgr_lg, pred_fgr_sm.shape[2:])
- true_msk_lg = true_pha_lg != 0
- true_msk_sm = true_pha_sm != 0
- return F.l1_loss(pred_pha_lg, true_pha_lg) + \
- F.l1_loss(pred_pha_sm, true_pha_sm) + \
- F.l1_loss(kornia.sobel(pred_pha_lg), kornia.sobel(true_pha_lg)) + \
- F.l1_loss(kornia.sobel(pred_pha_sm), kornia.sobel(true_pha_sm)) + \
- F.l1_loss(pred_fgr_lg * true_msk_lg, true_fgr_lg * true_msk_lg) + \
- F.l1_loss(pred_fgr_sm * true_msk_sm, true_fgr_sm * true_msk_sm) + \
- F.mse_loss(kornia.resize(pred_err_sm, true_pha_lg.shape[2:]), \
- kornia.resize(pred_pha_sm, true_pha_lg.shape[2:]).sub(true_pha_lg).abs())
- def random_crop(*imgs):
- H_src, W_src = imgs[0].shape[2:]
- W_tgt = random.choice(range(1024, 2048)) // 4 * 4
- H_tgt = random.choice(range(1024, 2048)) // 4 * 4
- scale = max(W_tgt / W_src, H_tgt / H_src)
- results = []
- for img in imgs:
- img = kornia.resize(img, (int(H_src * scale), int(W_src * scale)))
- img = kornia.center_crop(img, (H_tgt, W_tgt))
- results.append(img)
- return results
- def valid(model, dataloader, writer, step):
- model.eval()
- loss_total = 0
- loss_count = 0
- with torch.no_grad():
- for (true_pha, true_fgr), true_bgr in dataloader:
- batch_size = true_pha.size(0)
-
- true_pha = true_pha.cuda(non_blocking=True)
- true_fgr = true_fgr.cuda(non_blocking=True)
- true_bgr = true_bgr.cuda(non_blocking=True)
- true_src = true_pha * true_fgr + (1 - true_pha) * true_bgr
- pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm, _ = model(true_src, true_bgr)
- loss = compute_loss(pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm, true_pha, true_fgr)
- loss_total += loss.cpu().item() * batch_size
- loss_count += batch_size
- writer.add_scalar('valid_loss', loss_total / loss_count, step)
- model.train()
- # --------------- Start ---------------
- if __name__ == '__main__':
- addr = 'localhost'
- port = str(random.choice(range(12300, 12400))) # pick a random port.
- mp.spawn(train_worker,
- nprocs=distributed_num_gpus,
- args=(addr, port),
- join=True)
|