IGEV/IGEV-MVS/core/corr.py

61 lines
2.0 KiB
Python
Raw Permalink Normal View History

2023-03-20 19:52:04 +08:00
import torch
import torch.nn as nn
import torch.nn.functional as F
from .submodule import *
class CorrBlock1D_Cost_Volume:
def __init__(self, init_corr, corr, num_levels=2, radius=4, inverse_depth_min=None, inverse_depth_max=None, num_sample=None):
self.num_levels = 2
self.radius = radius
self.inverse_depth_min = inverse_depth_min
self.inverse_depth_max = inverse_depth_max
self.num_sample = num_sample
self.corr_pyramid = []
self.init_corr_pyramid = []
# all pairs correlation
# batch, h1, w1, dim, w2 = corr.shape
b, c, d, h, w = corr.shape
corr = corr.permute(0, 3, 4, 1, 2).reshape(b*h*w, 1, 1, d)
init_corr = init_corr.permute(0, 3, 4, 1, 2).reshape(b*h*w, 1, 1, d)
self.corr_pyramid.append(corr)
self.init_corr_pyramid.append(init_corr)
for i in range(self.num_levels):
corr = F.avg_pool2d(corr, [1,2], stride=[1,2])
self.corr_pyramid.append(corr)
for i in range(self.num_levels):
init_corr = F.avg_pool2d(init_corr, [1,2], stride=[1,2])
self.init_corr_pyramid.append(init_corr)
def __call__(self, disp):
r = self.radius
b, _, h, w = disp.shape
out_pyramid = []
for i in range(self.num_levels):
corr = self.corr_pyramid[i]
init_corr = self.init_corr_pyramid[i]
dx = torch.linspace(-r, r, 2*r+1)
dx = dx.view(1, 1, 2*r+1, 1).to(disp.device)
x0 = dx + disp.reshape(b*h*w, 1, 1, 1) / 2**i
y0 = torch.zeros_like(x0)
disp_lvl = torch.cat([x0,y0], dim=-1)
corr = bilinear_sampler(corr, disp_lvl)
corr = corr.view(b, h, w, -1)
init_corr = bilinear_sampler(init_corr, disp_lvl)
init_corr = init_corr.view(b, h, w, -1)
out_pyramid.append(corr)
out_pyramid.append(init_corr)
out = torch.cat(out_pyramid, dim=-1)
return out.permute(0, 3, 1, 2).contiguous().float()