added crestereo datasets & fixed some bugs
This commit is contained in:
parent
93371a3c4d
commit
a301ebb020
@ -165,7 +165,7 @@ class IGEVStereo(nn.Module):
|
|||||||
|
|
||||||
match_left = self.desc(self.conv(features_left[0]))
|
match_left = self.desc(self.conv(features_left[0]))
|
||||||
match_right = self.desc(self.conv(features_right[0]))
|
match_right = self.desc(self.conv(features_right[0]))
|
||||||
gwc_volume = build_gwc_volume(match_left, match_right, 192//4, 8)
|
gwc_volume = build_gwc_volume(match_left, match_right, self.args.max_disp//4, 8)
|
||||||
gwc_volume = self.corr_stem(gwc_volume)
|
gwc_volume = self.corr_stem(gwc_volume)
|
||||||
gwc_volume = self.corr_feature_att(gwc_volume, features_left[0])
|
gwc_volume = self.corr_feature_att(gwc_volume, features_left[0])
|
||||||
geo_encoding_volume = self.cost_agg(gwc_volume, features_left)
|
geo_encoding_volume = self.cost_agg(gwc_volume, features_left)
|
||||||
|
@ -73,6 +73,7 @@ class StereoDataset(data.Dataset):
|
|||||||
img2 = np.array(img2).astype(np.uint8)
|
img2 = np.array(img2).astype(np.uint8)
|
||||||
|
|
||||||
disp = np.array(disp).astype(np.float32)
|
disp = np.array(disp).astype(np.float32)
|
||||||
|
assert not (True in np.isnan(disp))
|
||||||
|
|
||||||
flow = np.stack([disp, np.zeros_like(disp)], axis=-1)
|
flow = np.stack([disp, np.zeros_like(disp)], axis=-1)
|
||||||
|
|
||||||
@ -266,6 +267,31 @@ class KITTI(StereoDataset):
|
|||||||
self.image_list += [ [img1, img2] ]
|
self.image_list += [ [img1, img2] ]
|
||||||
self.disparity_list += [ disp ]
|
self.disparity_list += [ disp ]
|
||||||
|
|
||||||
|
class CREStereo(StereoDataset):
|
||||||
|
def __init__(self, aug_params=None, root='/data/CREStereo'):
|
||||||
|
super(CREStereo, self).__init__(aug_params, sparse=True, reader=frame_utils.readDispCREStereo)
|
||||||
|
self.root = root
|
||||||
|
assert os.path.exists(root)
|
||||||
|
|
||||||
|
disp_list = self.selector('_left.disp.png')
|
||||||
|
image1_list = self.selector('_left.jpg')
|
||||||
|
image2_list = self.selector('_right.jpg')
|
||||||
|
|
||||||
|
assert len(image1_list) == len(image2_list) == len(disp_list) > 0
|
||||||
|
for img1, img2, disp in zip(image1_list, image2_list, disp_list):
|
||||||
|
# if random.randint(1, 20000) != 1:
|
||||||
|
# continue
|
||||||
|
self.image_list += [[img1, img2]]
|
||||||
|
self.disparity_list += [disp]
|
||||||
|
|
||||||
|
def selector(self, suffix):
|
||||||
|
files = list(glob(os.path.join(self.root, f"hole/*{suffix}")))
|
||||||
|
files += list(glob(os.path.join(self.root, f"shapenet/*{suffix}")))
|
||||||
|
files += list(glob(os.path.join(self.root, f"tree/*{suffix}")))
|
||||||
|
files += list(glob(os.path.join(self.root, f"reflective/*{suffix}")))
|
||||||
|
return sorted(files)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Middlebury(StereoDataset):
|
class Middlebury(StereoDataset):
|
||||||
def __init__(self, aug_params=None, root='/data/Middlebury', split='F'):
|
def __init__(self, aug_params=None, root='/data/Middlebury', split='F'):
|
||||||
@ -321,10 +347,13 @@ def fetch_dataloader(args):
|
|||||||
elif dataset_name.startswith('tartan_air'):
|
elif dataset_name.startswith('tartan_air'):
|
||||||
new_dataset = TartanAir(aug_params, keywords=dataset_name.split('_')[2:])
|
new_dataset = TartanAir(aug_params, keywords=dataset_name.split('_')[2:])
|
||||||
logging.info(f"Adding {len(new_dataset)} samples from Tartain Air")
|
logging.info(f"Adding {len(new_dataset)} samples from Tartain Air")
|
||||||
|
elif dataset_name.startswith('crestereo'):
|
||||||
|
new_dataset = CREStereo(aug_params)
|
||||||
|
logging.info(f"Adding {len(new_dataset)} samples from CREStereo")
|
||||||
train_dataset = new_dataset if train_dataset is None else train_dataset + new_dataset
|
train_dataset = new_dataset if train_dataset is None else train_dataset + new_dataset
|
||||||
|
|
||||||
train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size,
|
train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size,
|
||||||
pin_memory=True, shuffle=True, num_workers=int(os.environ.get('SLURM_CPUS_PER_TASK', 6))-2, drop_last=True)
|
pin_memory=False, shuffle=True, num_workers=int(os.environ.get('SLURM_CPUS_PER_TASK', 6))-2, drop_last=True)
|
||||||
|
|
||||||
logging.info('Training with %d image pairs' % len(train_dataset))
|
logging.info('Training with %d image pairs' % len(train_dataset))
|
||||||
return train_loader
|
return train_loader
|
||||||
|
@ -152,6 +152,11 @@ def readDispTartanAir(file_name):
|
|||||||
valid = disp > 0
|
valid = disp > 0
|
||||||
return disp, valid
|
return disp, valid
|
||||||
|
|
||||||
|
def readDispCREStereo(file_name):
|
||||||
|
disp = cv2.imread(file_name, cv2.IMREAD_UNCHANGED)
|
||||||
|
disp = disp.astype(np.float32) / 32.0
|
||||||
|
valid = disp > 0.0
|
||||||
|
return disp, valid
|
||||||
|
|
||||||
def readDispMiddlebury(file_name):
|
def readDispMiddlebury(file_name):
|
||||||
assert basename(file_name) == 'disp0GT.pfm'
|
assert basename(file_name) == 'disp0GT.pfm'
|
||||||
|
@ -26,6 +26,7 @@ def demo(args):
|
|||||||
model.load_state_dict(torch.load(args.restore_ckpt))
|
model.load_state_dict(torch.load(args.restore_ckpt))
|
||||||
|
|
||||||
model = model.module
|
model = model.module
|
||||||
|
# model = torch.compile(model)
|
||||||
model.to(DEVICE)
|
model.to(DEVICE)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
|
|
||||||
from __future__ import print_function, division
|
from __future__ import print_function, division
|
||||||
|
|
||||||
|
import math
|
||||||
import os
|
import os
|
||||||
os.environ['CUDA_VISIBLE_DEVICES'] = '0, 1'
|
os.environ['CUDA_VISIBLE_DEVICES'] = '0, 1'
|
||||||
|
|
||||||
@ -53,7 +54,7 @@ def sequence_loss(disp_preds, disp_init_pred, disp_gt, valid, loss_gamma=0.9, ma
|
|||||||
assert not torch.isinf(disp_gt[valid.bool()]).any()
|
assert not torch.isinf(disp_gt[valid.bool()]).any()
|
||||||
|
|
||||||
|
|
||||||
disp_loss += 1.0 * F.smooth_l1_loss(disp_init_pred[valid.bool()], disp_gt[valid.bool()], size_average=True)
|
disp_loss += 1.0 * F.smooth_l1_loss(disp_init_pred[valid.bool()], disp_gt[valid.bool()], reduction='mean')
|
||||||
for i in range(n_predictions):
|
for i in range(n_predictions):
|
||||||
adjusted_loss_gamma = loss_gamma**(15/(n_predictions - 1))
|
adjusted_loss_gamma = loss_gamma**(15/(n_predictions - 1))
|
||||||
i_weight = adjusted_loss_gamma**(n_predictions - i - 1)
|
i_weight = adjusted_loss_gamma**(n_predictions - i - 1)
|
||||||
@ -78,6 +79,7 @@ def fetch_optimizer(args, model):
|
|||||||
""" Create the optimizer and learning rate scheduler """
|
""" Create the optimizer and learning rate scheduler """
|
||||||
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wdecay, eps=1e-8)
|
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wdecay, eps=1e-8)
|
||||||
|
|
||||||
|
# todo: cosine scheduler, warm-up
|
||||||
scheduler = optim.lr_scheduler.OneCycleLR(optimizer, args.lr, args.num_steps+100,
|
scheduler = optim.lr_scheduler.OneCycleLR(optimizer, args.lr, args.num_steps+100,
|
||||||
pct_start=0.01, cycle_momentum=False, anneal_strategy='linear')
|
pct_start=0.01, cycle_momentum=False, anneal_strategy='linear')
|
||||||
return optimizer, scheduler
|
return optimizer, scheduler
|
||||||
@ -151,14 +153,13 @@ def train(args):
|
|||||||
model.train()
|
model.train()
|
||||||
model.module.freeze_bn() # We keep BatchNorm frozen
|
model.module.freeze_bn() # We keep BatchNorm frozen
|
||||||
|
|
||||||
validation_frequency = 10000
|
validation_frequency = 1000
|
||||||
|
|
||||||
scaler = GradScaler(enabled=args.mixed_precision)
|
scaler = GradScaler(enabled=args.mixed_precision)
|
||||||
|
|
||||||
should_keep_training = True
|
should_keep_training = True
|
||||||
global_batch_num = 0
|
global_batch_num = 0
|
||||||
while should_keep_training:
|
while should_keep_training:
|
||||||
|
|
||||||
for i_batch, (_, *data_blob) in enumerate(tqdm(train_loader)):
|
for i_batch, (_, *data_blob) in enumerate(tqdm(train_loader)):
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
image1, image2, disp_gt, valid = [x.cuda() for x in data_blob]
|
image1, image2, disp_gt, valid = [x.cuda() for x in data_blob]
|
||||||
@ -175,6 +176,7 @@ def train(args):
|
|||||||
scaler.unscale_(optimizer)
|
scaler.unscale_(optimizer)
|
||||||
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
||||||
|
|
||||||
|
#warning
|
||||||
scaler.step(optimizer)
|
scaler.step(optimizer)
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
scaler.update()
|
scaler.update()
|
||||||
@ -184,7 +186,7 @@ def train(args):
|
|||||||
save_path = Path(ckpt_path + '/%d_%s.pth' % (total_steps + 1, args.name))
|
save_path = Path(ckpt_path + '/%d_%s.pth' % (total_steps + 1, args.name))
|
||||||
logging.info(f"Saving file {save_path.absolute()}")
|
logging.info(f"Saving file {save_path.absolute()}")
|
||||||
torch.save(model.state_dict(), save_path)
|
torch.save(model.state_dict(), save_path)
|
||||||
results = validate_sceneflow(model.module, iters=args.valid_iters)
|
results = validate_middlebury(model.module, iters=args.valid_iters)
|
||||||
logger.write_dict(results)
|
logger.write_dict(results)
|
||||||
model.train()
|
model.train()
|
||||||
model.module.freeze_bn()
|
model.module.freeze_bn()
|
||||||
@ -212,7 +214,7 @@ if __name__ == '__main__':
|
|||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--name', default='igev-stereo', help="name your experiment")
|
parser.add_argument('--name', default='igev-stereo', help="name your experiment")
|
||||||
parser.add_argument('--restore_ckpt', default=None, help="")
|
parser.add_argument('--restore_ckpt', default=None, help="")
|
||||||
parser.add_argument('--mixed_precision', default=True, action='store_true', help='use mixed precision')
|
parser.add_argument('--mixed_precision', default=False, action='store_true', help='use mixed precision')
|
||||||
|
|
||||||
# Training parameters
|
# Training parameters
|
||||||
parser.add_argument('--batch_size', type=int, default=8, help="batch size used during training.")
|
parser.add_argument('--batch_size', type=int, default=8, help="batch size used during training.")
|
||||||
|
Loading…
Reference in New Issue
Block a user