add codes for saving submission images
This commit is contained in:
parent
71575036e5
commit
ec23805fa7
@ -96,4 +96,6 @@ class KITTIDataset(Dataset):
|
|||||||
return {"left": left_img,
|
return {"left": left_img,
|
||||||
"right": right_img,
|
"right": right_img,
|
||||||
"top_pad": top_pad,
|
"top_pad": top_pad,
|
||||||
"right_pad": right_pad}
|
"right_pad": right_pad,
|
||||||
|
"left_filename": self.left_filenames[index],
|
||||||
|
"right_filename": self.right_filenames[index]}
|
||||||
|
4
kitti12_save.sh
Executable file
4
kitti12_save.sh
Executable file
@ -0,0 +1,4 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
set -x
|
||||||
|
DATAPATH="/home/xyguo/data/kitti_2012/"
|
||||||
|
python save_disp.py --datapath $DATAPATH --testlist ./filenames/kitti12_test.txt --model gwcnet-gc --loadckpt ./checkpoints/kitti12/gwcnet-gc/best.ckpt
|
4
kitti15_save.sh
Executable file
4
kitti15_save.sh
Executable file
@ -0,0 +1,4 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
set -x
|
||||||
|
DATAPATH="/home/xyguo/data/kitti_2015/"
|
||||||
|
python save_disp.py --datapath $DATAPATH --testlist ./filenames/kitti15_test.txt --model gwcnet-g --loadckpt ./checkpoints/kitti15/gwcnet-g/best.ckpt
|
82
save_disp.py
Normal file
82
save_disp.py
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
from __future__ import print_function, division
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.parallel
|
||||||
|
import torch.backends.cudnn as cudnn
|
||||||
|
import torch.optim as optim
|
||||||
|
import torch.utils.data
|
||||||
|
from torch.autograd import Variable
|
||||||
|
import torchvision.utils as vutils
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import numpy as np
|
||||||
|
import time
|
||||||
|
from tensorboardX import SummaryWriter
|
||||||
|
from datasets import __datasets__
|
||||||
|
from models import __models__
|
||||||
|
from utils import *
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
import gc
|
||||||
|
import skimage
|
||||||
|
|
||||||
|
cudnn.benchmark = True
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description='Group-wise Correlation Stereo Network (GwcNet)')
|
||||||
|
parser.add_argument('--model', default='gwcnet-g', help='select a model structure', choices=__models__.keys())
|
||||||
|
parser.add_argument('--maxdisp', type=int, default=192, help='maximum disparity')
|
||||||
|
|
||||||
|
parser.add_argument('--dataset', default='kitti', help='dataset name', choices=__datasets__.keys())
|
||||||
|
parser.add_argument('--datapath', required=True, help='data path')
|
||||||
|
parser.add_argument('--testlist', required=True, help='testing list')
|
||||||
|
parser.add_argument('--loadckpt', required=True, help='load the weights from a specific checkpoint')
|
||||||
|
|
||||||
|
# parse arguments
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# dataset, dataloader
|
||||||
|
StereoDataset = __datasets__[args.dataset]
|
||||||
|
test_dataset = StereoDataset(args.datapath, args.testlist, False)
|
||||||
|
TestImgLoader = DataLoader(test_dataset, 1, shuffle=False, num_workers=4, drop_last=False)
|
||||||
|
|
||||||
|
# model, optimizer
|
||||||
|
model = __models__[args.model](args.maxdisp)
|
||||||
|
model = nn.DataParallel(model)
|
||||||
|
model.cuda()
|
||||||
|
|
||||||
|
# load parameters
|
||||||
|
print("loading model {}".format(args.loadckpt))
|
||||||
|
state_dict = torch.load(args.loadckpt)
|
||||||
|
model.load_state_dict(state_dict['model'])
|
||||||
|
|
||||||
|
|
||||||
|
def test():
|
||||||
|
os.makedirs('./predictions', exist_ok=True)
|
||||||
|
for batch_idx, sample in enumerate(TestImgLoader):
|
||||||
|
start_time = time.time()
|
||||||
|
disp_est_np = tensor2numpy(test_sample(sample))
|
||||||
|
top_pad_np = tensor2numpy(sample["top_pad"])
|
||||||
|
right_pad_np = tensor2numpy(sample["right_pad"])
|
||||||
|
left_filenames = sample["left_filename"]
|
||||||
|
print('Iter {}/{}, time = {:3f}'.format(batch_idx, len(TestImgLoader),
|
||||||
|
time.time() - start_time))
|
||||||
|
|
||||||
|
for disp_est, top_pad, right_pad, fn in zip(disp_est_np, top_pad_np, right_pad_np, left_filenames):
|
||||||
|
assert len(disp_est.shape) == 2
|
||||||
|
disp_est = np.array(disp_est[top_pad:, :-right_pad], dtype=np.float32)
|
||||||
|
fn = os.path.join("predictions", fn.split('/')[-1])
|
||||||
|
print("saving to", fn, disp_est.shape)
|
||||||
|
disp_est_uint = np.round(disp_est * 256).astype(np.uint16)
|
||||||
|
skimage.io.imsave(fn, disp_est_uint)
|
||||||
|
|
||||||
|
|
||||||
|
# test one sample
|
||||||
|
@make_nograd_func
|
||||||
|
def test_sample(sample):
|
||||||
|
model.eval()
|
||||||
|
disp_ests = model(sample['left'].cuda(), sample['right'].cuda())
|
||||||
|
return disp_ests[-1]
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test()
|
Loading…
Reference in New Issue
Block a user