Compare commits

...

7 Commits

Author SHA1 Message Date
HTensor
74bc06d58e Merge branch 'original-test' 2023-05-03 04:30:45 +08:00
HTensor
0404f5b5b0 set 12 threads to prepare datasets 2023-05-03 04:14:05 +08:00
HTensor
3b68318ed5 disabled cudnn benchmark 2023-05-03 04:13:28 +08:00
HTensor
a7d89bd95c change resolution of middlebury test images 2023-05-02 13:39:36 +08:00
HTensor
e7033dabf9 added comments 2023-05-02 01:19:05 +08:00
HTensor
9df896bb70 Update stereo_datasets.py 2023-05-01 12:25:21 +08:00
HTensor
73e65f99b8 Update evaluate_stereo.py 2023-04-30 16:03:58 +08:00
5 changed files with 5 additions and 5 deletions

View File

@ -60,6 +60,7 @@ class Combined_Geo_Encoding_Volume:
@staticmethod @staticmethod
def corr(fmap1, fmap2): def corr(fmap1, fmap2):
# batch, dim, ht, wd
B, D, H, W1 = fmap1.shape B, D, H, W1 = fmap1.shape
_, _, _, W2 = fmap2.shape _, _, _, W2 = fmap2.shape
fmap1 = fmap1.view(B, D, H, W1) fmap1 = fmap1.view(B, D, H, W1)

View File

@ -167,6 +167,7 @@ class IGEVStereo(nn.Module):
match_right = self.desc(self.conv(features_right[0])) match_right = self.desc(self.conv(features_right[0]))
gwc_volume = build_gwc_volume(match_left, match_right, self.args.max_disp//4, 8) gwc_volume = build_gwc_volume(match_left, match_right, self.args.max_disp//4, 8)
gwc_volume = self.corr_stem(gwc_volume) gwc_volume = self.corr_stem(gwc_volume)
# 3d unet
gwc_volume = self.corr_feature_att(gwc_volume, features_left[0]) gwc_volume = self.corr_feature_att(gwc_volume, features_left[0])
geo_encoding_volume = self.cost_agg(gwc_volume, features_left) geo_encoding_volume = self.cost_agg(gwc_volume, features_left)

View File

@ -73,7 +73,6 @@ class StereoDataset(data.Dataset):
img2 = np.array(img2).astype(np.uint8) img2 = np.array(img2).astype(np.uint8)
disp = np.array(disp).astype(np.float32) disp = np.array(disp).astype(np.float32)
assert not (True in np.isnan(disp))
flow = np.stack([disp, np.zeros_like(disp)], axis=-1) flow = np.stack([disp, np.zeros_like(disp)], axis=-1)
@ -294,7 +293,7 @@ class CREStereo(StereoDataset):
class Middlebury(StereoDataset): class Middlebury(StereoDataset):
def __init__(self, aug_params=None, root='/data/Middlebury', split='F'): def __init__(self, aug_params=None, root='/data/Middlebury', split='H'):
super(Middlebury, self).__init__(aug_params, sparse=True, reader=frame_utils.readDispMiddlebury) super(Middlebury, self).__init__(aug_params, sparse=True, reader=frame_utils.readDispMiddlebury)
assert os.path.exists(root) assert os.path.exists(root)
assert split in "FHQ" assert split in "FHQ"
@ -353,7 +352,7 @@ def fetch_dataloader(args):
train_dataset = new_dataset if train_dataset is None else train_dataset + new_dataset train_dataset = new_dataset if train_dataset is None else train_dataset + new_dataset
train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size, train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size,
pin_memory=True, shuffle=True, num_workers=int(os.environ.get('SLURM_CPUS_PER_TASK', 6))-2, drop_last=True) pin_memory=True, shuffle=True, num_workers=12, drop_last=True)
logging.info('Training with %d image pairs' % len(train_dataset)) logging.info('Training with %d image pairs' % len(train_dataset))
return train_loader return train_loader

View File

@ -169,7 +169,7 @@ def validate_sceneflow(model, iters=32, mixed_prec=False, max_disp=192):
@torch.no_grad() @torch.no_grad()
def validate_middlebury(model, iters=32, split='F', mixed_prec=False, max_disp=192): def validate_middlebury(model, iters=32, split='H', mixed_prec=False, max_disp=192):
""" Peform validation using the Middlebury-V3 dataset """ """ Peform validation using the Middlebury-V3 dataset """
model.eval() model.eval()
aug_params = {} aug_params = {}

View File

@ -22,7 +22,6 @@ from evaluate_stereo import *
import core.stereo_datasets as datasets import core.stereo_datasets as datasets
import torch.nn.functional as F import torch.nn.functional as F
torch.backends.cudnn.benchmark = True
ckpt_path = './checkpoints/igev_stereo' ckpt_path = './checkpoints/igev_stereo'
log_path = './checkpoints/igev_stereo' log_path = './checkpoints/igev_stereo'