Compare commits

..

No commits in common. "main" and "add-license-1" have entirely different histories.

14 changed files with 53 additions and 409 deletions

161
.gitignore vendored
View File

@ -1,161 +0,0 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/
.DS_Store

View File

@ -1,4 +0,0 @@
checkpoints/
demo-output/
pretrained_models/
/trace.json

View File

@ -325,24 +325,21 @@ class SubModule(nn.Module):
class Feature(SubModule): class Feature(SubModule):
def __init__(self, freeze): def __init__(self):
super(Feature, self).__init__() super(Feature, self).__init__()
pretrained = True pretrained = True
self.model = timm.create_model('mobilenetv2_100', pretrained=pretrained, features_only=True) model = timm.create_model('mobilenetv2_100', pretrained=pretrained, features_only=True)
if freeze:
for p in self.model.parameters():
p.requires_grad = False
layers = [1,2,3,5,6] layers = [1,2,3,5,6]
chans = [16, 24, 32, 96, 160] chans = [16, 24, 32, 96, 160]
self.conv_stem = self.model.conv_stem self.conv_stem = model.conv_stem
self.bn1 = self.model.bn1 self.bn1 = model.bn1
self.act1 = self.model.act1 self.act1 = model.act1
self.block0 = torch.nn.Sequential(*self.model.blocks[0:layers[0]]) self.block0 = torch.nn.Sequential(*model.blocks[0:layers[0]])
self.block1 = torch.nn.Sequential(*self.model.blocks[layers[0]:layers[1]]) self.block1 = torch.nn.Sequential(*model.blocks[layers[0]:layers[1]])
self.block2 = torch.nn.Sequential(*self.model.blocks[layers[1]:layers[2]]) self.block2 = torch.nn.Sequential(*model.blocks[layers[1]:layers[2]])
self.block3 = torch.nn.Sequential(*self.model.blocks[layers[2]:layers[3]]) self.block3 = torch.nn.Sequential(*model.blocks[layers[2]:layers[3]])
self.block4 = torch.nn.Sequential(*self.model.blocks[layers[3]:layers[4]]) self.block4 = torch.nn.Sequential(*model.blocks[layers[3]:layers[4]])
self.deconv32_16 = Conv2x_IN(chans[4], chans[3], deconv=True, concat=True) self.deconv32_16 = Conv2x_IN(chans[4], chans[3], deconv=True, concat=True)
self.deconv16_8 = Conv2x_IN(chans[3]*2, chans[2], deconv=True, concat=True) self.deconv16_8 = Conv2x_IN(chans[3]*2, chans[2], deconv=True, concat=True)

View File

@ -37,8 +37,8 @@ class Combined_Geo_Encoding_Volume:
out_pyramid = [] out_pyramid = []
for i in range(self.num_levels): for i in range(self.num_levels):
geo_volume = self.geo_volume_pyramid[i] geo_volume = self.geo_volume_pyramid[i]
dx = torch.linspace(-r, r, 2*r+1, device=disp.device) dx = torch.linspace(-r, r, 2*r+1)
dx = dx.view(1, 1, 2*r+1, 1) dx = dx.view(1, 1, 2*r+1, 1).to(disp.device)
x0 = dx + disp.reshape(b*h*w, 1, 1, 1) / 2**i x0 = dx + disp.reshape(b*h*w, 1, 1, 1) / 2**i
y0 = torch.zeros_like(x0) y0 = torch.zeros_like(x0)
@ -60,7 +60,6 @@ class Combined_Geo_Encoding_Volume:
@staticmethod @staticmethod
def corr(fmap1, fmap2): def corr(fmap1, fmap2):
# batch, dim, ht, wd
B, D, H, W1 = fmap1.shape B, D, H, W1 = fmap1.shape
_, _, _, W2 = fmap2.shape _, _, _, W2 = fmap2.shape
fmap1 = fmap1.view(B, D, H, W1) fmap1 = fmap1.view(B, D, H, W1)

View File

