import numpy as np import torch import torch.utils.data as data import torch.nn.functional as F import logging import os import re import copy import math import random from pathlib import Path from glob import glob import os.path as osp from core.utils import frame_utils from core.utils.augmentor import FlowAugmentor, SparseFlowAugmentor class StereoDataset(data.Dataset): def __init__(self, aug_params=None, sparse=False, reader=None): self.augmentor = None self.sparse = sparse self.img_pad = aug_params.pop("img_pad", None) if aug_params is not None else None if aug_params is not None and "crop_size" in aug_params: if sparse: self.augmentor = SparseFlowAugmentor(**aug_params) else: self.augmentor = FlowAugmentor(**aug_params) if reader is None: self.disparity_reader = frame_utils.read_gen else: self.disparity_reader = reader self.is_test = False self.init_seed = False self.flow_list = [] self.disparity_list = [] self.image_list = [] self.extra_info = [] def __getitem__(self, index): if self.is_test: img1 = frame_utils.read_gen(self.image_list[index][0]) img2 = frame_utils.read_gen(self.image_list[index][1]) img1 = np.array(img1).astype(np.uint8)[..., :3] img2 = np.array(img2).astype(np.uint8)[..., :3] img1 = torch.from_numpy(img1).permute(2, 0, 1).float() img2 = torch.from_numpy(img2).permute(2, 0, 1).float() return img1, img2, self.extra_info[index] if not self.init_seed: worker_info = torch.utils.data.get_worker_info() if worker_info is not None: torch.manual_seed(worker_info.id) np.random.seed(worker_info.id) random.seed(worker_info.id) self.init_seed = True index = index % len(self.image_list) disp = self.disparity_reader(self.disparity_list[index]) if isinstance(disp, tuple): disp, valid = disp else: valid = disp < 512 img1 = frame_utils.read_gen(self.image_list[index][0]) img2 = frame_utils.read_gen(self.image_list[index][1]) img1 = np.array(img1).astype(np.uint8) img2 = np.array(img2).astype(np.uint8) disp = np.array(disp).astype(np.float32) assert not (True in np.isnan(disp)) flow = np.stack([disp, np.zeros_like(disp)], axis=-1) # grayscale images if len(img1.shape) == 2: img1 = np.tile(img1[...,None], (1, 1, 3)) img2 = np.tile(img2[...,None], (1, 1, 3)) else: img1 = img1[..., :3] img2 = img2[..., :3] if self.augmentor is not None: if self.sparse: img1, img2, flow, valid = self.augmentor(img1, img2, flow, valid) else: img1, img2, flow = self.augmentor(img1, img2, flow) img1 = torch.from_numpy(img1).permute(2, 0, 1).float() img2 = torch.from_numpy(img2).permute(2, 0, 1).float() flow = torch.from_numpy(flow).permute(2, 0, 1).float() if self.sparse: valid = torch.from_numpy(valid) else: valid = (flow[0].abs() < 512) & (flow[1].abs() < 512) if self.img_pad is not None: padH, padW = self.img_pad img1 = F.pad(img1, [padW]*2 + [padH]*2) img2 = F.pad(img2, [padW]*2 + [padH]*2) flow = flow[:1] return self.image_list[index] + [self.disparity_list[index]], img1, img2, flow, valid.float() def __mul__(self, v): copy_of_self = copy.deepcopy(self) copy_of_self.flow_list = v * copy_of_self.flow_list copy_of_self.image_list = v * copy_of_self.image_list copy_of_self.disparity_list = v * copy_of_self.disparity_list copy_of_self.extra_info = v * copy_of_self.extra_info return copy_of_self def __len__(self): return len(self.image_list) class SceneFlowDatasets(StereoDataset): def __init__(self, aug_params=None, root='/data/sceneflow/', dstype='frames_finalpass', things_test=False): super(SceneFlowDatasets, self).__init__(aug_params) self.root = root self.dstype = dstype if things_test: self._add_things("TEST") else: self._add_things("TRAIN") self._add_monkaa("TRAIN") self._add_driving("TRAIN") def _add_things(self, split='TRAIN'): """ Add FlyingThings3D data """ original_length = len(self.disparity_list) # root = osp.join(self.root, 'FlyingThings3D') root = self.root left_images = sorted( glob(osp.join(root, self.dstype, split, '*/*/left/*.png')) ) right_images = [ im.replace('left', 'right') for im in left_images ] disparity_images = [ im.replace(self.dstype, 'disparity').replace('.png', '.pfm') for im in left_images ] # Choose a random subset of 400 images for validation state = np.random.get_state() np.random.seed(1000) # val_idxs = set(np.random.permutation(len(left_images))[:100]) val_idxs = set(np.random.permutation(len(left_images))) np.random.set_state(state) for idx, (img1, img2, disp) in enumerate(zip(left_images, right_images, disparity_images)): if (split == 'TEST' and idx in val_idxs) or split == 'TRAIN': self.image_list += [ [img1, img2] ] self.disparity_list += [ disp ] logging.info(f"Added {len(self.disparity_list) - original_length} from FlyingThings {self.dstype}") def _add_monkaa(self, split="TRAIN"): """ Add FlyingThings3D data """ original_length = len(self.disparity_list) root = self.root left_images = sorted( glob(osp.join(root, self.dstype, split, '*/left/*.png')) ) right_images = [ image_file.replace('left', 'right') for image_file in left_images ] disparity_images = [ im.replace(self.dstype, 'disparity').replace('.png', '.pfm') for im in left_images ] for img1, img2, disp in zip(left_images, right_images, disparity_images): self.image_list += [ [img1, img2] ] self.disparity_list += [ disp ] logging.info(f"Added {len(self.disparity_list) - original_length} from Monkaa {self.dstype}") def _add_driving(self, split="TRAIN"): """ Add FlyingThings3D data """ original_length = len(self.disparity_list) root = self.root left_images = sorted( glob(osp.join(root, self.dstype, split, '*/*/*/left/*.png')) ) right_images = [ image_file.replace('left', 'right') for image_file in left_images ] disparity_images = [ im.replace(self.dstype, 'disparity').replace('.png', '.pfm') for im in left_images ] for img1, img2, disp in zip(left_images, right_images, disparity_images): self.image_list += [ [img1, img2] ] self.disparity_list += [ disp ] logging.info(f"Added {len(self.disparity_list) - original_length} from Driving {self.dstype}") class ETH3D(StereoDataset): def __init__(self, aug_params=None, root='/data/ETH3D', split='training'): super(ETH3D, self).__init__(aug_params, sparse=True) image1_list = sorted( glob(osp.join(root, f'two_view_{split}/*/im0.png')) ) image2_list = sorted( glob(osp.join(root, f'two_view_{split}/*/im1.png')) ) disp_list = sorted( glob(osp.join(root, 'two_view_training_gt/*/disp0GT.pfm')) ) if split == 'training' else [osp.join(root, 'two_view_training_gt/playground_1l/disp0GT.pfm')]*len(image1_list) for img1, img2, disp in zip(image1_list, image2_list, disp_list): self.image_list += [ [img1, img2] ] self.disparity_list += [ disp ] class SintelStereo(StereoDataset): def __init__(self, aug_params=None, root='datasets/SintelStereo'): super().__init__(aug_params, sparse=True, reader=frame_utils.readDispSintelStereo) image1_list = sorted( glob(osp.join(root, 'training/*_left/*/frame_*.png')) ) image2_list = sorted( glob(osp.join(root, 'training/*_right/*/frame_*.png')) ) disp_list = sorted( glob(osp.join(root, 'training/disparities/*/frame_*.png')) ) * 2 for img1, img2, disp in zip(image1_list, image2_list, disp_list): assert img1.split('/')[-2:] == disp.split('/')[-2:] self.image_list += [ [img1, img2] ] self.disparity_list += [ disp ] class FallingThings(StereoDataset): def __init__(self, aug_params=None, root='datasets/FallingThings'): super().__init__(aug_params, reader=frame_utils.readDispFallingThings) assert os.path.exists(root) with open(os.path.join(root, 'filenames.txt'), 'r') as f: filenames = sorted(f.read().splitlines()) image1_list = [osp.join(root, e) for e in filenames] image2_list = [osp.join(root, e.replace('left.jpg', 'right.jpg')) for e in filenames] disp_list = [osp.join(root, e.replace('left.jpg', 'left.depth.png')) for e in filenames] for img1, img2, disp in zip(image1_list, image2_list, disp_list): self.image_list += [ [img1, img2] ] self.disparity_list += [ disp ] class TartanAir(StereoDataset): def __init__(self, aug_params=None, root='datasets', keywords=[]): super().__init__(aug_params, reader=frame_utils.readDispTartanAir) assert os.path.exists(root) with open(os.path.join(root, 'tartanair_filenames.txt'), 'r') as f: filenames = sorted(list(filter(lambda s: 'seasonsforest_winter/Easy' not in s, f.read().splitlines()))) for kw in keywords: filenames = sorted(list(filter(lambda s: kw in s.lower(), filenames))) image1_list = [osp.join(root, e) for e in filenames] image2_list = [osp.join(root, e.replace('_left', '_right')) for e in filenames] disp_list = [osp.join(root, e.replace('image_left', 'depth_left').replace('left.png', 'left_depth.npy')) for e in filenames] for img1, img2, disp in zip(image1_list, image2_list, disp_list): self.image_list += [ [img1, img2] ] self.disparity_list += [ disp ] class KITTI(StereoDataset): def __init__(self, aug_params=None, root='/data/KITTI/KITTI_2015', image_set='training'): super(KITTI, self).__init__(aug_params, sparse=True, reader=frame_utils.readDispKITTI) assert os.path.exists(root) root_12 = '/data/KITTI/KITTI_2012' image1_list = sorted(glob(os.path.join(root_12, image_set, 'colored_0/*_10.png'))) image2_list = sorted(glob(os.path.join(root_12, image_set, 'colored_1/*_10.png'))) disp_list = sorted(glob(os.path.join(root_12, 'training', 'disp_occ/*_10.png'))) if image_set == 'training' else [osp.join(root, 'training/disp_occ/000085_10.png')]*len(image1_list) root_15 = '/data/KITTI/KITTI_2015' image1_list += sorted(glob(os.path.join(root_15, image_set, 'image_2/*_10.png'))) image2_list += sorted(glob(os.path.join(root_15, image_set, 'image_3/*_10.png'))) disp_list += sorted(glob(os.path.join(root_15, 'training', 'disp_occ_0/*_10.png'))) if image_set == 'training' else [osp.join(root, 'training/disp_occ_0/000085_10.png')]*len(image1_list) for idx, (img1, img2, disp) in enumerate(zip(image1_list, image2_list, disp_list)): self.image_list += [ [img1, img2] ] self.disparity_list += [ disp ] class CREStereo(StereoDataset): def __init__(self, aug_params=None, root='/data/CREStereo'): super(CREStereo, self).__init__(aug_params, sparse=True, reader=frame_utils.readDispCREStereo) self.root = root assert os.path.exists(root) disp_list = self.selector('_left.disp.png') image1_list = self.selector('_left.jpg') image2_list = self.selector('_right.jpg') assert len(image1_list) == len(image2_list) == len(disp_list) > 0 for img1, img2, disp in zip(image1_list, image2_list, disp_list): # if random.randint(1, 20000) != 1: # continue self.image_list += [[img1, img2]] self.disparity_list += [disp] def selector(self, suffix): files = list(glob(os.path.join(self.root, f"hole/*{suffix}"))) files += list(glob(os.path.join(self.root, f"shapenet/*{suffix}"))) files += list(glob(os.path.join(self.root, f"tree/*{suffix}"))) files += list(glob(os.path.join(self.root, f"reflective/*{suffix}"))) return sorted(files) class Middlebury(StereoDataset): def __init__(self, aug_params=None, root='/data/Middlebury', split='F'): super(Middlebury, self).__init__(aug_params, sparse=True, reader=frame_utils.readDispMiddlebury) assert os.path.exists(root) assert split in "FHQ" lines = list(map(osp.basename, glob(os.path.join(root, "trainingH/*")))) # lines = list(filter(lambda p: any(s in p.split('/') for s in Path(os.path.join(root, "MiddEval3/official_train.txt")).read_text().splitlines()), lines)) # image1_list = sorted([os.path.join(root, "MiddEval3", f'training{split}', f'{name}/im0.png') for name in lines]) # image2_list = sorted([os.path.join(root, "MiddEval3", f'training{split}', f'{name}/im1.png') for name in lines]) # disp_list = sorted([os.path.join(root, "MiddEval3", f'training{split}', f'{name}/disp0GT.pfm') for name in lines]) image1_list = sorted([os.path.join(root, f'training{split}', f'{name}/im0.png') for name in lines]) image2_list = sorted([os.path.join(root, f'training{split}', f'{name}/im1.png') for name in lines]) disp_list = sorted([os.path.join(root, f'training{split}', f'{name}/disp0GT.pfm') for name in lines]) assert len(image1_list) == len(image2_list) == len(disp_list) > 0, [image1_list, split] for img1, img2, disp in zip(image1_list, image2_list, disp_list): self.image_list += [ [img1, img2] ] self.disparity_list += [ disp ] def fetch_dataloader(args): """ Create the data loader for the corresponding trainign set """ aug_params = {'crop_size': args.image_size, 'min_scale': args.spatial_scale[0], 'max_scale': args.spatial_scale[1], 'do_flip': False, 'yjitter': not args.noyjitter} if hasattr(args, "saturation_range") and args.saturation_range is not None: aug_params["saturation_range"] = args.saturation_range if hasattr(args, "img_gamma") and args.img_gamma is not None: aug_params["gamma"] = args.img_gamma if hasattr(args, "do_flip") and args.do_flip is not None: aug_params["do_flip"] = args.do_flip train_dataset = None for dataset_name in args.train_datasets: if re.compile("middlebury_.*").fullmatch(dataset_name): new_dataset = Middlebury(aug_params, split=dataset_name.replace('middlebury_','')) elif dataset_name == 'sceneflow': #clean_dataset = SceneFlowDatasets(aug_params, dstype='frames_cleanpass') final_dataset = SceneFlowDatasets(aug_params, dstype='frames_finalpass') #new_dataset = (clean_dataset*4) + (final_dataset*4) new_dataset = final_dataset logging.info(f"Adding {len(new_dataset)} samples from SceneFlow") elif 'kitti' in dataset_name: new_dataset = KITTI(aug_params) logging.info(f"Adding {len(new_dataset)} samples from KITTI") elif dataset_name == 'sintel_stereo': new_dataset = SintelStereo(aug_params)*140 logging.info(f"Adding {len(new_dataset)} samples from Sintel Stereo") elif dataset_name == 'falling_things': new_dataset = FallingThings(aug_params)*5 logging.info(f"Adding {len(new_dataset)} samples from FallingThings") elif dataset_name.startswith('tartan_air'): new_dataset = TartanAir(aug_params, keywords=dataset_name.split('_')[2:]) logging.info(f"Adding {len(new_dataset)} samples from Tartain Air") elif dataset_name.startswith('crestereo'): new_dataset = CREStereo(aug_params) logging.info(f"Adding {len(new_dataset)} samples from CREStereo") train_dataset = new_dataset if train_dataset is None else train_dataset + new_dataset train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size, pin_memory=False, shuffle=True, num_workers=int(os.environ.get('SLURM_CPUS_PER_TASK', 6))-2, drop_last=True) logging.info('Training with %d image pairs' % len(train_dataset)) return train_loader