Peter Lin 4 жил өмнө
parent
commit
eb14b80e6a
1 өөрчлөгдсөн 9 нэмэгдсэн , 8 устгасан
  1. 9 8
      inference_images.py

+ 9 - 8
inference_images.py

@@ -121,7 +121,6 @@ def writer(img, path):
 # Conversion loop
 with torch.no_grad():
     for i, (src, bgr) in enumerate(tqdm(dataloader)):
-        filename = dataset.datasets[0].filenames[i]
         src = src.to(device, non_blocking=True)
         bgr = bgr.to(device, non_blocking=True)
         
@@ -129,19 +128,21 @@ with torch.no_grad():
             pha, fgr, err, _ = model(src, bgr)
         elif args.model_type == 'mattingrefine':
             pha, fgr, _, _, err, ref = model(src, bgr)
-        elif args.model_type == 'mattingbm':
-            pha, fgr = model(src, bgr)
+
+        pathname = dataset.datasets[0].filenames[i]
+        pathname = os.path.relpath(pathname, args.images_src)
+        pathname = os.path.splitext(pathname)[0]
             
         if 'com' in args.output_types:
             com = torch.cat([fgr * pha.ne(0), pha], dim=1)
-            Thread(target=writer, args=(com, filename.replace(args.images_src, os.path.join(args.output_dir, 'com')).replace('.jpg', '.png'))).start()
+            Thread(target=writer, args=(com, os.path.join(args.output_dir, 'com', pathname + '.png'))).start()
         if 'pha' in args.output_types:
-            Thread(target=writer, args=(pha, filename.replace(args.images_src, os.path.join(args.output_dir, 'pha')).replace('.png', '.jpg'))).start()
+            Thread(target=writer, args=(pha, os.path.join(args.output_dir, 'pha', pathname + '.jpg'))).start()
         if 'fgr' in args.output_types:
-            Thread(target=writer, args=(fgr, filename.replace(args.images_src, os.path.join(args.output_dir, 'fgr')).replace('.png', '.jpg'))).start()
+            Thread(target=writer, args=(fgr, os.path.join(args.output_dir, 'fgr', pathname + '.jpg'))).start()
         if 'err' in args.output_types:
             err = F.interpolate(err, src.shape[2:], mode='bilinear', align_corners=False)
-            Thread(target=writer, args=(err, filename.replace(args.images_src, os.path.join(args.output_dir, 'err')).replace('.png', '.jpg'))).start()
+            Thread(target=writer, args=(err, os.path.join(args.output_dir, 'err', pathname + '.jpg'))).start()
         if 'ref' in args.output_types:
             ref = F.interpolate(ref, src.shape[2:], mode='nearest')
-            Thread(target=writer, args=(ref, filename.replace(args.images_src, os.path.join(args.output_dir, 'ref')).replace('.png', '.jpg'))).start()
+            Thread(target=writer, args=(ref, os.path.join(args.output_dir, 'ref', pathname + '.jpg'))).start()