import torch import torch.nn as nn import torch.nn.functional as F import math class SubModule(nn.Module): def __init__(self): super(SubModule, self).__init__() def weight_init(self): for m in self.modules(): if isinstance(m, nn.Conv2d): n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels m.weight.data.normal_(0, math.sqrt(2. / n)) elif isinstance(m, nn.Conv3d): n = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] * m.out_channels m.weight.data.normal_(0, math.sqrt(2. / n)) elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() elif isinstance(m, nn.BatchNorm3d): m.weight.data.fill_(1) m.bias.data.zero_() class BasicConv(nn.Module): def __init__(self, in_channels, out_channels, deconv=False, is_3d=False, bn=True, relu=True, **kwargs): super(BasicConv, self).__init__() self.relu = relu self.use_bn = bn if is_3d: if deconv: self.conv = nn.ConvTranspose3d(in_channels, out_channels, bias=False, **kwargs) else: self.conv = nn.Conv3d(in_channels, out_channels, bias=False, **kwargs) self.bn = nn.BatchNorm3d(out_channels) else: if deconv: self.conv = nn.ConvTranspose2d(in_channels, out_channels, bias=False, **kwargs) else: self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) self.bn = nn.BatchNorm2d(out_channels) def forward(self, x): x = self.conv(x) if self.use_bn: x = self.bn(x) if self.relu: x = nn.LeakyReLU()(x)#, inplace=True) return x class BasicConv_IN(nn.Module): def __init__(self, in_channels, out_channels, deconv=False, is_3d=False, IN=True, relu=True, **kwargs): super(BasicConv_IN, self).__init__() self.relu = relu self.use_in = IN if is_3d: if deconv: self.conv = nn.ConvTranspose3d(in_channels, out_channels, bias=False, **kwargs) else: self.conv = nn.Conv3d(in_channels, out_channels, bias=False, **kwargs) self.IN = nn.InstanceNorm3d(out_channels) else: if deconv: self.conv = nn.ConvTranspose2d(in_channels, out_channels, bias=False, **kwargs) else: self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) self.IN = nn.InstanceNorm2d(out_channels) def forward(self, x): x = self.conv(x) if self.use_in: x = self.IN(x) if self.relu: x = nn.LeakyReLU()(x)#, inplace=True) return x class Conv2x(nn.Module): def __init__(self, in_channels, out_channels, deconv=False, is_3d=False, concat=True, keep_concat=True, bn=True, relu=True, keep_dispc=False): super(Conv2x, self).__init__() self.concat = concat self.is_3d = is_3d if deconv and is_3d: kernel = (4, 4, 4) elif deconv: kernel = 4 else: kernel = 3 if deconv and is_3d and keep_dispc: kernel = (1, 4, 4) stride = (1, 2, 2) padding = (0, 1, 1) self.conv1 = BasicConv(in_channels, out_channels, deconv, is_3d, bn=True, relu=True, kernel_size=kernel, stride=stride, padding=padding) else: self.conv1 = BasicConv(in_channels, out_channels, deconv, is_3d, bn=True, relu=True, kernel_size=kernel, stride=2, padding=1) if self.concat: mul = 2 if keep_concat else 1 self.conv2 = BasicConv(out_channels*2, out_channels*mul, False, is_3d, bn, relu, kernel_size=3, stride=1, padding=1) else: self.conv2 = BasicConv(out_channels, out_channels, False, is_3d, bn, relu, kernel_size=3, stride=1, padding=1) def forward(self, x, rem): x = self.conv1(x) if x.shape != rem.shape: x = F.interpolate( x, size=(rem.shape[-2], rem.shape[-1]), mode='nearest') if self.concat: x = torch.cat((x, rem), 1) else: x = x + rem x = self.conv2(x) return x class Conv2x_IN(nn.Module): def __init__(self, in_channels, out_channels, deconv=False, is_3d=False, concat=True, keep_concat=True, IN=True, relu=True, keep_dispc=False): super(Conv2x_IN, self).__init__() self.concat = concat self.is_3d = is_3d if deconv and is_3d: kernel = (4, 4, 4) elif deconv: kernel = 4 else: kernel = 3 if deconv and is_3d and keep_dispc: kernel = (1, 4, 4) stride = (1, 2, 2) padding = (0, 1, 1) self.conv1 = BasicConv_IN(in_channels, out_channels, deconv, is_3d, IN=True, relu=True, kernel_size=kernel, stride=stride, padding=padding) else: self.conv1 = BasicConv_IN(in_channels, out_channels, deconv, is_3d, IN=True, relu=True, kernel_size=kernel, stride=2, padding=1) if self.concat: mul = 2 if keep_concat else 1 self.conv2 = BasicConv_IN(out_channels*2, out_channels*mul, False, is_3d, IN, relu, kernel_size=3, stride=1, padding=1) else: self.conv2 = BasicConv_IN(out_channels, out_channels, False, is_3d, IN, relu, kernel_size=3, stride=1, padding=1) def forward(self, x, rem): x = self.conv1(x) if x.shape != rem.shape: x = F.interpolate( x, size=(rem.shape[-2], rem.shape[-1]), mode='nearest') if self.concat: x = torch.cat((x, rem), 1) else: x = x + rem x = self.conv2(x) return x class ConvReLU(nn.Module): def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, pad=1, dilation=1): super(ConvReLU, self).__init__() self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=pad, dilation=dilation, bias=False) def forward(self,x): return F.relu(self.conv(x), inplace=True) class DepthInitialization(nn.Module): def __init__(self, num_sample): super(DepthInitialization, self).__init__() self.num_sample = num_sample def forward(self, inverse_depth_min, inverse_depth_max, height, width, device): batch = inverse_depth_min.size()[0] index = torch.arange(0, self.num_sample, 1, device=device).view(1, self.num_sample, 1, 1).float() normalized_sample = index.repeat(batch, 1, height, width) / (self.num_sample-1) depth_sample = inverse_depth_max + normalized_sample * (inverse_depth_min - inverse_depth_max) depth_sample = 1.0 / depth_sample return depth_sample class PixelViewWeight(nn.Module): def __init__(self, G): super(PixelViewWeight, self).__init__() self.conv = nn.Sequential( ConvReLU(G, 16), nn.Conv2d(16, 1, 1, stride=1, padding=0), ) def forward(self, x): # x: [B, G, N, H, W] batch, dim, num_depth, height, width = x.size() x = x.permute(0, 2, 1, 3, 4).contiguous() x = x.view(batch*num_depth, dim, height, width) # [B*N,G,H,W] x =self.conv(x).view(batch, num_depth, height, width) x = torch.softmax(x,dim=1) x = torch.max(x, dim=1)[0] return x.unsqueeze(1) class FeatureAtt(nn.Module): def __init__(self, cv_chan, feat_chan): super(FeatureAtt, self).__init__() self.feat_att = nn.Sequential( BasicConv(feat_chan, feat_chan//2, kernel_size=1, stride=1, padding=0), nn.Conv2d(feat_chan//2, cv_chan, 1)) def forward(self, cv, feat): ''' ''' feat_att = self.feat_att(feat).unsqueeze(2) cv = torch.sigmoid(feat_att)*cv return cv class hourglass(nn.Module): def __init__(self, in_channels): super(hourglass, self).__init__() self.conv1 = nn.Sequential(BasicConv(in_channels, in_channels*2, is_3d=True, bn=True, relu=True, kernel_size=3, padding=1, stride=2, dilation=1), BasicConv(in_channels*2, in_channels*2, is_3d=True, bn=True, relu=True, kernel_size=3, padding=1, stride=1, dilation=1)) self.conv2 = nn.Sequential(BasicConv(in_channels*2, in_channels*4, is_3d=True, bn=True, relu=True, kernel_size=3, padding=1, stride=2, dilation=1), BasicConv(in_channels*4, in_channels*4, is_3d=True, bn=True, relu=True, kernel_size=3, padding=1, stride=1, dilation=1)) self.conv3 = nn.Sequential(BasicConv(in_channels*4, in_channels*6, is_3d=True, bn=True, relu=True, kernel_size=3, padding=1, stride=2, dilation=1), BasicConv(in_channels*6, in_channels*6, is_3d=True, bn=True, relu=True, kernel_size=3, padding=1, stride=1, dilation=1)) self.conv3_up = BasicConv(in_channels*6, in_channels*4, deconv=True, is_3d=True, bn=True, relu=True, kernel_size=(4, 4, 4), padding=(1, 1, 1), stride=(2, 2, 2)) self.conv2_up = BasicConv(in_channels*4, in_channels*2, deconv=True, is_3d=True, bn=True, relu=True, kernel_size=(4, 4, 4), padding=(1, 1, 1), stride=(2, 2, 2)) self.conv1_up = BasicConv(in_channels*2, 1, deconv=True, is_3d=True, bn=False, relu=False, kernel_size=(4, 4, 4), padding=(1, 1, 1), stride=(2, 2, 2)) self.agg_0 = nn.Sequential(BasicConv(in_channels*8, in_channels*4, is_3d=True, kernel_size=1, padding=0, stride=1), BasicConv(in_channels*4, in_channels*4, is_3d=True, kernel_size=3, padding=1, stride=1), BasicConv(in_channels*4, in_channels*4, is_3d=True, kernel_size=3, padding=1, stride=1),) self.agg_1 = nn.Sequential(BasicConv(in_channels*4, in_channels*2, is_3d=True, kernel_size=1, padding=0, stride=1), BasicConv(in_channels*2, in_channels*2, is_3d=True, kernel_size=3, padding=1, stride=1), BasicConv(in_channels*2, in_channels*2, is_3d=True, kernel_size=3, padding=1, stride=1)) self.feature_att_8 = FeatureAtt(in_channels*2, 64) self.feature_att_16 = FeatureAtt(in_channels*4, 192) self.feature_att_32 = FeatureAtt(in_channels*6, 160) self.feature_att_up_16 = FeatureAtt(in_channels*4, 192) self.feature_att_up_8 = FeatureAtt(in_channels*2, 64) def forward(self, x, features): conv1 = self.conv1(x) conv1 = self.feature_att_8(conv1, features[1]) conv2 = self.conv2(conv1) conv2 = self.feature_att_16(conv2, features[2]) conv3 = self.conv3(conv2) conv3 = self.feature_att_32(conv3, features[3]) conv3_up = self.conv3_up(conv3) conv2 = torch.cat((conv3_up, conv2), dim=1) conv2 = self.agg_0(conv2) conv2 = self.feature_att_up_16(conv2, features[2]) conv2_up = self.conv2_up(conv2) conv1 = torch.cat((conv2_up, conv1), dim=1) conv1 = self.agg_1(conv1) conv1 = self.feature_att_up_8(conv1, features[1]) conv = self.conv1_up(conv1) return conv def bilinear_sampler(img, coords, mode='bilinear', mask=False): """ Wrapper for grid_sample, uses pixel coordinates """ H, W = img.shape[-2:] xgrid, ygrid = coords.split([1,1], dim=-1) xgrid = 2*xgrid/(W-1) - 1 assert torch.unique(ygrid).numel() == 1 and H == 1 # This is a stereo problem grid = torch.cat([xgrid, ygrid], dim=-1) img = F.grid_sample(img, grid, align_corners=True) if mask: mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) return img, mask.float() return img def context_upsample(disp_low, up_weights): ### # cv (b,1,h,w) # sp (b,9,4*h,4*w) ### b, c, h, w = disp_low.shape disp_unfold = F.unfold(disp_low.reshape(b,c,h,w),3,1,1).reshape(b,-1,h,w) disp_unfold = F.interpolate(disp_unfold,(h*4,w*4),mode='nearest').reshape(b,9,h*4,w*4) disp = (disp_unfold*up_weights).sum(1) return disp def pool2x(x): return F.avg_pool2d(x, 3, stride=2, padding=1) def interp(x, dest): interp_args = {'mode': 'bilinear', 'align_corners': True} return F.interpolate(x, dest.shape[2:], **interp_args) def differentiable_warping(src_fea, src_proj, ref_proj, depth_samples, return_mask=False): # src_fea: [B, C, H, W] # src_proj: [B, 4, 4] # ref_proj: [B, 4, 4] # depth_samples: [B, Ndepth, H, W] # out: [B, C, Ndepth, H, W] batch, num_depth, height, width = depth_samples.size() height1, width1 = src_fea.size()[2:] with torch.no_grad(): if batch==2: inv_ref_proj = [] for i in range(batch): inv_ref_proj.append(torch.inverse(ref_proj[i]).unsqueeze(0)) inv_ref_proj = torch.cat(inv_ref_proj, dim=0) assert (not torch.isnan(inv_ref_proj).any()), "nan in inverse(ref_proj)" proj = torch.matmul(src_proj, inv_ref_proj) else: proj = torch.matmul(src_proj, torch.inverse(ref_proj)) assert (not torch.isnan(proj).any()), "nan in proj" rot = proj[:, :3, :3] # [B,3,3] trans = proj[:, :3, 3:4] # [B,3,1] y, x = torch.meshgrid([torch.arange(0, height, dtype=torch.float32, device=depth_samples.device), torch.arange(0, width, dtype=torch.float32, device=depth_samples.device)]) y, x = y.contiguous(), x.contiguous() y, x = y.view(height * width), x.view(height * width) y = y*(height1/height) x = x*(width1/width) xyz = torch.stack((x, y, torch.ones_like(x))) # [3, H*W] xyz = torch.unsqueeze(xyz, 0).repeat(batch, 1, 1) # [B, 3, H*W] rot_xyz = torch.matmul(rot, xyz) # [B, 3, H*W] rot_depth_xyz = rot_xyz.unsqueeze(2).repeat(1, 1, num_depth, 1) * depth_samples.view(batch, 1, num_depth, height * width) # [B, 3, Ndepth, H*W] proj_xyz = rot_depth_xyz + trans.view(batch, 3, 1, 1) # [B, 3, Ndepth, H*W] # avoid negative depth valid_mask = proj_xyz[:, 2:] > 1e-2 proj_xyz[:, 0:1][~valid_mask] = width proj_xyz[:, 1:2][~valid_mask] = height proj_xyz[:, 2:3][~valid_mask] = 1 proj_xy = proj_xyz[:, :2, :, :] / proj_xyz[:, 2:3, :, :] # [B, 2, Ndepth, H*W] valid_mask = valid_mask & (proj_xy[:, 0:1] >=0) & (proj_xy[:, 0:1] < width) \ & (proj_xy[:, 1:2] >=0) & (proj_xy[:, 1:2] < height) proj_x_normalized = proj_xy[:, 0, :, :] / ((width1 - 1) / 2) - 1 # [B, Ndepth, H*W] proj_y_normalized = proj_xy[:, 1, :, :] / ((height1 - 1) / 2) - 1 proj_xy = torch.stack((proj_x_normalized, proj_y_normalized), dim=3) # [B, Ndepth, H*W, 2] grid = proj_xy dim = src_fea.size()[1] warped_src_fea = F.grid_sample(src_fea, grid.view(batch, num_depth * height, width, 2), mode='bilinear', padding_mode='zeros',align_corners=True) warped_src_fea = warped_src_fea.view(batch, dim, num_depth, height, width) if return_mask: valid_mask = valid_mask.view(batch,num_depth,height,width) return warped_src_fea, valid_mask else: return warped_src_fea def depth_normalization(depth, inverse_depth_min, inverse_depth_max): '''convert depth map to the index in inverse range''' inverse_depth = 1.0 / (depth+1e-5) normalized_depth = (inverse_depth - inverse_depth_max) / (inverse_depth_min - inverse_depth_max) return normalized_depth def depth_unnormalization(normalized_depth, inverse_depth_min, inverse_depth_max): '''convert the index in inverse range to depth map''' inverse_depth = inverse_depth_max + normalized_depth * (inverse_depth_min - inverse_depth_max) # [B,1,H,W] depth = 1.0 / inverse_depth return depth