This commit is contained in:
HTensor 2023-04-24 16:37:30 +08:00
parent 25753af380
commit c3b4812e99
4 changed files with 5 additions and 4 deletions

View File

@ -37,8 +37,8 @@ class Combined_Geo_Encoding_Volume:
out_pyramid = [] out_pyramid = []
for i in range(self.num_levels): for i in range(self.num_levels):
geo_volume = self.geo_volume_pyramid[i] geo_volume = self.geo_volume_pyramid[i]
dx = torch.linspace(-r, r, 2*r+1) dx = torch.linspace(-r, r, 2*r+1, device=disp.device)
dx = dx.view(1, 1, 2*r+1, 1).to(disp.device) dx = dx.view(1, 1, 2*r+1, 1)
x0 = dx + disp.reshape(b*h*w, 1, 1, 1) / 2**i x0 = dx + disp.reshape(b*h*w, 1, 1, 1) / 2**i
y0 = torch.zeros_like(x0) y0 = torch.zeros_like(x0)

View File

@ -191,7 +191,7 @@ class IGEVStereo(nn.Module):
geo_block = Combined_Geo_Encoding_Volume geo_block = Combined_Geo_Encoding_Volume
geo_fn = geo_block(match_left.float(), match_right.float(), geo_encoding_volume.float(), radius=self.args.corr_radius, num_levels=self.args.corr_levels) geo_fn = geo_block(match_left.float(), match_right.float(), geo_encoding_volume.float(), radius=self.args.corr_radius, num_levels=self.args.corr_levels)
b, c, h, w = match_left.shape b, c, h, w = match_left.shape
coords = torch.arange(w).float().to(match_left.device).reshape(1,1,w,1).repeat(b, h, 1, 1) coords = torch.arange(w, device=match_left.device).float().reshape(1,1,w,1).repeat(b, h, 1, 1)
disp = init_disp disp = init_disp
disp_preds = [] disp_preds = []

View File

@ -353,7 +353,7 @@ def fetch_dataloader(args):
train_dataset = new_dataset if train_dataset is None else train_dataset + new_dataset train_dataset = new_dataset if train_dataset is None else train_dataset + new_dataset
train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size, train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size,
pin_memory=False, shuffle=True, num_workers=int(os.environ.get('SLURM_CPUS_PER_TASK', 6))-2, drop_last=True) pin_memory=True, shuffle=True, num_workers=int(os.environ.get('SLURM_CPUS_PER_TASK', 6))-2, drop_last=True)
logging.info('Training with %d image pairs' % len(train_dataset)) logging.info('Training with %d image pairs' % len(train_dataset))
return train_loader return train_loader

View File

@ -22,6 +22,7 @@ from evaluate_stereo import *
import core.stereo_datasets as datasets import core.stereo_datasets as datasets
import torch.nn.functional as F import torch.nn.functional as F
torch.backends.cudnn.benchmark = True
ckpt_path = './checkpoints/igev_stereo' ckpt_path = './checkpoints/igev_stereo'
log_path = './checkpoints/igev_stereo' log_path = './checkpoints/igev_stereo'