Compare commits

..

7 Commits

Author SHA1 Message Date
HTensor
74bc06d58e Merge branch 'original-test' 2023-05-03 04:30:45 +08:00
HTensor
0404f5b5b0 set 12 threads to prepare datasets 2023-05-03 04:14:05 +08:00
HTensor
3b68318ed5 disabled cudnn benchmark 2023-05-03 04:13:28 +08:00
HTensor
a7d89bd95c change resolution of middlebury test images 2023-05-02 13:39:36 +08:00
HTensor
e7033dabf9 added comments 2023-05-02 01:19:05 +08:00
HTensor
9df896bb70 Update stereo_datasets.py 2023-05-01 12:25:21 +08:00
HTensor
73e65f99b8 Update evaluate_stereo.py 2023-04-30 16:03:58 +08:00
6 changed files with 8 additions and 16 deletions

View File

@ -60,6 +60,7 @@ class Combined_Geo_Encoding_Volume:
@staticmethod
def corr(fmap1, fmap2):
# batch, dim, ht, wd
B, D, H, W1 = fmap1.shape
_, _, _, W2 = fmap2.shape
fmap1 = fmap1.view(B, D, H, W1)

View File

@ -167,6 +167,7 @@ class IGEVStereo(nn.Module):
match_right = self.desc(self.conv(features_right[0]))
gwc_volume = build_gwc_volume(match_left, match_right, self.args.max_disp//4, 8)
gwc_volume = self.corr_stem(gwc_volume)
# 3d unet
gwc_volume = self.corr_feature_att(gwc_volume, features_left[0])
geo_encoding_volume = self.cost_agg(gwc_volume, features_left)

View File

@ -73,7 +73,6 @@ class StereoDataset(data.Dataset):
img2 = np.array(img2).astype(np.uint8)
disp = np.array(disp).astype(np.float32)
assert not (True in np.isnan(disp))
flow = np.stack([disp, np.zeros_like(disp)], axis=-1)
@ -294,7 +293,7 @@ class CREStereo(StereoDataset):
class Middlebury(StereoDataset):
def __init__(self, aug_params=None, root='/data/Middlebury', split='F'):
def __init__(self, aug_params=None, root='/data/Middlebury', split='H'):
super(Middlebury, self).__init__(aug_params, sparse=True, reader=frame_utils.readDispMiddlebury)
assert os.path.exists(root)
assert split in "FHQ"
@ -353,7 +352,7 @@ def fetch_dataloader(args):
train_dataset = new_dataset if train_dataset is None else train_dataset + new_dataset
train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size,
pin_memory=True, shuffle=True, num_workers=int(os.environ.get('SLURM_CPUS_PER_TASK', 6))-2, drop_last=True)
pin_memory=True, shuffle=True, num_workers=12, drop_last=True)
logging.info('Training with %d image pairs' % len(train_dataset))
return train_loader

View File

@ -15,7 +15,6 @@ from PIL import Image
from matplotlib import pyplot as plt
import os
import cv2
from torch.profiler import profile, record_function, ProfilerActivity
def load_image(imfile):
img = np.array(Image.open(imfile)).astype(np.uint8)
@ -45,14 +44,8 @@ def demo(args):
padder = InputPadder(image1.shape, divis_by=32)
image1, image2 = padder.pad(image1, image2)
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
with_stack=True,
) as prof:
disp = model(image1, image2, iters=args.valid_iters, test_mode=True)
print(prof.key_averages(group_by_stack_n=5).table(sort_by="self_cuda_time_total", row_limit=10))
prof.export_chrome_trace("./trace.json")
# disp = model(image1, image2, iters=args.valid_iters, test_mode=True)
disp = model(image1, image2, iters=args.valid_iters, test_mode=True)
disp = disp.cpu().numpy()
disp = padder.unpad(disp)
file_stem = imfile1.split('/')[-2]
@ -88,7 +81,6 @@ if __name__ == '__main__':
parser.add_argument('--slow_fast_gru', action='store_true', help="iterate the low-res GRUs more frequently")
parser.add_argument('--n_gru_layers', type=int, default=3, help="number of hidden GRU levels")
parser.add_argument('--max_disp', type=int, default=192, help="max disp of geometry encoding volume")
parser.add_argument('--freeze_backbone_params', action="store_true", help="freeze backbone parameters")
args = parser.parse_args()

View File

@ -169,7 +169,7 @@ def validate_sceneflow(model, iters=32, mixed_prec=False, max_disp=192):
@torch.no_grad()
def validate_middlebury(model, iters=32, split='F', mixed_prec=False, max_disp=192):
def validate_middlebury(model, iters=32, split='H', mixed_prec=False, max_disp=192):
""" Peform validation using the Middlebury-V3 dataset """
model.eval()
aug_params = {}

View File

@ -22,7 +22,6 @@ from evaluate_stereo import *
import core.stereo_datasets as datasets
import torch.nn.functional as F
torch.backends.cudnn.benchmark = True
ckpt_path = './checkpoints/igev_stereo'
log_path = './checkpoints/igev_stereo'