GwcNet/utils/metrics.py
2019-04-14 17:34:58 +08:00

66 lines
2.4 KiB
Python

import torch
import torch.nn.functional as F
from utils.experiment import make_nograd_func
from torch.autograd import Variable
from torch import Tensor
# Update D1 from >3px to >=3px & >5%
# matlab code:
# E = abs(D_gt - D_est);
# n_err = length(find(D_gt > 0 & E > tau(1) & E. / abs(D_gt) > tau(2)));
# n_total = length(find(D_gt > 0));
# d_err = n_err / n_total;
def check_shape_for_metric_computation(*vars):
assert isinstance(vars, tuple)
for var in vars:
assert len(var.size()) == 3
assert var.size() == vars[0].size()
# a wrapper to compute metrics for each image individually
def compute_metric_for_each_image(metric_func):
def wrapper(D_ests, D_gts, masks, *nargs):
check_shape_for_metric_computation(D_ests, D_gts, masks)
bn = D_gts.shape[0] # batch size
results = [] # a list to store results for each image
# compute result one by one
for idx in range(bn):
# if tensor, then pick idx, else pass the same value
cur_nargs = [x[idx] if isinstance(x, (Tensor, Variable)) else x for x in nargs]
if masks[idx].float().mean() / (D_gts[idx] > 0).float().mean() < 0.1:
print("masks[idx].float().mean() too small, skip")
else:
ret = metric_func(D_ests[idx], D_gts[idx], masks[idx], *cur_nargs)
results.append(ret)
if len(results) == 0:
print("masks[idx].float().mean() too small for all images in this batch, return 0")
return torch.tensor(0, dtype=torch.float32, device=D_gts.device)
else:
return torch.stack(results).mean()
return wrapper
@make_nograd_func
@compute_metric_for_each_image
def D1_metric(D_est, D_gt, mask):
D_est, D_gt = D_est[mask], D_gt[mask]
E = torch.abs(D_gt - D_est)
err_mask = (E > 3) & (E / D_gt.abs() > 0.05)
return torch.mean(err_mask.float())
@make_nograd_func
@compute_metric_for_each_image
def Thres_metric(D_est, D_gt, mask, thres):
assert isinstance(thres, (int, float))
D_est, D_gt = D_est[mask], D_gt[mask]
E = torch.abs(D_gt - D_est)
err_mask = E > thres
return torch.mean(err_mask.float())
# NOTE: please do not use this to build up training loss
@make_nograd_func
@compute_metric_for_each_image
def EPE_metric(D_est, D_gt, mask):
D_est, D_gt = D_est[mask], D_gt[mask]
return F.l1_loss(D_est, D_gt, size_average=True)