10 lines
314 B
Python
10 lines
314 B
Python
|
import torch.nn.functional as F
|
||
|
|
||
|
|
||
|
def model_loss(disp_ests, disp_gt, mask):
|
||
|
weights = [0.5, 0.5, 0.7, 1.0]
|
||
|
all_losses = []
|
||
|
for disp_est, weight in zip(disp_ests, weights):
|
||
|
all_losses.append(weight * F.smooth_l1_loss(disp_est[mask], disp_gt[mask], size_average=True))
|
||
|
return sum(all_losses)
|