Compare commits

...

22 Commits

Author SHA1 Message Date
HTensor
74bc06d58e Merge branch 'original-test' 2023-05-03 04:30:45 +08:00
HTensor
0404f5b5b0 set 12 threads to prepare datasets 2023-05-03 04:14:05 +08:00
HTensor
3b68318ed5 disabled cudnn benchmark 2023-05-03 04:13:28 +08:00
HTensor
a7d89bd95c change resolution of middlebury test images 2023-05-02 13:39:36 +08:00
HTensor
e7033dabf9 added comments 2023-05-02 01:19:05 +08:00
HTensor
9df896bb70 Update stereo_datasets.py 2023-05-01 12:25:21 +08:00
HTensor
73e65f99b8 Update evaluate_stereo.py 2023-04-30 16:03:58 +08:00
a1cc25351d Update create_crestereo_subsets.py 2023-04-29 13:39:09 +08:00
59ff17e149 Create create_crestereo_subsets.py but haven't finished 2023-04-28 00:52:54 +08:00
75735c8fae compatibility changes 2023-04-27 19:16:45 +08:00
21e3f92461 changed metrics indicator 2023-04-27 19:15:49 +08:00
8591c2edad Create evaluate-history.sh 2023-04-27 19:13:51 +08:00
875a1eec05 change d1 to err2.0 in middlebury 2023-04-27 13:29:57 +08:00
HTensor
1080b823a5 wtf? 2023-04-26 19:50:17 +08:00
0a3613711b locked backbone parameters 2023-04-25 20:19:43 +08:00
3f60e691f8 added asymmetric chromatic augmentation & adjusted augmentor param 2023-04-25 16:20:22 +08:00
0a20c0a001 Create .gitignore 2023-04-25 10:36:53 +08:00
c3b4812e99 optimize 2023-04-24 16:37:30 +08:00
HTensor
25753af380 Create .gitignore 2023-04-22 16:38:11 +08:00
a301ebb020 added crestereo datasets & fixed some bugs 2023-04-22 11:32:10 +08:00
HTensor
93371a3c4d todo 2023-04-19 00:58:31 +08:00
Gangwei Xu
54a2eabbff
Merge pull request #7 from gangweiX/add-license-1
Create LICENSE
2023-04-13 15:46:17 +08:00
14 changed files with 409 additions and 53 deletions

161
.gitignore vendored Normal file
View File

@ -0,0 +1,161 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/
.DS_Store

4
IGEV-Stereo/.gitignore vendored Normal file
View File

@ -0,0 +1,4 @@
checkpoints/
demo-output/
pretrained_models/
/trace.json

View File

@ -325,21 +325,24 @@ class SubModule(nn.Module):
class Feature(SubModule):
def __init__(self):
def __init__(self, freeze):
super(Feature, self).__init__()
pretrained = True
model = timm.create_model('mobilenetv2_100', pretrained=pretrained, features_only=True)
self.model = timm.create_model('mobilenetv2_100', pretrained=pretrained, features_only=True)
if freeze:
for p in self.model.parameters():
p.requires_grad = False
layers = [1,2,3,5,6]
chans = [16, 24, 32, 96, 160]
self.conv_stem = model.conv_stem
self.bn1 = model.bn1
self.act1 = model.act1
self.conv_stem = self.model.conv_stem
self.bn1 = self.model.bn1
self.act1 = self.model.act1
self.block0 = torch.nn.Sequential(*model.blocks[0:layers[0]])
self.block1 = torch.nn.Sequential(*model.blocks[layers[0]:layers[1]])
self.block2 = torch.nn.Sequential(*model.blocks[layers[1]:layers[2]])
self.block3 = torch.nn.Sequential(*model.blocks[layers[2]:layers[3]])
self.block4 = torch.nn.Sequential(*model.blocks[layers[3]:layers[4]])
self.block0 = torch.nn.Sequential(*self.model.blocks[0:layers[0]])
self.block1 = torch.nn.Sequential(*self.model.blocks[layers[0]:layers[1]])
self.block2 = torch.nn.Sequential(*self.model.blocks[layers[1]:layers[2]])
self.block3 = torch.nn.Sequential(*self.model.blocks[layers[2]:layers[3]])
self.block4 = torch.nn.Sequential(*self.model.blocks[layers[3]:layers[4]])
self.deconv32_16 = Conv2x_IN(chans[4], chans[3], deconv=True, concat=True)
self.deconv16_8 = Conv2x_IN(chans[3]*2, chans[2], deconv=True, concat=True)

