2023-03-12 20:19:58 +08:00
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
import torch.nn.functional as F
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BasicConv(nn.Module):
|
|
|
|
|
|
|
|
def __init__(self, in_channels, out_channels, deconv=False, is_3d=False, bn=True, relu=True, **kwargs):
|
|
|
|
super(BasicConv, self).__init__()
|
|
|
|
|
|
|
|
self.relu = relu
|
|
|
|
self.use_bn = bn
|
|
|
|
if is_3d:
|
|
|
|
if deconv:
|
|
|
|
self.conv = nn.ConvTranspose3d(in_channels, out_channels, bias=False, **kwargs)
|
|
|
|
else:
|
|
|
|
self.conv = nn.Conv3d(in_channels, out_channels, bias=False, **kwargs)
|
|
|
|
self.bn = nn.BatchNorm3d(out_channels)
|
|
|
|
else:
|
|
|
|
if deconv:
|
|
|
|
self.conv = nn.ConvTranspose2d(in_channels, out_channels, bias=False, **kwargs)
|
|
|
|
else:
|
|
|
|
self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
|
|
|
|
self.bn = nn.BatchNorm2d(out_channels)
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
x = self.conv(x)
|
|
|
|
if self.use_bn:
|
|
|
|
x = self.bn(x)
|
|
|
|
if self.relu:
|
|
|
|
x = nn.LeakyReLU()(x)#, inplace=True)
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
class Conv2x(nn.Module):
|
|
|
|
|
|
|
|
def __init__(self, in_channels, out_channels, deconv=False, is_3d=False, concat=True, keep_concat=True, bn=True, relu=True, keep_dispc=False):
|
|
|
|
super(Conv2x, self).__init__()
|
|
|
|
self.concat = concat
|
|
|
|
self.is_3d = is_3d
|
|
|
|
if deconv and is_3d:
|
|
|
|
kernel = (4, 4, 4)
|
|
|
|
elif deconv:
|
|
|
|
kernel = 4
|
|
|
|
else:
|
|
|
|
kernel = 3
|
|
|
|
|
|
|
|
if deconv and is_3d and keep_dispc:
|
|
|
|
kernel = (1, 4, 4)
|
|
|
|
stride = (1, 2, 2)
|
|
|
|
padding = (0, 1, 1)
|
|
|
|
self.conv1 = BasicConv(in_channels, out_channels, deconv, is_3d, bn=True, relu=True, kernel_size=kernel, stride=stride, padding=padding)
|
|
|
|
else:
|
|
|
|
self.conv1 = BasicConv(in_channels, out_channels, deconv, is_3d, bn=True, relu=True, kernel_size=kernel, stride=2, padding=1)
|
|
|
|
|
|
|
|
if self.concat:
|
|
|
|
mul = 2 if keep_concat else 1
|
|
|
|
self.conv2 = BasicConv(out_channels*2, out_channels*mul, False, is_3d, bn, relu, kernel_size=3, stride=1, padding=1)
|
|
|
|
else:
|
|
|
|
self.conv2 = BasicConv(out_channels, out_channels, False, is_3d, bn, relu, kernel_size=3, stride=1, padding=1)
|
|
|
|
|
|
|
|
def forward(self, x, rem):
|
|
|
|
x = self.conv1(x)
|
|
|
|
if x.shape != rem.shape:
|
|
|
|
x = F.interpolate(
|
|
|
|
x,
|
|
|
|
size=(rem.shape[-2], rem.shape[-1]),
|
|
|
|
mode='nearest')
|
|
|
|
if self.concat:
|
|
|
|
x = torch.cat((x, rem), 1)
|
|
|
|
else:
|
|
|
|
x = x + rem
|
|
|
|
x = self.conv2(x)
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
class BasicConv_IN(nn.Module):
|
|
|
|
|
|
|
|
def __init__(self, in_channels, out_channels, deconv=False, is_3d=False, IN=True, relu=True, **kwargs):
|
|
|
|
super(BasicConv_IN, self).__init__()
|
|
|
|
|
|
|
|
self.relu = relu
|
|
|
|
self.use_in = IN
|
|
|
|
if is_3d:
|
|
|
|
if deconv:
|
|
|
|
self.conv = nn.ConvTranspose3d(in_channels, out_channels, bias=False, **kwargs)
|
|
|
|
else:
|
|
|
|
self.conv = nn.Conv3d(in_channels, out_channels, bias=False, **kwargs)
|
|
|
|
self.IN = nn.InstanceNorm3d(out_channels)
|
|
|
|
else:
|
|
|
|
if deconv:
|
|
|
|
self.conv = nn.ConvTranspose2d(in_channels, out_channels, bias=False, **kwargs)
|
|
|
|
else:
|
|
|
|
self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
|
|
|
|
self.IN = nn.InstanceNorm2d(out_channels)
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
x = self.conv(x)
|
|
|
|
if self.use_in:
|
|
|
|
x = self.IN(x)
|
|
|
|
if self.relu:
|
|
|
|
x = nn.LeakyReLU()(x)#, inplace=True)
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
class Conv2x_IN(nn.Module):
|
|
|
|
|
|
|
|
def __init__(self, in_channels, out_channels, deconv=False, is_3d=False, concat=True, keep_concat=True, IN=True, relu=True, keep_dispc=False):
|
|
|
|
super(Conv2x_IN, self).__init__()
|
|
|
|
self.concat = concat
|
|
|
|
self.is_3d = is_3d
|
|
|
|
if deconv and is_3d:
|
|
|
|
kernel = (4, 4, 4)
|
|
|
|
elif deconv:
|
|
|
|
kernel = 4
|
|
|
|
else:
|
|
|
|
kernel = 3
|
|
|
|
|
|
|
|
if deconv and is_3d and keep_dispc:
|
|
|
|
kernel = (1, 4, 4)
|
|
|
|
stride = (1, 2, 2)
|
|
|
|
padding = (0, 1, 1)
|
|
|
|
self.conv1 = BasicConv_IN(in_channels, out_channels, deconv, is_3d, IN=True, relu=True, kernel_size=kernel, stride=stride, padding=padding)
|
|
|
|
else:
|
|
|
|
self.conv1 = BasicConv_IN(in_channels, out_channels, deconv, is_3d, IN=True, relu=True, kernel_size=kernel, stride=2, padding=1)
|
|
|
|
|
|
|
|
if self.concat:
|
|
|
|
mul = 2 if keep_concat else 1
|
|
|
|
self.conv2 = BasicConv_IN(out_channels*2, out_channels*mul, False, is_3d, IN, relu, kernel_size=3, stride=1, padding=1)
|
|
|
|
else:
|
|
|
|
self.conv2 = BasicConv_IN(out_channels, out_channels, False, is_3d, IN, relu, kernel_size=3, stride=1, padding=1)
|
|
|
|
|
|
|
|
def forward(self, x, rem):
|
|
|
|
x = self.conv1(x)
|
|
|
|
if x.shape != rem.shape:
|
|
|
|
x = F.interpolate(
|
|
|
|
x,
|
|
|
|
size=(rem.shape[-2], rem.shape[-1]),
|
|
|
|
mode='nearest')
|
|
|
|
if self.concat:
|
|
|
|
x = torch.cat((x, rem), 1)
|
|
|
|
else:
|
|
|
|
x = x + rem
|
|
|
|
x = self.conv2(x)
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
def groupwise_correlation(fea1, fea2, num_groups):
|
|
|
|
B, C, H, W = fea1.shape
|
|
|
|
assert C % num_groups == 0
|
|
|
|
channels_per_group = C // num_groups
|
|
|
|
cost = (fea1 * fea2).view([B, num_groups, channels_per_group, H, W]).mean(dim=2)
|
|
|
|
assert cost.shape == (B, num_groups, H, W)
|
|
|
|
return cost
|
|
|
|
|
|
|
|
def build_gwc_volume(refimg_fea, targetimg_fea, maxdisp, num_groups):
|
2023-04-27 13:29:57 +08:00
|
|
|
# batch, groups, max_disp, height, width
|
2023-03-12 20:19:58 +08:00
|
|
|
B, C, H, W = refimg_fea.shape
|
|
|
|
volume = refimg_fea.new_zeros([B, num_groups, maxdisp, H, W])
|
|
|
|
for i in range(maxdisp):
|
|
|
|
if i > 0:
|
|
|
|
volume[:, :, i, :, i:] = groupwise_correlation(refimg_fea[:, :, :, i:], targetimg_fea[:, :, :, :-i],
|
|
|
|
num_groups)
|
|
|
|
else:
|
|
|
|
volume[:, :, i, :, :] = groupwise_correlation(refimg_fea, targetimg_fea, num_groups)
|
|
|
|
volume = volume.contiguous()
|
|
|
|
return volume
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def norm_correlation(fea1, fea2):
|
|
|
|
cost = torch.mean(((fea1/(torch.norm(fea1, 2, 1, True)+1e-05)) * (fea2/(torch.norm(fea2, 2, 1, True)+1e-05))), dim=1, keepdim=True)
|
|
|
|
return cost
|
|
|
|
|
|
|
|
def build_norm_correlation_volume(refimg_fea, targetimg_fea, maxdisp):
|
|
|
|
B, C, H, W = refimg_fea.shape
|
|
|
|
volume = refimg_fea.new_zeros([B, 1, maxdisp, H, W])
|
|
|
|
for i in range(maxdisp):
|
|
|
|
if i > 0:
|
|
|
|
volume[:, :, i, :, i:] = norm_correlation(refimg_fea[:, :, :, i:], targetimg_fea[:, :, :, :-i])
|
|
|
|
else:
|
|
|
|
volume[:, :, i, :, :] = norm_correlation(refimg_fea, targetimg_fea)
|
|
|
|
volume = volume.contiguous()
|
|
|
|
return volume
|
|
|
|
|
|
|
|
def correlation(fea1, fea2):
|
|
|
|
cost = torch.sum((fea1 * fea2), dim=1, keepdim=True)
|
|
|
|
return cost
|
|
|
|
|
|
|
|
def build_correlation_volume(refimg_fea, targetimg_fea, maxdisp):
|
|
|
|
B, C, H, W = refimg_fea.shape
|
|
|
|
volume = refimg_fea.new_zeros([B, 1, maxdisp, H, W])
|
|
|
|
for i in range(maxdisp):
|
|
|
|
if i > 0:
|
|
|
|
volume[:, :, i, :, i:] = correlation(refimg_fea[:, :, :, i:], targetimg_fea[:, :, :, :-i])
|
|
|
|
else:
|
|
|
|
volume[:, :, i, :, :] = correlation(refimg_fea, targetimg_fea)
|
|
|
|
volume = volume.contiguous()
|
|
|
|
return volume
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_concat_volume(refimg_fea, targetimg_fea, maxdisp):
|
|
|
|
B, C, H, W = refimg_fea.shape
|
|
|
|
volume = refimg_fea.new_zeros([B, 2 * C, maxdisp, H, W])
|
|
|
|
for i in range(maxdisp):
|
|
|
|
if i > 0:
|
|
|
|
volume[:, :C, i, :, :] = refimg_fea[:, :, :, :]
|
|
|
|
volume[:, C:, i, :, i:] = targetimg_fea[:, :, :, :-i]
|
|
|
|
else:
|
|
|
|
volume[:, :C, i, :, :] = refimg_fea
|
|
|
|
volume[:, C:, i, :, :] = targetimg_fea
|
|
|
|
volume = volume.contiguous()
|
|
|
|
return volume
|
|
|
|
|
|
|
|
def disparity_regression(x, maxdisp):
|
|
|
|
assert len(x.shape) == 4
|
|
|
|
disp_values = torch.arange(0, maxdisp, dtype=x.dtype, device=x.device)
|
|
|
|
disp_values = disp_values.view(1, maxdisp, 1, 1)
|
|
|
|
return torch.sum(x * disp_values, 1, keepdim=True)
|
|
|
|
|
|
|
|
|
|
|
|
class FeatureAtt(nn.Module):
|
|
|
|
def __init__(self, cv_chan, feat_chan):
|
|
|
|
super(FeatureAtt, self).__init__()
|
|
|
|
|
|
|
|
self.feat_att = nn.Sequential(
|
|
|
|
BasicConv(feat_chan, feat_chan//2, kernel_size=1, stride=1, padding=0),
|
|
|
|
nn.Conv2d(feat_chan//2, cv_chan, 1))
|
|
|
|
|
|
|
|
def forward(self, cv, feat):
|
|
|
|
'''
|
|
|
|
'''
|
|
|
|
feat_att = self.feat_att(feat).unsqueeze(2)
|
|
|
|
cv = torch.sigmoid(feat_att)*cv
|
|
|
|
return cv
|
|
|
|
|
|
|
|
def context_upsample(disp_low, up_weights):
|
|
|
|
###
|
|
|
|
# cv (b,1,h,w)
|
|
|
|
# sp (b,9,4*h,4*w)
|
|
|
|
###
|
|
|
|
b, c, h, w = disp_low.shape
|
|
|
|
|
|
|
|
disp_unfold = F.unfold(disp_low.reshape(b,c,h,w),3,1,1).reshape(b,-1,h,w)
|
|
|
|
disp_unfold = F.interpolate(disp_unfold,(h*4,w*4),mode='nearest').reshape(b,9,h*4,w*4)
|
|
|
|
|
|
|
|
disp = (disp_unfold*up_weights).sum(1)
|
|
|
|
|
|
|
|
return disp
|