wtf?
This commit is contained in:
parent
0a3613711b
commit
1080b823a5
@ -182,6 +182,7 @@ class IGEVStereo(nn.Module):
|
||||
spx_pred = self.spx(xspx)
|
||||
spx_pred = F.softmax(spx_pred, 1)
|
||||
|
||||
# Content Network
|
||||
cnet_list = self.cnet(image1, num_layers=self.args.n_gru_layers)
|
||||
net_list = [torch.tanh(x[0]) for x in cnet_list]
|
||||
inp_list = [torch.relu(x[1]) for x in cnet_list]
|
||||
|
@ -20,7 +20,7 @@ def count_parameters(model):
|
||||
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
|
||||
@torch.no_grad()
|
||||
def validate_eth3d(model, iters=32, mixed_prec=False):
|
||||
def validate_eth3d(model, iters=32, mixed_prec=False, max_disp=192):
|
||||
""" Peform validation using the ETH3D (train) split """
|
||||
model.eval()
|
||||
aug_params = {}
|
||||
@ -67,7 +67,7 @@ def validate_eth3d(model, iters=32, mixed_prec=False):
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def validate_kitti(model, iters=32, mixed_prec=False):
|
||||
def validate_kitti(model, iters=32, mixed_prec=False, max_disp=192):
|
||||
""" Peform validation using the KITTI-2015 (train) split """
|
||||
model.eval()
|
||||
aug_params = {}
|
||||
@ -96,7 +96,7 @@ def validate_kitti(model, iters=32, mixed_prec=False):
|
||||
epe = torch.sum((flow_pr - flow_gt)**2, dim=0).sqrt()
|
||||
|
||||
epe_flattened = epe.flatten()
|
||||
val = (valid_gt.flatten() >= 0.5) & (flow_gt.abs().flatten() < 192)
|
||||
val = (valid_gt.flatten() >= 0.5) & (flow_gt.abs().flatten() < max_disp)
|
||||
# val = valid_gt.flatten() >= 0.5
|
||||
|
||||
out = (epe_flattened > 3.0)
|
||||
@ -120,7 +120,7 @@ def validate_kitti(model, iters=32, mixed_prec=False):
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def validate_sceneflow(model, iters=32, mixed_prec=False):
|
||||
def validate_sceneflow(model, iters=32, mixed_prec=False, max_disp=192):
|
||||
""" Peform validation using the Scene Flow (TEST) split """
|
||||
model.eval()
|
||||
val_dataset = datasets.SceneFlowDatasets(dstype='frames_finalpass', things_test=True)
|
||||
@ -144,7 +144,7 @@ def validate_sceneflow(model, iters=32, mixed_prec=False):
|
||||
epe = torch.abs(flow_pr - flow_gt)
|
||||
|
||||
epe = epe.flatten()
|
||||
val = (valid_gt.flatten() >= 0.5) & (flow_gt.abs().flatten() < 192)
|
||||
val = (valid_gt.flatten() >= 0.5) & (flow_gt.abs().flatten() < max_disp)
|
||||
|
||||
if(np.isnan(epe[val].mean().item())):
|
||||
continue
|
||||
@ -169,7 +169,7 @@ def validate_sceneflow(model, iters=32, mixed_prec=False):
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def validate_middlebury(model, iters=32, split='F', mixed_prec=False):
|
||||
def validate_middlebury(model, iters=32, split='F', mixed_prec=False, max_disp=192):
|
||||
""" Peform validation using the Middlebury-V3 dataset """
|
||||
model.eval()
|
||||
aug_params = {}
|
||||
@ -196,7 +196,7 @@ def validate_middlebury(model, iters=32, split='F', mixed_prec=False):
|
||||
occ_mask = Image.open(imageL_file.replace('im0.png', 'mask0nocc.png')).convert('L')
|
||||
occ_mask = np.ascontiguousarray(occ_mask, dtype=np.float32).flatten()
|
||||
|
||||
val = (valid_gt.reshape(-1) >= 0.5) & (flow_gt[0].reshape(-1) < 192) & (occ_mask==255)
|
||||
val = (valid_gt.reshape(-1) >= 0.5) & (flow_gt[0].reshape(-1) < max_disp) & (occ_mask==255)
|
||||
out = (epe_flattened > 2.0)
|
||||
image_out = out[val].float().mean().item()
|
||||
image_epe = epe_flattened[val].mean().item()
|
||||
@ -252,10 +252,10 @@ if __name__ == '__main__':
|
||||
use_mixed_precision = args.corr_implementation.endswith("_cuda")
|
||||
|
||||
if args.dataset == 'eth3d':
|
||||
validate_eth3d(model, iters=args.valid_iters, mixed_prec=use_mixed_precision)
|
||||
validate_eth3d(model, iters=args.valid_iters, mixed_prec=use_mixed_precision, max_disp=args.max_disp)
|
||||
|
||||
elif args.dataset == 'kitti':
|
||||
validate_kitti(model, iters=args.valid_iters, mixed_prec=use_mixed_precision)
|
||||
validate_kitti(model, iters=args.valid_iters, mixed_prec=use_mixed_precision, max_disp=args.max_disp)
|
||||
|
||||
elif args.dataset in [f"middlebury_{s}" for s in 'FHQ']:
|
||||
validate_middlebury(model, iters=args.valid_iters, split=args.dataset[-1], mixed_prec=use_mixed_precision)
|
||||
|
@ -187,7 +187,7 @@ def train(args):
|
||||
save_path = Path(ckpt_path + '/%d_%s.pth' % (total_steps + 1, args.name))
|
||||
logging.info(f"Saving file {save_path.absolute()}")
|
||||
torch.save(model.state_dict(), save_path)
|
||||
results = validate_middlebury(model.module, iters=args.valid_iters)
|
||||
results = validate_middlebury(model.module, iters=args.valid_iters, max_disp=args.max_disp)
|
||||
logger.write_dict(results)
|
||||
model.train()
|
||||
model.module.freeze_bn()
|
||||
|
Loading…
Reference in New Issue
Block a user