265 lines
11 KiB
265 lines
11 KiB
from __future__ import print_function, division
import sys
import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import argparse
import time
import logging
import numpy as np
import torch
from tqdm import tqdm
from igev_stereo import IGEVStereo, autocast
import stereo_datasets as datasets
from utils.utils import InputPadder
from PIL import Image
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def validate_eth3d(model, iters=32, mixed_prec=False, max_disp=192):
""" Peform validation using the ETH3D (train) split """
aug_params = {}
val_dataset = datasets.ETH3D(aug_params)
out_list, epe_list = [], []
for val_id in range(len(val_dataset)):
(imageL_file, imageR_file, GT_file), image1, image2, flow_gt, valid_gt = val_dataset[val_id]
image1 = image1[None].cuda()
image2 = image2[None].cuda()
padder = InputPadder(image1.shape, divis_by=32)
image1, image2 = padder.pad(image1, image2)
with autocast(enabled=mixed_prec):
flow_pr = model(image1, image2, iters=iters, test_mode=True)
flow_pr = padder.unpad(flow_pr.float()).cpu().squeeze(0)
assert flow_pr.shape == flow_gt.shape, (flow_pr.shape, flow_gt.shape)
epe = torch.sum((flow_pr - flow_gt)**2, dim=0).sqrt()
epe_flattened = epe.flatten()
occ_mask = Image.open(GT_file.replace('disp0GT.pfm', 'mask0nocc.png'))
occ_mask = np.ascontiguousarray(occ_mask).flatten()
val = (valid_gt.flatten() >= 0.5) & (occ_mask == 255)
# val = (valid_gt.flatten() >= 0.5)
out = (epe_flattened > 1.0)
image_out = out[val].float().mean().item()
image_epe = epe_flattened[val].mean().item()
logging.info(f"ETH3D {val_id+1} out of {len(val_dataset)}. EPE {round(image_epe,4)} D1 {round(image_out,4)}")
epe_list = np.array(epe_list)
out_list = np.array(out_list)
epe = np.mean(epe_list)
d1 = 100 * np.mean(out_list)
print("Validation ETH3D: EPE %f, D1 %f" % (epe, d1))
return {'eth3d-epe': epe, 'eth3d-d1': d1}
def validate_kitti(model, iters=32, mixed_prec=False, max_disp=192):
""" Peform validation using the KITTI-2015 (train) split """
aug_params = {}
val_dataset = datasets.KITTI(aug_params, image_set='training')
torch.backends.cudnn.benchmark = True
out_list, epe_list, elapsed_list = [], [], []
for val_id in range(len(val_dataset)):
_, image1, image2, flow_gt, valid_gt = val_dataset[val_id]
image1 = image1[None].cuda()
image2 = image2[None].cuda()
padder = InputPadder(image1.shape, divis_by=32)
image1, image2 = padder.pad(image1, image2)
with autocast(enabled=mixed_prec):
start = time.time()
flow_pr = model(image1, image2, iters=iters, test_mode=True)
end = time.time()
if val_id > 50:
flow_pr = padder.unpad(flow_pr).cpu().squeeze(0)
assert flow_pr.shape == flow_gt.shape, (flow_pr.shape, flow_gt.shape)
epe = torch.sum((flow_pr - flow_gt)**2, dim=0).sqrt()
epe_flattened = epe.flatten()
val = (valid_gt.flatten() >= 0.5) & (flow_gt.abs().flatten() < max_disp)
# val = valid_gt.flatten() >= 0.5
out = (epe_flattened > 3.0)
image_out = out[val].float().mean().item()
image_epe = epe_flattened[val].mean().item()
if val_id < 9 or (val_id+1)%10 == 0:
logging.info(f"KITTI Iter {val_id+1} out of {len(val_dataset)}. EPE {round(image_epe,4)} D1 {round(image_out,4)}. Runtime: {format(end-start, '.3f')}s ({format(1/(end-start), '.2f')}-FPS)")
epe_list = np.array(epe_list)
out_list = np.concatenate(out_list)
epe = np.mean(epe_list)
d1 = 100 * np.mean(out_list)
avg_runtime = np.mean(elapsed_list)
print(f"Validation KITTI: EPE {epe}, D1 {d1}, {format(1/avg_runtime, '.2f')}-FPS ({format(avg_runtime, '.3f')}s)")
return {'kitti-epe': epe, 'kitti-d1': d1}
def validate_sceneflow(model, iters=32, mixed_prec=False, max_disp=192):
""" Peform validation using the Scene Flow (TEST) split """
val_dataset = datasets.SceneFlowDatasets(dstype='frames_finalpass', things_test=True)
out_list, epe_list = [], []
for val_id in tqdm(range(len(val_dataset))):
_, image1, image2, flow_gt, valid_gt = val_dataset[val_id]
image1 = image1[None].cuda()
image2 = image2[None].cuda()
padder = InputPadder(image1.shape, divis_by=32)
image1, image2 = padder.pad(image1, image2)
with autocast(enabled=mixed_prec):
flow_pr = model(image1, image2, iters=iters, test_mode=True)
flow_pr = padder.unpad(flow_pr).cpu().squeeze(0)
assert flow_pr.shape == flow_gt.shape, (flow_pr.shape, flow_gt.shape)
# epe = torch.sum((flow_pr - flow_gt)**2, dim=0).sqrt()
epe = torch.abs(flow_pr - flow_gt)
epe = epe.flatten()
val = (valid_gt.flatten() >= 0.5) & (flow_gt.abs().flatten() < max_disp)
out = (epe > 3.0)
# if val_id == 400:
# break
epe_list = np.array(epe_list)
out_list = np.concatenate(out_list)
epe = np.mean(epe_list)
d1 = 100 * np.mean(out_list)
f = open('test.txt', 'a')
f.write("Validation Scene Flow: %f, %f\n" % (epe, d1))
print("Validation Scene Flow: %f, %f" % (epe, d1))
return {'scene-flow-epe': epe, 'scene-flow-d1': d1}
def validate_middlebury(model, iters=32, split='F', mixed_prec=False, max_disp=192):
""" Peform validation using the Middlebury-V3 dataset """
aug_params = {}
val_dataset = datasets.Middlebury(aug_params, split=split)
out_list, epe_list = [], []
for val_id in range(len(val_dataset)):
(imageL_file, _, _), image1, image2, flow_gt, valid_gt = val_dataset[val_id]
image1 = image1[None].cuda()
image2 = image2[None].cuda()
padder = InputPadder(image1.shape, divis_by=32)
image1, image2 = padder.pad(image1, image2)
with autocast(enabled=mixed_prec):
flow_pr = model(image1, image2, iters=iters, test_mode=True)
flow_pr = padder.unpad(flow_pr).cpu().squeeze(0)
assert flow_pr.shape == flow_gt.shape, (flow_pr.shape, flow_gt.shape)
epe = torch.sum((flow_pr - flow_gt)**2, dim=0).sqrt()
epe_flattened = epe.flatten()
occ_mask = Image.open(imageL_file.replace('im0.png', 'mask0nocc.png')).convert('L')
occ_mask = np.ascontiguousarray(occ_mask, dtype=np.float32).flatten()
val = (valid_gt.reshape(-1) >= 0.5) & (flow_gt[0].reshape(-1) < max_disp) & (occ_mask==255)
out = (epe_flattened > 2.0)
image_out = out[val].float().mean().item()
image_epe = epe_flattened[val].mean().item()
logging.info(f"Middlebury Iter {val_id+1} out of {len(val_dataset)}. EPE {round(image_epe,4)} D1 {round(image_out,4)}")
epe_list = np.array(epe_list)
out_list = np.array(out_list)
epe = np.mean(epe_list)
d1 = 100 * np.mean(out_list)
print(f"Validation Middlebury{split}: EPE {epe}, D1 {d1}")
return {f'middlebury{split}-epe': epe, f'middlebury{split}-d1': d1}
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--restore_ckpt', help="restore checkpoint", default='./pretrained_models/sceneflow/sceneflow.pth')
parser.add_argument('--dataset', help="dataset for evaluation", default='sceneflow', choices=["eth3d", "kitti", "sceneflow"] + [f"middlebury_{s}" for s in 'FHQ'])
parser.add_argument('--mixed_precision', default=False, action='store_true', help='use mixed precision')
parser.add_argument('--valid_iters', type=int, default=32, help='number of flow-field updates during forward pass')
# Architecure choices
parser.add_argument('--hidden_dims', nargs='+', type=int, default=[128]*3, help="hidden state and context dimensions")
parser.add_argument('--corr_implementation', choices=["reg", "alt", "reg_cuda", "alt_cuda"], default="reg", help="correlation volume implementation")
parser.add_argument('--shared_backbone', action='store_true', help="use a single backbone for the context and feature encoders")
parser.add_argument('--corr_levels', type=int, default=2, help="number of levels in the correlation pyramid")
parser.add_argument('--corr_radius', type=int, default=4, help="width of the correlation pyramid")
parser.add_argument('--n_downsample', type=int, default=2, help="resolution of the disparity field (1/2^K)")
parser.add_argument('--slow_fast_gru', action='store_true', help="iterate the low-res GRUs more frequently")
parser.add_argument('--n_gru_layers', type=int, default=3, help="number of hidden GRU levels")
parser.add_argument('--max_disp', type=int, default=192, help="max disp of geometry encoding volume")
args = parser.parse_args()
model = torch.nn.DataParallel(IGEVStereo(args), device_ids=[0])
format='%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s')
if args.restore_ckpt is not None:
assert args.restore_ckpt.endswith(".pth")
logging.info("Loading checkpoint...")
checkpoint = torch.load(args.restore_ckpt)
model.load_state_dict(checkpoint, strict=True)
logging.info(f"Done loading checkpoint")
print(f"The model has {format(count_parameters(model)/1e6, '.2f')}M learnable parameters.")
use_mixed_precision = args.corr_implementation.endswith("_cuda")
if args.dataset == 'eth3d':
validate_eth3d(model, iters=args.valid_iters, mixed_prec=use_mixed_precision, max_disp=args.max_disp)
elif args.dataset == 'kitti':
validate_kitti(model, iters=args.valid_iters, mixed_prec=use_mixed_precision, max_disp=args.max_disp)
elif args.dataset in [f"middlebury_{s}" for s in 'FHQ']:
validate_middlebury(model, iters=args.valid_iters, split=args.dataset[-1], mixed_prec=use_mixed_precision)
elif args.dataset == 'sceneflow':
validate_sceneflow(model, iters=args.valid_iters, mixed_prec=use_mixed_precision)