diff --git a/datasets/kitti_dataset.py b/datasets/kitti_dataset.py index a174a98..3349c60 100644 --- a/datasets/kitti_dataset.py +++ b/datasets/kitti_dataset.py @@ -96,4 +96,6 @@ class KITTIDataset(Dataset): return {"left": left_img, "right": right_img, "top_pad": top_pad, - "right_pad": right_pad} + "right_pad": right_pad, + "left_filename": self.left_filenames[index], + "right_filename": self.right_filenames[index]} diff --git a/kitti12_save.sh b/kitti12_save.sh new file mode 100755 index 0000000..f122ffd --- /dev/null +++ b/kitti12_save.sh @@ -0,0 +1,4 @@ +#!/usr/bin/env bash +set -x +DATAPATH="/home/xyguo/data/kitti_2012/" +python save_disp.py --datapath $DATAPATH --testlist ./filenames/kitti12_test.txt --model gwcnet-gc --loadckpt ./checkpoints/kitti12/gwcnet-gc/best.ckpt diff --git a/kitti15_save.sh b/kitti15_save.sh new file mode 100755 index 0000000..1b78c0d --- /dev/null +++ b/kitti15_save.sh @@ -0,0 +1,4 @@ +#!/usr/bin/env bash +set -x +DATAPATH="/home/xyguo/data/kitti_2015/" +python save_disp.py --datapath $DATAPATH --testlist ./filenames/kitti15_test.txt --model gwcnet-g --loadckpt ./checkpoints/kitti15/gwcnet-g/best.ckpt diff --git a/save_disp.py b/save_disp.py new file mode 100644 index 0000000..7daf836 --- /dev/null +++ b/save_disp.py @@ -0,0 +1,82 @@ +from __future__ import print_function, division +import argparse +import os +import torch +import torch.nn as nn +import torch.nn.parallel +import torch.backends.cudnn as cudnn +import torch.optim as optim +import torch.utils.data +from torch.autograd import Variable +import torchvision.utils as vutils +import torch.nn.functional as F +import numpy as np +import time +from tensorboardX import SummaryWriter +from datasets import __datasets__ +from models import __models__ +from utils import * +from torch.utils.data import DataLoader +import gc +import skimage + +cudnn.benchmark = True + +parser = argparse.ArgumentParser(description='Group-wise Correlation Stereo Network (GwcNet)') +parser.add_argument('--model', default='gwcnet-g', help='select a model structure', choices=__models__.keys()) +parser.add_argument('--maxdisp', type=int, default=192, help='maximum disparity') + +parser.add_argument('--dataset', default='kitti', help='dataset name', choices=__datasets__.keys()) +parser.add_argument('--datapath', required=True, help='data path') +parser.add_argument('--testlist', required=True, help='testing list') +parser.add_argument('--loadckpt', required=True, help='load the weights from a specific checkpoint') + +# parse arguments +args = parser.parse_args() + +# dataset, dataloader +StereoDataset = __datasets__[args.dataset] +test_dataset = StereoDataset(args.datapath, args.testlist, False) +TestImgLoader = DataLoader(test_dataset, 1, shuffle=False, num_workers=4, drop_last=False) + +# model, optimizer +model = __models__[args.model](args.maxdisp) +model = nn.DataParallel(model) +model.cuda() + +# load parameters +print("loading model {}".format(args.loadckpt)) +state_dict = torch.load(args.loadckpt) +model.load_state_dict(state_dict['model']) + + +def test(): + os.makedirs('./predictions', exist_ok=True) + for batch_idx, sample in enumerate(TestImgLoader): + start_time = time.time() + disp_est_np = tensor2numpy(test_sample(sample)) + top_pad_np = tensor2numpy(sample["top_pad"]) + right_pad_np = tensor2numpy(sample["right_pad"]) + left_filenames = sample["left_filename"] + print('Iter {}/{}, time = {:3f}'.format(batch_idx, len(TestImgLoader), + time.time() - start_time)) + + for disp_est, top_pad, right_pad, fn in zip(disp_est_np, top_pad_np, right_pad_np, left_filenames): + assert len(disp_est.shape) == 2 + disp_est = np.array(disp_est[top_pad:, :-right_pad], dtype=np.float32) + fn = os.path.join("predictions", fn.split('/')[-1]) + print("saving to", fn, disp_est.shape) + disp_est_uint = np.round(disp_est * 256).astype(np.uint16) + skimage.io.imsave(fn, disp_est_uint) + + +# test one sample +@make_nograd_func +def test_sample(sample): + model.eval() + disp_ests = model(sample['left'].cuda(), sample['right'].cuda()) + return disp_ests[-1] + + +if __name__ == '__main__': + test()