train_refine.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309
  1. """
  2. Train MattingRefine
  3. Supports multi-GPU training with DistributedDataParallel() and SyncBatchNorm.
  4. Select GPUs through CUDA_VISIBLE_DEVICES environment variable.
  5. Example:
  6. CUDA_VISIBLE_DEVICES=0,1 python train_refine.py \
  7. --dataset-name videomatte240k \
  8. --model-backbone resnet50 \
  9. --model-name mattingrefine-resnet50-videomatte240k \
  10. --model-last-checkpoint "PATH_TO_LAST_CHECKPOINT" \
  11. --epoch-end 1
  12. """
  13. import argparse
  14. import kornia
  15. import torch
  16. import os
  17. import random
  18. from torch import nn
  19. from torch import distributed as dist
  20. from torch import multiprocessing as mp
  21. from torch.nn import functional as F
  22. from torch.cuda.amp import autocast, GradScaler
  23. from torch.utils.tensorboard import SummaryWriter
  24. from torch.utils.data import DataLoader, Subset
  25. from torch.optim import Adam
  26. from torchvision.utils import make_grid
  27. from tqdm import tqdm
  28. from torchvision import transforms as T
  29. from PIL import Image
  30. from data_path import DATA_PATH
  31. from dataset import ImagesDataset, ZipDataset, VideoDataset, SampleDataset
  32. from dataset import augmentation as A
  33. from model import MattingRefine
  34. from model.utils import load_matched_state_dict
  35. # --------------- Arguments ---------------
  36. parser = argparse.ArgumentParser()
  37. parser.add_argument('--dataset-name', type=str, required=True, choices=DATA_PATH.keys())
  38. parser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2'])
  39. parser.add_argument('--model-backbone-scale', type=float, default=0.25)
  40. parser.add_argument('--model-refine-mode', type=str, default='sampling', choices=['full', 'sampling', 'thresholding'])
  41. parser.add_argument('--model-refine-sample-pixels', type=int, default=80_000)
  42. parser.add_argument('--model-refine-thresholding', type=float, default=0.7)
  43. parser.add_argument('--model-refine-kernel-size', type=int, default=3, choices=[1, 3])
  44. parser.add_argument('--model-name', type=str, required=True)
  45. parser.add_argument('--model-last-checkpoint', type=str, default=None)
  46. parser.add_argument('--batch-size', type=int, default=4)
  47. parser.add_argument('--num-workers', type=int, default=16)
  48. parser.add_argument('--epoch-start', type=int, default=0)
  49. parser.add_argument('--epoch-end', type=int, required=True)
  50. parser.add_argument('--log-train-loss-interval', type=int, default=10)
  51. parser.add_argument('--log-train-images-interval', type=int, default=1000)
  52. parser.add_argument('--log-valid-interval', type=int, default=2000)
  53. parser.add_argument('--checkpoint-interval', type=int, default=2000)
  54. args = parser.parse_args()
  55. distributed_num_gpus = torch.cuda.device_count()
  56. assert args.batch_size % distributed_num_gpus == 0
  57. # --------------- Main ---------------
  58. def train_worker(rank, addr, port):
  59. # Distributed Setup
  60. os.environ['MASTER_ADDR'] = addr
  61. os.environ['MASTER_PORT'] = port
  62. dist.init_process_group("nccl", rank=rank, world_size=distributed_num_gpus)
  63. # Training DataLoader
  64. dataset_train = ZipDataset([
  65. ZipDataset([
  66. ImagesDataset(DATA_PATH[args.dataset_name]['train']['pha'], mode='L'),
  67. ImagesDataset(DATA_PATH[args.dataset_name]['train']['fgr'], mode='RGB'),
  68. ], transforms=A.PairCompose([
  69. A.PairRandomAffineAndResize((2048, 2048), degrees=(-5, 5), translate=(0.1, 0.1), scale=(0.3, 1), shear=(-5, 5)),
  70. A.PairRandomHorizontalFlip(),
  71. A.PairRandomBoxBlur(0.1, 5),
  72. A.PairRandomSharpen(0.1),
  73. A.PairApplyOnlyAtIndices([1], T.ColorJitter(0.15, 0.15, 0.15, 0.05)),
  74. A.PairApply(T.ToTensor())
  75. ]), assert_equal_length=True),
  76. ImagesDataset(DATA_PATH['backgrounds']['train'], mode='RGB', transforms=T.Compose([
  77. A.RandomAffineAndResize((2048, 2048), degrees=(-5, 5), translate=(0.1, 0.1), scale=(1, 2), shear=(-5, 5)),
  78. T.RandomHorizontalFlip(),
  79. A.RandomBoxBlur(0.1, 5),
  80. A.RandomSharpen(0.1),
  81. T.ColorJitter(0.15, 0.15, 0.15, 0.05),
  82. T.ToTensor()
  83. ])),
  84. ])
  85. dataset_train_len_per_gpu_worker = int(len(dataset_train) / distributed_num_gpus)
  86. dataset_train = Subset(dataset_train, range(rank * dataset_train_len_per_gpu_worker, (rank + 1) * dataset_train_len_per_gpu_worker))
  87. dataloader_train = DataLoader(dataset_train,
  88. shuffle=True,
  89. pin_memory=True,
  90. drop_last=True,
  91. batch_size=args.batch_size // distributed_num_gpus,
  92. num_workers=args.num_workers // distributed_num_gpus)
  93. # Validation DataLoader
  94. if rank == 0:
  95. dataset_valid = ZipDataset([
  96. ZipDataset([
  97. ImagesDataset(DATA_PATH[args.dataset_name]['valid']['pha'], mode='L'),
  98. ImagesDataset(DATA_PATH[args.dataset_name]['valid']['fgr'], mode='RGB')
  99. ], transforms=A.PairCompose([
  100. A.PairRandomAffineAndResize((2048, 2048), degrees=(-5, 5), translate=(0.1, 0.1), scale=(0.3, 1), shear=(-5, 5)),
  101. A.PairApply(T.ToTensor())
  102. ]), assert_equal_length=True),
  103. ImagesDataset(DATA_PATH['backgrounds']['valid'], mode='RGB', transforms=T.Compose([
  104. A.RandomAffineAndResize((2048, 2048), degrees=(-5, 5), translate=(0.1, 0.1), scale=(1, 1.2), shear=(-5, 5)),
  105. T.ToTensor()
  106. ])),
  107. ])
  108. dataset_valid = SampleDataset(dataset_valid, 50)
  109. dataloader_valid = DataLoader(dataset_valid,
  110. pin_memory=True,
  111. drop_last=True,
  112. batch_size=args.batch_size // distributed_num_gpus,
  113. num_workers=args.num_workers // distributed_num_gpus)
  114. # Model
  115. model = MattingRefine(args.model_backbone,
  116. args.model_backbone_scale,
  117. args.model_refine_mode,
  118. args.model_refine_sample_pixels,
  119. args.model_refine_thresholding,
  120. args.model_refine_kernel_size).to(rank)
  121. model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
  122. model_distributed = nn.parallel.DistributedDataParallel(model, device_ids=[rank])
  123. if args.model_last_checkpoint is not None:
  124. load_matched_state_dict(model, torch.load(args.model_last_checkpoint))
  125. optimizer = Adam([
  126. {'params': model.backbone.parameters(), 'lr': 5e-5},
  127. {'params': model.aspp.parameters(), 'lr': 5e-5},
  128. {'params': model.decoder.parameters(), 'lr': 1e-4},
  129. {'params': model.refiner.parameters(), 'lr': 3e-4},
  130. ])
  131. scaler = GradScaler()
  132. # Logging and checkpoints
  133. if rank == 0:
  134. if not os.path.exists(f'checkpoint/{args.model_name}'):
  135. os.makedirs(f'checkpoint/{args.model_name}')
  136. writer = SummaryWriter(f'log/{args.model_name}')
  137. # Run loop
  138. for epoch in range(args.epoch_start, args.epoch_end):
  139. for i, ((true_pha, true_fgr), true_bgr) in enumerate(tqdm(dataloader_train)):
  140. step = epoch * len(dataloader_train) + i
  141. true_pha = true_pha.to(rank, non_blocking=True)
  142. true_fgr = true_fgr.to(rank, non_blocking=True)
  143. true_bgr = true_bgr.to(rank, non_blocking=True)
  144. true_pha, true_fgr, true_bgr = random_crop(true_pha, true_fgr, true_bgr)
  145. true_src = true_bgr.clone()
  146. # Augment with shadow
  147. aug_shadow_idx = torch.rand(len(true_src)) < 0.3
  148. if aug_shadow_idx.any():
  149. aug_shadow = true_pha[aug_shadow_idx].mul(0.3 * random.random())
  150. aug_shadow = T.RandomAffine(degrees=(-5, 5), translate=(0.2, 0.2), scale=(0.5, 1.5), shear=(-5, 5))(aug_shadow)
  151. aug_shadow = kornia.filters.box_blur(aug_shadow, (random.choice(range(20, 40)),) * 2)
  152. true_src[aug_shadow_idx] = true_src[aug_shadow_idx].sub_(aug_shadow).clamp_(0, 1)
  153. del aug_shadow
  154. del aug_shadow_idx
  155. # Composite foreground onto source
  156. true_src = true_fgr * true_pha + true_src * (1 - true_pha)
  157. # Augment with noise
  158. aug_noise_idx = torch.rand(len(true_src)) < 0.4
  159. if aug_noise_idx.any():
  160. true_src[aug_noise_idx] = true_src[aug_noise_idx].add_(torch.randn_like(true_src[aug_noise_idx]).mul_(0.03 * random.random())).clamp_(0, 1)
  161. true_bgr[aug_noise_idx] = true_bgr[aug_noise_idx].add_(torch.randn_like(true_bgr[aug_noise_idx]).mul_(0.03 * random.random())).clamp_(0, 1)
  162. del aug_noise_idx
  163. # Augment background with jitter
  164. aug_jitter_idx = torch.rand(len(true_src)) < 0.8
  165. if aug_jitter_idx.any():
  166. true_bgr[aug_jitter_idx] = kornia.augmentation.ColorJitter(0.18, 0.18, 0.18, 0.1)(true_bgr[aug_jitter_idx])
  167. del aug_jitter_idx
  168. # Augment background with affine
  169. aug_affine_idx = torch.rand(len(true_bgr)) < 0.3
  170. if aug_affine_idx.any():
  171. true_bgr[aug_affine_idx] = T.RandomAffine(degrees=(-1, 1), translate=(0.01, 0.01))(true_bgr[aug_affine_idx])
  172. del aug_affine_idx
  173. with autocast():
  174. pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm, _ = model_distributed(true_src, true_bgr)
  175. loss = compute_loss(pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm, true_pha, true_fgr)
  176. scaler.scale(loss).backward()
  177. scaler.step(optimizer)
  178. scaler.update()
  179. optimizer.zero_grad()
  180. if rank == 0:
  181. if (i + 1) % args.log_train_loss_interval == 0:
  182. writer.add_scalar('loss', loss, step)
  183. if (i + 1) % args.log_train_images_interval == 0:
  184. writer.add_image('train_pred_pha', make_grid(pred_pha, nrow=5), step)
  185. writer.add_image('train_pred_fgr', make_grid(pred_fgr, nrow=5), step)
  186. writer.add_image('train_pred_com', make_grid(pred_fgr * pred_pha, nrow=5), step)
  187. writer.add_image('train_pred_err', make_grid(pred_err_sm, nrow=5), step)
  188. writer.add_image('train_true_src', make_grid(true_src, nrow=5), step)
  189. del true_pha, true_fgr, true_src, true_bgr
  190. del pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm
  191. if (i + 1) % args.log_valid_interval == 0:
  192. valid(model, dataloader_valid, writer, step)
  193. if (step + 1) % args.checkpoint_interval == 0:
  194. torch.save(model.state_dict(), f'checkpoint/{args.model_name}/epoch-{epoch}-iter-{step}.pth')
  195. if rank == 0:
  196. torch.save(model.state_dict(), f'checkpoint/{args.model_name}/epoch-{epoch}.pth')
  197. # Clean up
  198. dist.destroy_process_group()
  199. # --------------- Utils ---------------
  200. def compute_loss(pred_pha_lg, pred_fgr_lg, pred_pha_sm, pred_fgr_sm, pred_err_sm, true_pha_lg, true_fgr_lg):
  201. true_pha_sm = kornia.resize(true_pha_lg, pred_pha_sm.shape[2:])
  202. true_fgr_sm = kornia.resize(true_fgr_lg, pred_fgr_sm.shape[2:])
  203. true_msk_lg = true_pha_lg != 0
  204. true_msk_sm = true_pha_sm != 0
  205. return F.l1_loss(pred_pha_lg, true_pha_lg) + \
  206. F.l1_loss(pred_pha_sm, true_pha_sm) + \
  207. F.l1_loss(kornia.sobel(pred_pha_lg), kornia.sobel(true_pha_lg)) + \
  208. F.l1_loss(kornia.sobel(pred_pha_sm), kornia.sobel(true_pha_sm)) + \
  209. F.l1_loss(pred_fgr_lg * true_msk_lg, true_fgr_lg * true_msk_lg) + \
  210. F.l1_loss(pred_fgr_sm * true_msk_sm, true_fgr_sm * true_msk_sm) + \
  211. F.mse_loss(kornia.resize(pred_err_sm, true_pha_lg.shape[2:]), \
  212. kornia.resize(pred_pha_sm, true_pha_lg.shape[2:]).sub(true_pha_lg).abs())
  213. def random_crop(*imgs):
  214. H_src, W_src = imgs[0].shape[2:]
  215. W_tgt = random.choice(range(1024, 2048)) // 4 * 4
  216. H_tgt = random.choice(range(1024, 2048)) // 4 * 4
  217. scale = max(W_tgt / W_src, H_tgt / H_src)
  218. results = []
  219. for img in imgs:
  220. img = kornia.resize(img, (int(H_src * scale), int(W_src * scale)))
  221. img = kornia.center_crop(img, (H_tgt, W_tgt))
  222. results.append(img)
  223. return results
  224. def valid(model, dataloader, writer, step):
  225. model.eval()
  226. loss_total = 0
  227. loss_count = 0
  228. with torch.no_grad():
  229. for (true_pha, true_fgr), true_bgr in dataloader:
  230. batch_size = true_pha.size(0)
  231. true_pha = true_pha.cuda(non_blocking=True)
  232. true_fgr = true_fgr.cuda(non_blocking=True)
  233. true_bgr = true_bgr.cuda(non_blocking=True)
  234. true_src = true_pha * true_fgr + (1 - true_pha) * true_bgr
  235. pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm, _ = model(true_src, true_bgr)
  236. loss = compute_loss(pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm, true_pha, true_fgr)
  237. loss_total += loss.cpu().item() * batch_size
  238. loss_count += batch_size
  239. writer.add_scalar('valid_loss', loss_total / loss_count, step)
  240. model.train()
  241. # --------------- Start ---------------
  242. if __name__ == '__main__':
  243. addr = 'localhost'
  244. port = str(random.choice(range(12300, 12400))) # pick a random port.
  245. mp.spawn(train_worker,
  246. nprocs=distributed_num_gpus,
  247. args=(addr, port),
  248. join=True)