This commit is contained in:
HTensor 2023-04-26 19:50:17 +08:00
parent 0a3613711b
commit 1080b823a5
3 changed files with 11 additions and 10 deletions

View File

@ -182,6 +182,7 @@ class IGEVStereo(nn.Module):
spx_pred = self.spx(xspx)
spx_pred = F.softmax(spx_pred, 1)
# Content Network
cnet_list = self.cnet(image1, num_layers=self.args.n_gru_layers)
net_list = [torch.tanh(x[0]) for x in cnet_list]
inp_list = [torch.relu(x[1]) for x in cnet_list]

View File

@ -20,7 +20,7 @@ def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
@torch.no_grad()
def validate_eth3d(model, iters=32, mixed_prec=False):
def validate_eth3d(model, iters=32, mixed_prec=False, max_disp=192):
""" Peform validation using the ETH3D (train) split """
model.eval()
aug_params = {}
@ -67,7 +67,7 @@ def validate_eth3d(model, iters=32, mixed_prec=False):
@torch.no_grad()
def validate_kitti(model, iters=32, mixed_prec=False):
def validate_kitti(model, iters=32, mixed_prec=False, max_disp=192):
""" Peform validation using the KITTI-2015 (train) split """
model.eval()
aug_params = {}
@ -96,7 +96,7 @@ def validate_kitti(model, iters=32, mixed_prec=False):
epe = torch.sum((flow_pr - flow_gt)**2, dim=0).sqrt()
epe_flattened = epe.flatten()
val = (valid_gt.flatten() >= 0.5) & (flow_gt.abs().flatten() < 192)
val = (valid_gt.flatten() >= 0.5) & (flow_gt.abs().flatten() < max_disp)
# val = valid_gt.flatten() >= 0.5
out = (epe_flattened > 3.0)
@ -120,7 +120,7 @@ def validate_kitti(model, iters=32, mixed_prec=False):
@torch.no_grad()
def validate_sceneflow(model, iters=32, mixed_prec=False):
def validate_sceneflow(model, iters=32, mixed_prec=False, max_disp=192):
""" Peform validation using the Scene Flow (TEST) split """
model.eval()
val_dataset = datasets.SceneFlowDatasets(dstype='frames_finalpass', things_test=True)
@ -144,7 +144,7 @@ def validate_sceneflow(model, iters=32, mixed_prec=False):
epe = torch.abs(flow_pr - flow_gt)
epe = epe.flatten()
val = (valid_gt.flatten() >= 0.5) & (flow_gt.abs().flatten() < 192)
val = (valid_gt.flatten() >= 0.5) & (flow_gt.abs().flatten() < max_disp)
if(np.isnan(epe[val].mean().item())):
continue
@ -169,7 +169,7 @@ def validate_sceneflow(model, iters=32, mixed_prec=False):
@torch.no_grad()
def validate_middlebury(model, iters=32, split='F', mixed_prec=False):
def validate_middlebury(model, iters=32, split='F', mixed_prec=False, max_disp=192):
""" Peform validation using the Middlebury-V3 dataset """
model.eval()
aug_params = {}
@ -196,7 +196,7 @@ def validate_middlebury(model, iters=32, split='F', mixed_prec=False):
occ_mask = Image.open(imageL_file.replace('im0.png', 'mask0nocc.png')).convert('L')
occ_mask = np.ascontiguousarray(occ_mask, dtype=np.float32).flatten()
val = (valid_gt.reshape(-1) >= 0.5) & (flow_gt[0].reshape(-1) < 192) & (occ_mask==255)
val = (valid_gt.reshape(-1) >= 0.5) & (flow_gt[0].reshape(-1) < max_disp) & (occ_mask==255)
out = (epe_flattened > 2.0)
image_out = out[val].float().mean().item()
image_epe = epe_flattened[val].mean().item()
@ -252,10 +252,10 @@ if __name__ == '__main__':
use_mixed_precision = args.corr_implementation.endswith("_cuda")
if args.dataset == 'eth3d':
validate_eth3d(model, iters=args.valid_iters, mixed_prec=use_mixed_precision)
validate_eth3d(model, iters=args.valid_iters, mixed_prec=use_mixed_precision, max_disp=args.max_disp)
elif args.dataset == 'kitti':
validate_kitti(model, iters=args.valid_iters, mixed_prec=use_mixed_precision)
validate_kitti(model, iters=args.valid_iters, mixed_prec=use_mixed_precision, max_disp=args.max_disp)
elif args.dataset in [f"middlebury_{s}" for s in 'FHQ']:
validate_middlebury(model, iters=args.valid_iters, split=args.dataset[-1], mixed_prec=use_mixed_precision)

View File

@ -187,7 +187,7 @@ def train(args):
save_path = Path(ckpt_path + '/%d_%s.pth' % (total_steps + 1, args.name))
logging.info(f"Saving file {save_path.absolute()}")
torch.save(model.state_dict(), save_path)
results = validate_middlebury(model.module, iters=args.valid_iters)
results = validate_middlebury(model.module, iters=args.valid_iters, max_disp=args.max_disp)
logger.write_dict(results)
model.train()
model.module.freeze_bn()