Initial Commit.

This commit is contained in:
Gangwei Xu 2023-03-12 20:19:58 +08:00 committed by GitHub
parent 3bc74984bd
commit e651d84ed8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 2779 additions and 0 deletions

View File

View File

@ -0,0 +1,362 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from core.submodule import *
import timm
class ResidualBlock(nn.Module):
def __init__(self, in_planes, planes, norm_fn='group', stride=1):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
self.relu = nn.ReLU(inplace=True)
num_groups = planes // 8
if norm_fn == 'group':
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
if not (stride == 1 and in_planes == planes):
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
elif norm_fn == 'batch':
self.norm1 = nn.BatchNorm2d(planes)
self.norm2 = nn.BatchNorm2d(planes)
if not (stride == 1 and in_planes == planes):
self.norm3 = nn.BatchNorm2d(planes)
elif norm_fn == 'instance':
self.norm1 = nn.InstanceNorm2d(planes)
self.norm2 = nn.InstanceNorm2d(planes)
if not (stride == 1 and in_planes == planes):
self.norm3 = nn.InstanceNorm2d(planes)
elif norm_fn == 'none':
self.norm1 = nn.Sequential()
self.norm2 = nn.Sequential()
if not (stride == 1 and in_planes == planes):
self.norm3 = nn.Sequential()
if stride == 1 and in_planes == planes:
self.downsample = None
else:
self.downsample = nn.Sequential(
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
def forward(self, x):
y = x
y = self.conv1(y)
y = self.norm1(y)
y = self.relu(y)
y = self.conv2(y)
y = self.norm2(y)
y = self.relu(y)
if self.downsample is not None:
x = self.downsample(x)
return self.relu(x+y)
class BottleneckBlock(nn.Module):
def __init__(self, in_planes, planes, norm_fn='group', stride=1):
super(BottleneckBlock, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0)
self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride)
self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0)
self.relu = nn.ReLU(inplace=True)
num_groups = planes // 8
if norm_fn == 'group':
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
if not stride == 1:
self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
elif norm_fn == 'batch':
self.norm1 = nn.BatchNorm2d(planes//4)
self.norm2 = nn.BatchNorm2d(planes//4)
self.norm3 = nn.BatchNorm2d(planes)
if not stride == 1:
self.norm4 = nn.BatchNorm2d(planes)
elif norm_fn == 'instance':
self.norm1 = nn.InstanceNorm2d(planes//4)
self.norm2 = nn.InstanceNorm2d(planes//4)
self.norm3 = nn.InstanceNorm2d(planes)
if not stride == 1:
self.norm4 = nn.InstanceNorm2d(planes)
elif norm_fn == 'none':
self.norm1 = nn.Sequential()
self.norm2 = nn.Sequential()
self.norm3 = nn.Sequential()
if not stride == 1:
self.norm4 = nn.Sequential()
if stride == 1:
self.downsample = None
else:
self.downsample = nn.Sequential(
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4)
def forward(self, x):
y = x
y = self.relu(self.norm1(self.conv1(y)))
y = self.relu(self.norm2(self.conv2(y)))
y = self.relu(self.norm3(self.conv3(y)))
if self.downsample is not None:
x = self.downsample(x)
return self.relu(x+y)
class BasicEncoder(nn.Module):
def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0, downsample=3):
super(BasicEncoder, self).__init__()
self.norm_fn = norm_fn
self.downsample = downsample
if self.norm_fn == 'group':
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
elif self.norm_fn == 'batch':
self.norm1 = nn.BatchNorm2d(64)
elif self.norm_fn == 'instance':
self.norm1 = nn.InstanceNorm2d(64)
elif self.norm_fn == 'none':
self.norm1 = nn.Sequential()
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=1 + (downsample > 2), padding=3)
self.relu1 = nn.ReLU(inplace=True)
self.in_planes = 64
self.layer1 = self._make_layer(64, stride=1)
self.layer2 = self._make_layer(96, stride=1 + (downsample > 1))
self.layer3 = self._make_layer(128, stride=1 + (downsample > 0))
# output convolution
self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1)
self.dropout = None
if dropout > 0:
self.dropout = nn.Dropout2d(p=dropout)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
if m.weight is not None:
nn.init.constant_(m.weight, 1)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def _make_layer(self, dim, stride=1):
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
layers = (layer1, layer2)
self.in_planes = dim
return nn.Sequential(*layers)
def forward(self, x, dual_inp=False):
# if input is list, combine batch dimension
is_list = isinstance(x, tuple) or isinstance(x, list)
if is_list:
batch_dim = x[0].shape[0]
x = torch.cat(x, dim=0)
x = self.conv1(x)
x = self.norm1(x)
x = self.relu1(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.conv2(x)
if self.training and self.dropout is not None:
x = self.dropout(x)
if is_list:
x = x.split(split_size=batch_dim, dim=0)
return x
class MultiBasicEncoder(nn.Module):
def __init__(self, output_dim=[128], norm_fn='batch', dropout=0.0, downsample=3):
super(MultiBasicEncoder, self).__init__()
self.norm_fn = norm_fn
self.downsample = downsample
# self.norm_111 = nn.BatchNorm2d(128, affine=False, track_running_stats=False)
# self.norm_222 = nn.BatchNorm2d(128, affine=False, track_running_stats=False)
if self.norm_fn == 'group':
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
elif self.norm_fn == 'batch':
self.norm1 = nn.BatchNorm2d(64)
elif self.norm_fn == 'instance':
self.norm1 = nn.InstanceNorm2d(64)
elif self.norm_fn == 'none':
self.norm1 = nn.Sequential()
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=1 + (downsample > 2), padding=3)
self.relu1 = nn.ReLU(inplace=True)
self.in_planes = 64
self.layer1 = self._make_layer(64, stride=1)
self.layer2 = self._make_layer(96, stride=1 + (downsample > 1))
self.layer3 = self._make_layer(128, stride=1 + (downsample > 0))
self.layer4 = self._make_layer(128, stride=2)
self.layer5 = self._make_layer(128, stride=2)
output_list = []
for dim in output_dim:
conv_out = nn.Sequential(
ResidualBlock(128, 128, self.norm_fn, stride=1),
nn.Conv2d(128, dim[2], 3, padding=1))
output_list.append(conv_out)
self.outputs04 = nn.ModuleList(output_list)
output_list = []
for dim in output_dim:
conv_out = nn.Sequential(
ResidualBlock(128, 128, self.norm_fn, stride=1),
nn.Conv2d(128, dim[1], 3, padding=1))
output_list.append(conv_out)
self.outputs08 = nn.ModuleList(output_list)
output_list = []
for dim in output_dim:
conv_out = nn.Conv2d(128, dim[0], 3, padding=1)
output_list.append(conv_out)
self.outputs16 = nn.ModuleList(output_list)
if dropout > 0:
self.dropout = nn.Dropout2d(p=dropout)
else:
self.dropout = None
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
if m.weight is not None:
nn.init.constant_(m.weight, 1)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def _make_layer(self, dim, stride=1):
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
layers = (layer1, layer2)
self.in_planes = dim
return nn.Sequential(*layers)
def forward(self, x, dual_inp=False, num_layers=3):
x = self.conv1(x)
x = self.norm1(x)
x = self.relu1(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
if dual_inp:
v = x
x = x[:(x.shape[0]//2)]
outputs04 = [f(x) for f in self.outputs04]
if num_layers == 1:
return (outputs04, v) if dual_inp else (outputs04,)
y = self.layer4(x)
outputs08 = [f(y) for f in self.outputs08]
if num_layers == 2:
return (outputs04, outputs08, v) if dual_inp else (outputs04, outputs08)
z = self.layer5(y)
outputs16 = [f(z) for f in self.outputs16]
return (outputs04, outputs08, outputs16, v) if dual_inp else (outputs04, outputs08, outputs16)
class SubModule(nn.Module):
def __init__(self):
super(SubModule, self).__init__()
def weight_init(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.Conv3d):
n = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm3d):
m.weight.data.fill_(1)
m.bias.data.zero_()
class Feature(SubModule):
def __init__(self):
super(Feature, self).__init__()
pretrained = True
model = timm.create_model('mobilenetv2_100', pretrained=pretrained, features_only=True)
layers = [1,2,3,5,6]
chans = [16, 24, 32, 96, 160]
self.conv_stem = model.conv_stem
self.bn1 = model.bn1
self.act1 = model.act1
self.block0 = torch.nn.Sequential(*model.blocks[0:layers[0]])
self.block1 = torch.nn.Sequential(*model.blocks[layers[0]:layers[1]])
self.block2 = torch.nn.Sequential(*model.blocks[layers[1]:layers[2]])
self.block3 = torch.nn.Sequential(*model.blocks[layers[2]:layers[3]])
self.block4 = torch.nn.Sequential(*model.blocks[layers[3]:layers[4]])
self.deconv32_16 = Conv2x_IN(chans[4], chans[3], deconv=True, concat=True)
self.deconv16_8 = Conv2x_IN(chans[3]*2, chans[2], deconv=True, concat=True)
self.deconv8_4 = Conv2x_IN(chans[2]*2, chans[1], deconv=True, concat=True)
self.conv4 = BasicConv_IN(chans[1]*2, chans[1]*2, kernel_size=3, stride=1, padding=1)
def forward(self, x):
x = self.act1(self.bn1(self.conv_stem(x)))
x2 = self.block0(x)
x4 = self.block1(x2)
x8 = self.block2(x4)
x16 = self.block3(x8)
x32 = self.block4(x16)
x16 = self.deconv32_16(x32, x16)
x8 = self.deconv16_8(x16, x8)
x4 = self.deconv8_4(x8, x4)
x4 = self.conv4(x4)
return [x4, x8, x16, x32]

View File

@ -0,0 +1,69 @@
import torch
import torch.nn.functional as F
from core.utils.utils import bilinear_sampler
class Combined_Geo_Encoding_Volume:
def __init__(self, init_fmap1, init_fmap2, geo_volume, num_levels=2, radius=4):
self.num_levels = num_levels
self.radius = radius
self.geo_volume_pyramid = []
self.init_corr_pyramid = []
# all pairs correlation
init_corr = Combined_Geo_Encoding_Volume.corr(init_fmap1, init_fmap2)
b, h, w, _, w2 = init_corr.shape
b, c, d, h, w = geo_volume.shape
geo_volume = geo_volume.permute(0, 3, 4, 1, 2).reshape(b*h*w, c, 1, d)
init_corr = init_corr.reshape(b*h*w, 1, 1, w2)
self.geo_volume_pyramid.append(geo_volume)
self.init_corr_pyramid.append(init_corr)
for i in range(self.num_levels-1):
geo_volume = F.avg_pool2d(geo_volume, [1,2], stride=[1,2])
self.geo_volume_pyramid.append(geo_volume)
for i in range(self.num_levels-1):
init_corr = F.avg_pool2d(init_corr, [1,2], stride=[1,2])
self.init_corr_pyramid.append(init_corr)
def __call__(self, disp, coords):
r = self.radius
b, _, h, w = disp.shape
out_pyramid = []
for i in range(self.num_levels):
geo_volume = self.geo_volume_pyramid[i]
dx = torch.linspace(-r, r, 2*r+1)
dx = dx.view(1, 1, 2*r+1, 1).to(disp.device)
x0 = dx + disp.reshape(b*h*w, 1, 1, 1) / 2**i
y0 = torch.zeros_like(x0)
disp_lvl = torch.cat([x0,y0], dim=-1)
geo_volume = bilinear_sampler(geo_volume, disp_lvl)
geo_volume = geo_volume.view(b, h, w, -1)
init_corr = self.init_corr_pyramid[i]
init_x0 = coords.reshape(b*h*w, 1, 1, 1)/2**i - disp.reshape(b*h*w, 1, 1, 1) / 2**i + dx
init_coords_lvl = torch.cat([init_x0,y0], dim=-1)
init_corr = bilinear_sampler(init_corr, init_coords_lvl)
init_corr = init_corr.view(b, h, w, -1)
out_pyramid.append(geo_volume)
out_pyramid.append(init_corr)
out = torch.cat(out_pyramid, dim=-1)
return out.permute(0, 3, 1, 2).contiguous().float()
@staticmethod
def corr(fmap1, fmap2):
B, D, H, W1 = fmap1.shape
_, _, _, W2 = fmap2.shape
fmap1 = fmap1.view(B, D, H, W1)
fmap2 = fmap2.view(B, D, H, W2)
corr = torch.einsum('aijk,aijh->ajkh', fmap1, fmap2)
corr = corr.reshape(B, H, W1, 1, W2).contiguous()
return corr

View File

@ -0,0 +1,221 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from core.update import BasicMultiUpdateBlock
from core.extractor import MultiBasicEncoder, Feature
from core.geometry import Combined_Geo_Encoding_Volume
from core.submodule import *
import time
try:
autocast = torch.cuda.amp.autocast
except:
class autocast:
def __init__(self, enabled):
pass
def __enter__(self):
pass
def __exit__(self, *args):
pass
class hourglass(nn.Module):
def __init__(self, in_channels):
super(hourglass, self).__init__()
self.conv1 = nn.Sequential(BasicConv(in_channels, in_channels*2, is_3d=True, bn=True, relu=True, kernel_size=3,
padding=1, stride=2, dilation=1),
BasicConv(in_channels*2, in_channels*2, is_3d=True, bn=True, relu=True, kernel_size=3,
padding=1, stride=1, dilation=1))
self.conv2 = nn.Sequential(BasicConv(in_channels*2, in_channels*4, is_3d=True, bn=True, relu=True, kernel_size=3,
padding=1, stride=2, dilation=1),
BasicConv(in_channels*4, in_channels*4, is_3d=True, bn=True, relu=True, kernel_size=3,
padding=1, stride=1, dilation=1))
self.conv3 = nn.Sequential(BasicConv(in_channels*4, in_channels*6, is_3d=True, bn=True, relu=True, kernel_size=3,
padding=1, stride=2, dilation=1),
BasicConv(in_channels*6, in_channels*6, is_3d=True, bn=True, relu=True, kernel_size=3,
padding=1, stride=1, dilation=1))
self.conv3_up = BasicConv(in_channels*6, in_channels*4, deconv=True, is_3d=True, bn=True,
relu=True, kernel_size=(4, 4, 4), padding=(1, 1, 1), stride=(2, 2, 2))
self.conv2_up = BasicConv(in_channels*4, in_channels*2, deconv=True, is_3d=True, bn=True,
relu=True, kernel_size=(4, 4, 4), padding=(1, 1, 1), stride=(2, 2, 2))
self.conv1_up = BasicConv(in_channels*2, 8, deconv=True, is_3d=True, bn=False,
relu=False, kernel_size=(4, 4, 4), padding=(1, 1, 1), stride=(2, 2, 2))
self.agg_0 = nn.Sequential(BasicConv(in_channels*8, in_channels*4, is_3d=True, kernel_size=1, padding=0, stride=1),
BasicConv(in_channels*4, in_channels*4, is_3d=True, kernel_size=3, padding=1, stride=1),
BasicConv(in_channels*4, in_channels*4, is_3d=True, kernel_size=3, padding=1, stride=1),)
self.agg_1 = nn.Sequential(BasicConv(in_channels*4, in_channels*2, is_3d=True, kernel_size=1, padding=0, stride=1),
BasicConv(in_channels*2, in_channels*2, is_3d=True, kernel_size=3, padding=1, stride=1),
BasicConv(in_channels*2, in_channels*2, is_3d=True, kernel_size=3, padding=1, stride=1))
self.feature_att_8 = FeatureAtt(in_channels*2, 64)
self.feature_att_16 = FeatureAtt(in_channels*4, 192)
self.feature_att_32 = FeatureAtt(in_channels*6, 160)
self.feature_att_up_16 = FeatureAtt(in_channels*4, 192)
self.feature_att_up_8 = FeatureAtt(in_channels*2, 64)
def forward(self, x, features):
conv1 = self.conv1(x)
conv1 = self.feature_att_8(conv1, features[1])
conv2 = self.conv2(conv1)
conv2 = self.feature_att_16(conv2, features[2])
conv3 = self.conv3(conv2)
conv3 = self.feature_att_32(conv3, features[3])
conv3_up = self.conv3_up(conv3)
conv2 = torch.cat((conv3_up, conv2), dim=1)
conv2 = self.agg_0(conv2)
conv2 = self.feature_att_up_16(conv2, features[2])
conv2_up = self.conv2_up(conv2)
conv1 = torch.cat((conv2_up, conv1), dim=1)
conv1 = self.agg_1(conv1)
conv1 = self.feature_att_up_8(conv1, features[1])
conv = self.conv1_up(conv1)
return conv
class IGEVStereo(nn.Module):
def __init__(self, args):
super().__init__()
self.args = args
context_dims = args.hidden_dims
self.cnet = MultiBasicEncoder(output_dim=[args.hidden_dims, context_dims], norm_fn="batch", downsample=args.n_downsample)
self.update_block = BasicMultiUpdateBlock(self.args, hidden_dims=args.hidden_dims)
self.context_zqr_convs = nn.ModuleList([nn.Conv2d(context_dims[i], args.hidden_dims[i]*3, 3, padding=3//2) for i in range(self.args.n_gru_layers)])
self.feature = Feature()
self.stem_2 = nn.Sequential(
BasicConv_IN(3, 32, kernel_size=3, stride=2, padding=1),
nn.Conv2d(32, 32, 3, 1, 1, bias=False),
nn.InstanceNorm2d(32), nn.ReLU()
)
self.stem_4 = nn.Sequential(
BasicConv_IN(32, 48, kernel_size=3, stride=2, padding=1),
nn.Conv2d(48, 48, 3, 1, 1, bias=False),
nn.InstanceNorm2d(48), nn.ReLU()
)
self.spx = nn.Sequential(nn.ConvTranspose2d(2*32, 9, kernel_size=4, stride=2, padding=1),)
self.spx_2 = Conv2x_IN(24, 32, True)
self.spx_4 = nn.Sequential(
BasicConv_IN(96, 24, kernel_size=3, stride=1, padding=1),
nn.Conv2d(24, 24, 3, 1, 1, bias=False),
nn.InstanceNorm2d(24), nn.ReLU()
)
self.spx_2_gru = Conv2x(32, 32, True)
self.spx_gru = nn.Sequential(nn.ConvTranspose2d(2*32, 9, kernel_size=4, stride=2, padding=1),)
self.conv = BasicConv_IN(96, 96, kernel_size=3, padding=1, stride=1)
self.desc = nn.Conv2d(96, 96, kernel_size=1, padding=0, stride=1)
self.corr_stem = BasicConv(8, 8, is_3d=True, kernel_size=3, stride=1, padding=1)
self.corr_feature_att = FeatureAtt(8, 96)
self.cost_agg = hourglass(8)
self.classifier = nn.Conv3d(8, 1, 3, 1, 1, bias=False)
def freeze_bn(self):
for m in self.modules():
if isinstance(m, nn.BatchNorm2d):
m.eval()
def upsample_disp(self, disp, mask_feat_4, stem_2x):
with autocast(enabled=self.args.mixed_precision):
xspx = self.spx_2_gru(mask_feat_4, stem_2x)
spx_pred = self.spx_gru(xspx)
spx_pred = F.softmax(spx_pred, 1)
up_disp = context_upsample(disp*4., spx_pred).unsqueeze(1)
return up_disp
def forward(self, image1, image2, iters=12, flow_init=None, test_mode=False):
""" Estimate disparity between pair of frames """
image1 = (2 * (image1 / 255.0) - 1.0).contiguous()
image2 = (2 * (image2 / 255.0) - 1.0).contiguous()
with autocast(enabled=self.args.mixed_precision):
features_left = self.feature(image1)
features_right = self.feature(image2)
stem_2x = self.stem_2(image1)
stem_4x = self.stem_4(stem_2x)
stem_2y = self.stem_2(image2)
stem_4y = self.stem_4(stem_2y)
features_left[0] = torch.cat((features_left[0], stem_4x), 1)
features_right[0] = torch.cat((features_right[0], stem_4y), 1)
match_left = self.desc(self.conv(features_left[0]))
match_right = self.desc(self.conv(features_right[0]))
gwc_volume = build_gwc_volume(match_left, match_right, 192//4, 8)
gwc_volume = self.corr_stem(gwc_volume)
gwc_volume = self.corr_feature_att(gwc_volume, features_left[0])
geo_encoding_volume = self.cost_agg(gwc_volume, features_left)
# Init disp from geometry encoding volume
prob = F.softmax(self.classifier(geo_encoding_volume).squeeze(1), dim=1)
init_disp = disparity_regression(prob, self.args.max_disp//4)
del prob, gwc_volume
if not test_mode:
xspx = self.spx_4(features_left[0])
xspx = self.spx_2(xspx, stem_2x)
spx_pred = self.spx(xspx)
spx_pred = F.softmax(spx_pred, 1)
cnet_list = self.cnet(image1, num_layers=self.args.n_gru_layers)
net_list = [torch.tanh(x[0]) for x in cnet_list]
inp_list = [torch.relu(x[1]) for x in cnet_list]
inp_list = [list(conv(i).split(split_size=conv.out_channels//3, dim=1)) for i,conv in zip(inp_list, self.context_zqr_convs)]
geo_block = Combined_Geo_Encoding_Volume
geo_fn = geo_block(match_left.float(), match_right.float(), geo_encoding_volume.float(), radius=self.args.corr_radius, num_levels=self.args.corr_levels)
b, c, h, w = match_left.shape
coords = torch.arange(w).float().to(match_left.device).reshape(1,1,w,1).repeat(b, h, 1, 1)
disp = init_disp
disp_preds = []
# GRUs iterations to update disparity
for itr in range(iters):
disp = disp.detach()
geo_feat = geo_fn(disp, coords)
with autocast(enabled=self.args.mixed_precision):
if self.args.n_gru_layers == 3 and self.args.slow_fast_gru: # Update low-res ConvGRU
net_list = self.update_block(net_list, inp_list, iter16=True, iter08=False, iter04=False, update=False)
if self.args.n_gru_layers >= 2 and self.args.slow_fast_gru:# Update low-res ConvGRU and mid-res ConvGRU
net_list = self.update_block(net_list, inp_list, iter16=self.args.n_gru_layers==3, iter08=True, iter04=False, update=False)
net_list, mask_feat_4, delta_disp = self.update_block(net_list, inp_list, geo_feat, disp, iter16=self.args.n_gru_layers==3, iter08=self.args.n_gru_layers>=2)
disp = disp + delta_disp
if test_mode and itr < iters-1:
continue
# upsample predictions
disp_up = self.upsample_disp(disp, mask_feat_4, stem_2x)
disp_preds.append(disp_up)
if test_mode:
return disp_up
init_disp = context_upsample(init_disp*4., spx_pred.float()).unsqueeze(1)
return init_disp, disp_preds

View File

@ -0,0 +1,331 @@
import numpy as np
import torch
import torch.utils.data as data
import torch.nn.functional as F
import logging
import os
import re
import copy
import math
import random
from pathlib import Path
from glob import glob
import os.path as osp
from core.utils import frame_utils
from core.utils.augmentor import FlowAugmentor, SparseFlowAugmentor
class StereoDataset(data.Dataset):
def __init__(self, aug_params=None, sparse=False, reader=None):
self.augmentor = None
self.sparse = sparse
self.img_pad = aug_params.pop("img_pad", None) if aug_params is not None else None
if aug_params is not None and "crop_size" in aug_params:
if sparse:
self.augmentor = SparseFlowAugmentor(**aug_params)
else:
self.augmentor = FlowAugmentor(**aug_params)
if reader is None:
self.disparity_reader = frame_utils.read_gen
else:
self.disparity_reader = reader
self.is_test = False
self.init_seed = False
self.flow_list = []
self.disparity_list = []
self.image_list = []
self.extra_info = []
def __getitem__(self, index):
if self.is_test:
img1 = frame_utils.read_gen(self.image_list[index][0])
img2 = frame_utils.read_gen(self.image_list[index][1])
img1 = np.array(img1).astype(np.uint8)[..., :3]
img2 = np.array(img2).astype(np.uint8)[..., :3]
img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
return img1, img2, self.extra_info[index]
if not self.init_seed:
worker_info = torch.utils.data.get_worker_info()
if worker_info is not None:
torch.manual_seed(worker_info.id)
np.random.seed(worker_info.id)
random.seed(worker_info.id)
self.init_seed = True
index = index % len(self.image_list)
disp = self.disparity_reader(self.disparity_list[index])
if isinstance(disp, tuple):
disp, valid = disp
else:
valid = disp < 512
img1 = frame_utils.read_gen(self.image_list[index][0])
img2 = frame_utils.read_gen(self.image_list[index][1])
img1 = np.array(img1).astype(np.uint8)
img2 = np.array(img2).astype(np.uint8)
disp = np.array(disp).astype(np.float32)
flow = np.stack([disp, np.zeros_like(disp)], axis=-1)
# grayscale images
if len(img1.shape) == 2:
img1 = np.tile(img1[...,None], (1, 1, 3))
img2 = np.tile(img2[...,None], (1, 1, 3))
else:
img1 = img1[..., :3]
img2 = img2[..., :3]
if self.augmentor is not None:
if self.sparse:
img1, img2, flow, valid = self.augmentor(img1, img2, flow, valid)
else:
img1, img2, flow = self.augmentor(img1, img2, flow)
img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
flow = torch.from_numpy(flow).permute(2, 0, 1).float()
if self.sparse:
valid = torch.from_numpy(valid)
else:
valid = (flow[0].abs() < 512) & (flow[1].abs() < 512)
if self.img_pad is not None:
padH, padW = self.img_pad
img1 = F.pad(img1, [padW]*2 + [padH]*2)
img2 = F.pad(img2, [padW]*2 + [padH]*2)
flow = flow[:1]
return self.image_list[index] + [self.disparity_list[index]], img1, img2, flow, valid.float()
def __mul__(self, v):
copy_of_self = copy.deepcopy(self)
copy_of_self.flow_list = v * copy_of_self.flow_list
copy_of_self.image_list = v * copy_of_self.image_list
copy_of_self.disparity_list = v * copy_of_self.disparity_list
copy_of_self.extra_info = v * copy_of_self.extra_info
return copy_of_self
def __len__(self):
return len(self.image_list)
class SceneFlowDatasets(StereoDataset):
def __init__(self, aug_params=None, root='/data/sceneflow/', dstype='frames_finalpass', things_test=False):
super(SceneFlowDatasets, self).__init__(aug_params)
self.root = root
self.dstype = dstype
if things_test:
self._add_things("TEST")
else:
self._add_things("TRAIN")
self._add_monkaa("TRAIN")
self._add_driving("TRAIN")
def _add_things(self, split='TRAIN'):
""" Add FlyingThings3D data """
original_length = len(self.disparity_list)
# root = osp.join(self.root, 'FlyingThings3D')
root = self.root
left_images = sorted( glob(osp.join(root, self.dstype, split, '*/*/left/*.png')) )
right_images = [ im.replace('left', 'right') for im in left_images ]
disparity_images = [ im.replace(self.dstype, 'disparity').replace('.png', '.pfm') for im in left_images ]
# Choose a random subset of 400 images for validation
state = np.random.get_state()
np.random.seed(1000)
# val_idxs = set(np.random.permutation(len(left_images))[:100])
val_idxs = set(np.random.permutation(len(left_images)))
np.random.set_state(state)
for idx, (img1, img2, disp) in enumerate(zip(left_images, right_images, disparity_images)):
if (split == 'TEST' and idx in val_idxs) or split == 'TRAIN':
self.image_list += [ [img1, img2] ]
self.disparity_list += [ disp ]
logging.info(f"Added {len(self.disparity_list) - original_length} from FlyingThings {self.dstype}")
def _add_monkaa(self, split="TRAIN"):
""" Add FlyingThings3D data """
original_length = len(self.disparity_list)
root = self.root
left_images = sorted( glob(osp.join(root, self.dstype, split, '*/left/*.png')) )
right_images = [ image_file.replace('left', 'right') for image_file in left_images ]
disparity_images = [ im.replace(self.dstype, 'disparity').replace('.png', '.pfm') for im in left_images ]
for img1, img2, disp in zip(left_images, right_images, disparity_images):
self.image_list += [ [img1, img2] ]
self.disparity_list += [ disp ]
logging.info(f"Added {len(self.disparity_list) - original_length} from Monkaa {self.dstype}")
def _add_driving(self, split="TRAIN"):
""" Add FlyingThings3D data """
original_length = len(self.disparity_list)
root = self.root
left_images = sorted( glob(osp.join(root, self.dstype, split, '*/*/*/left/*.png')) )
right_images = [ image_file.replace('left', 'right') for image_file in left_images ]
disparity_images = [ im.replace(self.dstype, 'disparity').replace('.png', '.pfm') for im in left_images ]
for img1, img2, disp in zip(left_images, right_images, disparity_images):
self.image_list += [ [img1, img2] ]
self.disparity_list += [ disp ]
logging.info(f"Added {len(self.disparity_list) - original_length} from Driving {self.dstype}")
class ETH3D(StereoDataset):
def __init__(self, aug_params=None, root='/data/ETH3D', split='training'):
super(ETH3D, self).__init__(aug_params, sparse=True)
image1_list = sorted( glob(osp.join(root, f'two_view_{split}/*/im0.png')) )
image2_list = sorted( glob(osp.join(root, f'two_view_{split}/*/im1.png')) )
disp_list = sorted( glob(osp.join(root, 'two_view_training_gt/*/disp0GT.pfm')) ) if split == 'training' else [osp.join(root, 'two_view_training_gt/playground_1l/disp0GT.pfm')]*len(image1_list)
for img1, img2, disp in zip(image1_list, image2_list, disp_list):
self.image_list += [ [img1, img2] ]
self.disparity_list += [ disp ]
class SintelStereo(StereoDataset):
def __init__(self, aug_params=None, root='datasets/SintelStereo'):
super().__init__(aug_params, sparse=True, reader=frame_utils.readDispSintelStereo)
image1_list = sorted( glob(osp.join(root, 'training/*_left/*/frame_*.png')) )
image2_list = sorted( glob(osp.join(root, 'training/*_right/*/frame_*.png')) )
disp_list = sorted( glob(osp.join(root, 'training/disparities/*/frame_*.png')) ) * 2
for img1, img2, disp in zip(image1_list, image2_list, disp_list):
assert img1.split('/')[-2:] == disp.split('/')[-2:]
self.image_list += [ [img1, img2] ]
self.disparity_list += [ disp ]
class FallingThings(StereoDataset):
def __init__(self, aug_params=None, root='datasets/FallingThings'):
super().__init__(aug_params, reader=frame_utils.readDispFallingThings)
assert os.path.exists(root)
with open(os.path.join(root, 'filenames.txt'), 'r') as f:
filenames = sorted(f.read().splitlines())
image1_list = [osp.join(root, e) for e in filenames]
image2_list = [osp.join(root, e.replace('left.jpg', 'right.jpg')) for e in filenames]
disp_list = [osp.join(root, e.replace('left.jpg', 'left.depth.png')) for e in filenames]
for img1, img2, disp in zip(image1_list, image2_list, disp_list):
self.image_list += [ [img1, img2] ]
self.disparity_list += [ disp ]
class TartanAir(StereoDataset):
def __init__(self, aug_params=None, root='datasets', keywords=[]):
super().__init__(aug_params, reader=frame_utils.readDispTartanAir)
assert os.path.exists(root)
with open(os.path.join(root, 'tartanair_filenames.txt'), 'r') as f:
filenames = sorted(list(filter(lambda s: 'seasonsforest_winter/Easy' not in s, f.read().splitlines())))
for kw in keywords:
filenames = sorted(list(filter(lambda s: kw in s.lower(), filenames)))
image1_list = [osp.join(root, e) for e in filenames]
image2_list = [osp.join(root, e.replace('_left', '_right')) for e in filenames]
disp_list = [osp.join(root, e.replace('image_left', 'depth_left').replace('left.png', 'left_depth.npy')) for e in filenames]
for img1, img2, disp in zip(image1_list, image2_list, disp_list):
self.image_list += [ [img1, img2] ]
self.disparity_list += [ disp ]
class KITTI(StereoDataset):
def __init__(self, aug_params=None, root='/data/KITTI/KITTI_2015', image_set='training'):
super(KITTI, self).__init__(aug_params, sparse=True, reader=frame_utils.readDispKITTI)
assert os.path.exists(root)
root_12 = '/data/KITTI/KITTI_2012'
image1_list = sorted(glob(os.path.join(root_12, image_set, 'colored_0/*_10.png')))
image2_list = sorted(glob(os.path.join(root_12, image_set, 'colored_1/*_10.png')))
disp_list = sorted(glob(os.path.join(root_12, 'training', 'disp_occ/*_10.png'))) if image_set == 'training' else [osp.join(root, 'training/disp_occ/000085_10.png')]*len(image1_list)
root_15 = '/data/KITTI/KITTI_2015'
image1_list += sorted(glob(os.path.join(root_15, image_set, 'image_2/*_10.png')))
image2_list += sorted(glob(os.path.join(root_15, image_set, 'image_3/*_10.png')))
disp_list += sorted(glob(os.path.join(root_15, 'training', 'disp_occ_0/*_10.png'))) if image_set == 'training' else [osp.join(root, 'training/disp_occ_0/000085_10.png')]*len(image1_list)
for idx, (img1, img2, disp) in enumerate(zip(image1_list, image2_list, disp_list)):
self.image_list += [ [img1, img2] ]
self.disparity_list += [ disp ]
class Middlebury(StereoDataset):
def __init__(self, aug_params=None, root='/data/Middlebury', split='F'):
super(Middlebury, self).__init__(aug_params, sparse=True, reader=frame_utils.readDispMiddlebury)
assert os.path.exists(root)
assert split in "FHQ"
lines = list(map(osp.basename, glob(os.path.join(root, "trainingH/*"))))
# lines = list(filter(lambda p: any(s in p.split('/') for s in Path(os.path.join(root, "MiddEval3/official_train.txt")).read_text().splitlines()), lines))
# image1_list = sorted([os.path.join(root, "MiddEval3", f'training{split}', f'{name}/im0.png') for name in lines])
# image2_list = sorted([os.path.join(root, "MiddEval3", f'training{split}', f'{name}/im1.png') for name in lines])
# disp_list = sorted([os.path.join(root, "MiddEval3", f'training{split}', f'{name}/disp0GT.pfm') for name in lines])
image1_list = sorted([os.path.join(root, f'training{split}', f'{name}/im0.png') for name in lines])
image2_list = sorted([os.path.join(root, f'training{split}', f'{name}/im1.png') for name in lines])
disp_list = sorted([os.path.join(root, f'training{split}', f'{name}/disp0GT.pfm') for name in lines])
assert len(image1_list) == len(image2_list) == len(disp_list) > 0, [image1_list, split]
for img1, img2, disp in zip(image1_list, image2_list, disp_list):
self.image_list += [ [img1, img2] ]
self.disparity_list += [ disp ]
def fetch_dataloader(args):
""" Create the data loader for the corresponding trainign set """
aug_params = {'crop_size': args.image_size, 'min_scale': args.spatial_scale[0], 'max_scale': args.spatial_scale[1], 'do_flip': False, 'yjitter': not args.noyjitter}
if hasattr(args, "saturation_range") and args.saturation_range is not None:
aug_params["saturation_range"] = args.saturation_range
if hasattr(args, "img_gamma") and args.img_gamma is not None:
aug_params["gamma"] = args.img_gamma
if hasattr(args, "do_flip") and args.do_flip is not None:
aug_params["do_flip"] = args.do_flip
train_dataset = None
for dataset_name in args.train_datasets:
if re.compile("middlebury_.*").fullmatch(dataset_name):
new_dataset = Middlebury(aug_params, split=dataset_name.replace('middlebury_',''))
elif dataset_name == 'sceneflow':
#clean_dataset = SceneFlowDatasets(aug_params, dstype='frames_cleanpass')
final_dataset = SceneFlowDatasets(aug_params, dstype='frames_finalpass')
#new_dataset = (clean_dataset*4) + (final_dataset*4)
new_dataset = final_dataset
logging.info(f"Adding {len(new_dataset)} samples from SceneFlow")
elif 'kitti' in dataset_name:
new_dataset = KITTI(aug_params)
logging.info(f"Adding {len(new_dataset)} samples from KITTI")
elif dataset_name == 'sintel_stereo':
new_dataset = SintelStereo(aug_params)*140
logging.info(f"Adding {len(new_dataset)} samples from Sintel Stereo")
elif dataset_name == 'falling_things':
new_dataset = FallingThings(aug_params)*5
logging.info(f"Adding {len(new_dataset)} samples from FallingThings")
elif dataset_name.startswith('tartan_air'):
new_dataset = TartanAir(aug_params, keywords=dataset_name.split('_')[2:])
logging.info(f"Adding {len(new_dataset)} samples from Tartain Air")
train_dataset = new_dataset if train_dataset is None else train_dataset + new_dataset
train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size,
pin_memory=True, shuffle=True, num_workers=int(os.environ.get('SLURM_CPUS_PER_TASK', 6))-2, drop_last=True)
logging.info('Training with %d image pairs' % len(train_dataset))
return train_loader

View File

@ -0,0 +1,253 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class BasicConv(nn.Module):
def __init__(self, in_channels, out_channels, deconv=False, is_3d=False, bn=True, relu=True, **kwargs):
super(BasicConv, self).__init__()
self.relu = relu
self.use_bn = bn
if is_3d:
if deconv:
self.conv = nn.ConvTranspose3d(in_channels, out_channels, bias=False, **kwargs)
else:
self.conv = nn.Conv3d(in_channels, out_channels, bias=False, **kwargs)
self.bn = nn.BatchNorm3d(out_channels)
else:
if deconv:
self.conv = nn.ConvTranspose2d(in_channels, out_channels, bias=False, **kwargs)
else:
self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
self.bn = nn.BatchNorm2d(out_channels)
def forward(self, x):
x = self.conv(x)
if self.use_bn:
x = self.bn(x)
if self.relu:
x = nn.LeakyReLU()(x)#, inplace=True)
return x
class Conv2x(nn.Module):
def __init__(self, in_channels, out_channels, deconv=False, is_3d=False, concat=True, keep_concat=True, bn=True, relu=True, keep_dispc=False):
super(Conv2x, self).__init__()
self.concat = concat
self.is_3d = is_3d
if deconv and is_3d:
kernel = (4, 4, 4)
elif deconv:
kernel = 4
else:
kernel = 3
if deconv and is_3d and keep_dispc:
kernel = (1, 4, 4)
stride = (1, 2, 2)
padding = (0, 1, 1)
self.conv1 = BasicConv(in_channels, out_channels, deconv, is_3d, bn=True, relu=True, kernel_size=kernel, stride=stride, padding=padding)
else:
self.conv1 = BasicConv(in_channels, out_channels, deconv, is_3d, bn=True, relu=True, kernel_size=kernel, stride=2, padding=1)
if self.concat:
mul = 2 if keep_concat else 1
self.conv2 = BasicConv(out_channels*2, out_channels*mul, False, is_3d, bn, relu, kernel_size=3, stride=1, padding=1)
else:
self.conv2 = BasicConv(out_channels, out_channels, False, is_3d, bn, relu, kernel_size=3, stride=1, padding=1)
def forward(self, x, rem):
x = self.conv1(x)
if x.shape != rem.shape:
x = F.interpolate(
x,
size=(rem.shape[-2], rem.shape[-1]),
mode='nearest')
if self.concat:
x = torch.cat((x, rem), 1)
else:
x = x + rem
x = self.conv2(x)
return x
class BasicConv_IN(nn.Module):
def __init__(self, in_channels, out_channels, deconv=False, is_3d=False, IN=True, relu=True, **kwargs):
super(BasicConv_IN, self).__init__()
self.relu = relu
self.use_in = IN
if is_3d:
if deconv:
self.conv = nn.ConvTranspose3d(in_channels, out_channels, bias=False, **kwargs)
else:
self.conv = nn.Conv3d(in_channels, out_channels, bias=False, **kwargs)
self.IN = nn.InstanceNorm3d(out_channels)
else:
if deconv:
self.conv = nn.ConvTranspose2d(in_channels, out_channels, bias=False, **kwargs)
else:
self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
self.IN = nn.InstanceNorm2d(out_channels)
def forward(self, x):
x = self.conv(x)
if self.use_in:
x = self.IN(x)
if self.relu:
x = nn.LeakyReLU()(x)#, inplace=True)
return x
class Conv2x_IN(nn.Module):
def __init__(self, in_channels, out_channels, deconv=False, is_3d=False, concat=True, keep_concat=True, IN=True, relu=True, keep_dispc=False):
super(Conv2x_IN, self).__init__()
self.concat = concat
self.is_3d = is_3d
if deconv and is_3d:
kernel = (4, 4, 4)
elif deconv:
kernel = 4
else:
kernel = 3
if deconv and is_3d and keep_dispc:
kernel = (1, 4, 4)
stride = (1, 2, 2)
padding = (0, 1, 1)
self.conv1 = BasicConv_IN(in_channels, out_channels, deconv, is_3d, IN=True, relu=True, kernel_size=kernel, stride=stride, padding=padding)
else:
self.conv1 = BasicConv_IN(in_channels, out_channels, deconv, is_3d, IN=True, relu=True, kernel_size=kernel, stride=2, padding=1)
if self.concat:
mul = 2 if keep_concat else 1
self.conv2 = BasicConv_IN(out_channels*2, out_channels*mul, False, is_3d, IN, relu, kernel_size=3, stride=1, padding=1)
else:
self.conv2 = BasicConv_IN(out_channels, out_channels, False, is_3d, IN, relu, kernel_size=3, stride=1, padding=1)
def forward(self, x, rem):
x = self.conv1(x)
if x.shape != rem.shape:
x = F.interpolate(
x,
size=(rem.shape[-2], rem.shape[-1]),
mode='nearest')
if self.concat:
x = torch.cat((x, rem), 1)
else:
x = x + rem
x = self.conv2(x)
return x
def groupwise_correlation(fea1, fea2, num_groups):
B, C, H, W = fea1.shape
assert C % num_groups == 0
channels_per_group = C // num_groups
cost = (fea1 * fea2).view([B, num_groups, channels_per_group, H, W]).mean(dim=2)
assert cost.shape == (B, num_groups, H, W)
return cost
def build_gwc_volume(refimg_fea, targetimg_fea, maxdisp, num_groups):
B, C, H, W = refimg_fea.shape
volume = refimg_fea.new_zeros([B, num_groups, maxdisp, H, W])
for i in range(maxdisp):
if i > 0:
volume[:, :, i, :, i:] = groupwise_correlation(refimg_fea[:, :, :, i:], targetimg_fea[:, :, :, :-i],
num_groups)
else:
volume[:, :, i, :, :] = groupwise_correlation(refimg_fea, targetimg_fea, num_groups)
volume = volume.contiguous()
return volume
def norm_correlation(fea1, fea2):
cost = torch.mean(((fea1/(torch.norm(fea1, 2, 1, True)+1e-05)) * (fea2/(torch.norm(fea2, 2, 1, True)+1e-05))), dim=1, keepdim=True)
return cost
def build_norm_correlation_volume(refimg_fea, targetimg_fea, maxdisp):
B, C, H, W = refimg_fea.shape
volume = refimg_fea.new_zeros([B, 1, maxdisp, H, W])
for i in range(maxdisp):
if i > 0:
volume[:, :, i, :, i:] = norm_correlation(refimg_fea[:, :, :, i:], targetimg_fea[:, :, :, :-i])
else:
volume[:, :, i, :, :] = norm_correlation(refimg_fea, targetimg_fea)
volume = volume.contiguous()
return volume
def correlation(fea1, fea2):
cost = torch.sum((fea1 * fea2), dim=1, keepdim=True)
return cost
def build_correlation_volume(refimg_fea, targetimg_fea, maxdisp):
B, C, H, W = refimg_fea.shape
volume = refimg_fea.new_zeros([B, 1, maxdisp, H, W])
for i in range(maxdisp):
if i > 0:
volume[:, :, i, :, i:] = correlation(refimg_fea[:, :, :, i:], targetimg_fea[:, :, :, :-i])
else:
volume[:, :, i, :, :] = correlation(refimg_fea, targetimg_fea)
volume = volume.contiguous()
return volume
def build_concat_volume(refimg_fea, targetimg_fea, maxdisp):
B, C, H, W = refimg_fea.shape
volume = refimg_fea.new_zeros([B, 2 * C, maxdisp, H, W])
for i in range(maxdisp):
if i > 0:
volume[:, :C, i, :, :] = refimg_fea[:, :, :, :]
volume[:, C:, i, :, i:] = targetimg_fea[:, :, :, :-i]
else:
volume[:, :C, i, :, :] = refimg_fea
volume[:, C:, i, :, :] = targetimg_fea
volume = volume.contiguous()
return volume
def disparity_regression(x, maxdisp):
assert len(x.shape) == 4
disp_values = torch.arange(0, maxdisp, dtype=x.dtype, device=x.device)
disp_values = disp_values.view(1, maxdisp, 1, 1)
return torch.sum(x * disp_values, 1, keepdim=True)
class FeatureAtt(nn.Module):
def __init__(self, cv_chan, feat_chan):
super(FeatureAtt, self).__init__()
self.feat_att = nn.Sequential(
BasicConv(feat_chan, feat_chan//2, kernel_size=1, stride=1, padding=0),
nn.Conv2d(feat_chan//2, cv_chan, 1))
def forward(self, cv, feat):
'''
'''
feat_att = self.feat_att(feat).unsqueeze(2)
cv = torch.sigmoid(feat_att)*cv
return cv
def context_upsample(disp_low, up_weights):
###
# cv (b,1,h,w)
# sp (b,9,4*h,4*w)
###
b, c, h, w = disp_low.shape
disp_unfold = F.unfold(disp_low.reshape(b,c,h,w),3,1,1).reshape(b,-1,h,w)
disp_unfold = F.interpolate(disp_unfold,(h*4,w*4),mode='nearest').reshape(b,9,h*4,w*4)
disp = (disp_unfold*up_weights).sum(1)
return disp

142
IGEV-Stereo/core/update.py Normal file
View File

@ -0,0 +1,142 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from opt_einsum import contract
class FlowHead(nn.Module):
def __init__(self, input_dim=128, hidden_dim=256, output_dim=2):
super(FlowHead, self).__init__()
self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
self.conv2 = nn.Conv2d(hidden_dim, output_dim, 3, padding=1)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
return self.conv2(self.relu(self.conv1(x)))
class DispHead(nn.Module):
def __init__(self, input_dim=128, hidden_dim=256, output_dim=1):
super(DispHead, self).__init__()
self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
self.conv2 = nn.Conv2d(hidden_dim, output_dim, 3, padding=1)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
return self.conv2(self.relu(self.conv1(x)))
class ConvGRU(nn.Module):
def __init__(self, hidden_dim, input_dim, kernel_size=3):
super(ConvGRU, self).__init__()
self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, kernel_size, padding=kernel_size//2)
self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, kernel_size, padding=kernel_size//2)
self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, kernel_size, padding=kernel_size//2)
def forward(self, h, cz, cr, cq, *x_list):
x = torch.cat(x_list, dim=1)
hx = torch.cat([h, x], dim=1)
z = torch.sigmoid(self.convz(hx) + cz)
r = torch.sigmoid(self.convr(hx) + cr)
q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1)) + cq)
h = (1-z) * h + z * q
return h
class SepConvGRU(nn.Module):
def __init__(self, hidden_dim=128, input_dim=192+128):
super(SepConvGRU, self).__init__()
self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
def forward(self, h, *x):
# horizontal
x = torch.cat(x, dim=1)
hx = torch.cat([h, x], dim=1)
z = torch.sigmoid(self.convz1(hx))
r = torch.sigmoid(self.convr1(hx))
q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1)))
h = (1-z) * h + z * q
# vertical
hx = torch.cat([h, x], dim=1)
z = torch.sigmoid(self.convz2(hx))
r = torch.sigmoid(self.convr2(hx))
q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1)))
h = (1-z) * h + z * q
return h
class BasicMotionEncoder(nn.Module):
def __init__(self, args):
super(BasicMotionEncoder, self).__init__()
self.args = args
cor_planes = args.corr_levels * (2*args.corr_radius + 1) * (8+1)
self.convc1 = nn.Conv2d(cor_planes, 64, 1, padding=0)
self.convc2 = nn.Conv2d(64, 64, 3, padding=1)
self.convd1 = nn.Conv2d(1, 64, 7, padding=3)
self.convd2 = nn.Conv2d(64, 64, 3, padding=1)
self.conv = nn.Conv2d(64+64, 128-1, 3, padding=1)
def forward(self, disp, corr):
cor = F.relu(self.convc1(corr))
cor = F.relu(self.convc2(cor))
disp_ = F.relu(self.convd1(disp))
disp_ = F.relu(self.convd2(disp_))
cor_disp = torch.cat([cor, disp_], dim=1)
out = F.relu(self.conv(cor_disp))
return torch.cat([out, disp], dim=1)
def pool2x(x):
return F.avg_pool2d(x, 3, stride=2, padding=1)
def pool4x(x):
return F.avg_pool2d(x, 5, stride=4, padding=1)
def interp(x, dest):
interp_args = {'mode': 'bilinear', 'align_corners': True}
return F.interpolate(x, dest.shape[2:], **interp_args)
class BasicMultiUpdateBlock(nn.Module):
def __init__(self, args, hidden_dims=[]):
super().__init__()
self.args = args
self.encoder = BasicMotionEncoder(args)
encoder_output_dim = 128
self.gru04 = ConvGRU(hidden_dims[2], encoder_output_dim + hidden_dims[1] * (args.n_gru_layers > 1))
self.gru08 = ConvGRU(hidden_dims[1], hidden_dims[0] * (args.n_gru_layers == 3) + hidden_dims[2])
self.gru16 = ConvGRU(hidden_dims[0], hidden_dims[1])
self.disp_head = DispHead(hidden_dims[2], hidden_dim=256, output_dim=1)
factor = 2**self.args.n_downsample
self.mask_feat_4 = nn.Sequential(
nn.Conv2d(hidden_dims[2], 32, 3, padding=1),
nn.ReLU(inplace=True))
def forward(self, net, inp, corr=None, disp=None, iter04=True, iter08=True, iter16=True, update=True):
if iter16:
net[2] = self.gru16(net[2], *(inp[2]), pool2x(net[1]))
if iter08:
if self.args.n_gru_layers > 2:
net[1] = self.gru08(net[1], *(inp[1]), pool2x(net[0]), interp(net[2], net[1]))
else:
net[1] = self.gru08(net[1], *(inp[1]), pool2x(net[0]))
if iter04:
motion_features = self.encoder(disp, corr)
if self.args.n_gru_layers > 1:
net[0] = self.gru04(net[0], *(inp[0]), motion_features, interp(net[1], net[0]))
else:
net[0] = self.gru04(net[0], *(inp[0]), motion_features)
if not update:
return net
delta_disp = self.disp_head(net[0])
mask_feat_4 = self.mask_feat_4(net[0])
return net, mask_feat_4, delta_disp

View File

View File

@ -0,0 +1,321 @@
import numpy as np
import random
import warnings
import os
import time
from glob import glob
from skimage import color, io
from PIL import Image
import cv2
cv2.setNumThreads(0)
cv2.ocl.setUseOpenCL(False)
import torch
from torchvision.transforms import ColorJitter, functional, Compose
import torch.nn.functional as F
def get_middlebury_images():
root = "datasets/Middlebury/MiddEval3"
with open(os.path.join(root, "official_train.txt"), 'r') as f:
lines = f.read().splitlines()
return sorted([os.path.join(root, 'trainingQ', f'{name}/im0.png') for name in lines])
def get_eth3d_images():
return sorted(glob('datasets/ETH3D/two_view_training/*/im0.png'))
def get_kitti_images():
return sorted(glob('datasets/KITTI/training/image_2/*_10.png'))
def transfer_color(image, style_mean, style_stddev):
reference_image_lab = color.rgb2lab(image)
reference_stddev = np.std(reference_image_lab, axis=(0,1), keepdims=True)# + 1
reference_mean = np.mean(reference_image_lab, axis=(0,1), keepdims=True)
reference_image_lab = reference_image_lab - reference_mean
lamb = style_stddev/reference_stddev
style_image_lab = lamb * reference_image_lab
output_image_lab = style_image_lab + style_mean
l, a, b = np.split(output_image_lab, 3, axis=2)
l = l.clip(0, 100)
output_image_lab = np.concatenate((l,a,b), axis=2)
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=UserWarning)
output_image_rgb = color.lab2rgb(output_image_lab) * 255
return output_image_rgb
class AdjustGamma(object):
def __init__(self, gamma_min, gamma_max, gain_min=1.0, gain_max=1.0):
self.gamma_min, self.gamma_max, self.gain_min, self.gain_max = gamma_min, gamma_max, gain_min, gain_max
def __call__(self, sample):
gain = random.uniform(self.gain_min, self.gain_max)
gamma = random.uniform(self.gamma_min, self.gamma_max)
return functional.adjust_gamma(sample, gamma, gain)
def __repr__(self):
return f"Adjust Gamma {self.gamma_min}, ({self.gamma_max}) and Gain ({self.gain_min}, {self.gain_max})"
class FlowAugmentor:
def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True, yjitter=False, saturation_range=[0.6,1.4], gamma=[1,1,1,1]):
# spatial augmentation params
self.crop_size = crop_size
self.min_scale = min_scale
self.max_scale = max_scale
self.spatial_aug_prob = 1.0
self.stretch_prob = 0.8
self.max_stretch = 0.2
# flip augmentation params
self.yjitter = yjitter
self.do_flip = do_flip
self.h_flip_prob = 0.5
self.v_flip_prob = 0.1
# photometric augmentation params
self.photo_aug = Compose([ColorJitter(brightness=0.4, contrast=0.4, saturation=saturation_range, hue=0.5/3.14), AdjustGamma(*gamma)])
self.asymmetric_color_aug_prob = 0.2
self.eraser_aug_prob = 0.5
def color_transform(self, img1, img2):
""" Photometric augmentation """
# asymmetric
if np.random.rand() < self.asymmetric_color_aug_prob:
#print("#####44444", img1.shape)
img1 = np.array(self.photo_aug(Image.fromarray(img1)), dtype=np.uint8)
img2 = np.array(self.photo_aug(Image.fromarray(img2)), dtype=np.uint8)
# symmetric
else:
image_stack = np.concatenate([img1, img2], axis=0)
image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
img1, img2 = np.split(image_stack, 2, axis=0)
return img1, img2
def eraser_transform(self, img1, img2, bounds=[50, 100]):
""" Occlusion augmentation """
ht, wd = img1.shape[:2]
if np.random.rand() < self.eraser_aug_prob:
mean_color = np.mean(img2.reshape(-1, 3), axis=0)
for _ in range(np.random.randint(1, 3)):
x0 = np.random.randint(0, wd)
y0 = np.random.randint(0, ht)
dx = np.random.randint(bounds[0], bounds[1])
dy = np.random.randint(bounds[0], bounds[1])
img2[y0:y0+dy, x0:x0+dx, :] = mean_color
return img1, img2
def spatial_transform(self, img1, img2, flow):
# randomly sample scale
ht, wd = img1.shape[:2]
min_scale = np.maximum(
(self.crop_size[0] + 8) / float(ht),
(self.crop_size[1] + 8) / float(wd))
scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
scale_x = scale
scale_y = scale
if np.random.rand() < self.stretch_prob:
scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
scale_x = np.clip(scale_x, min_scale, None)
scale_y = np.clip(scale_y, min_scale, None)
# print("####22222", flow.shape, scale_x, scale_y)
if np.random.rand() < self.spatial_aug_prob:
# rescale the images
img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
flow = cv2.resize(flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
flow = flow * [scale_x, scale_y]
if self.do_flip:
if np.random.rand() < self.h_flip_prob and self.do_flip == 'hf': # h-flip
img1 = img1[:, ::-1]
img2 = img2[:, ::-1]
flow = flow[:, ::-1] * [-1.0, 1.0]
if np.random.rand() < self.h_flip_prob and self.do_flip == 'h': # h-flip for stereo
tmp = img1[:, ::-1]
img1 = img2[:, ::-1]
img2 = tmp
if np.random.rand() < self.v_flip_prob and self.do_flip == 'v': # v-flip
img1 = img1[::-1, :]
img2 = img2[::-1, :]
flow = flow[::-1, :] * [1.0, -1.0]
if self.yjitter:
y0 = np.random.randint(2, img1.shape[0] - self.crop_size[0] - 2)
x0 = np.random.randint(2, img1.shape[1] - self.crop_size[1] - 2)
y1 = y0 + np.random.randint(-2, 2 + 1)
img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
img2 = img2[y1:y1+self.crop_size[0], x0:x0+self.crop_size[1]]
flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
else:
y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0])
x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1])
img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
return img1, img2, flow
def __call__(self, img1, img2, flow):
img1, img2 = self.color_transform(img1, img2)
img1, img2 = self.eraser_transform(img1, img2)
img1, img2, flow = self.spatial_transform(img1, img2, flow)
img1 = np.ascontiguousarray(img1)
img2 = np.ascontiguousarray(img2)
flow = np.ascontiguousarray(flow)
return img1, img2, flow
class SparseFlowAugmentor:
def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False, yjitter=False, saturation_range=[0.7,1.3], gamma=[1,1,1,1]):
# spatial augmentation params
self.crop_size = crop_size
self.min_scale = min_scale
self.max_scale = max_scale
self.spatial_aug_prob = 0.8
self.stretch_prob = 0.8
self.max_stretch = 0.2
# flip augmentation params
self.do_flip = do_flip
self.h_flip_prob = 0.5
self.v_flip_prob = 0.1
# photometric augmentation params
self.photo_aug = Compose([ColorJitter(brightness=0.3, contrast=0.3, saturation=saturation_range, hue=0.3/3.14), AdjustGamma(*gamma)])
self.asymmetric_color_aug_prob = 0.2
self.eraser_aug_prob = 0.5
def color_transform(self, img1, img2):
image_stack = np.concatenate([img1, img2], axis=0)
image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
img1, img2 = np.split(image_stack, 2, axis=0)
return img1, img2
def eraser_transform(self, img1, img2):
ht, wd = img1.shape[:2]
if np.random.rand() < self.eraser_aug_prob:
mean_color = np.mean(img2.reshape(-1, 3), axis=0)
for _ in range(np.random.randint(1, 3)):
x0 = np.random.randint(0, wd)
y0 = np.random.randint(0, ht)
dx = np.random.randint(50, 100)
dy = np.random.randint(50, 100)
img2[y0:y0+dy, x0:x0+dx, :] = mean_color
return img1, img2
def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0):
ht, wd = flow.shape[:2]
coords = np.meshgrid(np.arange(wd), np.arange(ht))
coords = np.stack(coords, axis=-1)
coords = coords.reshape(-1, 2).astype(np.float32)
flow = flow.reshape(-1, 2).astype(np.float32)
valid = valid.reshape(-1).astype(np.float32)
coords0 = coords[valid>=1]
flow0 = flow[valid>=1]
ht1 = int(round(ht * fy))
wd1 = int(round(wd * fx))
coords1 = coords0 * [fx, fy]
flow1 = flow0 * [fx, fy]
xx = np.round(coords1[:,0]).astype(np.int32)
yy = np.round(coords1[:,1]).astype(np.int32)
v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1)
xx = xx[v]
yy = yy[v]
flow1 = flow1[v]
flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32)
valid_img = np.zeros([ht1, wd1], dtype=np.int32)
flow_img[yy, xx] = flow1
valid_img[yy, xx] = 1
return flow_img, valid_img
def spatial_transform(self, img1, img2, flow, valid):
# randomly sample scale
ht, wd = img1.shape[:2]
min_scale = np.maximum(
(self.crop_size[0] + 1) / float(ht),
(self.crop_size[1] + 1) / float(wd))
scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
scale_x = np.clip(scale, min_scale, None)
scale_y = np.clip(scale, min_scale, None)
if np.random.rand() < self.spatial_aug_prob:
# rescale the images
img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
flow, valid = self.resize_sparse_flow_map(flow, valid, fx=scale_x, fy=scale_y)
if self.do_flip:
if np.random.rand() < self.h_flip_prob and self.do_flip == 'hf': # h-flip
img1 = img1[:, ::-1]
img2 = img2[:, ::-1]
flow = flow[:, ::-1] * [-1.0, 1.0]
if np.random.rand() < self.h_flip_prob and self.do_flip == 'h': # h-flip for stereo
tmp = img1[:, ::-1]
img1 = img2[:, ::-1]
img2 = tmp
if np.random.rand() < self.v_flip_prob and self.do_flip == 'v': # v-flip
img1 = img1[::-1, :]
img2 = img2[::-1, :]
flow = flow[::-1, :] * [1.0, -1.0]
margin_y = 20
margin_x = 50
y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0] + margin_y)
x0 = np.random.randint(-margin_x, img1.shape[1] - self.crop_size[1] + margin_x)
y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0])
x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1])
img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
valid = valid[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
return img1, img2, flow, valid
def __call__(self, img1, img2, flow, valid):
img1, img2 = self.color_transform(img1, img2)
img1, img2 = self.eraser_transform(img1, img2)
img1, img2, flow, valid = self.spatial_transform(img1, img2, flow, valid)
img1 = np.ascontiguousarray(img1)
img2 = np.ascontiguousarray(img2)
flow = np.ascontiguousarray(flow)
valid = np.ascontiguousarray(valid)
return img1, img2, flow, valid

View File

@ -0,0 +1,187 @@
import numpy as np
from PIL import Image
from os.path import *
import re
import json
import imageio
import cv2
cv2.setNumThreads(0)
cv2.ocl.setUseOpenCL(False)
TAG_CHAR = np.array([202021.25], np.float32)
def readFlow(fn):
""" Read .flo file in Middlebury format"""
# Code adapted from:
# http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy
# WARNING: this will work on little-endian architectures (eg Intel x86) only!
# print 'fn = %s'%(fn)
with open(fn, 'rb') as f:
magic = np.fromfile(f, np.float32, count=1)
if 202021.25 != magic:
print('Magic number incorrect. Invalid .flo file')
return None
else:
w = np.fromfile(f, np.int32, count=1)
h = np.fromfile(f, np.int32, count=1)
# print 'Reading %d x %d flo file\n' % (w, h)
data = np.fromfile(f, np.float32, count=2*int(w)*int(h))
# Reshape data into 3D array (columns, rows, bands)
# The reshape here is for visualization, the original code is (w,h,2)
return np.resize(data, (int(h), int(w), 2))
def readPFM(file):
file = open(file, 'rb')
color = None
width = None
height = None
scale = None
endian = None
header = file.readline().rstrip()
if header == b'PF':
color = True
elif header == b'Pf':
color = False
else:
raise Exception('Not a PFM file.')
dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline())
if dim_match:
width, height = map(int, dim_match.groups())
else:
raise Exception('Malformed PFM header.')
scale = float(file.readline().rstrip())
if scale < 0: # little-endian
endian = '<'
scale = -scale
else:
endian = '>' # big-endian
data = np.fromfile(file, endian + 'f')
shape = (height, width, 3) if color else (height, width)
data = np.reshape(data, shape)
data = np.flipud(data)
return data
def writePFM(file, array):
import os
assert type(file) is str and type(array) is np.ndarray and \
os.path.splitext(file)[1] == ".pfm"
with open(file, 'wb') as f:
H, W = array.shape
headers = ["Pf\n", f"{W} {H}\n", "-1\n"]
for header in headers:
f.write(str.encode(header))
array = np.flip(array, axis=0).astype(np.float32)
f.write(array.tobytes())
def writeFlow(filename,uv,v=None):
""" Write optical flow to file.
If v is None, uv is assumed to contain both u and v channels,
stacked in depth.
Original code by Deqing Sun, adapted from Daniel Scharstein.
"""
nBands = 2
if v is None:
assert(uv.ndim == 3)
assert(uv.shape[2] == 2)
u = uv[:,:,0]
v = uv[:,:,1]
else:
u = uv
assert(u.shape == v.shape)
height,width = u.shape
f = open(filename,'wb')
# write the header
f.write(TAG_CHAR)
np.array(width).astype(np.int32).tofile(f)
np.array(height).astype(np.int32).tofile(f)
# arrange into matrix form
tmp = np.zeros((height, width*nBands))
tmp[:,np.arange(width)*2] = u
tmp[:,np.arange(width)*2 + 1] = v
tmp.astype(np.float32).tofile(f)
f.close()
def readFlowKITTI(filename):
flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR)
flow = flow[:,:,::-1].astype(np.float32)
flow, valid = flow[:, :, :2], flow[:, :, 2]
flow = (flow - 2**15) / 64.0
return flow, valid
def readDispKITTI(filename):
disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0
valid = disp > 0.0
return disp, valid
# Method taken from /n/fs/raft-depth/RAFT-Stereo/datasets/SintelStereo/sdk/python/sintel_io.py
def readDispSintelStereo(file_name):
a = np.array(Image.open(file_name))
d_r, d_g, d_b = np.split(a, axis=2, indices_or_sections=3)
disp = (d_r * 4 + d_g / (2**6) + d_b / (2**14))[..., 0]
mask = np.array(Image.open(file_name.replace('disparities', 'occlusions')))
valid = ((mask == 0) & (disp > 0))
return disp, valid
# Method taken from https://research.nvidia.com/sites/default/files/pubs/2018-06_Falling-Things/readme_0.txt
def readDispFallingThings(file_name):
a = np.array(Image.open(file_name))
with open('/'.join(file_name.split('/')[:-1] + ['_camera_settings.json']), 'r') as f:
intrinsics = json.load(f)
fx = intrinsics['camera_settings'][0]['intrinsic_settings']['fx']
disp = (fx * 6.0 * 100) / a.astype(np.float32)
valid = disp > 0
return disp, valid
# Method taken from https://github.com/castacks/tartanair_tools/blob/master/data_type.md
def readDispTartanAir(file_name):
depth = np.load(file_name)
disp = 80.0 / depth
valid = disp > 0
return disp, valid
def readDispMiddlebury(file_name):
assert basename(file_name) == 'disp0GT.pfm'
disp = readPFM(file_name).astype(np.float32)
assert len(disp.shape) == 2
nocc_pix = file_name.replace('disp0GT.pfm', 'mask0nocc.png')
assert exists(nocc_pix)
nocc_pix = imageio.imread(nocc_pix) == 255
assert np.any(nocc_pix)
return disp, nocc_pix
def writeFlowKITTI(filename, uv):
uv = 64.0 * uv + 2**15
valid = np.ones([uv.shape[0], uv.shape[1], 1])
uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16)
cv2.imwrite(filename, uv[..., ::-1])
def read_gen(file_name, pil=False):
ext = splitext(file_name)[-1]
if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg':
return Image.open(file_name)
elif ext == '.bin' or ext == '.raw':
return np.load(file_name)
elif ext == '.flo':
return readFlow(file_name).astype(np.float32)
elif ext == '.pfm':
flow = readPFM(file_name).astype(np.float32)
if len(flow.shape) == 2:
return flow
else:
return flow[:, :, :-1]
return []

