Compare commits

...

14 Commits

Author SHA1 Message Date
HTensor
74bc06d58e Merge branch 'original-test' 2023-05-03 04:30:45 +08:00
HTensor
0404f5b5b0 set 12 threads to prepare datasets 2023-05-03 04:14:05 +08:00
HTensor
3b68318ed5 disabled cudnn benchmark 2023-05-03 04:13:28 +08:00
HTensor
a7d89bd95c change resolution of middlebury test images 2023-05-02 13:39:36 +08:00
HTensor
e7033dabf9 added comments 2023-05-02 01:19:05 +08:00
a1cc25351d Update create_crestereo_subsets.py 2023-04-29 13:39:09 +08:00
59ff17e149 Create create_crestereo_subsets.py but haven't finished 2023-04-28 00:52:54 +08:00
75735c8fae compatibility changes 2023-04-27 19:16:45 +08:00
21e3f92461 changed metrics indicator 2023-04-27 19:15:49 +08:00
8591c2edad Create evaluate-history.sh 2023-04-27 19:13:51 +08:00
875a1eec05 change d1 to err2.0 in middlebury 2023-04-27 13:29:57 +08:00
HTensor
1080b823a5 wtf? 2023-04-26 19:50:17 +08:00
0a3613711b locked backbone parameters 2023-04-25 20:19:43 +08:00
3f60e691f8 added asymmetric chromatic augmentation & adjusted augmentor param 2023-04-25 16:20:22 +08:00
10 changed files with 198 additions and 44 deletions

View File

@ -325,21 +325,24 @@ class SubModule(nn.Module):
class Feature(SubModule): class Feature(SubModule):
def __init__(self): def __init__(self, freeze):
super(Feature, self).__init__() super(Feature, self).__init__()
pretrained = True pretrained = True
model = timm.create_model('mobilenetv2_100', pretrained=pretrained, features_only=True) self.model = timm.create_model('mobilenetv2_100', pretrained=pretrained, features_only=True)
if freeze:
for p in self.model.parameters():
p.requires_grad = False
layers = [1,2,3,5,6] layers = [1,2,3,5,6]
chans = [16, 24, 32, 96, 160] chans = [16, 24, 32, 96, 160]
self.conv_stem = model.conv_stem self.conv_stem = self.model.conv_stem
self.bn1 = model.bn1 self.bn1 = self.model.bn1
self.act1 = model.act1 self.act1 = self.model.act1
self.block0 = torch.nn.Sequential(*model.blocks[0:layers[0]]) self.block0 = torch.nn.Sequential(*self.model.blocks[0:layers[0]])
self.block1 = torch.nn.Sequential(*model.blocks[layers[0]:layers[1]]) self.block1 = torch.nn.Sequential(*self.model.blocks[layers[0]:layers[1]])
self.block2 = torch.nn.Sequential(*model.blocks[layers[1]:layers[2]]) self.block2 = torch.nn.Sequential(*self.model.blocks[layers[1]:layers[2]])
self.block3 = torch.nn.Sequential(*model.blocks[layers[2]:layers[3]]) self.block3 = torch.nn.Sequential(*self.model.blocks[layers[2]:layers[3]])
self.block4 = torch.nn.Sequential(*model.blocks[layers[3]:layers[4]]) self.block4 = torch.nn.Sequential(*self.model.blocks[layers[3]:layers[4]])
self.deconv32_16 = Conv2x_IN(chans[4], chans[3], deconv=True, concat=True) self.deconv32_16 = Conv2x_IN(chans[4], chans[3], deconv=True, concat=True)
self.deconv16_8 = Conv2x_IN(chans[3]*2, chans[2], deconv=True, concat=True) self.deconv16_8 = Conv2x_IN(chans[3]*2, chans[2], deconv=True, concat=True)

View File

@ -60,6 +60,7 @@ class Combined_Geo_Encoding_Volume:
@staticmethod @staticmethod
def corr(fmap1, fmap2): def corr(fmap1, fmap2):
# batch, dim, ht, wd
B, D, H, W1 = fmap1.shape B, D, H, W1 = fmap1.shape
_, _, _, W2 = fmap2.shape _, _, _, W2 = fmap2.shape
fmap1 = fmap1.view(B, D, H, W1) fmap1 = fmap1.view(B, D, H, W1)

