2023-03-12 20:19:58 +08:00
import numpy as np
import torch
import torch . utils . data as data
import torch . nn . functional as F
import logging
import os
import re
import copy
import math
import random
from pathlib import Path
from glob import glob
import os . path as osp
from core . utils import frame_utils
from core . utils . augmentor import FlowAugmentor , SparseFlowAugmentor
class StereoDataset ( data . Dataset ) :
def __init__ ( self , aug_params = None , sparse = False , reader = None ) :
self . augmentor = None
self . sparse = sparse
self . img_pad = aug_params . pop ( " img_pad " , None ) if aug_params is not None else None
if aug_params is not None and " crop_size " in aug_params :
if sparse :
self . augmentor = SparseFlowAugmentor ( * * aug_params )
else :
self . augmentor = FlowAugmentor ( * * aug_params )
if reader is None :
self . disparity_reader = frame_utils . read_gen
else :
self . disparity_reader = reader
self . is_test = False
self . init_seed = False
self . flow_list = [ ]
self . disparity_list = [ ]
self . image_list = [ ]
self . extra_info = [ ]
def __getitem__ ( self , index ) :
if self . is_test :
img1 = frame_utils . read_gen ( self . image_list [ index ] [ 0 ] )
img2 = frame_utils . read_gen ( self . image_list [ index ] [ 1 ] )
img1 = np . array ( img1 ) . astype ( np . uint8 ) [ . . . , : 3 ]
img2 = np . array ( img2 ) . astype ( np . uint8 ) [ . . . , : 3 ]
img1 = torch . from_numpy ( img1 ) . permute ( 2 , 0 , 1 ) . float ( )
img2 = torch . from_numpy ( img2 ) . permute ( 2 , 0 , 1 ) . float ( )
return img1 , img2 , self . extra_info [ index ]
if not self . init_seed :
worker_info = torch . utils . data . get_worker_info ( )
if worker_info is not None :
torch . manual_seed ( worker_info . id )
np . random . seed ( worker_info . id )
random . seed ( worker_info . id )
self . init_seed = True
index = index % len ( self . image_list )
disp = self . disparity_reader ( self . disparity_list [ index ] )
if isinstance ( disp , tuple ) :
disp , valid = disp
else :
valid = disp < 512
img1 = frame_utils . read_gen ( self . image_list [ index ] [ 0 ] )
img2 = frame_utils . read_gen ( self . image_list [ index ] [ 1 ] )
img1 = np . array ( img1 ) . astype ( np . uint8 )
img2 = np . array ( img2 ) . astype ( np . uint8 )
disp = np . array ( disp ) . astype ( np . float32 )
flow = np . stack ( [ disp , np . zeros_like ( disp ) ] , axis = - 1 )
# grayscale images
if len ( img1 . shape ) == 2 :
img1 = np . tile ( img1 [ . . . , None ] , ( 1 , 1 , 3 ) )
img2 = np . tile ( img2 [ . . . , None ] , ( 1 , 1 , 3 ) )
else :
img1 = img1 [ . . . , : 3 ]
img2 = img2 [ . . . , : 3 ]
if self . augmentor is not None :
if self . sparse :
img1 , img2 , flow , valid = self . augmentor ( img1 , img2 , flow , valid )
else :
img1 , img2 , flow = self . augmentor ( img1 , img2 , flow )
img1 = torch . from_numpy ( img1 ) . permute ( 2 , 0 , 1 ) . float ( )
img2 = torch . from_numpy ( img2 ) . permute ( 2 , 0 , 1 ) . float ( )
flow = torch . from_numpy ( flow ) . permute ( 2 , 0 , 1 ) . float ( )
if self . sparse :
valid = torch . from_numpy ( valid )
else :
valid = ( flow [ 0 ] . abs ( ) < 512 ) & ( flow [ 1 ] . abs ( ) < 512 )
if self . img_pad is not None :
padH , padW = self . img_pad
img1 = F . pad ( img1 , [ padW ] * 2 + [ padH ] * 2 )
img2 = F . pad ( img2 , [ padW ] * 2 + [ padH ] * 2 )
flow = flow [ : 1 ]
return self . image_list [ index ] + [ self . disparity_list [ index ] ] , img1 , img2 , flow , valid . float ( )
def __mul__ ( self , v ) :
copy_of_self = copy . deepcopy ( self )
copy_of_self . flow_list = v * copy_of_self . flow_list
copy_of_self . image_list = v * copy_of_self . image_list
copy_of_self . disparity_list = v * copy_of_self . disparity_list
copy_of_self . extra_info = v * copy_of_self . extra_info
return copy_of_self
def __len__ ( self ) :
return len ( self . image_list )
class SceneFlowDatasets ( StereoDataset ) :
def __init__ ( self , aug_params = None , root = ' /data/sceneflow/ ' , dstype = ' frames_finalpass ' , things_test = False ) :
super ( SceneFlowDatasets , self ) . __init__ ( aug_params )
self . root = root
self . dstype = dstype
if things_test :
self . _add_things ( " TEST " )
else :
self . _add_things ( " TRAIN " )
self . _add_monkaa ( " TRAIN " )
self . _add_driving ( " TRAIN " )
def _add_things ( self , split = ' TRAIN ' ) :
""" Add FlyingThings3D data """
original_length = len ( self . disparity_list )
# root = osp.join(self.root, 'FlyingThings3D')
root = self . root
left_images = sorted ( glob ( osp . join ( root , self . dstype , split , ' */*/left/*.png ' ) ) )
right_images = [ im . replace ( ' left ' , ' right ' ) for im in left_images ]
disparity_images = [ im . replace ( self . dstype , ' disparity ' ) . replace ( ' .png ' , ' .pfm ' ) for im in left_images ]
# Choose a random subset of 400 images for validation
state = np . random . get_state ( )
np . random . seed ( 1000 )
# val_idxs = set(np.random.permutation(len(left_images))[:100])
val_idxs = set ( np . random . permutation ( len ( left_images ) ) )
np . random . set_state ( state )
for idx , ( img1 , img2 , disp ) in enumerate ( zip ( left_images , right_images , disparity_images ) ) :
if ( split == ' TEST ' and idx in val_idxs ) or split == ' TRAIN ' :
self . image_list + = [ [ img1 , img2 ] ]
self . disparity_list + = [ disp ]
logging . info ( f " Added { len ( self . disparity_list ) - original_length } from FlyingThings { self . dstype } " )
def _add_monkaa ( self , split = " TRAIN " ) :
""" Add FlyingThings3D data """
original_length = len ( self . disparity_list )
root = self . root
left_images = sorted ( glob ( osp . join ( root , self . dstype , split , ' */left/*.png ' ) ) )
right_images = [ image_file . replace ( ' left ' , ' right ' ) for image_file in left_images ]
disparity_images = [ im . replace ( self . dstype , ' disparity ' ) . replace ( ' .png ' , ' .pfm ' ) for im in left_images ]
for img1 , img2 , disp in zip ( left_images , right_images , disparity_images ) :
self . image_list + = [ [ img1 , img2 ] ]
self . disparity_list + = [ disp ]
logging . info ( f " Added { len ( self . disparity_list ) - original_length } from Monkaa { self . dstype } " )
def _add_driving ( self , split = " TRAIN " ) :
""" Add FlyingThings3D data """
original_length = len ( self . disparity_list )
root = self . root
left_images = sorted ( glob ( osp . join ( root , self . dstype , split , ' */*/*/left/*.png ' ) ) )
right_images = [ image_file . replace ( ' left ' , ' right ' ) for image_file in left_images ]
disparity_images = [ im . replace ( self . dstype , ' disparity ' ) . replace ( ' .png ' , ' .pfm ' ) for im in left_images ]
for img1 , img2 , disp in zip ( left_images , right_images , disparity_images ) :
self . image_list + = [ [ img1 , img2 ] ]
self . disparity_list + = [ disp ]
logging . info ( f " Added { len ( self . disparity_list ) - original_length } from Driving { self . dstype } " )
class ETH3D ( StereoDataset ) :
def __init__ ( self , aug_params = None , root = ' /data/ETH3D ' , split = ' training ' ) :
super ( ETH3D , self ) . __init__ ( aug_params , sparse = True )
image1_list = sorted ( glob ( osp . join ( root , f ' two_view_ { split } /*/im0.png ' ) ) )
image2_list = sorted ( glob ( osp . join ( root , f ' two_view_ { split } /*/im1.png ' ) ) )
disp_list = sorted ( glob ( osp . join ( root , ' two_view_training_gt/*/disp0GT.pfm ' ) ) ) if split == ' training ' else [ osp . join ( root , ' two_view_training_gt/playground_1l/disp0GT.pfm ' ) ] * len ( image1_list )
for img1 , img2 , disp in zip ( image1_list , image2_list , disp_list ) :
self . image_list + = [ [ img1 , img2 ] ]
self . disparity_list + = [ disp ]
class SintelStereo ( StereoDataset ) :
def __init__ ( self , aug_params = None , root = ' datasets/SintelStereo ' ) :
super ( ) . __init__ ( aug_params , sparse = True , reader = frame_utils . readDispSintelStereo )
image1_list = sorted ( glob ( osp . join ( root , ' training/*_left/*/frame_*.png ' ) ) )
image2_list = sorted ( glob ( osp . join ( root , ' training/*_right/*/frame_*.png ' ) ) )
disp_list = sorted ( glob ( osp . join ( root , ' training/disparities/*/frame_*.png ' ) ) ) * 2
for img1 , img2 , disp in zip ( image1_list , image2_list , disp_list ) :
assert img1 . split ( ' / ' ) [ - 2 : ] == disp . split ( ' / ' ) [ - 2 : ]
self . image_list + = [ [ img1 , img2 ] ]
self . disparity_list + = [ disp ]
class FallingThings ( StereoDataset ) :
def __init__ ( self , aug_params = None , root = ' datasets/FallingThings ' ) :
super ( ) . __init__ ( aug_params , reader = frame_utils . readDispFallingThings )
assert os . path . exists ( root )
with open ( os . path . join ( root , ' filenames.txt ' ) , ' r ' ) as f :
filenames = sorted ( f . read ( ) . splitlines ( ) )
image1_list = [ osp . join ( root , e ) for e in filenames ]
image2_list = [ osp . join ( root , e . replace ( ' left.jpg ' , ' right.jpg ' ) ) for e in filenames ]
disp_list = [ osp . join ( root , e . replace ( ' left.jpg ' , ' left.depth.png ' ) ) for e in filenames ]
for img1 , img2 , disp in zip ( image1_list , image2_list , disp_list ) :
self . image_list + = [ [ img1 , img2 ] ]
self . disparity_list + = [ disp ]
class TartanAir ( StereoDataset ) :
def __init__ ( self , aug_params = None , root = ' datasets ' , keywords = [ ] ) :
super ( ) . __init__ ( aug_params , reader = frame_utils . readDispTartanAir )
assert os . path . exists ( root )
with open ( os . path . join ( root , ' tartanair_filenames.txt ' ) , ' r ' ) as f :
filenames = sorted ( list ( filter ( lambda s : ' seasonsforest_winter/Easy ' not in s , f . read ( ) . splitlines ( ) ) ) )
for kw in keywords :
filenames = sorted ( list ( filter ( lambda s : kw in s . lower ( ) , filenames ) ) )
image1_list = [ osp . join ( root , e ) for e in filenames ]
image2_list = [ osp . join ( root , e . replace ( ' _left ' , ' _right ' ) ) for e in filenames ]
disp_list = [ osp . join ( root , e . replace ( ' image_left ' , ' depth_left ' ) . replace ( ' left.png ' , ' left_depth.npy ' ) ) for e in filenames ]
for img1 , img2 , disp in zip ( image1_list , image2_list , disp_list ) :
self . image_list + = [ [ img1 , img2 ] ]
self . disparity_list + = [ disp ]
class KITTI ( StereoDataset ) :
def __init__ ( self , aug_params = None , root = ' /data/KITTI/KITTI_2015 ' , image_set = ' training ' ) :
super ( KITTI , self ) . __init__ ( aug_params , sparse = True , reader = frame_utils . readDispKITTI )
assert os . path . exists ( root )
root_12 = ' /data/KITTI/KITTI_2012 '
image1_list = sorted ( glob ( os . path . join ( root_12 , image_set , ' colored_0/*_10.png ' ) ) )
image2_list = sorted ( glob ( os . path . join ( root_12 , image_set , ' colored_1/*_10.png ' ) ) )
disp_list = sorted ( glob ( os . path . join ( root_12 , ' training ' , ' disp_occ/*_10.png ' ) ) ) if image_set == ' training ' else [ osp . join ( root , ' training/disp_occ/000085_10.png ' ) ] * len ( image1_list )
root_15 = ' /data/KITTI/KITTI_2015 '
image1_list + = sorted ( glob ( os . path . join ( root_15 , image_set , ' image_2/*_10.png ' ) ) )
image2_list + = sorted ( glob ( os . path . join ( root_15 , image_set , ' image_3/*_10.png ' ) ) )
disp_list + = sorted ( glob ( os . path . join ( root_15 , ' training ' , ' disp_occ_0/*_10.png ' ) ) ) if image_set == ' training ' else [ osp . join ( root , ' training/disp_occ_0/000085_10.png ' ) ] * len ( image1_list )
for idx , ( img1 , img2 , disp ) in enumerate ( zip ( image1_list , image2_list , disp_list ) ) :
self . image_list + = [ [ img1 , img2 ] ]
self . disparity_list + = [ disp ]
2023-04-22 11:32:10 +08:00
class CREStereo ( StereoDataset ) :
def __init__ ( self , aug_params = None , root = ' /data/CREStereo ' ) :
super ( CREStereo , self ) . __init__ ( aug_params , sparse = True , reader = frame_utils . readDispCREStereo )
self . root = root
assert os . path . exists ( root )
disp_list = self . selector ( ' _left.disp.png ' )
image1_list = self . selector ( ' _left.jpg ' )
image2_list = self . selector ( ' _right.jpg ' )
assert len ( image1_list ) == len ( image2_list ) == len ( disp_list ) > 0
for img1 , img2 , disp in zip ( image1_list , image2_list , disp_list ) :
# if random.randint(1, 20000) != 1:
# continue
self . image_list + = [ [ img1 , img2 ] ]
self . disparity_list + = [ disp ]
def selector ( self , suffix ) :
files = list ( glob ( os . path . join ( self . root , f " hole/* { suffix } " ) ) )
files + = list ( glob ( os . path . join ( self . root , f " shapenet/* { suffix } " ) ) )
files + = list ( glob ( os . path . join ( self . root , f " tree/* { suffix } " ) ) )
files + = list ( glob ( os . path . join ( self . root , f " reflective/* { suffix } " ) ) )
return sorted ( files )
2023-03-12 20:19:58 +08:00
class Middlebury ( StereoDataset ) :
def __init__ ( self , aug_params = None , root = ' /data/Middlebury ' , split = ' F ' ) :
super ( Middlebury , self ) . __init__ ( aug_params , sparse = True , reader = frame_utils . readDispMiddlebury )
assert os . path . exists ( root )
assert split in " FHQ "
lines = list ( map ( osp . basename , glob ( os . path . join ( root , " trainingH/* " ) ) ) )
# lines = list(filter(lambda p: any(s in p.split('/') for s in Path(os.path.join(root, "MiddEval3/official_train.txt")).read_text().splitlines()), lines))
# image1_list = sorted([os.path.join(root, "MiddEval3", f'training{split}', f'{name}/im0.png') for name in lines])
# image2_list = sorted([os.path.join(root, "MiddEval3", f'training{split}', f'{name}/im1.png') for name in lines])
# disp_list = sorted([os.path.join(root, "MiddEval3", f'training{split}', f'{name}/disp0GT.pfm') for name in lines])
image1_list = sorted ( [ os . path . join ( root , f ' training { split } ' , f ' { name } /im0.png ' ) for name in lines ] )
image2_list = sorted ( [ os . path . join ( root , f ' training { split } ' , f ' { name } /im1.png ' ) for name in lines ] )
disp_list = sorted ( [ os . path . join ( root , f ' training { split } ' , f ' { name } /disp0GT.pfm ' ) for name in lines ] )
assert len ( image1_list ) == len ( image2_list ) == len ( disp_list ) > 0 , [ image1_list , split ]
for img1 , img2 , disp in zip ( image1_list , image2_list , disp_list ) :
self . image_list + = [ [ img1 , img2 ] ]
self . disparity_list + = [ disp ]
def fetch_dataloader ( args ) :
""" Create the data loader for the corresponding trainign set """
aug_params = { ' crop_size ' : args . image_size , ' min_scale ' : args . spatial_scale [ 0 ] , ' max_scale ' : args . spatial_scale [ 1 ] , ' do_flip ' : False , ' yjitter ' : not args . noyjitter }
if hasattr ( args , " saturation_range " ) and args . saturation_range is not None :
aug_params [ " saturation_range " ] = args . saturation_range
if hasattr ( args , " img_gamma " ) and args . img_gamma is not None :
aug_params [ " gamma " ] = args . img_gamma
if hasattr ( args , " do_flip " ) and args . do_flip is not None :
aug_params [ " do_flip " ] = args . do_flip
train_dataset = None
for dataset_name in args . train_datasets :
if re . compile ( " middlebury_.* " ) . fullmatch ( dataset_name ) :
new_dataset = Middlebury ( aug_params , split = dataset_name . replace ( ' middlebury_ ' , ' ' ) )
elif dataset_name == ' sceneflow ' :
#clean_dataset = SceneFlowDatasets(aug_params, dstype='frames_cleanpass')
final_dataset = SceneFlowDatasets ( aug_params , dstype = ' frames_finalpass ' )
#new_dataset = (clean_dataset*4) + (final_dataset*4)
new_dataset = final_dataset
logging . info ( f " Adding { len ( new_dataset ) } samples from SceneFlow " )
elif ' kitti ' in dataset_name :
new_dataset = KITTI ( aug_params )
logging . info ( f " Adding { len ( new_dataset ) } samples from KITTI " )
elif dataset_name == ' sintel_stereo ' :
new_dataset = SintelStereo ( aug_params ) * 140
logging . info ( f " Adding { len ( new_dataset ) } samples from Sintel Stereo " )
elif dataset_name == ' falling_things ' :
new_dataset = FallingThings ( aug_params ) * 5
logging . info ( f " Adding { len ( new_dataset ) } samples from FallingThings " )
elif dataset_name . startswith ( ' tartan_air ' ) :
new_dataset = TartanAir ( aug_params , keywords = dataset_name . split ( ' _ ' ) [ 2 : ] )
logging . info ( f " Adding { len ( new_dataset ) } samples from Tartain Air " )
2023-04-22 11:32:10 +08:00
elif dataset_name . startswith ( ' crestereo ' ) :
new_dataset = CREStereo ( aug_params )
logging . info ( f " Adding { len ( new_dataset ) } samples from CREStereo " )
2023-03-12 20:19:58 +08:00
train_dataset = new_dataset if train_dataset is None else train_dataset + new_dataset
train_loader = data . DataLoader ( train_dataset , batch_size = args . batch_size ,
2023-04-24 16:37:30 +08:00
pin_memory = True , shuffle = True , num_workers = int ( os . environ . get ( ' SLURM_CPUS_PER_TASK ' , 6 ) ) - 2 , drop_last = True )
2023-03-12 20:19:58 +08:00
logging . info ( ' Training with %d image pairs ' % len ( train_dataset ) )
return train_loader