benchmark.m 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. #!/usr/bin/octave
  2. arg_list = argv ();
  3. bench_path = arg_list{1};
  4. result_path = arg_list{2};
  5. gt_files = dir(fullfile(bench_path, 'pha', '*.png'));
  6. total_loss_mse = 0;
  7. total_loss_sad = 0;
  8. total_loss_gradient = 0;
  9. total_loss_connectivity = 0;
  10. total_fg_mse = 0;
  11. total_premult_mse = 0;
  12. for i = 1:length(gt_files)
  13. filename = gt_files(i).name;
  14. gt_fullname = fullfile(bench_path, 'pha', filename);
  15. gt_alpha = imread(gt_fullname);
  16. trimap = imread(fullfile(bench_path, 'trimap', filename));
  17. crop_edge = idivide(size(gt_alpha), 4) * 4;
  18. gt_alpha = gt_alpha(1:crop_edge(1), 1:crop_edge(2));
  19. trimap = trimap(1:crop_edge(1), 1:crop_edge(2));
  20. result_fullname = fullfile(result_path, 'pha', filename);%strrep(filename, '.png', '.jpg'));
  21. hat_alpha = imread(result_fullname)(1:crop_edge(1), 1:crop_edge(2));
  22. fg_hat_fullname = fullfile(result_path, 'fgr', filename);%strrep(filename, '.png', '.jpg'));
  23. fg_gt_fullname = fullfile(bench_path, 'fgr', filename);
  24. hat_fgr = imread(fg_hat_fullname)(1:crop_edge(1), 1:crop_edge(2), :);
  25. gt_fgr = imread(fg_gt_fullname)(1:crop_edge(1), 1:crop_edge(2), :);
  26. nonzero_alpha = gt_alpha > 0;
  27. % fprintf('size(gt_fgr) is %s\n', mat2str(size(gt_fgr)))
  28. fg_mse = mean(compute_mse_loss(hat_fgr .* nonzero_alpha, gt_fgr .* nonzero_alpha, trimap));
  29. mse = compute_mse_loss(hat_alpha, gt_alpha, trimap);
  30. sad = compute_sad_loss(hat_alpha, gt_alpha, trimap);
  31. grad = compute_gradient_loss(hat_alpha, gt_alpha, trimap);
  32. conn = compute_connectivity_error(hat_alpha, gt_alpha, trimap, 0.1);
  33. fprintf(2, strcat(filename, ',%.6f,%.3f,%.0f,%.0f,%.6f\n'), mse, sad, grad, conn, fg_mse);
  34. fflush(stderr);
  35. total_loss_mse += mse;
  36. total_loss_sad += sad;
  37. total_loss_gradient += grad;
  38. total_loss_connectivity += conn;
  39. total_fg_mse += fg_mse;
  40. end
  41. avg_loss_mse = total_loss_mse / length(gt_files);
  42. avg_loss_sad = total_loss_sad / length(gt_files);
  43. avg_loss_gradient = total_loss_gradient / length(gt_files);
  44. avg_loss_connectivity = total_loss_connectivity / length(gt_files);
  45. avg_loss_fg_mse = total_fg_mse / length(gt_files);
  46. fprintf('mse:%.6f,sad:%.3f,grad:%.0f,conn:%.0f,fg_mse:%.6f\n', avg_loss_mse, avg_loss_sad, avg_loss_gradient, avg_loss_connectivity, avg_loss_fg_mse);