396 lines
16 KiB
Python
396 lines
16 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import math
|
|
|
|
class SubModule(nn.Module):
|
|
def __init__(self):
|
|
super(SubModule, self).__init__()
|
|
|
|
def weight_init(self):
|
|
for m in self.modules():
|
|
if isinstance(m, nn.Conv2d):
|
|
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
|
m.weight.data.normal_(0, math.sqrt(2. / n))
|
|
elif isinstance(m, nn.Conv3d):
|
|
n = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] * m.out_channels
|
|
m.weight.data.normal_(0, math.sqrt(2. / n))
|
|
elif isinstance(m, nn.BatchNorm2d):
|
|
m.weight.data.fill_(1)
|
|
m.bias.data.zero_()
|
|
elif isinstance(m, nn.BatchNorm3d):
|
|
m.weight.data.fill_(1)
|
|
m.bias.data.zero_()
|
|
|
|
class BasicConv(nn.Module):
|
|
|
|
def __init__(self, in_channels, out_channels, deconv=False, is_3d=False, bn=True, relu=True, **kwargs):
|
|
super(BasicConv, self).__init__()
|
|
|
|
self.relu = relu
|
|
self.use_bn = bn
|
|
if is_3d:
|
|
if deconv:
|
|
self.conv = nn.ConvTranspose3d(in_channels, out_channels, bias=False, **kwargs)
|
|
else:
|
|
self.conv = nn.Conv3d(in_channels, out_channels, bias=False, **kwargs)
|
|
self.bn = nn.BatchNorm3d(out_channels)
|
|
else:
|
|
if deconv:
|
|
self.conv = nn.ConvTranspose2d(in_channels, out_channels, bias=False, **kwargs)
|
|
else:
|
|
self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
|
|
self.bn = nn.BatchNorm2d(out_channels)
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
if self.use_bn:
|
|
x = self.bn(x)
|
|
if self.relu:
|
|
x = nn.LeakyReLU()(x)#, inplace=True)
|
|
return x
|
|
|
|
class BasicConv_IN(nn.Module):
|
|
|
|
def __init__(self, in_channels, out_channels, deconv=False, is_3d=False, IN=True, relu=True, **kwargs):
|
|
super(BasicConv_IN, self).__init__()
|
|
|
|
self.relu = relu
|
|
self.use_in = IN
|
|
if is_3d:
|
|
if deconv:
|
|
self.conv = nn.ConvTranspose3d(in_channels, out_channels, bias=False, **kwargs)
|
|
else:
|
|
self.conv = nn.Conv3d(in_channels, out_channels, bias=False, **kwargs)
|
|
self.IN = nn.InstanceNorm3d(out_channels)
|
|
else:
|
|
if deconv:
|
|
self.conv = nn.ConvTranspose2d(in_channels, out_channels, bias=False, **kwargs)
|
|
else:
|
|
self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
|
|
self.IN = nn.InstanceNorm2d(out_channels)
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
if self.use_in:
|
|
x = self.IN(x)
|
|
if self.relu:
|
|
x = nn.LeakyReLU()(x)#, inplace=True)
|
|
return x
|
|
|
|
class Conv2x(nn.Module):
|
|
|
|
def __init__(self, in_channels, out_channels, deconv=False, is_3d=False, concat=True, keep_concat=True, bn=True, relu=True, keep_dispc=False):
|
|
super(Conv2x, self).__init__()
|
|
self.concat = concat
|
|
self.is_3d = is_3d
|
|
if deconv and is_3d:
|
|
kernel = (4, 4, 4)
|
|
elif deconv:
|
|
kernel = 4
|
|
else:
|
|
kernel = 3
|
|
|
|
if deconv and is_3d and keep_dispc:
|
|
kernel = (1, 4, 4)
|
|
stride = (1, 2, 2)
|
|
padding = (0, 1, 1)
|
|
self.conv1 = BasicConv(in_channels, out_channels, deconv, is_3d, bn=True, relu=True, kernel_size=kernel, stride=stride, padding=padding)
|
|
else:
|
|
self.conv1 = BasicConv(in_channels, out_channels, deconv, is_3d, bn=True, relu=True, kernel_size=kernel, stride=2, padding=1)
|
|
|
|
if self.concat:
|
|
mul = 2 if keep_concat else 1
|
|
self.conv2 = BasicConv(out_channels*2, out_channels*mul, False, is_3d, bn, relu, kernel_size=3, stride=1, padding=1)
|
|
else:
|
|
self.conv2 = BasicConv(out_channels, out_channels, False, is_3d, bn, relu, kernel_size=3, stride=1, padding=1)
|
|
|
|
def forward(self, x, rem):
|
|
x = self.conv1(x)
|
|
if x.shape != rem.shape:
|
|
x = F.interpolate(
|
|
x,
|
|
size=(rem.shape[-2], rem.shape[-1]),
|
|
mode='nearest')
|
|
if self.concat:
|
|
x = torch.cat((x, rem), 1)
|
|
else:
|
|
x = x + rem
|
|
x = self.conv2(x)
|
|
return x
|
|
|
|
class Conv2x_IN(nn.Module):
|
|
|
|
def __init__(self, in_channels, out_channels, deconv=False, is_3d=False, concat=True, keep_concat=True, IN=True, relu=True, keep_dispc=False):
|
|
super(Conv2x_IN, self).__init__()
|
|
self.concat = concat
|
|
self.is_3d = is_3d
|
|
if deconv and is_3d:
|
|
kernel = (4, 4, 4)
|
|
elif deconv:
|
|
kernel = 4
|
|
else:
|
|
kernel = 3
|
|
|
|
if deconv and is_3d and keep_dispc:
|
|
kernel = (1, 4, 4)
|
|
stride = (1, 2, 2)
|
|
padding = (0, 1, 1)
|
|
self.conv1 = BasicConv_IN(in_channels, out_channels, deconv, is_3d, IN=True, relu=True, kernel_size=kernel, stride=stride, padding=padding)
|
|
else:
|
|
self.conv1 = BasicConv_IN(in_channels, out_channels, deconv, is_3d, IN=True, relu=True, kernel_size=kernel, stride=2, padding=1)
|
|
|
|
if self.concat:
|
|
mul = 2 if keep_concat else 1
|
|
self.conv2 = BasicConv_IN(out_channels*2, out_channels*mul, False, is_3d, IN, relu, kernel_size=3, stride=1, padding=1)
|
|
else:
|
|
self.conv2 = BasicConv_IN(out_channels, out_channels, False, is_3d, IN, relu, kernel_size=3, stride=1, padding=1)
|
|
|
|
def forward(self, x, rem):
|
|
x = self.conv1(x)
|
|
if x.shape != rem.shape:
|
|
x = F.interpolate(
|
|
x,
|
|
size=(rem.shape[-2], rem.shape[-1]),
|
|
mode='nearest')
|
|
if self.concat:
|
|
x = torch.cat((x, rem), 1)
|
|
else:
|
|
x = x + rem
|
|
x = self.conv2(x)
|
|
return x
|
|
|
|
class ConvReLU(nn.Module):
|
|
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, pad=1, dilation=1):
|
|
super(ConvReLU, self).__init__()
|
|
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=pad, dilation=dilation, bias=False)
|
|
|
|
def forward(self,x):
|
|
return F.relu(self.conv(x), inplace=True)
|
|
|
|
class DepthInitialization(nn.Module):
|
|
def __init__(self, num_sample):
|
|
super(DepthInitialization, self).__init__()
|
|
self.num_sample = num_sample
|
|
|
|
def forward(self, inverse_depth_min, inverse_depth_max, height, width, device):
|
|
batch = inverse_depth_min.size()[0]
|
|
index = torch.arange(0, self.num_sample, 1, device=device).view(1, self.num_sample, 1, 1).float()
|
|
normalized_sample = index.repeat(batch, 1, height, width) / (self.num_sample-1)
|
|
depth_sample = inverse_depth_max + normalized_sample * (inverse_depth_min - inverse_depth_max)
|
|
|
|
depth_sample = 1.0 / depth_sample
|
|
|
|
return depth_sample
|
|
|
|
class PixelViewWeight(nn.Module):
|
|
def __init__(self, G):
|
|
super(PixelViewWeight, self).__init__()
|
|
self.conv = nn.Sequential(
|
|
ConvReLU(G, 16),
|
|
nn.Conv2d(16, 1, 1, stride=1, padding=0),
|
|
)
|
|
|
|
def forward(self, x):
|
|
# x: [B, G, N, H, W]
|
|
batch, dim, num_depth, height, width = x.size()
|
|
x = x.permute(0, 2, 1, 3, 4).contiguous()
|
|
x = x.view(batch*num_depth, dim, height, width) # [B*N,G,H,W]
|
|
x =self.conv(x).view(batch, num_depth, height, width)
|
|
x = torch.softmax(x,dim=1)
|
|
x = torch.max(x, dim=1)[0]
|
|
|
|
return x.unsqueeze(1)
|
|
|
|
class FeatureAtt(nn.Module):
|
|
def __init__(self, cv_chan, feat_chan):
|
|
super(FeatureAtt, self).__init__()
|
|
|
|
self.feat_att = nn.Sequential(
|
|
BasicConv(feat_chan, feat_chan//2, kernel_size=1, stride=1, padding=0),
|
|
nn.Conv2d(feat_chan//2, cv_chan, 1))
|
|
|
|
def forward(self, cv, feat):
|
|
'''
|
|
'''
|
|
feat_att = self.feat_att(feat).unsqueeze(2)
|
|
cv = torch.sigmoid(feat_att)*cv
|
|
return cv
|
|
|
|
class hourglass(nn.Module):
|
|
def __init__(self, in_channels):
|
|
super(hourglass, self).__init__()
|
|
|
|
self.conv1 = nn.Sequential(BasicConv(in_channels, in_channels*2, is_3d=True, bn=True, relu=True, kernel_size=3,
|
|
padding=1, stride=2, dilation=1),
|
|
BasicConv(in_channels*2, in_channels*2, is_3d=True, bn=True, relu=True, kernel_size=3,
|
|
padding=1, stride=1, dilation=1))
|
|
|
|
self.conv2 = nn.Sequential(BasicConv(in_channels*2, in_channels*4, is_3d=True, bn=True, relu=True, kernel_size=3,
|
|
padding=1, stride=2, dilation=1),
|
|
BasicConv(in_channels*4, in_channels*4, is_3d=True, bn=True, relu=True, kernel_size=3,
|
|
padding=1, stride=1, dilation=1))
|
|
|
|
self.conv3 = nn.Sequential(BasicConv(in_channels*4, in_channels*6, is_3d=True, bn=True, relu=True, kernel_size=3,
|
|
padding=1, stride=2, dilation=1),
|
|
BasicConv(in_channels*6, in_channels*6, is_3d=True, bn=True, relu=True, kernel_size=3,
|
|
padding=1, stride=1, dilation=1))
|
|
|
|
|
|
self.conv3_up = BasicConv(in_channels*6, in_channels*4, deconv=True, is_3d=True, bn=True,
|
|
relu=True, kernel_size=(4, 4, 4), padding=(1, 1, 1), stride=(2, 2, 2))
|
|
|
|
self.conv2_up = BasicConv(in_channels*4, in_channels*2, deconv=True, is_3d=True, bn=True,
|
|
relu=True, kernel_size=(4, 4, 4), padding=(1, 1, 1), stride=(2, 2, 2))
|
|
|
|
self.conv1_up = BasicConv(in_channels*2, 1, deconv=True, is_3d=True, bn=False,
|
|
relu=False, kernel_size=(4, 4, 4), padding=(1, 1, 1), stride=(2, 2, 2))
|
|
|
|
self.agg_0 = nn.Sequential(BasicConv(in_channels*8, in_channels*4, is_3d=True, kernel_size=1, padding=0, stride=1),
|
|
BasicConv(in_channels*4, in_channels*4, is_3d=True, kernel_size=3, padding=1, stride=1),
|
|
BasicConv(in_channels*4, in_channels*4, is_3d=True, kernel_size=3, padding=1, stride=1),)
|
|
|
|
self.agg_1 = nn.Sequential(BasicConv(in_channels*4, in_channels*2, is_3d=True, kernel_size=1, padding=0, stride=1),
|
|
BasicConv(in_channels*2, in_channels*2, is_3d=True, kernel_size=3, padding=1, stride=1),
|
|
BasicConv(in_channels*2, in_channels*2, is_3d=True, kernel_size=3, padding=1, stride=1))
|
|
|
|
|
|
self.feature_att_8 = FeatureAtt(in_channels*2, 64)
|
|
self.feature_att_16 = FeatureAtt(in_channels*4, 192)
|
|
self.feature_att_32 = FeatureAtt(in_channels*6, 160)
|
|
self.feature_att_up_16 = FeatureAtt(in_channels*4, 192)
|
|
self.feature_att_up_8 = FeatureAtt(in_channels*2, 64)
|
|
|
|
def forward(self, x, features):
|
|
conv1 = self.conv1(x)
|
|
conv1 = self.feature_att_8(conv1, features[1])
|
|
|
|
conv2 = self.conv2(conv1)
|
|
conv2 = self.feature_att_16(conv2, features[2])
|
|
|
|
conv3 = self.conv3(conv2)
|
|
conv3 = self.feature_att_32(conv3, features[3])
|
|
|
|
conv3_up = self.conv3_up(conv3)
|
|
|
|
conv2 = torch.cat((conv3_up, conv2), dim=1)
|
|
conv2 = self.agg_0(conv2)
|
|
conv2 = self.feature_att_up_16(conv2, features[2])
|
|
|
|
conv2_up = self.conv2_up(conv2)
|
|
|
|
conv1 = torch.cat((conv2_up, conv1), dim=1)
|
|
conv1 = self.agg_1(conv1)
|
|
conv1 = self.feature_att_up_8(conv1, features[1])
|
|
|
|
conv = self.conv1_up(conv1)
|
|
|
|
return conv
|
|
|
|
def bilinear_sampler(img, coords, mode='bilinear', mask=False):
|
|
""" Wrapper for grid_sample, uses pixel coordinates """
|
|
H, W = img.shape[-2:]
|
|
xgrid, ygrid = coords.split([1,1], dim=-1)
|
|
xgrid = 2*xgrid/(W-1) - 1
|
|
|
|
assert torch.unique(ygrid).numel() == 1 and H == 1 # This is a stereo problem
|
|
|
|
grid = torch.cat([xgrid, ygrid], dim=-1)
|
|
img = F.grid_sample(img, grid, align_corners=True)
|
|
if mask:
|
|
mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
|
|
return img, mask.float()
|
|
|
|
return img
|
|
|
|
def context_upsample(disp_low, up_weights):
|
|
###
|
|
# cv (b,1,h,w)
|
|
# sp (b,9,4*h,4*w)
|
|
###
|
|
b, c, h, w = disp_low.shape
|
|
|
|
disp_unfold = F.unfold(disp_low.reshape(b,c,h,w),3,1,1).reshape(b,-1,h,w)
|
|
disp_unfold = F.interpolate(disp_unfold,(h*4,w*4),mode='nearest').reshape(b,9,h*4,w*4)
|
|
|
|
disp = (disp_unfold*up_weights).sum(1)
|
|
|
|
return disp
|
|
|
|
def pool2x(x):
|
|
return F.avg_pool2d(x, 3, stride=2, padding=1)
|
|
|
|
def interp(x, dest):
|
|
interp_args = {'mode': 'bilinear', 'align_corners': True}
|
|
return F.interpolate(x, dest.shape[2:], **interp_args)
|
|
|
|
def differentiable_warping(src_fea, src_proj, ref_proj, depth_samples, return_mask=False):
|
|
# src_fea: [B, C, H, W]
|
|
# src_proj: [B, 4, 4]
|
|
# ref_proj: [B, 4, 4]
|
|
# depth_samples: [B, Ndepth, H, W]
|
|
# out: [B, C, Ndepth, H, W]
|
|
batch, num_depth, height, width = depth_samples.size()
|
|
height1, width1 = src_fea.size()[2:]
|
|
|
|
with torch.no_grad():
|
|
if batch==2:
|
|
inv_ref_proj = []
|
|
for i in range(batch):
|
|
inv_ref_proj.append(torch.inverse(ref_proj[i]).unsqueeze(0))
|
|
inv_ref_proj = torch.cat(inv_ref_proj, dim=0)
|
|
assert (not torch.isnan(inv_ref_proj).any()), "nan in inverse(ref_proj)"
|
|
proj = torch.matmul(src_proj, inv_ref_proj)
|
|
else:
|
|
proj = torch.matmul(src_proj, torch.inverse(ref_proj))
|
|
assert (not torch.isnan(proj).any()), "nan in proj"
|
|
|
|
rot = proj[:, :3, :3] # [B,3,3]
|
|
trans = proj[:, :3, 3:4] # [B,3,1]
|
|
y, x = torch.meshgrid([torch.arange(0, height, dtype=torch.float32, device=depth_samples.device),
|
|
torch.arange(0, width, dtype=torch.float32, device=depth_samples.device)])
|
|
y, x = y.contiguous(), x.contiguous()
|
|
y, x = y.view(height * width), x.view(height * width)
|
|
y = y*(height1/height)
|
|
x = x*(width1/width)
|
|
xyz = torch.stack((x, y, torch.ones_like(x))) # [3, H*W]
|
|
xyz = torch.unsqueeze(xyz, 0).repeat(batch, 1, 1) # [B, 3, H*W]
|
|
rot_xyz = torch.matmul(rot, xyz) # [B, 3, H*W]
|
|
|
|
rot_depth_xyz = rot_xyz.unsqueeze(2).repeat(1, 1, num_depth, 1) * depth_samples.view(batch, 1, num_depth,
|
|
height * width) # [B, 3, Ndepth, H*W]
|
|
proj_xyz = rot_depth_xyz + trans.view(batch, 3, 1, 1) # [B, 3, Ndepth, H*W]
|
|
# avoid negative depth
|
|
valid_mask = proj_xyz[:, 2:] > 1e-2
|
|
proj_xyz[:, 0:1][~valid_mask] = width
|
|
proj_xyz[:, 1:2][~valid_mask] = height
|
|
proj_xyz[:, 2:3][~valid_mask] = 1
|
|
proj_xy = proj_xyz[:, :2, :, :] / proj_xyz[:, 2:3, :, :] # [B, 2, Ndepth, H*W]
|
|
valid_mask = valid_mask & (proj_xy[:, 0:1] >=0) & (proj_xy[:, 0:1] < width) \
|
|
& (proj_xy[:, 1:2] >=0) & (proj_xy[:, 1:2] < height)
|
|
proj_x_normalized = proj_xy[:, 0, :, :] / ((width1 - 1) / 2) - 1 # [B, Ndepth, H*W]
|
|
proj_y_normalized = proj_xy[:, 1, :, :] / ((height1 - 1) / 2) - 1
|
|
proj_xy = torch.stack((proj_x_normalized, proj_y_normalized), dim=3) # [B, Ndepth, H*W, 2]
|
|
grid = proj_xy
|
|
|
|
dim = src_fea.size()[1]
|
|
warped_src_fea = F.grid_sample(src_fea, grid.view(batch, num_depth * height, width, 2), mode='bilinear',
|
|
padding_mode='zeros',align_corners=True)
|
|
warped_src_fea = warped_src_fea.view(batch, dim, num_depth, height, width)
|
|
if return_mask:
|
|
valid_mask = valid_mask.view(batch,num_depth,height,width)
|
|
return warped_src_fea, valid_mask
|
|
else:
|
|
return warped_src_fea
|
|
|
|
def depth_normalization(depth, inverse_depth_min, inverse_depth_max):
|
|
'''convert depth map to the index in inverse range'''
|
|
inverse_depth = 1.0 / (depth+1e-5)
|
|
normalized_depth = (inverse_depth - inverse_depth_max) / (inverse_depth_min - inverse_depth_max)
|
|
return normalized_depth
|
|
|
|
def depth_unnormalization(normalized_depth, inverse_depth_min, inverse_depth_max):
|
|
'''convert the index in inverse range to depth map'''
|
|
inverse_depth = inverse_depth_max + normalized_depth * (inverse_depth_min - inverse_depth_max) # [B,1,H,W]
|
|
depth = 1.0 / inverse_depth
|
|
return depth |