import torch.nn.functional as F def model_loss(disp_ests, disp_gt, mask): weights = [0.5, 0.5, 0.7, 1.0] all_losses = [] for disp_est, weight in zip(disp_ests, weights): all_losses.append(weight * F.smooth_l1_loss(disp_est[mask], disp_gt[mask], size_average=True)) return sum(all_losses)