IGEV/IGEV-Stereo/core/stereo_datasets.py

361 lines
16 KiB
Python
Raw Normal View History

2023-03-12 20:19:58 +08:00
import numpy as np
import torch
import torch.utils.data as data
import torch.nn.functional as F
import logging
import os
import re
import copy
import math
import random
from pathlib import Path
from glob import glob
import os.path as osp
from core.utils import frame_utils
from core.utils.augmentor import FlowAugmentor, SparseFlowAugmentor
class StereoDataset(data.Dataset):
def __init__(self, aug_params=None, sparse=False, reader=None):
self.augmentor = None
self.sparse = sparse
self.img_pad = aug_params.pop("img_pad", None) if aug_params is not None else None
if aug_params is not None and "crop_size" in aug_params:
if sparse:
self.augmentor = SparseFlowAugmentor(**aug_params)
else:
self.augmentor = FlowAugmentor(**aug_params)
if reader is None:
self.disparity_reader = frame_utils.read_gen
else:
self.disparity_reader = reader
self.is_test = False
self.init_seed = False
self.flow_list = []
self.disparity_list = []
self.image_list = []
self.extra_info = []
def __getitem__(self, index):
if self.is_test:
img1 = frame_utils.read_gen(self.image_list[index][0])
img2 = frame_utils.read_gen(self.image_list[index][1])
img1 = np.array(img1).astype(np.uint8)[..., :3]
img2 = np.array(img2).astype(np.uint8)[..., :3]
img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
return img1, img2, self.extra_info[index]
if not self.init_seed:
worker_info = torch.utils.data.get_worker_info()
if worker_info is not None:
torch.manual_seed(worker_info.id)
np.random.seed(worker_info.id)
random.seed(worker_info.id)
self.init_seed = True
index = index % len(self.image_list)
disp = self.disparity_reader(self.disparity_list[index])
if isinstance(disp, tuple):
disp, valid = disp
else:
valid = disp < 512
img1 = frame_utils.read_gen(self.image_list[index][0])
img2 = frame_utils.read_gen(self.image_list[index][1])
img1 = np.array(img1).astype(np.uint8)
img2 = np.array(img2).astype(np.uint8)
disp = np.array(disp).astype(np.float32)
assert not (True in np.isnan(disp))
2023-03-12 20:19:58 +08:00
flow = np.stack([disp, np.zeros_like(disp)], axis=-1)
# grayscale images
if len(img1.shape) == 2:
img1 = np.tile(img1[...,None], (1, 1, 3))
img2 = np.tile(img2[...,None], (1, 1, 3))
else:
img1 = img1[..., :3]
img2 = img2[..., :3]
if self.augmentor is not None:
if self.sparse:
img1, img2, flow, valid = self.augmentor(img1, img2, flow, valid)
else:
img1, img2, flow = self.augmentor(img1, img2, flow)
img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
flow = torch.from_numpy(flow).permute(2, 0, 1).float()
if self.sparse:
valid = torch.from_numpy(valid)
else:
valid = (flow[0].abs() < 512) & (flow[1].abs() < 512)
if self.img_pad is not None:
padH, padW = self.img_pad
img1 = F.pad(img1, [padW]*2 + [padH]*2)
img2 = F.pad(img2, [padW]*2 + [padH]*2)
flow = flow[:1]
return self.image_list[index] + [self.disparity_list[index]], img1, img2, flow, valid.float()
def __mul__(self, v):
copy_of_self = copy.deepcopy(self)
copy_of_self.flow_list = v * copy_of_self.flow_list
copy_of_self.image_list = v * copy_of_self.image_list
copy_of_self.disparity_list = v * copy_of_self.disparity_list
copy_of_self.extra_info = v * copy_of_self.extra_info
return copy_of_self
def __len__(self):
return len(self.image_list)
class SceneFlowDatasets(StereoDataset):
def __init__(self, aug_params=None, root='/data/sceneflow/', dstype='frames_finalpass', things_test=False):
super(SceneFlowDatasets, self).__init__(aug_params)
self.root = root
self.dstype = dstype
if things_test:
self._add_things("TEST")
else:
self._add_things("TRAIN")
self._add_monkaa("TRAIN")
self._add_driving("TRAIN")
def _add_things(self, split='TRAIN'):
""" Add FlyingThings3D data """
original_length = len(self.disparity_list)
# root = osp.join(self.root, 'FlyingThings3D')
root = self.root
left_images = sorted( glob(osp.join(root, self.dstype, split, '*/*/left/*.png')) )
right_images = [ im.replace('left', 'right') for im in left_images ]
disparity_images = [ im.replace(self.dstype, 'disparity').replace('.png', '.pfm') for im in left_images ]
# Choose a random subset of 400 images for validation
state = np.random.get_state()
np.random.seed(1000)
# val_idxs = set(np.random.permutation(len(left_images))[:100])
val_idxs = set(np.random.permutation(len(left_images)))
np.random.set_state(state)
for idx, (img1, img2, disp) in enumerate(zip(left_images, right_images, disparity_images)):
if (split == 'TEST' and idx in val_idxs) or split == 'TRAIN':
self.image_list += [ [img1, img2] ]
self.disparity_list += [ disp ]
logging.info(f"Added {len(self.disparity_list) - original_length} from FlyingThings {self.dstype}")
def _add_monkaa(self, split="TRAIN"):
""" Add FlyingThings3D data """
original_length = len(self.disparity_list)
root = self.root
left_images = sorted( glob(osp.join(root, self.dstype, split, '*/left/*.png')) )
right_images = [ image_file.replace('left', 'right') for image_file in left_images ]
disparity_images = [ im.replace(self.dstype, 'disparity').replace('.png', '.pfm') for im in left_images ]
for img1, img2, disp in zip(left_images, right_images, disparity_images):
self.image_list += [ [img1, img2] ]
self.disparity_list += [ disp ]
logging.info(f"Added {len(self.disparity_list) - original_length} from Monkaa {self.dstype}")
def _add_driving(self, split="TRAIN"):
""" Add FlyingThings3D data """
original_length = len(self.disparity_list)
root = self.root
left_images = sorted( glob(osp.join(root, self.dstype, split, '*/*/*/left/*.png')) )
right_images = [ image_file.replace('left', 'right') for image_file in left_images ]
disparity_images = [ im.replace(self.dstype, 'disparity').replace('.png', '.pfm') for im in left_images ]
for img1, img2, disp in zip(left_images, right_images, disparity_images):
self.image_list += [ [img1, img2] ]
self.disparity_list += [ disp ]
logging.info(f"Added {len(self.disparity_list) - original_length} from Driving {self.dstype}")
class ETH3D(StereoDataset):
def __init__(self, aug_params=None, root='/data/ETH3D', split='training'):
super(ETH3D, self).__init__(aug_params, sparse=True)
image1_list = sorted( glob(osp.join(root, f'two_view_{split}/*/im0.png')) )
image2_list = sorted( glob(osp.join(root, f'two_view_{split}/*/im1.png')) )
disp_list = sorted( glob(osp.join(root, 'two_view_training_gt/*/disp0GT.pfm')) ) if split == 'training' else [osp.join(root, 'two_view_training_gt/playground_1l/disp0GT.pfm')]*len(image1_list)
for img1, img2, disp in zip(image1_list, image2_list, disp_list):
self.image_list += [ [img1, img2] ]
self.disparity_list += [ disp ]
class SintelStereo(StereoDataset):
def __init__(self, aug_params=None, root='datasets/SintelStereo'):
super().__init__(aug_params, sparse=True, reader=frame_utils.readDispSintelStereo)
image1_list = sorted( glob(osp.join(root, 'training/*_left/*/frame_*.png')) )
image2_list = sorted( glob(osp.join(root, 'training/*_right/*/frame_*.png')) )
disp_list = sorted( glob(osp.join(root, 'training/disparities/*/frame_*.png')) ) * 2
for img1, img2, disp in zip(image1_list, image2_list, disp_list):
assert img1.split('/')[-2:] == disp.split('/')[-2:]
self.image_list += [ [img1, img2] ]
self.disparity_list += [ disp ]
class FallingThings(StereoDataset):
def __init__(self, aug_params=None, root='datasets/FallingThings'):
super().__init__(aug_params, reader=frame_utils.readDispFallingThings)
assert os.path.exists(root)
with open(os.path.join(root, 'filenames.txt'), 'r') as f:
filenames = sorted(f.read().splitlines())
image1_list = [osp.join(root, e) for e in filenames]
image2_list = [osp.join(root, e.replace('left.jpg', 'right.jpg')) for e in filenames]
disp_list = [osp.join(root, e.replace('left.jpg', 'left.depth.png')) for e in filenames]
for img1, img2, disp in zip(image1_list, image2_list, disp_list):
self.image_list += [ [img1, img2] ]
self.disparity_list += [ disp ]
class TartanAir(StereoDataset):
def __init__(self, aug_params=None, root='datasets', keywords=[]):
super().__init__(aug_params, reader=frame_utils.readDispTartanAir)
assert os.path.exists(root)
with open(os.path.join(root, 'tartanair_filenames.txt'), 'r') as f:
filenames = sorted(list(filter(lambda s: 'seasonsforest_winter/Easy' not in s, f.read().splitlines())))
for kw in keywords:
filenames = sorted(list(filter(lambda s: kw in s.lower(), filenames)))
image1_list = [osp.join(root, e) for e in filenames]
image2_list = [osp.join(root, e.replace('_left', '_right')) for e in filenames]
disp_list = [osp.join(root, e.replace('image_left', 'depth_left').replace('left.png', 'left_depth.npy')) for e in filenames]
for img1, img2, disp in zip(image1_list, image2_list, disp_list):
self.image_list += [ [img1, img2] ]
self.disparity_list += [ disp ]
class KITTI(StereoDataset):
def __init__(self, aug_params=None, root='/data/KITTI/KITTI_2015', image_set='training'):
super(KITTI, self).__init__(aug_params, sparse=True, reader=frame_utils.readDispKITTI)
assert os.path.exists(root)
root_12 = '/data/KITTI/KITTI_2012'
image1_list = sorted(glob(os.path.join(root_12, image_set, 'colored_0/*_10.png')))
image2_list = sorted(glob(os.path.join(root_12, image_set, 'colored_1/*_10.png')))
disp_list = sorted(glob(os.path.join(root_12, 'training', 'disp_occ/*_10.png'))) if image_set == 'training' else [osp.join(root, 'training/disp_occ/000085_10.png')]*len(image1_list)
root_15 = '/data/KITTI/KITTI_2015'
image1_list += sorted(glob(os.path.join(root_15, image_set, 'image_2/*_10.png')))
image2_list += sorted(glob(os.path.join(root_15, image_set, 'image_3/*_10.png')))
disp_list += sorted(glob(os.path.join(root_15, 'training', 'disp_occ_0/*_10.png'))) if image_set == 'training' else [osp.join(root, 'training/disp_occ_0/000085_10.png')]*len(image1_list)
for idx, (img1, img2, disp) in enumerate(zip(image1_list, image2_list, disp_list)):
self.image_list += [ [img1, img2] ]
self.disparity_list += [ disp ]
class CREStereo(StereoDataset):
def __init__(self, aug_params=None, root='/data/CREStereo'):
super(CREStereo, self).__init__(aug_params, sparse=True, reader=frame_utils.readDispCREStereo)
self.root = root
assert os.path.exists(root)
disp_list = self.selector('_left.disp.png')
image1_list = self.selector('_left.jpg')
image2_list = self.selector('_right.jpg')
assert len(image1_list) == len(image2_list) == len(disp_list) > 0
for img1, img2, disp in zip(image1_list, image2_list, disp_list):
# if random.randint(1, 20000) != 1:
# continue
self.image_list += [[img1, img2]]
self.disparity_list += [disp]
def selector(self, suffix):
files = list(glob(os.path.join(self.root, f"hole/*{suffix}")))
files += list(glob(os.path.join(self.root, f"shapenet/*{suffix}")))
files += list(glob(os.path.join(self.root, f"tree/*{suffix}")))
files += list(glob(os.path.join(self.root, f"reflective/*{suffix}")))
return sorted(files)
2023-03-12 20:19:58 +08:00
class Middlebury(StereoDataset):
def __init__(self, aug_params=None, root='/data/Middlebury', split='H'):
2023-03-12 20:19:58 +08:00
super(Middlebury, self).__init__(aug_params, sparse=True, reader=frame_utils.readDispMiddlebury)
assert os.path.exists(root)
assert split in "FHQ"
lines = list(map(osp.basename, glob(os.path.join(root, "trainingH/*"))))
# lines = list(filter(lambda p: any(s in p.split('/') for s in Path(os.path.join(root, "MiddEval3/official_train.txt")).read_text().splitlines()), lines))
# image1_list = sorted([os.path.join(root, "MiddEval3", f'training{split}', f'{name}/im0.png') for name in lines])
# image2_list = sorted([os.path.join(root, "MiddEval3", f'training{split}', f'{name}/im1.png') for name in lines])
# disp_list = sorted([os.path.join(root, "MiddEval3", f'training{split}', f'{name}/disp0GT.pfm') for name in lines])
image1_list = sorted([os.path.join(root, f'training{split}', f'{name}/im0.png') for name in lines])
image2_list = sorted([os.path.join(root, f'training{split}', f'{name}/im1.png') for name in lines])
disp_list = sorted([os.path.join(root, f'training{split}', f'{name}/disp0GT.pfm') for name in lines])
assert len(image1_list) == len(image2_list) == len(disp_list) > 0, [image1_list, split]
for img1, img2, disp in zip(image1_list, image2_list, disp_list):
self.image_list += [ [img1, img2] ]
self.disparity_list += [ disp ]
def fetch_dataloader(args):
""" Create the data loader for the corresponding trainign set """
aug_params = {'crop_size': args.image_size, 'min_scale': args.spatial_scale[0], 'max_scale': args.spatial_scale[1], 'do_flip': False, 'yjitter': not args.noyjitter}
if hasattr(args, "saturation_range") and args.saturation_range is not None:
aug_params["saturation_range"] = args.saturation_range
if hasattr(args, "img_gamma") and args.img_gamma is not None:
aug_params["gamma"] = args.img_gamma
if hasattr(args, "do_flip") and args.do_flip is not None:
aug_params["do_flip"] = args.do_flip
train_dataset = None
for dataset_name in args.train_datasets:
if re.compile("middlebury_.*").fullmatch(dataset_name):
new_dataset = Middlebury(aug_params, split=dataset_name.replace('middlebury_',''))
elif dataset_name == 'sceneflow':
#clean_dataset = SceneFlowDatasets(aug_params, dstype='frames_cleanpass')
final_dataset = SceneFlowDatasets(aug_params, dstype='frames_finalpass')
#new_dataset = (clean_dataset*4) + (final_dataset*4)
new_dataset = final_dataset
logging.info(f"Adding {len(new_dataset)} samples from SceneFlow")
elif 'kitti' in dataset_name:
new_dataset = KITTI(aug_params)
logging.info(f"Adding {len(new_dataset)} samples from KITTI")
elif dataset_name == 'sintel_stereo':
new_dataset = SintelStereo(aug_params)*140
logging.info(f"Adding {len(new_dataset)} samples from Sintel Stereo")
elif dataset_name == 'falling_things':
new_dataset = FallingThings(aug_params)*5
logging.info(f"Adding {len(new_dataset)} samples from FallingThings")
elif dataset_name.startswith('tartan_air'):
new_dataset = TartanAir(aug_params, keywords=dataset_name.split('_')[2:])
logging.info(f"Adding {len(new_dataset)} samples from Tartain Air")
elif dataset_name.startswith('crestereo'):
new_dataset = CREStereo(aug_params)
logging.info(f"Adding {len(new_dataset)} samples from CREStereo")
2023-03-12 20:19:58 +08:00
train_dataset = new_dataset if train_dataset is None else train_dataset + new_dataset
train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size,
2023-04-24 16:37:30 +08:00
pin_memory=True, shuffle=True, num_workers=int(os.environ.get('SLURM_CPUS_PER_TASK', 6))-2, drop_last=True)
2023-03-12 20:19:58 +08:00
logging.info('Training with %d image pairs' % len(train_dataset))
return train_loader