compatibility changes
This commit is contained in:
parent
21e3f92461
commit
75735c8fae
@ -231,6 +231,7 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument('--slow_fast_gru', action='store_true', help="iterate the low-res GRUs more frequently")
|
parser.add_argument('--slow_fast_gru', action='store_true', help="iterate the low-res GRUs more frequently")
|
||||||
parser.add_argument('--n_gru_layers', type=int, default=3, help="number of hidden GRU levels")
|
parser.add_argument('--n_gru_layers', type=int, default=3, help="number of hidden GRU levels")
|
||||||
parser.add_argument('--max_disp', type=int, default=192, help="max disp of geometry encoding volume")
|
parser.add_argument('--max_disp', type=int, default=192, help="max disp of geometry encoding volume")
|
||||||
|
parser.add_argument('--freeze_backbone_params', action="store_true", help="freeze backbone parameters")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
model = torch.nn.DataParallel(IGEVStereo(args), device_ids=[0])
|
model = torch.nn.DataParallel(IGEVStereo(args), device_ids=[0])
|
||||||
@ -242,6 +243,12 @@ if __name__ == '__main__':
|
|||||||
assert args.restore_ckpt.endswith(".pth")
|
assert args.restore_ckpt.endswith(".pth")
|
||||||
logging.info("Loading checkpoint...")
|
logging.info("Loading checkpoint...")
|
||||||
checkpoint = torch.load(args.restore_ckpt)
|
checkpoint = torch.load(args.restore_ckpt)
|
||||||
|
|
||||||
|
unwanted_prefix = '_orig_mod.'
|
||||||
|
for k, v in list(checkpoint.items()):
|
||||||
|
if k.startswith(unwanted_prefix):
|
||||||
|
checkpoint[k[len(unwanted_prefix):]] = checkpoint.pop(k)
|
||||||
|
|
||||||
model.load_state_dict(checkpoint, strict=True)
|
model.load_state_dict(checkpoint, strict=True)
|
||||||
logging.info(f"Done loading checkpoint")
|
logging.info(f"Done loading checkpoint")
|
||||||
|
|
||||||
@ -258,7 +265,7 @@ if __name__ == '__main__':
|
|||||||
validate_kitti(model, iters=args.valid_iters, mixed_prec=use_mixed_precision, max_disp=args.max_disp)
|
validate_kitti(model, iters=args.valid_iters, mixed_prec=use_mixed_precision, max_disp=args.max_disp)
|
||||||
|
|
||||||
elif args.dataset in [f"middlebury_{s}" for s in 'FHQ']:
|
elif args.dataset in [f"middlebury_{s}" for s in 'FHQ']:
|
||||||
validate_middlebury(model, iters=args.valid_iters, split=args.dataset[-1], mixed_prec=use_mixed_precision)
|
validate_middlebury(model, iters=args.valid_iters, split=args.dataset[-1], mixed_prec=use_mixed_precision, max_disp=args.max_disp)
|
||||||
|
|
||||||
elif args.dataset == 'sceneflow':
|
elif args.dataset == 'sceneflow':
|
||||||
validate_sceneflow(model, iters=args.valid_iters, mixed_prec=use_mixed_precision)
|
validate_sceneflow(model, iters=args.valid_iters, mixed_prec=use_mixed_precision, max_disp=args.max_disp)
|
||||||
|
Loading…
Reference in New Issue
Block a user