IGEV/IGEV-MVS/core/submodule.py
2023-03-20 19:52:04 +08:00

396 lines
16 KiB
Python

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class SubModule(nn.Module):
def __init__(self):
super(SubModule, self).__init__()
def weight_init(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.Conv3d):
n = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm3d):
m.weight.data.fill_(1)
m.bias.data.zero_()
class BasicConv(nn.Module):
def __init__(self, in_channels, out_channels, deconv=False, is_3d=False, bn=True, relu=True, **kwargs):
super(BasicConv, self).__init__()
self.relu = relu
self.use_bn = bn
if is_3d:
if deconv:
self.conv = nn.ConvTranspose3d(in_channels, out_channels, bias=False, **kwargs)
else:
self.conv = nn.Conv3d(in_channels, out_channels, bias=False, **kwargs)
self.bn = nn.BatchNorm3d(out_channels)
else:
if deconv:
self.conv = nn.ConvTranspose2d(in_channels, out_channels, bias=False, **kwargs)
else:
self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
self.bn = nn.BatchNorm2d(out_channels)
def forward(self, x):
x = self.conv(x)
if self.use_bn:
x = self.bn(x)
if self.relu:
x = nn.LeakyReLU()(x)#, inplace=True)
return x
class BasicConv_IN(nn.Module):
def __init__(self, in_channels, out_channels, deconv=False, is_3d=False, IN=True, relu=True, **kwargs):
super(BasicConv_IN, self).__init__()
self.relu = relu
self.use_in = IN
if is_3d:
if deconv:
self.conv = nn.ConvTranspose3d(in_channels, out_channels, bias=False, **kwargs)
else:
self.conv = nn.Conv3d(in_channels, out_channels, bias=False, **kwargs)
self.IN = nn.InstanceNorm3d(out_channels)
else:
if deconv:
self.conv = nn.ConvTranspose2d(in_channels, out_channels, bias=False, **kwargs)
else:
self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
self.IN = nn.InstanceNorm2d(out_channels)
def forward(self, x):
x = self.conv(x)
if self.use_in:
x = self.IN(x)
if self.relu:
x = nn.LeakyReLU()(x)#, inplace=True)
return x
class Conv2x(nn.Module):
def __init__(self, in_channels, out_channels, deconv=False, is_3d=False, concat=True, keep_concat=True, bn=True, relu=True, keep_dispc=False):
super(Conv2x, self).__init__()
self.concat = concat
self.is_3d = is_3d
if deconv and is_3d:
kernel = (4, 4, 4)
elif deconv:
kernel = 4
else:
kernel = 3
if deconv and is_3d and keep_dispc:
kernel = (1, 4, 4)
stride = (1, 2, 2)
padding = (0, 1, 1)
self.conv1 = BasicConv(in_channels, out_channels, deconv, is_3d, bn=True, relu=True, kernel_size=kernel, stride=stride, padding=padding)
else:
self.conv1 = BasicConv(in_channels, out_channels, deconv, is_3d, bn=True, relu=True, kernel_size=kernel, stride=2, padding=1)
if self.concat:
mul = 2 if keep_concat else 1
self.conv2 = BasicConv(out_channels*2, out_channels*mul, False, is_3d, bn, relu, kernel_size=3, stride=1, padding=1)
else:
self.conv2 = BasicConv(out_channels, out_channels, False, is_3d, bn, relu, kernel_size=3, stride=1, padding=1)
def forward(self, x, rem):
x = self.conv1(x)
if x.shape != rem.shape:
x = F.interpolate(
x,
size=(rem.shape[-2], rem.shape[-1]),
mode='nearest')
if self.concat:
x = torch.cat((x, rem), 1)
else:
x = x + rem
x = self.conv2(x)
return x
class Conv2x_IN(nn.Module):
def __init__(self, in_channels, out_channels, deconv=False, is_3d=False, concat=True, keep_concat=True, IN=True, relu=True, keep_dispc=False):
super(Conv2x_IN, self).__init__()
self.concat = concat
self.is_3d = is_3d
if deconv and is_3d:
kernel = (4, 4, 4)
elif deconv:
kernel = 4
else:
kernel = 3
if deconv and is_3d and keep_dispc:
kernel = (1, 4, 4)
stride = (1, 2, 2)
padding = (0, 1, 1)
self.conv1 = BasicConv_IN(in_channels, out_channels, deconv, is_3d, IN=True, relu=True, kernel_size=kernel, stride=stride, padding=padding)
else:
self.conv1 = BasicConv_IN(in_channels, out_channels, deconv, is_3d, IN=True, relu=True, kernel_size=kernel, stride=2, padding=1)
if self.concat:
mul = 2 if keep_concat else 1
self.conv2 = BasicConv_IN(out_channels*2, out_channels*mul, False, is_3d, IN, relu, kernel_size=3, stride=1, padding=1)
else:
self.conv2 = BasicConv_IN(out_channels, out_channels, False, is_3d, IN, relu, kernel_size=3, stride=1, padding=1)
def forward(self, x, rem):
x = self.conv1(x)
if x.shape != rem.shape:
x = F.interpolate(
x,
size=(rem.shape[-2], rem.shape[-1]),
mode='nearest')
if self.concat:
x = torch.cat((x, rem), 1)
else:
x = x + rem
x = self.conv2(x)
return x
class ConvReLU(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, pad=1, dilation=1):
super(ConvReLU, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=pad, dilation=dilation, bias=False)
def forward(self,x):
return F.relu(self.conv(x), inplace=True)
class DepthInitialization(nn.Module):
def __init__(self, num_sample):
super(DepthInitialization, self).__init__()
self.num_sample = num_sample
def forward(self, inverse_depth_min, inverse_depth_max, height, width, device):
batch = inverse_depth_min.size()[0]
index = torch.arange(0, self.num_sample, 1, device=device).view(1, self.num_sample, 1, 1).float()
normalized_sample = index.repeat(batch, 1, height, width) / (self.num_sample-1)
depth_sample = inverse_depth_max + normalized_sample * (inverse_depth_min - inverse_depth_max)
depth_sample = 1.0 / depth_sample
return depth_sample
class PixelViewWeight(nn.Module):
def __init__(self, G):
super(PixelViewWeight, self).__init__()
self.conv = nn.Sequential(
ConvReLU(G, 16),
nn.Conv2d(16, 1, 1, stride=1, padding=0),
)
def forward(self, x):
# x: [B, G, N, H, W]
batch, dim, num_depth, height, width = x.size()
x = x.permute(0, 2, 1, 3, 4).contiguous()
x = x.view(batch*num_depth, dim, height, width) # [B*N,G,H,W]
x =self.conv(x).view(batch, num_depth, height, width)
x = torch.softmax(x,dim=1)
x = torch.max(x, dim=1)[0]
return x.unsqueeze(1)
class FeatureAtt(nn.Module):
def __init__(self, cv_chan, feat_chan):
super(FeatureAtt, self).__init__()
self.feat_att = nn.Sequential(
BasicConv(feat_chan, feat_chan//2, kernel_size=1, stride=1, padding=0),
nn.Conv2d(feat_chan//2, cv_chan, 1))
def forward(self, cv, feat):
'''
'''
feat_att = self.feat_att(feat).unsqueeze(2)
cv = torch.sigmoid(feat_att)*cv
return cv
class hourglass(nn.Module):
def __init__(self, in_channels):
super(hourglass, self).__init__()
self.conv1 = nn.Sequential(BasicConv(in_channels, in_channels*2, is_3d=True, bn=True, relu=True, kernel_size=3,
padding=1, stride=2, dilation=1),
BasicConv(in_channels*2, in_channels*2, is_3d=True, bn=True, relu=True, kernel_size=3,
padding=1, stride=1, dilation=1))
self.conv2 = nn.Sequential(BasicConv(in_channels*2, in_channels*4, is_3d=True, bn=True, relu=True, kernel_size=3,
padding=1, stride=2, dilation=1),
BasicConv(in_channels*4, in_channels*4, is_3d=True, bn=True, relu=True, kernel_size=3,
padding=1, stride=1, dilation=1))
self.conv3 = nn.Sequential(BasicConv(in_channels*4, in_channels*6, is_3d=True, bn=True, relu=True, kernel_size=3,
padding=1, stride=2, dilation=1),
BasicConv(in_channels*6, in_channels*6, is_3d=True, bn=True, relu=True, kernel_size=3,
padding=1, stride=1, dilation=1))
self.conv3_up = BasicConv(in_channels*6, in_channels*4, deconv=True, is_3d=True, bn=True,
relu=True, kernel_size=(4, 4, 4), padding=(1, 1, 1), stride=(2, 2, 2))
self.conv2_up = BasicConv(in_channels*4, in_channels*2, deconv=True, is_3d=True, bn=True,
relu=True, kernel_size=(4, 4, 4), padding=(1, 1, 1), stride=(2, 2, 2))
self.conv1_up = BasicConv(in_channels*2, 1, deconv=True, is_3d=True, bn=False,
relu=False, kernel_size=(4, 4, 4), padding=(1, 1, 1), stride=(2, 2, 2))
self.agg_0 = nn.Sequential(BasicConv(in_channels*8, in_channels*4, is_3d=True, kernel_size=1, padding=0, stride=1),
BasicConv(in_channels*4, in_channels*4, is_3d=True, kernel_size=3, padding=1, stride=1),
BasicConv(in_channels*4, in_channels*4, is_3d=True, kernel_size=3, padding=1, stride=1),)
self.agg_1 = nn.Sequential(BasicConv(in_channels*4, in_channels*2, is_3d=True, kernel_size=1, padding=0, stride=1),
BasicConv(in_channels*2, in_channels*2, is_3d=True, kernel_size=3, padding=1, stride=1),
BasicConv(in_channels*2, in_channels*2, is_3d=True, kernel_size=3, padding=1, stride=1))
self.feature_att_8 = FeatureAtt(in_channels*2, 64)
self.feature_att_16 = FeatureAtt(in_channels*4, 192)
self.feature_att_32 = FeatureAtt(in_channels*6, 160)
self.feature_att_up_16 = FeatureAtt(in_channels*4, 192)
self.feature_att_up_8 = FeatureAtt(in_channels*2, 64)
def forward(self, x, features):
conv1 = self.conv1(x)
conv1 = self.feature_att_8(conv1, features[1])
conv2 = self.conv2(conv1)
conv2 = self.feature_att_16(conv2, features[2])
conv3 = self.conv3(conv2)
conv3 = self.feature_att_32(conv3, features[3])
conv3_up = self.conv3_up(conv3)
conv2 = torch.cat((conv3_up, conv2), dim=1)
conv2 = self.agg_0(conv2)
conv2 = self.feature_att_up_16(conv2, features[2])
conv2_up = self.conv2_up(conv2)
conv1 = torch.cat((conv2_up, conv1), dim=1)
conv1 = self.agg_1(conv1)
conv1 = self.feature_att_up_8(conv1, features[1])
conv = self.conv1_up(conv1)
return conv
def bilinear_sampler(img, coords, mode='bilinear', mask=False):
""" Wrapper for grid_sample, uses pixel coordinates """
H, W = img.shape[-2:]
xgrid, ygrid = coords.split([1,1], dim=-1)
xgrid = 2*xgrid/(W-1) - 1
assert torch.unique(ygrid).numel() == 1 and H == 1 # This is a stereo problem
grid = torch.cat([xgrid, ygrid], dim=-1)
img = F.grid_sample(img, grid, align_corners=True)
if mask:
mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
return img, mask.float()
return img
def context_upsample(disp_low, up_weights):
###
# cv (b,1,h,w)
# sp (b,9,4*h,4*w)
###
b, c, h, w = disp_low.shape
disp_unfold = F.unfold(disp_low.reshape(b,c,h,w),3,1,1).reshape(b,-1,h,w)
disp_unfold = F.interpolate(disp_unfold,(h*4,w*4),mode='nearest').reshape(b,9,h*4,w*4)
disp = (disp_unfold*up_weights).sum(1)
return disp
def pool2x(x):
return F.avg_pool2d(x, 3, stride=2, padding=1)
def interp(x, dest):
interp_args = {'mode': 'bilinear', 'align_corners': True}
return F.interpolate(x, dest.shape[2:], **interp_args)
def differentiable_warping(src_fea, src_proj, ref_proj, depth_samples, return_mask=False):
# src_fea: [B, C, H, W]
# src_proj: [B, 4, 4]
# ref_proj: [B, 4, 4]
# depth_samples: [B, Ndepth, H, W]
# out: [B, C, Ndepth, H, W]
batch, num_depth, height, width = depth_samples.size()
height1, width1 = src_fea.size()[2:]
with torch.no_grad():
if batch==2:
inv_ref_proj = []
for i in range(batch):
inv_ref_proj.append(torch.inverse(ref_proj[i]).unsqueeze(0))
inv_ref_proj = torch.cat(inv_ref_proj, dim=0)
assert (not torch.isnan(inv_ref_proj).any()), "nan in inverse(ref_proj)"
proj = torch.matmul(src_proj, inv_ref_proj)
else:
proj = torch.matmul(src_proj, torch.inverse(ref_proj))
assert (not torch.isnan(proj).any()), "nan in proj"
rot = proj[:, :3, :3] # [B,3,3]
trans = proj[:, :3, 3:4] # [B,3,1]
y, x = torch.meshgrid([torch.arange(0, height, dtype=torch.float32, device=depth_samples.device),
torch.arange(0, width, dtype=torch.float32, device=depth_samples.device)])
y, x = y.contiguous(), x.contiguous()
y, x = y.view(height * width), x.view(height * width)
y = y*(height1/height)
x = x*(width1/width)
xyz = torch.stack((x, y, torch.ones_like(x))) # [3, H*W]
xyz = torch.unsqueeze(xyz, 0).repeat(batch, 1, 1) # [B, 3, H*W]
rot_xyz = torch.matmul(rot, xyz) # [B, 3, H*W]
rot_depth_xyz = rot_xyz.unsqueeze(2).repeat(1, 1, num_depth, 1) * depth_samples.view(batch, 1, num_depth,
height * width) # [B, 3, Ndepth, H*W]
proj_xyz = rot_depth_xyz + trans.view(batch, 3, 1, 1) # [B, 3, Ndepth, H*W]
# avoid negative depth
valid_mask = proj_xyz[:, 2:] > 1e-2
proj_xyz[:, 0:1][~valid_mask] = width
proj_xyz[:, 1:2][~valid_mask] = height
proj_xyz[:, 2:3][~valid_mask] = 1
proj_xy = proj_xyz[:, :2, :, :] / proj_xyz[:, 2:3, :, :] # [B, 2, Ndepth, H*W]
valid_mask = valid_mask & (proj_xy[:, 0:1] >=0) & (proj_xy[:, 0:1] < width) \
& (proj_xy[:, 1:2] >=0) & (proj_xy[:, 1:2] < height)
proj_x_normalized = proj_xy[:, 0, :, :] / ((width1 - 1) / 2) - 1 # [B, Ndepth, H*W]
proj_y_normalized = proj_xy[:, 1, :, :] / ((height1 - 1) / 2) - 1
proj_xy = torch.stack((proj_x_normalized, proj_y_normalized), dim=3) # [B, Ndepth, H*W, 2]
grid = proj_xy
dim = src_fea.size()[1]
warped_src_fea = F.grid_sample(src_fea, grid.view(batch, num_depth * height, width, 2), mode='bilinear',
padding_mode='zeros',align_corners=True)
warped_src_fea = warped_src_fea.view(batch, dim, num_depth, height, width)
if return_mask:
valid_mask = valid_mask.view(batch,num_depth,height,width)
return warped_src_fea, valid_mask
else:
return warped_src_fea
def depth_normalization(depth, inverse_depth_min, inverse_depth_max):
'''convert depth map to the index in inverse range'''
inverse_depth = 1.0 / (depth+1e-5)
normalized_depth = (inverse_depth - inverse_depth_max) / (inverse_depth_min - inverse_depth_max)
return normalized_depth
def depth_unnormalization(normalized_depth, inverse_depth_min, inverse_depth_max):
'''convert the index in inverse range to depth map'''
inverse_depth = inverse_depth_max + normalized_depth * (inverse_depth_min - inverse_depth_max) # [B,1,H,W]
depth = 1.0 / inverse_depth
return depth