optimize
This commit is contained in:
parent
25753af380
commit
c3b4812e99
@ -37,8 +37,8 @@ class Combined_Geo_Encoding_Volume:
|
|||||||
out_pyramid = []
|
out_pyramid = []
|
||||||
for i in range(self.num_levels):
|
for i in range(self.num_levels):
|
||||||
geo_volume = self.geo_volume_pyramid[i]
|
geo_volume = self.geo_volume_pyramid[i]
|
||||||
dx = torch.linspace(-r, r, 2*r+1)
|
dx = torch.linspace(-r, r, 2*r+1, device=disp.device)
|
||||||
dx = dx.view(1, 1, 2*r+1, 1).to(disp.device)
|
dx = dx.view(1, 1, 2*r+1, 1)
|
||||||
x0 = dx + disp.reshape(b*h*w, 1, 1, 1) / 2**i
|
x0 = dx + disp.reshape(b*h*w, 1, 1, 1) / 2**i
|
||||||
y0 = torch.zeros_like(x0)
|
y0 = torch.zeros_like(x0)
|
||||||
|
|
||||||
|
@ -191,7 +191,7 @@ class IGEVStereo(nn.Module):
|
|||||||
geo_block = Combined_Geo_Encoding_Volume
|
geo_block = Combined_Geo_Encoding_Volume
|
||||||
geo_fn = geo_block(match_left.float(), match_right.float(), geo_encoding_volume.float(), radius=self.args.corr_radius, num_levels=self.args.corr_levels)
|
geo_fn = geo_block(match_left.float(), match_right.float(), geo_encoding_volume.float(), radius=self.args.corr_radius, num_levels=self.args.corr_levels)
|
||||||
b, c, h, w = match_left.shape
|
b, c, h, w = match_left.shape
|
||||||
coords = torch.arange(w).float().to(match_left.device).reshape(1,1,w,1).repeat(b, h, 1, 1)
|
coords = torch.arange(w, device=match_left.device).float().reshape(1,1,w,1).repeat(b, h, 1, 1)
|
||||||
disp = init_disp
|
disp = init_disp
|
||||||
disp_preds = []
|
disp_preds = []
|
||||||
|
|
||||||
|
@ -353,7 +353,7 @@ def fetch_dataloader(args):
|
|||||||
train_dataset = new_dataset if train_dataset is None else train_dataset + new_dataset
|
train_dataset = new_dataset if train_dataset is None else train_dataset + new_dataset
|
||||||
|
|
||||||
train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size,
|
train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size,
|
||||||
pin_memory=False, shuffle=True, num_workers=int(os.environ.get('SLURM_CPUS_PER_TASK', 6))-2, drop_last=True)
|
pin_memory=True, shuffle=True, num_workers=int(os.environ.get('SLURM_CPUS_PER_TASK', 6))-2, drop_last=True)
|
||||||
|
|
||||||
logging.info('Training with %d image pairs' % len(train_dataset))
|
logging.info('Training with %d image pairs' % len(train_dataset))
|
||||||
return train_loader
|
return train_loader
|
||||||
|
@ -22,6 +22,7 @@ from evaluate_stereo import *
|
|||||||
import core.stereo_datasets as datasets
|
import core.stereo_datasets as datasets
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
torch.backends.cudnn.benchmark = True
|
||||||
|
|
||||||
ckpt_path = './checkpoints/igev_stereo'
|
ckpt_path = './checkpoints/igev_stereo'
|
||||||
log_path = './checkpoints/igev_stereo'
|
log_path = './checkpoints/igev_stereo'
|
||||||
|
Loading…
Reference in New Issue
Block a user