train_base.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265
  1. """
  2. Train MattingBase
  3. You can download pretrained DeepLabV3 weights from <https://github.com/VainF/DeepLabV3Plus-Pytorch>
  4. Example:
  5. CUDA_VISIBLE_DEVICES=0 python train_base.py \
  6. --dataset-name videomatte240k \
  7. --model-backbone resnet50 \
  8. --model-name mattingbase-resnet50-videomatte240k \
  9. --model-pretrain-initialization "pretraining/best_deeplabv3_resnet50_voc_os16.pth" \
  10. --epoch-end 8
  11. """
  12. import argparse
  13. import kornia
  14. import torch
  15. import os
  16. import random
  17. from torch import nn
  18. from torch.nn import functional as F
  19. from torch.cuda.amp import autocast, GradScaler
  20. from torch.utils.tensorboard import SummaryWriter
  21. from torch.utils.data import DataLoader
  22. from torch.optim import Adam
  23. from torchvision.utils import make_grid
  24. from tqdm import tqdm
  25. from torchvision import transforms as T
  26. from PIL import Image
  27. from data_path import DATA_PATH
  28. from dataset import ImagesDataset, ZipDataset, VideoDataset, SampleDataset
  29. from dataset import augmentation as A
  30. from model import MattingBase
  31. from model.utils import load_matched_state_dict
  32. # --------------- Arguments ---------------
  33. parser = argparse.ArgumentParser()
  34. parser.add_argument('--dataset-name', type=str, required=True, choices=DATA_PATH.keys())
  35. parser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2'])
  36. parser.add_argument('--model-name', type=str, required=True)
  37. parser.add_argument('--model-pretrain-initialization', type=str, default=None)
  38. parser.add_argument('--model-last-checkpoint', type=str, default=None)
  39. parser.add_argument('--batch-size', type=int, default=8)
  40. parser.add_argument('--num-workers', type=int, default=16)
  41. parser.add_argument('--epoch-start', type=int, default=0)
  42. parser.add_argument('--epoch-end', type=int, required=True)
  43. parser.add_argument('--log-train-loss-interval', type=int, default=10)
  44. parser.add_argument('--log-train-images-interval', type=int, default=2000)
  45. parser.add_argument('--log-valid-interval', type=int, default=5000)
  46. parser.add_argument('--checkpoint-interval', type=int, default=5000)
  47. args = parser.parse_args()
  48. # --------------- Loading ---------------
  49. def train():
  50. # Training DataLoader
  51. dataset_train = ZipDataset([
  52. ZipDataset([
  53. ImagesDataset(DATA_PATH[args.dataset_name]['train']['pha'], mode='L'),
  54. ImagesDataset(DATA_PATH[args.dataset_name]['train']['fgr'], mode='RGB'),
  55. ], transforms=A.PairCompose([
  56. A.PairRandomAffineAndResize((512, 512), degrees=(-5, 5), translate=(0.1, 0.1), scale=(0.4, 1), shear=(-5, 5)),
  57. A.PairRandomHorizontalFlip(),
  58. A.PairRandomBoxBlur(0.1, 5),
  59. A.PairRandomSharpen(0.1),
  60. A.PairApplyOnlyAtIndices([1], T.ColorJitter(0.15, 0.15, 0.15, 0.05)),
  61. A.PairApply(T.ToTensor())
  62. ]), assert_equal_length=True),
  63. ImagesDataset(DATA_PATH['backgrounds']['train'], mode='RGB', transforms=T.Compose([
  64. A.RandomAffineAndResize((512, 512), degrees=(-5, 5), translate=(0.1, 0.1), scale=(1, 2), shear=(-5, 5)),
  65. T.RandomHorizontalFlip(),
  66. A.RandomBoxBlur(0.1, 5),
  67. A.RandomSharpen(0.1),
  68. T.ColorJitter(0.15, 0.15, 0.15, 0.05),
  69. T.ToTensor()
  70. ])),
  71. ])
  72. dataloader_train = DataLoader(dataset_train,
  73. shuffle=True,
  74. batch_size=args.batch_size,
  75. num_workers=args.num_workers,
  76. pin_memory=True)
  77. # Validation DataLoader
  78. dataset_valid = ZipDataset([
  79. ZipDataset([
  80. ImagesDataset(DATA_PATH[args.dataset_name]['valid']['pha'], mode='L'),
  81. ImagesDataset(DATA_PATH[args.dataset_name]['valid']['fgr'], mode='RGB')
  82. ], transforms=A.PairCompose([
  83. A.PairRandomAffineAndResize((512, 512), degrees=(-5, 5), translate=(0.1, 0.1), scale=(0.3, 1), shear=(-5, 5)),
  84. A.PairApply(T.ToTensor())
  85. ]), assert_equal_length=True),
  86. ImagesDataset(DATA_PATH['backgrounds']['valid'], mode='RGB', transforms=T.Compose([
  87. A.RandomAffineAndResize((512, 512), degrees=(-5, 5), translate=(0.1, 0.1), scale=(1, 1.2), shear=(-5, 5)),
  88. T.ToTensor()
  89. ])),
  90. ])
  91. dataset_valid = SampleDataset(dataset_valid, 50)
  92. dataloader_valid = DataLoader(dataset_valid,
  93. pin_memory=True,
  94. batch_size=args.batch_size,
  95. num_workers=args.num_workers)
  96. # Model
  97. model = MattingBase(args.model_backbone).cuda()
  98. if args.model_last_checkpoint is not None:
  99. load_matched_state_dict(model, torch.load(args.model_last_checkpoint))
  100. elif args.model_pretrain_initialization is not None:
  101. model.load_pretrained_deeplabv3_state_dict(torch.load(args.model_pretrain_initialization)['model_state'])
  102. optimizer = Adam([
  103. {'params': model.backbone.parameters(), 'lr': 1e-4},
  104. {'params': model.aspp.parameters(), 'lr': 5e-4},
  105. {'params': model.decoder.parameters(), 'lr': 5e-4}
  106. ])
  107. scaler = GradScaler()
  108. # Logging and checkpoints
  109. if not os.path.exists(f'checkpoint/{args.model_name}'):
  110. os.makedirs(f'checkpoint/{args.model_name}')
  111. writer = SummaryWriter(f'log/{args.model_name}')
  112. # Run loop
  113. for epoch in range(args.epoch_start, args.epoch_end):
  114. for i, ((true_pha, true_fgr), true_bgr) in enumerate(tqdm(dataloader_train)):
  115. step = epoch * len(dataloader_train) + i
  116. true_pha = true_pha.cuda(non_blocking=True)
  117. true_fgr = true_fgr.cuda(non_blocking=True)
  118. true_bgr = true_bgr.cuda(non_blocking=True)
  119. true_pha, true_fgr, true_bgr = random_crop(true_pha, true_fgr, true_bgr)
  120. true_src = true_bgr.clone()
  121. # Augment with shadow
  122. aug_shadow_idx = torch.rand(len(true_src)) < 0.3
  123. if aug_shadow_idx.any():
  124. aug_shadow = true_pha[aug_shadow_idx].mul(0.3 * random.random())
  125. aug_shadow = T.RandomAffine(degrees=(-5, 5), translate=(0.2, 0.2), scale=(0.5, 1.5), shear=(-5, 5))(aug_shadow)
  126. aug_shadow = kornia.filters.box_blur(aug_shadow, (random.choice(range(20, 40)),) * 2)
  127. true_src[aug_shadow_idx] = true_src[aug_shadow_idx].sub_(aug_shadow).clamp_(0, 1)
  128. del aug_shadow
  129. del aug_shadow_idx
  130. # Composite foreground onto source
  131. true_src = true_fgr * true_pha + true_src * (1 - true_pha)
  132. # Augment with noise
  133. aug_noise_idx = torch.rand(len(true_src)) < 0.4
  134. if aug_noise_idx.any():
  135. true_src[aug_noise_idx] = true_src[aug_noise_idx].add_(torch.randn_like(true_src[aug_noise_idx]).mul_(0.03 * random.random())).clamp_(0, 1)
  136. true_bgr[aug_noise_idx] = true_bgr[aug_noise_idx].add_(torch.randn_like(true_bgr[aug_noise_idx]).mul_(0.03 * random.random())).clamp_(0, 1)
  137. del aug_noise_idx
  138. # Augment background with jitter
  139. aug_jitter_idx = torch.rand(len(true_src)) < 0.8
  140. if aug_jitter_idx.any():
  141. true_bgr[aug_jitter_idx] = kornia.augmentation.ColorJitter(0.18, 0.18, 0.18, 0.1)(true_bgr[aug_jitter_idx])
  142. del aug_jitter_idx
  143. # Augment background with affine
  144. aug_affine_idx = torch.rand(len(true_bgr)) < 0.3
  145. if aug_affine_idx.any():
  146. true_bgr[aug_affine_idx] = T.RandomAffine(degrees=(-1, 1), translate=(0.01, 0.01))(true_bgr[aug_affine_idx])
  147. del aug_affine_idx
  148. with autocast():
  149. pred_pha, pred_fgr, pred_err = model(true_src, true_bgr)[:3]
  150. loss = compute_loss(pred_pha, pred_fgr, pred_err, true_pha, true_fgr)
  151. scaler.scale(loss).backward()
  152. scaler.step(optimizer)
  153. scaler.update()
  154. optimizer.zero_grad()
  155. if (i + 1) % args.log_train_loss_interval == 0:
  156. writer.add_scalar('loss', loss, step)
  157. if (i + 1) % args.log_train_images_interval == 0:
  158. writer.add_image('train_pred_pha', make_grid(pred_pha, nrow=5), step)
  159. writer.add_image('train_pred_fgr', make_grid(pred_fgr, nrow=5), step)
  160. writer.add_image('train_pred_com', make_grid(pred_fgr * pred_pha, nrow=5), step)
  161. writer.add_image('train_pred_err', make_grid(pred_err, nrow=5), step)
  162. writer.add_image('train_true_src', make_grid(true_src, nrow=5), step)
  163. writer.add_image('train_true_bgr', make_grid(true_bgr, nrow=5), step)
  164. del true_pha, true_fgr, true_bgr
  165. del pred_pha, pred_fgr, pred_err
  166. if (i + 1) % args.log_valid_interval == 0:
  167. valid(model, dataloader_valid, writer, step)
  168. if (step + 1) % args.checkpoint_interval == 0:
  169. torch.save(model.state_dict(), f'checkpoint/{args.model_name}/epoch-{epoch}-iter-{step}.pth')
  170. torch.save(model.state_dict(), f'checkpoint/{args.model_name}/epoch-{epoch}.pth')
  171. # --------------- Utils ---------------
  172. def compute_loss(pred_pha, pred_fgr, pred_err, true_pha, true_fgr):
  173. true_err = torch.abs(pred_pha.detach() - true_pha)
  174. true_msk = true_pha != 0
  175. return F.l1_loss(pred_pha, true_pha) + \
  176. F.l1_loss(kornia.sobel(pred_pha), kornia.sobel(true_pha)) + \
  177. F.l1_loss(pred_fgr * true_msk, true_fgr * true_msk) + \
  178. F.mse_loss(pred_err, true_err)
  179. def random_crop(*imgs):
  180. w = random.choice(range(256, 512))
  181. h = random.choice(range(256, 512))
  182. results = []
  183. for img in imgs:
  184. img = kornia.resize(img, (max(h, w), max(h, w)))
  185. img = kornia.center_crop(img, (h, w))
  186. results.append(img)
  187. return results
  188. def valid(model, dataloader, writer, step):
  189. model.eval()
  190. loss_total = 0
  191. loss_count = 0
  192. with torch.no_grad():
  193. for (true_pha, true_fgr), true_bgr in dataloader:
  194. batch_size = true_pha.size(0)
  195. true_pha = true_pha.cuda(non_blocking=True)
  196. true_fgr = true_fgr.cuda(non_blocking=True)
  197. true_bgr = true_bgr.cuda(non_blocking=True)
  198. true_src = true_pha * true_fgr + (1 - true_pha) * true_bgr
  199. pred_pha, pred_fgr, pred_err = model(true_src, true_bgr)[:3]
  200. loss = compute_loss(pred_pha, pred_fgr, pred_err, true_pha, true_fgr)
  201. loss_total += loss.cpu().item() * batch_size
  202. loss_count += batch_size
  203. writer.add_scalar('valid_loss', loss_total / loss_count, step)
  204. model.train()
  205. # --------------- Start ---------------
  206. if __name__ == '__main__':
  207. train()