View File

@ -37,8 +37,8 @@ class Combined_Geo_Encoding_Volume:
out_pyramid = []
for i in range(self.num_levels):
geo_volume = self.geo_volume_pyramid[i]
dx = torch.linspace(-r, r, 2*r+1)
dx = dx.view(1, 1, 2*r+1, 1).to(disp.device)
dx = torch.linspace(-r, r, 2*r+1, device=disp.device)
dx = dx.view(1, 1, 2*r+1, 1)
x0 = dx + disp.reshape(b*h*w, 1, 1, 1) / 2**i
y0 = torch.zeros_like(x0)
@ -60,6 +60,7 @@ class Combined_Geo_Encoding_Volume:
@staticmethod
def corr(fmap1, fmap2):
# batch, dim, ht, wd
B, D, H, W1 = fmap1.shape
_, _, _, W2 = fmap2.shape
fmap1 = fmap1.view(B, D, H, W1)

View File

@ -100,7 +100,7 @@ class IGEVStereo(nn.Module):
self.context_zqr_convs = nn.ModuleList([nn.Conv2d(context_dims[i], args.hidden_dims[i]*3, 3, padding=3//2) for i in range(self.args.n_gru_layers)])
self.feature = Feature()
self.feature = Feature(args.freeze_backbone_params)
self.stem_2 = nn.Sequential(
BasicConv_IN(3, 32, kernel_size=3, stride=2, padding=1),
@ -165,8 +165,9 @@ class IGEVStereo(nn.Module):
match_left = self.desc(self.conv(features_left[0]))
match_right = self.desc(self.conv(features_right[0]))
gwc_volume = build_gwc_volume(match_left, match_right, 192//4, 8)
gwc_volume = build_gwc_volume(match_left, match_right, self.args.max_disp//4, 8)
gwc_volume = self.corr_stem(gwc_volume)
# 3d unet
gwc_volume = self.corr_feature_att(gwc_volume, features_left[0])
geo_encoding_volume = self.cost_agg(gwc_volume, features_left)
@ -182,6 +183,7 @@ class IGEVStereo(nn.Module):
spx_pred = self.spx(xspx)
spx_pred = F.softmax(spx_pred, 1)
# Content Network
cnet_list = self.cnet(image1, num_layers=self.args.n_gru_layers)
net_list = [torch.tanh(x[0]) for x in cnet_list]
inp_list = [torch.relu(x[1]) for x in cnet_list]
@ -191,7 +193,7 @@ class IGEVStereo(nn.Module):
geo_block = Combined_Geo_Encoding_Volume
geo_fn = geo_block(match_left.float(), match_right.float(), geo_encoding_volume.float(), radius=self.args.corr_radius, num_levels=self.args.corr_levels)
b, c, h, w = match_left.shape
coords = torch.arange(w).float().to(match_left.device).reshape(1,1,w,1).repeat(b, h, 1, 1)
coords = torch.arange(w, device=match_left.device).float().reshape(1,1,w,1).repeat(b, h, 1, 1)
disp = init_disp
disp_preds = []

View File

@ -266,9 +266,34 @@ class KITTI(StereoDataset):
self.image_list += [ [img1, img2] ]
self.disparity_list += [ disp ]
class CREStereo(StereoDataset):
def __init__(self, aug_params=None, root='/data/CREStereo'):
super(CREStereo, self).__init__(aug_params, sparse=True, reader=frame_utils.readDispCREStereo)
self.root = root
assert os.path.exists(root)
disp_list = self.selector('_left.disp.png')
image1_list = self.selector('_left.jpg')
image2_list = self.selector('_right.jpg')
assert len(image1_list) == len(image2_list) == len(disp_list) > 0
for img1, img2, disp in zip(image1_list, image2_list, disp_list):
# if random.randint(1, 20000) != 1:
# continue
self.image_list += [[img1, img2]]
self.disparity_list += [disp]
def selector(self, suffix):
files = list(glob(os.path.join(self.root, f"hole/*{suffix}")))
files += list(glob(os.path.join(self.root, f"shapenet/*{suffix}")))
files += list(glob(os.path.join(self.root, f"tree/*{suffix}")))
files += list(glob(os.path.join(self.root, f"reflective/*{suffix}")))
return sorted(files)
class Middlebury(StereoDataset):
def __init__(self, aug_params=None, root='/data/Middlebury', split='F'):
def __init__(self, aug_params=None, root='/data/Middlebury', split='H'):
super(Middlebury, self).__init__(aug_params, sparse=True, reader=frame_utils.readDispMiddlebury)
assert os.path.exists(root)
assert split in "FHQ"
@ -321,10 +346,13 @@ def fetch_dataloader(args):
elif dataset_name.startswith('tartan_air'):
new_dataset = TartanAir(aug_params, keywords=dataset_name.split('_')[2:])
logging.info(f"Adding {len(new_dataset)} samples from Tartain Air")
elif dataset_name.startswith('crestereo'):
new_dataset = CREStereo(aug_params)
logging.info(f"Adding {len(new_dataset)} samples from CREStereo")
train_dataset = new_dataset if train_dataset is None else train_dataset + new_dataset
train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size,
pin_memory=True, shuffle=True, num_workers=int(os.environ.get('SLURM_CPUS_PER_TASK', 6))-2, drop_last=True)
pin_memory=True, shuffle=True, num_workers=12, drop_last=True)
logging.info('Training with %d image pairs' % len(train_dataset))
return train_loader

View File

@ -157,6 +157,7 @@ def groupwise_correlation(fea1, fea2, num_groups):
return cost
def build_gwc_volume(refimg_fea, targetimg_fea, maxdisp, num_groups):
# batch, groups, max_disp, height, width
B, C, H, W = refimg_fea.shape
volume = refimg_fea.new_zeros([B, num_groups, maxdisp, H, W])
for i in range(maxdisp):

View File

@ -5,7 +5,7 @@ import os
import time
from glob import glob
from skimage import color, io
from PIL import Image
from PIL import Image, ImageEnhance
import cv2
cv2.setNumThreads(0)
@ -198,21 +198,40 @@ class SparseFlowAugmentor:
self.v_flip_prob = 0.1
# photometric augmentation params
self.photo_aug = Compose([ColorJitter(brightness=0.3, contrast=0.3, saturation=saturation_range, hue=0.3/3.14), AdjustGamma(*gamma)])
self.asymmetric_color_aug_prob = 0.2
# self.photo_aug = Compose([ColorJitter(brightness=0.3, contrast=0.3, saturation=saturation_range, hue=0.3/3.14), AdjustGamma(*gamma)])
self.eraser_aug_prob = 0.5
def chromatic_augmentation(self, img):
random_brightness = np.random.uniform(0.8, 1.2)
random_contrast = np.random.uniform(0.8, 1.2)
random_gamma = np.random.uniform(0.8, 1.2)
img = Image.fromarray(img)
enhancer = ImageEnhance.Brightness(img)
img = enhancer.enhance(random_brightness)
enhancer = ImageEnhance.Contrast(img)
img = enhancer.enhance(random_contrast)
gamma_map = [
255 * 1.0 * pow(ele / 255.0, random_gamma) for ele in range(256)
] * 3
img = img.point(gamma_map) # use PIL's point-function to accelerate this part
img_ = np.array(img)
return img_
def color_transform(self, img1, img2):
image_stack = np.concatenate([img1, img2], axis=0)
image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
img1, img2 = np.split(image_stack, 2, axis=0)
img1 = self.chromatic_augmentation(img1)
img2 = self.chromatic_augmentation(img2)
return img1, img2
def eraser_transform(self, img1, img2):
ht, wd = img1.shape[:2]
if np.random.rand() < self.eraser_aug_prob:
mean_color = np.mean(img2.reshape(-1, 3), axis=0)
for _ in range(np.random.randint(1, 3)):
for _ in range(1):
x0 = np.random.randint(0, wd)
y0 = np.random.randint(0, ht)
dx = np.random.randint(50, 100)

View File

@ -152,6 +152,11 @@ def readDispTartanAir(file_name):
valid = disp > 0
return disp, valid
def readDispCREStereo(file_name):
disp = cv2.imread(file_name, cv2.IMREAD_UNCHANGED)
disp = disp.astype(np.float32) / 32.0
valid = disp > 0.0
return disp, valid
def readDispMiddlebury(file_name):
assert basename(file_name) == 'disp0GT.pfm'

View File

@ -0,0 +1,112 @@
import os
import sys
sys.path.append("..")
import numpy as np
import logging
import os
import shutil
import random
from pathlib import Path
from glob import glob
import matplotlib.pyplot as plt
from core.utils import frame_utils
def unique(lst):
return dict(zip(*np.unique(lst, return_counts=True)))
def ensure_path_exists(path):
if not os.path.exists(path):
os.makedirs(path)
class CREStereo():
def __init__(self, aug_params=None, root='/data/CREStereo'):
self.root = root
assert os.path.exists(root)
# disp_list = self.selector('_left.disp.png')
# image1_list = self.selector('_left.jpg')
# image2_list = self.selector('_right.jpg')
# assert len(image1_list) == len(image2_list) == len(disp_list) > 0
# for img1, img2, disp in zip(image1_list, image2_list, disp_list):
# # if random.randint(1, 20000) != 1:
# # continue
# self.image_list += [[img1, img2]]
# self.disparity_list += [disp]
def get_path_info(self, path):
position, filename = os.path.split(path)
root, sub_folder = os.path.split(position)
return root, sub_folder, filename
def get_new_file(self, path):
root, sub_folder, filename = self.get_path_info(path)
return os.path.join(root, 'subset', sub_folder, filename)
def divide(self, num):
ensure_path_exists(os.path.join(self.root, 'subset'))
for sub_folder in ['tree', 'shapenet', 'reflective', 'hole']:
ensure_path_exists(os.path.join(self.root, 'subset', sub_folder))
disp1_list = self.single_folder_selector(sub_folder, '_left.disp.png')
disp2_list = self.single_folder_selector(sub_folder, '_right.disp.png')
image1_list = self.single_folder_selector(sub_folder, '_left.jpg')
image2_list = self.single_folder_selector(sub_folder, '_right.jpg')
assert len(image1_list) == len(image2_list) == len(disp1_list) == len(disp2_list) > 0
lists = []
for img1, img2, disp1, disp2 in zip(image1_list, image2_list, disp1_list, disp2_list):
lists += [[img1, img2, disp1, disp2]]
subset = random.sample(lists, num)
for s in subset:
for element in s:
print(element)
print(self.get_new_file(element))
shutil.copy(element, self.get_new_file(element))
def selector(self, suffix):
files = list(glob(os.path.join(self.root, f"hole/*{suffix}")))
files += list(glob(os.path.join(self.root, f"shapenet/*{suffix}")))
files += list(glob(os.path.join(self.root, f"tree/*{suffix}")))
files += list(glob(os.path.join(self.root, f"reflective/*{suffix}")))
return sorted(files)
def single_folder_selector(self, sub_folder, suffix):
return sorted(list(glob(os.path.join(self.root, f"{sub_folder}/*{suffix}"))))
def disparity_distribution(self):
disp_lists = self.selector('_left.disp.png')
disparities = {}
for filename in disp_lists:
print(filename)
disp_gt, _ = frame_utils.readDispCREStereo(filename)
[rows, cols] = disp_gt.shape
disp_gt = (disp_gt * 32).astype(int)
cnt = unique(disp_gt)
for i in cnt:
if i in disparities:
disparities[i] += cnt[i]
else:
disparities[i] = cnt[i]
x = []
y = []
for key in disparities.keys():
x.append(key / 32)
y.append(disparities[key])
plt.scatter(x, y)
plt.show()
c = CREStereo()
c.divide(10000)

View File

@ -26,6 +26,7 @@ def demo(args):
model.load_state_dict(torch.load(args.restore_ckpt))
model = model.module
# model = torch.compile(model)
model.to(DEVICE)
model.eval()

View File

@ -0,0 +1,8 @@
#!/bin/bash
for iter in {1..100}
do
iter=$((iter * 1000))
echo "These are the results of ${iter} iterations."
python evaluate_stereo.py --dataset middlebury_H --max_disp 384 --freeze_backbone_params --restore_ckpt ~/checkpoints/igev_stereo/${iter}_fix-validation.pth
done

View File

@ -20,7 +20,7 @@ def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
@torch.no_grad()
def validate_eth3d(model, iters=32, mixed_prec=False):
def validate_eth3d(model, iters=32, mixed_prec=False, max_disp=192):
""" Peform validation using the ETH3D (train) split """
model.eval()
aug_params = {}
@ -67,7 +67,7 @@ def validate_eth3d(model, iters=32, mixed_prec=False):
@torch.no_grad()
def validate_kitti(model, iters=32, mixed_prec=False):
def validate_kitti(model, iters=32, mixed_prec=False, max_disp=192):
""" Peform validation using the KITTI-2015 (train) split """
model.eval()
aug_params = {}
@ -96,7 +96,7 @@ def validate_kitti(model, iters=32, mixed_prec=False):
epe = torch.sum((flow_pr - flow_gt)**2, dim=0).sqrt()
epe_flattened = epe.flatten()
val = (valid_gt.flatten() >= 0.5) & (flow_gt.abs().flatten() < 192)
val = (valid_gt.flatten() >= 0.5) & (flow_gt.abs().flatten() < max_disp)
# val = valid_gt.flatten() >= 0.5
out = (epe_flattened > 3.0)
@ -120,7 +120,7 @@ def validate_kitti(model, iters=32, mixed_prec=False):
@torch.no_grad()
def validate_sceneflow(model, iters=32, mixed_prec=False):
def validate_sceneflow(model, iters=32, mixed_prec=False, max_disp=192):
""" Peform validation using the Scene Flow (TEST) split """
model.eval()
val_dataset = datasets.SceneFlowDatasets(dstype='frames_finalpass', things_test=True)
@ -144,7 +144,7 @@ def validate_sceneflow(model, iters=32, mixed_prec=False):
epe = torch.abs(flow_pr - flow_gt)
epe = epe.flatten()
val = (valid_gt.flatten() >= 0.5) & (flow_gt.abs().flatten() < 192)
val = (valid_gt.flatten() >= 0.5) & (flow_gt.abs().flatten() < max_disp)
if(np.isnan(epe[val].mean().item())):
continue
@ -169,7 +169,7 @@ def validate_sceneflow(model, iters=32, mixed_prec=False):
@torch.no_grad()
def validate_middlebury(model, iters=32, split='F', mixed_prec=False):
def validate_middlebury(model, iters=32, split='H', mixed_prec=False, max_disp=192):
""" Peform validation using the Middlebury-V3 dataset """
model.eval()
aug_params = {}
@ -196,11 +196,11 @@ def validate_middlebury(model, iters=32, split='F', mixed_prec=False):
occ_mask = Image.open(imageL_file.replace('im0.png', 'mask0nocc.png')).convert('L')
occ_mask = np.ascontiguousarray(occ_mask, dtype=np.float32).flatten()
val = (valid_gt.reshape(-1) >= 0.5) & (flow_gt[0].reshape(-1) < 192) & (occ_mask==255)
val = (valid_gt.reshape(-1) >= 0.5) & (flow_gt[0].reshape(-1) < max_disp) & (occ_mask==255)
out = (epe_flattened > 2.0)
image_out = out[val].float().mean().item()
image_epe = epe_flattened[val].mean().item()
logging.info(f"Middlebury Iter {val_id+1} out of {len(val_dataset)}. EPE {round(image_epe,4)} D1 {round(image_out,4)}")
logging.info(f"Middlebury Iter {val_id+1} out of {len(val_dataset)}. EPE {round(image_epe,4)} Err2.0 {round(image_out,4)}")
epe_list.append(image_epe)
out_list.append(image_out)
@ -208,10 +208,10 @@ def validate_middlebury(model, iters=32, split='F', mixed_prec=False):
out_list = np.array(out_list)
epe = np.mean(epe_list)
d1 = 100 * np.mean(out_list)
err2 = 100 * np.mean(out_list)
print(f"Validation Middlebury{split}: EPE {epe}, D1 {d1}")
return {f'middlebury{split}-epe': epe, f'middlebury{split}-d1': d1}
print(f"Validation Middlebury{split}: EPE {epe}, Err2.0 {err2}")
return {f'middlebury{split}-epe': epe, f'middlebury{split}-err2.0': err2}
if __name__ == '__main__':
@ -231,6 +231,7 @@ if __name__ == '__main__':
parser.add_argument('--slow_fast_gru', action='store_true', help="iterate the low-res GRUs more frequently")
parser.add_argument('--n_gru_layers', type=int, default=3, help="number of hidden GRU levels")
parser.add_argument('--max_disp', type=int, default=192, help="max disp of geometry encoding volume")
parser.add_argument('--freeze_backbone_params', action="store_true", help="freeze backbone parameters")
args = parser.parse_args()
model = torch.nn.DataParallel(IGEVStereo(args), device_ids=[0])
@ -242,6 +243,12 @@ if __name__ == '__main__':
assert args.restore_ckpt.endswith(".pth")
logging.info("Loading checkpoint...")
checkpoint = torch.load(args.restore_ckpt)
unwanted_prefix = '_orig_mod.'
for k, v in list(checkpoint.items()):
if k.startswith(unwanted_prefix):
checkpoint[k[len(unwanted_prefix):]] = checkpoint.pop(k)
model.load_state_dict(checkpoint, strict=True)
logging.info(f"Done loading checkpoint")
@ -252,13 +259,13 @@ if __name__ == '__main__':
use_mixed_precision = args.corr_implementation.endswith("_cuda")
if args.dataset == 'eth3d':
validate_eth3d(model, iters=args.valid_iters, mixed_prec=use_mixed_precision)
validate_eth3d(model, iters=args.valid_iters, mixed_prec=use_mixed_precision, max_disp=args.max_disp)
elif args.dataset == 'kitti':
validate_kitti(model, iters=args.valid_iters, mixed_prec=use_mixed_precision)
validate_kitti(model, iters=args.valid_iters, mixed_prec=use_mixed_precision, max_disp=args.max_disp)
elif args.dataset in [f"middlebury_{s}" for s in 'FHQ']:
validate_middlebury(model, iters=args.valid_iters, split=args.dataset[-1], mixed_prec=use_mixed_precision)
validate_middlebury(model, iters=args.valid_iters, split=args.dataset[-1], mixed_prec=use_mixed_precision, max_disp=args.max_disp)
elif args.dataset == 'sceneflow':
validate_sceneflow(model, iters=args.valid_iters, mixed_prec=use_mixed_precision)
validate_sceneflow(model, iters=args.valid_iters, mixed_prec=use_mixed_precision, max_disp=args.max_disp)

View File

@ -1,6 +1,7 @@
from __future__ import print_function, division
import math
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0, 1'
@ -53,7 +54,7 @@ def sequence_loss(disp_preds, disp_init_pred, disp_gt, valid, loss_gamma=0.9, ma
assert not torch.isinf(disp_gt[valid.bool()]).any()
disp_loss += 1.0 * F.smooth_l1_loss(disp_init_pred[valid.bool()], disp_gt[valid.bool()], size_average=True)
disp_loss += 1.0 * F.smooth_l1_loss(disp_init_pred[valid.bool()], disp_gt[valid.bool()], reduction='mean')
for i in range(n_predictions):
adjusted_loss_gamma = loss_gamma**(15/(n_predictions - 1))
i_weight = adjusted_loss_gamma**(n_predictions - i - 1)
@ -66,9 +67,10 @@ def sequence_loss(disp_preds, disp_init_pred, disp_gt, valid, loss_gamma=0.9, ma
metrics = {
'epe': epe.mean().item(),
'1px': (epe < 1).float().mean().item(),
'3px': (epe < 3).float().mean().item(),
'5px': (epe < 5).float().mean().item(),
'0.5px': (epe < 0.5).float().mean().item(),
'1.0px': (epe < 1.0).float().mean().item(),
'2.0px': (epe < 2.0).float().mean().item(),
'4.0px': (epe < 4.0).float().mean().item(),
}
return disp_loss, metrics
@ -76,8 +78,9 @@ def sequence_loss(disp_preds, disp_init_pred, disp_gt, valid, loss_gamma=0.9, ma
def fetch_optimizer(args, model):
""" Create the optimizer and learning rate scheduler """
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wdecay, eps=1e-8)
optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=args.wdecay, eps=1e-8)
# todo: cosine scheduler, warm-up
scheduler = optim.lr_scheduler.OneCycleLR(optimizer, args.lr, args.num_steps+100,
pct_start=0.01, cycle_momentum=False, anneal_strategy='linear')
return optimizer, scheduler
@ -132,7 +135,7 @@ class Logger:
def train(args):
# todo: compile the model to speed up at pytorch 2.0.
model = nn.DataParallel(IGEVStereo(args))
print("Parameter Count: %d" % count_parameters(model))
@ -151,14 +154,13 @@ def train(args):
model.train()
model.module.freeze_bn() # We keep BatchNorm frozen
validation_frequency = 10000
validation_frequency = 1000
scaler = GradScaler(enabled=args.mixed_precision)
should_keep_training = True
global_batch_num = 0
while should_keep_training:
for i_batch, (_, *data_blob) in enumerate(tqdm(train_loader)):
optimizer.zero_grad()
image1, image2, disp_gt, valid = [x.cuda() for x in data_blob]
@ -175,6 +177,7 @@ def train(args):
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
#warning
scaler.step(optimizer)
scheduler.step()
scaler.update()
@ -184,7 +187,7 @@ def train(args):
save_path = Path(ckpt_path + '/%d_%s.pth' % (total_steps + 1, args.name))
logging.info(f"Saving file {save_path.absolute()}")
torch.save(model.state_dict(), save_path)
results = validate_sceneflow(model.module, iters=args.valid_iters)
results = validate_middlebury(model.module, iters=args.valid_iters, max_disp=args.max_disp)
logger.write_dict(results)
model.train()
model.module.freeze_bn()
@ -212,13 +215,14 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--name', default='igev-stereo', help="name your experiment")
parser.add_argument('--restore_ckpt', default=None, help="")
parser.add_argument('--mixed_precision', default=True, action='store_true', help='use mixed precision')
parser.add_argument('--mixed_precision', default=False, action='store_true', help='use mixed precision')
# Training parameters
parser.add_argument('--batch_size', type=int, default=8, help="batch size used during training.")
parser.add_argument('--train_datasets', nargs='+', default=['sceneflow'], help="training datasets.")
parser.add_argument('--lr', type=float, default=0.0002, help="max learning rate.")
parser.add_argument('--num_steps', type=int, default=200000, help="length of training schedule.")
parser.add_argument('--freeze_backbone_params', action="store_true", help="freeze backbone parameters")
parser.add_argument('--image_size', type=int, nargs='+', default=[320, 736], help="size of the random image crops used during training.")
parser.add_argument('--train_iters', type=int, default=22, help="number of updates to the disparity field in each forward pass.")
parser.add_argument('--wdecay', type=float, default=.00001, help="Weight decay in optimizer.")
@ -238,8 +242,8 @@ if __name__ == '__main__':
parser.add_argument('--max_disp', type=int, default=192, help="max disp of geometry encoding volume")
# Data augmentation
parser.add_argument('--img_gamma', type=float, nargs='+', default=None, help="gamma range")
parser.add_argument('--saturation_range', type=float, nargs='+', default=[0, 1.4], help='color saturation')
# parser.add_argument('--img_gamma', type=float, nargs='+', default=None, help="gamma range")
# parser.add_argument('--saturation_range', type=float, nargs='+', default=[0, 1.4], help='color saturation')
parser.add_argument('--do_flip', default=False, choices=['h', 'v'], help='flip the images horizontally or vertically')
parser.add_argument('--spatial_scale', type=float, nargs='+', default=[-0.2, 0.4], help='re-scale the images randomly')
parser.add_argument('--noyjitter', action='store_true', help='don\'t simulate imperfect rectification')