Update create_crestereo_subsets.py

This commit is contained in:
HTensor 2023-04-29 13:39:09 +08:00
parent 59ff17e149
commit a1cc25351d

View File

@ -1,35 +1,76 @@
import os import os
import sys
sys.path.append("..")
import numpy as np import numpy as np
import torch
import torch.utils.data as data
import torch.nn.functional as F
import logging import logging
import os import os
import re import shutil
import copy
import math
import random import random
from pathlib import Path from pathlib import Path
from glob import glob from glob import glob
import os.path as osp import matplotlib.pyplot as plt
from core.utils import frame_utils from core.utils import frame_utils
def unique(lst):
return dict(zip(*np.unique(lst, return_counts=True)))
def ensure_path_exists(path):
if not os.path.exists(path):
os.makedirs(path)
class CREStereo(): class CREStereo():
def __init__(self, aug_params=None, root='/data/CREStereo'): def __init__(self, aug_params=None, root='/data/CREStereo'):
self.root = root self.root = root
assert os.path.exists(root) assert os.path.exists(root)
disp_list = self.selector('_left.disp.png') # disp_list = self.selector('_left.disp.png')
image1_list = self.selector('_left.jpg') # image1_list = self.selector('_left.jpg')
image2_list = self.selector('_right.jpg') # image2_list = self.selector('_right.jpg')
# assert len(image1_list) == len(image2_list) == len(disp_list) > 0
# for img1, img2, disp in zip(image1_list, image2_list, disp_list):
# # if random.randint(1, 20000) != 1:
# # continue
# self.image_list += [[img1, img2]]
# self.disparity_list += [disp]
def get_path_info(self, path):
position, filename = os.path.split(path)
root, sub_folder = os.path.split(position)
return root, sub_folder, filename
def get_new_file(self, path):
root, sub_folder, filename = self.get_path_info(path)
return os.path.join(root, 'subset', sub_folder, filename)
def divide(self, num):
ensure_path_exists(os.path.join(self.root, 'subset'))
for sub_folder in ['tree', 'shapenet', 'reflective', 'hole']:
ensure_path_exists(os.path.join(self.root, 'subset', sub_folder))
disp1_list = self.single_folder_selector(sub_folder, '_left.disp.png')
disp2_list = self.single_folder_selector(sub_folder, '_right.disp.png')
image1_list = self.single_folder_selector(sub_folder, '_left.jpg')
image2_list = self.single_folder_selector(sub_folder, '_right.jpg')
assert len(image1_list) == len(image2_list) == len(disp1_list) == len(disp2_list) > 0
lists = []
for img1, img2, disp1, disp2 in zip(image1_list, image2_list, disp1_list, disp2_list):
lists += [[img1, img2, disp1, disp2]]
subset = random.sample(lists, num)
for s in subset:
for element in s:
print(element)
print(self.get_new_file(element))
shutil.copy(element, self.get_new_file(element))
assert len(image1_list) == len(image2_list) == len(disp_list) > 0
for img1, img2, disp in zip(image1_list, image2_list, disp_list):
# if random.randint(1, 20000) != 1:
# continue
self.image_list += [[img1, img2]]
self.disparity_list += [disp]
def selector(self, suffix): def selector(self, suffix):
files = list(glob(os.path.join(self.root, f"hole/*{suffix}"))) files = list(glob(os.path.join(self.root, f"hole/*{suffix}")))
@ -38,12 +79,34 @@ class CREStereo():
files += list(glob(os.path.join(self.root, f"reflective/*{suffix}"))) files += list(glob(os.path.join(self.root, f"reflective/*{suffix}")))
return sorted(files) return sorted(files)
def single_folder_selector(self, sub_folder, suffix):
return sorted(list(glob(os.path.join(self.root, f"{sub_folder}/*{suffix}"))))
def disparity_distribution(self): def disparity_distribution(self):
disp_lists = self.selector('_left.disp.png') disp_lists = self.selector('_left.disp.png')
disparities = [] disparities = {}
for filename in disp_lists: for filename in disp_lists:
disp_gt = frame_utils.readDispCREStereo(filename) print(filename)
print(disp_gt.shape) disp_gt, _ = frame_utils.readDispCREStereo(filename)
[rows, cols] = disp_gt.shape
disp_gt = (disp_gt * 32).astype(int)
cnt = unique(disp_gt)
for i in cnt:
if i in disparities:
disparities[i] += cnt[i]
else:
disparities[i] = cnt[i]
x = []
y = []
for key in disparities.keys():
x.append(key / 32)
y.append(disparities[key])
plt.scatter(x, y)
plt.show()
CREStereo.disparity_distribution() c = CREStereo()
c.divide(10000)