|
@@ -121,7 +121,6 @@ def writer(img, path):
|
|
|
# Conversion loop
|
|
|
with torch.no_grad():
|
|
|
for i, (src, bgr) in enumerate(tqdm(dataloader)):
|
|
|
- filename = dataset.datasets[0].filenames[i]
|
|
|
src = src.to(device, non_blocking=True)
|
|
|
bgr = bgr.to(device, non_blocking=True)
|
|
|
|
|
@@ -129,19 +128,21 @@ with torch.no_grad():
|
|
|
pha, fgr, err, _ = model(src, bgr)
|
|
|
elif args.model_type == 'mattingrefine':
|
|
|
pha, fgr, _, _, err, ref = model(src, bgr)
|
|
|
- elif args.model_type == 'mattingbm':
|
|
|
- pha, fgr = model(src, bgr)
|
|
|
+
|
|
|
+ pathname = dataset.datasets[0].filenames[i]
|
|
|
+ pathname = os.path.relpath(pathname, args.images_src)
|
|
|
+ pathname = os.path.splitext(pathname)[0]
|
|
|
|
|
|
if 'com' in args.output_types:
|
|
|
com = torch.cat([fgr * pha.ne(0), pha], dim=1)
|
|
|
- Thread(target=writer, args=(com, filename.replace(args.images_src, os.path.join(args.output_dir, 'com')).replace('.jpg', '.png'))).start()
|
|
|
+ Thread(target=writer, args=(com, os.path.join(args.output_dir, 'com', pathname + '.png'))).start()
|
|
|
if 'pha' in args.output_types:
|
|
|
- Thread(target=writer, args=(pha, filename.replace(args.images_src, os.path.join(args.output_dir, 'pha')).replace('.png', '.jpg'))).start()
|
|
|
+ Thread(target=writer, args=(pha, os.path.join(args.output_dir, 'pha', pathname + '.jpg'))).start()
|
|
|
if 'fgr' in args.output_types:
|
|
|
- Thread(target=writer, args=(fgr, filename.replace(args.images_src, os.path.join(args.output_dir, 'fgr')).replace('.png', '.jpg'))).start()
|
|
|
+ Thread(target=writer, args=(fgr, os.path.join(args.output_dir, 'fgr', pathname + '.jpg'))).start()
|
|
|
if 'err' in args.output_types:
|
|
|
err = F.interpolate(err, src.shape[2:], mode='bilinear', align_corners=False)
|
|
|
- Thread(target=writer, args=(err, filename.replace(args.images_src, os.path.join(args.output_dir, 'err')).replace('.png', '.jpg'))).start()
|
|
|
+ Thread(target=writer, args=(err, os.path.join(args.output_dir, 'err', pathname + '.jpg'))).start()
|
|
|
if 'ref' in args.output_types:
|
|
|
ref = F.interpolate(ref, src.shape[2:], mode='nearest')
|
|
|
- Thread(target=writer, args=(ref, filename.replace(args.images_src, os.path.join(args.output_dir, 'ref')).replace('.png', '.jpg'))).start()
|
|
|
+ Thread(target=writer, args=(ref, os.path.join(args.output_dir, 'ref', pathname + '.jpg'))).start()
|