From 75735c8faed79fdeb862fe7475ee8f5898ba1bf5 Mon Sep 17 00:00:00 2001 From: HTensor Date: Thu, 27 Apr 2023 19:16:45 +0800 Subject: [PATCH] compatibility changes --- IGEV-Stereo/evaluate_stereo.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/IGEV-Stereo/evaluate_stereo.py b/IGEV-Stereo/evaluate_stereo.py index ae61eea..ba436d4 100644 --- a/IGEV-Stereo/evaluate_stereo.py +++ b/IGEV-Stereo/evaluate_stereo.py @@ -231,6 +231,7 @@ if __name__ == '__main__': parser.add_argument('--slow_fast_gru', action='store_true', help="iterate the low-res GRUs more frequently") parser.add_argument('--n_gru_layers', type=int, default=3, help="number of hidden GRU levels") parser.add_argument('--max_disp', type=int, default=192, help="max disp of geometry encoding volume") + parser.add_argument('--freeze_backbone_params', action="store_true", help="freeze backbone parameters") args = parser.parse_args() model = torch.nn.DataParallel(IGEVStereo(args), device_ids=[0]) @@ -242,6 +243,12 @@ if __name__ == '__main__': assert args.restore_ckpt.endswith(".pth") logging.info("Loading checkpoint...") checkpoint = torch.load(args.restore_ckpt) + + unwanted_prefix = '_orig_mod.' + for k, v in list(checkpoint.items()): + if k.startswith(unwanted_prefix): + checkpoint[k[len(unwanted_prefix):]] = checkpoint.pop(k) + model.load_state_dict(checkpoint, strict=True) logging.info(f"Done loading checkpoint") @@ -258,7 +265,7 @@ if __name__ == '__main__': validate_kitti(model, iters=args.valid_iters, mixed_prec=use_mixed_precision, max_disp=args.max_disp) elif args.dataset in [f"middlebury_{s}" for s in 'FHQ']: - validate_middlebury(model, iters=args.valid_iters, split=args.dataset[-1], mixed_prec=use_mixed_precision) + validate_middlebury(model, iters=args.valid_iters, split=args.dataset[-1], mixed_prec=use_mixed_precision, max_disp=args.max_disp) elif args.dataset == 'sceneflow': - validate_sceneflow(model, iters=args.valid_iters, mixed_prec=use_mixed_precision) + validate_sceneflow(model, iters=args.valid_iters, mixed_prec=use_mixed_precision, max_disp=args.max_disp)