from __future__ import print_function, division import argparse import os import torch import torch.nn as nn import torch.nn.parallel import torch.backends.cudnn as cudnn import torch.optim as optim import torch.utils.data from torch.autograd import Variable import torchvision.utils as vutils import torch.nn.functional as F import numpy as np import time from tensorboardX import SummaryWriter from datasets import __datasets__ from models import __models__ from utils import * from torch.utils.data import DataLoader import gc import skimage cudnn.benchmark = True parser = argparse.ArgumentParser(description='Group-wise Correlation Stereo Network (GwcNet)') parser.add_argument('--model', default='gwcnet-g', help='select a model structure', choices=__models__.keys()) parser.add_argument('--maxdisp', type=int, default=192, help='maximum disparity') parser.add_argument('--dataset', default='kitti', help='dataset name', choices=__datasets__.keys()) parser.add_argument('--datapath', required=True, help='data path') parser.add_argument('--testlist', required=True, help='testing list') parser.add_argument('--loadckpt', required=True, help='load the weights from a specific checkpoint') # parse arguments args = parser.parse_args() # dataset, dataloader StereoDataset = __datasets__[args.dataset] test_dataset = StereoDataset(args.datapath, args.testlist, False) TestImgLoader = DataLoader(test_dataset, 1, shuffle=False, num_workers=4, drop_last=False) # model, optimizer model = __models__[args.model](args.maxdisp) model = nn.DataParallel(model) model.cuda() # load parameters print("loading model {}".format(args.loadckpt)) state_dict = torch.load(args.loadckpt) model.load_state_dict(state_dict['model']) def test(): os.makedirs('./predictions', exist_ok=True) for batch_idx, sample in enumerate(TestImgLoader): start_time = time.time() disp_est_np = tensor2numpy(test_sample(sample)) top_pad_np = tensor2numpy(sample["top_pad"]) right_pad_np = tensor2numpy(sample["right_pad"]) left_filenames = sample["left_filename"] print('Iter {}/{}, time = {:3f}'.format(batch_idx, len(TestImgLoader), time.time() - start_time)) for disp_est, top_pad, right_pad, fn in zip(disp_est_np, top_pad_np, right_pad_np, left_filenames): assert len(disp_est.shape) == 2 disp_est = np.array(disp_est[top_pad:, :-right_pad], dtype=np.float32) fn = os.path.join("predictions", fn.split('/')[-1]) print("saving to", fn, disp_est.shape) disp_est_uint = np.round(disp_est * 256).astype(np.uint16) skimage.io.imsave(fn, disp_est_uint) # test one sample @make_nograd_func def test_sample(sample): model.eval() disp_ests = model(sample['left'].cuda(), sample['right'].cuda()) return disp_ests[-1] if __name__ == '__main__': test()