Compare commits
7 Commits
Author | SHA1 | Date | |
---|---|---|---|
|
74bc06d58e | ||
|
0404f5b5b0 | ||
|
3b68318ed5 | ||
|
a7d89bd95c | ||
|
e7033dabf9 | ||
|
9df896bb70 | ||
|
73e65f99b8 |
@ -60,6 +60,7 @@ class Combined_Geo_Encoding_Volume:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def corr(fmap1, fmap2):
|
def corr(fmap1, fmap2):
|
||||||
|
# batch, dim, ht, wd
|
||||||
B, D, H, W1 = fmap1.shape
|
B, D, H, W1 = fmap1.shape
|
||||||
_, _, _, W2 = fmap2.shape
|
_, _, _, W2 = fmap2.shape
|
||||||
fmap1 = fmap1.view(B, D, H, W1)
|
fmap1 = fmap1.view(B, D, H, W1)
|
||||||
|
@ -167,6 +167,7 @@ class IGEVStereo(nn.Module):
|
|||||||
match_right = self.desc(self.conv(features_right[0]))
|
match_right = self.desc(self.conv(features_right[0]))
|
||||||
gwc_volume = build_gwc_volume(match_left, match_right, self.args.max_disp//4, 8)
|
gwc_volume = build_gwc_volume(match_left, match_right, self.args.max_disp//4, 8)
|
||||||
gwc_volume = self.corr_stem(gwc_volume)
|
gwc_volume = self.corr_stem(gwc_volume)
|
||||||
|
# 3d unet
|
||||||
gwc_volume = self.corr_feature_att(gwc_volume, features_left[0])
|
gwc_volume = self.corr_feature_att(gwc_volume, features_left[0])
|
||||||
geo_encoding_volume = self.cost_agg(gwc_volume, features_left)
|
geo_encoding_volume = self.cost_agg(gwc_volume, features_left)
|
||||||
|
|
||||||
|
@ -73,7 +73,6 @@ class StereoDataset(data.Dataset):
|
|||||||
img2 = np.array(img2).astype(np.uint8)
|
img2 = np.array(img2).astype(np.uint8)
|
||||||
|
|
||||||
disp = np.array(disp).astype(np.float32)
|
disp = np.array(disp).astype(np.float32)
|
||||||
assert not (True in np.isnan(disp))
|
|
||||||
|
|
||||||
flow = np.stack([disp, np.zeros_like(disp)], axis=-1)
|
flow = np.stack([disp, np.zeros_like(disp)], axis=-1)
|
||||||
|
|
||||||
@ -294,7 +293,7 @@ class CREStereo(StereoDataset):
|
|||||||
|
|
||||||
|
|
||||||
class Middlebury(StereoDataset):
|
class Middlebury(StereoDataset):
|
||||||
def __init__(self, aug_params=None, root='/data/Middlebury', split='F'):
|
def __init__(self, aug_params=None, root='/data/Middlebury', split='H'):
|
||||||
super(Middlebury, self).__init__(aug_params, sparse=True, reader=frame_utils.readDispMiddlebury)
|
super(Middlebury, self).__init__(aug_params, sparse=True, reader=frame_utils.readDispMiddlebury)
|
||||||
assert os.path.exists(root)
|
assert os.path.exists(root)
|
||||||
assert split in "FHQ"
|
assert split in "FHQ"
|
||||||
@ -353,7 +352,7 @@ def fetch_dataloader(args):
|
|||||||
train_dataset = new_dataset if train_dataset is None else train_dataset + new_dataset
|
train_dataset = new_dataset if train_dataset is None else train_dataset + new_dataset
|
||||||
|
|
||||||
train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size,
|
train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size,
|
||||||
pin_memory=True, shuffle=True, num_workers=int(os.environ.get('SLURM_CPUS_PER_TASK', 6))-2, drop_last=True)
|
pin_memory=True, shuffle=True, num_workers=12, drop_last=True)
|
||||||
|
|
||||||
logging.info('Training with %d image pairs' % len(train_dataset))
|
logging.info('Training with %d image pairs' % len(train_dataset))
|
||||||
return train_loader
|
return train_loader
|
||||||
|
@ -15,7 +15,6 @@ from PIL import Image
|
|||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
import os
|
import os
|
||||||
import cv2
|
import cv2
|
||||||
from torch.profiler import profile, record_function, ProfilerActivity
|
|
||||||
|
|
||||||
def load_image(imfile):
|
def load_image(imfile):
|
||||||
img = np.array(Image.open(imfile)).astype(np.uint8)
|
img = np.array(Image.open(imfile)).astype(np.uint8)
|
||||||
@ -45,14 +44,8 @@ def demo(args):
|
|||||||
|
|
||||||
padder = InputPadder(image1.shape, divis_by=32)
|
padder = InputPadder(image1.shape, divis_by=32)
|
||||||
image1, image2 = padder.pad(image1, image2)
|
image1, image2 = padder.pad(image1, image2)
|
||||||
with profile(
|
|
||||||
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
|
|
||||||
with_stack=True,
|
|
||||||
) as prof:
|
|
||||||
disp = model(image1, image2, iters=args.valid_iters, test_mode=True)
|
disp = model(image1, image2, iters=args.valid_iters, test_mode=True)
|
||||||
print(prof.key_averages(group_by_stack_n=5).table(sort_by="self_cuda_time_total", row_limit=10))
|
|
||||||
prof.export_chrome_trace("./trace.json")
|
|
||||||
# disp = model(image1, image2, iters=args.valid_iters, test_mode=True)
|
|
||||||
disp = disp.cpu().numpy()
|
disp = disp.cpu().numpy()
|
||||||
disp = padder.unpad(disp)
|
disp = padder.unpad(disp)
|
||||||
file_stem = imfile1.split('/')[-2]
|
file_stem = imfile1.split('/')[-2]
|
||||||
@ -88,7 +81,6 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument('--slow_fast_gru', action='store_true', help="iterate the low-res GRUs more frequently")
|
parser.add_argument('--slow_fast_gru', action='store_true', help="iterate the low-res GRUs more frequently")
|
||||||
parser.add_argument('--n_gru_layers', type=int, default=3, help="number of hidden GRU levels")
|
parser.add_argument('--n_gru_layers', type=int, default=3, help="number of hidden GRU levels")
|
||||||
parser.add_argument('--max_disp', type=int, default=192, help="max disp of geometry encoding volume")
|
parser.add_argument('--max_disp', type=int, default=192, help="max disp of geometry encoding volume")
|
||||||
parser.add_argument('--freeze_backbone_params', action="store_true", help="freeze backbone parameters")
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
@ -169,7 +169,7 @@ def validate_sceneflow(model, iters=32, mixed_prec=False, max_disp=192):
|
|||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def validate_middlebury(model, iters=32, split='F', mixed_prec=False, max_disp=192):
|
def validate_middlebury(model, iters=32, split='H', mixed_prec=False, max_disp=192):
|
||||||
""" Peform validation using the Middlebury-V3 dataset """
|
""" Peform validation using the Middlebury-V3 dataset """
|
||||||
model.eval()
|
model.eval()
|
||||||
aug_params = {}
|
aug_params = {}
|
||||||
|
@ -22,7 +22,6 @@ from evaluate_stereo import *
|
|||||||
import core.stereo_datasets as datasets
|
import core.stereo_datasets as datasets
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
torch.backends.cudnn.benchmark = True
|
|
||||||
|
|
||||||
ckpt_path = './checkpoints/igev_stereo'
|
ckpt_path = './checkpoints/igev_stereo'
|
||||||
log_path = './checkpoints/igev_stereo'
|
log_path = './checkpoints/igev_stereo'
|
||||||
|
Loading…
Reference in New Issue
Block a user