GwcNet/models/loss.py

10 lines
314 B
Python
Raw Normal View History

2019-04-14 17:34:58 +08:00
import torch.nn.functional as F
def model_loss(disp_ests, disp_gt, mask):
weights = [0.5, 0.5, 0.7, 1.0]
all_losses = []
for disp_est, weight in zip(disp_ests, weights):
all_losses.append(weight * F.smooth_l1_loss(disp_est[mask], disp_gt[mask], size_average=True))
return sum(all_losses)