90 lines
3.1 KiB
Python
90 lines
3.1 KiB
Python
|
from __future__ import print_function
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
import torch.utils.data
|
||
|
from torch.autograd import Variable
|
||
|
from torch.autograd.function import Function
|
||
|
import torch.nn.functional as F
|
||
|
import numpy as np
|
||
|
|
||
|
|
||
|
def convbn(in_channels, out_channels, kernel_size, stride, pad, dilation):
|
||
|
return nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride,
|
||
|
padding=dilation if dilation > 1 else pad, dilation=dilation, bias=False),
|
||
|
nn.BatchNorm2d(out_channels))
|
||
|
|
||
|
|
||
|
def convbn_3d(in_channels, out_channels, kernel_size, stride, pad):
|
||
|
return nn.Sequential(nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, stride=stride,
|
||
|
padding=pad, bias=False),
|
||
|
nn.BatchNorm3d(out_channels))
|
||
|
|
||
|
|
||
|
def disparity_regression(x, maxdisp):
|
||
|
assert len(x.shape) == 4
|
||
|
disp_values = torch.arange(0, maxdisp, dtype=x.dtype, device=x.device)
|
||
|
disp_values = disp_values.view(1, maxdisp, 1, 1)
|
||
|
return torch.sum(x * disp_values, 1, keepdim=False)
|
||
|
|
||
|
|
||
|
def build_concat_volume(refimg_fea, targetimg_fea, maxdisp):
|
||
|
B, C, H, W = refimg_fea.shape
|
||
|
volume = refimg_fea.new_zeros([B, 2 * C, maxdisp, H, W])
|
||
|
for i in range(maxdisp):
|
||
|
if i > 0:
|
||
|
volume[:, :C, i, :, i:] = refimg_fea[:, :, :, i:]
|
||
|
volume[:, C:, i, :, i:] = targetimg_fea[:, :, :, :-i]
|
||
|
else:
|
||
|
volume[:, :C, i, :, :] = refimg_fea
|
||
|
volume[:, C:, i, :, :] = targetimg_fea
|
||
|
volume = volume.contiguous()
|
||
|
return volume
|
||
|
|
||
|
|
||
|
def groupwise_correlation(fea1, fea2, num_groups):
|
||
|
B, C, H, W = fea1.shape
|
||
|
assert C % num_groups == 0
|
||
|
channels_per_group = C // num_groups
|
||
|
cost = (fea1 * fea2).view([B, num_groups, channels_per_group, H, W]).mean(dim=2)
|
||
|
assert cost.shape == (B, num_groups, H, W)
|
||
|
return cost
|
||
|
|
||
|
|
||
|
def build_gwc_volume(refimg_fea, targetimg_fea, maxdisp, num_groups):
|
||
|
B, C, H, W = refimg_fea.shape
|
||
|
volume = refimg_fea.new_zeros([B, num_groups, maxdisp, H, W])
|
||
|
for i in range(maxdisp):
|
||
|
if i > 0:
|
||
|
volume[:, :, i, :, i:] = groupwise_correlation(refimg_fea[:, :, :, i:], targetimg_fea[:, :, :, :-i],
|
||
|
num_groups)
|
||
|
else:
|
||
|
volume[:, :, i, :, :] = groupwise_correlation(refimg_fea, targetimg_fea, num_groups)
|
||
|
volume = volume.contiguous()
|
||
|
return volume
|
||
|
|
||
|
|
||
|
class BasicBlock(nn.Module):
|
||
|
expansion = 1
|
||
|
|
||
|
def __init__(self, inplanes, planes, stride, downsample, pad, dilation):
|
||
|
super(BasicBlock, self).__init__()
|
||
|
|
||
|
self.conv1 = nn.Sequential(convbn(inplanes, planes, 3, stride, pad, dilation),
|
||
|
nn.ReLU(inplace=True))
|
||
|
|
||
|
self.conv2 = convbn(planes, planes, 3, 1, pad, dilation)
|
||
|
|
||
|
self.downsample = downsample
|
||
|
self.stride = stride
|
||
|
|
||
|
def forward(self, x):
|
||
|
out = self.conv1(x)
|
||
|
out = self.conv2(out)
|
||
|
|
||
|
if self.downsample is not None:
|
||
|
x = self.downsample(x)
|
||
|
|
||
|
out += x
|
||
|
|
||
|
return out
|