Initial Commit.
This commit is contained in:
parent
3bc74984bd
commit
e651d84ed8
0
IGEV-Stereo/core/__init__.py
Normal file
0
IGEV-Stereo/core/__init__.py
Normal file
362
IGEV-Stereo/core/extractor.py
Normal file
362
IGEV-Stereo/core/extractor.py
Normal file
@ -0,0 +1,362 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from core.submodule import *
|
||||||
|
import timm
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualBlock(nn.Module):
|
||||||
|
def __init__(self, in_planes, planes, norm_fn='group', stride=1):
|
||||||
|
super(ResidualBlock, self).__init__()
|
||||||
|
|
||||||
|
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride)
|
||||||
|
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
|
num_groups = planes // 8
|
||||||
|
|
||||||
|
if norm_fn == 'group':
|
||||||
|
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||||
|
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||||
|
if not (stride == 1 and in_planes == planes):
|
||||||
|
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||||
|
|
||||||
|
elif norm_fn == 'batch':
|
||||||
|
self.norm1 = nn.BatchNorm2d(planes)
|
||||||
|
self.norm2 = nn.BatchNorm2d(planes)
|
||||||
|
if not (stride == 1 and in_planes == planes):
|
||||||
|
self.norm3 = nn.BatchNorm2d(planes)
|
||||||
|
|
||||||
|
elif norm_fn == 'instance':
|
||||||
|
self.norm1 = nn.InstanceNorm2d(planes)
|
||||||
|
self.norm2 = nn.InstanceNorm2d(planes)
|
||||||
|
if not (stride == 1 and in_planes == planes):
|
||||||
|
self.norm3 = nn.InstanceNorm2d(planes)
|
||||||
|
|
||||||
|
elif norm_fn == 'none':
|
||||||
|
self.norm1 = nn.Sequential()
|
||||||
|
self.norm2 = nn.Sequential()
|
||||||
|
if not (stride == 1 and in_planes == planes):
|
||||||
|
self.norm3 = nn.Sequential()
|
||||||
|
|
||||||
|
if stride == 1 and in_planes == planes:
|
||||||
|
self.downsample = None
|
||||||
|
|
||||||
|
else:
|
||||||
|
self.downsample = nn.Sequential(
|
||||||
|
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
y = x
|
||||||
|
y = self.conv1(y)
|
||||||
|
y = self.norm1(y)
|
||||||
|
y = self.relu(y)
|
||||||
|
y = self.conv2(y)
|
||||||
|
y = self.norm2(y)
|
||||||
|
y = self.relu(y)
|
||||||
|
|
||||||
|
if self.downsample is not None:
|
||||||
|
x = self.downsample(x)
|
||||||
|
|
||||||
|
return self.relu(x+y)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class BottleneckBlock(nn.Module):
|
||||||
|
def __init__(self, in_planes, planes, norm_fn='group', stride=1):
|
||||||
|
super(BottleneckBlock, self).__init__()
|
||||||
|
|
||||||
|
self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0)
|
||||||
|
self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride)
|
||||||
|
self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0)
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
|
num_groups = planes // 8
|
||||||
|
|
||||||
|
if norm_fn == 'group':
|
||||||
|
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
|
||||||
|
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
|
||||||
|
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||||
|
if not stride == 1:
|
||||||
|
self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||||
|
|
||||||
|
elif norm_fn == 'batch':
|
||||||
|
self.norm1 = nn.BatchNorm2d(planes//4)
|
||||||
|
self.norm2 = nn.BatchNorm2d(planes//4)
|
||||||
|
self.norm3 = nn.BatchNorm2d(planes)
|
||||||
|
if not stride == 1:
|
||||||
|
self.norm4 = nn.BatchNorm2d(planes)
|
||||||
|
|
||||||
|
elif norm_fn == 'instance':
|
||||||
|
self.norm1 = nn.InstanceNorm2d(planes//4)
|
||||||
|
self.norm2 = nn.InstanceNorm2d(planes//4)
|
||||||
|
self.norm3 = nn.InstanceNorm2d(planes)
|
||||||
|
if not stride == 1:
|
||||||
|
self.norm4 = nn.InstanceNorm2d(planes)
|
||||||
|
|
||||||
|
elif norm_fn == 'none':
|
||||||
|
self.norm1 = nn.Sequential()
|
||||||
|
self.norm2 = nn.Sequential()
|
||||||
|
self.norm3 = nn.Sequential()
|
||||||
|
if not stride == 1:
|
||||||
|
self.norm4 = nn.Sequential()
|
||||||
|
|
||||||
|
if stride == 1:
|
||||||
|
self.downsample = None
|
||||||
|
|
||||||
|
else:
|
||||||
|
self.downsample = nn.Sequential(
|
||||||
|
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4)
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
y = x
|
||||||
|
y = self.relu(self.norm1(self.conv1(y)))
|
||||||
|
y = self.relu(self.norm2(self.conv2(y)))
|
||||||
|
y = self.relu(self.norm3(self.conv3(y)))
|
||||||
|
|
||||||
|
if self.downsample is not None:
|
||||||
|
x = self.downsample(x)
|
||||||
|
|
||||||
|
return self.relu(x+y)
|
||||||
|
|
||||||
|
class BasicEncoder(nn.Module):
|
||||||
|
def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0, downsample=3):
|
||||||
|
super(BasicEncoder, self).__init__()
|
||||||
|
self.norm_fn = norm_fn
|
||||||
|
self.downsample = downsample
|
||||||
|
|
||||||
|
if self.norm_fn == 'group':
|
||||||
|
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
|
||||||
|
|
||||||
|
elif self.norm_fn == 'batch':
|
||||||
|
self.norm1 = nn.BatchNorm2d(64)
|
||||||
|
|
||||||
|
elif self.norm_fn == 'instance':
|
||||||
|
self.norm1 = nn.InstanceNorm2d(64)
|
||||||
|
|
||||||
|
elif self.norm_fn == 'none':
|
||||||
|
self.norm1 = nn.Sequential()
|
||||||
|
|
||||||
|
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=1 + (downsample > 2), padding=3)
|
||||||
|
self.relu1 = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
|
self.in_planes = 64
|
||||||
|
self.layer1 = self._make_layer(64, stride=1)
|
||||||
|
self.layer2 = self._make_layer(96, stride=1 + (downsample > 1))
|
||||||
|
self.layer3 = self._make_layer(128, stride=1 + (downsample > 0))
|
||||||
|
|
||||||
|
# output convolution
|
||||||
|
self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1)
|
||||||
|
|
||||||
|
self.dropout = None
|
||||||
|
if dropout > 0:
|
||||||
|
self.dropout = nn.Dropout2d(p=dropout)
|
||||||
|
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, nn.Conv2d):
|
||||||
|
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||||
|
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
|
||||||
|
if m.weight is not None:
|
||||||
|
nn.init.constant_(m.weight, 1)
|
||||||
|
if m.bias is not None:
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
|
||||||
|
def _make_layer(self, dim, stride=1):
|
||||||
|
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
|
||||||
|
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
|
||||||
|
layers = (layer1, layer2)
|
||||||
|
|
||||||
|
self.in_planes = dim
|
||||||
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, x, dual_inp=False):
|
||||||
|
|
||||||
|
# if input is list, combine batch dimension
|
||||||
|
is_list = isinstance(x, tuple) or isinstance(x, list)
|
||||||
|
if is_list:
|
||||||
|
batch_dim = x[0].shape[0]
|
||||||
|
x = torch.cat(x, dim=0)
|
||||||
|
x = self.conv1(x)
|
||||||
|
x = self.norm1(x)
|
||||||
|
x = self.relu1(x)
|
||||||
|
x = self.layer1(x)
|
||||||
|
x = self.layer2(x)
|
||||||
|
x = self.layer3(x)
|
||||||
|
x = self.conv2(x)
|
||||||
|
|
||||||
|
if self.training and self.dropout is not None:
|
||||||
|
x = self.dropout(x)
|
||||||
|
|
||||||
|
if is_list:
|
||||||
|
x = x.split(split_size=batch_dim, dim=0)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
class MultiBasicEncoder(nn.Module):
|
||||||
|
def __init__(self, output_dim=[128], norm_fn='batch', dropout=0.0, downsample=3):
|
||||||
|
super(MultiBasicEncoder, self).__init__()
|
||||||
|
self.norm_fn = norm_fn
|
||||||
|
self.downsample = downsample
|
||||||
|
|
||||||
|
# self.norm_111 = nn.BatchNorm2d(128, affine=False, track_running_stats=False)
|
||||||
|
# self.norm_222 = nn.BatchNorm2d(128, affine=False, track_running_stats=False)
|
||||||
|
|
||||||
|
if self.norm_fn == 'group':
|
||||||
|
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
|
||||||
|
|
||||||
|
elif self.norm_fn == 'batch':
|
||||||
|
self.norm1 = nn.BatchNorm2d(64)
|
||||||
|
|
||||||
|
elif self.norm_fn == 'instance':
|
||||||
|
self.norm1 = nn.InstanceNorm2d(64)
|
||||||
|
|
||||||
|
elif self.norm_fn == 'none':
|
||||||
|
self.norm1 = nn.Sequential()
|
||||||
|
|
||||||
|
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=1 + (downsample > 2), padding=3)
|
||||||
|
self.relu1 = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
|
self.in_planes = 64
|
||||||
|
self.layer1 = self._make_layer(64, stride=1)
|
||||||
|
self.layer2 = self._make_layer(96, stride=1 + (downsample > 1))
|
||||||
|
self.layer3 = self._make_layer(128, stride=1 + (downsample > 0))
|
||||||
|
self.layer4 = self._make_layer(128, stride=2)
|
||||||
|
self.layer5 = self._make_layer(128, stride=2)
|
||||||
|
|
||||||
|
output_list = []
|
||||||
|
|
||||||
|
for dim in output_dim:
|
||||||
|
conv_out = nn.Sequential(
|
||||||
|
ResidualBlock(128, 128, self.norm_fn, stride=1),
|
||||||
|
nn.Conv2d(128, dim[2], 3, padding=1))
|
||||||
|
output_list.append(conv_out)
|
||||||
|
|
||||||
|
self.outputs04 = nn.ModuleList(output_list)
|
||||||
|
|
||||||
|
output_list = []
|
||||||
|
for dim in output_dim:
|
||||||
|
conv_out = nn.Sequential(
|
||||||
|
ResidualBlock(128, 128, self.norm_fn, stride=1),
|
||||||
|
nn.Conv2d(128, dim[1], 3, padding=1))
|
||||||
|
output_list.append(conv_out)
|
||||||
|
|
||||||
|
self.outputs08 = nn.ModuleList(output_list)
|
||||||
|
|
||||||
|
output_list = []
|
||||||
|
for dim in output_dim:
|
||||||
|
conv_out = nn.Conv2d(128, dim[0], 3, padding=1)
|
||||||
|
output_list.append(conv_out)
|
||||||
|
|
||||||
|
self.outputs16 = nn.ModuleList(output_list)
|
||||||
|
|
||||||
|
if dropout > 0:
|
||||||
|
self.dropout = nn.Dropout2d(p=dropout)
|
||||||
|
else:
|
||||||
|
self.dropout = None
|
||||||
|
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, nn.Conv2d):
|
||||||
|
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||||
|
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
|
||||||
|
if m.weight is not None:
|
||||||
|
nn.init.constant_(m.weight, 1)
|
||||||
|
if m.bias is not None:
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
|
||||||
|
def _make_layer(self, dim, stride=1):
|
||||||
|
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
|
||||||
|
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
|
||||||
|
layers = (layer1, layer2)
|
||||||
|
|
||||||
|
self.in_planes = dim
|
||||||
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def forward(self, x, dual_inp=False, num_layers=3):
|
||||||
|
|
||||||
|
x = self.conv1(x)
|
||||||
|
x = self.norm1(x)
|
||||||
|
x = self.relu1(x)
|
||||||
|
x = self.layer1(x)
|
||||||
|
x = self.layer2(x)
|
||||||
|
x = self.layer3(x)
|
||||||
|
if dual_inp:
|
||||||
|
v = x
|
||||||
|
x = x[:(x.shape[0]//2)]
|
||||||
|
|
||||||
|
outputs04 = [f(x) for f in self.outputs04]
|
||||||
|
if num_layers == 1:
|
||||||
|
return (outputs04, v) if dual_inp else (outputs04,)
|
||||||
|
|
||||||
|
y = self.layer4(x)
|
||||||
|
outputs08 = [f(y) for f in self.outputs08]
|
||||||
|
|
||||||
|
if num_layers == 2:
|
||||||
|
return (outputs04, outputs08, v) if dual_inp else (outputs04, outputs08)
|
||||||
|
|
||||||
|
z = self.layer5(y)
|
||||||
|
outputs16 = [f(z) for f in self.outputs16]
|
||||||
|
|
||||||
|
return (outputs04, outputs08, outputs16, v) if dual_inp else (outputs04, outputs08, outputs16)
|
||||||
|
|
||||||
|
|
||||||
|
class SubModule(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(SubModule, self).__init__()
|
||||||
|
|
||||||
|
def weight_init(self):
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, nn.Conv2d):
|
||||||
|
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||||
|
m.weight.data.normal_(0, math.sqrt(2. / n))
|
||||||
|
elif isinstance(m, nn.Conv3d):
|
||||||
|
n = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] * m.out_channels
|
||||||
|
m.weight.data.normal_(0, math.sqrt(2. / n))
|
||||||
|
elif isinstance(m, nn.BatchNorm2d):
|
||||||
|
m.weight.data.fill_(1)
|
||||||
|
m.bias.data.zero_()
|
||||||
|
elif isinstance(m, nn.BatchNorm3d):
|
||||||
|
m.weight.data.fill_(1)
|
||||||
|
m.bias.data.zero_()
|
||||||
|
|
||||||
|
|
||||||
|
class Feature(SubModule):
|
||||||
|
def __init__(self):
|
||||||
|
super(Feature, self).__init__()
|
||||||
|
pretrained = True
|
||||||
|
model = timm.create_model('mobilenetv2_100', pretrained=pretrained, features_only=True)
|
||||||
|
layers = [1,2,3,5,6]
|
||||||
|
chans = [16, 24, 32, 96, 160]
|
||||||
|
self.conv_stem = model.conv_stem
|
||||||
|
self.bn1 = model.bn1
|
||||||
|
self.act1 = model.act1
|
||||||
|
|
||||||
|
self.block0 = torch.nn.Sequential(*model.blocks[0:layers[0]])
|
||||||
|
self.block1 = torch.nn.Sequential(*model.blocks[layers[0]:layers[1]])
|
||||||
|
self.block2 = torch.nn.Sequential(*model.blocks[layers[1]:layers[2]])
|
||||||
|
self.block3 = torch.nn.Sequential(*model.blocks[layers[2]:layers[3]])
|
||||||
|
self.block4 = torch.nn.Sequential(*model.blocks[layers[3]:layers[4]])
|
||||||
|
|
||||||
|
self.deconv32_16 = Conv2x_IN(chans[4], chans[3], deconv=True, concat=True)
|
||||||
|
self.deconv16_8 = Conv2x_IN(chans[3]*2, chans[2], deconv=True, concat=True)
|
||||||
|
self.deconv8_4 = Conv2x_IN(chans[2]*2, chans[1], deconv=True, concat=True)
|
||||||
|
self.conv4 = BasicConv_IN(chans[1]*2, chans[1]*2, kernel_size=3, stride=1, padding=1)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.act1(self.bn1(self.conv_stem(x)))
|
||||||
|
x2 = self.block0(x)
|
||||||
|
x4 = self.block1(x2)
|
||||||
|
x8 = self.block2(x4)
|
||||||
|
x16 = self.block3(x8)
|
||||||
|
x32 = self.block4(x16)
|
||||||
|
|
||||||
|
x16 = self.deconv32_16(x32, x16)
|
||||||
|
x8 = self.deconv16_8(x16, x8)
|
||||||
|
x4 = self.deconv8_4(x8, x4)
|
||||||
|
x4 = self.conv4(x4)
|
||||||
|
return [x4, x8, x16, x32]
|
||||||
|
|
69
IGEV-Stereo/core/geometry.py
Normal file
69
IGEV-Stereo/core/geometry.py
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from core.utils.utils import bilinear_sampler
|
||||||
|
|
||||||
|
|
||||||
|
class Combined_Geo_Encoding_Volume:
|
||||||
|
def __init__(self, init_fmap1, init_fmap2, geo_volume, num_levels=2, radius=4):
|
||||||
|
self.num_levels = num_levels
|
||||||
|
self.radius = radius
|
||||||
|
self.geo_volume_pyramid = []
|
||||||
|
self.init_corr_pyramid = []
|
||||||
|
|
||||||
|
# all pairs correlation
|
||||||
|
init_corr = Combined_Geo_Encoding_Volume.corr(init_fmap1, init_fmap2)
|
||||||
|
|
||||||
|
b, h, w, _, w2 = init_corr.shape
|
||||||
|
b, c, d, h, w = geo_volume.shape
|
||||||
|
geo_volume = geo_volume.permute(0, 3, 4, 1, 2).reshape(b*h*w, c, 1, d)
|
||||||
|
|
||||||
|
init_corr = init_corr.reshape(b*h*w, 1, 1, w2)
|
||||||
|
self.geo_volume_pyramid.append(geo_volume)
|
||||||
|
self.init_corr_pyramid.append(init_corr)
|
||||||
|
for i in range(self.num_levels-1):
|
||||||
|
geo_volume = F.avg_pool2d(geo_volume, [1,2], stride=[1,2])
|
||||||
|
self.geo_volume_pyramid.append(geo_volume)
|
||||||
|
|
||||||
|
for i in range(self.num_levels-1):
|
||||||
|
init_corr = F.avg_pool2d(init_corr, [1,2], stride=[1,2])
|
||||||
|
self.init_corr_pyramid.append(init_corr)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def __call__(self, disp, coords):
|
||||||
|
r = self.radius
|
||||||
|
b, _, h, w = disp.shape
|
||||||
|
out_pyramid = []
|
||||||
|
for i in range(self.num_levels):
|
||||||
|
geo_volume = self.geo_volume_pyramid[i]
|
||||||
|
dx = torch.linspace(-r, r, 2*r+1)
|
||||||
|
dx = dx.view(1, 1, 2*r+1, 1).to(disp.device)
|
||||||
|
x0 = dx + disp.reshape(b*h*w, 1, 1, 1) / 2**i
|
||||||
|
y0 = torch.zeros_like(x0)
|
||||||
|
|
||||||
|
disp_lvl = torch.cat([x0,y0], dim=-1)
|
||||||
|
geo_volume = bilinear_sampler(geo_volume, disp_lvl)
|
||||||
|
geo_volume = geo_volume.view(b, h, w, -1)
|
||||||
|
|
||||||
|
init_corr = self.init_corr_pyramid[i]
|
||||||
|
init_x0 = coords.reshape(b*h*w, 1, 1, 1)/2**i - disp.reshape(b*h*w, 1, 1, 1) / 2**i + dx
|
||||||
|
init_coords_lvl = torch.cat([init_x0,y0], dim=-1)
|
||||||
|
init_corr = bilinear_sampler(init_corr, init_coords_lvl)
|
||||||
|
init_corr = init_corr.view(b, h, w, -1)
|
||||||
|
|
||||||
|
out_pyramid.append(geo_volume)
|
||||||
|
out_pyramid.append(init_corr)
|
||||||
|
out = torch.cat(out_pyramid, dim=-1)
|
||||||
|
return out.permute(0, 3, 1, 2).contiguous().float()
|
||||||
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def corr(fmap1, fmap2):
|
||||||
|
B, D, H, W1 = fmap1.shape
|
||||||
|
_, _, _, W2 = fmap2.shape
|
||||||
|
fmap1 = fmap1.view(B, D, H, W1)
|
||||||
|
fmap2 = fmap2.view(B, D, H, W2)
|
||||||
|
corr = torch.einsum('aijk,aijh->ajkh', fmap1, fmap2)
|
||||||
|
corr = corr.reshape(B, H, W1, 1, W2).contiguous()
|
||||||
|
return corr
|
221
IGEV-Stereo/core/igev_stereo.py
Normal file
221
IGEV-Stereo/core/igev_stereo.py
Normal file
@ -0,0 +1,221 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from core.update import BasicMultiUpdateBlock
|
||||||
|
from core.extractor import MultiBasicEncoder, Feature
|
||||||
|
from core.geometry import Combined_Geo_Encoding_Volume
|
||||||
|
from core.submodule import *
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
autocast = torch.cuda.amp.autocast
|
||||||
|
except:
|
||||||
|
class autocast:
|
||||||
|
def __init__(self, enabled):
|
||||||
|
pass
|
||||||
|
def __enter__(self):
|
||||||
|
pass
|
||||||
|
def __exit__(self, *args):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class hourglass(nn.Module):
|
||||||
|
def __init__(self, in_channels):
|
||||||
|
super(hourglass, self).__init__()
|
||||||
|
|
||||||
|
self.conv1 = nn.Sequential(BasicConv(in_channels, in_channels*2, is_3d=True, bn=True, relu=True, kernel_size=3,
|
||||||
|
padding=1, stride=2, dilation=1),
|
||||||
|
BasicConv(in_channels*2, in_channels*2, is_3d=True, bn=True, relu=True, kernel_size=3,
|
||||||
|
padding=1, stride=1, dilation=1))
|
||||||
|
|
||||||
|
self.conv2 = nn.Sequential(BasicConv(in_channels*2, in_channels*4, is_3d=True, bn=True, relu=True, kernel_size=3,
|
||||||
|
padding=1, stride=2, dilation=1),
|
||||||
|
BasicConv(in_channels*4, in_channels*4, is_3d=True, bn=True, relu=True, kernel_size=3,
|
||||||
|
padding=1, stride=1, dilation=1))
|
||||||
|
|
||||||
|
self.conv3 = nn.Sequential(BasicConv(in_channels*4, in_channels*6, is_3d=True, bn=True, relu=True, kernel_size=3,
|
||||||
|
padding=1, stride=2, dilation=1),
|
||||||
|
BasicConv(in_channels*6, in_channels*6, is_3d=True, bn=True, relu=True, kernel_size=3,
|
||||||
|
padding=1, stride=1, dilation=1))
|
||||||
|
|
||||||
|
|
||||||
|
self.conv3_up = BasicConv(in_channels*6, in_channels*4, deconv=True, is_3d=True, bn=True,
|
||||||
|
relu=True, kernel_size=(4, 4, 4), padding=(1, 1, 1), stride=(2, 2, 2))
|
||||||
|
|
||||||
|
self.conv2_up = BasicConv(in_channels*4, in_channels*2, deconv=True, is_3d=True, bn=True,
|
||||||
|
relu=True, kernel_size=(4, 4, 4), padding=(1, 1, 1), stride=(2, 2, 2))
|
||||||
|
|
||||||
|
self.conv1_up = BasicConv(in_channels*2, 8, deconv=True, is_3d=True, bn=False,
|
||||||
|
relu=False, kernel_size=(4, 4, 4), padding=(1, 1, 1), stride=(2, 2, 2))
|
||||||
|
|
||||||
|
self.agg_0 = nn.Sequential(BasicConv(in_channels*8, in_channels*4, is_3d=True, kernel_size=1, padding=0, stride=1),
|
||||||
|
BasicConv(in_channels*4, in_channels*4, is_3d=True, kernel_size=3, padding=1, stride=1),
|
||||||
|
BasicConv(in_channels*4, in_channels*4, is_3d=True, kernel_size=3, padding=1, stride=1),)
|
||||||
|
|
||||||
|
self.agg_1 = nn.Sequential(BasicConv(in_channels*4, in_channels*2, is_3d=True, kernel_size=1, padding=0, stride=1),
|
||||||
|
BasicConv(in_channels*2, in_channels*2, is_3d=True, kernel_size=3, padding=1, stride=1),
|
||||||
|
BasicConv(in_channels*2, in_channels*2, is_3d=True, kernel_size=3, padding=1, stride=1))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
self.feature_att_8 = FeatureAtt(in_channels*2, 64)
|
||||||
|
self.feature_att_16 = FeatureAtt(in_channels*4, 192)
|
||||||
|
self.feature_att_32 = FeatureAtt(in_channels*6, 160)
|
||||||
|
self.feature_att_up_16 = FeatureAtt(in_channels*4, 192)
|
||||||
|
self.feature_att_up_8 = FeatureAtt(in_channels*2, 64)
|
||||||
|
|
||||||
|
def forward(self, x, features):
|
||||||
|
conv1 = self.conv1(x)
|
||||||
|
conv1 = self.feature_att_8(conv1, features[1])
|
||||||
|
|
||||||
|
conv2 = self.conv2(conv1)
|
||||||
|
conv2 = self.feature_att_16(conv2, features[2])
|
||||||
|
|
||||||
|
conv3 = self.conv3(conv2)
|
||||||
|
conv3 = self.feature_att_32(conv3, features[3])
|
||||||
|
|
||||||
|
conv3_up = self.conv3_up(conv3)
|
||||||
|
conv2 = torch.cat((conv3_up, conv2), dim=1)
|
||||||
|
conv2 = self.agg_0(conv2)
|
||||||
|
conv2 = self.feature_att_up_16(conv2, features[2])
|
||||||
|
|
||||||
|
conv2_up = self.conv2_up(conv2)
|
||||||
|
conv1 = torch.cat((conv2_up, conv1), dim=1)
|
||||||
|
conv1 = self.agg_1(conv1)
|
||||||
|
conv1 = self.feature_att_up_8(conv1, features[1])
|
||||||
|
|
||||||
|
conv = self.conv1_up(conv1)
|
||||||
|
|
||||||
|
return conv
|
||||||
|
|
||||||
|
class IGEVStereo(nn.Module):
|
||||||
|
def __init__(self, args):
|
||||||
|
super().__init__()
|
||||||
|
self.args = args
|
||||||
|
|
||||||
|
context_dims = args.hidden_dims
|
||||||
|
|
||||||
|
self.cnet = MultiBasicEncoder(output_dim=[args.hidden_dims, context_dims], norm_fn="batch", downsample=args.n_downsample)
|
||||||
|
self.update_block = BasicMultiUpdateBlock(self.args, hidden_dims=args.hidden_dims)
|
||||||
|
|
||||||
|
self.context_zqr_convs = nn.ModuleList([nn.Conv2d(context_dims[i], args.hidden_dims[i]*3, 3, padding=3//2) for i in range(self.args.n_gru_layers)])
|
||||||
|
|
||||||
|
self.feature = Feature()
|
||||||
|
|
||||||
|
self.stem_2 = nn.Sequential(
|
||||||
|
BasicConv_IN(3, 32, kernel_size=3, stride=2, padding=1),
|
||||||
|
nn.Conv2d(32, 32, 3, 1, 1, bias=False),
|
||||||
|
nn.InstanceNorm2d(32), nn.ReLU()
|
||||||
|
)
|
||||||
|
self.stem_4 = nn.Sequential(
|
||||||
|
BasicConv_IN(32, 48, kernel_size=3, stride=2, padding=1),
|
||||||
|
nn.Conv2d(48, 48, 3, 1, 1, bias=False),
|
||||||
|
nn.InstanceNorm2d(48), nn.ReLU()
|
||||||
|
)
|
||||||
|
|
||||||
|
self.spx = nn.Sequential(nn.ConvTranspose2d(2*32, 9, kernel_size=4, stride=2, padding=1),)
|
||||||
|
self.spx_2 = Conv2x_IN(24, 32, True)
|
||||||
|
self.spx_4 = nn.Sequential(
|
||||||
|
BasicConv_IN(96, 24, kernel_size=3, stride=1, padding=1),
|
||||||
|
nn.Conv2d(24, 24, 3, 1, 1, bias=False),
|
||||||
|
nn.InstanceNorm2d(24), nn.ReLU()
|
||||||
|
)
|
||||||
|
|
||||||
|
self.spx_2_gru = Conv2x(32, 32, True)
|
||||||
|
self.spx_gru = nn.Sequential(nn.ConvTranspose2d(2*32, 9, kernel_size=4, stride=2, padding=1),)
|
||||||
|
|
||||||
|
self.conv = BasicConv_IN(96, 96, kernel_size=3, padding=1, stride=1)
|
||||||
|
self.desc = nn.Conv2d(96, 96, kernel_size=1, padding=0, stride=1)
|
||||||
|
|
||||||
|
self.corr_stem = BasicConv(8, 8, is_3d=True, kernel_size=3, stride=1, padding=1)
|
||||||
|
self.corr_feature_att = FeatureAtt(8, 96)
|
||||||
|
self.cost_agg = hourglass(8)
|
||||||
|
self.classifier = nn.Conv3d(8, 1, 3, 1, 1, bias=False)
|
||||||
|
|
||||||
|
def freeze_bn(self):
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, nn.BatchNorm2d):
|
||||||
|
m.eval()
|
||||||
|
|
||||||
|
def upsample_disp(self, disp, mask_feat_4, stem_2x):
|
||||||
|
|
||||||
|
with autocast(enabled=self.args.mixed_precision):
|
||||||
|
xspx = self.spx_2_gru(mask_feat_4, stem_2x)
|
||||||
|
spx_pred = self.spx_gru(xspx)
|
||||||
|
spx_pred = F.softmax(spx_pred, 1)
|
||||||
|
up_disp = context_upsample(disp*4., spx_pred).unsqueeze(1)
|
||||||
|
|
||||||
|
return up_disp
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, image1, image2, iters=12, flow_init=None, test_mode=False):
|
||||||
|
""" Estimate disparity between pair of frames """
|
||||||
|
|
||||||
|
image1 = (2 * (image1 / 255.0) - 1.0).contiguous()
|
||||||
|
image2 = (2 * (image2 / 255.0) - 1.0).contiguous()
|
||||||
|
with autocast(enabled=self.args.mixed_precision):
|
||||||
|
features_left = self.feature(image1)
|
||||||
|
features_right = self.feature(image2)
|
||||||
|
stem_2x = self.stem_2(image1)
|
||||||
|
stem_4x = self.stem_4(stem_2x)
|
||||||
|
stem_2y = self.stem_2(image2)
|
||||||
|
stem_4y = self.stem_4(stem_2y)
|
||||||
|
features_left[0] = torch.cat((features_left[0], stem_4x), 1)
|
||||||
|
features_right[0] = torch.cat((features_right[0], stem_4y), 1)
|
||||||
|
|
||||||
|
match_left = self.desc(self.conv(features_left[0]))
|
||||||
|
match_right = self.desc(self.conv(features_right[0]))
|
||||||
|
gwc_volume = build_gwc_volume(match_left, match_right, 192//4, 8)
|
||||||
|
gwc_volume = self.corr_stem(gwc_volume)
|
||||||
|
gwc_volume = self.corr_feature_att(gwc_volume, features_left[0])
|
||||||
|
geo_encoding_volume = self.cost_agg(gwc_volume, features_left)
|
||||||
|
|
||||||
|
# Init disp from geometry encoding volume
|
||||||
|
prob = F.softmax(self.classifier(geo_encoding_volume).squeeze(1), dim=1)
|
||||||
|
init_disp = disparity_regression(prob, self.args.max_disp//4)
|
||||||
|
|
||||||
|
del prob, gwc_volume
|
||||||
|
|
||||||
|
if not test_mode:
|
||||||
|
xspx = self.spx_4(features_left[0])
|
||||||
|
xspx = self.spx_2(xspx, stem_2x)
|
||||||
|
spx_pred = self.spx(xspx)
|
||||||
|
spx_pred = F.softmax(spx_pred, 1)
|
||||||
|
|
||||||
|
cnet_list = self.cnet(image1, num_layers=self.args.n_gru_layers)
|
||||||
|
net_list = [torch.tanh(x[0]) for x in cnet_list]
|
||||||
|
inp_list = [torch.relu(x[1]) for x in cnet_list]
|
||||||
|
inp_list = [list(conv(i).split(split_size=conv.out_channels//3, dim=1)) for i,conv in zip(inp_list, self.context_zqr_convs)]
|
||||||
|
|
||||||
|
|
||||||
|
geo_block = Combined_Geo_Encoding_Volume
|
||||||
|
geo_fn = geo_block(match_left.float(), match_right.float(), geo_encoding_volume.float(), radius=self.args.corr_radius, num_levels=self.args.corr_levels)
|
||||||
|
b, c, h, w = match_left.shape
|
||||||
|
coords = torch.arange(w).float().to(match_left.device).reshape(1,1,w,1).repeat(b, h, 1, 1)
|
||||||
|
disp = init_disp
|
||||||
|
disp_preds = []
|
||||||
|
|
||||||
|
# GRUs iterations to update disparity
|
||||||
|
for itr in range(iters):
|
||||||
|
disp = disp.detach()
|
||||||
|
geo_feat = geo_fn(disp, coords)
|
||||||
|
with autocast(enabled=self.args.mixed_precision):
|
||||||
|
if self.args.n_gru_layers == 3 and self.args.slow_fast_gru: # Update low-res ConvGRU
|
||||||
|
net_list = self.update_block(net_list, inp_list, iter16=True, iter08=False, iter04=False, update=False)
|
||||||
|
if self.args.n_gru_layers >= 2 and self.args.slow_fast_gru:# Update low-res ConvGRU and mid-res ConvGRU
|
||||||
|
net_list = self.update_block(net_list, inp_list, iter16=self.args.n_gru_layers==3, iter08=True, iter04=False, update=False)
|
||||||
|
net_list, mask_feat_4, delta_disp = self.update_block(net_list, inp_list, geo_feat, disp, iter16=self.args.n_gru_layers==3, iter08=self.args.n_gru_layers>=2)
|
||||||
|
|
||||||
|
disp = disp + delta_disp
|
||||||
|
if test_mode and itr < iters-1:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# upsample predictions
|
||||||
|
disp_up = self.upsample_disp(disp, mask_feat_4, stem_2x)
|
||||||
|
disp_preds.append(disp_up)
|
||||||
|
|
||||||
|
if test_mode:
|
||||||
|
return disp_up
|
||||||
|
|
||||||
|
init_disp = context_upsample(init_disp*4., spx_pred.float()).unsqueeze(1)
|
||||||
|
return init_disp, disp_preds
|
331
IGEV-Stereo/core/stereo_datasets.py
Normal file
331
IGEV-Stereo/core/stereo_datasets.py
Normal file
@ -0,0 +1,331 @@
|
|||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.utils.data as data
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import copy
|
||||||
|
import math
|
||||||
|
import random
|
||||||
|
from pathlib import Path
|
||||||
|
from glob import glob
|
||||||
|
import os.path as osp
|
||||||
|
|
||||||
|
from core.utils import frame_utils
|
||||||
|
from core.utils.augmentor import FlowAugmentor, SparseFlowAugmentor
|
||||||
|
|
||||||
|
|
||||||
|
class StereoDataset(data.Dataset):
|
||||||
|
def __init__(self, aug_params=None, sparse=False, reader=None):
|
||||||
|
self.augmentor = None
|
||||||
|
self.sparse = sparse
|
||||||
|
self.img_pad = aug_params.pop("img_pad", None) if aug_params is not None else None
|
||||||
|
if aug_params is not None and "crop_size" in aug_params:
|
||||||
|
if sparse:
|
||||||
|
self.augmentor = SparseFlowAugmentor(**aug_params)
|
||||||
|
else:
|
||||||
|
self.augmentor = FlowAugmentor(**aug_params)
|
||||||
|
|
||||||
|
if reader is None:
|
||||||
|
self.disparity_reader = frame_utils.read_gen
|
||||||
|
else:
|
||||||
|
self.disparity_reader = reader
|
||||||
|
|
||||||
|
self.is_test = False
|
||||||
|
self.init_seed = False
|
||||||
|
self.flow_list = []
|
||||||
|
self.disparity_list = []
|
||||||
|
self.image_list = []
|
||||||
|
self.extra_info = []
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
|
||||||
|
if self.is_test:
|
||||||
|
img1 = frame_utils.read_gen(self.image_list[index][0])
|
||||||
|
img2 = frame_utils.read_gen(self.image_list[index][1])
|
||||||
|
img1 = np.array(img1).astype(np.uint8)[..., :3]
|
||||||
|
img2 = np.array(img2).astype(np.uint8)[..., :3]
|
||||||
|
img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
|
||||||
|
img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
|
||||||
|
return img1, img2, self.extra_info[index]
|
||||||
|
|
||||||
|
if not self.init_seed:
|
||||||
|
worker_info = torch.utils.data.get_worker_info()
|
||||||
|
if worker_info is not None:
|
||||||
|
torch.manual_seed(worker_info.id)
|
||||||
|
np.random.seed(worker_info.id)
|
||||||
|
random.seed(worker_info.id)
|
||||||
|
self.init_seed = True
|
||||||
|
|
||||||
|
index = index % len(self.image_list)
|
||||||
|
disp = self.disparity_reader(self.disparity_list[index])
|
||||||
|
|
||||||
|
if isinstance(disp, tuple):
|
||||||
|
disp, valid = disp
|
||||||
|
else:
|
||||||
|
valid = disp < 512
|
||||||
|
|
||||||
|
img1 = frame_utils.read_gen(self.image_list[index][0])
|
||||||
|
img2 = frame_utils.read_gen(self.image_list[index][1])
|
||||||
|
|
||||||
|
img1 = np.array(img1).astype(np.uint8)
|
||||||
|
img2 = np.array(img2).astype(np.uint8)
|
||||||
|
|
||||||
|
disp = np.array(disp).astype(np.float32)
|
||||||
|
|
||||||
|
flow = np.stack([disp, np.zeros_like(disp)], axis=-1)
|
||||||
|
|
||||||
|
# grayscale images
|
||||||
|
if len(img1.shape) == 2:
|
||||||
|
img1 = np.tile(img1[...,None], (1, 1, 3))
|
||||||
|
img2 = np.tile(img2[...,None], (1, 1, 3))
|
||||||
|
else:
|
||||||
|
img1 = img1[..., :3]
|
||||||
|
img2 = img2[..., :3]
|
||||||
|
|
||||||
|
if self.augmentor is not None:
|
||||||
|
if self.sparse:
|
||||||
|
img1, img2, flow, valid = self.augmentor(img1, img2, flow, valid)
|
||||||
|
else:
|
||||||
|
|
||||||
|
img1, img2, flow = self.augmentor(img1, img2, flow)
|
||||||
|
|
||||||
|
img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
|
||||||
|
img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
|
||||||
|
flow = torch.from_numpy(flow).permute(2, 0, 1).float()
|
||||||
|
|
||||||
|
if self.sparse:
|
||||||
|
valid = torch.from_numpy(valid)
|
||||||
|
else:
|
||||||
|
valid = (flow[0].abs() < 512) & (flow[1].abs() < 512)
|
||||||
|
|
||||||
|
if self.img_pad is not None:
|
||||||
|
|
||||||
|
padH, padW = self.img_pad
|
||||||
|
img1 = F.pad(img1, [padW]*2 + [padH]*2)
|
||||||
|
img2 = F.pad(img2, [padW]*2 + [padH]*2)
|
||||||
|
|
||||||
|
flow = flow[:1]
|
||||||
|
return self.image_list[index] + [self.disparity_list[index]], img1, img2, flow, valid.float()
|
||||||
|
|
||||||
|
|
||||||
|
def __mul__(self, v):
|
||||||
|
copy_of_self = copy.deepcopy(self)
|
||||||
|
copy_of_self.flow_list = v * copy_of_self.flow_list
|
||||||
|
copy_of_self.image_list = v * copy_of_self.image_list
|
||||||
|
copy_of_self.disparity_list = v * copy_of_self.disparity_list
|
||||||
|
copy_of_self.extra_info = v * copy_of_self.extra_info
|
||||||
|
return copy_of_self
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.image_list)
|
||||||
|
|
||||||
|
|
||||||
|
class SceneFlowDatasets(StereoDataset):
|
||||||
|
def __init__(self, aug_params=None, root='/data/sceneflow/', dstype='frames_finalpass', things_test=False):
|
||||||
|
super(SceneFlowDatasets, self).__init__(aug_params)
|
||||||
|
self.root = root
|
||||||
|
self.dstype = dstype
|
||||||
|
|
||||||
|
if things_test:
|
||||||
|
self._add_things("TEST")
|
||||||
|
else:
|
||||||
|
self._add_things("TRAIN")
|
||||||
|
self._add_monkaa("TRAIN")
|
||||||
|
self._add_driving("TRAIN")
|
||||||
|
|
||||||
|
def _add_things(self, split='TRAIN'):
|
||||||
|
""" Add FlyingThings3D data """
|
||||||
|
|
||||||
|
original_length = len(self.disparity_list)
|
||||||
|
# root = osp.join(self.root, 'FlyingThings3D')
|
||||||
|
root = self.root
|
||||||
|
left_images = sorted( glob(osp.join(root, self.dstype, split, '*/*/left/*.png')) )
|
||||||
|
right_images = [ im.replace('left', 'right') for im in left_images ]
|
||||||
|
disparity_images = [ im.replace(self.dstype, 'disparity').replace('.png', '.pfm') for im in left_images ]
|
||||||
|
|
||||||
|
# Choose a random subset of 400 images for validation
|
||||||
|
state = np.random.get_state()
|
||||||
|
np.random.seed(1000)
|
||||||
|
# val_idxs = set(np.random.permutation(len(left_images))[:100])
|
||||||
|
val_idxs = set(np.random.permutation(len(left_images)))
|
||||||
|
np.random.set_state(state)
|
||||||
|
|
||||||
|
for idx, (img1, img2, disp) in enumerate(zip(left_images, right_images, disparity_images)):
|
||||||
|
if (split == 'TEST' and idx in val_idxs) or split == 'TRAIN':
|
||||||
|
self.image_list += [ [img1, img2] ]
|
||||||
|
self.disparity_list += [ disp ]
|
||||||
|
logging.info(f"Added {len(self.disparity_list) - original_length} from FlyingThings {self.dstype}")
|
||||||
|
|
||||||
|
def _add_monkaa(self, split="TRAIN"):
|
||||||
|
""" Add FlyingThings3D data """
|
||||||
|
|
||||||
|
original_length = len(self.disparity_list)
|
||||||
|
root = self.root
|
||||||
|
left_images = sorted( glob(osp.join(root, self.dstype, split, '*/left/*.png')) )
|
||||||
|
right_images = [ image_file.replace('left', 'right') for image_file in left_images ]
|
||||||
|
disparity_images = [ im.replace(self.dstype, 'disparity').replace('.png', '.pfm') for im in left_images ]
|
||||||
|
|
||||||
|
for img1, img2, disp in zip(left_images, right_images, disparity_images):
|
||||||
|
self.image_list += [ [img1, img2] ]
|
||||||
|
self.disparity_list += [ disp ]
|
||||||
|
logging.info(f"Added {len(self.disparity_list) - original_length} from Monkaa {self.dstype}")
|
||||||
|
|
||||||
|
|
||||||
|
def _add_driving(self, split="TRAIN"):
|
||||||
|
""" Add FlyingThings3D data """
|
||||||
|
|
||||||
|
original_length = len(self.disparity_list)
|
||||||
|
root = self.root
|
||||||
|
left_images = sorted( glob(osp.join(root, self.dstype, split, '*/*/*/left/*.png')) )
|
||||||
|
right_images = [ image_file.replace('left', 'right') for image_file in left_images ]
|
||||||
|
disparity_images = [ im.replace(self.dstype, 'disparity').replace('.png', '.pfm') for im in left_images ]
|
||||||
|
|
||||||
|
for img1, img2, disp in zip(left_images, right_images, disparity_images):
|
||||||
|
self.image_list += [ [img1, img2] ]
|
||||||
|
self.disparity_list += [ disp ]
|
||||||
|
logging.info(f"Added {len(self.disparity_list) - original_length} from Driving {self.dstype}")
|
||||||
|
|
||||||
|
|
||||||
|
class ETH3D(StereoDataset):
|
||||||
|
def __init__(self, aug_params=None, root='/data/ETH3D', split='training'):
|
||||||
|
super(ETH3D, self).__init__(aug_params, sparse=True)
|
||||||
|
|
||||||
|
image1_list = sorted( glob(osp.join(root, f'two_view_{split}/*/im0.png')) )
|
||||||
|
image2_list = sorted( glob(osp.join(root, f'two_view_{split}/*/im1.png')) )
|
||||||
|
disp_list = sorted( glob(osp.join(root, 'two_view_training_gt/*/disp0GT.pfm')) ) if split == 'training' else [osp.join(root, 'two_view_training_gt/playground_1l/disp0GT.pfm')]*len(image1_list)
|
||||||
|
|
||||||
|
for img1, img2, disp in zip(image1_list, image2_list, disp_list):
|
||||||
|
self.image_list += [ [img1, img2] ]
|
||||||
|
self.disparity_list += [ disp ]
|
||||||
|
|
||||||
|
class SintelStereo(StereoDataset):
|
||||||
|
def __init__(self, aug_params=None, root='datasets/SintelStereo'):
|
||||||
|
super().__init__(aug_params, sparse=True, reader=frame_utils.readDispSintelStereo)
|
||||||
|
|
||||||
|
image1_list = sorted( glob(osp.join(root, 'training/*_left/*/frame_*.png')) )
|
||||||
|
image2_list = sorted( glob(osp.join(root, 'training/*_right/*/frame_*.png')) )
|
||||||
|
disp_list = sorted( glob(osp.join(root, 'training/disparities/*/frame_*.png')) ) * 2
|
||||||
|
|
||||||
|
for img1, img2, disp in zip(image1_list, image2_list, disp_list):
|
||||||
|
assert img1.split('/')[-2:] == disp.split('/')[-2:]
|
||||||
|
self.image_list += [ [img1, img2] ]
|
||||||
|
self.disparity_list += [ disp ]
|
||||||
|
|
||||||
|
class FallingThings(StereoDataset):
|
||||||
|
def __init__(self, aug_params=None, root='datasets/FallingThings'):
|
||||||
|
super().__init__(aug_params, reader=frame_utils.readDispFallingThings)
|
||||||
|
assert os.path.exists(root)
|
||||||
|
|
||||||
|
with open(os.path.join(root, 'filenames.txt'), 'r') as f:
|
||||||
|
filenames = sorted(f.read().splitlines())
|
||||||
|
|
||||||
|
image1_list = [osp.join(root, e) for e in filenames]
|
||||||
|
image2_list = [osp.join(root, e.replace('left.jpg', 'right.jpg')) for e in filenames]
|
||||||
|
disp_list = [osp.join(root, e.replace('left.jpg', 'left.depth.png')) for e in filenames]
|
||||||
|
|
||||||
|
for img1, img2, disp in zip(image1_list, image2_list, disp_list):
|
||||||
|
self.image_list += [ [img1, img2] ]
|
||||||
|
self.disparity_list += [ disp ]
|
||||||
|
|
||||||
|
class TartanAir(StereoDataset):
|
||||||
|
def __init__(self, aug_params=None, root='datasets', keywords=[]):
|
||||||
|
super().__init__(aug_params, reader=frame_utils.readDispTartanAir)
|
||||||
|
assert os.path.exists(root)
|
||||||
|
|
||||||
|
with open(os.path.join(root, 'tartanair_filenames.txt'), 'r') as f:
|
||||||
|
filenames = sorted(list(filter(lambda s: 'seasonsforest_winter/Easy' not in s, f.read().splitlines())))
|
||||||
|
for kw in keywords:
|
||||||
|
filenames = sorted(list(filter(lambda s: kw in s.lower(), filenames)))
|
||||||
|
|
||||||
|
image1_list = [osp.join(root, e) for e in filenames]
|
||||||
|
image2_list = [osp.join(root, e.replace('_left', '_right')) for e in filenames]
|
||||||
|
disp_list = [osp.join(root, e.replace('image_left', 'depth_left').replace('left.png', 'left_depth.npy')) for e in filenames]
|
||||||
|
|
||||||
|
for img1, img2, disp in zip(image1_list, image2_list, disp_list):
|
||||||
|
self.image_list += [ [img1, img2] ]
|
||||||
|
self.disparity_list += [ disp ]
|
||||||
|
|
||||||
|
class KITTI(StereoDataset):
|
||||||
|
def __init__(self, aug_params=None, root='/data/KITTI/KITTI_2015', image_set='training'):
|
||||||
|
super(KITTI, self).__init__(aug_params, sparse=True, reader=frame_utils.readDispKITTI)
|
||||||
|
assert os.path.exists(root)
|
||||||
|
|
||||||
|
root_12 = '/data/KITTI/KITTI_2012'
|
||||||
|
image1_list = sorted(glob(os.path.join(root_12, image_set, 'colored_0/*_10.png')))
|
||||||
|
image2_list = sorted(glob(os.path.join(root_12, image_set, 'colored_1/*_10.png')))
|
||||||
|
disp_list = sorted(glob(os.path.join(root_12, 'training', 'disp_occ/*_10.png'))) if image_set == 'training' else [osp.join(root, 'training/disp_occ/000085_10.png')]*len(image1_list)
|
||||||
|
|
||||||
|
root_15 = '/data/KITTI/KITTI_2015'
|
||||||
|
image1_list += sorted(glob(os.path.join(root_15, image_set, 'image_2/*_10.png')))
|
||||||
|
image2_list += sorted(glob(os.path.join(root_15, image_set, 'image_3/*_10.png')))
|
||||||
|
disp_list += sorted(glob(os.path.join(root_15, 'training', 'disp_occ_0/*_10.png'))) if image_set == 'training' else [osp.join(root, 'training/disp_occ_0/000085_10.png')]*len(image1_list)
|
||||||
|
|
||||||
|
for idx, (img1, img2, disp) in enumerate(zip(image1_list, image2_list, disp_list)):
|
||||||
|
self.image_list += [ [img1, img2] ]
|
||||||
|
self.disparity_list += [ disp ]
|
||||||
|
|
||||||
|
|
||||||
|
class Middlebury(StereoDataset):
|
||||||
|
def __init__(self, aug_params=None, root='/data/Middlebury', split='F'):
|
||||||
|
super(Middlebury, self).__init__(aug_params, sparse=True, reader=frame_utils.readDispMiddlebury)
|
||||||
|
assert os.path.exists(root)
|
||||||
|
assert split in "FHQ"
|
||||||
|
lines = list(map(osp.basename, glob(os.path.join(root, "trainingH/*"))))
|
||||||
|
# lines = list(filter(lambda p: any(s in p.split('/') for s in Path(os.path.join(root, "MiddEval3/official_train.txt")).read_text().splitlines()), lines))
|
||||||
|
# image1_list = sorted([os.path.join(root, "MiddEval3", f'training{split}', f'{name}/im0.png') for name in lines])
|
||||||
|
# image2_list = sorted([os.path.join(root, "MiddEval3", f'training{split}', f'{name}/im1.png') for name in lines])
|
||||||
|
# disp_list = sorted([os.path.join(root, "MiddEval3", f'training{split}', f'{name}/disp0GT.pfm') for name in lines])
|
||||||
|
image1_list = sorted([os.path.join(root, f'training{split}', f'{name}/im0.png') for name in lines])
|
||||||
|
image2_list = sorted([os.path.join(root, f'training{split}', f'{name}/im1.png') for name in lines])
|
||||||
|
disp_list = sorted([os.path.join(root, f'training{split}', f'{name}/disp0GT.pfm') for name in lines])
|
||||||
|
|
||||||
|
assert len(image1_list) == len(image2_list) == len(disp_list) > 0, [image1_list, split]
|
||||||
|
for img1, img2, disp in zip(image1_list, image2_list, disp_list):
|
||||||
|
self.image_list += [ [img1, img2] ]
|
||||||
|
self.disparity_list += [ disp ]
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_dataloader(args):
|
||||||
|
""" Create the data loader for the corresponding trainign set """
|
||||||
|
|
||||||
|
aug_params = {'crop_size': args.image_size, 'min_scale': args.spatial_scale[0], 'max_scale': args.spatial_scale[1], 'do_flip': False, 'yjitter': not args.noyjitter}
|
||||||
|
if hasattr(args, "saturation_range") and args.saturation_range is not None:
|
||||||
|
aug_params["saturation_range"] = args.saturation_range
|
||||||
|
if hasattr(args, "img_gamma") and args.img_gamma is not None:
|
||||||
|
aug_params["gamma"] = args.img_gamma
|
||||||
|
if hasattr(args, "do_flip") and args.do_flip is not None:
|
||||||
|
aug_params["do_flip"] = args.do_flip
|
||||||
|
|
||||||
|
|
||||||
|
train_dataset = None
|
||||||
|
for dataset_name in args.train_datasets:
|
||||||
|
if re.compile("middlebury_.*").fullmatch(dataset_name):
|
||||||
|
new_dataset = Middlebury(aug_params, split=dataset_name.replace('middlebury_',''))
|
||||||
|
elif dataset_name == 'sceneflow':
|
||||||
|
#clean_dataset = SceneFlowDatasets(aug_params, dstype='frames_cleanpass')
|
||||||
|
final_dataset = SceneFlowDatasets(aug_params, dstype='frames_finalpass')
|
||||||
|
#new_dataset = (clean_dataset*4) + (final_dataset*4)
|
||||||
|
new_dataset = final_dataset
|
||||||
|
logging.info(f"Adding {len(new_dataset)} samples from SceneFlow")
|
||||||
|
elif 'kitti' in dataset_name:
|
||||||
|
new_dataset = KITTI(aug_params)
|
||||||
|
logging.info(f"Adding {len(new_dataset)} samples from KITTI")
|
||||||
|
elif dataset_name == 'sintel_stereo':
|
||||||
|
new_dataset = SintelStereo(aug_params)*140
|
||||||
|
logging.info(f"Adding {len(new_dataset)} samples from Sintel Stereo")
|
||||||
|
elif dataset_name == 'falling_things':
|
||||||
|
new_dataset = FallingThings(aug_params)*5
|
||||||
|
logging.info(f"Adding {len(new_dataset)} samples from FallingThings")
|
||||||
|
elif dataset_name.startswith('tartan_air'):
|
||||||
|
new_dataset = TartanAir(aug_params, keywords=dataset_name.split('_')[2:])
|
||||||
|
logging.info(f"Adding {len(new_dataset)} samples from Tartain Air")
|
||||||
|
train_dataset = new_dataset if train_dataset is None else train_dataset + new_dataset
|
||||||
|
|
||||||
|
train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size,
|
||||||
|
pin_memory=True, shuffle=True, num_workers=int(os.environ.get('SLURM_CPUS_PER_TASK', 6))-2, drop_last=True)
|
||||||
|
|
||||||
|
logging.info('Training with %d image pairs' % len(train_dataset))
|
||||||
|
return train_loader
|
||||||
|
|
253
IGEV-Stereo/core/submodule.py
Normal file
253
IGEV-Stereo/core/submodule.py
Normal file
@ -0,0 +1,253 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class BasicConv(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, in_channels, out_channels, deconv=False, is_3d=False, bn=True, relu=True, **kwargs):
|
||||||
|
super(BasicConv, self).__init__()
|
||||||
|
|
||||||
|
self.relu = relu
|
||||||
|
self.use_bn = bn
|
||||||
|
if is_3d:
|
||||||
|
if deconv:
|
||||||
|
self.conv = nn.ConvTranspose3d(in_channels, out_channels, bias=False, **kwargs)
|
||||||
|
else:
|
||||||
|
self.conv = nn.Conv3d(in_channels, out_channels, bias=False, **kwargs)
|
||||||
|
self.bn = nn.BatchNorm3d(out_channels)
|
||||||
|
else:
|
||||||
|
if deconv:
|
||||||
|
self.conv = nn.ConvTranspose2d(in_channels, out_channels, bias=False, **kwargs)
|
||||||
|
else:
|
||||||
|
self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
|
||||||
|
self.bn = nn.BatchNorm2d(out_channels)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv(x)
|
||||||
|
if self.use_bn:
|
||||||
|
x = self.bn(x)
|
||||||
|
if self.relu:
|
||||||
|
x = nn.LeakyReLU()(x)#, inplace=True)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Conv2x(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, in_channels, out_channels, deconv=False, is_3d=False, concat=True, keep_concat=True, bn=True, relu=True, keep_dispc=False):
|
||||||
|
super(Conv2x, self).__init__()
|
||||||
|
self.concat = concat
|
||||||
|
self.is_3d = is_3d
|
||||||
|
if deconv and is_3d:
|
||||||
|
kernel = (4, 4, 4)
|
||||||
|
elif deconv:
|
||||||
|
kernel = 4
|
||||||
|
else:
|
||||||
|
kernel = 3
|
||||||
|
|
||||||
|
if deconv and is_3d and keep_dispc:
|
||||||
|
kernel = (1, 4, 4)
|
||||||
|
stride = (1, 2, 2)
|
||||||
|
padding = (0, 1, 1)
|
||||||
|
self.conv1 = BasicConv(in_channels, out_channels, deconv, is_3d, bn=True, relu=True, kernel_size=kernel, stride=stride, padding=padding)
|
||||||
|
else:
|
||||||
|
self.conv1 = BasicConv(in_channels, out_channels, deconv, is_3d, bn=True, relu=True, kernel_size=kernel, stride=2, padding=1)
|
||||||
|
|
||||||
|
if self.concat:
|
||||||
|
mul = 2 if keep_concat else 1
|
||||||
|
self.conv2 = BasicConv(out_channels*2, out_channels*mul, False, is_3d, bn, relu, kernel_size=3, stride=1, padding=1)
|
||||||
|
else:
|
||||||
|
self.conv2 = BasicConv(out_channels, out_channels, False, is_3d, bn, relu, kernel_size=3, stride=1, padding=1)
|
||||||
|
|
||||||
|
def forward(self, x, rem):
|
||||||
|
x = self.conv1(x)
|
||||||
|
if x.shape != rem.shape:
|
||||||
|
x = F.interpolate(
|
||||||
|
x,
|
||||||
|
size=(rem.shape[-2], rem.shape[-1]),
|
||||||
|
mode='nearest')
|
||||||
|
if self.concat:
|
||||||
|
x = torch.cat((x, rem), 1)
|
||||||
|
else:
|
||||||
|
x = x + rem
|
||||||
|
x = self.conv2(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class BasicConv_IN(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, in_channels, out_channels, deconv=False, is_3d=False, IN=True, relu=True, **kwargs):
|
||||||
|
super(BasicConv_IN, self).__init__()
|
||||||
|
|
||||||
|
self.relu = relu
|
||||||
|
self.use_in = IN
|
||||||
|
if is_3d:
|
||||||
|
if deconv:
|
||||||
|
self.conv = nn.ConvTranspose3d(in_channels, out_channels, bias=False, **kwargs)
|
||||||
|
else:
|
||||||
|
self.conv = nn.Conv3d(in_channels, out_channels, bias=False, **kwargs)
|
||||||
|
self.IN = nn.InstanceNorm3d(out_channels)
|
||||||
|
else:
|
||||||
|
if deconv:
|
||||||
|
self.conv = nn.ConvTranspose2d(in_channels, out_channels, bias=False, **kwargs)
|
||||||
|
else:
|
||||||
|
self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
|
||||||
|
self.IN = nn.InstanceNorm2d(out_channels)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv(x)
|
||||||
|
if self.use_in:
|
||||||
|
x = self.IN(x)
|
||||||
|
if self.relu:
|
||||||
|
x = nn.LeakyReLU()(x)#, inplace=True)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Conv2x_IN(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, in_channels, out_channels, deconv=False, is_3d=False, concat=True, keep_concat=True, IN=True, relu=True, keep_dispc=False):
|
||||||
|
super(Conv2x_IN, self).__init__()
|
||||||
|
self.concat = concat
|
||||||
|
self.is_3d = is_3d
|
||||||
|
if deconv and is_3d:
|
||||||
|
kernel = (4, 4, 4)
|
||||||
|
elif deconv:
|
||||||
|
kernel = 4
|
||||||
|
else:
|
||||||
|
kernel = 3
|
||||||
|
|
||||||
|
if deconv and is_3d and keep_dispc:
|
||||||
|
kernel = (1, 4, 4)
|
||||||
|
stride = (1, 2, 2)
|
||||||
|
padding = (0, 1, 1)
|
||||||
|
self.conv1 = BasicConv_IN(in_channels, out_channels, deconv, is_3d, IN=True, relu=True, kernel_size=kernel, stride=stride, padding=padding)
|
||||||
|
else:
|
||||||
|
self.conv1 = BasicConv_IN(in_channels, out_channels, deconv, is_3d, IN=True, relu=True, kernel_size=kernel, stride=2, padding=1)
|
||||||
|
|
||||||
|
if self.concat:
|
||||||
|
mul = 2 if keep_concat else 1
|
||||||
|
self.conv2 = BasicConv_IN(out_channels*2, out_channels*mul, False, is_3d, IN, relu, kernel_size=3, stride=1, padding=1)
|
||||||
|
else:
|
||||||
|
self.conv2 = BasicConv_IN(out_channels, out_channels, False, is_3d, IN, relu, kernel_size=3, stride=1, padding=1)
|
||||||
|
|
||||||
|
def forward(self, x, rem):
|
||||||
|
x = self.conv1(x)
|
||||||
|
if x.shape != rem.shape:
|
||||||
|
x = F.interpolate(
|
||||||
|
x,
|
||||||
|
size=(rem.shape[-2], rem.shape[-1]),
|
||||||
|
mode='nearest')
|
||||||
|
if self.concat:
|
||||||
|
x = torch.cat((x, rem), 1)
|
||||||
|
else:
|
||||||
|
x = x + rem
|
||||||
|
x = self.conv2(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def groupwise_correlation(fea1, fea2, num_groups):
|
||||||
|
B, C, H, W = fea1.shape
|
||||||
|
assert C % num_groups == 0
|
||||||
|
channels_per_group = C // num_groups
|
||||||
|
cost = (fea1 * fea2).view([B, num_groups, channels_per_group, H, W]).mean(dim=2)
|
||||||
|
assert cost.shape == (B, num_groups, H, W)
|
||||||
|
return cost
|
||||||
|
|
||||||
|
def build_gwc_volume(refimg_fea, targetimg_fea, maxdisp, num_groups):
|
||||||
|
B, C, H, W = refimg_fea.shape
|
||||||
|
volume = refimg_fea.new_zeros([B, num_groups, maxdisp, H, W])
|
||||||
|
for i in range(maxdisp):
|
||||||
|
if i > 0:
|
||||||
|
volume[:, :, i, :, i:] = groupwise_correlation(refimg_fea[:, :, :, i:], targetimg_fea[:, :, :, :-i],
|
||||||
|
num_groups)
|
||||||
|
else:
|
||||||
|
volume[:, :, i, :, :] = groupwise_correlation(refimg_fea, targetimg_fea, num_groups)
|
||||||
|
volume = volume.contiguous()
|
||||||
|
return volume
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def norm_correlation(fea1, fea2):
|
||||||
|
cost = torch.mean(((fea1/(torch.norm(fea1, 2, 1, True)+1e-05)) * (fea2/(torch.norm(fea2, 2, 1, True)+1e-05))), dim=1, keepdim=True)
|
||||||
|
return cost
|
||||||
|
|
||||||
|
def build_norm_correlation_volume(refimg_fea, targetimg_fea, maxdisp):
|
||||||
|
B, C, H, W = refimg_fea.shape
|
||||||
|
volume = refimg_fea.new_zeros([B, 1, maxdisp, H, W])
|
||||||
|
for i in range(maxdisp):
|
||||||
|
if i > 0:
|
||||||
|
volume[:, :, i, :, i:] = norm_correlation(refimg_fea[:, :, :, i:], targetimg_fea[:, :, :, :-i])
|
||||||
|
else:
|
||||||
|
volume[:, :, i, :, :] = norm_correlation(refimg_fea, targetimg_fea)
|
||||||
|
volume = volume.contiguous()
|
||||||
|
return volume
|
||||||
|
|
||||||
|
def correlation(fea1, fea2):
|
||||||
|
cost = torch.sum((fea1 * fea2), dim=1, keepdim=True)
|
||||||
|
return cost
|
||||||
|
|
||||||
|
def build_correlation_volume(refimg_fea, targetimg_fea, maxdisp):
|
||||||
|
B, C, H, W = refimg_fea.shape
|
||||||
|
volume = refimg_fea.new_zeros([B, 1, maxdisp, H, W])
|
||||||
|
for i in range(maxdisp):
|
||||||
|
if i > 0:
|
||||||
|
volume[:, :, i, :, i:] = correlation(refimg_fea[:, :, :, i:], targetimg_fea[:, :, :, :-i])
|
||||||
|
else:
|
||||||
|
volume[:, :, i, :, :] = correlation(refimg_fea, targetimg_fea)
|
||||||
|
volume = volume.contiguous()
|
||||||
|
return volume
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def build_concat_volume(refimg_fea, targetimg_fea, maxdisp):
|
||||||
|
B, C, H, W = refimg_fea.shape
|
||||||
|
volume = refimg_fea.new_zeros([B, 2 * C, maxdisp, H, W])
|
||||||
|
for i in range(maxdisp):
|
||||||
|
if i > 0:
|
||||||
|
volume[:, :C, i, :, :] = refimg_fea[:, :, :, :]
|
||||||
|
volume[:, C:, i, :, i:] = targetimg_fea[:, :, :, :-i]
|
||||||
|
else:
|
||||||
|
volume[:, :C, i, :, :] = refimg_fea
|
||||||
|
volume[:, C:, i, :, :] = targetimg_fea
|
||||||
|
volume = volume.contiguous()
|
||||||
|
return volume
|
||||||
|
|
||||||
|
def disparity_regression(x, maxdisp):
|
||||||
|
assert len(x.shape) == 4
|
||||||
|
disp_values = torch.arange(0, maxdisp, dtype=x.dtype, device=x.device)
|
||||||
|
disp_values = disp_values.view(1, maxdisp, 1, 1)
|
||||||
|
return torch.sum(x * disp_values, 1, keepdim=True)
|
||||||
|
|
||||||
|
|
||||||
|
class FeatureAtt(nn.Module):
|
||||||
|
def __init__(self, cv_chan, feat_chan):
|
||||||
|
super(FeatureAtt, self).__init__()
|
||||||
|
|
||||||
|
self.feat_att = nn.Sequential(
|
||||||
|
BasicConv(feat_chan, feat_chan//2, kernel_size=1, stride=1, padding=0),
|
||||||
|
nn.Conv2d(feat_chan//2, cv_chan, 1))
|
||||||
|
|
||||||
|
def forward(self, cv, feat):
|
||||||
|
'''
|
||||||
|
'''
|
||||||
|
feat_att = self.feat_att(feat).unsqueeze(2)
|
||||||
|
cv = torch.sigmoid(feat_att)*cv
|
||||||
|
return cv
|
||||||
|
|
||||||
|
def context_upsample(disp_low, up_weights):
|
||||||
|
###
|
||||||
|
# cv (b,1,h,w)
|
||||||
|
# sp (b,9,4*h,4*w)
|
||||||
|
###
|
||||||
|
b, c, h, w = disp_low.shape
|
||||||
|
|
||||||
|
disp_unfold = F.unfold(disp_low.reshape(b,c,h,w),3,1,1).reshape(b,-1,h,w)
|
||||||
|
disp_unfold = F.interpolate(disp_unfold,(h*4,w*4),mode='nearest').reshape(b,9,h*4,w*4)
|
||||||
|
|
||||||
|
disp = (disp_unfold*up_weights).sum(1)
|
||||||
|
|
||||||
|
return disp
|
142
IGEV-Stereo/core/update.py
Normal file
142
IGEV-Stereo/core/update.py
Normal file
@ -0,0 +1,142 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from opt_einsum import contract
|
||||||
|
|
||||||
|
class FlowHead(nn.Module):
|
||||||
|
def __init__(self, input_dim=128, hidden_dim=256, output_dim=2):
|
||||||
|
super(FlowHead, self).__init__()
|
||||||
|
self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
|
||||||
|
self.conv2 = nn.Conv2d(hidden_dim, output_dim, 3, padding=1)
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.conv2(self.relu(self.conv1(x)))
|
||||||
|
|
||||||
|
class DispHead(nn.Module):
|
||||||
|
def __init__(self, input_dim=128, hidden_dim=256, output_dim=1):
|
||||||
|
super(DispHead, self).__init__()
|
||||||
|
self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
|
||||||
|
self.conv2 = nn.Conv2d(hidden_dim, output_dim, 3, padding=1)
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.conv2(self.relu(self.conv1(x)))
|
||||||
|
|
||||||
|
class ConvGRU(nn.Module):
|
||||||
|
def __init__(self, hidden_dim, input_dim, kernel_size=3):
|
||||||
|
super(ConvGRU, self).__init__()
|
||||||
|
self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, kernel_size, padding=kernel_size//2)
|
||||||
|
self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, kernel_size, padding=kernel_size//2)
|
||||||
|
self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, kernel_size, padding=kernel_size//2)
|
||||||
|
|
||||||
|
def forward(self, h, cz, cr, cq, *x_list):
|
||||||
|
|
||||||
|
x = torch.cat(x_list, dim=1)
|
||||||
|
hx = torch.cat([h, x], dim=1)
|
||||||
|
z = torch.sigmoid(self.convz(hx) + cz)
|
||||||
|
r = torch.sigmoid(self.convr(hx) + cr)
|
||||||
|
q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1)) + cq)
|
||||||
|
h = (1-z) * h + z * q
|
||||||
|
return h
|
||||||
|
|
||||||
|
class SepConvGRU(nn.Module):
|
||||||
|
def __init__(self, hidden_dim=128, input_dim=192+128):
|
||||||
|
super(SepConvGRU, self).__init__()
|
||||||
|
self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
|
||||||
|
self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
|
||||||
|
self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
|
||||||
|
|
||||||
|
self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
|
||||||
|
self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
|
||||||
|
self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, h, *x):
|
||||||
|
# horizontal
|
||||||
|
x = torch.cat(x, dim=1)
|
||||||
|
hx = torch.cat([h, x], dim=1)
|
||||||
|
z = torch.sigmoid(self.convz1(hx))
|
||||||
|
r = torch.sigmoid(self.convr1(hx))
|
||||||
|
q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1)))
|
||||||
|
h = (1-z) * h + z * q
|
||||||
|
|
||||||
|
# vertical
|
||||||
|
hx = torch.cat([h, x], dim=1)
|
||||||
|
z = torch.sigmoid(self.convz2(hx))
|
||||||
|
r = torch.sigmoid(self.convr2(hx))
|
||||||
|
q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1)))
|
||||||
|
h = (1-z) * h + z * q
|
||||||
|
|
||||||
|
return h
|
||||||
|
|
||||||
|
class BasicMotionEncoder(nn.Module):
|
||||||
|
def __init__(self, args):
|
||||||
|
super(BasicMotionEncoder, self).__init__()
|
||||||
|
self.args = args
|
||||||
|
cor_planes = args.corr_levels * (2*args.corr_radius + 1) * (8+1)
|
||||||
|
self.convc1 = nn.Conv2d(cor_planes, 64, 1, padding=0)
|
||||||
|
self.convc2 = nn.Conv2d(64, 64, 3, padding=1)
|
||||||
|
self.convd1 = nn.Conv2d(1, 64, 7, padding=3)
|
||||||
|
self.convd2 = nn.Conv2d(64, 64, 3, padding=1)
|
||||||
|
self.conv = nn.Conv2d(64+64, 128-1, 3, padding=1)
|
||||||
|
|
||||||
|
def forward(self, disp, corr):
|
||||||
|
cor = F.relu(self.convc1(corr))
|
||||||
|
cor = F.relu(self.convc2(cor))
|
||||||
|
disp_ = F.relu(self.convd1(disp))
|
||||||
|
disp_ = F.relu(self.convd2(disp_))
|
||||||
|
|
||||||
|
cor_disp = torch.cat([cor, disp_], dim=1)
|
||||||
|
out = F.relu(self.conv(cor_disp))
|
||||||
|
return torch.cat([out, disp], dim=1)
|
||||||
|
|
||||||
|
def pool2x(x):
|
||||||
|
return F.avg_pool2d(x, 3, stride=2, padding=1)
|
||||||
|
|
||||||
|
def pool4x(x):
|
||||||
|
return F.avg_pool2d(x, 5, stride=4, padding=1)
|
||||||
|
|
||||||
|
def interp(x, dest):
|
||||||
|
interp_args = {'mode': 'bilinear', 'align_corners': True}
|
||||||
|
return F.interpolate(x, dest.shape[2:], **interp_args)
|
||||||
|
|
||||||
|
class BasicMultiUpdateBlock(nn.Module):
|
||||||
|
def __init__(self, args, hidden_dims=[]):
|
||||||
|
super().__init__()
|
||||||
|
self.args = args
|
||||||
|
self.encoder = BasicMotionEncoder(args)
|
||||||
|
encoder_output_dim = 128
|
||||||
|
|
||||||
|
self.gru04 = ConvGRU(hidden_dims[2], encoder_output_dim + hidden_dims[1] * (args.n_gru_layers > 1))
|
||||||
|
self.gru08 = ConvGRU(hidden_dims[1], hidden_dims[0] * (args.n_gru_layers == 3) + hidden_dims[2])
|
||||||
|
self.gru16 = ConvGRU(hidden_dims[0], hidden_dims[1])
|
||||||
|
self.disp_head = DispHead(hidden_dims[2], hidden_dim=256, output_dim=1)
|
||||||
|
factor = 2**self.args.n_downsample
|
||||||
|
|
||||||
|
self.mask_feat_4 = nn.Sequential(
|
||||||
|
nn.Conv2d(hidden_dims[2], 32, 3, padding=1),
|
||||||
|
nn.ReLU(inplace=True))
|
||||||
|
|
||||||
|
def forward(self, net, inp, corr=None, disp=None, iter04=True, iter08=True, iter16=True, update=True):
|
||||||
|
|
||||||
|
if iter16:
|
||||||
|
net[2] = self.gru16(net[2], *(inp[2]), pool2x(net[1]))
|
||||||
|
if iter08:
|
||||||
|
if self.args.n_gru_layers > 2:
|
||||||
|
net[1] = self.gru08(net[1], *(inp[1]), pool2x(net[0]), interp(net[2], net[1]))
|
||||||
|
else:
|
||||||
|
net[1] = self.gru08(net[1], *(inp[1]), pool2x(net[0]))
|
||||||
|
if iter04:
|
||||||
|
motion_features = self.encoder(disp, corr)
|
||||||
|
if self.args.n_gru_layers > 1:
|
||||||
|
net[0] = self.gru04(net[0], *(inp[0]), motion_features, interp(net[1], net[0]))
|
||||||
|
else:
|
||||||
|
net[0] = self.gru04(net[0], *(inp[0]), motion_features)
|
||||||
|
|
||||||
|
if not update:
|
||||||
|
return net
|
||||||
|
|
||||||
|
delta_disp = self.disp_head(net[0])
|
||||||
|
mask_feat_4 = self.mask_feat_4(net[0])
|
||||||
|
return net, mask_feat_4, delta_disp
|
0
IGEV-Stereo/core/utils/__init__.py
Normal file
0
IGEV-Stereo/core/utils/__init__.py
Normal file
321
IGEV-Stereo/core/utils/augmentor.py
Normal file
321
IGEV-Stereo/core/utils/augmentor.py
Normal file
@ -0,0 +1,321 @@
|
|||||||
|
import numpy as np
|
||||||
|
import random
|
||||||
|
import warnings
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from glob import glob
|
||||||
|
from skimage import color, io
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
cv2.setNumThreads(0)
|
||||||
|
cv2.ocl.setUseOpenCL(False)
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torchvision.transforms import ColorJitter, functional, Compose
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
def get_middlebury_images():
|
||||||
|
root = "datasets/Middlebury/MiddEval3"
|
||||||
|
with open(os.path.join(root, "official_train.txt"), 'r') as f:
|
||||||
|
lines = f.read().splitlines()
|
||||||
|
return sorted([os.path.join(root, 'trainingQ', f'{name}/im0.png') for name in lines])
|
||||||
|
|
||||||
|
def get_eth3d_images():
|
||||||
|
return sorted(glob('datasets/ETH3D/two_view_training/*/im0.png'))
|
||||||
|
|
||||||
|
def get_kitti_images():
|
||||||
|
return sorted(glob('datasets/KITTI/training/image_2/*_10.png'))
|
||||||
|
|
||||||
|
def transfer_color(image, style_mean, style_stddev):
|
||||||
|
reference_image_lab = color.rgb2lab(image)
|
||||||
|
reference_stddev = np.std(reference_image_lab, axis=(0,1), keepdims=True)# + 1
|
||||||
|
reference_mean = np.mean(reference_image_lab, axis=(0,1), keepdims=True)
|
||||||
|
|
||||||
|
reference_image_lab = reference_image_lab - reference_mean
|
||||||
|
lamb = style_stddev/reference_stddev
|
||||||
|
style_image_lab = lamb * reference_image_lab
|
||||||
|
output_image_lab = style_image_lab + style_mean
|
||||||
|
l, a, b = np.split(output_image_lab, 3, axis=2)
|
||||||
|
l = l.clip(0, 100)
|
||||||
|
output_image_lab = np.concatenate((l,a,b), axis=2)
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.simplefilter("ignore", category=UserWarning)
|
||||||
|
output_image_rgb = color.lab2rgb(output_image_lab) * 255
|
||||||
|
return output_image_rgb
|
||||||
|
|
||||||
|
class AdjustGamma(object):
|
||||||
|
|
||||||
|
def __init__(self, gamma_min, gamma_max, gain_min=1.0, gain_max=1.0):
|
||||||
|
self.gamma_min, self.gamma_max, self.gain_min, self.gain_max = gamma_min, gamma_max, gain_min, gain_max
|
||||||
|
|
||||||
|
def __call__(self, sample):
|
||||||
|
gain = random.uniform(self.gain_min, self.gain_max)
|
||||||
|
gamma = random.uniform(self.gamma_min, self.gamma_max)
|
||||||
|
return functional.adjust_gamma(sample, gamma, gain)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"Adjust Gamma {self.gamma_min}, ({self.gamma_max}) and Gain ({self.gain_min}, {self.gain_max})"
|
||||||
|
|
||||||
|
class FlowAugmentor:
|
||||||
|
def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True, yjitter=False, saturation_range=[0.6,1.4], gamma=[1,1,1,1]):
|
||||||
|
|
||||||
|
# spatial augmentation params
|
||||||
|
self.crop_size = crop_size
|
||||||
|
self.min_scale = min_scale
|
||||||
|
self.max_scale = max_scale
|
||||||
|
self.spatial_aug_prob = 1.0
|
||||||
|
self.stretch_prob = 0.8
|
||||||
|
self.max_stretch = 0.2
|
||||||
|
|
||||||
|
# flip augmentation params
|
||||||
|
self.yjitter = yjitter
|
||||||
|
self.do_flip = do_flip
|
||||||
|
self.h_flip_prob = 0.5
|
||||||
|
self.v_flip_prob = 0.1
|
||||||
|
|
||||||
|
# photometric augmentation params
|
||||||
|
self.photo_aug = Compose([ColorJitter(brightness=0.4, contrast=0.4, saturation=saturation_range, hue=0.5/3.14), AdjustGamma(*gamma)])
|
||||||
|
self.asymmetric_color_aug_prob = 0.2
|
||||||
|
self.eraser_aug_prob = 0.5
|
||||||
|
|
||||||
|
def color_transform(self, img1, img2):
|
||||||
|
""" Photometric augmentation """
|
||||||
|
|
||||||
|
# asymmetric
|
||||||
|
if np.random.rand() < self.asymmetric_color_aug_prob:
|
||||||
|
#print("#####44444", img1.shape)
|
||||||
|
img1 = np.array(self.photo_aug(Image.fromarray(img1)), dtype=np.uint8)
|
||||||
|
img2 = np.array(self.photo_aug(Image.fromarray(img2)), dtype=np.uint8)
|
||||||
|
|
||||||
|
# symmetric
|
||||||
|
else:
|
||||||
|
image_stack = np.concatenate([img1, img2], axis=0)
|
||||||
|
image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
|
||||||
|
img1, img2 = np.split(image_stack, 2, axis=0)
|
||||||
|
|
||||||
|
return img1, img2
|
||||||
|
|
||||||
|
def eraser_transform(self, img1, img2, bounds=[50, 100]):
|
||||||
|
""" Occlusion augmentation """
|
||||||
|
|
||||||
|
ht, wd = img1.shape[:2]
|
||||||
|
if np.random.rand() < self.eraser_aug_prob:
|
||||||
|
mean_color = np.mean(img2.reshape(-1, 3), axis=0)
|
||||||
|
for _ in range(np.random.randint(1, 3)):
|
||||||
|
x0 = np.random.randint(0, wd)
|
||||||
|
y0 = np.random.randint(0, ht)
|
||||||
|
dx = np.random.randint(bounds[0], bounds[1])
|
||||||
|
dy = np.random.randint(bounds[0], bounds[1])
|
||||||
|
img2[y0:y0+dy, x0:x0+dx, :] = mean_color
|
||||||
|
|
||||||
|
return img1, img2
|
||||||
|
|
||||||
|
def spatial_transform(self, img1, img2, flow):
|
||||||
|
# randomly sample scale
|
||||||
|
ht, wd = img1.shape[:2]
|
||||||
|
min_scale = np.maximum(
|
||||||
|
(self.crop_size[0] + 8) / float(ht),
|
||||||
|
(self.crop_size[1] + 8) / float(wd))
|
||||||
|
|
||||||
|
scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
|
||||||
|
scale_x = scale
|
||||||
|
scale_y = scale
|
||||||
|
if np.random.rand() < self.stretch_prob:
|
||||||
|
scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
|
||||||
|
scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
|
||||||
|
|
||||||
|
scale_x = np.clip(scale_x, min_scale, None)
|
||||||
|
scale_y = np.clip(scale_y, min_scale, None)
|
||||||
|
|
||||||
|
# print("####22222", flow.shape, scale_x, scale_y)
|
||||||
|
|
||||||
|
if np.random.rand() < self.spatial_aug_prob:
|
||||||
|
# rescale the images
|
||||||
|
img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
|
||||||
|
img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
|
||||||
|
flow = cv2.resize(flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
|
||||||
|
|
||||||
|
flow = flow * [scale_x, scale_y]
|
||||||
|
|
||||||
|
if self.do_flip:
|
||||||
|
if np.random.rand() < self.h_flip_prob and self.do_flip == 'hf': # h-flip
|
||||||
|
img1 = img1[:, ::-1]
|
||||||
|
img2 = img2[:, ::-1]
|
||||||
|
flow = flow[:, ::-1] * [-1.0, 1.0]
|
||||||
|
|
||||||
|
if np.random.rand() < self.h_flip_prob and self.do_flip == 'h': # h-flip for stereo
|
||||||
|
tmp = img1[:, ::-1]
|
||||||
|
img1 = img2[:, ::-1]
|
||||||
|
img2 = tmp
|
||||||
|
|
||||||
|
if np.random.rand() < self.v_flip_prob and self.do_flip == 'v': # v-flip
|
||||||
|
img1 = img1[::-1, :]
|
||||||
|
img2 = img2[::-1, :]
|
||||||
|
flow = flow[::-1, :] * [1.0, -1.0]
|
||||||
|
|
||||||
|
if self.yjitter:
|
||||||
|
y0 = np.random.randint(2, img1.shape[0] - self.crop_size[0] - 2)
|
||||||
|
x0 = np.random.randint(2, img1.shape[1] - self.crop_size[1] - 2)
|
||||||
|
|
||||||
|
y1 = y0 + np.random.randint(-2, 2 + 1)
|
||||||
|
img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
|
||||||
|
img2 = img2[y1:y1+self.crop_size[0], x0:x0+self.crop_size[1]]
|
||||||
|
flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
|
||||||
|
|
||||||
|
else:
|
||||||
|
y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0])
|
||||||
|
x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1])
|
||||||
|
|
||||||
|
img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
|
||||||
|
img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
|
||||||
|
flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
|
||||||
|
|
||||||
|
return img1, img2, flow
|
||||||
|
|
||||||
|
|
||||||
|
def __call__(self, img1, img2, flow):
|
||||||
|
img1, img2 = self.color_transform(img1, img2)
|
||||||
|
img1, img2 = self.eraser_transform(img1, img2)
|
||||||
|
img1, img2, flow = self.spatial_transform(img1, img2, flow)
|
||||||
|
|
||||||
|
img1 = np.ascontiguousarray(img1)
|
||||||
|
img2 = np.ascontiguousarray(img2)
|
||||||
|
flow = np.ascontiguousarray(flow)
|
||||||
|
|
||||||
|
return img1, img2, flow
|
||||||
|
|
||||||
|
class SparseFlowAugmentor:
|
||||||
|
def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False, yjitter=False, saturation_range=[0.7,1.3], gamma=[1,1,1,1]):
|
||||||
|
# spatial augmentation params
|
||||||
|
self.crop_size = crop_size
|
||||||
|
self.min_scale = min_scale
|
||||||
|
self.max_scale = max_scale
|
||||||
|
self.spatial_aug_prob = 0.8
|
||||||
|
self.stretch_prob = 0.8
|
||||||
|
self.max_stretch = 0.2
|
||||||
|
|
||||||
|
# flip augmentation params
|
||||||
|
self.do_flip = do_flip
|
||||||
|
self.h_flip_prob = 0.5
|
||||||
|
self.v_flip_prob = 0.1
|
||||||
|
|
||||||
|
# photometric augmentation params
|
||||||
|
self.photo_aug = Compose([ColorJitter(brightness=0.3, contrast=0.3, saturation=saturation_range, hue=0.3/3.14), AdjustGamma(*gamma)])
|
||||||
|
self.asymmetric_color_aug_prob = 0.2
|
||||||
|
self.eraser_aug_prob = 0.5
|
||||||
|
|
||||||
|
def color_transform(self, img1, img2):
|
||||||
|
image_stack = np.concatenate([img1, img2], axis=0)
|
||||||
|
image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
|
||||||
|
img1, img2 = np.split(image_stack, 2, axis=0)
|
||||||
|
return img1, img2
|
||||||
|
|
||||||
|
def eraser_transform(self, img1, img2):
|
||||||
|
ht, wd = img1.shape[:2]
|
||||||
|
if np.random.rand() < self.eraser_aug_prob:
|
||||||
|
mean_color = np.mean(img2.reshape(-1, 3), axis=0)
|
||||||
|
for _ in range(np.random.randint(1, 3)):
|
||||||
|
x0 = np.random.randint(0, wd)
|
||||||
|
y0 = np.random.randint(0, ht)
|
||||||
|
dx = np.random.randint(50, 100)
|
||||||
|
dy = np.random.randint(50, 100)
|
||||||
|
img2[y0:y0+dy, x0:x0+dx, :] = mean_color
|
||||||
|
|
||||||
|
return img1, img2
|
||||||
|
|
||||||
|
def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0):
|
||||||
|
ht, wd = flow.shape[:2]
|
||||||
|
coords = np.meshgrid(np.arange(wd), np.arange(ht))
|
||||||
|
coords = np.stack(coords, axis=-1)
|
||||||
|
|
||||||
|
coords = coords.reshape(-1, 2).astype(np.float32)
|
||||||
|
flow = flow.reshape(-1, 2).astype(np.float32)
|
||||||
|
valid = valid.reshape(-1).astype(np.float32)
|
||||||
|
|
||||||
|
coords0 = coords[valid>=1]
|
||||||
|
flow0 = flow[valid>=1]
|
||||||
|
|
||||||
|
ht1 = int(round(ht * fy))
|
||||||
|
wd1 = int(round(wd * fx))
|
||||||
|
|
||||||
|
coords1 = coords0 * [fx, fy]
|
||||||
|
flow1 = flow0 * [fx, fy]
|
||||||
|
|
||||||
|
xx = np.round(coords1[:,0]).astype(np.int32)
|
||||||
|
yy = np.round(coords1[:,1]).astype(np.int32)
|
||||||
|
|
||||||
|
v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1)
|
||||||
|
xx = xx[v]
|
||||||
|
yy = yy[v]
|
||||||
|
flow1 = flow1[v]
|
||||||
|
|
||||||
|
flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32)
|
||||||
|
valid_img = np.zeros([ht1, wd1], dtype=np.int32)
|
||||||
|
|
||||||
|
flow_img[yy, xx] = flow1
|
||||||
|
valid_img[yy, xx] = 1
|
||||||
|
|
||||||
|
return flow_img, valid_img
|
||||||
|
|
||||||
|
def spatial_transform(self, img1, img2, flow, valid):
|
||||||
|
# randomly sample scale
|
||||||
|
|
||||||
|
ht, wd = img1.shape[:2]
|
||||||
|
min_scale = np.maximum(
|
||||||
|
(self.crop_size[0] + 1) / float(ht),
|
||||||
|
(self.crop_size[1] + 1) / float(wd))
|
||||||
|
|
||||||
|
scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
|
||||||
|
scale_x = np.clip(scale, min_scale, None)
|
||||||
|
scale_y = np.clip(scale, min_scale, None)
|
||||||
|
|
||||||
|
if np.random.rand() < self.spatial_aug_prob:
|
||||||
|
# rescale the images
|
||||||
|
img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
|
||||||
|
img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
|
||||||
|
flow, valid = self.resize_sparse_flow_map(flow, valid, fx=scale_x, fy=scale_y)
|
||||||
|
|
||||||
|
if self.do_flip:
|
||||||
|
if np.random.rand() < self.h_flip_prob and self.do_flip == 'hf': # h-flip
|
||||||
|
img1 = img1[:, ::-1]
|
||||||
|
img2 = img2[:, ::-1]
|
||||||
|
flow = flow[:, ::-1] * [-1.0, 1.0]
|
||||||
|
|
||||||
|
if np.random.rand() < self.h_flip_prob and self.do_flip == 'h': # h-flip for stereo
|
||||||
|
tmp = img1[:, ::-1]
|
||||||
|
img1 = img2[:, ::-1]
|
||||||
|
img2 = tmp
|
||||||
|
|
||||||
|
if np.random.rand() < self.v_flip_prob and self.do_flip == 'v': # v-flip
|
||||||
|
img1 = img1[::-1, :]
|
||||||
|
img2 = img2[::-1, :]
|
||||||
|
flow = flow[::-1, :] * [1.0, -1.0]
|
||||||
|
|
||||||
|
margin_y = 20
|
||||||
|
margin_x = 50
|
||||||
|
|
||||||
|
y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0] + margin_y)
|
||||||
|
x0 = np.random.randint(-margin_x, img1.shape[1] - self.crop_size[1] + margin_x)
|
||||||
|
|
||||||
|
y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0])
|
||||||
|
x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1])
|
||||||
|
|
||||||
|
img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
|
||||||
|
img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
|
||||||
|
flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
|
||||||
|
valid = valid[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
|
||||||
|
return img1, img2, flow, valid
|
||||||
|
|
||||||
|
|
||||||
|
def __call__(self, img1, img2, flow, valid):
|
||||||
|
img1, img2 = self.color_transform(img1, img2)
|
||||||
|
img1, img2 = self.eraser_transform(img1, img2)
|
||||||
|
img1, img2, flow, valid = self.spatial_transform(img1, img2, flow, valid)
|
||||||
|
|
||||||
|
img1 = np.ascontiguousarray(img1)
|
||||||
|
img2 = np.ascontiguousarray(img2)
|
||||||
|
flow = np.ascontiguousarray(flow)
|
||||||
|
valid = np.ascontiguousarray(valid)
|
||||||
|
|
||||||
|
return img1, img2, flow, valid
|
187
IGEV-Stereo/core/utils/frame_utils.py
Normal file
187
IGEV-Stereo/core/utils/frame_utils.py
Normal file
@ -0,0 +1,187 @@
|
|||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
from os.path import *
|
||||||
|
import re
|
||||||
|
import json
|
||||||
|
import imageio
|
||||||
|
import cv2
|
||||||
|
cv2.setNumThreads(0)
|
||||||
|
cv2.ocl.setUseOpenCL(False)
|
||||||
|
|
||||||
|
TAG_CHAR = np.array([202021.25], np.float32)
|
||||||
|
|
||||||
|
def readFlow(fn):
|
||||||
|
""" Read .flo file in Middlebury format"""
|
||||||
|
# Code adapted from:
|
||||||
|
# http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy
|
||||||
|
|
||||||
|
# WARNING: this will work on little-endian architectures (eg Intel x86) only!
|
||||||
|
# print 'fn = %s'%(fn)
|
||||||
|
with open(fn, 'rb') as f:
|
||||||
|
magic = np.fromfile(f, np.float32, count=1)
|
||||||
|
if 202021.25 != magic:
|
||||||
|
print('Magic number incorrect. Invalid .flo file')
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
w = np.fromfile(f, np.int32, count=1)
|
||||||
|
h = np.fromfile(f, np.int32, count=1)
|
||||||
|
# print 'Reading %d x %d flo file\n' % (w, h)
|
||||||
|
data = np.fromfile(f, np.float32, count=2*int(w)*int(h))
|
||||||
|
# Reshape data into 3D array (columns, rows, bands)
|
||||||
|
# The reshape here is for visualization, the original code is (w,h,2)
|
||||||
|
return np.resize(data, (int(h), int(w), 2))
|
||||||
|
|
||||||
|
def readPFM(file):
|
||||||
|
file = open(file, 'rb')
|
||||||
|
|
||||||
|
color = None
|
||||||
|
width = None
|
||||||
|
height = None
|
||||||
|
scale = None
|
||||||
|
endian = None
|
||||||
|
|
||||||
|
header = file.readline().rstrip()
|
||||||
|
if header == b'PF':
|
||||||
|
color = True
|
||||||
|
elif header == b'Pf':
|
||||||
|
color = False
|
||||||
|
else:
|
||||||
|
raise Exception('Not a PFM file.')
|
||||||
|
|
||||||
|
dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline())
|
||||||
|
if dim_match:
|
||||||
|
width, height = map(int, dim_match.groups())
|
||||||
|
else:
|
||||||
|
raise Exception('Malformed PFM header.')
|
||||||
|
|
||||||
|
scale = float(file.readline().rstrip())
|
||||||
|
if scale < 0: # little-endian
|
||||||
|
endian = '<'
|
||||||
|
scale = -scale
|
||||||
|
else:
|
||||||
|
endian = '>' # big-endian
|
||||||
|
|
||||||
|
data = np.fromfile(file, endian + 'f')
|
||||||
|
shape = (height, width, 3) if color else (height, width)
|
||||||
|
|
||||||
|
data = np.reshape(data, shape)
|
||||||
|
data = np.flipud(data)
|
||||||
|
return data
|
||||||
|
|
||||||
|
def writePFM(file, array):
|
||||||
|
import os
|
||||||
|
assert type(file) is str and type(array) is np.ndarray and \
|
||||||
|
os.path.splitext(file)[1] == ".pfm"
|
||||||
|
with open(file, 'wb') as f:
|
||||||
|
H, W = array.shape
|
||||||
|
headers = ["Pf\n", f"{W} {H}\n", "-1\n"]
|
||||||
|
for header in headers:
|
||||||
|
f.write(str.encode(header))
|
||||||
|
array = np.flip(array, axis=0).astype(np.float32)
|
||||||
|
f.write(array.tobytes())
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def writeFlow(filename,uv,v=None):
|
||||||
|
""" Write optical flow to file.
|
||||||
|
|
||||||
|
If v is None, uv is assumed to contain both u and v channels,
|
||||||
|
stacked in depth.
|
||||||
|
Original code by Deqing Sun, adapted from Daniel Scharstein.
|
||||||
|
"""
|
||||||
|
nBands = 2
|
||||||
|
|
||||||
|
if v is None:
|
||||||
|
assert(uv.ndim == 3)
|
||||||
|
assert(uv.shape[2] == 2)
|
||||||
|
u = uv[:,:,0]
|
||||||
|
v = uv[:,:,1]
|
||||||
|
else:
|
||||||
|
u = uv
|
||||||
|
|
||||||
|
assert(u.shape == v.shape)
|
||||||
|
height,width = u.shape
|
||||||
|
f = open(filename,'wb')
|
||||||
|
# write the header
|
||||||
|
f.write(TAG_CHAR)
|
||||||
|
np.array(width).astype(np.int32).tofile(f)
|
||||||
|
np.array(height).astype(np.int32).tofile(f)
|
||||||
|
# arrange into matrix form
|
||||||
|
tmp = np.zeros((height, width*nBands))
|
||||||
|
tmp[:,np.arange(width)*2] = u
|
||||||
|
tmp[:,np.arange(width)*2 + 1] = v
|
||||||
|
tmp.astype(np.float32).tofile(f)
|
||||||
|
f.close()
|
||||||
|
|
||||||
|
|
||||||
|
def readFlowKITTI(filename):
|
||||||
|
flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR)
|
||||||
|
flow = flow[:,:,::-1].astype(np.float32)
|
||||||
|
flow, valid = flow[:, :, :2], flow[:, :, 2]
|
||||||
|
flow = (flow - 2**15) / 64.0
|
||||||
|
return flow, valid
|
||||||
|
|
||||||
|
def readDispKITTI(filename):
|
||||||
|
disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0
|
||||||
|
valid = disp > 0.0
|
||||||
|
return disp, valid
|
||||||
|
|
||||||
|
# Method taken from /n/fs/raft-depth/RAFT-Stereo/datasets/SintelStereo/sdk/python/sintel_io.py
|
||||||
|
def readDispSintelStereo(file_name):
|
||||||
|
a = np.array(Image.open(file_name))
|
||||||
|
d_r, d_g, d_b = np.split(a, axis=2, indices_or_sections=3)
|
||||||
|
disp = (d_r * 4 + d_g / (2**6) + d_b / (2**14))[..., 0]
|
||||||
|
mask = np.array(Image.open(file_name.replace('disparities', 'occlusions')))
|
||||||
|
valid = ((mask == 0) & (disp > 0))
|
||||||
|
return disp, valid
|
||||||
|
|
||||||
|
# Method taken from https://research.nvidia.com/sites/default/files/pubs/2018-06_Falling-Things/readme_0.txt
|
||||||
|
def readDispFallingThings(file_name):
|
||||||
|
a = np.array(Image.open(file_name))
|
||||||
|
with open('/'.join(file_name.split('/')[:-1] + ['_camera_settings.json']), 'r') as f:
|
||||||
|
intrinsics = json.load(f)
|
||||||
|
fx = intrinsics['camera_settings'][0]['intrinsic_settings']['fx']
|
||||||
|
disp = (fx * 6.0 * 100) / a.astype(np.float32)
|
||||||
|
valid = disp > 0
|
||||||
|
return disp, valid
|
||||||
|
|
||||||
|
# Method taken from https://github.com/castacks/tartanair_tools/blob/master/data_type.md
|
||||||
|
def readDispTartanAir(file_name):
|
||||||
|
depth = np.load(file_name)
|
||||||
|
disp = 80.0 / depth
|
||||||
|
valid = disp > 0
|
||||||
|
return disp, valid
|
||||||
|
|
||||||
|
|
||||||
|
def readDispMiddlebury(file_name):
|
||||||
|
assert basename(file_name) == 'disp0GT.pfm'
|
||||||
|
disp = readPFM(file_name).astype(np.float32)
|
||||||
|
assert len(disp.shape) == 2
|
||||||
|
nocc_pix = file_name.replace('disp0GT.pfm', 'mask0nocc.png')
|
||||||
|
assert exists(nocc_pix)
|
||||||
|
nocc_pix = imageio.imread(nocc_pix) == 255
|
||||||
|
assert np.any(nocc_pix)
|
||||||
|
return disp, nocc_pix
|
||||||
|
|
||||||
|
def writeFlowKITTI(filename, uv):
|
||||||
|
uv = 64.0 * uv + 2**15
|
||||||
|
valid = np.ones([uv.shape[0], uv.shape[1], 1])
|
||||||
|
uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16)
|
||||||
|
cv2.imwrite(filename, uv[..., ::-1])
|
||||||
|
|
||||||
|
|
||||||
|
def read_gen(file_name, pil=False):
|
||||||
|
ext = splitext(file_name)[-1]
|
||||||
|
if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg':
|
||||||
|
return Image.open(file_name)
|
||||||
|
elif ext == '.bin' or ext == '.raw':
|
||||||
|
return np.load(file_name)
|
||||||
|
elif ext == '.flo':
|
||||||
|
return readFlow(file_name).astype(np.float32)
|
||||||
|
elif ext == '.pfm':
|
||||||
|
flow = readPFM(file_name).astype(np.float32)
|
||||||
|
if len(flow.shape) == 2:
|
||||||
|
return flow
|
||||||
|
else:
|
||||||
|
return flow[:, :, :-1]
|
||||||
|
return []
|
97
IGEV-Stereo/core/utils/utils.py
Normal file
97
IGEV-Stereo/core/utils/utils.py
Normal file
@ -0,0 +1,97 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import numpy as np
|
||||||
|
from scipy import interpolate
|
||||||
|
|
||||||
|
|
||||||
|
class InputPadder:
|
||||||
|
""" Pads images such that dimensions are divisible by 8 """
|
||||||
|
def __init__(self, dims, mode='sintel', divis_by=8):
|
||||||
|
self.ht, self.wd = dims[-2:]
|
||||||
|
pad_ht = (((self.ht // divis_by) + 1) * divis_by - self.ht) % divis_by
|
||||||
|
pad_wd = (((self.wd // divis_by) + 1) * divis_by - self.wd) % divis_by
|
||||||
|
if mode == 'sintel':
|
||||||
|
self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2]
|
||||||
|
else:
|
||||||
|
self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht]
|
||||||
|
|
||||||
|
def pad(self, *inputs):
|
||||||
|
assert all((x.ndim == 4) for x in inputs)
|
||||||
|
return [F.pad(x, self._pad, mode='replicate') for x in inputs]
|
||||||
|
|
||||||
|
def unpad(self, x):
|
||||||
|
assert x.ndim == 4
|
||||||
|
ht, wd = x.shape[-2:]
|
||||||
|
c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]]
|
||||||
|
return x[..., c[0]:c[1], c[2]:c[3]]
|
||||||
|
|
||||||
|
def forward_interpolate(flow):
|
||||||
|
flow = flow.detach().cpu().numpy()
|
||||||
|
dx, dy = flow[0], flow[1]
|
||||||
|
|
||||||
|
ht, wd = dx.shape
|
||||||
|
x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht))
|
||||||
|
|
||||||
|
x1 = x0 + dx
|
||||||
|
y1 = y0 + dy
|
||||||
|
|
||||||
|
x1 = x1.reshape(-1)
|
||||||
|
y1 = y1.reshape(-1)
|
||||||
|
dx = dx.reshape(-1)
|
||||||
|
dy = dy.reshape(-1)
|
||||||
|
|
||||||
|
valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht)
|
||||||
|
x1 = x1[valid]
|
||||||
|
y1 = y1[valid]
|
||||||
|
dx = dx[valid]
|
||||||
|
dy = dy[valid]
|
||||||
|
|
||||||
|
flow_x = interpolate.griddata(
|
||||||
|
(x1, y1), dx, (x0, y0), method='nearest', fill_value=0)
|
||||||
|
|
||||||
|
flow_y = interpolate.griddata(
|
||||||
|
(x1, y1), dy, (x0, y0), method='nearest', fill_value=0)
|
||||||
|
|
||||||
|
flow = np.stack([flow_x, flow_y], axis=0)
|
||||||
|
return torch.from_numpy(flow).float()
|
||||||
|
|
||||||
|
|
||||||
|
def bilinear_sampler(img, coords, mode='bilinear', mask=False):
|
||||||
|
""" Wrapper for grid_sample, uses pixel coordinates """
|
||||||
|
H, W = img.shape[-2:]
|
||||||
|
|
||||||
|
# print("$$$55555", img.shape, coords.shape)
|
||||||
|
xgrid, ygrid = coords.split([1,1], dim=-1)
|
||||||
|
xgrid = 2*xgrid/(W-1) - 1
|
||||||
|
|
||||||
|
# print("######88888", xgrid)
|
||||||
|
assert torch.unique(ygrid).numel() == 1 and H == 1 # This is a stereo problem
|
||||||
|
|
||||||
|
grid = torch.cat([xgrid, ygrid], dim=-1)
|
||||||
|
# print("###37777", grid.shape)
|
||||||
|
img = F.grid_sample(img, grid, align_corners=True)
|
||||||
|
if mask:
|
||||||
|
mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
|
||||||
|
return img, mask.float()
|
||||||
|
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
def coords_grid(batch, ht, wd):
|
||||||
|
coords = torch.meshgrid(torch.arange(ht), torch.arange(wd))
|
||||||
|
coords = torch.stack(coords[::-1], dim=0).float()
|
||||||
|
return coords[None].repeat(batch, 1, 1, 1)
|
||||||
|
|
||||||
|
|
||||||
|
def upflow8(flow, mode='bilinear'):
|
||||||
|
new_size = (8 * flow.shape[2], 8 * flow.shape[3])
|
||||||
|
return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
|
||||||
|
|
||||||
|
def gauss_blur(input, N=5, std=1):
|
||||||
|
B, D, H, W = input.shape
|
||||||
|
x, y = torch.meshgrid(torch.arange(N).float() - N//2, torch.arange(N).float() - N//2)
|
||||||
|
unnormalized_gaussian = torch.exp(-(x.pow(2) + y.pow(2)) / (2 * std ** 2))
|
||||||
|
weights = unnormalized_gaussian / unnormalized_gaussian.sum().clamp(min=1e-4)
|
||||||
|
weights = weights.view(1,1,N,N).to(input)
|
||||||
|
output = F.conv2d(input.reshape(B*D,1,H,W), weights, padding=N//2)
|
||||||
|
return output.view(B, D, H, W)
|
BIN
IGEV-Stereo/demo-imgs/Motorcycle/im0.png
Normal file
BIN
IGEV-Stereo/demo-imgs/Motorcycle/im0.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 2.1 MiB |
BIN
IGEV-Stereo/demo-imgs/Motorcycle/im1.png
Normal file
BIN
IGEV-Stereo/demo-imgs/Motorcycle/im1.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 2.1 MiB |
BIN
IGEV-Stereo/demo-imgs/PlaytableP/im0.png
Normal file
BIN
IGEV-Stereo/demo-imgs/PlaytableP/im0.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 1.9 MiB |
BIN
IGEV-Stereo/demo-imgs/PlaytableP/im1.png
Normal file
BIN
IGEV-Stereo/demo-imgs/PlaytableP/im1.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 1.9 MiB |
88
IGEV-Stereo/demo_imgs.py
Normal file
88
IGEV-Stereo/demo_imgs.py
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
import sys
|
||||||
|
sys.path.append('core')
|
||||||
|
DEVICE = 'cuda'
|
||||||
|
import os
|
||||||
|
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
||||||
|
import argparse
|
||||||
|
import glob
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from tqdm import tqdm
|
||||||
|
from pathlib import Path
|
||||||
|
from igev_stereo import IGEVStereo
|
||||||
|
from utils.utils import InputPadder
|
||||||
|
from PIL import Image
|
||||||
|
from matplotlib import pyplot as plt
|
||||||
|
import os
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
def load_image(imfile):
|
||||||
|
img = np.array(Image.open(imfile)).astype(np.uint8)
|
||||||
|
img = torch.from_numpy(img).permute(2, 0, 1).float()
|
||||||
|
return img[None].to(DEVICE)
|
||||||
|
|
||||||
|
def demo(args):
|
||||||
|
model = torch.nn.DataParallel(IGEVStereo(args), device_ids=[0])
|
||||||
|
model.load_state_dict(torch.load(args.restore_ckpt))
|
||||||
|
|
||||||
|
model = model.module
|
||||||
|
model.to(DEVICE)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
output_directory = Path(args.output_directory)
|
||||||
|
output_directory.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
left_images = sorted(glob.glob(args.left_imgs, recursive=True))
|
||||||
|
right_images = sorted(glob.glob(args.right_imgs, recursive=True))
|
||||||
|
print(f"Found {len(left_images)} images. Saving files to {output_directory}/")
|
||||||
|
|
||||||
|
for (imfile1, imfile2) in tqdm(list(zip(left_images, right_images))):
|
||||||
|
image1 = load_image(imfile1)
|
||||||
|
image2 = load_image(imfile2)
|
||||||
|
|
||||||
|
padder = InputPadder(image1.shape, divis_by=32)
|
||||||
|
image1, image2 = padder.pad(image1, image2)
|
||||||
|
|
||||||
|
disp = model(image1, image2, iters=args.valid_iters, test_mode=True)
|
||||||
|
disp = disp.cpu().numpy()
|
||||||
|
disp = padder.unpad(disp)
|
||||||
|
file_stem = imfile1.split('/')[-2]
|
||||||
|
filename = os.path.join(output_directory, f"{file_stem}.png")
|
||||||
|
plt.imsave(output_directory / f"{file_stem}.png", disp.squeeze(), cmap='jet')
|
||||||
|
# disp = np.round(disp * 256).astype(np.uint16)
|
||||||
|
# cv2.imwrite(filename, cv2.applyColorMap(cv2.convertScaleAbs(disp.squeeze(), alpha=0.01),cv2.COLORMAP_JET), [int(cv2.IMWRITE_PNG_COMPRESSION), 0])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--restore_ckpt', help="restore checkpoint", default='./pretrained_models/sceneflow/sceneflow.pth')
|
||||||
|
parser.add_argument('--save_numpy', action='store_true', help='save output as numpy arrays')
|
||||||
|
|
||||||
|
parser.add_argument('-l', '--left_imgs', help="path to all first (left) frames", default="./demo-imgs/*/im0.png")
|
||||||
|
parser.add_argument('-r', '--right_imgs', help="path to all second (right) frames", default="./demo-imgs/*/im1.png")
|
||||||
|
|
||||||
|
# parser.add_argument('-l', '--left_imgs', help="path to all first (left) frames", default="/data/Middlebury/trainingH/*/im0.png")
|
||||||
|
# parser.add_argument('-r', '--right_imgs', help="path to all second (right) frames", default="/data/Middlebury/trainingH/*/im1.png")
|
||||||
|
# parser.add_argument('-l', '--left_imgs', help="path to all first (left) frames", default="/data/ETH3D/two_view_training/*/im0.png")
|
||||||
|
# parser.add_argument('-r', '--right_imgs', help="path to all second (right) frames", default="/data/ETH3D/two_view_training/*/im1.png")
|
||||||
|
parser.add_argument('--output_directory', help="directory to save output", default="./demo-output/")
|
||||||
|
parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
|
||||||
|
parser.add_argument('--valid_iters', type=int, default=32, help='number of flow-field updates during forward pass')
|
||||||
|
|
||||||
|
# Architecture choices
|
||||||
|
parser.add_argument('--hidden_dims', nargs='+', type=int, default=[128]*3, help="hidden state and context dimensions")
|
||||||
|
parser.add_argument('--corr_implementation', choices=["reg", "alt", "reg_cuda", "alt_cuda"], default="reg", help="correlation volume implementation")
|
||||||
|
parser.add_argument('--shared_backbone', action='store_true', help="use a single backbone for the context and feature encoders")
|
||||||
|
parser.add_argument('--corr_levels', type=int, default=2, help="number of levels in the correlation pyramid")
|
||||||
|
parser.add_argument('--corr_radius', type=int, default=4, help="width of the correlation pyramid")
|
||||||
|
parser.add_argument('--n_downsample', type=int, default=2, help="resolution of the disparity field (1/2^K)")
|
||||||
|
parser.add_argument('--slow_fast_gru', action='store_true', help="iterate the low-res GRUs more frequently")
|
||||||
|
parser.add_argument('--n_gru_layers', type=int, default=3, help="number of hidden GRU levels")
|
||||||
|
parser.add_argument('--max_disp', type=int, default=192, help="max disp of geometry encoding volume")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
Path(args.output_directory).mkdir(exist_ok=True, parents=True)
|
||||||
|
|
||||||
|
demo(args)
|
94
IGEV-Stereo/demo_video.py
Normal file
94
IGEV-Stereo/demo_video.py
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
import sys
|
||||||
|
sys.path.append('core')
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import glob
|
||||||
|
from pathlib import Path
|
||||||
|
from tqdm import tqdm
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from igev_stereo import IGEVStereo
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
from utils.utils import InputPadder
|
||||||
|
torch.backends.cudnn.benchmark = True
|
||||||
|
half_precision = True
|
||||||
|
|
||||||
|
|
||||||
|
DEVICE = 'cuda'
|
||||||
|
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description='Iterative Geometry Encoding Volume for Stereo Matching and Multi-View Stereo (IGEV-Stereo)')
|
||||||
|
parser.add_argument('--restore_ckpt', help="restore checkpoint", default='./pretrained_models/kitti/kitti15.pth')
|
||||||
|
parser.add_argument('--save_numpy', action='store_true', help='save output as numpy arrays')
|
||||||
|
parser.add_argument('-l', '--left_imgs', help="path to all first (left) frames", default="/data/KITTI_raw/2011_09_26/2011_09_26_drive_0005_sync/image_02/data/*.png")
|
||||||
|
parser.add_argument('-r', '--right_imgs', help="path to all second (right) frames", default="/data/KITTI_raw/2011_09_26/2011_09_26_drive_0005_sync/image_03/data/*.png")
|
||||||
|
parser.add_argument('--mixed_precision', default=True, action='store_true', help='use mixed precision')
|
||||||
|
parser.add_argument('--valid_iters', type=int, default=16, help='number of flow-field updates during forward pass')
|
||||||
|
parser.add_argument('--hidden_dims', nargs='+', type=int, default=[128]*3, help="hidden state and context dimensions")
|
||||||
|
parser.add_argument('--corr_implementation', choices=["reg", "alt", "reg_cuda", "alt_cuda"], default="reg", help="correlation volume implementation")
|
||||||
|
parser.add_argument('--shared_backbone', action='store_true', help="use a single backbone for the context and feature encoders")
|
||||||
|
parser.add_argument('--corr_levels', type=int, default=2, help="number of levels in the correlation pyramid")
|
||||||
|
parser.add_argument('--corr_radius', type=int, default=4, help="width of the correlation pyramid")
|
||||||
|
parser.add_argument('--n_downsample', type=int, default=2, help="resolution of the disparity field (1/2^K)")
|
||||||
|
parser.add_argument('--slow_fast_gru', action='store_true', help="iterate the low-res GRUs more frequently")
|
||||||
|
parser.add_argument('--n_gru_layers', type=int, default=3, help="number of hidden GRU levels")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
model = torch.nn.DataParallel(IGEVStereo(args), device_ids=[0])
|
||||||
|
model.load_state_dict(torch.load(args.restore_ckpt))
|
||||||
|
model = model.module
|
||||||
|
model.to(DEVICE)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
left_images = sorted(glob.glob(args.left_imgs, recursive=True))
|
||||||
|
right_images = sorted(glob.glob(args.right_imgs, recursive=True))
|
||||||
|
print(f"Found {len(left_images)} images.")
|
||||||
|
|
||||||
|
|
||||||
|
def load_image(imfile):
|
||||||
|
img = np.array(Image.open(imfile)).astype(np.uint8)
|
||||||
|
img = torch.from_numpy(img).permute(2, 0, 1).float()
|
||||||
|
return img[None].to(DEVICE)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
|
||||||
|
fps_list = np.array([])
|
||||||
|
videoWrite = cv2.VideoWriter('./IGEV_Stereo.mp4', cv2.VideoWriter_fourcc(*'mp4v'), 10, (1242, 750))
|
||||||
|
for (imfile1, imfile2) in tqdm(list(zip(left_images, right_images))):
|
||||||
|
image1 = load_image(imfile1)
|
||||||
|
image2 = load_image(imfile2)
|
||||||
|
padder = InputPadder(image1.shape, divis_by=32)
|
||||||
|
image1_pad, image2_pad = padder.pad(image1, image2)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
start = torch.cuda.Event(enable_timing=True)
|
||||||
|
end = torch.cuda.Event(enable_timing=True)
|
||||||
|
start.record()
|
||||||
|
with torch.no_grad():
|
||||||
|
with torch.cuda.amp.autocast(enabled=half_precision):
|
||||||
|
disp = model(image1_pad, image2_pad, iters=16, test_mode=True)
|
||||||
|
disp = padder.unpad(disp)
|
||||||
|
end.record()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
runtime = start.elapsed_time(end)
|
||||||
|
fps = 1000/runtime
|
||||||
|
fps_list = np.append(fps_list, fps)
|
||||||
|
if len(fps_list) > 5:
|
||||||
|
fps_list = fps_list[-5:]
|
||||||
|
avg_fps = np.mean(fps_list)
|
||||||
|
print('Stereo runtime: {:.3f}'.format(1000/avg_fps))
|
||||||
|
|
||||||
|
disp_np = (2*disp).data.cpu().numpy().squeeze().astype(np.uint8)
|
||||||
|
disp_np = cv2.applyColorMap(disp_np, cv2.COLORMAP_PLASMA)
|
||||||
|
image_np = np.array(Image.open(imfile1)).astype(np.uint8)
|
||||||
|
out_img = np.concatenate((image_np, disp_np), 0)
|
||||||
|
cv2.putText(
|
||||||
|
out_img,
|
||||||
|
"%.1f fps" % (avg_fps),
|
||||||
|
(10, image_np.shape[0]+30),
|
||||||
|
cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA)
|
||||||
|
cv2.imshow('img', out_img)
|
||||||
|
cv2.waitKey(1)
|
||||||
|
videoWrite.write(out_img)
|
||||||
|
videoWrite.release()
|
276
IGEV-Stereo/evaluate_stereo.py
Normal file
276
IGEV-Stereo/evaluate_stereo.py
Normal file
@ -0,0 +1,276 @@
|
|||||||
|
from __future__ import print_function, division
|
||||||
|
import sys
|
||||||
|
sys.path.append('core')
|
||||||
|
|
||||||
|
import os
|
||||||
|
# os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import time
|
||||||
|
import logging
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from tqdm import tqdm
|
||||||
|
from igev_stereo import IGEVStereo, autocast
|
||||||
|
import stereo_datasets as datasets
|
||||||
|
from utils.utils import InputPadder
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
def count_parameters(model):
|
||||||
|
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def validate_eth3d(model, iters=32, mixed_prec=False):
|
||||||
|
""" Peform validation using the ETH3D (train) split """
|
||||||
|
model.eval()
|
||||||
|
aug_params = {}
|
||||||
|
val_dataset = datasets.ETH3D(aug_params)
|
||||||
|
|
||||||
|
out_list, epe_list = [], []
|
||||||
|
for val_id in range(len(val_dataset)):
|
||||||
|
(imageL_file, imageR_file, GT_file), image1, image2, flow_gt, valid_gt = val_dataset[val_id]
|
||||||
|
image1 = image1[None].cuda()
|
||||||
|
image2 = image2[None].cuda()
|
||||||
|
|
||||||
|
padder = InputPadder(image1.shape, divis_by=32)
|
||||||
|
image1, image2 = padder.pad(image1, image2)
|
||||||
|
|
||||||
|
with autocast(enabled=mixed_prec):
|
||||||
|
if iters == 0:
|
||||||
|
flow_pr = model(image1, image2, iters=iters, test_mode=True)
|
||||||
|
else:
|
||||||
|
_, flow_pr = model(image1, image2, iters=iters, test_mode=True)
|
||||||
|
flow_pr = padder.unpad(flow_pr.float()).cpu().squeeze(0)
|
||||||
|
assert flow_pr.shape == flow_gt.shape, (flow_pr.shape, flow_gt.shape)
|
||||||
|
epe = torch.sum((flow_pr - flow_gt)**2, dim=0).sqrt()
|
||||||
|
|
||||||
|
epe_flattened = epe.flatten()
|
||||||
|
|
||||||
|
occ_mask = Image.open(GT_file.replace('disp0GT.pfm', 'mask0nocc.png'))
|
||||||
|
|
||||||
|
occ_mask = np.ascontiguousarray(occ_mask).flatten()
|
||||||
|
|
||||||
|
val = (valid_gt.flatten() >= 0.5) & (occ_mask == 255)
|
||||||
|
# val = (valid_gt.flatten() >= 0.5)
|
||||||
|
out = (epe_flattened > 1.0)
|
||||||
|
image_out = out[val].float().mean().item()
|
||||||
|
image_epe = epe_flattened[val].mean().item()
|
||||||
|
logging.info(f"ETH3D {val_id+1} out of {len(val_dataset)}. EPE {round(image_epe,4)} D1 {round(image_out,4)}")
|
||||||
|
epe_list.append(image_epe)
|
||||||
|
out_list.append(image_out)
|
||||||
|
|
||||||
|
epe_list = np.array(epe_list)
|
||||||
|
out_list = np.array(out_list)
|
||||||
|
|
||||||
|
epe = np.mean(epe_list)
|
||||||
|
d1 = 100 * np.mean(out_list)
|
||||||
|
|
||||||
|
print("Validation ETH3D: EPE %f, D1 %f" % (epe, d1))
|
||||||
|
return {'eth3d-epe': epe, 'eth3d-d1': d1}
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def validate_kitti(model, iters=32, mixed_prec=False):
|
||||||
|
""" Peform validation using the KITTI-2015 (train) split """
|
||||||
|
model.eval()
|
||||||
|
aug_params = {}
|
||||||
|
val_dataset = datasets.KITTI(aug_params, image_set='training')
|
||||||
|
torch.backends.cudnn.benchmark = True
|
||||||
|
|
||||||
|
out_list, epe_list, elapsed_list = [], [], []
|
||||||
|
for val_id in range(len(val_dataset)):
|
||||||
|
_, image1, image2, flow_gt, valid_gt = val_dataset[val_id]
|
||||||
|
image1 = image1[None].cuda()
|
||||||
|
image2 = image2[None].cuda()
|
||||||
|
|
||||||
|
padder = InputPadder(image1.shape, divis_by=32)
|
||||||
|
image1, image2 = padder.pad(image1, image2)
|
||||||
|
|
||||||
|
with autocast(enabled=mixed_prec):
|
||||||
|
start = time.time()
|
||||||
|
if iters == 0:
|
||||||
|
flow_pr = model(image1, image2, iters=iters, test_mode=True)
|
||||||
|
else:
|
||||||
|
_, flow_pr = model(image1, image2, iters=iters, test_mode=True)
|
||||||
|
end = time.time()
|
||||||
|
|
||||||
|
if val_id > 50:
|
||||||
|
elapsed_list.append(end-start)
|
||||||
|
flow_pr = padder.unpad(flow_pr).cpu().squeeze(0)
|
||||||
|
|
||||||
|
assert flow_pr.shape == flow_gt.shape, (flow_pr.shape, flow_gt.shape)
|
||||||
|
epe = torch.sum((flow_pr - flow_gt)**2, dim=0).sqrt()
|
||||||
|
|
||||||
|
epe_flattened = epe.flatten()
|
||||||
|
val = (valid_gt.flatten() >= 0.5) & (flow_gt.abs().flatten() < 192)
|
||||||
|
# val = valid_gt.flatten() >= 0.5
|
||||||
|
|
||||||
|
out = (epe_flattened > 3.0)
|
||||||
|
image_out = out[val].float().mean().item()
|
||||||
|
image_epe = epe_flattened[val].mean().item()
|
||||||
|
if val_id < 9 or (val_id+1)%10 == 0:
|
||||||
|
logging.info(f"KITTI Iter {val_id+1} out of {len(val_dataset)}. EPE {round(image_epe,4)} D1 {round(image_out,4)}. Runtime: {format(end-start, '.3f')}s ({format(1/(end-start), '.2f')}-FPS)")
|
||||||
|
epe_list.append(epe_flattened[val].mean().item())
|
||||||
|
out_list.append(out[val].cpu().numpy())
|
||||||
|
|
||||||
|
epe_list = np.array(epe_list)
|
||||||
|
out_list = np.concatenate(out_list)
|
||||||
|
|
||||||
|
epe = np.mean(epe_list)
|
||||||
|
d1 = 100 * np.mean(out_list)
|
||||||
|
|
||||||
|
avg_runtime = np.mean(elapsed_list)
|
||||||
|
|
||||||
|
print(f"Validation KITTI: EPE {epe}, D1 {d1}, {format(1/avg_runtime, '.2f')}-FPS ({format(avg_runtime, '.3f')}s)")
|
||||||
|
return {'kitti-epe': epe, 'kitti-d1': d1}
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def validate_sceneflow(model, iters=32, mixed_prec=False):
|
||||||
|
""" Peform validation using the Scene Flow (TEST) split """
|
||||||
|
model.eval()
|
||||||
|
val_dataset = datasets.SceneFlowDatasets(dstype='frames_finalpass', things_test=True)
|
||||||
|
|
||||||
|
out_list, epe_list = [], []
|
||||||
|
for val_id in tqdm(range(len(val_dataset))):
|
||||||
|
_, image1, image2, flow_gt, valid_gt = val_dataset[val_id]
|
||||||
|
|
||||||
|
image1 = image1[None].cuda()
|
||||||
|
image2 = image2[None].cuda()
|
||||||
|
|
||||||
|
padder = InputPadder(image1.shape, divis_by=32)
|
||||||
|
image1, image2 = padder.pad(image1, image2)
|
||||||
|
|
||||||
|
with autocast(enabled=mixed_prec):
|
||||||
|
if iters == 0:
|
||||||
|
flow_pr = model(image1, image2, iters=iters, test_mode=True)
|
||||||
|
else:
|
||||||
|
flow_pr = model(image1, image2, iters=iters, test_mode=True)
|
||||||
|
flow_pr = padder.unpad(flow_pr).cpu().squeeze(0)
|
||||||
|
assert flow_pr.shape == flow_gt.shape, (flow_pr.shape, flow_gt.shape)
|
||||||
|
|
||||||
|
# epe = torch.sum((flow_pr - flow_gt)**2, dim=0).sqrt()
|
||||||
|
epe = torch.abs(flow_pr - flow_gt)
|
||||||
|
|
||||||
|
epe = epe.flatten()
|
||||||
|
val = (valid_gt.flatten() >= 0.5) & (flow_gt.abs().flatten() < 192)
|
||||||
|
|
||||||
|
if(np.isnan(epe[val].mean().item())):
|
||||||
|
continue
|
||||||
|
|
||||||
|
out = (epe > 3.0)
|
||||||
|
epe_list.append(epe[val].mean().item())
|
||||||
|
out_list.append(out[val].cpu().numpy())
|
||||||
|
# if val_id == 400:
|
||||||
|
# break
|
||||||
|
|
||||||
|
epe_list = np.array(epe_list)
|
||||||
|
out_list = np.concatenate(out_list)
|
||||||
|
|
||||||
|
epe = np.mean(epe_list)
|
||||||
|
d1 = 100 * np.mean(out_list)
|
||||||
|
|
||||||
|
f = open('test.txt', 'a')
|
||||||
|
f.write("Validation Scene Flow: %f, %f\n" % (epe, d1))
|
||||||
|
|
||||||
|
print("Validation Scene Flow: %f, %f" % (epe, d1))
|
||||||
|
return {'scene-flow-epe': epe, 'scene-flow-d1': d1}
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def validate_middlebury(model, iters=32, split='F', mixed_prec=False):
|
||||||
|
""" Peform validation using the Middlebury-V3 dataset """
|
||||||
|
model.eval()
|
||||||
|
aug_params = {}
|
||||||
|
val_dataset = datasets.Middlebury(aug_params, split=split)
|
||||||
|
|
||||||
|
out_list, epe_list = [], []
|
||||||
|
for val_id in range(len(val_dataset)):
|
||||||
|
(imageL_file, _, _), image1, image2, flow_gt, valid_gt = val_dataset[val_id]
|
||||||
|
image1 = image1[None].cuda()
|
||||||
|
image2 = image2[None].cuda()
|
||||||
|
|
||||||
|
padder = InputPadder(image1.shape, divis_by=32)
|
||||||
|
image1, image2 = padder.pad(image1, image2)
|
||||||
|
|
||||||
|
with autocast(enabled=mixed_prec):
|
||||||
|
if iters == 0:
|
||||||
|
flow_pr = model(image1, image2, iters=iters, test_mode=True)
|
||||||
|
else:
|
||||||
|
_, flow_pr = model(image1, image2, iters=iters, test_mode=True)
|
||||||
|
flow_pr = padder.unpad(flow_pr).cpu().squeeze(0)
|
||||||
|
|
||||||
|
assert flow_pr.shape == flow_gt.shape, (flow_pr.shape, flow_gt.shape)
|
||||||
|
epe = torch.sum((flow_pr - flow_gt)**2, dim=0).sqrt()
|
||||||
|
|
||||||
|
epe_flattened = epe.flatten()
|
||||||
|
|
||||||
|
occ_mask = Image.open(imageL_file.replace('im0.png', 'mask0nocc.png')).convert('L')
|
||||||
|
occ_mask = np.ascontiguousarray(occ_mask, dtype=np.float32).flatten()
|
||||||
|
|
||||||
|
val = (valid_gt.reshape(-1) >= 0.5) & (flow_gt[0].reshape(-1) < 192) & (occ_mask==255)
|
||||||
|
out = (epe_flattened > 2.0)
|
||||||
|
image_out = out[val].float().mean().item()
|
||||||
|
image_epe = epe_flattened[val].mean().item()
|
||||||
|
logging.info(f"Middlebury Iter {val_id+1} out of {len(val_dataset)}. EPE {round(image_epe,4)} D1 {round(image_out,4)}")
|
||||||
|
epe_list.append(image_epe)
|
||||||
|
out_list.append(image_out)
|
||||||
|
|
||||||
|
epe_list = np.array(epe_list)
|
||||||
|
out_list = np.array(out_list)
|
||||||
|
|
||||||
|
epe = np.mean(epe_list)
|
||||||
|
d1 = 100 * np.mean(out_list)
|
||||||
|
|
||||||
|
print(f"Validation Middlebury{split}: EPE {epe}, D1 {d1}")
|
||||||
|
return {f'middlebury{split}-epe': epe, f'middlebury{split}-d1': d1}
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--restore_ckpt', help="restore checkpoint", default='./pretrained_models/sceneflow/sceneflow.pth')
|
||||||
|
parser.add_argument('--dataset', help="dataset for evaluation", default='sceneflow', choices=["eth3d", "kitti", "sceneflow"] + [f"middlebury_{s}" for s in 'FHQ'])
|
||||||
|
parser.add_argument('--mixed_precision', default=False, action='store_true', help='use mixed precision')
|
||||||
|
parser.add_argument('--valid_iters', type=int, default=32, help='number of flow-field updates during forward pass')
|
||||||
|
|
||||||
|
# Architecure choices
|
||||||
|
parser.add_argument('--hidden_dims', nargs='+', type=int, default=[128]*3, help="hidden state and context dimensions")
|
||||||
|
parser.add_argument('--corr_implementation', choices=["reg", "alt", "reg_cuda", "alt_cuda"], default="reg", help="correlation volume implementation")
|
||||||
|
parser.add_argument('--shared_backbone', action='store_true', help="use a single backbone for the context and feature encoders")
|
||||||
|
parser.add_argument('--corr_levels', type=int, default=2, help="number of levels in the correlation pyramid")
|
||||||
|
parser.add_argument('--corr_radius', type=int, default=4, help="width of the correlation pyramid")
|
||||||
|
parser.add_argument('--n_downsample', type=int, default=2, help="resolution of the disparity field (1/2^K)")
|
||||||
|
parser.add_argument('--slow_fast_gru', action='store_true', help="iterate the low-res GRUs more frequently")
|
||||||
|
parser.add_argument('--n_gru_layers', type=int, default=3, help="number of hidden GRU levels")
|
||||||
|
parser.add_argument('--max_disp', type=int, default=192, help="max disp of geometry encoding volume")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
model = torch.nn.DataParallel(IGEVStereo(args), device_ids=[0])
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO,
|
||||||
|
format='%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s')
|
||||||
|
|
||||||
|
if args.restore_ckpt is not None:
|
||||||
|
assert args.restore_ckpt.endswith(".pth")
|
||||||
|
logging.info("Loading checkpoint...")
|
||||||
|
checkpoint = torch.load(args.restore_ckpt)
|
||||||
|
model.load_state_dict(checkpoint, strict=True)
|
||||||
|
logging.info(f"Done loading checkpoint")
|
||||||
|
|
||||||
|
model.cuda()
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
print(f"The model has {format(count_parameters(model)/1e6, '.2f')}M learnable parameters.")
|
||||||
|
use_mixed_precision = args.corr_implementation.endswith("_cuda")
|
||||||
|
|
||||||
|
if args.dataset == 'eth3d':
|
||||||
|
validate_eth3d(model, iters=args.valid_iters, mixed_prec=use_mixed_precision)
|
||||||
|
|
||||||
|
elif args.dataset == 'kitti':
|
||||||
|
validate_kitti(model, iters=args.valid_iters, mixed_prec=use_mixed_precision)
|
||||||
|
|
||||||
|
elif args.dataset in [f"middlebury_{s}" for s in 'FHQ']:
|
||||||
|
validate_middlebury(model, iters=args.valid_iters, split=args.dataset[-1], mixed_prec=use_mixed_precision)
|
||||||
|
|
||||||
|
elif args.dataset == 'sceneflow':
|
||||||
|
validate_sceneflow(model, iters=args.valid_iters, mixed_prec=use_mixed_precision)
|
82
IGEV-Stereo/save_disp.py
Normal file
82
IGEV-Stereo/save_disp.py
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
import sys
|
||||||
|
sys.path.append('core')
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import glob
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from tqdm import tqdm
|
||||||
|
from pathlib import Path
|
||||||
|
from igev_stereo import IGEVStereo
|
||||||
|
from utils.utils import InputPadder
|
||||||
|
from PIL import Image
|
||||||
|
from matplotlib import pyplot as plt
|
||||||
|
import os
|
||||||
|
import skimage.io
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
|
||||||
|
DEVICE = 'cuda'
|
||||||
|
|
||||||
|
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
||||||
|
|
||||||
|
def load_image(imfile):
|
||||||
|
img = np.array(Image.open(imfile)).astype(np.uint8)
|
||||||
|
img = torch.from_numpy(img).permute(2, 0, 1).float()
|
||||||
|
return img[None].to(DEVICE)
|
||||||
|
|
||||||
|
def demo(args):
|
||||||
|
model = torch.nn.DataParallel(IGEVStereo(args), device_ids=[0])
|
||||||
|
model.load_state_dict(torch.load(args.restore_ckpt))
|
||||||
|
|
||||||
|
model = model.module
|
||||||
|
model.to(DEVICE)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
output_directory = Path(args.output_directory)
|
||||||
|
output_directory.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
left_images = sorted(glob.glob(args.left_imgs, recursive=True))
|
||||||
|
right_images = sorted(glob.glob(args.right_imgs, recursive=True))
|
||||||
|
print(f"Found {len(left_images)} images. Saving files to {output_directory}/")
|
||||||
|
|
||||||
|
for (imfile1, imfile2) in tqdm(list(zip(left_images, right_images))):
|
||||||
|
image1 = load_image(imfile1)
|
||||||
|
image2 = load_image(imfile2)
|
||||||
|
padder = InputPadder(image1.shape, divis_by=32)
|
||||||
|
image1, image2 = padder.pad(image1, image2)
|
||||||
|
disp = model(image1, image2, iters=args.valid_iters, test_mode=True)
|
||||||
|
disp = padder.unpad(disp)
|
||||||
|
file_stem = os.path.join(output_directory, imfile1.split('/')[-1])
|
||||||
|
disp = disp.cpu().numpy().squeeze()
|
||||||
|
disp = np.round(disp * 256).astype(np.uint16)
|
||||||
|
skimage.io.imsave(file_stem, disp)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--restore_ckpt', help="restore checkpoint", default='./pretrained_models/kitti/kitti15.pth')
|
||||||
|
parser.add_argument('--save_numpy', action='store_true', help='save output as numpy arrays')
|
||||||
|
parser.add_argument('-l', '--left_imgs', help="path to all first (left) frames", default="/data/KITTI/KITTI_2015/testing/image_2/*_10.png")
|
||||||
|
parser.add_argument('-r', '--right_imgs', help="path to all second (right) frames", default="/data/KITTI/KITTI_2015/testing/image_3/*_10.png")
|
||||||
|
# parser.add_argument('-l', '--left_imgs', help="path to all first (left) frames", default="/data/KITTI/KITTI_2012/testing/colored_0/*_10.png")
|
||||||
|
# parser.add_argument('-r', '--right_imgs', help="path to all second (right) frames", default="/data/KITTI/KITTI_2012/testing/colored_1/*_10.png")
|
||||||
|
parser.add_argument('--output_directory', help="directory to save output", default="output")
|
||||||
|
parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
|
||||||
|
parser.add_argument('--valid_iters', type=int, default=16, help='number of flow-field updates during forward pass')
|
||||||
|
|
||||||
|
# Architecture choices
|
||||||
|
parser.add_argument('--hidden_dims', nargs='+', type=int, default=[128]*3, help="hidden state and context dimensions")
|
||||||
|
parser.add_argument('--corr_implementation', choices=["reg", "alt", "reg_cuda", "alt_cuda"], default="reg", help="correlation volume implementation")
|
||||||
|
parser.add_argument('--shared_backbone', action='store_true', help="use a single backbone for the context and feature encoders")
|
||||||
|
parser.add_argument('--corr_levels', type=int, default=2, help="number of levels in the correlation pyramid")
|
||||||
|
parser.add_argument('--corr_radius', type=int, default=4, help="width of the correlation pyramid")
|
||||||
|
parser.add_argument('--n_downsample', type=int, default=2, help="resolution of the disparity field (1/2^K)")
|
||||||
|
parser.add_argument('--slow_fast_gru', action='store_true', help="iterate the low-res GRUs more frequently")
|
||||||
|
parser.add_argument('--n_gru_layers', type=int, default=3, help="number of hidden GRU levels")
|
||||||
|
parser.add_argument('--max_disp', type=int, default=192, help="max disp of geometry encoding volume")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
demo(args)
|
256
IGEV-Stereo/train_stereo.py
Normal file
256
IGEV-Stereo/train_stereo.py
Normal file
@ -0,0 +1,256 @@
|
|||||||
|
|
||||||
|
from __future__ import print_function, division
|
||||||
|
|
||||||
|
import os
|
||||||
|
os.environ['CUDA_VISIBLE_DEVICES'] = '0, 1'
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import numpy as np
|
||||||
|
from pathlib import Path
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.optim as optim
|
||||||
|
from core.igev_stereo import IGEVStereo
|
||||||
|
from evaluate_stereo import *
|
||||||
|
import core.stereo_datasets as datasets
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
ckpt_path = './checkpoints/igev_stereo'
|
||||||
|
log_path = './checkpoints/igev_stereo'
|
||||||
|
|
||||||
|
try:
|
||||||
|
from torch.cuda.amp import GradScaler
|
||||||
|
except:
|
||||||
|
class GradScaler:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
def scale(self, loss):
|
||||||
|
return loss
|
||||||
|
def unscale_(self, optimizer):
|
||||||
|
pass
|
||||||
|
def step(self, optimizer):
|
||||||
|
optimizer.step()
|
||||||
|
def update(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def sequence_loss(disp_preds, disp_init_pred, disp_gt, valid, loss_gamma=0.9, max_disp=192):
|
||||||
|
""" Loss function defined over sequence of flow predictions """
|
||||||
|
|
||||||
|
n_predictions = len(disp_preds)
|
||||||
|
assert n_predictions >= 1
|
||||||
|
disp_loss = 0.0
|
||||||
|
mag = torch.sum(disp_gt**2, dim=1).sqrt()
|
||||||
|
valid = ((valid >= 0.5) & (mag < max_disp)).unsqueeze(1)
|
||||||
|
assert valid.shape == disp_gt.shape, [valid.shape, disp_gt.shape]
|
||||||
|
assert not torch.isinf(disp_gt[valid.bool()]).any()
|
||||||
|
|
||||||
|
|
||||||
|
disp_loss += 1.0 * F.smooth_l1_loss(disp_init_pred[valid.bool()], disp_gt[valid.bool()], size_average=True)
|
||||||
|
for i in range(n_predictions):
|
||||||
|
adjusted_loss_gamma = loss_gamma**(15/(n_predictions - 1))
|
||||||
|
i_weight = adjusted_loss_gamma**(n_predictions - i - 1)
|
||||||
|
i_loss = (disp_preds[i] - disp_gt).abs()
|
||||||
|
assert i_loss.shape == valid.shape, [i_loss.shape, valid.shape, disp_gt.shape, disp_preds[i].shape]
|
||||||
|
disp_loss += i_weight * i_loss[valid.bool()].mean()
|
||||||
|
|
||||||
|
epe = torch.sum((disp_preds[-1] - disp_gt)**2, dim=1).sqrt()
|
||||||
|
epe = epe.view(-1)[valid.view(-1)]
|
||||||
|
|
||||||
|
metrics = {
|
||||||
|
'epe': epe.mean().item(),
|
||||||
|
'1px': (epe < 1).float().mean().item(),
|
||||||
|
'3px': (epe < 3).float().mean().item(),
|
||||||
|
'5px': (epe < 5).float().mean().item(),
|
||||||
|
}
|
||||||
|
|
||||||
|
return disp_loss, metrics
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_optimizer(args, model):
|
||||||
|
""" Create the optimizer and learning rate scheduler """
|
||||||
|
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wdecay, eps=1e-8)
|
||||||
|
|
||||||
|
scheduler = optim.lr_scheduler.OneCycleLR(optimizer, args.lr, args.num_steps+100,
|
||||||
|
pct_start=0.01, cycle_momentum=False, anneal_strategy='linear')
|
||||||
|
return optimizer, scheduler
|
||||||
|
|
||||||
|
|
||||||
|
class Logger:
|
||||||
|
SUM_FREQ = 100
|
||||||
|
def __init__(self, model, scheduler):
|
||||||
|
self.model = model
|
||||||
|
self.scheduler = scheduler
|
||||||
|
self.total_steps = 0
|
||||||
|
self.running_loss = {}
|
||||||
|
self.writer = SummaryWriter(log_dir=log_path)
|
||||||
|
|
||||||
|
def _print_training_status(self):
|
||||||
|
metrics_data = [self.running_loss[k]/Logger.SUM_FREQ for k in sorted(self.running_loss.keys())]
|
||||||
|
training_str = "[{:6d}, {:10.7f}] ".format(self.total_steps+1, self.scheduler.get_last_lr()[0])
|
||||||
|
metrics_str = ("{:10.4f}, "*len(metrics_data)).format(*metrics_data)
|
||||||
|
|
||||||
|
# print the training status
|
||||||
|
logging.info(f"Training Metrics ({self.total_steps}): {training_str + metrics_str}")
|
||||||
|
|
||||||
|
if self.writer is None:
|
||||||
|
self.writer = SummaryWriter(log_dir=log_path)
|
||||||
|
|
||||||
|
for k in self.running_loss:
|
||||||
|
self.writer.add_scalar(k, self.running_loss[k]/Logger.SUM_FREQ, self.total_steps)
|
||||||
|
self.running_loss[k] = 0.0
|
||||||
|
|
||||||
|
def push(self, metrics):
|
||||||
|
self.total_steps += 1
|
||||||
|
|
||||||
|
for key in metrics:
|
||||||
|
if key not in self.running_loss:
|
||||||
|
self.running_loss[key] = 0.0
|
||||||
|
|
||||||
|
self.running_loss[key] += metrics[key]
|
||||||
|
|
||||||
|
if self.total_steps % Logger.SUM_FREQ == Logger.SUM_FREQ-1:
|
||||||
|
self._print_training_status()
|
||||||
|
self.running_loss = {}
|
||||||
|
|
||||||
|
def write_dict(self, results):
|
||||||
|
if self.writer is None:
|
||||||
|
self.writer = SummaryWriter(log_dir=log_path)
|
||||||
|
|
||||||
|
for key in results:
|
||||||
|
self.writer.add_scalar(key, results[key], self.total_steps)
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
self.writer.close()
|
||||||
|
|
||||||
|
|
||||||
|
def train(args):
|
||||||
|
|
||||||
|
model = nn.DataParallel(IGEVStereo(args))
|
||||||
|
print("Parameter Count: %d" % count_parameters(model))
|
||||||
|
|
||||||
|
train_loader = datasets.fetch_dataloader(args)
|
||||||
|
optimizer, scheduler = fetch_optimizer(args, model)
|
||||||
|
total_steps = 0
|
||||||
|
logger = Logger(model, scheduler)
|
||||||
|
|
||||||
|
if args.restore_ckpt is not None:
|
||||||
|
assert args.restore_ckpt.endswith(".pth")
|
||||||
|
logging.info("Loading checkpoint...")
|
||||||
|
checkpoint = torch.load(args.restore_ckpt)
|
||||||
|
model.load_state_dict(checkpoint, strict=True)
|
||||||
|
logging.info(f"Done loading checkpoint")
|
||||||
|
model.cuda()
|
||||||
|
model.train()
|
||||||
|
model.module.freeze_bn() # We keep BatchNorm frozen
|
||||||
|
|
||||||
|
validation_frequency = 10000
|
||||||
|
|
||||||
|
scaler = GradScaler(enabled=args.mixed_precision)
|
||||||
|
|
||||||
|
should_keep_training = True
|
||||||
|
global_batch_num = 0
|
||||||
|
while should_keep_training:
|
||||||
|
|
||||||
|
for i_batch, (_, *data_blob) in enumerate(tqdm(train_loader)):
|
||||||
|
optimizer.zero_grad()
|
||||||
|
image1, image2, disp_gt, valid = [x.cuda() for x in data_blob]
|
||||||
|
|
||||||
|
assert model.training
|
||||||
|
disp_init_pred, disp_preds = model(image1, image2, iters=args.train_iters)
|
||||||
|
assert model.training
|
||||||
|
|
||||||
|
loss, metrics = sequence_loss(disp_preds, disp_init_pred, disp_gt, valid, max_disp=args.max_disp)
|
||||||
|
logger.writer.add_scalar("live_loss", loss.item(), global_batch_num)
|
||||||
|
logger.writer.add_scalar(f'learning_rate', optimizer.param_groups[0]['lr'], global_batch_num)
|
||||||
|
global_batch_num += 1
|
||||||
|
scaler.scale(loss).backward()
|
||||||
|
scaler.unscale_(optimizer)
|
||||||
|
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
||||||
|
|
||||||
|
scaler.step(optimizer)
|
||||||
|
scheduler.step()
|
||||||
|
scaler.update()
|
||||||
|
logger.push(metrics)
|
||||||
|
|
||||||
|
if total_steps % validation_frequency == validation_frequency - 1:
|
||||||
|
save_path = Path(ckpt_path + '/%d_%s.pth' % (total_steps + 1, args.name))
|
||||||
|
logging.info(f"Saving file {save_path.absolute()}")
|
||||||
|
torch.save(model.state_dict(), save_path)
|
||||||
|
results = validate_sceneflow(model.module, iters=args.valid_iters)
|
||||||
|
logger.write_dict(results)
|
||||||
|
model.train()
|
||||||
|
model.module.freeze_bn()
|
||||||
|
|
||||||
|
total_steps += 1
|
||||||
|
|
||||||
|
if total_steps > args.num_steps:
|
||||||
|
should_keep_training = False
|
||||||
|
break
|
||||||
|
|
||||||
|
if len(train_loader) >= 10000:
|
||||||
|
save_path = Path(ckpt_path + '/%d_epoch_%s.pth.gz' % (total_steps + 1, args.name))
|
||||||
|
logging.info(f"Saving file {save_path}")
|
||||||
|
torch.save(model.state_dict(), save_path)
|
||||||
|
|
||||||
|
print("FINISHED TRAINING")
|
||||||
|
logger.close()
|
||||||
|
PATH = ckpt_path + '/%s.pth' % args.name
|
||||||
|
torch.save(model.state_dict(), PATH)
|
||||||
|
|
||||||
|
return PATH
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--name', default='igev-stereo', help="name your experiment")
|
||||||
|
parser.add_argument('--restore_ckpt', default=None, help="")
|
||||||
|
parser.add_argument('--mixed_precision', default=True, action='store_true', help='use mixed precision')
|
||||||
|
|
||||||
|
# Training parameters
|
||||||
|
parser.add_argument('--batch_size', type=int, default=8, help="batch size used during training.")
|
||||||
|
parser.add_argument('--train_datasets', nargs='+', default=['sceneflow'], help="training datasets.")
|
||||||
|
parser.add_argument('--lr', type=float, default=0.0002, help="max learning rate.")
|
||||||
|
parser.add_argument('--num_steps', type=int, default=200000, help="length of training schedule.")
|
||||||
|
parser.add_argument('--image_size', type=int, nargs='+', default=[320, 736], help="size of the random image crops used during training.")
|
||||||
|
parser.add_argument('--train_iters', type=int, default=22, help="number of updates to the disparity field in each forward pass.")
|
||||||
|
parser.add_argument('--wdecay', type=float, default=.00001, help="Weight decay in optimizer.")
|
||||||
|
|
||||||
|
# Validation parameters
|
||||||
|
parser.add_argument('--valid_iters', type=int, default=32, help='number of flow-field updates during validation forward pass')
|
||||||
|
|
||||||
|
# Architecure choices
|
||||||
|
parser.add_argument('--corr_implementation', choices=["reg", "alt", "reg_cuda", "alt_cuda"], default="reg", help="correlation volume implementation")
|
||||||
|
parser.add_argument('--shared_backbone', action='store_true', help="use a single backbone for the context and feature encoders")
|
||||||
|
parser.add_argument('--corr_levels', type=int, default=2, help="number of levels in the correlation pyramid")
|
||||||
|
parser.add_argument('--corr_radius', type=int, default=4, help="width of the correlation pyramid")
|
||||||
|
parser.add_argument('--n_downsample', type=int, default=2, help="resolution of the disparity field (1/2^K)")
|
||||||
|
parser.add_argument('--slow_fast_gru', action='store_true', help="iterate the low-res GRUs more frequently")
|
||||||
|
parser.add_argument('--n_gru_layers', type=int, default=3, help="number of hidden GRU levels")
|
||||||
|
parser.add_argument('--hidden_dims', nargs='+', type=int, default=[128]*3, help="hidden state and context dimensions")
|
||||||
|
parser.add_argument('--max_disp', type=int, default=192, help="max disp of geometry encoding volume")
|
||||||
|
|
||||||
|
# Data augmentation
|
||||||
|
parser.add_argument('--img_gamma', type=float, nargs='+', default=None, help="gamma range")
|
||||||
|
parser.add_argument('--saturation_range', type=float, nargs='+', default=[0, 1.4], help='color saturation')
|
||||||
|
parser.add_argument('--do_flip', default=False, choices=['h', 'v'], help='flip the images horizontally or vertically')
|
||||||
|
parser.add_argument('--spatial_scale', type=float, nargs='+', default=[-0.2, 0.4], help='re-scale the images randomly')
|
||||||
|
parser.add_argument('--noyjitter', action='store_true', help='don\'t simulate imperfect rectification')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
torch.manual_seed(666)
|
||||||
|
np.random.seed(666)
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO,
|
||||||
|
format='%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s')
|
||||||
|
|
||||||
|
Path(ckpt_path).mkdir(exist_ok=True, parents=True)
|
||||||
|
|
||||||
|
train(args)
|
Loading…
Reference in New Issue
Block a user