compatibility changes

This commit is contained in:
HTensor 2023-04-27 19:16:45 +08:00
parent 21e3f92461
commit 75735c8fae

View File

@ -231,6 +231,7 @@ if __name__ == '__main__':
parser.add_argument('--slow_fast_gru', action='store_true', help="iterate the low-res GRUs more frequently") parser.add_argument('--slow_fast_gru', action='store_true', help="iterate the low-res GRUs more frequently")
parser.add_argument('--n_gru_layers', type=int, default=3, help="number of hidden GRU levels") parser.add_argument('--n_gru_layers', type=int, default=3, help="number of hidden GRU levels")
parser.add_argument('--max_disp', type=int, default=192, help="max disp of geometry encoding volume") parser.add_argument('--max_disp', type=int, default=192, help="max disp of geometry encoding volume")
parser.add_argument('--freeze_backbone_params', action="store_true", help="freeze backbone parameters")
args = parser.parse_args() args = parser.parse_args()
model = torch.nn.DataParallel(IGEVStereo(args), device_ids=[0]) model = torch.nn.DataParallel(IGEVStereo(args), device_ids=[0])
@ -242,6 +243,12 @@ if __name__ == '__main__':
assert args.restore_ckpt.endswith(".pth") assert args.restore_ckpt.endswith(".pth")
logging.info("Loading checkpoint...") logging.info("Loading checkpoint...")
checkpoint = torch.load(args.restore_ckpt) checkpoint = torch.load(args.restore_ckpt)
unwanted_prefix = '_orig_mod.'
for k, v in list(checkpoint.items()):
if k.startswith(unwanted_prefix):
checkpoint[k[len(unwanted_prefix):]] = checkpoint.pop(k)
model.load_state_dict(checkpoint, strict=True) model.load_state_dict(checkpoint, strict=True)
logging.info(f"Done loading checkpoint") logging.info(f"Done loading checkpoint")
@ -258,7 +265,7 @@ if __name__ == '__main__':
validate_kitti(model, iters=args.valid_iters, mixed_prec=use_mixed_precision, max_disp=args.max_disp) validate_kitti(model, iters=args.valid_iters, mixed_prec=use_mixed_precision, max_disp=args.max_disp)
elif args.dataset in [f"middlebury_{s}" for s in 'FHQ']: elif args.dataset in [f"middlebury_{s}" for s in 'FHQ']:
validate_middlebury(model, iters=args.valid_iters, split=args.dataset[-1], mixed_prec=use_mixed_precision) validate_middlebury(model, iters=args.valid_iters, split=args.dataset[-1], mixed_prec=use_mixed_precision, max_disp=args.max_disp)
elif args.dataset == 'sceneflow': elif args.dataset == 'sceneflow':
validate_sceneflow(model, iters=args.valid_iters, mixed_prec=use_mixed_precision) validate_sceneflow(model, iters=args.valid_iters, mixed_prec=use_mixed_precision, max_disp=args.max_disp)