Compare commits
22 Commits
add-licens
...
main
Author | SHA1 | Date | |
---|---|---|---|
|
74bc06d58e | ||
|
0404f5b5b0 | ||
|
3b68318ed5 | ||
|
a7d89bd95c | ||
|
e7033dabf9 | ||
|
9df896bb70 | ||
|
73e65f99b8 | ||
a1cc25351d | |||
59ff17e149 | |||
75735c8fae | |||
21e3f92461 | |||
8591c2edad | |||
875a1eec05 | |||
|
1080b823a5 | ||
0a3613711b | |||
3f60e691f8 | |||
0a20c0a001 | |||
c3b4812e99 | |||
|
25753af380 | ||
a301ebb020 | |||
|
93371a3c4d | ||
|
54a2eabbff |
161
.gitignore
vendored
Normal file
161
.gitignore
vendored
Normal file
@ -0,0 +1,161 @@
|
|||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
share/python-wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
*.py,cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
cover/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
db.sqlite3-journal
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
.pybuilder/
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# IPython
|
||||||
|
profile_default/
|
||||||
|
ipython_config.py
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
# For a library or package, you might want to ignore these files since the code is
|
||||||
|
# intended to run in multiple environments; otherwise, check them in:
|
||||||
|
# .python-version
|
||||||
|
|
||||||
|
# pipenv
|
||||||
|
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||||
|
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||||
|
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||||
|
# install all needed dependencies.
|
||||||
|
#Pipfile.lock
|
||||||
|
|
||||||
|
# poetry
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||||
|
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||||
|
# commonly ignored for libraries.
|
||||||
|
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||||
|
#poetry.lock
|
||||||
|
|
||||||
|
# pdm
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||||
|
#pdm.lock
|
||||||
|
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||||
|
# in version control.
|
||||||
|
# https://pdm.fming.dev/#use-with-ide
|
||||||
|
.pdm.toml
|
||||||
|
|
||||||
|
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||||
|
__pypackages__/
|
||||||
|
|
||||||
|
# Celery stuff
|
||||||
|
celerybeat-schedule
|
||||||
|
celerybeat.pid
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
||||||
|
# Pyre type checker
|
||||||
|
.pyre/
|
||||||
|
|
||||||
|
# pytype static type analyzer
|
||||||
|
.pytype/
|
||||||
|
|
||||||
|
# Cython debug symbols
|
||||||
|
cython_debug/
|
||||||
|
|
||||||
|
# PyCharm
|
||||||
|
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||||
|
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||||
|
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||||
|
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||||
|
.idea/
|
||||||
|
.DS_Store
|
4
IGEV-Stereo/.gitignore
vendored
Normal file
4
IGEV-Stereo/.gitignore
vendored
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
checkpoints/
|
||||||
|
demo-output/
|
||||||
|
pretrained_models/
|
||||||
|
/trace.json
|
@ -325,21 +325,24 @@ class SubModule(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class Feature(SubModule):
|
class Feature(SubModule):
|
||||||
def __init__(self):
|
def __init__(self, freeze):
|
||||||
super(Feature, self).__init__()
|
super(Feature, self).__init__()
|
||||||
pretrained = True
|
pretrained = True
|
||||||
model = timm.create_model('mobilenetv2_100', pretrained=pretrained, features_only=True)
|
self.model = timm.create_model('mobilenetv2_100', pretrained=pretrained, features_only=True)
|
||||||
|
if freeze:
|
||||||
|
for p in self.model.parameters():
|
||||||
|
p.requires_grad = False
|
||||||
layers = [1,2,3,5,6]
|
layers = [1,2,3,5,6]
|
||||||
chans = [16, 24, 32, 96, 160]
|
chans = [16, 24, 32, 96, 160]
|
||||||
self.conv_stem = model.conv_stem
|
self.conv_stem = self.model.conv_stem
|
||||||
self.bn1 = model.bn1
|
self.bn1 = self.model.bn1
|
||||||
self.act1 = model.act1
|
self.act1 = self.model.act1
|
||||||
|
|
||||||
self.block0 = torch.nn.Sequential(*model.blocks[0:layers[0]])
|
self.block0 = torch.nn.Sequential(*self.model.blocks[0:layers[0]])
|
||||||
self.block1 = torch.nn.Sequential(*model.blocks[layers[0]:layers[1]])
|
self.block1 = torch.nn.Sequential(*self.model.blocks[layers[0]:layers[1]])
|
||||||
self.block2 = torch.nn.Sequential(*model.blocks[layers[1]:layers[2]])
|
self.block2 = torch.nn.Sequential(*self.model.blocks[layers[1]:layers[2]])
|
||||||
self.block3 = torch.nn.Sequential(*model.blocks[layers[2]:layers[3]])
|
self.block3 = torch.nn.Sequential(*self.model.blocks[layers[2]:layers[3]])
|
||||||
self.block4 = torch.nn.Sequential(*model.blocks[layers[3]:layers[4]])
|
self.block4 = torch.nn.Sequential(*self.model.blocks[layers[3]:layers[4]])
|
||||||
|
|
||||||
self.deconv32_16 = Conv2x_IN(chans[4], chans[3], deconv=True, concat=True)
|
self.deconv32_16 = Conv2x_IN(chans[4], chans[3], deconv=True, concat=True)
|
||||||
self.deconv16_8 = Conv2x_IN(chans[3]*2, chans[2], deconv=True, concat=True)
|
self.deconv16_8 = Conv2x_IN(chans[3]*2, chans[2], deconv=True, concat=True)
|
||||||
|
@ -37,8 +37,8 @@ class Combined_Geo_Encoding_Volume:
|
|||||||
out_pyramid = []
|
out_pyramid = []
|
||||||
for i in range(self.num_levels):
|
for i in range(self.num_levels):
|
||||||
geo_volume = self.geo_volume_pyramid[i]
|
geo_volume = self.geo_volume_pyramid[i]
|
||||||
dx = torch.linspace(-r, r, 2*r+1)
|
dx = torch.linspace(-r, r, 2*r+1, device=disp.device)
|
||||||
dx = dx.view(1, 1, 2*r+1, 1).to(disp.device)
|
dx = dx.view(1, 1, 2*r+1, 1)
|
||||||
x0 = dx + disp.reshape(b*h*w, 1, 1, 1) / 2**i
|
x0 = dx + disp.reshape(b*h*w, 1, 1, 1) / 2**i
|
||||||
y0 = torch.zeros_like(x0)
|
y0 = torch.zeros_like(x0)
|
||||||
|
|
||||||
@ -60,6 +60,7 @@ class Combined_Geo_Encoding_Volume:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def corr(fmap1, fmap2):
|
def corr(fmap1, fmap2):
|
||||||
|
# batch, dim, ht, wd
|
||||||
B, D, H, W1 = fmap1.shape
|
B, D, H, W1 = fmap1.shape
|
||||||
_, _, _, W2 = fmap2.shape
|
_, _, _, W2 = fmap2.shape
|
||||||
fmap1 = fmap1.view(B, D, H, W1)
|
fmap1 = fmap1.view(B, D, H, W1)
|
||||||
|
@ -100,7 +100,7 @@ class IGEVStereo(nn.Module):
|
|||||||
|
|
||||||
self.context_zqr_convs = nn.ModuleList([nn.Conv2d(context_dims[i], args.hidden_dims[i]*3, 3, padding=3//2) for i in range(self.args.n_gru_layers)])
|
self.context_zqr_convs = nn.ModuleList([nn.Conv2d(context_dims[i], args.hidden_dims[i]*3, 3, padding=3//2) for i in range(self.args.n_gru_layers)])
|
||||||
|
|
||||||
self.feature = Feature()
|
self.feature = Feature(args.freeze_backbone_params)
|
||||||
|
|
||||||
self.stem_2 = nn.Sequential(
|
self.stem_2 = nn.Sequential(
|
||||||
BasicConv_IN(3, 32, kernel_size=3, stride=2, padding=1),
|
BasicConv_IN(3, 32, kernel_size=3, stride=2, padding=1),
|
||||||
@ -165,8 +165,9 @@ class IGEVStereo(nn.Module):
|
|||||||
|
|
||||||
match_left = self.desc(self.conv(features_left[0]))
|
match_left = self.desc(self.conv(features_left[0]))
|
||||||
match_right = self.desc(self.conv(features_right[0]))
|
match_right = self.desc(self.conv(features_right[0]))
|
||||||
gwc_volume = build_gwc_volume(match_left, match_right, 192//4, 8)
|
gwc_volume = build_gwc_volume(match_left, match_right, self.args.max_disp//4, 8)
|
||||||
gwc_volume = self.corr_stem(gwc_volume)
|
gwc_volume = self.corr_stem(gwc_volume)
|
||||||
|
# 3d unet
|
||||||
gwc_volume = self.corr_feature_att(gwc_volume, features_left[0])
|
gwc_volume = self.corr_feature_att(gwc_volume, features_left[0])
|
||||||
geo_encoding_volume = self.cost_agg(gwc_volume, features_left)
|
geo_encoding_volume = self.cost_agg(gwc_volume, features_left)
|
||||||
|
|
||||||
@ -182,6 +183,7 @@ class IGEVStereo(nn.Module):
|
|||||||
spx_pred = self.spx(xspx)
|
spx_pred = self.spx(xspx)
|
||||||
spx_pred = F.softmax(spx_pred, 1)
|
spx_pred = F.softmax(spx_pred, 1)
|
||||||
|
|
||||||
|
# Content Network
|
||||||
cnet_list = self.cnet(image1, num_layers=self.args.n_gru_layers)
|
cnet_list = self.cnet(image1, num_layers=self.args.n_gru_layers)
|
||||||
net_list = [torch.tanh(x[0]) for x in cnet_list]
|
net_list = [torch.tanh(x[0]) for x in cnet_list]
|
||||||
inp_list = [torch.relu(x[1]) for x in cnet_list]
|
inp_list = [torch.relu(x[1]) for x in cnet_list]
|
||||||
@ -191,7 +193,7 @@ class IGEVStereo(nn.Module):
|
|||||||
geo_block = Combined_Geo_Encoding_Volume
|
geo_block = Combined_Geo_Encoding_Volume
|
||||||
geo_fn = geo_block(match_left.float(), match_right.float(), geo_encoding_volume.float(), radius=self.args.corr_radius, num_levels=self.args.corr_levels)
|
geo_fn = geo_block(match_left.float(), match_right.float(), geo_encoding_volume.float(), radius=self.args.corr_radius, num_levels=self.args.corr_levels)
|
||||||
b, c, h, w = match_left.shape
|
b, c, h, w = match_left.shape
|
||||||
coords = torch.arange(w).float().to(match_left.device).reshape(1,1,w,1).repeat(b, h, 1, 1)
|
coords = torch.arange(w, device=match_left.device).float().reshape(1,1,w,1).repeat(b, h, 1, 1)
|
||||||
disp = init_disp
|
disp = init_disp
|
||||||
disp_preds = []
|
disp_preds = []
|
||||||
|
|
||||||
|
@ -266,9 +266,34 @@ class KITTI(StereoDataset):
|
|||||||
self.image_list += [ [img1, img2] ]
|
self.image_list += [ [img1, img2] ]
|
||||||
self.disparity_list += [ disp ]
|
self.disparity_list += [ disp ]
|
||||||
|
|
||||||
|
class CREStereo(StereoDataset):
|
||||||
|
def __init__(self, aug_params=None, root='/data/CREStereo'):
|
||||||
|
super(CREStereo, self).__init__(aug_params, sparse=True, reader=frame_utils.readDispCREStereo)
|
||||||
|
self.root = root
|
||||||
|
assert os.path.exists(root)
|
||||||
|
|
||||||
|
disp_list = self.selector('_left.disp.png')
|
||||||
|
image1_list = self.selector('_left.jpg')
|
||||||
|
image2_list = self.selector('_right.jpg')
|
||||||
|
|
||||||
|
assert len(image1_list) == len(image2_list) == len(disp_list) > 0
|
||||||
|
for img1, img2, disp in zip(image1_list, image2_list, disp_list):
|
||||||
|
# if random.randint(1, 20000) != 1:
|
||||||
|
# continue
|
||||||
|
self.image_list += [[img1, img2]]
|
||||||
|
self.disparity_list += [disp]
|
||||||
|
|
||||||
|
def selector(self, suffix):
|
||||||
|
files = list(glob(os.path.join(self.root, f"hole/*{suffix}")))
|
||||||
|
files += list(glob(os.path.join(self.root, f"shapenet/*{suffix}")))
|
||||||
|
files += list(glob(os.path.join(self.root, f"tree/*{suffix}")))
|
||||||
|
files += list(glob(os.path.join(self.root, f"reflective/*{suffix}")))
|
||||||
|
return sorted(files)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Middlebury(StereoDataset):
|
class Middlebury(StereoDataset):
|
||||||
def __init__(self, aug_params=None, root='/data/Middlebury', split='F'):
|
def __init__(self, aug_params=None, root='/data/Middlebury', split='H'):
|
||||||
super(Middlebury, self).__init__(aug_params, sparse=True, reader=frame_utils.readDispMiddlebury)
|
super(Middlebury, self).__init__(aug_params, sparse=True, reader=frame_utils.readDispMiddlebury)
|
||||||
assert os.path.exists(root)
|
assert os.path.exists(root)
|
||||||
assert split in "FHQ"
|
assert split in "FHQ"
|
||||||
@ -321,10 +346,13 @@ def fetch_dataloader(args):
|
|||||||
elif dataset_name.startswith('tartan_air'):
|
elif dataset_name.startswith('tartan_air'):
|
||||||
new_dataset = TartanAir(aug_params, keywords=dataset_name.split('_')[2:])
|
new_dataset = TartanAir(aug_params, keywords=dataset_name.split('_')[2:])
|
||||||
logging.info(f"Adding {len(new_dataset)} samples from Tartain Air")
|
logging.info(f"Adding {len(new_dataset)} samples from Tartain Air")
|
||||||
|
elif dataset_name.startswith('crestereo'):
|
||||||
|
new_dataset = CREStereo(aug_params)
|
||||||
|
logging.info(f"Adding {len(new_dataset)} samples from CREStereo")
|
||||||
train_dataset = new_dataset if train_dataset is None else train_dataset + new_dataset
|
train_dataset = new_dataset if train_dataset is None else train_dataset + new_dataset
|
||||||
|
|
||||||
train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size,
|
train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size,
|
||||||
pin_memory=True, shuffle=True, num_workers=int(os.environ.get('SLURM_CPUS_PER_TASK', 6))-2, drop_last=True)
|
pin_memory=True, shuffle=True, num_workers=12, drop_last=True)
|
||||||
|
|
||||||
logging.info('Training with %d image pairs' % len(train_dataset))
|
logging.info('Training with %d image pairs' % len(train_dataset))
|
||||||
return train_loader
|
return train_loader
|
||||||
|
@ -157,6 +157,7 @@ def groupwise_correlation(fea1, fea2, num_groups):
|
|||||||
return cost
|
return cost
|
||||||
|
|
||||||
def build_gwc_volume(refimg_fea, targetimg_fea, maxdisp, num_groups):
|
def build_gwc_volume(refimg_fea, targetimg_fea, maxdisp, num_groups):
|
||||||
|
# batch, groups, max_disp, height, width
|
||||||
B, C, H, W = refimg_fea.shape
|
B, C, H, W = refimg_fea.shape
|
||||||
volume = refimg_fea.new_zeros([B, num_groups, maxdisp, H, W])
|
volume = refimg_fea.new_zeros([B, num_groups, maxdisp, H, W])
|
||||||
for i in range(maxdisp):
|
for i in range(maxdisp):
|
||||||
|
@ -5,7 +5,7 @@ import os
|
|||||||
import time
|
import time
|
||||||
from glob import glob
|
from glob import glob
|
||||||
from skimage import color, io
|
from skimage import color, io
|
||||||
from PIL import Image
|
from PIL import Image, ImageEnhance
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
cv2.setNumThreads(0)
|
cv2.setNumThreads(0)
|
||||||
@ -198,21 +198,40 @@ class SparseFlowAugmentor:
|
|||||||
self.v_flip_prob = 0.1
|
self.v_flip_prob = 0.1
|
||||||
|
|
||||||
# photometric augmentation params
|
# photometric augmentation params
|
||||||
self.photo_aug = Compose([ColorJitter(brightness=0.3, contrast=0.3, saturation=saturation_range, hue=0.3/3.14), AdjustGamma(*gamma)])
|
# self.photo_aug = Compose([ColorJitter(brightness=0.3, contrast=0.3, saturation=saturation_range, hue=0.3/3.14), AdjustGamma(*gamma)])
|
||||||
self.asymmetric_color_aug_prob = 0.2
|
|
||||||
self.eraser_aug_prob = 0.5
|
self.eraser_aug_prob = 0.5
|
||||||
|
|
||||||
|
def chromatic_augmentation(self, img):
|
||||||
|
random_brightness = np.random.uniform(0.8, 1.2)
|
||||||
|
random_contrast = np.random.uniform(0.8, 1.2)
|
||||||
|
random_gamma = np.random.uniform(0.8, 1.2)
|
||||||
|
|
||||||
|
img = Image.fromarray(img)
|
||||||
|
|
||||||
|
enhancer = ImageEnhance.Brightness(img)
|
||||||
|
img = enhancer.enhance(random_brightness)
|
||||||
|
enhancer = ImageEnhance.Contrast(img)
|
||||||
|
img = enhancer.enhance(random_contrast)
|
||||||
|
|
||||||
|
gamma_map = [
|
||||||
|
255 * 1.0 * pow(ele / 255.0, random_gamma) for ele in range(256)
|
||||||
|
] * 3
|
||||||
|
img = img.point(gamma_map) # use PIL's point-function to accelerate this part
|
||||||
|
|
||||||
|
img_ = np.array(img)
|
||||||
|
|
||||||
|
return img_
|
||||||
|
|
||||||
def color_transform(self, img1, img2):
|
def color_transform(self, img1, img2):
|
||||||
image_stack = np.concatenate([img1, img2], axis=0)
|
img1 = self.chromatic_augmentation(img1)
|
||||||
image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
|
img2 = self.chromatic_augmentation(img2)
|
||||||
img1, img2 = np.split(image_stack, 2, axis=0)
|
|
||||||
return img1, img2
|
return img1, img2
|
||||||
|
|
||||||
def eraser_transform(self, img1, img2):
|
def eraser_transform(self, img1, img2):
|
||||||
ht, wd = img1.shape[:2]
|
ht, wd = img1.shape[:2]
|
||||||
if np.random.rand() < self.eraser_aug_prob:
|
if np.random.rand() < self.eraser_aug_prob:
|
||||||
mean_color = np.mean(img2.reshape(-1, 3), axis=0)
|
mean_color = np.mean(img2.reshape(-1, 3), axis=0)
|
||||||
for _ in range(np.random.randint(1, 3)):
|
for _ in range(1):
|
||||||
x0 = np.random.randint(0, wd)
|
x0 = np.random.randint(0, wd)
|
||||||
y0 = np.random.randint(0, ht)
|
y0 = np.random.randint(0, ht)
|
||||||
dx = np.random.randint(50, 100)
|
dx = np.random.randint(50, 100)
|
||||||
|
@ -152,6 +152,11 @@ def readDispTartanAir(file_name):
|
|||||||
valid = disp > 0
|
valid = disp > 0
|
||||||
return disp, valid
|
return disp, valid
|
||||||
|
|
||||||
|
def readDispCREStereo(file_name):
|
||||||
|
disp = cv2.imread(file_name, cv2.IMREAD_UNCHANGED)
|
||||||
|
disp = disp.astype(np.float32) / 32.0
|
||||||
|
valid = disp > 0.0
|
||||||
|
return disp, valid
|
||||||
|
|
||||||
def readDispMiddlebury(file_name):
|
def readDispMiddlebury(file_name):
|
||||||
assert basename(file_name) == 'disp0GT.pfm'
|
assert basename(file_name) == 'disp0GT.pfm'
|
||||||
|
112
IGEV-Stereo/dataset_utils/create_crestereo_subsets.py
Normal file
112
IGEV-Stereo/dataset_utils/create_crestereo_subsets.py
Normal file
@ -0,0 +1,112 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
sys.path.append("..")
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import random
|
||||||
|
from pathlib import Path
|
||||||
|
from glob import glob
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
from core.utils import frame_utils
|
||||||
|
|
||||||
|
|
||||||
|
def unique(lst):
|
||||||
|
return dict(zip(*np.unique(lst, return_counts=True)))
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_path_exists(path):
|
||||||
|
if not os.path.exists(path):
|
||||||
|
os.makedirs(path)
|
||||||
|
|
||||||
|
|
||||||
|
class CREStereo():
|
||||||
|
def __init__(self, aug_params=None, root='/data/CREStereo'):
|
||||||
|
self.root = root
|
||||||
|
assert os.path.exists(root)
|
||||||
|
|
||||||
|
# disp_list = self.selector('_left.disp.png')
|
||||||
|
# image1_list = self.selector('_left.jpg')
|
||||||
|
# image2_list = self.selector('_right.jpg')
|
||||||
|
|
||||||
|
# assert len(image1_list) == len(image2_list) == len(disp_list) > 0
|
||||||
|
# for img1, img2, disp in zip(image1_list, image2_list, disp_list):
|
||||||
|
# # if random.randint(1, 20000) != 1:
|
||||||
|
# # continue
|
||||||
|
# self.image_list += [[img1, img2]]
|
||||||
|
# self.disparity_list += [disp]
|
||||||
|
|
||||||
|
def get_path_info(self, path):
|
||||||
|
position, filename = os.path.split(path)
|
||||||
|
root, sub_folder = os.path.split(position)
|
||||||
|
return root, sub_folder, filename
|
||||||
|
|
||||||
|
def get_new_file(self, path):
|
||||||
|
root, sub_folder, filename = self.get_path_info(path)
|
||||||
|
return os.path.join(root, 'subset', sub_folder, filename)
|
||||||
|
|
||||||
|
def divide(self, num):
|
||||||
|
ensure_path_exists(os.path.join(self.root, 'subset'))
|
||||||
|
for sub_folder in ['tree', 'shapenet', 'reflective', 'hole']:
|
||||||
|
ensure_path_exists(os.path.join(self.root, 'subset', sub_folder))
|
||||||
|
disp1_list = self.single_folder_selector(sub_folder, '_left.disp.png')
|
||||||
|
disp2_list = self.single_folder_selector(sub_folder, '_right.disp.png')
|
||||||
|
image1_list = self.single_folder_selector(sub_folder, '_left.jpg')
|
||||||
|
image2_list = self.single_folder_selector(sub_folder, '_right.jpg')
|
||||||
|
|
||||||
|
assert len(image1_list) == len(image2_list) == len(disp1_list) == len(disp2_list) > 0
|
||||||
|
lists = []
|
||||||
|
for img1, img2, disp1, disp2 in zip(image1_list, image2_list, disp1_list, disp2_list):
|
||||||
|
lists += [[img1, img2, disp1, disp2]]
|
||||||
|
subset = random.sample(lists, num)
|
||||||
|
|
||||||
|
for s in subset:
|
||||||
|
for element in s:
|
||||||
|
print(element)
|
||||||
|
print(self.get_new_file(element))
|
||||||
|
shutil.copy(element, self.get_new_file(element))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def selector(self, suffix):
|
||||||
|
files = list(glob(os.path.join(self.root, f"hole/*{suffix}")))
|
||||||
|
files += list(glob(os.path.join(self.root, f"shapenet/*{suffix}")))
|
||||||
|
files += list(glob(os.path.join(self.root, f"tree/*{suffix}")))
|
||||||
|
files += list(glob(os.path.join(self.root, f"reflective/*{suffix}")))
|
||||||
|
return sorted(files)
|
||||||
|
|
||||||
|
def single_folder_selector(self, sub_folder, suffix):
|
||||||
|
return sorted(list(glob(os.path.join(self.root, f"{sub_folder}/*{suffix}"))))
|
||||||
|
|
||||||
|
def disparity_distribution(self):
|
||||||
|
disp_lists = self.selector('_left.disp.png')
|
||||||
|
disparities = {}
|
||||||
|
for filename in disp_lists:
|
||||||
|
print(filename)
|
||||||
|
disp_gt, _ = frame_utils.readDispCREStereo(filename)
|
||||||
|
[rows, cols] = disp_gt.shape
|
||||||
|
disp_gt = (disp_gt * 32).astype(int)
|
||||||
|
|
||||||
|
cnt = unique(disp_gt)
|
||||||
|
for i in cnt:
|
||||||
|
if i in disparities:
|
||||||
|
disparities[i] += cnt[i]
|
||||||
|
else:
|
||||||
|
disparities[i] = cnt[i]
|
||||||
|
|
||||||
|
x = []
|
||||||
|
y = []
|
||||||
|
for key in disparities.keys():
|
||||||
|
x.append(key / 32)
|
||||||
|
y.append(disparities[key])
|
||||||
|
|
||||||
|
plt.scatter(x, y)
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
c = CREStereo()
|
||||||
|
c.divide(10000)
|
@ -26,6 +26,7 @@ def demo(args):
|
|||||||
model.load_state_dict(torch.load(args.restore_ckpt))
|
model.load_state_dict(torch.load(args.restore_ckpt))
|
||||||
|
|
||||||
model = model.module
|
model = model.module
|
||||||
|
# model = torch.compile(model)
|
||||||
model.to(DEVICE)
|
model.to(DEVICE)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
|
8
IGEV-Stereo/evaluate-history.sh
Executable file
8
IGEV-Stereo/evaluate-history.sh
Executable file
@ -0,0 +1,8 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
for iter in {1..100}
|
||||||
|
do
|
||||||
|
iter=$((iter * 1000))
|
||||||
|
echo "These are the results of ${iter} iterations."
|
||||||
|
python evaluate_stereo.py --dataset middlebury_H --max_disp 384 --freeze_backbone_params --restore_ckpt ~/checkpoints/igev_stereo/${iter}_fix-validation.pth
|
||||||
|
done
|
||||||
|
|
@ -20,7 +20,7 @@ def count_parameters(model):
|
|||||||
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def validate_eth3d(model, iters=32, mixed_prec=False):
|
def validate_eth3d(model, iters=32, mixed_prec=False, max_disp=192):
|
||||||
""" Peform validation using the ETH3D (train) split """
|
""" Peform validation using the ETH3D (train) split """
|
||||||
model.eval()
|
model.eval()
|
||||||
aug_params = {}
|
aug_params = {}
|
||||||
@ -67,7 +67,7 @@ def validate_eth3d(model, iters=32, mixed_prec=False):
|
|||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def validate_kitti(model, iters=32, mixed_prec=False):
|
def validate_kitti(model, iters=32, mixed_prec=False, max_disp=192):
|
||||||
""" Peform validation using the KITTI-2015 (train) split """
|
""" Peform validation using the KITTI-2015 (train) split """
|
||||||
model.eval()
|
model.eval()
|
||||||
aug_params = {}
|
aug_params = {}
|
||||||
@ -96,7 +96,7 @@ def validate_kitti(model, iters=32, mixed_prec=False):
|
|||||||
epe = torch.sum((flow_pr - flow_gt)**2, dim=0).sqrt()
|
epe = torch.sum((flow_pr - flow_gt)**2, dim=0).sqrt()
|
||||||
|
|
||||||
epe_flattened = epe.flatten()
|
epe_flattened = epe.flatten()
|
||||||
val = (valid_gt.flatten() >= 0.5) & (flow_gt.abs().flatten() < 192)
|
val = (valid_gt.flatten() >= 0.5) & (flow_gt.abs().flatten() < max_disp)
|
||||||
# val = valid_gt.flatten() >= 0.5
|
# val = valid_gt.flatten() >= 0.5
|
||||||
|
|
||||||
out = (epe_flattened > 3.0)
|
out = (epe_flattened > 3.0)
|
||||||
@ -120,7 +120,7 @@ def validate_kitti(model, iters=32, mixed_prec=False):
|
|||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def validate_sceneflow(model, iters=32, mixed_prec=False):
|
def validate_sceneflow(model, iters=32, mixed_prec=False, max_disp=192):
|
||||||
""" Peform validation using the Scene Flow (TEST) split """
|
""" Peform validation using the Scene Flow (TEST) split """
|
||||||
model.eval()
|
model.eval()
|
||||||
val_dataset = datasets.SceneFlowDatasets(dstype='frames_finalpass', things_test=True)
|
val_dataset = datasets.SceneFlowDatasets(dstype='frames_finalpass', things_test=True)
|
||||||
@ -144,7 +144,7 @@ def validate_sceneflow(model, iters=32, mixed_prec=False):
|
|||||||
epe = torch.abs(flow_pr - flow_gt)
|
epe = torch.abs(flow_pr - flow_gt)
|
||||||
|
|
||||||
epe = epe.flatten()
|
epe = epe.flatten()
|
||||||
val = (valid_gt.flatten() >= 0.5) & (flow_gt.abs().flatten() < 192)
|
val = (valid_gt.flatten() >= 0.5) & (flow_gt.abs().flatten() < max_disp)
|
||||||
|
|
||||||
if(np.isnan(epe[val].mean().item())):
|
if(np.isnan(epe[val].mean().item())):
|
||||||
continue
|
continue
|
||||||
@ -169,7 +169,7 @@ def validate_sceneflow(model, iters=32, mixed_prec=False):
|
|||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def validate_middlebury(model, iters=32, split='F', mixed_prec=False):
|
def validate_middlebury(model, iters=32, split='H', mixed_prec=False, max_disp=192):
|
||||||
""" Peform validation using the Middlebury-V3 dataset """
|
""" Peform validation using the Middlebury-V3 dataset """
|
||||||
model.eval()
|
model.eval()
|
||||||
aug_params = {}
|
aug_params = {}
|
||||||
@ -196,11 +196,11 @@ def validate_middlebury(model, iters=32, split='F', mixed_prec=False):
|
|||||||
occ_mask = Image.open(imageL_file.replace('im0.png', 'mask0nocc.png')).convert('L')
|
occ_mask = Image.open(imageL_file.replace('im0.png', 'mask0nocc.png')).convert('L')
|
||||||
occ_mask = np.ascontiguousarray(occ_mask, dtype=np.float32).flatten()
|
occ_mask = np.ascontiguousarray(occ_mask, dtype=np.float32).flatten()
|
||||||
|
|
||||||
val = (valid_gt.reshape(-1) >= 0.5) & (flow_gt[0].reshape(-1) < 192) & (occ_mask==255)
|
val = (valid_gt.reshape(-1) >= 0.5) & (flow_gt[0].reshape(-1) < max_disp) & (occ_mask==255)
|
||||||
out = (epe_flattened > 2.0)
|
out = (epe_flattened > 2.0)
|
||||||
image_out = out[val].float().mean().item()
|
image_out = out[val].float().mean().item()
|
||||||
image_epe = epe_flattened[val].mean().item()
|
image_epe = epe_flattened[val].mean().item()
|
||||||
logging.info(f"Middlebury Iter {val_id+1} out of {len(val_dataset)}. EPE {round(image_epe,4)} D1 {round(image_out,4)}")
|
logging.info(f"Middlebury Iter {val_id+1} out of {len(val_dataset)}. EPE {round(image_epe,4)} Err2.0 {round(image_out,4)}")
|
||||||
epe_list.append(image_epe)
|
epe_list.append(image_epe)
|
||||||
out_list.append(image_out)
|
out_list.append(image_out)
|
||||||
|
|
||||||
@ -208,10 +208,10 @@ def validate_middlebury(model, iters=32, split='F', mixed_prec=False):
|
|||||||
out_list = np.array(out_list)
|
out_list = np.array(out_list)
|
||||||
|
|
||||||
epe = np.mean(epe_list)
|
epe = np.mean(epe_list)
|
||||||
d1 = 100 * np.mean(out_list)
|
err2 = 100 * np.mean(out_list)
|
||||||
|
|
||||||
print(f"Validation Middlebury{split}: EPE {epe}, D1 {d1}")
|
print(f"Validation Middlebury{split}: EPE {epe}, Err2.0 {err2}")
|
||||||
return {f'middlebury{split}-epe': epe, f'middlebury{split}-d1': d1}
|
return {f'middlebury{split}-epe': epe, f'middlebury{split}-err2.0': err2}
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
@ -231,6 +231,7 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument('--slow_fast_gru', action='store_true', help="iterate the low-res GRUs more frequently")
|
parser.add_argument('--slow_fast_gru', action='store_true', help="iterate the low-res GRUs more frequently")
|
||||||
parser.add_argument('--n_gru_layers', type=int, default=3, help="number of hidden GRU levels")
|
parser.add_argument('--n_gru_layers', type=int, default=3, help="number of hidden GRU levels")
|
||||||
parser.add_argument('--max_disp', type=int, default=192, help="max disp of geometry encoding volume")
|
parser.add_argument('--max_disp', type=int, default=192, help="max disp of geometry encoding volume")
|
||||||
|
parser.add_argument('--freeze_backbone_params', action="store_true", help="freeze backbone parameters")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
model = torch.nn.DataParallel(IGEVStereo(args), device_ids=[0])
|
model = torch.nn.DataParallel(IGEVStereo(args), device_ids=[0])
|
||||||
@ -242,6 +243,12 @@ if __name__ == '__main__':
|
|||||||
assert args.restore_ckpt.endswith(".pth")
|
assert args.restore_ckpt.endswith(".pth")
|
||||||
logging.info("Loading checkpoint...")
|
logging.info("Loading checkpoint...")
|
||||||
checkpoint = torch.load(args.restore_ckpt)
|
checkpoint = torch.load(args.restore_ckpt)
|
||||||
|
|
||||||
|
unwanted_prefix = '_orig_mod.'
|
||||||
|
for k, v in list(checkpoint.items()):
|
||||||
|
if k.startswith(unwanted_prefix):
|
||||||
|
checkpoint[k[len(unwanted_prefix):]] = checkpoint.pop(k)
|
||||||
|
|
||||||
model.load_state_dict(checkpoint, strict=True)
|
model.load_state_dict(checkpoint, strict=True)
|
||||||
logging.info(f"Done loading checkpoint")
|
logging.info(f"Done loading checkpoint")
|
||||||
|
|
||||||
@ -252,13 +259,13 @@ if __name__ == '__main__':
|
|||||||
use_mixed_precision = args.corr_implementation.endswith("_cuda")
|
use_mixed_precision = args.corr_implementation.endswith("_cuda")
|
||||||
|
|
||||||
if args.dataset == 'eth3d':
|
if args.dataset == 'eth3d':
|
||||||
validate_eth3d(model, iters=args.valid_iters, mixed_prec=use_mixed_precision)
|
validate_eth3d(model, iters=args.valid_iters, mixed_prec=use_mixed_precision, max_disp=args.max_disp)
|
||||||
|
|
||||||
elif args.dataset == 'kitti':
|
elif args.dataset == 'kitti':
|
||||||
validate_kitti(model, iters=args.valid_iters, mixed_prec=use_mixed_precision)
|
validate_kitti(model, iters=args.valid_iters, mixed_prec=use_mixed_precision, max_disp=args.max_disp)
|
||||||
|
|
||||||
elif args.dataset in [f"middlebury_{s}" for s in 'FHQ']:
|
elif args.dataset in [f"middlebury_{s}" for s in 'FHQ']:
|
||||||
validate_middlebury(model, iters=args.valid_iters, split=args.dataset[-1], mixed_prec=use_mixed_precision)
|
validate_middlebury(model, iters=args.valid_iters, split=args.dataset[-1], mixed_prec=use_mixed_precision, max_disp=args.max_disp)
|
||||||
|
|
||||||
elif args.dataset == 'sceneflow':
|
elif args.dataset == 'sceneflow':
|
||||||
validate_sceneflow(model, iters=args.valid_iters, mixed_prec=use_mixed_precision)
|
validate_sceneflow(model, iters=args.valid_iters, mixed_prec=use_mixed_precision, max_disp=args.max_disp)
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
|
|
||||||
from __future__ import print_function, division
|
from __future__ import print_function, division
|
||||||
|
|
||||||
|
import math
|
||||||
import os
|
import os
|
||||||
os.environ['CUDA_VISIBLE_DEVICES'] = '0, 1'
|
os.environ['CUDA_VISIBLE_DEVICES'] = '0, 1'
|
||||||
|
|
||||||
@ -53,7 +54,7 @@ def sequence_loss(disp_preds, disp_init_pred, disp_gt, valid, loss_gamma=0.9, ma
|
|||||||
assert not torch.isinf(disp_gt[valid.bool()]).any()
|
assert not torch.isinf(disp_gt[valid.bool()]).any()
|
||||||
|
|
||||||
|
|
||||||
disp_loss += 1.0 * F.smooth_l1_loss(disp_init_pred[valid.bool()], disp_gt[valid.bool()], size_average=True)
|
disp_loss += 1.0 * F.smooth_l1_loss(disp_init_pred[valid.bool()], disp_gt[valid.bool()], reduction='mean')
|
||||||
for i in range(n_predictions):
|
for i in range(n_predictions):
|
||||||
adjusted_loss_gamma = loss_gamma**(15/(n_predictions - 1))
|
adjusted_loss_gamma = loss_gamma**(15/(n_predictions - 1))
|
||||||
i_weight = adjusted_loss_gamma**(n_predictions - i - 1)
|
i_weight = adjusted_loss_gamma**(n_predictions - i - 1)
|
||||||
@ -66,9 +67,10 @@ def sequence_loss(disp_preds, disp_init_pred, disp_gt, valid, loss_gamma=0.9, ma
|
|||||||
|
|
||||||
metrics = {
|
metrics = {
|
||||||
'epe': epe.mean().item(),
|
'epe': epe.mean().item(),
|
||||||
'1px': (epe < 1).float().mean().item(),
|
'0.5px': (epe < 0.5).float().mean().item(),
|
||||||
'3px': (epe < 3).float().mean().item(),
|
'1.0px': (epe < 1.0).float().mean().item(),
|
||||||
'5px': (epe < 5).float().mean().item(),
|
'2.0px': (epe < 2.0).float().mean().item(),
|
||||||
|
'4.0px': (epe < 4.0).float().mean().item(),
|
||||||
}
|
}
|
||||||
|
|
||||||
return disp_loss, metrics
|
return disp_loss, metrics
|
||||||
@ -76,8 +78,9 @@ def sequence_loss(disp_preds, disp_init_pred, disp_gt, valid, loss_gamma=0.9, ma
|
|||||||
|
|
||||||
def fetch_optimizer(args, model):
|
def fetch_optimizer(args, model):
|
||||||
""" Create the optimizer and learning rate scheduler """
|
""" Create the optimizer and learning rate scheduler """
|
||||||
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wdecay, eps=1e-8)
|
optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=args.wdecay, eps=1e-8)
|
||||||
|
|
||||||
|
# todo: cosine scheduler, warm-up
|
||||||
scheduler = optim.lr_scheduler.OneCycleLR(optimizer, args.lr, args.num_steps+100,
|
scheduler = optim.lr_scheduler.OneCycleLR(optimizer, args.lr, args.num_steps+100,
|
||||||
pct_start=0.01, cycle_momentum=False, anneal_strategy='linear')
|
pct_start=0.01, cycle_momentum=False, anneal_strategy='linear')
|
||||||
return optimizer, scheduler
|
return optimizer, scheduler
|
||||||
@ -132,7 +135,7 @@ class Logger:
|
|||||||
|
|
||||||
|
|
||||||
def train(args):
|
def train(args):
|
||||||
|
# todo: compile the model to speed up at pytorch 2.0.
|
||||||
model = nn.DataParallel(IGEVStereo(args))
|
model = nn.DataParallel(IGEVStereo(args))
|
||||||
print("Parameter Count: %d" % count_parameters(model))
|
print("Parameter Count: %d" % count_parameters(model))
|
||||||
|
|
||||||
@ -151,14 +154,13 @@ def train(args):
|
|||||||
model.train()
|
model.train()
|
||||||
model.module.freeze_bn() # We keep BatchNorm frozen
|
model.module.freeze_bn() # We keep BatchNorm frozen
|
||||||
|
|
||||||
validation_frequency = 10000
|
validation_frequency = 1000
|
||||||
|
|
||||||
scaler = GradScaler(enabled=args.mixed_precision)
|
scaler = GradScaler(enabled=args.mixed_precision)
|
||||||
|
|
||||||
should_keep_training = True
|
should_keep_training = True
|
||||||
global_batch_num = 0
|
global_batch_num = 0
|
||||||
while should_keep_training:
|
while should_keep_training:
|
||||||
|
|
||||||
for i_batch, (_, *data_blob) in enumerate(tqdm(train_loader)):
|
for i_batch, (_, *data_blob) in enumerate(tqdm(train_loader)):
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
image1, image2, disp_gt, valid = [x.cuda() for x in data_blob]
|
image1, image2, disp_gt, valid = [x.cuda() for x in data_blob]
|
||||||
@ -175,6 +177,7 @@ def train(args):
|
|||||||
scaler.unscale_(optimizer)
|
scaler.unscale_(optimizer)
|
||||||
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
||||||
|
|
||||||
|
#warning
|
||||||
scaler.step(optimizer)
|
scaler.step(optimizer)
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
scaler.update()
|
scaler.update()
|
||||||
@ -184,7 +187,7 @@ def train(args):
|
|||||||
save_path = Path(ckpt_path + '/%d_%s.pth' % (total_steps + 1, args.name))
|
save_path = Path(ckpt_path + '/%d_%s.pth' % (total_steps + 1, args.name))
|
||||||
logging.info(f"Saving file {save_path.absolute()}")
|
logging.info(f"Saving file {save_path.absolute()}")
|
||||||
torch.save(model.state_dict(), save_path)
|
torch.save(model.state_dict(), save_path)
|
||||||
results = validate_sceneflow(model.module, iters=args.valid_iters)
|
results = validate_middlebury(model.module, iters=args.valid_iters, max_disp=args.max_disp)
|
||||||
logger.write_dict(results)
|
logger.write_dict(results)
|
||||||
model.train()
|
model.train()
|
||||||
model.module.freeze_bn()
|
model.module.freeze_bn()
|
||||||
@ -212,13 +215,14 @@ if __name__ == '__main__':
|
|||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--name', default='igev-stereo', help="name your experiment")
|
parser.add_argument('--name', default='igev-stereo', help="name your experiment")
|
||||||
parser.add_argument('--restore_ckpt', default=None, help="")
|
parser.add_argument('--restore_ckpt', default=None, help="")
|
||||||
parser.add_argument('--mixed_precision', default=True, action='store_true', help='use mixed precision')
|
parser.add_argument('--mixed_precision', default=False, action='store_true', help='use mixed precision')
|
||||||
|
|
||||||
# Training parameters
|
# Training parameters
|
||||||
parser.add_argument('--batch_size', type=int, default=8, help="batch size used during training.")
|
parser.add_argument('--batch_size', type=int, default=8, help="batch size used during training.")
|
||||||
parser.add_argument('--train_datasets', nargs='+', default=['sceneflow'], help="training datasets.")
|
parser.add_argument('--train_datasets', nargs='+', default=['sceneflow'], help="training datasets.")
|
||||||
parser.add_argument('--lr', type=float, default=0.0002, help="max learning rate.")
|
parser.add_argument('--lr', type=float, default=0.0002, help="max learning rate.")
|
||||||
parser.add_argument('--num_steps', type=int, default=200000, help="length of training schedule.")
|
parser.add_argument('--num_steps', type=int, default=200000, help="length of training schedule.")
|
||||||
|
parser.add_argument('--freeze_backbone_params', action="store_true", help="freeze backbone parameters")
|
||||||
parser.add_argument('--image_size', type=int, nargs='+', default=[320, 736], help="size of the random image crops used during training.")
|
parser.add_argument('--image_size', type=int, nargs='+', default=[320, 736], help="size of the random image crops used during training.")
|
||||||
parser.add_argument('--train_iters', type=int, default=22, help="number of updates to the disparity field in each forward pass.")
|
parser.add_argument('--train_iters', type=int, default=22, help="number of updates to the disparity field in each forward pass.")
|
||||||
parser.add_argument('--wdecay', type=float, default=.00001, help="Weight decay in optimizer.")
|
parser.add_argument('--wdecay', type=float, default=.00001, help="Weight decay in optimizer.")
|
||||||
@ -238,8 +242,8 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument('--max_disp', type=int, default=192, help="max disp of geometry encoding volume")
|
parser.add_argument('--max_disp', type=int, default=192, help="max disp of geometry encoding volume")
|
||||||
|
|
||||||
# Data augmentation
|
# Data augmentation
|
||||||
parser.add_argument('--img_gamma', type=float, nargs='+', default=None, help="gamma range")
|
# parser.add_argument('--img_gamma', type=float, nargs='+', default=None, help="gamma range")
|
||||||
parser.add_argument('--saturation_range', type=float, nargs='+', default=[0, 1.4], help='color saturation')
|
# parser.add_argument('--saturation_range', type=float, nargs='+', default=[0, 1.4], help='color saturation')
|
||||||
parser.add_argument('--do_flip', default=False, choices=['h', 'v'], help='flip the images horizontally or vertically')
|
parser.add_argument('--do_flip', default=False, choices=['h', 'v'], help='flip the images horizontally or vertically')
|
||||||
parser.add_argument('--spatial_scale', type=float, nargs='+', default=[-0.2, 0.4], help='re-scale the images randomly')
|
parser.add_argument('--spatial_scale', type=float, nargs='+', default=[-0.2, 0.4], help='re-scale the images randomly')
|
||||||
parser.add_argument('--noyjitter', action='store_true', help='don\'t simulate imperfect rectification')
|
parser.add_argument('--noyjitter', action='store_true', help='don\'t simulate imperfect rectification')
|
||||||
|
Loading…
Reference in New Issue
Block a user