View File

@ -100,7 +100,7 @@ class IGEVStereo(nn.Module):
self.context_zqr_convs = nn.ModuleList([nn.Conv2d(context_dims[i], args.hidden_dims[i]*3, 3, padding=3//2) for i in range(self.args.n_gru_layers)]) self.context_zqr_convs = nn.ModuleList([nn.Conv2d(context_dims[i], args.hidden_dims[i]*3, 3, padding=3//2) for i in range(self.args.n_gru_layers)])
self.feature = Feature() self.feature = Feature(args.freeze_backbone_params)
self.stem_2 = nn.Sequential( self.stem_2 = nn.Sequential(
BasicConv_IN(3, 32, kernel_size=3, stride=2, padding=1), BasicConv_IN(3, 32, kernel_size=3, stride=2, padding=1),
@ -167,6 +167,7 @@ class IGEVStereo(nn.Module):
match_right = self.desc(self.conv(features_right[0])) match_right = self.desc(self.conv(features_right[0]))
gwc_volume = build_gwc_volume(match_left, match_right, self.args.max_disp//4, 8) gwc_volume = build_gwc_volume(match_left, match_right, self.args.max_disp//4, 8)
gwc_volume = self.corr_stem(gwc_volume) gwc_volume = self.corr_stem(gwc_volume)
# 3d unet
gwc_volume = self.corr_feature_att(gwc_volume, features_left[0]) gwc_volume = self.corr_feature_att(gwc_volume, features_left[0])
geo_encoding_volume = self.cost_agg(gwc_volume, features_left) geo_encoding_volume = self.cost_agg(gwc_volume, features_left)
@ -182,6 +183,7 @@ class IGEVStereo(nn.Module):
spx_pred = self.spx(xspx) spx_pred = self.spx(xspx)
spx_pred = F.softmax(spx_pred, 1) spx_pred = F.softmax(spx_pred, 1)
# Content Network
cnet_list = self.cnet(image1, num_layers=self.args.n_gru_layers) cnet_list = self.cnet(image1, num_layers=self.args.n_gru_layers)
net_list = [torch.tanh(x[0]) for x in cnet_list] net_list = [torch.tanh(x[0]) for x in cnet_list]
inp_list = [torch.relu(x[1]) for x in cnet_list] inp_list = [torch.relu(x[1]) for x in cnet_list]

View File

@ -293,7 +293,7 @@ class CREStereo(StereoDataset):
class Middlebury(StereoDataset): class Middlebury(StereoDataset):
def __init__(self, aug_params=None, root='/data/Middlebury', split='F'): def __init__(self, aug_params=None, root='/data/Middlebury', split='H'):
super(Middlebury, self).__init__(aug_params, sparse=True, reader=frame_utils.readDispMiddlebury) super(Middlebury, self).__init__(aug_params, sparse=True, reader=frame_utils.readDispMiddlebury)
assert os.path.exists(root) assert os.path.exists(root)
assert split in "FHQ" assert split in "FHQ"
@ -352,7 +352,7 @@ def fetch_dataloader(args):
train_dataset = new_dataset if train_dataset is None else train_dataset + new_dataset train_dataset = new_dataset if train_dataset is None else train_dataset + new_dataset
train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size, train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size,
pin_memory=True, shuffle=True, num_workers=int(os.environ.get('SLURM_CPUS_PER_TASK', 6))-2, drop_last=True) pin_memory=True, shuffle=True, num_workers=12, drop_last=True)
logging.info('Training with %d image pairs' % len(train_dataset)) logging.info('Training with %d image pairs' % len(train_dataset))
return train_loader return train_loader

View File

@ -157,6 +157,7 @@ def groupwise_correlation(fea1, fea2, num_groups):
return cost return cost
def build_gwc_volume(refimg_fea, targetimg_fea, maxdisp, num_groups): def build_gwc_volume(refimg_fea, targetimg_fea, maxdisp, num_groups):
# batch, groups, max_disp, height, width
B, C, H, W = refimg_fea.shape B, C, H, W = refimg_fea.shape
volume = refimg_fea.new_zeros([B, num_groups, maxdisp, H, W]) volume = refimg_fea.new_zeros([B, num_groups, maxdisp, H, W])
for i in range(maxdisp): for i in range(maxdisp):

View File

@ -5,7 +5,7 @@ import os
import time import time
from glob import glob from glob import glob
from skimage import color, io from skimage import color, io
from PIL import Image from PIL import Image, ImageEnhance
import cv2 import cv2
cv2.setNumThreads(0) cv2.setNumThreads(0)
@ -198,21 +198,40 @@ class SparseFlowAugmentor:
self.v_flip_prob = 0.1 self.v_flip_prob = 0.1
# photometric augmentation params # photometric augmentation params
self.photo_aug = Compose([ColorJitter(brightness=0.3, contrast=0.3, saturation=saturation_range, hue=0.3/3.14), AdjustGamma(*gamma)]) # self.photo_aug = Compose([ColorJitter(brightness=0.3, contrast=0.3, saturation=saturation_range, hue=0.3/3.14), AdjustGamma(*gamma)])
self.asymmetric_color_aug_prob = 0.2
self.eraser_aug_prob = 0.5 self.eraser_aug_prob = 0.5
def chromatic_augmentation(self, img):
random_brightness = np.random.uniform(0.8, 1.2)
random_contrast = np.random.uniform(0.8, 1.2)
random_gamma = np.random.uniform(0.8, 1.2)
img = Image.fromarray(img)
enhancer = ImageEnhance.Brightness(img)
img = enhancer.enhance(random_brightness)
enhancer = ImageEnhance.Contrast(img)
img = enhancer.enhance(random_contrast)
gamma_map = [
255 * 1.0 * pow(ele / 255.0, random_gamma) for ele in range(256)
] * 3
img = img.point(gamma_map) # use PIL's point-function to accelerate this part
img_ = np.array(img)
return img_
def color_transform(self, img1, img2): def color_transform(self, img1, img2):
image_stack = np.concatenate([img1, img2], axis=0) img1 = self.chromatic_augmentation(img1)
image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8) img2 = self.chromatic_augmentation(img2)
img1, img2 = np.split(image_stack, 2, axis=0)
return img1, img2 return img1, img2
def eraser_transform(self, img1, img2): def eraser_transform(self, img1, img2):
ht, wd = img1.shape[:2] ht, wd = img1.shape[:2]
if np.random.rand() < self.eraser_aug_prob: if np.random.rand() < self.eraser_aug_prob:
mean_color = np.mean(img2.reshape(-1, 3), axis=0) mean_color = np.mean(img2.reshape(-1, 3), axis=0)
for _ in range(np.random.randint(1, 3)): for _ in range(1):
x0 = np.random.randint(0, wd) x0 = np.random.randint(0, wd)
y0 = np.random.randint(0, ht) y0 = np.random.randint(0, ht)
dx = np.random.randint(50, 100) dx = np.random.randint(50, 100)

View File

@ -0,0 +1,112 @@
import os
import sys
sys.path.append("..")
import numpy as np
import logging
import os
import shutil
import random
from pathlib import Path
from glob import glob
import matplotlib.pyplot as plt
from core.utils import frame_utils
def unique(lst):
return dict(zip(*np.unique(lst, return_counts=True)))
def ensure_path_exists(path):
if not os.path.exists(path):
os.makedirs(path)
class CREStereo():
def __init__(self, aug_params=None, root='/data/CREStereo'):
self.root = root
assert os.path.exists(root)
# disp_list = self.selector('_left.disp.png')
# image1_list = self.selector('_left.jpg')
# image2_list = self.selector('_right.jpg')
# assert len(image1_list) == len(image2_list) == len(disp_list) > 0
# for img1, img2, disp in zip(image1_list, image2_list, disp_list):
# # if random.randint(1, 20000) != 1:
# # continue
# self.image_list += [[img1, img2]]
# self.disparity_list += [disp]
def get_path_info(self, path):
position, filename = os.path.split(path)
root, sub_folder = os.path.split(position)
return root, sub_folder, filename
def get_new_file(self, path):
root, sub_folder, filename = self.get_path_info(path)
return os.path.join(root, 'subset', sub_folder, filename)
def divide(self, num):
ensure_path_exists(os.path.join(self.root, 'subset'))
for sub_folder in ['tree', 'shapenet', 'reflective', 'hole']:
ensure_path_exists(os.path.join(self.root, 'subset', sub_folder))
disp1_list = self.single_folder_selector(sub_folder, '_left.disp.png')
disp2_list = self.single_folder_selector(sub_folder, '_right.disp.png')
image1_list = self.single_folder_selector(sub_folder, '_left.jpg')
image2_list = self.single_folder_selector(sub_folder, '_right.jpg')
assert len(image1_list) == len(image2_list) == len(disp1_list) == len(disp2_list) > 0
lists = []
for img1, img2, disp1, disp2 in zip(image1_list, image2_list, disp1_list, disp2_list):
lists += [[img1, img2, disp1, disp2]]
subset = random.sample(lists, num)
for s in subset:
for element in s:
print(element)
print(self.get_new_file(element))
shutil.copy(element, self.get_new_file(element))
def selector(self, suffix):
files = list(glob(os.path.join(self.root, f"hole/*{suffix}")))
files += list(glob(os.path.join(self.root, f"shapenet/*{suffix}")))
files += list(glob(os.path.join(self.root, f"tree/*{suffix}")))
files += list(glob(os.path.join(self.root, f"reflective/*{suffix}")))
return sorted(files)
def single_folder_selector(self, sub_folder, suffix):
return sorted(list(glob(os.path.join(self.root, f"{sub_folder}/*{suffix}"))))
def disparity_distribution(self):
disp_lists = self.selector('_left.disp.png')
disparities = {}
for filename in disp_lists:
print(filename)
disp_gt, _ = frame_utils.readDispCREStereo(filename)
[rows, cols] = disp_gt.shape
disp_gt = (disp_gt * 32).astype(int)
cnt = unique(disp_gt)
for i in cnt:
if i in disparities:
disparities[i] += cnt[i]
else:
disparities[i] = cnt[i]
x = []
y = []
for key in disparities.keys():
x.append(key / 32)
y.append(disparities[key])
plt.scatter(x, y)
plt.show()
c = CREStereo()
c.divide(10000)

View File

@ -0,0 +1,8 @@
#!/bin/bash
for iter in {1..100}
do
iter=$((iter * 1000))
echo "These are the results of ${iter} iterations."
python evaluate_stereo.py --dataset middlebury_H --max_disp 384 --freeze_backbone_params --restore_ckpt ~/checkpoints/igev_stereo/${iter}_fix-validation.pth
done

View File

@ -20,7 +20,7 @@ def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad) return sum(p.numel() for p in model.parameters() if p.requires_grad)
@torch.no_grad() @torch.no_grad()
def validate_eth3d(model, iters=32, mixed_prec=False): def validate_eth3d(model, iters=32, mixed_prec=False, max_disp=192):
""" Peform validation using the ETH3D (train) split """ """ Peform validation using the ETH3D (train) split """
model.eval() model.eval()
aug_params = {} aug_params = {}
@ -67,7 +67,7 @@ def validate_eth3d(model, iters=32, mixed_prec=False):
@torch.no_grad() @torch.no_grad()
def validate_kitti(model, iters=32, mixed_prec=False): def validate_kitti(model, iters=32, mixed_prec=False, max_disp=192):
""" Peform validation using the KITTI-2015 (train) split """ """ Peform validation using the KITTI-2015 (train) split """
model.eval() model.eval()
aug_params = {} aug_params = {}
@ -96,7 +96,7 @@ def validate_kitti(model, iters=32, mixed_prec=False):
epe = torch.sum((flow_pr - flow_gt)**2, dim=0).sqrt() epe = torch.sum((flow_pr - flow_gt)**2, dim=0).sqrt()
epe_flattened = epe.flatten() epe_flattened = epe.flatten()
val = (valid_gt.flatten() >= 0.5) & (flow_gt.abs().flatten() < 192) val = (valid_gt.flatten() >= 0.5) & (flow_gt.abs().flatten() < max_disp)
# val = valid_gt.flatten() >= 0.5 # val = valid_gt.flatten() >= 0.5
out = (epe_flattened > 3.0) out = (epe_flattened > 3.0)
@ -120,7 +120,7 @@ def validate_kitti(model, iters=32, mixed_prec=False):
@torch.no_grad() @torch.no_grad()
def validate_sceneflow(model, iters=32, mixed_prec=False): def validate_sceneflow(model, iters=32, mixed_prec=False, max_disp=192):
""" Peform validation using the Scene Flow (TEST) split """ """ Peform validation using the Scene Flow (TEST) split """
model.eval() model.eval()
val_dataset = datasets.SceneFlowDatasets(dstype='frames_finalpass', things_test=True) val_dataset = datasets.SceneFlowDatasets(dstype='frames_finalpass', things_test=True)
@ -144,7 +144,7 @@ def validate_sceneflow(model, iters=32, mixed_prec=False):
epe = torch.abs(flow_pr - flow_gt) epe = torch.abs(flow_pr - flow_gt)
epe = epe.flatten() epe = epe.flatten()
val = (valid_gt.flatten() >= 0.5) & (flow_gt.abs().flatten() < 192) val = (valid_gt.flatten() >= 0.5) & (flow_gt.abs().flatten() < max_disp)
if(np.isnan(epe[val].mean().item())): if(np.isnan(epe[val].mean().item())):
continue continue
@ -169,7 +169,7 @@ def validate_sceneflow(model, iters=32, mixed_prec=False):
@torch.no_grad() @torch.no_grad()
def validate_middlebury(model, iters=32, split='H', mixed_prec=False): def validate_middlebury(model, iters=32, split='H', mixed_prec=False, max_disp=192):
""" Peform validation using the Middlebury-V3 dataset """ """ Peform validation using the Middlebury-V3 dataset """
model.eval() model.eval()
aug_params = {} aug_params = {}
@ -196,11 +196,11 @@ def validate_middlebury(model, iters=32, split='H', mixed_prec=False):
occ_mask = Image.open(imageL_file.replace('im0.png', 'mask0nocc.png')).convert('L') occ_mask = Image.open(imageL_file.replace('im0.png', 'mask0nocc.png')).convert('L')
occ_mask = np.ascontiguousarray(occ_mask, dtype=np.float32).flatten() occ_mask = np.ascontiguousarray(occ_mask, dtype=np.float32).flatten()
val = (valid_gt.reshape(-1) >= 0.5) & (flow_gt[0].reshape(-1) < 192) & (occ_mask==255) val = (valid_gt.reshape(-1) >= 0.5) & (flow_gt[0].reshape(-1) < max_disp) & (occ_mask==255)
out = (epe_flattened > 2.0) out = (epe_flattened > 2.0)
image_out = out[val].float().mean().item() image_out = out[val].float().mean().item()
image_epe = epe_flattened[val].mean().item() image_epe = epe_flattened[val].mean().item()
logging.info(f"Middlebury Iter {val_id+1} out of {len(val_dataset)}. EPE {round(image_epe,4)} D1 {round(image_out,4)}") logging.info(f"Middlebury Iter {val_id+1} out of {len(val_dataset)}. EPE {round(image_epe,4)} Err2.0 {round(image_out,4)}")
epe_list.append(image_epe) epe_list.append(image_epe)
out_list.append(image_out) out_list.append(image_out)
@ -208,10 +208,10 @@ def validate_middlebury(model, iters=32, split='H', mixed_prec=False):
out_list = np.array(out_list) out_list = np.array(out_list)
epe = np.mean(epe_list) epe = np.mean(epe_list)
d1 = 100 * np.mean(out_list) err2 = 100 * np.mean(out_list)
print(f"Validation Middlebury{split}: EPE {epe}, D1 {d1}") print(f"Validation Middlebury{split}: EPE {epe}, Err2.0 {err2}")
return {f'middlebury{split}-epe': epe, f'middlebury{split}-d1': d1} return {f'middlebury{split}-epe': epe, f'middlebury{split}-err2.0': err2}
if __name__ == '__main__': if __name__ == '__main__':
@ -231,6 +231,7 @@ if __name__ == '__main__':
parser.add_argument('--slow_fast_gru', action='store_true', help="iterate the low-res GRUs more frequently") parser.add_argument('--slow_fast_gru', action='store_true', help="iterate the low-res GRUs more frequently")
parser.add_argument('--n_gru_layers', type=int, default=3, help="number of hidden GRU levels") parser.add_argument('--n_gru_layers', type=int, default=3, help="number of hidden GRU levels")
parser.add_argument('--max_disp', type=int, default=192, help="max disp of geometry encoding volume") parser.add_argument('--max_disp', type=int, default=192, help="max disp of geometry encoding volume")
parser.add_argument('--freeze_backbone_params', action="store_true", help="freeze backbone parameters")
args = parser.parse_args() args = parser.parse_args()
model = torch.nn.DataParallel(IGEVStereo(args), device_ids=[0]) model = torch.nn.DataParallel(IGEVStereo(args), device_ids=[0])
@ -242,6 +243,12 @@ if __name__ == '__main__':
assert args.restore_ckpt.endswith(".pth") assert args.restore_ckpt.endswith(".pth")
logging.info("Loading checkpoint...") logging.info("Loading checkpoint...")
checkpoint = torch.load(args.restore_ckpt) checkpoint = torch.load(args.restore_ckpt)
unwanted_prefix = '_orig_mod.'
for k, v in list(checkpoint.items()):
if k.startswith(unwanted_prefix):
checkpoint[k[len(unwanted_prefix):]] = checkpoint.pop(k)
model.load_state_dict(checkpoint, strict=True) model.load_state_dict(checkpoint, strict=True)
logging.info(f"Done loading checkpoint") logging.info(f"Done loading checkpoint")
@ -252,13 +259,13 @@ if __name__ == '__main__':
use_mixed_precision = args.corr_implementation.endswith("_cuda") use_mixed_precision = args.corr_implementation.endswith("_cuda")
if args.dataset == 'eth3d': if args.dataset == 'eth3d':
validate_eth3d(model, iters=args.valid_iters, mixed_prec=use_mixed_precision) validate_eth3d(model, iters=args.valid_iters, mixed_prec=use_mixed_precision, max_disp=args.max_disp)
elif args.dataset == 'kitti': elif args.dataset == 'kitti':
validate_kitti(model, iters=args.valid_iters, mixed_prec=use_mixed_precision) validate_kitti(model, iters=args.valid_iters, mixed_prec=use_mixed_precision, max_disp=args.max_disp)
elif args.dataset in [f"middlebury_{s}" for s in 'FHQ']: elif args.dataset in [f"middlebury_{s}" for s in 'FHQ']:
validate_middlebury(model, iters=args.valid_iters, split=args.dataset[-1], mixed_prec=use_mixed_precision) validate_middlebury(model, iters=args.valid_iters, split=args.dataset[-1], mixed_prec=use_mixed_precision, max_disp=args.max_disp)
elif args.dataset == 'sceneflow': elif args.dataset == 'sceneflow':
validate_sceneflow(model, iters=args.valid_iters, mixed_prec=use_mixed_precision) validate_sceneflow(model, iters=args.valid_iters, mixed_prec=use_mixed_precision, max_disp=args.max_disp)

View File

@ -22,7 +22,6 @@ from evaluate_stereo import *
import core.stereo_datasets as datasets import core.stereo_datasets as datasets
import torch.nn.functional as F import torch.nn.functional as F
torch.backends.cudnn.benchmark = True
ckpt_path = './checkpoints/igev_stereo' ckpt_path = './checkpoints/igev_stereo'
log_path = './checkpoints/igev_stereo' log_path = './checkpoints/igev_stereo'
@ -68,9 +67,10 @@ def sequence_loss(disp_preds, disp_init_pred, disp_gt, valid, loss_gamma=0.9, ma
metrics = { metrics = {
'epe': epe.mean().item(), 'epe': epe.mean().item(),
'1px': (epe < 1).float().mean().item(), '0.5px': (epe < 0.5).float().mean().item(),
'3px': (epe < 3).float().mean().item(), '1.0px': (epe < 1.0).float().mean().item(),
'5px': (epe < 5).float().mean().item(), '2.0px': (epe < 2.0).float().mean().item(),
'4.0px': (epe < 4.0).float().mean().item(),
} }
return disp_loss, metrics return disp_loss, metrics
@ -78,7 +78,7 @@ def sequence_loss(disp_preds, disp_init_pred, disp_gt, valid, loss_gamma=0.9, ma
def fetch_optimizer(args, model): def fetch_optimizer(args, model):
""" Create the optimizer and learning rate scheduler """ """ Create the optimizer and learning rate scheduler """
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wdecay, eps=1e-8) optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=args.wdecay, eps=1e-8)
# todo: cosine scheduler, warm-up # todo: cosine scheduler, warm-up
scheduler = optim.lr_scheduler.OneCycleLR(optimizer, args.lr, args.num_steps+100, scheduler = optim.lr_scheduler.OneCycleLR(optimizer, args.lr, args.num_steps+100,
@ -187,7 +187,7 @@ def train(args):
save_path = Path(ckpt_path + '/%d_%s.pth' % (total_steps + 1, args.name)) save_path = Path(ckpt_path + '/%d_%s.pth' % (total_steps + 1, args.name))
logging.info(f"Saving file {save_path.absolute()}") logging.info(f"Saving file {save_path.absolute()}")
torch.save(model.state_dict(), save_path) torch.save(model.state_dict(), save_path)
results = validate_middlebury(model.module, iters=args.valid_iters) results = validate_middlebury(model.module, iters=args.valid_iters, max_disp=args.max_disp)
logger.write_dict(results) logger.write_dict(results)
model.train() model.train()
model.module.freeze_bn() model.module.freeze_bn()
@ -222,6 +222,7 @@ if __name__ == '__main__':
parser.add_argument('--train_datasets', nargs='+', default=['sceneflow'], help="training datasets.") parser.add_argument('--train_datasets', nargs='+', default=['sceneflow'], help="training datasets.")
parser.add_argument('--lr', type=float, default=0.0002, help="max learning rate.") parser.add_argument('--lr', type=float, default=0.0002, help="max learning rate.")
parser.add_argument('--num_steps', type=int, default=200000, help="length of training schedule.") parser.add_argument('--num_steps', type=int, default=200000, help="length of training schedule.")
parser.add_argument('--freeze_backbone_params', action="store_true", help="freeze backbone parameters")
parser.add_argument('--image_size', type=int, nargs='+', default=[320, 736], help="size of the random image crops used during training.") parser.add_argument('--image_size', type=int, nargs='+', default=[320, 736], help="size of the random image crops used during training.")
parser.add_argument('--train_iters', type=int, default=22, help="number of updates to the disparity field in each forward pass.") parser.add_argument('--train_iters', type=int, default=22, help="number of updates to the disparity field in each forward pass.")
parser.add_argument('--wdecay', type=float, default=.00001, help="Weight decay in optimizer.") parser.add_argument('--wdecay', type=float, default=.00001, help="Weight decay in optimizer.")
@ -241,8 +242,8 @@ if __name__ == '__main__':
parser.add_argument('--max_disp', type=int, default=192, help="max disp of geometry encoding volume") parser.add_argument('--max_disp', type=int, default=192, help="max disp of geometry encoding volume")
# Data augmentation # Data augmentation
parser.add_argument('--img_gamma', type=float, nargs='+', default=None, help="gamma range") # parser.add_argument('--img_gamma', type=float, nargs='+', default=None, help="gamma range")
parser.add_argument('--saturation_range', type=float, nargs='+', default=[0, 1.4], help='color saturation') # parser.add_argument('--saturation_range', type=float, nargs='+', default=[0, 1.4], help='color saturation')
parser.add_argument('--do_flip', default=False, choices=['h', 'v'], help='flip the images horizontally or vertically') parser.add_argument('--do_flip', default=False, choices=['h', 'v'], help='flip the images horizontally or vertically')
parser.add_argument('--spatial_scale', type=float, nargs='+', default=[-0.2, 0.4], help='re-scale the images randomly') parser.add_argument('--spatial_scale', type=float, nargs='+', default=[-0.2, 0.4], help='re-scale the images randomly')
parser.add_argument('--noyjitter', action='store_true', help='don\'t simulate imperfect rectification') parser.add_argument('--noyjitter', action='store_true', help='don\'t simulate imperfect rectification')