diff --git a/IGEV-Stereo/core/geometry.py b/IGEV-Stereo/core/geometry.py index a4519dd..12f1a85 100644 --- a/IGEV-Stereo/core/geometry.py +++ b/IGEV-Stereo/core/geometry.py @@ -37,8 +37,8 @@ class Combined_Geo_Encoding_Volume: out_pyramid = [] for i in range(self.num_levels): geo_volume = self.geo_volume_pyramid[i] - dx = torch.linspace(-r, r, 2*r+1) - dx = dx.view(1, 1, 2*r+1, 1).to(disp.device) + dx = torch.linspace(-r, r, 2*r+1, device=disp.device) + dx = dx.view(1, 1, 2*r+1, 1) x0 = dx + disp.reshape(b*h*w, 1, 1, 1) / 2**i y0 = torch.zeros_like(x0) diff --git a/IGEV-Stereo/core/igev_stereo.py b/IGEV-Stereo/core/igev_stereo.py index e8ce791..c432092 100644 --- a/IGEV-Stereo/core/igev_stereo.py +++ b/IGEV-Stereo/core/igev_stereo.py @@ -191,7 +191,7 @@ class IGEVStereo(nn.Module): geo_block = Combined_Geo_Encoding_Volume geo_fn = geo_block(match_left.float(), match_right.float(), geo_encoding_volume.float(), radius=self.args.corr_radius, num_levels=self.args.corr_levels) b, c, h, w = match_left.shape - coords = torch.arange(w).float().to(match_left.device).reshape(1,1,w,1).repeat(b, h, 1, 1) + coords = torch.arange(w, device=match_left.device).float().reshape(1,1,w,1).repeat(b, h, 1, 1) disp = init_disp disp_preds = [] diff --git a/IGEV-Stereo/core/stereo_datasets.py b/IGEV-Stereo/core/stereo_datasets.py index 206d802..21fc63f 100644 --- a/IGEV-Stereo/core/stereo_datasets.py +++ b/IGEV-Stereo/core/stereo_datasets.py @@ -353,7 +353,7 @@ def fetch_dataloader(args): train_dataset = new_dataset if train_dataset is None else train_dataset + new_dataset train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size, - pin_memory=False, shuffle=True, num_workers=int(os.environ.get('SLURM_CPUS_PER_TASK', 6))-2, drop_last=True) + pin_memory=True, shuffle=True, num_workers=int(os.environ.get('SLURM_CPUS_PER_TASK', 6))-2, drop_last=True) logging.info('Training with %d image pairs' % len(train_dataset)) return train_loader diff --git a/IGEV-Stereo/train_stereo.py b/IGEV-Stereo/train_stereo.py index ad0b8b5..a8aa3c8 100644 --- a/IGEV-Stereo/train_stereo.py +++ b/IGEV-Stereo/train_stereo.py @@ -22,6 +22,7 @@ from evaluate_stereo import * import core.stereo_datasets as datasets import torch.nn.functional as F +torch.backends.cudnn.benchmark = True ckpt_path = './checkpoints/igev_stereo' log_path = './checkpoints/igev_stereo'