wtf?
This commit is contained in:
parent
0a3613711b
commit
1080b823a5
@ -182,6 +182,7 @@ class IGEVStereo(nn.Module):
|
|||||||
spx_pred = self.spx(xspx)
|
spx_pred = self.spx(xspx)
|
||||||
spx_pred = F.softmax(spx_pred, 1)
|
spx_pred = F.softmax(spx_pred, 1)
|
||||||
|
|
||||||
|
# Content Network
|
||||||
cnet_list = self.cnet(image1, num_layers=self.args.n_gru_layers)
|
cnet_list = self.cnet(image1, num_layers=self.args.n_gru_layers)
|
||||||
net_list = [torch.tanh(x[0]) for x in cnet_list]
|
net_list = [torch.tanh(x[0]) for x in cnet_list]
|
||||||
inp_list = [torch.relu(x[1]) for x in cnet_list]
|
inp_list = [torch.relu(x[1]) for x in cnet_list]
|
||||||
|
@ -20,7 +20,7 @@ def count_parameters(model):
|
|||||||
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def validate_eth3d(model, iters=32, mixed_prec=False):
|
def validate_eth3d(model, iters=32, mixed_prec=False, max_disp=192):
|
||||||
""" Peform validation using the ETH3D (train) split """
|
""" Peform validation using the ETH3D (train) split """
|
||||||
model.eval()
|
model.eval()
|
||||||
aug_params = {}
|
aug_params = {}
|
||||||
@ -67,7 +67,7 @@ def validate_eth3d(model, iters=32, mixed_prec=False):
|
|||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def validate_kitti(model, iters=32, mixed_prec=False):
|
def validate_kitti(model, iters=32, mixed_prec=False, max_disp=192):
|
||||||
""" Peform validation using the KITTI-2015 (train) split """
|
""" Peform validation using the KITTI-2015 (train) split """
|
||||||
model.eval()
|
model.eval()
|
||||||
aug_params = {}
|
aug_params = {}
|
||||||
@ -96,7 +96,7 @@ def validate_kitti(model, iters=32, mixed_prec=False):
|
|||||||
epe = torch.sum((flow_pr - flow_gt)**2, dim=0).sqrt()
|
epe = torch.sum((flow_pr - flow_gt)**2, dim=0).sqrt()
|
||||||
|
|
||||||
epe_flattened = epe.flatten()
|
epe_flattened = epe.flatten()
|
||||||
val = (valid_gt.flatten() >= 0.5) & (flow_gt.abs().flatten() < 192)
|
val = (valid_gt.flatten() >= 0.5) & (flow_gt.abs().flatten() < max_disp)
|
||||||
# val = valid_gt.flatten() >= 0.5
|
# val = valid_gt.flatten() >= 0.5
|
||||||
|
|
||||||
out = (epe_flattened > 3.0)
|
out = (epe_flattened > 3.0)
|
||||||
@ -120,7 +120,7 @@ def validate_kitti(model, iters=32, mixed_prec=False):
|
|||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def validate_sceneflow(model, iters=32, mixed_prec=False):
|
def validate_sceneflow(model, iters=32, mixed_prec=False, max_disp=192):
|
||||||
""" Peform validation using the Scene Flow (TEST) split """
|
""" Peform validation using the Scene Flow (TEST) split """
|
||||||
model.eval()
|
model.eval()
|
||||||
val_dataset = datasets.SceneFlowDatasets(dstype='frames_finalpass', things_test=True)
|
val_dataset = datasets.SceneFlowDatasets(dstype='frames_finalpass', things_test=True)
|
||||||
@ -144,7 +144,7 @@ def validate_sceneflow(model, iters=32, mixed_prec=False):
|
|||||||
epe = torch.abs(flow_pr - flow_gt)
|
epe = torch.abs(flow_pr - flow_gt)
|
||||||
|
|
||||||
epe = epe.flatten()
|
epe = epe.flatten()
|
||||||
val = (valid_gt.flatten() >= 0.5) & (flow_gt.abs().flatten() < 192)
|
val = (valid_gt.flatten() >= 0.5) & (flow_gt.abs().flatten() < max_disp)
|
||||||
|
|
||||||
if(np.isnan(epe[val].mean().item())):
|
if(np.isnan(epe[val].mean().item())):
|
||||||
continue
|
continue
|
||||||
@ -169,7 +169,7 @@ def validate_sceneflow(model, iters=32, mixed_prec=False):
|
|||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def validate_middlebury(model, iters=32, split='F', mixed_prec=False):
|
def validate_middlebury(model, iters=32, split='F', mixed_prec=False, max_disp=192):
|
||||||
""" Peform validation using the Middlebury-V3 dataset """
|
""" Peform validation using the Middlebury-V3 dataset """
|
||||||
model.eval()
|
model.eval()
|
||||||
aug_params = {}
|
aug_params = {}
|
||||||
@ -196,7 +196,7 @@ def validate_middlebury(model, iters=32, split='F', mixed_prec=False):
|
|||||||
occ_mask = Image.open(imageL_file.replace('im0.png', 'mask0nocc.png')).convert('L')
|
occ_mask = Image.open(imageL_file.replace('im0.png', 'mask0nocc.png')).convert('L')
|
||||||
occ_mask = np.ascontiguousarray(occ_mask, dtype=np.float32).flatten()
|
occ_mask = np.ascontiguousarray(occ_mask, dtype=np.float32).flatten()
|
||||||
|
|
||||||
val = (valid_gt.reshape(-1) >= 0.5) & (flow_gt[0].reshape(-1) < 192) & (occ_mask==255)
|
val = (valid_gt.reshape(-1) >= 0.5) & (flow_gt[0].reshape(-1) < max_disp) & (occ_mask==255)
|
||||||
out = (epe_flattened > 2.0)
|
out = (epe_flattened > 2.0)
|
||||||
image_out = out[val].float().mean().item()
|
image_out = out[val].float().mean().item()
|
||||||
image_epe = epe_flattened[val].mean().item()
|
image_epe = epe_flattened[val].mean().item()
|
||||||
@ -252,10 +252,10 @@ if __name__ == '__main__':
|
|||||||
use_mixed_precision = args.corr_implementation.endswith("_cuda")
|
use_mixed_precision = args.corr_implementation.endswith("_cuda")
|
||||||
|
|
||||||
if args.dataset == 'eth3d':
|
if args.dataset == 'eth3d':
|
||||||
validate_eth3d(model, iters=args.valid_iters, mixed_prec=use_mixed_precision)
|
validate_eth3d(model, iters=args.valid_iters, mixed_prec=use_mixed_precision, max_disp=args.max_disp)
|
||||||
|
|
||||||
elif args.dataset == 'kitti':
|
elif args.dataset == 'kitti':
|
||||||
validate_kitti(model, iters=args.valid_iters, mixed_prec=use_mixed_precision)
|
validate_kitti(model, iters=args.valid_iters, mixed_prec=use_mixed_precision, max_disp=args.max_disp)
|
||||||
|
|
||||||
elif args.dataset in [f"middlebury_{s}" for s in 'FHQ']:
|
elif args.dataset in [f"middlebury_{s}" for s in 'FHQ']:
|
||||||
validate_middlebury(model, iters=args.valid_iters, split=args.dataset[-1], mixed_prec=use_mixed_precision)
|
validate_middlebury(model, iters=args.valid_iters, split=args.dataset[-1], mixed_prec=use_mixed_precision)
|
||||||
|
@ -187,7 +187,7 @@ def train(args):
|
|||||||
save_path = Path(ckpt_path + '/%d_%s.pth' % (total_steps + 1, args.name))
|
save_path = Path(ckpt_path + '/%d_%s.pth' % (total_steps + 1, args.name))
|
||||||
logging.info(f"Saving file {save_path.absolute()}")
|
logging.info(f"Saving file {save_path.absolute()}")
|
||||||
torch.save(model.state_dict(), save_path)
|
torch.save(model.state_dict(), save_path)
|
||||||
results = validate_middlebury(model.module, iters=args.valid_iters)
|
results = validate_middlebury(model.module, iters=args.valid_iters, max_disp=args.max_disp)
|
||||||
logger.write_dict(results)
|
logger.write_dict(results)
|
||||||
model.train()
|
model.train()
|
||||||
model.module.freeze_bn()
|
model.module.freeze_bn()
|
||||||
|
Loading…
Reference in New Issue
Block a user