add codes for saving submission images
This commit is contained in:
parent
71575036e5
commit
ec23805fa7
@ -96,4 +96,6 @@ class KITTIDataset(Dataset):
|
||||
return {"left": left_img,
|
||||
"right": right_img,
|
||||
"top_pad": top_pad,
|
||||
"right_pad": right_pad}
|
||||
"right_pad": right_pad,
|
||||
"left_filename": self.left_filenames[index],
|
||||
"right_filename": self.right_filenames[index]}
|
||||
|
4
kitti12_save.sh
Executable file
4
kitti12_save.sh
Executable file
@ -0,0 +1,4 @@
|
||||
#!/usr/bin/env bash
|
||||
set -x
|
||||
DATAPATH="/home/xyguo/data/kitti_2012/"
|
||||
python save_disp.py --datapath $DATAPATH --testlist ./filenames/kitti12_test.txt --model gwcnet-gc --loadckpt ./checkpoints/kitti12/gwcnet-gc/best.ckpt
|
4
kitti15_save.sh
Executable file
4
kitti15_save.sh
Executable file
@ -0,0 +1,4 @@
|
||||
#!/usr/bin/env bash
|
||||
set -x
|
||||
DATAPATH="/home/xyguo/data/kitti_2015/"
|
||||
python save_disp.py --datapath $DATAPATH --testlist ./filenames/kitti15_test.txt --model gwcnet-g --loadckpt ./checkpoints/kitti15/gwcnet-g/best.ckpt
|
82
save_disp.py
Normal file
82
save_disp.py
Normal file
@ -0,0 +1,82 @@
|
||||
from __future__ import print_function, division
|
||||
import argparse
|
||||
import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.parallel
|
||||
import torch.backends.cudnn as cudnn
|
||||
import torch.optim as optim
|
||||
import torch.utils.data
|
||||
from torch.autograd import Variable
|
||||
import torchvision.utils as vutils
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
import time
|
||||
from tensorboardX import SummaryWriter
|
||||
from datasets import __datasets__
|
||||
from models import __models__
|
||||
from utils import *
|
||||
from torch.utils.data import DataLoader
|
||||
import gc
|
||||
import skimage
|
||||
|
||||
cudnn.benchmark = True
|
||||
|
||||
parser = argparse.ArgumentParser(description='Group-wise Correlation Stereo Network (GwcNet)')
|
||||
parser.add_argument('--model', default='gwcnet-g', help='select a model structure', choices=__models__.keys())
|
||||
parser.add_argument('--maxdisp', type=int, default=192, help='maximum disparity')
|
||||
|
||||
parser.add_argument('--dataset', default='kitti', help='dataset name', choices=__datasets__.keys())
|
||||
parser.add_argument('--datapath', required=True, help='data path')
|
||||
parser.add_argument('--testlist', required=True, help='testing list')
|
||||
parser.add_argument('--loadckpt', required=True, help='load the weights from a specific checkpoint')
|
||||
|
||||
# parse arguments
|
||||
args = parser.parse_args()
|
||||
|
||||
# dataset, dataloader
|
||||
StereoDataset = __datasets__[args.dataset]
|
||||
test_dataset = StereoDataset(args.datapath, args.testlist, False)
|
||||
TestImgLoader = DataLoader(test_dataset, 1, shuffle=False, num_workers=4, drop_last=False)
|
||||
|
||||
# model, optimizer
|
||||
model = __models__[args.model](args.maxdisp)
|
||||
model = nn.DataParallel(model)
|
||||
model.cuda()
|
||||
|
||||
# load parameters
|
||||
print("loading model {}".format(args.loadckpt))
|
||||
state_dict = torch.load(args.loadckpt)
|
||||
model.load_state_dict(state_dict['model'])
|
||||
|
||||
|
||||
def test():
|
||||
os.makedirs('./predictions', exist_ok=True)
|
||||
for batch_idx, sample in enumerate(TestImgLoader):
|
||||
start_time = time.time()
|
||||
disp_est_np = tensor2numpy(test_sample(sample))
|
||||
top_pad_np = tensor2numpy(sample["top_pad"])
|
||||
right_pad_np = tensor2numpy(sample["right_pad"])
|
||||
left_filenames = sample["left_filename"]
|
||||
print('Iter {}/{}, time = {:3f}'.format(batch_idx, len(TestImgLoader),
|
||||
time.time() - start_time))
|
||||
|
||||
for disp_est, top_pad, right_pad, fn in zip(disp_est_np, top_pad_np, right_pad_np, left_filenames):
|
||||
assert len(disp_est.shape) == 2
|
||||
disp_est = np.array(disp_est[top_pad:, :-right_pad], dtype=np.float32)
|
||||
fn = os.path.join("predictions", fn.split('/')[-1])
|
||||
print("saving to", fn, disp_est.shape)
|
||||
disp_est_uint = np.round(disp_est * 256).astype(np.uint16)
|
||||
skimage.io.imsave(fn, disp_est_uint)
|
||||
|
||||
|
||||
# test one sample
|
||||
@make_nograd_func
|
||||
def test_sample(sample):
|
||||
model.eval()
|
||||
disp_ests = model(sample['left'].cuda(), sample['right'].cuda())
|
||||
return disp_ests[-1]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test()
|
Loading…
Reference in New Issue
Block a user