From 1080b823a5fd1387b83f01b197cf36811b072291 Mon Sep 17 00:00:00 2001 From: HTensor <27774381+HTensor@users.noreply.github.com> Date: Wed, 26 Apr 2023 19:50:17 +0800 Subject: [PATCH] wtf? --- IGEV-Stereo/core/igev_stereo.py | 1 + IGEV-Stereo/evaluate_stereo.py | 18 +++++++++--------- IGEV-Stereo/train_stereo.py | 2 +- 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/IGEV-Stereo/core/igev_stereo.py b/IGEV-Stereo/core/igev_stereo.py index acb316e..18d31ef 100644 --- a/IGEV-Stereo/core/igev_stereo.py +++ b/IGEV-Stereo/core/igev_stereo.py @@ -182,6 +182,7 @@ class IGEVStereo(nn.Module): spx_pred = self.spx(xspx) spx_pred = F.softmax(spx_pred, 1) + # Content Network cnet_list = self.cnet(image1, num_layers=self.args.n_gru_layers) net_list = [torch.tanh(x[0]) for x in cnet_list] inp_list = [torch.relu(x[1]) for x in cnet_list] diff --git a/IGEV-Stereo/evaluate_stereo.py b/IGEV-Stereo/evaluate_stereo.py index 73d35dd..a13d026 100644 --- a/IGEV-Stereo/evaluate_stereo.py +++ b/IGEV-Stereo/evaluate_stereo.py @@ -20,7 +20,7 @@ def count_parameters(model): return sum(p.numel() for p in model.parameters() if p.requires_grad) @torch.no_grad() -def validate_eth3d(model, iters=32, mixed_prec=False): +def validate_eth3d(model, iters=32, mixed_prec=False, max_disp=192): """ Peform validation using the ETH3D (train) split """ model.eval() aug_params = {} @@ -67,7 +67,7 @@ def validate_eth3d(model, iters=32, mixed_prec=False): @torch.no_grad() -def validate_kitti(model, iters=32, mixed_prec=False): +def validate_kitti(model, iters=32, mixed_prec=False, max_disp=192): """ Peform validation using the KITTI-2015 (train) split """ model.eval() aug_params = {} @@ -96,7 +96,7 @@ def validate_kitti(model, iters=32, mixed_prec=False): epe = torch.sum((flow_pr - flow_gt)**2, dim=0).sqrt() epe_flattened = epe.flatten() - val = (valid_gt.flatten() >= 0.5) & (flow_gt.abs().flatten() < 192) + val = (valid_gt.flatten() >= 0.5) & (flow_gt.abs().flatten() < max_disp) # val = valid_gt.flatten() >= 0.5 out = (epe_flattened > 3.0) @@ -120,7 +120,7 @@ def validate_kitti(model, iters=32, mixed_prec=False): @torch.no_grad() -def validate_sceneflow(model, iters=32, mixed_prec=False): +def validate_sceneflow(model, iters=32, mixed_prec=False, max_disp=192): """ Peform validation using the Scene Flow (TEST) split """ model.eval() val_dataset = datasets.SceneFlowDatasets(dstype='frames_finalpass', things_test=True) @@ -144,7 +144,7 @@ def validate_sceneflow(model, iters=32, mixed_prec=False): epe = torch.abs(flow_pr - flow_gt) epe = epe.flatten() - val = (valid_gt.flatten() >= 0.5) & (flow_gt.abs().flatten() < 192) + val = (valid_gt.flatten() >= 0.5) & (flow_gt.abs().flatten() < max_disp) if(np.isnan(epe[val].mean().item())): continue @@ -169,7 +169,7 @@ def validate_sceneflow(model, iters=32, mixed_prec=False): @torch.no_grad() -def validate_middlebury(model, iters=32, split='F', mixed_prec=False): +def validate_middlebury(model, iters=32, split='F', mixed_prec=False, max_disp=192): """ Peform validation using the Middlebury-V3 dataset """ model.eval() aug_params = {} @@ -196,7 +196,7 @@ def validate_middlebury(model, iters=32, split='F', mixed_prec=False): occ_mask = Image.open(imageL_file.replace('im0.png', 'mask0nocc.png')).convert('L') occ_mask = np.ascontiguousarray(occ_mask, dtype=np.float32).flatten() - val = (valid_gt.reshape(-1) >= 0.5) & (flow_gt[0].reshape(-1) < 192) & (occ_mask==255) + val = (valid_gt.reshape(-1) >= 0.5) & (flow_gt[0].reshape(-1) < max_disp) & (occ_mask==255) out = (epe_flattened > 2.0) image_out = out[val].float().mean().item() image_epe = epe_flattened[val].mean().item() @@ -252,10 +252,10 @@ if __name__ == '__main__': use_mixed_precision = args.corr_implementation.endswith("_cuda") if args.dataset == 'eth3d': - validate_eth3d(model, iters=args.valid_iters, mixed_prec=use_mixed_precision) + validate_eth3d(model, iters=args.valid_iters, mixed_prec=use_mixed_precision, max_disp=args.max_disp) elif args.dataset == 'kitti': - validate_kitti(model, iters=args.valid_iters, mixed_prec=use_mixed_precision) + validate_kitti(model, iters=args.valid_iters, mixed_prec=use_mixed_precision, max_disp=args.max_disp) elif args.dataset in [f"middlebury_{s}" for s in 'FHQ']: validate_middlebury(model, iters=args.valid_iters, split=args.dataset[-1], mixed_prec=use_mixed_precision) diff --git a/IGEV-Stereo/train_stereo.py b/IGEV-Stereo/train_stereo.py index 19262c5..00e2deb 100644 --- a/IGEV-Stereo/train_stereo.py +++ b/IGEV-Stereo/train_stereo.py @@ -187,7 +187,7 @@ def train(args): save_path = Path(ckpt_path + '/%d_%s.pth' % (total_steps + 1, args.name)) logging.info(f"Saving file {save_path.absolute()}") torch.save(model.state_dict(), save_path) - results = validate_middlebury(model.module, iters=args.valid_iters) + results = validate_middlebury(model.module, iters=args.valid_iters, max_disp=args.max_disp) logger.write_dict(results) model.train() model.module.freeze_bn()