compatibility changes
This commit is contained in:
parent
21e3f92461
commit
75735c8fae
@ -231,6 +231,7 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--slow_fast_gru', action='store_true', help="iterate the low-res GRUs more frequently")
|
||||
parser.add_argument('--n_gru_layers', type=int, default=3, help="number of hidden GRU levels")
|
||||
parser.add_argument('--max_disp', type=int, default=192, help="max disp of geometry encoding volume")
|
||||
parser.add_argument('--freeze_backbone_params', action="store_true", help="freeze backbone parameters")
|
||||
args = parser.parse_args()
|
||||
|
||||
model = torch.nn.DataParallel(IGEVStereo(args), device_ids=[0])
|
||||
@ -242,6 +243,12 @@ if __name__ == '__main__':
|
||||
assert args.restore_ckpt.endswith(".pth")
|
||||
logging.info("Loading checkpoint...")
|
||||
checkpoint = torch.load(args.restore_ckpt)
|
||||
|
||||
unwanted_prefix = '_orig_mod.'
|
||||
for k, v in list(checkpoint.items()):
|
||||
if k.startswith(unwanted_prefix):
|
||||
checkpoint[k[len(unwanted_prefix):]] = checkpoint.pop(k)
|
||||
|
||||
model.load_state_dict(checkpoint, strict=True)
|
||||
logging.info(f"Done loading checkpoint")
|
||||
|
||||
@ -258,7 +265,7 @@ if __name__ == '__main__':
|
||||
validate_kitti(model, iters=args.valid_iters, mixed_prec=use_mixed_precision, max_disp=args.max_disp)
|
||||
|
||||
elif args.dataset in [f"middlebury_{s}" for s in 'FHQ']:
|
||||
validate_middlebury(model, iters=args.valid_iters, split=args.dataset[-1], mixed_prec=use_mixed_precision)
|
||||
validate_middlebury(model, iters=args.valid_iters, split=args.dataset[-1], mixed_prec=use_mixed_precision, max_disp=args.max_disp)
|
||||
|
||||
elif args.dataset == 'sceneflow':
|
||||
validate_sceneflow(model, iters=args.valid_iters, mixed_prec=use_mixed_precision)
|
||||
validate_sceneflow(model, iters=args.valid_iters, mixed_prec=use_mixed_precision, max_disp=args.max_disp)
|
||||
|
Loading…
Reference in New Issue
Block a user