View File

@ -0,0 +1,97 @@
import torch
import torch.nn.functional as F
import numpy as np
from scipy import interpolate
class InputPadder:
""" Pads images such that dimensions are divisible by 8 """
def __init__(self, dims, mode='sintel', divis_by=8):
self.ht, self.wd = dims[-2:]
pad_ht = (((self.ht // divis_by) + 1) * divis_by - self.ht) % divis_by
pad_wd = (((self.wd // divis_by) + 1) * divis_by - self.wd) % divis_by
if mode == 'sintel':
self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2]
else:
self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht]
def pad(self, *inputs):
assert all((x.ndim == 4) for x in inputs)
return [F.pad(x, self._pad, mode='replicate') for x in inputs]
def unpad(self, x):
assert x.ndim == 4
ht, wd = x.shape[-2:]
c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]]
return x[..., c[0]:c[1], c[2]:c[3]]
def forward_interpolate(flow):
flow = flow.detach().cpu().numpy()
dx, dy = flow[0], flow[1]
ht, wd = dx.shape
x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht))
x1 = x0 + dx
y1 = y0 + dy
x1 = x1.reshape(-1)
y1 = y1.reshape(-1)
dx = dx.reshape(-1)
dy = dy.reshape(-1)
valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht)
x1 = x1[valid]
y1 = y1[valid]
dx = dx[valid]
dy = dy[valid]
flow_x = interpolate.griddata(
(x1, y1), dx, (x0, y0), method='nearest', fill_value=0)
flow_y = interpolate.griddata(
(x1, y1), dy, (x0, y0), method='nearest', fill_value=0)
flow = np.stack([flow_x, flow_y], axis=0)
return torch.from_numpy(flow).float()
def bilinear_sampler(img, coords, mode='bilinear', mask=False):
""" Wrapper for grid_sample, uses pixel coordinates """
H, W = img.shape[-2:]
# print("$$$55555", img.shape, coords.shape)
xgrid, ygrid = coords.split([1,1], dim=-1)
xgrid = 2*xgrid/(W-1) - 1
# print("######88888", xgrid)
assert torch.unique(ygrid).numel() == 1 and H == 1 # This is a stereo problem
grid = torch.cat([xgrid, ygrid], dim=-1)
# print("###37777", grid.shape)
img = F.grid_sample(img, grid, align_corners=True)
if mask:
mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
return img, mask.float()
return img
def coords_grid(batch, ht, wd):
coords = torch.meshgrid(torch.arange(ht), torch.arange(wd))
coords = torch.stack(coords[::-1], dim=0).float()
return coords[None].repeat(batch, 1, 1, 1)
def upflow8(flow, mode='bilinear'):
new_size = (8 * flow.shape[2], 8 * flow.shape[3])
return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
def gauss_blur(input, N=5, std=1):
B, D, H, W = input.shape
x, y = torch.meshgrid(torch.arange(N).float() - N//2, torch.arange(N).float() - N//2)
unnormalized_gaussian = torch.exp(-(x.pow(2) + y.pow(2)) / (2 * std ** 2))
weights = unnormalized_gaussian / unnormalized_gaussian.sum().clamp(min=1e-4)
weights = weights.view(1,1,N,N).to(input)
output = F.conv2d(input.reshape(B*D,1,H,W), weights, padding=N//2)
return output.view(B, D, H, W)

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.1 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.1 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.9 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.9 MiB

88
IGEV-Stereo/demo_imgs.py Normal file
View File

@ -0,0 +1,88 @@
import sys
sys.path.append('core')
DEVICE = 'cuda'
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import argparse
import glob
import numpy as np
import torch
from tqdm import tqdm
from pathlib import Path
from igev_stereo import IGEVStereo
from utils.utils import InputPadder
from PIL import Image
from matplotlib import pyplot as plt
import os
import cv2
def load_image(imfile):
img = np.array(Image.open(imfile)).astype(np.uint8)
img = torch.from_numpy(img).permute(2, 0, 1).float()
return img[None].to(DEVICE)
def demo(args):
model = torch.nn.DataParallel(IGEVStereo(args), device_ids=[0])
model.load_state_dict(torch.load(args.restore_ckpt))
model = model.module
model.to(DEVICE)
model.eval()
output_directory = Path(args.output_directory)
output_directory.mkdir(exist_ok=True)
with torch.no_grad():
left_images = sorted(glob.glob(args.left_imgs, recursive=True))
right_images = sorted(glob.glob(args.right_imgs, recursive=True))
print(f"Found {len(left_images)} images. Saving files to {output_directory}/")
for (imfile1, imfile2) in tqdm(list(zip(left_images, right_images))):
image1 = load_image(imfile1)
image2 = load_image(imfile2)
padder = InputPadder(image1.shape, divis_by=32)
image1, image2 = padder.pad(image1, image2)
disp = model(image1, image2, iters=args.valid_iters, test_mode=True)
disp = disp.cpu().numpy()
disp = padder.unpad(disp)
file_stem = imfile1.split('/')[-2]
filename = os.path.join(output_directory, f"{file_stem}.png")
plt.imsave(output_directory / f"{file_stem}.png", disp.squeeze(), cmap='jet')
# disp = np.round(disp * 256).astype(np.uint16)
# cv2.imwrite(filename, cv2.applyColorMap(cv2.convertScaleAbs(disp.squeeze(), alpha=0.01),cv2.COLORMAP_JET), [int(cv2.IMWRITE_PNG_COMPRESSION), 0])
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--restore_ckpt', help="restore checkpoint", default='./pretrained_models/sceneflow/sceneflow.pth')
parser.add_argument('--save_numpy', action='store_true', help='save output as numpy arrays')
parser.add_argument('-l', '--left_imgs', help="path to all first (left) frames", default="./demo-imgs/*/im0.png")
parser.add_argument('-r', '--right_imgs', help="path to all second (right) frames", default="./demo-imgs/*/im1.png")
# parser.add_argument('-l', '--left_imgs', help="path to all first (left) frames", default="/data/Middlebury/trainingH/*/im0.png")
# parser.add_argument('-r', '--right_imgs', help="path to all second (right) frames", default="/data/Middlebury/trainingH/*/im1.png")
# parser.add_argument('-l', '--left_imgs', help="path to all first (left) frames", default="/data/ETH3D/two_view_training/*/im0.png")
# parser.add_argument('-r', '--right_imgs', help="path to all second (right) frames", default="/data/ETH3D/two_view_training/*/im1.png")
parser.add_argument('--output_directory', help="directory to save output", default="./demo-output/")
parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
parser.add_argument('--valid_iters', type=int, default=32, help='number of flow-field updates during forward pass')
# Architecture choices
parser.add_argument('--hidden_dims', nargs='+', type=int, default=[128]*3, help="hidden state and context dimensions")
parser.add_argument('--corr_implementation', choices=["reg", "alt", "reg_cuda", "alt_cuda"], default="reg", help="correlation volume implementation")
parser.add_argument('--shared_backbone', action='store_true', help="use a single backbone for the context and feature encoders")
parser.add_argument('--corr_levels', type=int, default=2, help="number of levels in the correlation pyramid")
parser.add_argument('--corr_radius', type=int, default=4, help="width of the correlation pyramid")
parser.add_argument('--n_downsample', type=int, default=2, help="resolution of the disparity field (1/2^K)")
parser.add_argument('--slow_fast_gru', action='store_true', help="iterate the low-res GRUs more frequently")
parser.add_argument('--n_gru_layers', type=int, default=3, help="number of hidden GRU levels")
parser.add_argument('--max_disp', type=int, default=192, help="max disp of geometry encoding volume")
args = parser.parse_args()
Path(args.output_directory).mkdir(exist_ok=True, parents=True)
demo(args)

94
IGEV-Stereo/demo_video.py Normal file
View File

@ -0,0 +1,94 @@
import sys
sys.path.append('core')
import cv2
import numpy as np
import glob
from pathlib import Path
from tqdm import tqdm
import torch
from PIL import Image
from igev_stereo import IGEVStereo
import os
import argparse
from utils.utils import InputPadder
torch.backends.cudnn.benchmark = True
half_precision = True
DEVICE = 'cuda'
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
parser = argparse.ArgumentParser(description='Iterative Geometry Encoding Volume for Stereo Matching and Multi-View Stereo (IGEV-Stereo)')
parser.add_argument('--restore_ckpt', help="restore checkpoint", default='./pretrained_models/kitti/kitti15.pth')
parser.add_argument('--save_numpy', action='store_true', help='save output as numpy arrays')
parser.add_argument('-l', '--left_imgs', help="path to all first (left) frames", default="/data/KITTI_raw/2011_09_26/2011_09_26_drive_0005_sync/image_02/data/*.png")
parser.add_argument('-r', '--right_imgs', help="path to all second (right) frames", default="/data/KITTI_raw/2011_09_26/2011_09_26_drive_0005_sync/image_03/data/*.png")
parser.add_argument('--mixed_precision', default=True, action='store_true', help='use mixed precision')
parser.add_argument('--valid_iters', type=int, default=16, help='number of flow-field updates during forward pass')
parser.add_argument('--hidden_dims', nargs='+', type=int, default=[128]*3, help="hidden state and context dimensions")
parser.add_argument('--corr_implementation', choices=["reg", "alt", "reg_cuda", "alt_cuda"], default="reg", help="correlation volume implementation")
parser.add_argument('--shared_backbone', action='store_true', help="use a single backbone for the context and feature encoders")
parser.add_argument('--corr_levels', type=int, default=2, help="number of levels in the correlation pyramid")
parser.add_argument('--corr_radius', type=int, default=4, help="width of the correlation pyramid")
parser.add_argument('--n_downsample', type=int, default=2, help="resolution of the disparity field (1/2^K)")
parser.add_argument('--slow_fast_gru', action='store_true', help="iterate the low-res GRUs more frequently")
parser.add_argument('--n_gru_layers', type=int, default=3, help="number of hidden GRU levels")
args = parser.parse_args()
model = torch.nn.DataParallel(IGEVStereo(args), device_ids=[0])
model.load_state_dict(torch.load(args.restore_ckpt))
model = model.module
model.to(DEVICE)
model.eval()
left_images = sorted(glob.glob(args.left_imgs, recursive=True))
right_images = sorted(glob.glob(args.right_imgs, recursive=True))
print(f"Found {len(left_images)} images.")
def load_image(imfile):
img = np.array(Image.open(imfile)).astype(np.uint8)
img = torch.from_numpy(img).permute(2, 0, 1).float()
return img[None].to(DEVICE)
if __name__ == '__main__':
fps_list = np.array([])
videoWrite = cv2.VideoWriter('./IGEV_Stereo.mp4', cv2.VideoWriter_fourcc(*'mp4v'), 10, (1242, 750))
for (imfile1, imfile2) in tqdm(list(zip(left_images, right_images))):
image1 = load_image(imfile1)
image2 = load_image(imfile2)
padder = InputPadder(image1.shape, divis_by=32)
image1_pad, image2_pad = padder.pad(image1, image2)
torch.cuda.synchronize()
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
with torch.no_grad():
with torch.cuda.amp.autocast(enabled=half_precision):
disp = model(image1_pad, image2_pad, iters=16, test_mode=True)
disp = padder.unpad(disp)
end.record()
torch.cuda.synchronize()
runtime = start.elapsed_time(end)
fps = 1000/runtime
fps_list = np.append(fps_list, fps)
if len(fps_list) > 5:
fps_list = fps_list[-5:]
avg_fps = np.mean(fps_list)
print('Stereo runtime: {:.3f}'.format(1000/avg_fps))
disp_np = (2*disp).data.cpu().numpy().squeeze().astype(np.uint8)
disp_np = cv2.applyColorMap(disp_np, cv2.COLORMAP_PLASMA)
image_np = np.array(Image.open(imfile1)).astype(np.uint8)
out_img = np.concatenate((image_np, disp_np), 0)
cv2.putText(
out_img,
"%.1f fps" % (avg_fps),
(10, image_np.shape[0]+30),
cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA)
cv2.imshow('img', out_img)
cv2.waitKey(1)
videoWrite.write(out_img)
videoWrite.release()

View File

@ -0,0 +1,276 @@
from __future__ import print_function, division
import sys
sys.path.append('core')
import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import argparse
import time
import logging
import numpy as np
import torch
from tqdm import tqdm
from igev_stereo import IGEVStereo, autocast
import stereo_datasets as datasets
from utils.utils import InputPadder
from PIL import Image
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
@torch.no_grad()
def validate_eth3d(model, iters=32, mixed_prec=False):
""" Peform validation using the ETH3D (train) split """
model.eval()
aug_params = {}
val_dataset = datasets.ETH3D(aug_params)
out_list, epe_list = [], []
for val_id in range(len(val_dataset)):
(imageL_file, imageR_file, GT_file), image1, image2, flow_gt, valid_gt = val_dataset[val_id]
image1 = image1[None].cuda()
image2 = image2[None].cuda()
padder = InputPadder(image1.shape, divis_by=32)
image1, image2 = padder.pad(image1, image2)
with autocast(enabled=mixed_prec):
if iters == 0:
flow_pr = model(image1, image2, iters=iters, test_mode=True)
else:
_, flow_pr = model(image1, image2, iters=iters, test_mode=True)
flow_pr = padder.unpad(flow_pr.float()).cpu().squeeze(0)
assert flow_pr.shape == flow_gt.shape, (flow_pr.shape, flow_gt.shape)
epe = torch.sum((flow_pr - flow_gt)**2, dim=0).sqrt()
epe_flattened = epe.flatten()
occ_mask = Image.open(GT_file.replace('disp0GT.pfm', 'mask0nocc.png'))
occ_mask = np.ascontiguousarray(occ_mask).flatten()
val = (valid_gt.flatten() >= 0.5) & (occ_mask == 255)
# val = (valid_gt.flatten() >= 0.5)
out = (epe_flattened > 1.0)
image_out = out[val].float().mean().item()
image_epe = epe_flattened[val].mean().item()
logging.info(f"ETH3D {val_id+1} out of {len(val_dataset)}. EPE {round(image_epe,4)} D1 {round(image_out,4)}")
epe_list.append(image_epe)
out_list.append(image_out)
epe_list = np.array(epe_list)
out_list = np.array(out_list)
epe = np.mean(epe_list)
d1 = 100 * np.mean(out_list)
print("Validation ETH3D: EPE %f, D1 %f" % (epe, d1))
return {'eth3d-epe': epe, 'eth3d-d1': d1}
@torch.no_grad()
def validate_kitti(model, iters=32, mixed_prec=False):
""" Peform validation using the KITTI-2015 (train) split """
model.eval()
aug_params = {}
val_dataset = datasets.KITTI(aug_params, image_set='training')
torch.backends.cudnn.benchmark = True
out_list, epe_list, elapsed_list = [], [], []
for val_id in range(len(val_dataset)):
_, image1, image2, flow_gt, valid_gt = val_dataset[val_id]
image1 = image1[None].cuda()
image2 = image2[None].cuda()
padder = InputPadder(image1.shape, divis_by=32)
image1, image2 = padder.pad(image1, image2)
with autocast(enabled=mixed_prec):
start = time.time()
if iters == 0:
flow_pr = model(image1, image2, iters=iters, test_mode=True)
else:
_, flow_pr = model(image1, image2, iters=iters, test_mode=True)
end = time.time()
if val_id > 50:
elapsed_list.append(end-start)
flow_pr = padder.unpad(flow_pr).cpu().squeeze(0)
assert flow_pr.shape == flow_gt.shape, (flow_pr.shape, flow_gt.shape)
epe = torch.sum((flow_pr - flow_gt)**2, dim=0).sqrt()
epe_flattened = epe.flatten()
val = (valid_gt.flatten() >= 0.5) & (flow_gt.abs().flatten() < 192)
# val = valid_gt.flatten() >= 0.5
out = (epe_flattened > 3.0)
image_out = out[val].float().mean().item()
image_epe = epe_flattened[val].mean().item()
if val_id < 9 or (val_id+1)%10 == 0:
logging.info(f"KITTI Iter {val_id+1} out of {len(val_dataset)}. EPE {round(image_epe,4)} D1 {round(image_out,4)}. Runtime: {format(end-start, '.3f')}s ({format(1/(end-start), '.2f')}-FPS)")
epe_list.append(epe_flattened[val].mean().item())
out_list.append(out[val].cpu().numpy())
epe_list = np.array(epe_list)
out_list = np.concatenate(out_list)
epe = np.mean(epe_list)
d1 = 100 * np.mean(out_list)
avg_runtime = np.mean(elapsed_list)
print(f"Validation KITTI: EPE {epe}, D1 {d1}, {format(1/avg_runtime, '.2f')}-FPS ({format(avg_runtime, '.3f')}s)")
return {'kitti-epe': epe, 'kitti-d1': d1}
@torch.no_grad()
def validate_sceneflow(model, iters=32, mixed_prec=False):
""" Peform validation using the Scene Flow (TEST) split """
model.eval()
val_dataset = datasets.SceneFlowDatasets(dstype='frames_finalpass', things_test=True)
out_list, epe_list = [], []
for val_id in tqdm(range(len(val_dataset))):
_, image1, image2, flow_gt, valid_gt = val_dataset[val_id]
image1 = image1[None].cuda()
image2 = image2[None].cuda()
padder = InputPadder(image1.shape, divis_by=32)
image1, image2 = padder.pad(image1, image2)
with autocast(enabled=mixed_prec):
if iters == 0:
flow_pr = model(image1, image2, iters=iters, test_mode=True)
else:
flow_pr = model(image1, image2, iters=iters, test_mode=True)
flow_pr = padder.unpad(flow_pr).cpu().squeeze(0)
assert flow_pr.shape == flow_gt.shape, (flow_pr.shape, flow_gt.shape)
# epe = torch.sum((flow_pr - flow_gt)**2, dim=0).sqrt()
epe = torch.abs(flow_pr - flow_gt)
epe = epe.flatten()
val = (valid_gt.flatten() >= 0.5) & (flow_gt.abs().flatten() < 192)
if(np.isnan(epe[val].mean().item())):
continue
out = (epe > 3.0)
epe_list.append(epe[val].mean().item())
out_list.append(out[val].cpu().numpy())
# if val_id == 400:
# break
epe_list = np.array(epe_list)
out_list = np.concatenate(out_list)
epe = np.mean(epe_list)
d1 = 100 * np.mean(out_list)
f = open('test.txt', 'a')
f.write("Validation Scene Flow: %f, %f\n" % (epe, d1))
print("Validation Scene Flow: %f, %f" % (epe, d1))
return {'scene-flow-epe': epe, 'scene-flow-d1': d1}
@torch.no_grad()
def validate_middlebury(model, iters=32, split='F', mixed_prec=False):
""" Peform validation using the Middlebury-V3 dataset """
model.eval()
aug_params = {}
val_dataset = datasets.Middlebury(aug_params, split=split)
out_list, epe_list = [], []
for val_id in range(len(val_dataset)):
(imageL_file, _, _), image1, image2, flow_gt, valid_gt = val_dataset[val_id]
image1 = image1[None].cuda()
image2 = image2[None].cuda()
padder = InputPadder(image1.shape, divis_by=32)
image1, image2 = padder.pad(image1, image2)
with autocast(enabled=mixed_prec):
if iters == 0:
flow_pr = model(image1, image2, iters=iters, test_mode=True)
else:
_, flow_pr = model(image1, image2, iters=iters, test_mode=True)
flow_pr = padder.unpad(flow_pr).cpu().squeeze(0)
assert flow_pr.shape == flow_gt.shape, (flow_pr.shape, flow_gt.shape)
epe = torch.sum((flow_pr - flow_gt)**2, dim=0).sqrt()
epe_flattened = epe.flatten()
occ_mask = Image.open(imageL_file.replace('im0.png', 'mask0nocc.png')).convert('L')
occ_mask = np.ascontiguousarray(occ_mask, dtype=np.float32).flatten()
val = (valid_gt.reshape(-1) >= 0.5) & (flow_gt[0].reshape(-1) < 192) & (occ_mask==255)
out = (epe_flattened > 2.0)
image_out = out[val].float().mean().item()
image_epe = epe_flattened[val].mean().item()
logging.info(f"Middlebury Iter {val_id+1} out of {len(val_dataset)}. EPE {round(image_epe,4)} D1 {round(image_out,4)}")
epe_list.append(image_epe)
out_list.append(image_out)
epe_list = np.array(epe_list)
out_list = np.array(out_list)
epe = np.mean(epe_list)
d1 = 100 * np.mean(out_list)
print(f"Validation Middlebury{split}: EPE {epe}, D1 {d1}")
return {f'middlebury{split}-epe': epe, f'middlebury{split}-d1': d1}
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--restore_ckpt', help="restore checkpoint", default='./pretrained_models/sceneflow/sceneflow.pth')
parser.add_argument('--dataset', help="dataset for evaluation", default='sceneflow', choices=["eth3d", "kitti", "sceneflow"] + [f"middlebury_{s}" for s in 'FHQ'])
parser.add_argument('--mixed_precision', default=False, action='store_true', help='use mixed precision')
parser.add_argument('--valid_iters', type=int, default=32, help='number of flow-field updates during forward pass')
# Architecure choices
parser.add_argument('--hidden_dims', nargs='+', type=int, default=[128]*3, help="hidden state and context dimensions")
parser.add_argument('--corr_implementation', choices=["reg", "alt", "reg_cuda", "alt_cuda"], default="reg", help="correlation volume implementation")
parser.add_argument('--shared_backbone', action='store_true', help="use a single backbone for the context and feature encoders")
parser.add_argument('--corr_levels', type=int, default=2, help="number of levels in the correlation pyramid")
parser.add_argument('--corr_radius', type=int, default=4, help="width of the correlation pyramid")
parser.add_argument('--n_downsample', type=int, default=2, help="resolution of the disparity field (1/2^K)")
parser.add_argument('--slow_fast_gru', action='store_true', help="iterate the low-res GRUs more frequently")
parser.add_argument('--n_gru_layers', type=int, default=3, help="number of hidden GRU levels")
parser.add_argument('--max_disp', type=int, default=192, help="max disp of geometry encoding volume")
args = parser.parse_args()
model = torch.nn.DataParallel(IGEVStereo(args), device_ids=[0])
logging.basicConfig(level=logging.INFO,
format='%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s')
if args.restore_ckpt is not None:
assert args.restore_ckpt.endswith(".pth")
logging.info("Loading checkpoint...")
checkpoint = torch.load(args.restore_ckpt)
model.load_state_dict(checkpoint, strict=True)
logging.info(f"Done loading checkpoint")
model.cuda()
model.eval()
print(f"The model has {format(count_parameters(model)/1e6, '.2f')}M learnable parameters.")
use_mixed_precision = args.corr_implementation.endswith("_cuda")
if args.dataset == 'eth3d':
validate_eth3d(model, iters=args.valid_iters, mixed_prec=use_mixed_precision)
elif args.dataset == 'kitti':
validate_kitti(model, iters=args.valid_iters, mixed_prec=use_mixed_precision)
elif args.dataset in [f"middlebury_{s}" for s in 'FHQ']:
validate_middlebury(model, iters=args.valid_iters, split=args.dataset[-1], mixed_prec=use_mixed_precision)
elif args.dataset == 'sceneflow':
validate_sceneflow(model, iters=args.valid_iters, mixed_prec=use_mixed_precision)

82
IGEV-Stereo/save_disp.py Normal file
View File

@ -0,0 +1,82 @@
import sys
sys.path.append('core')
import argparse
import glob
import numpy as np
import torch
from tqdm import tqdm
from pathlib import Path
from igev_stereo import IGEVStereo
from utils.utils import InputPadder
from PIL import Image
from matplotlib import pyplot as plt
import os
import skimage.io
import cv2
DEVICE = 'cuda'
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
def load_image(imfile):
img = np.array(Image.open(imfile)).astype(np.uint8)
img = torch.from_numpy(img).permute(2, 0, 1).float()
return img[None].to(DEVICE)
def demo(args):
model = torch.nn.DataParallel(IGEVStereo(args), device_ids=[0])
model.load_state_dict(torch.load(args.restore_ckpt))
model = model.module
model.to(DEVICE)
model.eval()
output_directory = Path(args.output_directory)
output_directory.mkdir(exist_ok=True)
with torch.no_grad():
left_images = sorted(glob.glob(args.left_imgs, recursive=True))
right_images = sorted(glob.glob(args.right_imgs, recursive=True))
print(f"Found {len(left_images)} images. Saving files to {output_directory}/")
for (imfile1, imfile2) in tqdm(list(zip(left_images, right_images))):
image1 = load_image(imfile1)
image2 = load_image(imfile2)
padder = InputPadder(image1.shape, divis_by=32)
image1, image2 = padder.pad(image1, image2)
disp = model(image1, image2, iters=args.valid_iters, test_mode=True)
disp = padder.unpad(disp)
file_stem = os.path.join(output_directory, imfile1.split('/')[-1])
disp = disp.cpu().numpy().squeeze()
disp = np.round(disp * 256).astype(np.uint16)
skimage.io.imsave(file_stem, disp)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--restore_ckpt', help="restore checkpoint", default='./pretrained_models/kitti/kitti15.pth')
parser.add_argument('--save_numpy', action='store_true', help='save output as numpy arrays')
parser.add_argument('-l', '--left_imgs', help="path to all first (left) frames", default="/data/KITTI/KITTI_2015/testing/image_2/*_10.png")
parser.add_argument('-r', '--right_imgs', help="path to all second (right) frames", default="/data/KITTI/KITTI_2015/testing/image_3/*_10.png")
# parser.add_argument('-l', '--left_imgs', help="path to all first (left) frames", default="/data/KITTI/KITTI_2012/testing/colored_0/*_10.png")
# parser.add_argument('-r', '--right_imgs', help="path to all second (right) frames", default="/data/KITTI/KITTI_2012/testing/colored_1/*_10.png")
parser.add_argument('--output_directory', help="directory to save output", default="output")
parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
parser.add_argument('--valid_iters', type=int, default=16, help='number of flow-field updates during forward pass')
# Architecture choices
parser.add_argument('--hidden_dims', nargs='+', type=int, default=[128]*3, help="hidden state and context dimensions")
parser.add_argument('--corr_implementation', choices=["reg", "alt", "reg_cuda", "alt_cuda"], default="reg", help="correlation volume implementation")
parser.add_argument('--shared_backbone', action='store_true', help="use a single backbone for the context and feature encoders")
parser.add_argument('--corr_levels', type=int, default=2, help="number of levels in the correlation pyramid")
parser.add_argument('--corr_radius', type=int, default=4, help="width of the correlation pyramid")
parser.add_argument('--n_downsample', type=int, default=2, help="resolution of the disparity field (1/2^K)")
parser.add_argument('--slow_fast_gru', action='store_true', help="iterate the low-res GRUs more frequently")
parser.add_argument('--n_gru_layers', type=int, default=3, help="number of hidden GRU levels")
parser.add_argument('--max_disp', type=int, default=192, help="max disp of geometry encoding volume")
args = parser.parse_args()
demo(args)

256
IGEV-Stereo/train_stereo.py Normal file
View File

@ -0,0 +1,256 @@
from __future__ import print_function, division
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0, 1'
import argparse
import logging
import numpy as np
from pathlib import Path
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
import torch
import torch.nn as nn
import torch.optim as optim
from core.igev_stereo import IGEVStereo
from evaluate_stereo import *
import core.stereo_datasets as datasets
import torch.nn.functional as F
ckpt_path = './checkpoints/igev_stereo'
log_path = './checkpoints/igev_stereo'
try:
from torch.cuda.amp import GradScaler
except:
class GradScaler:
def __init__(self):
pass
def scale(self, loss):
return loss
def unscale_(self, optimizer):
pass
def step(self, optimizer):
optimizer.step()
def update(self):
pass
def sequence_loss(disp_preds, disp_init_pred, disp_gt, valid, loss_gamma=0.9, max_disp=192):
""" Loss function defined over sequence of flow predictions """
n_predictions = len(disp_preds)
assert n_predictions >= 1
disp_loss = 0.0
mag = torch.sum(disp_gt**2, dim=1).sqrt()
valid = ((valid >= 0.5) & (mag < max_disp)).unsqueeze(1)
assert valid.shape == disp_gt.shape, [valid.shape, disp_gt.shape]
assert not torch.isinf(disp_gt[valid.bool()]).any()
disp_loss += 1.0 * F.smooth_l1_loss(disp_init_pred[valid.bool()], disp_gt[valid.bool()], size_average=True)
for i in range(n_predictions):
adjusted_loss_gamma = loss_gamma**(15/(n_predictions - 1))
i_weight = adjusted_loss_gamma**(n_predictions - i - 1)
i_loss = (disp_preds[i] - disp_gt).abs()
assert i_loss.shape == valid.shape, [i_loss.shape, valid.shape, disp_gt.shape, disp_preds[i].shape]
disp_loss += i_weight * i_loss[valid.bool()].mean()
epe = torch.sum((disp_preds[-1] - disp_gt)**2, dim=1).sqrt()
epe = epe.view(-1)[valid.view(-1)]
metrics = {
'epe': epe.mean().item(),
'1px': (epe < 1).float().mean().item(),
'3px': (epe < 3).float().mean().item(),
'5px': (epe < 5).float().mean().item(),
}
return disp_loss, metrics
def fetch_optimizer(args, model):
""" Create the optimizer and learning rate scheduler """
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wdecay, eps=1e-8)
scheduler = optim.lr_scheduler.OneCycleLR(optimizer, args.lr, args.num_steps+100,
pct_start=0.01, cycle_momentum=False, anneal_strategy='linear')
return optimizer, scheduler
class Logger:
SUM_FREQ = 100
def __init__(self, model, scheduler):
self.model = model
self.scheduler = scheduler
self.total_steps = 0
self.running_loss = {}
self.writer = SummaryWriter(log_dir=log_path)
def _print_training_status(self):
metrics_data = [self.running_loss[k]/Logger.SUM_FREQ for k in sorted(self.running_loss.keys())]
training_str = "[{:6d}, {:10.7f}] ".format(self.total_steps+1, self.scheduler.get_last_lr()[0])
metrics_str = ("{:10.4f}, "*len(metrics_data)).format(*metrics_data)
# print the training status
logging.info(f"Training Metrics ({self.total_steps}): {training_str + metrics_str}")
if self.writer is None:
self.writer = SummaryWriter(log_dir=log_path)
for k in self.running_loss:
self.writer.add_scalar(k, self.running_loss[k]/Logger.SUM_FREQ, self.total_steps)
self.running_loss[k] = 0.0
def push(self, metrics):
self.total_steps += 1
for key in metrics:
if key not in self.running_loss:
self.running_loss[key] = 0.0
self.running_loss[key] += metrics[key]
if self.total_steps % Logger.SUM_FREQ == Logger.SUM_FREQ-1:
self._print_training_status()
self.running_loss = {}
def write_dict(self, results):
if self.writer is None:
self.writer = SummaryWriter(log_dir=log_path)
for key in results:
self.writer.add_scalar(key, results[key], self.total_steps)
def close(self):
self.writer.close()
def train(args):
model = nn.DataParallel(IGEVStereo(args))
print("Parameter Count: %d" % count_parameters(model))
train_loader = datasets.fetch_dataloader(args)
optimizer, scheduler = fetch_optimizer(args, model)
total_steps = 0
logger = Logger(model, scheduler)
if args.restore_ckpt is not None:
assert args.restore_ckpt.endswith(".pth")
logging.info("Loading checkpoint...")
checkpoint = torch.load(args.restore_ckpt)
model.load_state_dict(checkpoint, strict=True)
logging.info(f"Done loading checkpoint")
model.cuda()
model.train()
model.module.freeze_bn() # We keep BatchNorm frozen
validation_frequency = 10000
scaler = GradScaler(enabled=args.mixed_precision)
should_keep_training = True
global_batch_num = 0
while should_keep_training:
for i_batch, (_, *data_blob) in enumerate(tqdm(train_loader)):
optimizer.zero_grad()
image1, image2, disp_gt, valid = [x.cuda() for x in data_blob]
assert model.training
disp_init_pred, disp_preds = model(image1, image2, iters=args.train_iters)
assert model.training
loss, metrics = sequence_loss(disp_preds, disp_init_pred, disp_gt, valid, max_disp=args.max_disp)
logger.writer.add_scalar("live_loss", loss.item(), global_batch_num)
logger.writer.add_scalar(f'learning_rate', optimizer.param_groups[0]['lr'], global_batch_num)
global_batch_num += 1
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
scaler.step(optimizer)
scheduler.step()
scaler.update()
logger.push(metrics)
if total_steps % validation_frequency == validation_frequency - 1:
save_path = Path(ckpt_path + '/%d_%s.pth' % (total_steps + 1, args.name))
logging.info(f"Saving file {save_path.absolute()}")
torch.save(model.state_dict(), save_path)
results = validate_sceneflow(model.module, iters=args.valid_iters)
logger.write_dict(results)
model.train()
model.module.freeze_bn()
total_steps += 1
if total_steps > args.num_steps:
should_keep_training = False
break
if len(train_loader) >= 10000:
save_path = Path(ckpt_path + '/%d_epoch_%s.pth.gz' % (total_steps + 1, args.name))
logging.info(f"Saving file {save_path}")
torch.save(model.state_dict(), save_path)
print("FINISHED TRAINING")
logger.close()
PATH = ckpt_path + '/%s.pth' % args.name
torch.save(model.state_dict(), PATH)
return PATH
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--name', default='igev-stereo', help="name your experiment")
parser.add_argument('--restore_ckpt', default=None, help="")
parser.add_argument('--mixed_precision', default=True, action='store_true', help='use mixed precision')
# Training parameters
parser.add_argument('--batch_size', type=int, default=8, help="batch size used during training.")
parser.add_argument('--train_datasets', nargs='+', default=['sceneflow'], help="training datasets.")
parser.add_argument('--lr', type=float, default=0.0002, help="max learning rate.")
parser.add_argument('--num_steps', type=int, default=200000, help="length of training schedule.")
parser.add_argument('--image_size', type=int, nargs='+', default=[320, 736], help="size of the random image crops used during training.")
parser.add_argument('--train_iters', type=int, default=22, help="number of updates to the disparity field in each forward pass.")
parser.add_argument('--wdecay', type=float, default=.00001, help="Weight decay in optimizer.")
# Validation parameters
parser.add_argument('--valid_iters', type=int, default=32, help='number of flow-field updates during validation forward pass')
# Architecure choices
parser.add_argument('--corr_implementation', choices=["reg", "alt", "reg_cuda", "alt_cuda"], default="reg", help="correlation volume implementation")
parser.add_argument('--shared_backbone', action='store_true', help="use a single backbone for the context and feature encoders")
parser.add_argument('--corr_levels', type=int, default=2, help="number of levels in the correlation pyramid")
parser.add_argument('--corr_radius', type=int, default=4, help="width of the correlation pyramid")
parser.add_argument('--n_downsample', type=int, default=2, help="resolution of the disparity field (1/2^K)")
parser.add_argument('--slow_fast_gru', action='store_true', help="iterate the low-res GRUs more frequently")
parser.add_argument('--n_gru_layers', type=int, default=3, help="number of hidden GRU levels")
parser.add_argument('--hidden_dims', nargs='+', type=int, default=[128]*3, help="hidden state and context dimensions")
parser.add_argument('--max_disp', type=int, default=192, help="max disp of geometry encoding volume")
# Data augmentation
parser.add_argument('--img_gamma', type=float, nargs='+', default=None, help="gamma range")
parser.add_argument('--saturation_range', type=float, nargs='+', default=[0, 1.4], help='color saturation')
parser.add_argument('--do_flip', default=False, choices=['h', 'v'], help='flip the images horizontally or vertically')
parser.add_argument('--spatial_scale', type=float, nargs='+', default=[-0.2, 0.4], help='re-scale the images randomly')
parser.add_argument('--noyjitter', action='store_true', help='don\'t simulate imperfect rectification')
args = parser.parse_args()
torch.manual_seed(666)
np.random.seed(666)
logging.basicConfig(level=logging.INFO,
format='%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s')
Path(ckpt_path).mkdir(exist_ok=True, parents=True)
train(args)