2023-03-12 20:19:58 +08:00
from __future__ import print_function , division
import sys
sys . path . append ( ' core ' )
import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import argparse
import time
import logging
import numpy as np
import torch
from tqdm import tqdm
from igev_stereo import IGEVStereo , autocast
import stereo_datasets as datasets
from utils . utils import InputPadder
from PIL import Image
def count_parameters ( model ) :
return sum ( p . numel ( ) for p in model . parameters ( ) if p . requires_grad )
@torch.no_grad ( )
2023-04-26 19:50:17 +08:00
def validate_eth3d ( model , iters = 32 , mixed_prec = False , max_disp = 192 ) :
2023-03-12 20:19:58 +08:00
""" Peform validation using the ETH3D (train) split """
model . eval ( )
aug_params = { }
val_dataset = datasets . ETH3D ( aug_params )
out_list , epe_list = [ ] , [ ]
for val_id in range ( len ( val_dataset ) ) :
( imageL_file , imageR_file , GT_file ) , image1 , image2 , flow_gt , valid_gt = val_dataset [ val_id ]
image1 = image1 [ None ] . cuda ( )
image2 = image2 [ None ] . cuda ( )
padder = InputPadder ( image1 . shape , divis_by = 32 )
image1 , image2 = padder . pad ( image1 , image2 )
with autocast ( enabled = mixed_prec ) :
2023-03-17 11:12:32 +08:00
flow_pr = model ( image1 , image2 , iters = iters , test_mode = True )
2023-03-12 20:19:58 +08:00
flow_pr = padder . unpad ( flow_pr . float ( ) ) . cpu ( ) . squeeze ( 0 )
assert flow_pr . shape == flow_gt . shape , ( flow_pr . shape , flow_gt . shape )
epe = torch . sum ( ( flow_pr - flow_gt ) * * 2 , dim = 0 ) . sqrt ( )
epe_flattened = epe . flatten ( )
occ_mask = Image . open ( GT_file . replace ( ' disp0GT.pfm ' , ' mask0nocc.png ' ) )
occ_mask = np . ascontiguousarray ( occ_mask ) . flatten ( )
val = ( valid_gt . flatten ( ) > = 0.5 ) & ( occ_mask == 255 )
# val = (valid_gt.flatten() >= 0.5)
out = ( epe_flattened > 1.0 )
image_out = out [ val ] . float ( ) . mean ( ) . item ( )
image_epe = epe_flattened [ val ] . mean ( ) . item ( )
logging . info ( f " ETH3D { val_id + 1 } out of { len ( val_dataset ) } . EPE { round ( image_epe , 4 ) } D1 { round ( image_out , 4 ) } " )
epe_list . append ( image_epe )
out_list . append ( image_out )
epe_list = np . array ( epe_list )
out_list = np . array ( out_list )
epe = np . mean ( epe_list )
d1 = 100 * np . mean ( out_list )
print ( " Validation ETH3D: EPE %f , D1 %f " % ( epe , d1 ) )
return { ' eth3d-epe ' : epe , ' eth3d-d1 ' : d1 }
@torch.no_grad ( )
2023-04-26 19:50:17 +08:00
def validate_kitti ( model , iters = 32 , mixed_prec = False , max_disp = 192 ) :
2023-03-12 20:19:58 +08:00
""" Peform validation using the KITTI-2015 (train) split """
model . eval ( )
aug_params = { }
val_dataset = datasets . KITTI ( aug_params , image_set = ' training ' )
torch . backends . cudnn . benchmark = True
out_list , epe_list , elapsed_list = [ ] , [ ] , [ ]
for val_id in range ( len ( val_dataset ) ) :
_ , image1 , image2 , flow_gt , valid_gt = val_dataset [ val_id ]
image1 = image1 [ None ] . cuda ( )
image2 = image2 [ None ] . cuda ( )
padder = InputPadder ( image1 . shape , divis_by = 32 )
image1 , image2 = padder . pad ( image1 , image2 )
with autocast ( enabled = mixed_prec ) :
start = time . time ( )
2023-03-17 11:12:32 +08:00
flow_pr = model ( image1 , image2 , iters = iters , test_mode = True )
2023-03-12 20:19:58 +08:00
end = time . time ( )
if val_id > 50 :
elapsed_list . append ( end - start )
flow_pr = padder . unpad ( flow_pr ) . cpu ( ) . squeeze ( 0 )
assert flow_pr . shape == flow_gt . shape , ( flow_pr . shape , flow_gt . shape )
epe = torch . sum ( ( flow_pr - flow_gt ) * * 2 , dim = 0 ) . sqrt ( )
epe_flattened = epe . flatten ( )
2023-04-26 19:50:17 +08:00
val = ( valid_gt . flatten ( ) > = 0.5 ) & ( flow_gt . abs ( ) . flatten ( ) < max_disp )
2023-03-12 20:19:58 +08:00
# val = valid_gt.flatten() >= 0.5
out = ( epe_flattened > 3.0 )
image_out = out [ val ] . float ( ) . mean ( ) . item ( )
image_epe = epe_flattened [ val ] . mean ( ) . item ( )
if val_id < 9 or ( val_id + 1 ) % 10 == 0 :
logging . info ( f " KITTI Iter { val_id + 1 } out of { len ( val_dataset ) } . EPE { round ( image_epe , 4 ) } D1 { round ( image_out , 4 ) } . Runtime: { format ( end - start , ' .3f ' ) } s ( { format ( 1 / ( end - start ) , ' .2f ' ) } -FPS) " )
epe_list . append ( epe_flattened [ val ] . mean ( ) . item ( ) )
out_list . append ( out [ val ] . cpu ( ) . numpy ( ) )
epe_list = np . array ( epe_list )
out_list = np . concatenate ( out_list )
epe = np . mean ( epe_list )
d1 = 100 * np . mean ( out_list )
avg_runtime = np . mean ( elapsed_list )
print ( f " Validation KITTI: EPE { epe } , D1 { d1 } , { format ( 1 / avg_runtime , ' .2f ' ) } -FPS ( { format ( avg_runtime , ' .3f ' ) } s) " )
return { ' kitti-epe ' : epe , ' kitti-d1 ' : d1 }
@torch.no_grad ( )
2023-04-26 19:50:17 +08:00
def validate_sceneflow ( model , iters = 32 , mixed_prec = False , max_disp = 192 ) :
2023-03-12 20:19:58 +08:00
""" Peform validation using the Scene Flow (TEST) split """
model . eval ( )
val_dataset = datasets . SceneFlowDatasets ( dstype = ' frames_finalpass ' , things_test = True )
out_list , epe_list = [ ] , [ ]
for val_id in tqdm ( range ( len ( val_dataset ) ) ) :
_ , image1 , image2 , flow_gt , valid_gt = val_dataset [ val_id ]
image1 = image1 [ None ] . cuda ( )
image2 = image2 [ None ] . cuda ( )
padder = InputPadder ( image1 . shape , divis_by = 32 )
image1 , image2 = padder . pad ( image1 , image2 )
with autocast ( enabled = mixed_prec ) :
2023-03-17 11:12:32 +08:00
flow_pr = model ( image1 , image2 , iters = iters , test_mode = True )
2023-03-12 20:19:58 +08:00
flow_pr = padder . unpad ( flow_pr ) . cpu ( ) . squeeze ( 0 )
assert flow_pr . shape == flow_gt . shape , ( flow_pr . shape , flow_gt . shape )
# epe = torch.sum((flow_pr - flow_gt)**2, dim=0).sqrt()
epe = torch . abs ( flow_pr - flow_gt )
epe = epe . flatten ( )
2023-04-26 19:50:17 +08:00
val = ( valid_gt . flatten ( ) > = 0.5 ) & ( flow_gt . abs ( ) . flatten ( ) < max_disp )
2023-03-12 20:19:58 +08:00
if ( np . isnan ( epe [ val ] . mean ( ) . item ( ) ) ) :
continue
out = ( epe > 3.0 )
epe_list . append ( epe [ val ] . mean ( ) . item ( ) )
out_list . append ( out [ val ] . cpu ( ) . numpy ( ) )
# if val_id == 400:
# break
epe_list = np . array ( epe_list )
out_list = np . concatenate ( out_list )
epe = np . mean ( epe_list )
d1 = 100 * np . mean ( out_list )
f = open ( ' test.txt ' , ' a ' )
f . write ( " Validation Scene Flow: %f , %f \n " % ( epe , d1 ) )
print ( " Validation Scene Flow: %f , %f " % ( epe , d1 ) )
return { ' scene-flow-epe ' : epe , ' scene-flow-d1 ' : d1 }
@torch.no_grad ( )
2023-04-26 19:50:17 +08:00
def validate_middlebury ( model , iters = 32 , split = ' F ' , mixed_prec = False , max_disp = 192 ) :
2023-03-12 20:19:58 +08:00
""" Peform validation using the Middlebury-V3 dataset """
model . eval ( )
aug_params = { }
val_dataset = datasets . Middlebury ( aug_params , split = split )
out_list , epe_list = [ ] , [ ]
for val_id in range ( len ( val_dataset ) ) :
( imageL_file , _ , _ ) , image1 , image2 , flow_gt , valid_gt = val_dataset [ val_id ]
image1 = image1 [ None ] . cuda ( )
image2 = image2 [ None ] . cuda ( )
padder = InputPadder ( image1 . shape , divis_by = 32 )
image1 , image2 = padder . pad ( image1 , image2 )
with autocast ( enabled = mixed_prec ) :
2023-03-17 11:12:32 +08:00
flow_pr = model ( image1 , image2 , iters = iters , test_mode = True )
2023-03-12 20:19:58 +08:00
flow_pr = padder . unpad ( flow_pr ) . cpu ( ) . squeeze ( 0 )
assert flow_pr . shape == flow_gt . shape , ( flow_pr . shape , flow_gt . shape )
epe = torch . sum ( ( flow_pr - flow_gt ) * * 2 , dim = 0 ) . sqrt ( )
epe_flattened = epe . flatten ( )
occ_mask = Image . open ( imageL_file . replace ( ' im0.png ' , ' mask0nocc.png ' ) ) . convert ( ' L ' )
occ_mask = np . ascontiguousarray ( occ_mask , dtype = np . float32 ) . flatten ( )
2023-04-26 19:50:17 +08:00
val = ( valid_gt . reshape ( - 1 ) > = 0.5 ) & ( flow_gt [ 0 ] . reshape ( - 1 ) < max_disp ) & ( occ_mask == 255 )
2023-03-12 20:19:58 +08:00
out = ( epe_flattened > 2.0 )
image_out = out [ val ] . float ( ) . mean ( ) . item ( )
image_epe = epe_flattened [ val ] . mean ( ) . item ( )
2023-04-27 13:29:57 +08:00
logging . info ( f " Middlebury Iter { val_id + 1 } out of { len ( val_dataset ) } . EPE { round ( image_epe , 4 ) } Err2.0 { round ( image_out , 4 ) } " )
2023-03-12 20:19:58 +08:00
epe_list . append ( image_epe )
out_list . append ( image_out )
epe_list = np . array ( epe_list )
out_list = np . array ( out_list )
epe = np . mean ( epe_list )
2023-04-27 13:29:57 +08:00
err2 = 100 * np . mean ( out_list )
2023-03-12 20:19:58 +08:00
2023-04-27 13:29:57 +08:00
print ( f " Validation Middlebury { split } : EPE { epe } , Err2.0 { err2 } " )
return { f ' middlebury { split } -epe ' : epe , f ' middlebury { split } -err2.0 ' : err2 }
2023-03-12 20:19:58 +08:00
if __name__ == ' __main__ ' :
parser = argparse . ArgumentParser ( )
parser . add_argument ( ' --restore_ckpt ' , help = " restore checkpoint " , default = ' ./pretrained_models/sceneflow/sceneflow.pth ' )
parser . add_argument ( ' --dataset ' , help = " dataset for evaluation " , default = ' sceneflow ' , choices = [ " eth3d " , " kitti " , " sceneflow " ] + [ f " middlebury_ { s } " for s in ' FHQ ' ] )
parser . add_argument ( ' --mixed_precision ' , default = False , action = ' store_true ' , help = ' use mixed precision ' )
parser . add_argument ( ' --valid_iters ' , type = int , default = 32 , help = ' number of flow-field updates during forward pass ' )
# Architecure choices
parser . add_argument ( ' --hidden_dims ' , nargs = ' + ' , type = int , default = [ 128 ] * 3 , help = " hidden state and context dimensions " )
parser . add_argument ( ' --corr_implementation ' , choices = [ " reg " , " alt " , " reg_cuda " , " alt_cuda " ] , default = " reg " , help = " correlation volume implementation " )
parser . add_argument ( ' --shared_backbone ' , action = ' store_true ' , help = " use a single backbone for the context and feature encoders " )
parser . add_argument ( ' --corr_levels ' , type = int , default = 2 , help = " number of levels in the correlation pyramid " )
parser . add_argument ( ' --corr_radius ' , type = int , default = 4 , help = " width of the correlation pyramid " )
parser . add_argument ( ' --n_downsample ' , type = int , default = 2 , help = " resolution of the disparity field (1/2^K) " )
parser . add_argument ( ' --slow_fast_gru ' , action = ' store_true ' , help = " iterate the low-res GRUs more frequently " )
parser . add_argument ( ' --n_gru_layers ' , type = int , default = 3 , help = " number of hidden GRU levels " )
parser . add_argument ( ' --max_disp ' , type = int , default = 192 , help = " max disp of geometry encoding volume " )
2023-04-27 19:16:45 +08:00
parser . add_argument ( ' --freeze_backbone_params ' , action = " store_true " , help = " freeze backbone parameters " )
2023-03-12 20:19:58 +08:00
args = parser . parse_args ( )
model = torch . nn . DataParallel ( IGEVStereo ( args ) , device_ids = [ 0 ] )
logging . basicConfig ( level = logging . INFO ,
format = ' %(asctime)s %(levelname)-8s [ %(filename)s : %(lineno)d ] %(message)s ' )
if args . restore_ckpt is not None :
assert args . restore_ckpt . endswith ( " .pth " )
logging . info ( " Loading checkpoint... " )
checkpoint = torch . load ( args . restore_ckpt )
2023-04-27 19:16:45 +08:00
unwanted_prefix = ' _orig_mod. '
for k , v in list ( checkpoint . items ( ) ) :
if k . startswith ( unwanted_prefix ) :
checkpoint [ k [ len ( unwanted_prefix ) : ] ] = checkpoint . pop ( k )
2023-03-12 20:19:58 +08:00
model . load_state_dict ( checkpoint , strict = True )
logging . info ( f " Done loading checkpoint " )
model . cuda ( )
model . eval ( )
print ( f " The model has { format ( count_parameters ( model ) / 1e6 , ' .2f ' ) } M learnable parameters. " )
use_mixed_precision = args . corr_implementation . endswith ( " _cuda " )
if args . dataset == ' eth3d ' :
2023-04-26 19:50:17 +08:00
validate_eth3d ( model , iters = args . valid_iters , mixed_prec = use_mixed_precision , max_disp = args . max_disp )
2023-03-12 20:19:58 +08:00
elif args . dataset == ' kitti ' :
2023-04-26 19:50:17 +08:00
validate_kitti ( model , iters = args . valid_iters , mixed_prec = use_mixed_precision , max_disp = args . max_disp )
2023-03-12 20:19:58 +08:00
elif args . dataset in [ f " middlebury_ { s } " for s in ' FHQ ' ] :
2023-04-27 19:16:45 +08:00
validate_middlebury ( model , iters = args . valid_iters , split = args . dataset [ - 1 ] , mixed_prec = use_mixed_precision , max_disp = args . max_disp )
2023-03-12 20:19:58 +08:00
elif args . dataset == ' sceneflow ' :
2023-04-27 19:16:45 +08:00
validate_sceneflow ( model , iters = args . valid_iters , mixed_prec = use_mixed_precision , max_disp = args . max_disp )