@ -100,7 +100,7 @@ class IGEVStereo(nn.Module):
self.context_zqr_convs = nn.ModuleList([nn.Conv2d(context_dims[i], args.hidden_dims[i]*3, 3, padding=3//2) for i in range(self.args.n_gru_layers)]) self.context_zqr_convs = nn.ModuleList([nn.Conv2d(context_dims[i], args.hidden_dims[i]*3, 3, padding=3//2) for i in range(self.args.n_gru_layers)])
self.feature = Feature(args.freeze_backbone_params) self.feature = Feature()
self.stem_2 = nn.Sequential( self.stem_2 = nn.Sequential(
BasicConv_IN(3, 32, kernel_size=3, stride=2, padding=1), BasicConv_IN(3, 32, kernel_size=3, stride=2, padding=1),
@ -165,9 +165,8 @@ class IGEVStereo(nn.Module):
match_left = self.desc(self.conv(features_left[0])) match_left = self.desc(self.conv(features_left[0]))
match_right = self.desc(self.conv(features_right[0])) match_right = self.desc(self.conv(features_right[0]))
gwc_volume = build_gwc_volume(match_left, match_right, self.args.max_disp//4, 8) gwc_volume = build_gwc_volume(match_left, match_right, 192//4, 8)
gwc_volume = self.corr_stem(gwc_volume) gwc_volume = self.corr_stem(gwc_volume)
# 3d unet
gwc_volume = self.corr_feature_att(gwc_volume, features_left[0]) gwc_volume = self.corr_feature_att(gwc_volume, features_left[0])
geo_encoding_volume = self.cost_agg(gwc_volume, features_left) geo_encoding_volume = self.cost_agg(gwc_volume, features_left)
@ -183,7 +182,6 @@ class IGEVStereo(nn.Module):
spx_pred = self.spx(xspx) spx_pred = self.spx(xspx)
spx_pred = F.softmax(spx_pred, 1) spx_pred = F.softmax(spx_pred, 1)
# Content Network
cnet_list = self.cnet(image1, num_layers=self.args.n_gru_layers) cnet_list = self.cnet(image1, num_layers=self.args.n_gru_layers)
net_list = [torch.tanh(x[0]) for x in cnet_list] net_list = [torch.tanh(x[0]) for x in cnet_list]
inp_list = [torch.relu(x[1]) for x in cnet_list] inp_list = [torch.relu(x[1]) for x in cnet_list]
@ -193,7 +191,7 @@ class IGEVStereo(nn.Module):
geo_block = Combined_Geo_Encoding_Volume geo_block = Combined_Geo_Encoding_Volume
geo_fn = geo_block(match_left.float(), match_right.float(), geo_encoding_volume.float(), radius=self.args.corr_radius, num_levels=self.args.corr_levels) geo_fn = geo_block(match_left.float(), match_right.float(), geo_encoding_volume.float(), radius=self.args.corr_radius, num_levels=self.args.corr_levels)
b, c, h, w = match_left.shape b, c, h, w = match_left.shape
coords = torch.arange(w, device=match_left.device).float().reshape(1,1,w,1).repeat(b, h, 1, 1) coords = torch.arange(w).float().to(match_left.device).reshape(1,1,w,1).repeat(b, h, 1, 1)
disp = init_disp disp = init_disp
disp_preds = [] disp_preds = []

View File

@ -266,34 +266,9 @@ class KITTI(StereoDataset):
self.image_list += [ [img1, img2] ] self.image_list += [ [img1, img2] ]
self.disparity_list += [ disp ] self.disparity_list += [ disp ]
class CREStereo(StereoDataset):
def __init__(self, aug_params=None, root='/data/CREStereo'):
super(CREStereo, self).__init__(aug_params, sparse=True, reader=frame_utils.readDispCREStereo)
self.root = root
assert os.path.exists(root)
disp_list = self.selector('_left.disp.png')
image1_list = self.selector('_left.jpg')
image2_list = self.selector('_right.jpg')
assert len(image1_list) == len(image2_list) == len(disp_list) > 0
for img1, img2, disp in zip(image1_list, image2_list, disp_list):
# if random.randint(1, 20000) != 1:
# continue
self.image_list += [[img1, img2]]
self.disparity_list += [disp]
def selector(self, suffix):
files = list(glob(os.path.join(self.root, f"hole/*{suffix}")))
files += list(glob(os.path.join(self.root, f"shapenet/*{suffix}")))
files += list(glob(os.path.join(self.root, f"tree/*{suffix}")))
files += list(glob(os.path.join(self.root, f"reflective/*{suffix}")))
return sorted(files)
class Middlebury(StereoDataset): class Middlebury(StereoDataset):
def __init__(self, aug_params=None, root='/data/Middlebury', split='H'): def __init__(self, aug_params=None, root='/data/Middlebury', split='F'):
super(Middlebury, self).__init__(aug_params, sparse=True, reader=frame_utils.readDispMiddlebury) super(Middlebury, self).__init__(aug_params, sparse=True, reader=frame_utils.readDispMiddlebury)
assert os.path.exists(root) assert os.path.exists(root)
assert split in "FHQ" assert split in "FHQ"
@ -346,13 +321,10 @@ def fetch_dataloader(args):
elif dataset_name.startswith('tartan_air'): elif dataset_name.startswith('tartan_air'):
new_dataset = TartanAir(aug_params, keywords=dataset_name.split('_')[2:]) new_dataset = TartanAir(aug_params, keywords=dataset_name.split('_')[2:])
logging.info(f"Adding {len(new_dataset)} samples from Tartain Air") logging.info(f"Adding {len(new_dataset)} samples from Tartain Air")
elif dataset_name.startswith('crestereo'):
new_dataset = CREStereo(aug_params)
logging.info(f"Adding {len(new_dataset)} samples from CREStereo")
train_dataset = new_dataset if train_dataset is None else train_dataset + new_dataset train_dataset = new_dataset if train_dataset is None else train_dataset + new_dataset
train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size, train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size,
pin_memory=True, shuffle=True, num_workers=12, drop_last=True) pin_memory=True, shuffle=True, num_workers=int(os.environ.get('SLURM_CPUS_PER_TASK', 6))-2, drop_last=True)
logging.info('Training with %d image pairs' % len(train_dataset)) logging.info('Training with %d image pairs' % len(train_dataset))
return train_loader return train_loader

View File

@ -157,7 +157,6 @@ def groupwise_correlation(fea1, fea2, num_groups):
return cost return cost
def build_gwc_volume(refimg_fea, targetimg_fea, maxdisp, num_groups): def build_gwc_volume(refimg_fea, targetimg_fea, maxdisp, num_groups):
# batch, groups, max_disp, height, width
B, C, H, W = refimg_fea.shape B, C, H, W = refimg_fea.shape
volume = refimg_fea.new_zeros([B, num_groups, maxdisp, H, W]) volume = refimg_fea.new_zeros([B, num_groups, maxdisp, H, W])
for i in range(maxdisp): for i in range(maxdisp):

View File

@ -5,7 +5,7 @@ import os
import time import time
from glob import glob from glob import glob
from skimage import color, io from skimage import color, io
from PIL import Image, ImageEnhance from PIL import Image
import cv2 import cv2
cv2.setNumThreads(0) cv2.setNumThreads(0)
@ -198,40 +198,21 @@ class SparseFlowAugmentor:
self.v_flip_prob = 0.1 self.v_flip_prob = 0.1
# photometric augmentation params # photometric augmentation params
# self.photo_aug = Compose([ColorJitter(brightness=0.3, contrast=0.3, saturation=saturation_range, hue=0.3/3.14), AdjustGamma(*gamma)]) self.photo_aug = Compose([ColorJitter(brightness=0.3, contrast=0.3, saturation=saturation_range, hue=0.3/3.14), AdjustGamma(*gamma)])
self.asymmetric_color_aug_prob = 0.2
self.eraser_aug_prob = 0.5 self.eraser_aug_prob = 0.5
def chromatic_augmentation(self, img):
random_brightness = np.random.uniform(0.8, 1.2)
random_contrast = np.random.uniform(0.8, 1.2)
random_gamma = np.random.uniform(0.8, 1.2)
img = Image.fromarray(img)
enhancer = ImageEnhance.Brightness(img)
img = enhancer.enhance(random_brightness)
enhancer = ImageEnhance.Contrast(img)
img = enhancer.enhance(random_contrast)
gamma_map = [
255 * 1.0 * pow(ele / 255.0, random_gamma) for ele in range(256)
] * 3
img = img.point(gamma_map) # use PIL's point-function to accelerate this part
img_ = np.array(img)
return img_
def color_transform(self, img1, img2): def color_transform(self, img1, img2):
img1 = self.chromatic_augmentation(img1) image_stack = np.concatenate([img1, img2], axis=0)
img2 = self.chromatic_augmentation(img2) image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
img1, img2 = np.split(image_stack, 2, axis=0)
return img1, img2 return img1, img2
def eraser_transform(self, img1, img2): def eraser_transform(self, img1, img2):
ht, wd = img1.shape[:2] ht, wd = img1.shape[:2]
if np.random.rand() < self.eraser_aug_prob: if np.random.rand() < self.eraser_aug_prob:
mean_color = np.mean(img2.reshape(-1, 3), axis=0) mean_color = np.mean(img2.reshape(-1, 3), axis=0)
for _ in range(1): for _ in range(np.random.randint(1, 3)):
x0 = np.random.randint(0, wd) x0 = np.random.randint(0, wd)
y0 = np.random.randint(0, ht) y0 = np.random.randint(0, ht)
dx = np.random.randint(50, 100) dx = np.random.randint(50, 100)

View File

@ -152,11 +152,6 @@ def readDispTartanAir(file_name):
valid = disp > 0 valid = disp > 0
return disp, valid return disp, valid
def readDispCREStereo(file_name):
disp = cv2.imread(file_name, cv2.IMREAD_UNCHANGED)
disp = disp.astype(np.float32) / 32.0
valid = disp > 0.0
return disp, valid
def readDispMiddlebury(file_name): def readDispMiddlebury(file_name):
assert basename(file_name) == 'disp0GT.pfm' assert basename(file_name) == 'disp0GT.pfm'
@ -173,7 +168,7 @@ def writeFlowKITTI(filename, uv):
valid = np.ones([uv.shape[0], uv.shape[1], 1]) valid = np.ones([uv.shape[0], uv.shape[1], 1])
uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16) uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16)
cv2.imwrite(filename, uv[..., ::-1]) cv2.imwrite(filename, uv[..., ::-1])
def read_gen(file_name, pil=False): def read_gen(file_name, pil=False):
ext = splitext(file_name)[-1] ext = splitext(file_name)[-1]

View File

@ -1,112 +0,0 @@
import os
import sys
sys.path.append("..")
import numpy as np
import logging
import os
import shutil
import random
from pathlib import Path
from glob import glob
import matplotlib.pyplot as plt
from core.utils import frame_utils
def unique(lst):
return dict(zip(*np.unique(lst, return_counts=True)))
def ensure_path_exists(path):
if not os.path.exists(path):
os.makedirs(path)
class CREStereo():
def __init__(self, aug_params=None, root='/data/CREStereo'):
self.root = root
assert os.path.exists(root)
# disp_list = self.selector('_left.disp.png')
# image1_list = self.selector('_left.jpg')
# image2_list = self.selector('_right.jpg')
# assert len(image1_list) == len(image2_list) == len(disp_list) > 0
# for img1, img2, disp in zip(image1_list, image2_list, disp_list):
# # if random.randint(1, 20000) != 1:
# # continue
# self.image_list += [[img1, img2]]
# self.disparity_list += [disp]
def get_path_info(self, path):
position, filename = os.path.split(path)
root, sub_folder = os.path.split(position)
return root, sub_folder, filename
def get_new_file(self, path):
root, sub_folder, filename = self.get_path_info(path)
return os.path.join(root, 'subset', sub_folder, filename)
def divide(self, num):
ensure_path_exists(os.path.join(self.root, 'subset'))
for sub_folder in ['tree', 'shapenet', 'reflective', 'hole']:
ensure_path_exists(os.path.join(self.root, 'subset', sub_folder))
disp1_list = self.single_folder_selector(sub_folder, '_left.disp.png')
disp2_list = self.single_folder_selector(sub_folder, '_right.disp.png')
image1_list = self.single_folder_selector(sub_folder, '_left.jpg')
image2_list = self.single_folder_selector(sub_folder, '_right.jpg')
assert len(image1_list) == len(image2_list) == len(disp1_list) == len(disp2_list) > 0
lists = []
for img1, img2, disp1, disp2 in zip(image1_list, image2_list, disp1_list, disp2_list):
lists += [[img1, img2, disp1, disp2]]
subset = random.sample(lists, num)
for s in subset:
for element in s:
print(element)
print(self.get_new_file(element))
shutil.copy(element, self.get_new_file(element))
def selector(self, suffix):
files = list(glob(os.path.join(self.root, f"hole/*{suffix}")))
files += list(glob(os.path.join(self.root, f"shapenet/*{suffix}")))
files += list(glob(os.path.join(self.root, f"tree/*{suffix}")))
files += list(glob(os.path.join(self.root, f"reflective/*{suffix}")))
return sorted(files)
def single_folder_selector(self, sub_folder, suffix):
return sorted(list(glob(os.path.join(self.root, f"{sub_folder}/*{suffix}"))))
def disparity_distribution(self):
disp_lists = self.selector('_left.disp.png')
disparities = {}
for filename in disp_lists:
print(filename)
disp_gt, _ = frame_utils.readDispCREStereo(filename)
[rows, cols] = disp_gt.shape
disp_gt = (disp_gt * 32).astype(int)
cnt = unique(disp_gt)
for i in cnt:
if i in disparities:
disparities[i] += cnt[i]
else:
disparities[i] = cnt[i]
x = []
y = []
for key in disparities.keys():
x.append(key / 32)
y.append(disparities[key])
plt.scatter(x, y)
plt.show()
c = CREStereo()
c.divide(10000)

View File

@ -26,7 +26,6 @@ def demo(args):
model.load_state_dict(torch.load(args.restore_ckpt)) model.load_state_dict(torch.load(args.restore_ckpt))
model = model.module model = model.module
# model = torch.compile(model)
model.to(DEVICE) model.to(DEVICE)
model.eval() model.eval()

View File

@ -1,8 +0,0 @@
#!/bin/bash
for iter in {1..100}
do
iter=$((iter * 1000))
echo "These are the results of ${iter} iterations."
python evaluate_stereo.py --dataset middlebury_H --max_disp 384 --freeze_backbone_params --restore_ckpt ~/checkpoints/igev_stereo/${iter}_fix-validation.pth
done

View File

@ -20,7 +20,7 @@ def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad) return sum(p.numel() for p in model.parameters() if p.requires_grad)
@torch.no_grad() @torch.no_grad()
def validate_eth3d(model, iters=32, mixed_prec=False, max_disp=192): def validate_eth3d(model, iters=32, mixed_prec=False):
""" Peform validation using the ETH3D (train) split """ """ Peform validation using the ETH3D (train) split """
model.eval() model.eval()
aug_params = {} aug_params = {}
@ -67,7 +67,7 @@ def validate_eth3d(model, iters=32, mixed_prec=False, max_disp=192):
@torch.no_grad() @torch.no_grad()
def validate_kitti(model, iters=32, mixed_prec=False, max_disp=192): def validate_kitti(model, iters=32, mixed_prec=False):
""" Peform validation using the KITTI-2015 (train) split """ """ Peform validation using the KITTI-2015 (train) split """
model.eval() model.eval()
aug_params = {} aug_params = {}
@ -96,7 +96,7 @@ def validate_kitti(model, iters=32, mixed_prec=False, max_disp=192):
epe = torch.sum((flow_pr - flow_gt)**2, dim=0).sqrt() epe = torch.sum((flow_pr - flow_gt)**2, dim=0).sqrt()
epe_flattened = epe.flatten() epe_flattened = epe.flatten()
val = (valid_gt.flatten() >= 0.5) & (flow_gt.abs().flatten() < max_disp) val = (valid_gt.flatten() >= 0.5) & (flow_gt.abs().flatten() < 192)
# val = valid_gt.flatten() >= 0.5 # val = valid_gt.flatten() >= 0.5
out = (epe_flattened > 3.0) out = (epe_flattened > 3.0)
@ -120,7 +120,7 @@ def validate_kitti(model, iters=32, mixed_prec=False, max_disp=192):
@torch.no_grad() @torch.no_grad()
def validate_sceneflow(model, iters=32, mixed_prec=False, max_disp=192): def validate_sceneflow(model, iters=32, mixed_prec=False):
""" Peform validation using the Scene Flow (TEST) split """ """ Peform validation using the Scene Flow (TEST) split """
model.eval() model.eval()
val_dataset = datasets.SceneFlowDatasets(dstype='frames_finalpass', things_test=True) val_dataset = datasets.SceneFlowDatasets(dstype='frames_finalpass', things_test=True)
@ -144,7 +144,7 @@ def validate_sceneflow(model, iters=32, mixed_prec=False, max_disp=192):
epe = torch.abs(flow_pr - flow_gt) epe = torch.abs(flow_pr - flow_gt)
epe = epe.flatten() epe = epe.flatten()
val = (valid_gt.flatten() >= 0.5) & (flow_gt.abs().flatten() < max_disp) val = (valid_gt.flatten() >= 0.5) & (flow_gt.abs().flatten() < 192)
if(np.isnan(epe[val].mean().item())): if(np.isnan(epe[val].mean().item())):
continue continue
@ -169,7 +169,7 @@ def validate_sceneflow(model, iters=32, mixed_prec=False, max_disp=192):
@torch.no_grad() @torch.no_grad()
def validate_middlebury(model, iters=32, split='H', mixed_prec=False, max_disp=192): def validate_middlebury(model, iters=32, split='F', mixed_prec=False):
""" Peform validation using the Middlebury-V3 dataset """ """ Peform validation using the Middlebury-V3 dataset """
model.eval() model.eval()
aug_params = {} aug_params = {}
@ -196,11 +196,11 @@ def validate_middlebury(model, iters=32, split='H', mixed_prec=False, max_disp=1
occ_mask = Image.open(imageL_file.replace('im0.png', 'mask0nocc.png')).convert('L') occ_mask = Image.open(imageL_file.replace('im0.png', 'mask0nocc.png')).convert('L')
occ_mask = np.ascontiguousarray(occ_mask, dtype=np.float32).flatten() occ_mask = np.ascontiguousarray(occ_mask, dtype=np.float32).flatten()
val = (valid_gt.reshape(-1) >= 0.5) & (flow_gt[0].reshape(-1) < max_disp) & (occ_mask==255) val = (valid_gt.reshape(-1) >= 0.5) & (flow_gt[0].reshape(-1) < 192) & (occ_mask==255)
out = (epe_flattened > 2.0) out = (epe_flattened > 2.0)
image_out = out[val].float().mean().item() image_out = out[val].float().mean().item()
image_epe = epe_flattened[val].mean().item() image_epe = epe_flattened[val].mean().item()
logging.info(f"Middlebury Iter {val_id+1} out of {len(val_dataset)}. EPE {round(image_epe,4)} Err2.0 {round(image_out,4)}") logging.info(f"Middlebury Iter {val_id+1} out of {len(val_dataset)}. EPE {round(image_epe,4)} D1 {round(image_out,4)}")
epe_list.append(image_epe) epe_list.append(image_epe)
out_list.append(image_out) out_list.append(image_out)
@ -208,10 +208,10 @@ def validate_middlebury(model, iters=32, split='H', mixed_prec=False, max_disp=1
out_list = np.array(out_list) out_list = np.array(out_list)
epe = np.mean(epe_list) epe = np.mean(epe_list)
err2 = 100 * np.mean(out_list) d1 = 100 * np.mean(out_list)
print(f"Validation Middlebury{split}: EPE {epe}, Err2.0 {err2}") print(f"Validation Middlebury{split}: EPE {epe}, D1 {d1}")
return {f'middlebury{split}-epe': epe, f'middlebury{split}-err2.0': err2} return {f'middlebury{split}-epe': epe, f'middlebury{split}-d1': d1}
if __name__ == '__main__': if __name__ == '__main__':
@ -231,7 +231,6 @@ if __name__ == '__main__':
parser.add_argument('--slow_fast_gru', action='store_true', help="iterate the low-res GRUs more frequently") parser.add_argument('--slow_fast_gru', action='store_true', help="iterate the low-res GRUs more frequently")
parser.add_argument('--n_gru_layers', type=int, default=3, help="number of hidden GRU levels") parser.add_argument('--n_gru_layers', type=int, default=3, help="number of hidden GRU levels")
parser.add_argument('--max_disp', type=int, default=192, help="max disp of geometry encoding volume") parser.add_argument('--max_disp', type=int, default=192, help="max disp of geometry encoding volume")
parser.add_argument('--freeze_backbone_params', action="store_true", help="freeze backbone parameters")
args = parser.parse_args() args = parser.parse_args()
model = torch.nn.DataParallel(IGEVStereo(args), device_ids=[0]) model = torch.nn.DataParallel(IGEVStereo(args), device_ids=[0])
@ -243,12 +242,6 @@ if __name__ == '__main__':
assert args.restore_ckpt.endswith(".pth") assert args.restore_ckpt.endswith(".pth")
logging.info("Loading checkpoint...") logging.info("Loading checkpoint...")
checkpoint = torch.load(args.restore_ckpt) checkpoint = torch.load(args.restore_ckpt)
unwanted_prefix = '_orig_mod.'
for k, v in list(checkpoint.items()):
if k.startswith(unwanted_prefix):
checkpoint[k[len(unwanted_prefix):]] = checkpoint.pop(k)
model.load_state_dict(checkpoint, strict=True) model.load_state_dict(checkpoint, strict=True)
logging.info(f"Done loading checkpoint") logging.info(f"Done loading checkpoint")
@ -259,13 +252,13 @@ if __name__ == '__main__':
use_mixed_precision = args.corr_implementation.endswith("_cuda") use_mixed_precision = args.corr_implementation.endswith("_cuda")
if args.dataset == 'eth3d': if args.dataset == 'eth3d':
validate_eth3d(model, iters=args.valid_iters, mixed_prec=use_mixed_precision, max_disp=args.max_disp) validate_eth3d(model, iters=args.valid_iters, mixed_prec=use_mixed_precision)
elif args.dataset == 'kitti': elif args.dataset == 'kitti':
validate_kitti(model, iters=args.valid_iters, mixed_prec=use_mixed_precision, max_disp=args.max_disp) validate_kitti(model, iters=args.valid_iters, mixed_prec=use_mixed_precision)
elif args.dataset in [f"middlebury_{s}" for s in 'FHQ']: elif args.dataset in [f"middlebury_{s}" for s in 'FHQ']:
validate_middlebury(model, iters=args.valid_iters, split=args.dataset[-1], mixed_prec=use_mixed_precision, max_disp=args.max_disp) validate_middlebury(model, iters=args.valid_iters, split=args.dataset[-1], mixed_prec=use_mixed_precision)
elif args.dataset == 'sceneflow': elif args.dataset == 'sceneflow':
validate_sceneflow(model, iters=args.valid_iters, mixed_prec=use_mixed_precision, max_disp=args.max_disp) validate_sceneflow(model, iters=args.valid_iters, mixed_prec=use_mixed_precision)

View File

@ -1,7 +1,6 @@
from __future__ import print_function, division from __future__ import print_function, division
import math
import os import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0, 1' os.environ['CUDA_VISIBLE_DEVICES'] = '0, 1'
@ -54,7 +53,7 @@ def sequence_loss(disp_preds, disp_init_pred, disp_gt, valid, loss_gamma=0.9, ma
assert not torch.isinf(disp_gt[valid.bool()]).any() assert not torch.isinf(disp_gt[valid.bool()]).any()
disp_loss += 1.0 * F.smooth_l1_loss(disp_init_pred[valid.bool()], disp_gt[valid.bool()], reduction='mean') disp_loss += 1.0 * F.smooth_l1_loss(disp_init_pred[valid.bool()], disp_gt[valid.bool()], size_average=True)
for i in range(n_predictions): for i in range(n_predictions):
adjusted_loss_gamma = loss_gamma**(15/(n_predictions - 1)) adjusted_loss_gamma = loss_gamma**(15/(n_predictions - 1))
i_weight = adjusted_loss_gamma**(n_predictions - i - 1) i_weight = adjusted_loss_gamma**(n_predictions - i - 1)
@ -67,10 +66,9 @@ def sequence_loss(disp_preds, disp_init_pred, disp_gt, valid, loss_gamma=0.9, ma
metrics = { metrics = {
'epe': epe.mean().item(), 'epe': epe.mean().item(),
'0.5px': (epe < 0.5).float().mean().item(), '1px': (epe < 1).float().mean().item(),
'1.0px': (epe < 1.0).float().mean().item(), '3px': (epe < 3).float().mean().item(),
'2.0px': (epe < 2.0).float().mean().item(), '5px': (epe < 5).float().mean().item(),
'4.0px': (epe < 4.0).float().mean().item(),
} }
return disp_loss, metrics return disp_loss, metrics
@ -78,9 +76,8 @@ def sequence_loss(disp_preds, disp_init_pred, disp_gt, valid, loss_gamma=0.9, ma
def fetch_optimizer(args, model): def fetch_optimizer(args, model):
""" Create the optimizer and learning rate scheduler """ """ Create the optimizer and learning rate scheduler """
optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=args.wdecay, eps=1e-8) optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wdecay, eps=1e-8)
# todo: cosine scheduler, warm-up
scheduler = optim.lr_scheduler.OneCycleLR(optimizer, args.lr, args.num_steps+100, scheduler = optim.lr_scheduler.OneCycleLR(optimizer, args.lr, args.num_steps+100,
pct_start=0.01, cycle_momentum=False, anneal_strategy='linear') pct_start=0.01, cycle_momentum=False, anneal_strategy='linear')
return optimizer, scheduler return optimizer, scheduler
@ -135,7 +132,7 @@ class Logger:
def train(args): def train(args):
# todo: compile the model to speed up at pytorch 2.0.
model = nn.DataParallel(IGEVStereo(args)) model = nn.DataParallel(IGEVStereo(args))
print("Parameter Count: %d" % count_parameters(model)) print("Parameter Count: %d" % count_parameters(model))
@ -154,13 +151,14 @@ def train(args):
model.train() model.train()
model.module.freeze_bn() # We keep BatchNorm frozen model.module.freeze_bn() # We keep BatchNorm frozen
validation_frequency = 1000 validation_frequency = 10000
scaler = GradScaler(enabled=args.mixed_precision) scaler = GradScaler(enabled=args.mixed_precision)
should_keep_training = True should_keep_training = True
global_batch_num = 0 global_batch_num = 0
while should_keep_training: while should_keep_training:
for i_batch, (_, *data_blob) in enumerate(tqdm(train_loader)): for i_batch, (_, *data_blob) in enumerate(tqdm(train_loader)):
optimizer.zero_grad() optimizer.zero_grad()
image1, image2, disp_gt, valid = [x.cuda() for x in data_blob] image1, image2, disp_gt, valid = [x.cuda() for x in data_blob]
@ -177,7 +175,6 @@ def train(args):
scaler.unscale_(optimizer) scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
#warning
scaler.step(optimizer) scaler.step(optimizer)
scheduler.step() scheduler.step()
scaler.update() scaler.update()
@ -187,7 +184,7 @@ def train(args):
save_path = Path(ckpt_path + '/%d_%s.pth' % (total_steps + 1, args.name)) save_path = Path(ckpt_path + '/%d_%s.pth' % (total_steps + 1, args.name))
logging.info(f"Saving file {save_path.absolute()}") logging.info(f"Saving file {save_path.absolute()}")
torch.save(model.state_dict(), save_path) torch.save(model.state_dict(), save_path)
results = validate_middlebury(model.module, iters=args.valid_iters, max_disp=args.max_disp) results = validate_sceneflow(model.module, iters=args.valid_iters)
logger.write_dict(results) logger.write_dict(results)
model.train() model.train()
model.module.freeze_bn() model.module.freeze_bn()
@ -215,14 +212,13 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--name', default='igev-stereo', help="name your experiment") parser.add_argument('--name', default='igev-stereo', help="name your experiment")
parser.add_argument('--restore_ckpt', default=None, help="") parser.add_argument('--restore_ckpt', default=None, help="")
parser.add_argument('--mixed_precision', default=False, action='store_true', help='use mixed precision') parser.add_argument('--mixed_precision', default=True, action='store_true', help='use mixed precision')
# Training parameters # Training parameters
parser.add_argument('--batch_size', type=int, default=8, help="batch size used during training.") parser.add_argument('--batch_size', type=int, default=8, help="batch size used during training.")
parser.add_argument('--train_datasets', nargs='+', default=['sceneflow'], help="training datasets.") parser.add_argument('--train_datasets', nargs='+', default=['sceneflow'], help="training datasets.")
parser.add_argument('--lr', type=float, default=0.0002, help="max learning rate.") parser.add_argument('--lr', type=float, default=0.0002, help="max learning rate.")
parser.add_argument('--num_steps', type=int, default=200000, help="length of training schedule.") parser.add_argument('--num_steps', type=int, default=200000, help="length of training schedule.")
parser.add_argument('--freeze_backbone_params', action="store_true", help="freeze backbone parameters")
parser.add_argument('--image_size', type=int, nargs='+', default=[320, 736], help="size of the random image crops used during training.") parser.add_argument('--image_size', type=int, nargs='+', default=[320, 736], help="size of the random image crops used during training.")
parser.add_argument('--train_iters', type=int, default=22, help="number of updates to the disparity field in each forward pass.") parser.add_argument('--train_iters', type=int, default=22, help="number of updates to the disparity field in each forward pass.")
parser.add_argument('--wdecay', type=float, default=.00001, help="Weight decay in optimizer.") parser.add_argument('--wdecay', type=float, default=.00001, help="Weight decay in optimizer.")
@ -242,8 +238,8 @@ if __name__ == '__main__':
parser.add_argument('--max_disp', type=int, default=192, help="max disp of geometry encoding volume") parser.add_argument('--max_disp', type=int, default=192, help="max disp of geometry encoding volume")
# Data augmentation # Data augmentation
# parser.add_argument('--img_gamma', type=float, nargs='+', default=None, help="gamma range") parser.add_argument('--img_gamma', type=float, nargs='+', default=None, help="gamma range")
# parser.add_argument('--saturation_range', type=float, nargs='+', default=[0, 1.4], help='color saturation') parser.add_argument('--saturation_range', type=float, nargs='+', default=[0, 1.4], help='color saturation')
parser.add_argument('--do_flip', default=False, choices=['h', 'v'], help='flip the images horizontally or vertically') parser.add_argument('--do_flip', default=False, choices=['h', 'v'], help='flip the images horizontally or vertically')
parser.add_argument('--spatial_scale', type=float, nargs='+', default=[-0.2, 0.4], help='re-scale the images randomly') parser.add_argument('--spatial_scale', type=float, nargs='+', default=[-0.2, 0.4], help='re-scale the images randomly')
parser.add_argument('--noyjitter', action='store_true', help='don\'t simulate imperfect rectification') parser.add_argument('--noyjitter', action='store_true', help='don\'t simulate imperfect rectification')