Add files via upload

This commit is contained in:
Gangwei Xu 2023-03-20 19:52:04 +08:00 committed by GitHub
parent 84baaf7d8e
commit e079f027ff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
29 changed files with 4062 additions and 0 deletions

View File

61
IGEV-MVS/core/corr.py Normal file
View File

@ -0,0 +1,61 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from .submodule import *
class CorrBlock1D_Cost_Volume:
def __init__(self, init_corr, corr, num_levels=2, radius=4, inverse_depth_min=None, inverse_depth_max=None, num_sample=None):
self.num_levels = 2
self.radius = radius
self.inverse_depth_min = inverse_depth_min
self.inverse_depth_max = inverse_depth_max
self.num_sample = num_sample
self.corr_pyramid = []
self.init_corr_pyramid = []
# all pairs correlation
# batch, h1, w1, dim, w2 = corr.shape
b, c, d, h, w = corr.shape
corr = corr.permute(0, 3, 4, 1, 2).reshape(b*h*w, 1, 1, d)
init_corr = init_corr.permute(0, 3, 4, 1, 2).reshape(b*h*w, 1, 1, d)
self.corr_pyramid.append(corr)
self.init_corr_pyramid.append(init_corr)
for i in range(self.num_levels):
corr = F.avg_pool2d(corr, [1,2], stride=[1,2])
self.corr_pyramid.append(corr)
for i in range(self.num_levels):
init_corr = F.avg_pool2d(init_corr, [1,2], stride=[1,2])
self.init_corr_pyramid.append(init_corr)
def __call__(self, disp):
r = self.radius
b, _, h, w = disp.shape
out_pyramid = []
for i in range(self.num_levels):
corr = self.corr_pyramid[i]
init_corr = self.init_corr_pyramid[i]
dx = torch.linspace(-r, r, 2*r+1)
dx = dx.view(1, 1, 2*r+1, 1).to(disp.device)
x0 = dx + disp.reshape(b*h*w, 1, 1, 1) / 2**i
y0 = torch.zeros_like(x0)
disp_lvl = torch.cat([x0,y0], dim=-1)
corr = bilinear_sampler(corr, disp_lvl)
corr = corr.view(b, h, w, -1)
init_corr = bilinear_sampler(init_corr, disp_lvl)
init_corr = init_corr.view(b, h, w, -1)
out_pyramid.append(corr)
out_pyramid.append(init_corr)
out = torch.cat(out_pyramid, dim=-1)
return out.permute(0, 3, 1, 2).contiguous().float()

212
IGEV-MVS/core/extractor.py Normal file
View File

@ -0,0 +1,212 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
import math
from .submodule import *
class ResidualBlock(nn.Module):
def __init__(self, in_planes, planes, norm_fn='group', stride=1):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
self.relu = nn.ReLU(inplace=True)
num_groups = planes // 8
if norm_fn == 'group':
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
if not (stride == 1 and in_planes == planes):
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
elif norm_fn == 'batch':
self.norm1 = nn.BatchNorm2d(planes)
self.norm2 = nn.BatchNorm2d(planes)
if not (stride == 1 and in_planes == planes):
self.norm3 = nn.BatchNorm2d(planes)
elif norm_fn == 'instance':
self.norm1 = nn.InstanceNorm2d(planes)
self.norm2 = nn.InstanceNorm2d(planes)
if not (stride == 1 and in_planes == planes):
self.norm3 = nn.InstanceNorm2d(planes)
elif norm_fn == 'none':
self.norm1 = nn.Sequential()
self.norm2 = nn.Sequential()
if not (stride == 1 and in_planes == planes):
self.norm3 = nn.Sequential()
if stride == 1 and in_planes == planes:
self.downsample = None
else:
self.downsample = nn.Sequential(
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
def forward(self, x):
y = x
y = self.conv1(y)
y = self.norm1(y)
y = self.relu(y)
y = self.conv2(y)
y = self.norm2(y)
y = self.relu(y)
if self.downsample is not None:
x = self.downsample(x)
return self.relu(x+y)
class MultiBasicEncoder(nn.Module):
def __init__(self, output_dim=[128], norm_fn='batch', dropout=0.0, downsample=3):
super(MultiBasicEncoder, self).__init__()
self.norm_fn = norm_fn
self.downsample = downsample
if self.norm_fn == 'group':
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
elif self.norm_fn == 'batch':
self.norm1 = nn.BatchNorm2d(64)
elif self.norm_fn == 'instance':
self.norm1 = nn.InstanceNorm2d(64)
elif self.norm_fn == 'none':
self.norm1 = nn.Sequential()
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=1 + (downsample > 2), padding=3)
self.relu1 = nn.ReLU(inplace=True)
self.in_planes = 64
self.layer1 = self._make_layer(64, stride=1)
self.layer2 = self._make_layer(96, stride=1 + (downsample > 1))
self.layer3 = self._make_layer(128, stride=1 + (downsample > 0))
self.layer4 = self._make_layer(128, stride=2)
self.layer5 = self._make_layer(128, stride=2)
output_list = []
for dim in output_dim:
conv_out = nn.Sequential(
ResidualBlock(128, 128, self.norm_fn, stride=1),
nn.Conv2d(128, dim[2], 3, padding=1))
output_list.append(conv_out)
self.outputs04 = nn.ModuleList(output_list)
output_list = []
for dim in output_dim:
conv_out = nn.Sequential(
ResidualBlock(128, 128, self.norm_fn, stride=1),
nn.Conv2d(128, dim[1], 3, padding=1))
output_list.append(conv_out)
self.outputs08 = nn.ModuleList(output_list)
output_list = []
for dim in output_dim:
conv_out = nn.Conv2d(128, dim[0], 3, padding=1)
output_list.append(conv_out)
self.outputs16 = nn.ModuleList(output_list)
if dropout > 0:
self.dropout = nn.Dropout2d(p=dropout)
else:
self.dropout = None
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
if m.weight is not None:
nn.init.constant_(m.weight, 1)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def _make_layer(self, dim, stride=1):
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
layers = (layer1, layer2)
self.in_planes = dim
return nn.Sequential(*layers)
def forward(self, x, dual_inp=False, num_layers=3):
x = self.conv1(x)
x = self.norm1(x)
x = self.relu1(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
if dual_inp:
v = x
x = x[:(x.shape[0]//2)]
outputs04 = [f(x) for f in self.outputs04]
if num_layers == 1:
return (outputs04, v) if dual_inp else (outputs04,)
y = self.layer4(x)
outputs08 = [f(y) for f in self.outputs08]
if num_layers == 2:
return (outputs04, outputs08, v) if dual_inp else (outputs04, outputs08)
z = self.layer5(y)
outputs16 = [f(z) for f in self.outputs16]
return (outputs04, outputs08, outputs16, v) if dual_inp else (outputs04, outputs08, outputs16)
class Feature(SubModule):
def __init__(self):
super(Feature, self).__init__()
pretrained = True
model = timm.create_model('mobilenetv2_100', pretrained=pretrained, features_only=True)
layers = [1,2,3,5,6]
chans = [16, 24, 32, 96, 160]
self.conv_stem = model.conv_stem
self.bn1 = model.bn1
self.block0 = torch.nn.Sequential(*model.blocks[0:layers[0]])
self.block1 = torch.nn.Sequential(*model.blocks[layers[0]:layers[1]])
self.block2 = torch.nn.Sequential(*model.blocks[layers[1]:layers[2]])
self.block3 = torch.nn.Sequential(*model.blocks[layers[2]:layers[3]])
self.block4 = torch.nn.Sequential(*model.blocks[layers[3]:layers[4]])
self.deconv32_16 = Conv2x_IN(chans[4], chans[3], deconv=True, concat=True)
self.deconv16_8 = Conv2x_IN(chans[3]*2, chans[2], deconv=True, concat=True)
self.deconv8_4 = Conv2x_IN(chans[2]*2, chans[1], deconv=True, concat=True)
self.conv4 = BasicConv_IN(chans[1]*2, chans[1]*2, kernel_size=3, stride=1, padding=1)
def forward(self, x):
B, V, _, H, W = x.size()
x = x.view(B * V, -1, H, W)
#x = self.act1(self.bn1(self.conv_stem(x)))
x = self.bn1(self.conv_stem(x))
x2 = self.block0(x)
x4 = self.block1(x2)
# return x4,x4,x4,x4
x8 = self.block2(x4)
x16 = self.block3(x8)
x32 = self.block4(x16)
x16 = self.deconv32_16(x32, x16)
x8 = self.deconv16_8(x16, x8)
x4 = self.deconv8_4(x8, x4)
x4 = self.conv4(x4)
x4 = x4.view(B, V, -1, H // 4, W // 4)
x8 = x8.view(B, V, -1, H // 8, W // 8)
x16 = x16.view(B, V, -1, H // 16, W // 16)
x32 = x32.view(B, V, -1, H // 32, W // 32)
return [x4, x8, x16, x32]

195
IGEV-MVS/core/igev_mvs.py Normal file
View File

@ -0,0 +1,195 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from .submodule import *
from .corr import *
from .extractor import *
from .update import *
try:
autocast = torch.cuda.amp.autocast
except:
class autocast:
def __init__(self, enabled):
pass
def __enter__(self):
pass
def __exit__(self, *args):
pass
class IGEVMVS(nn.Module):
def __init__(self, args):
super().__init__()
context_dims = [128, 128, 128]
self.n_gru_layers = 3
self.slow_fast_gru = False
self.mixed_precision = True
self.num_sample = 64
self.G = 1
self.corr_radius = 4
self.corr_levels = 2
self.iters = args.iteration
self.update_block = BasicMultiUpdateBlock(hidden_dims=context_dims)
self.conv_hidden_1 = nn.Conv2d(64, 128, kernel_size=3, padding=1, stride=1)
self.conv_hidden_2 = nn.Conv2d(128, 128, kernel_size=3, padding=1, stride=2)
self.conv_hidden_4 = nn.Conv2d(128, 128, kernel_size=3, padding=1, stride=2)
self.feature = Feature()
self.stem_2 = nn.Sequential(
BasicConv_IN(3, 32, kernel_size=3, stride=2, padding=1),
nn.Conv2d(32, 32, 3, 1, 1, bias=False),
nn.InstanceNorm2d(32), nn.ReLU()
)
self.stem_4 = nn.Sequential(
BasicConv_IN(32, 48, kernel_size=3, stride=2, padding=1),
nn.Conv2d(48, 48, 3, 1, 1, bias=False),
nn.InstanceNorm2d(48), nn.ReLU()
)
self.conv = BasicConv_IN(96, 48, kernel_size=3, padding=1, stride=1)
self.desc = nn.Conv2d(48, 48, kernel_size=1, padding=0, stride=1)
self.spx = nn.Sequential(nn.ConvTranspose2d(2*32, 9, kernel_size=4, stride=2, padding=1),)
self.spx_2 = Conv2x_IN(32, 32, True)
self.spx_4 = nn.Sequential(
BasicConv_IN(96, 32, kernel_size=3, stride=1, padding=1),
nn.Conv2d(32, 32, 3, 1, 1, bias=False),
nn.InstanceNorm2d(32), nn.ReLU()
)
self.depth_initialization = DepthInitialization(self.num_sample)
self.pixel_view_weight = PixelViewWeight(self.G)
self.corr_stem = BasicConv(1, 8, is_3d=True, kernel_size=3, stride=1, padding=1)
self.corr_feature_att = FeatureAtt(8, 96)
self.cost_agg = hourglass(8)
self.spx_2_gru = Conv2x(32, 32, True)
self.spx_gru = nn.Sequential(nn.ConvTranspose2d(2*32, 9, kernel_size=4, stride=2, padding=1),)
def upsample_disp(self, depth, mask_feat_4, stem_2x):
with autocast(enabled=self.mixed_precision):
xspx = self.spx_2_gru(mask_feat_4, stem_2x)
spx_pred = self.spx_gru(xspx)
spx_pred = F.softmax(spx_pred, 1)
up_depth = context_upsample(depth, spx_pred).unsqueeze(1)
return up_depth
def forward(self, imgs, proj_matrices, depth_min, depth_max, test_mode=False):
proj_matrices_2 = torch.unbind(proj_matrices['level_2'].float(), 1)
depth_min = depth_min.float()
depth_max = depth_max.float()
ref_proj = proj_matrices_2[0]
src_projs = proj_matrices_2[1:]
with autocast(enabled=self.mixed_precision):
images = torch.unbind(imgs['level_0'], dim=1)
features = self.feature(imgs['level_0'])
ref_feature = []
for fea in features:
ref_feature.append(torch.unbind(fea, dim=1)[0])
src_features = [src_fea for src_fea in torch.unbind(features[0], dim=1)[1:]]
stem_2x = self.stem_2(images[0])
stem_4x = self.stem_4(stem_2x)
ref_feature[0] = torch.cat((ref_feature[0], stem_4x), 1)
for idx, src_fea in enumerate(src_features):
stem_2y = self.stem_2(images[idx + 1])
stem_4y = self.stem_4(stem_2y)
src_features[idx] = torch.cat((src_fea, stem_4y), 1)
match_left = self.desc(self.conv(ref_feature[0]))
match_left = match_left / torch.norm(match_left, 2, 1, True)
match_rights = [self.desc(self.conv(src_fea)) for src_fea in src_features]
match_rights = [match_right / torch.norm(match_right, 2, 1, True) for match_right in match_rights]
xspx = self.spx_4(ref_feature[0])
xspx = self.spx_2(xspx, stem_2x)
spx_pred = self.spx(xspx)
spx_pred = F.softmax(spx_pred, 1)
batch, dim, height, width = match_left.size()
inverse_depth_min = (1.0 / depth_min).view(batch, 1, 1, 1)
inverse_depth_max = (1.0 / depth_max).view(batch, 1, 1, 1)
device = match_left.device
correlation_sum = 0
view_weight_sum = 1e-5
match_left = match_left.float()
depth_samples = self.depth_initialization(inverse_depth_min, inverse_depth_max, height, width, device)
for src_feature, src_proj in zip(match_rights, src_projs):
src_feature = src_feature.float()
warped_feature = differentiable_warping(src_feature, src_proj, ref_proj, depth_samples)
warped_feature = warped_feature.view(batch, self.G, dim // self.G, self.num_sample, height, width)
correlation = torch.mean(warped_feature * match_left.view(batch, self.G, dim // self.G, 1, height, width), dim=2, keepdim=False)
view_weight = self.pixel_view_weight(correlation)
del warped_feature, src_feature, src_proj
correlation_sum += correlation * view_weight.unsqueeze(1)
view_weight_sum += view_weight_sum + view_weight.unsqueeze(1)
del correlation, view_weight
del match_left, match_rights, src_projs
with autocast(enabled=self.mixed_precision):
init_corr_volume = correlation_sum.div_(view_weight_sum)
corr_volume = self.corr_stem(init_corr_volume)
corr_volume = self.corr_feature_att(corr_volume, ref_feature[0])
regularized_cost_volume = self.cost_agg(corr_volume, ref_feature)
GEV_hidden = self.conv_hidden_1(regularized_cost_volume.squeeze(1))
GEV_hidden_2 = self.conv_hidden_2(GEV_hidden)
GEV_hidden_4 = self.conv_hidden_4(GEV_hidden_2)
net_list = [GEV_hidden, GEV_hidden_2, GEV_hidden_4]
net_list = [torch.tanh(x) for x in net_list]
corr_block = CorrBlock1D_Cost_Volume
init_corr_volume = init_corr_volume.float()
regularized_cost_volume = regularized_cost_volume.float()
probability = F.softmax(regularized_cost_volume.squeeze(1), dim=1)
index = torch.arange(0, self.num_sample, 1, device=probability.device).view(1, self.num_sample, 1, 1).float()
disp_init = torch.sum(index * probability, dim = 1, keepdim=True)
corr_fn = corr_block(init_corr_volume, regularized_cost_volume, radius=self.corr_radius, num_levels=self.corr_levels, inverse_depth_min=inverse_depth_min, inverse_depth_max=inverse_depth_max, num_sample=self.num_sample)
disp_predictions = []
disp = disp_init
for itr in range(self.iters):
disp = disp.detach()
corr = corr_fn(disp)
with autocast(enabled=self.mixed_precision):
if self.n_gru_layers == 3 and self.slow_fast_gru: # Update low-res GRU
net_list = self.update_block(net_list, iter16=True, iter08=False, iter04=False, update=False)
if self.n_gru_layers >= 2 and self.slow_fast_gru:# Update low-res GRU and mid-res GRU
net_list = self.update_block(net_list, iter16=self.n_gru_layers==3, iter08=True, iter04=False, update=False)
net_list, mask_feat_4, delta_disp = self.update_block(net_list, corr, disp, iter16=self.n_gru_layers==3, iter08=self.n_gru_layers>=2)
disp = disp + delta_disp
if test_mode and itr < self.iters-1:
continue
disp_up = self.upsample_disp(disp, mask_feat_4, stem_2x) / (self.num_sample-1)
disp_predictions.append(disp_up)
disp_init = context_upsample(disp_init, spx_pred.float()).unsqueeze(1) / (self.num_sample-1)
if test_mode:
return disp_up
return disp_init, disp_predictions

396
IGEV-MVS/core/submodule.py Normal file
View File

@ -0,0 +1,396 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class SubModule(nn.Module):
def __init__(self):
super(SubModule, self).__init__()
def weight_init(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.Conv3d):
n = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm3d):
m.weight.data.fill_(1)
m.bias.data.zero_()
class BasicConv(nn.Module):
def __init__(self, in_channels, out_channels, deconv=False, is_3d=False, bn=True, relu=True, **kwargs):
super(BasicConv, self).__init__()
self.relu = relu
self.use_bn = bn
if is_3d:
if deconv:
self.conv = nn.ConvTranspose3d(in_channels, out_channels, bias=False, **kwargs)
else:
self.conv = nn.Conv3d(in_channels, out_channels, bias=False, **kwargs)
self.bn = nn.BatchNorm3d(out_channels)
else:
if deconv:
self.conv = nn.ConvTranspose2d(in_channels, out_channels, bias=False, **kwargs)
else:
self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
self.bn = nn.BatchNorm2d(out_channels)
def forward(self, x):
x = self.conv(x)
if self.use_bn:
x = self.bn(x)
if self.relu:
x = nn.LeakyReLU()(x)#, inplace=True)
return x
class BasicConv_IN(nn.Module):
def __init__(self, in_channels, out_channels, deconv=False, is_3d=False, IN=True, relu=True, **kwargs):
super(BasicConv_IN, self).__init__()
self.relu = relu
self.use_in = IN
if is_3d:
if deconv:
self.conv = nn.ConvTranspose3d(in_channels, out_channels, bias=False, **kwargs)
else:
self.conv = nn.Conv3d(in_channels, out_channels, bias=False, **kwargs)
self.IN = nn.InstanceNorm3d(out_channels)
else:
if deconv:
self.conv = nn.ConvTranspose2d(in_channels, out_channels, bias=False, **kwargs)
else:
self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
self.IN = nn.InstanceNorm2d(out_channels)
def forward(self, x):
x = self.conv(x)
if self.use_in:
x = self.IN(x)
if self.relu:
x = nn.LeakyReLU()(x)#, inplace=True)
return x
class Conv2x(nn.Module):
def __init__(self, in_channels, out_channels, deconv=False, is_3d=False, concat=True, keep_concat=True, bn=True, relu=True, keep_dispc=False):
super(Conv2x, self).__init__()
self.concat = concat
self.is_3d = is_3d
if deconv and is_3d:
kernel = (4, 4, 4)
elif deconv:
kernel = 4
else:
kernel = 3
if deconv and is_3d and keep_dispc:
kernel = (1, 4, 4)
stride = (1, 2, 2)
padding = (0, 1, 1)
self.conv1 = BasicConv(in_channels, out_channels, deconv, is_3d, bn=True, relu=True, kernel_size=kernel, stride=stride, padding=padding)
else:
self.conv1 = BasicConv(in_channels, out_channels, deconv, is_3d, bn=True, relu=True, kernel_size=kernel, stride=2, padding=1)
if self.concat:
mul = 2 if keep_concat else 1
self.conv2 = BasicConv(out_channels*2, out_channels*mul, False, is_3d, bn, relu, kernel_size=3, stride=1, padding=1)
else:
self.conv2 = BasicConv(out_channels, out_channels, False, is_3d, bn, relu, kernel_size=3, stride=1, padding=1)
def forward(self, x, rem):
x = self.conv1(x)
if x.shape != rem.shape:
x = F.interpolate(
x,
size=(rem.shape[-2], rem.shape[-1]),
mode='nearest')
if self.concat:
x = torch.cat((x, rem), 1)
else:
x = x + rem
x = self.conv2(x)
return x
class Conv2x_IN(nn.Module):
def __init__(self, in_channels, out_channels, deconv=False, is_3d=False, concat=True, keep_concat=True, IN=True, relu=True, keep_dispc=False):
super(Conv2x_IN, self).__init__()
self.concat = concat
self.is_3d = is_3d
if deconv and is_3d:
kernel = (4, 4, 4)
elif deconv:
kernel = 4
else:
kernel = 3
if deconv and is_3d and keep_dispc:
kernel = (1, 4, 4)
stride = (1, 2, 2)
padding = (0, 1, 1)
self.conv1 = BasicConv_IN(in_channels, out_channels, deconv, is_3d, IN=True, relu=True, kernel_size=kernel, stride=stride, padding=padding)
else:
self.conv1 = BasicConv_IN(in_channels, out_channels, deconv, is_3d, IN=True, relu=True, kernel_size=kernel, stride=2, padding=1)
if self.concat:
mul = 2 if keep_concat else 1
self.conv2 = BasicConv_IN(out_channels*2, out_channels*mul, False, is_3d, IN, relu, kernel_size=3, stride=1, padding=1)
else:
self.conv2 = BasicConv_IN(out_channels, out_channels, False, is_3d, IN, relu, kernel_size=3, stride=1, padding=1)
def forward(self, x, rem):
x = self.conv1(x)
if x.shape != rem.shape:
x = F.interpolate(
x,
size=(rem.shape[-2], rem.shape[-1]),
mode='nearest')
if self.concat:
x = torch.cat((x, rem), 1)
else:
x = x + rem
x = self.conv2(x)
return x
class ConvReLU(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, pad=1, dilation=1):
super(ConvReLU, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=pad, dilation=dilation, bias=False)
def forward(self,x):
return F.relu(self.conv(x), inplace=True)
class DepthInitialization(nn.Module):
def __init__(self, num_sample):
super(DepthInitialization, self).__init__()
self.num_sample = num_sample
def forward(self, inverse_depth_min, inverse_depth_max, height, width, device):
batch = inverse_depth_min.size()[0]
index = torch.arange(0, self.num_sample, 1, device=device).view(1, self.num_sample, 1, 1).float()
normalized_sample = index.repeat(batch, 1, height, width) / (self.num_sample-1)
depth_sample = inverse_depth_max + normalized_sample * (inverse_depth_min - inverse_depth_max)
depth_sample = 1.0 / depth_sample
return depth_sample
class PixelViewWeight(nn.Module):
def __init__(self, G):
super(PixelViewWeight, self).__init__()
self.conv = nn.Sequential(
ConvReLU(G, 16),
nn.Conv2d(16, 1, 1, stride=1, padding=0),
)
def forward(self, x):
# x: [B, G, N, H, W]
batch, dim, num_depth, height, width = x.size()
x = x.permute(0, 2, 1, 3, 4).contiguous()
x = x.view(batch*num_depth, dim, height, width) # [B*N,G,H,W]
x =self.conv(x).view(batch, num_depth, height, width)
x = torch.softmax(x,dim=1)
x = torch.max(x, dim=1)[0]
return x.unsqueeze(1)
class FeatureAtt(nn.Module):
def __init__(self, cv_chan, feat_chan):
super(FeatureAtt, self).__init__()
self.feat_att = nn.Sequential(
BasicConv(feat_chan, feat_chan//2, kernel_size=1, stride=1, padding=0),
nn.Conv2d(feat_chan//2, cv_chan, 1))
def forward(self, cv, feat):
'''
'''
feat_att = self.feat_att(feat).unsqueeze(2)
cv = torch.sigmoid(feat_att)*cv
return cv
class hourglass(nn.Module):
def __init__(self, in_channels):
super(hourglass, self).__init__()
self.conv1 = nn.Sequential(BasicConv(in_channels, in_channels*2, is_3d=True, bn=True, relu=True, kernel_size=3,
padding=1, stride=2, dilation=1),
BasicConv(in_channels*2, in_channels*2, is_3d=True, bn=True, relu=True, kernel_size=3,
padding=1, stride=1, dilation=1))
self.conv2 = nn.Sequential(BasicConv(in_channels*2, in_channels*4, is_3d=True, bn=True, relu=True, kernel_size=3,
padding=1, stride=2, dilation=1),
BasicConv(in_channels*4, in_channels*4, is_3d=True, bn=True, relu=True, kernel_size=3,
padding=1, stride=1, dilation=1))
self.conv3 = nn.Sequential(BasicConv(in_channels*4, in_channels*6, is_3d=True, bn=True, relu=True, kernel_size=3,
padding=1, stride=2, dilation=1),
BasicConv(in_channels*6, in_channels*6, is_3d=True, bn=True, relu=True, kernel_size=3,
padding=1, stride=1, dilation=1))
self.conv3_up = BasicConv(in_channels*6, in_channels*4, deconv=True, is_3d=True, bn=True,
relu=True, kernel_size=(4, 4, 4), padding=(1, 1, 1), stride=(2, 2, 2))
self.conv2_up = BasicConv(in_channels*4, in_channels*2, deconv=True, is_3d=True, bn=True,
relu=True, kernel_size=(4, 4, 4), padding=(1, 1, 1), stride=(2, 2, 2))
self.conv1_up = BasicConv(in_channels*2, 1, deconv=True, is_3d=True, bn=False,
relu=False, kernel_size=(4, 4, 4), padding=(1, 1, 1), stride=(2, 2, 2))
self.agg_0 = nn.Sequential(BasicConv(in_channels*8, in_channels*4, is_3d=True, kernel_size=1, padding=0, stride=1),
BasicConv(in_channels*4, in_channels*4, is_3d=True, kernel_size=3, padding=1, stride=1),
BasicConv(in_channels*4, in_channels*4, is_3d=True, kernel_size=3, padding=1, stride=1),)
self.agg_1 = nn.Sequential(BasicConv(in_channels*4, in_channels*2, is_3d=True, kernel_size=1, padding=0, stride=1),
BasicConv(in_channels*2, in_channels*2, is_3d=True, kernel_size=3, padding=1, stride=1),
BasicConv(in_channels*2, in_channels*2, is_3d=True, kernel_size=3, padding=1, stride=1))
self.feature_att_8 = FeatureAtt(in_channels*2, 64)
self.feature_att_16 = FeatureAtt(in_channels*4, 192)
self.feature_att_32 = FeatureAtt(in_channels*6, 160)
self.feature_att_up_16 = FeatureAtt(in_channels*4, 192)
self.feature_att_up_8 = FeatureAtt(in_channels*2, 64)
def forward(self, x, features):
conv1 = self.conv1(x)
conv1 = self.feature_att_8(conv1, features[1])
conv2 = self.conv2(conv1)
conv2 = self.feature_att_16(conv2, features[2])
conv3 = self.conv3(conv2)
conv3 = self.feature_att_32(conv3, features[3])
conv3_up = self.conv3_up(conv3)
conv2 = torch.cat((conv3_up, conv2), dim=1)
conv2 = self.agg_0(conv2)
conv2 = self.feature_att_up_16(conv2, features[2])
conv2_up = self.conv2_up(conv2)
conv1 = torch.cat((conv2_up, conv1), dim=1)
conv1 = self.agg_1(conv1)
conv1 = self.feature_att_up_8(conv1, features[1])
conv = self.conv1_up(conv1)
return conv
def bilinear_sampler(img, coords, mode='bilinear', mask=False):
""" Wrapper for grid_sample, uses pixel coordinates """
H, W = img.shape[-2:]
xgrid, ygrid = coords.split([1,1], dim=-1)
xgrid = 2*xgrid/(W-1) - 1
assert torch.unique(ygrid).numel() == 1 and H == 1 # This is a stereo problem
grid = torch.cat([xgrid, ygrid], dim=-1)
img = F.grid_sample(img, grid, align_corners=True)
if mask:
mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
return img, mask.float()
return img
def context_upsample(disp_low, up_weights):
###
# cv (b,1,h,w)
# sp (b,9,4*h,4*w)
###
b, c, h, w = disp_low.shape
disp_unfold = F.unfold(disp_low.reshape(b,c,h,w),3,1,1).reshape(b,-1,h,w)
disp_unfold = F.interpolate(disp_unfold,(h*4,w*4),mode='nearest').reshape(b,9,h*4,w*4)
disp = (disp_unfold*up_weights).sum(1)
return disp
def pool2x(x):
return F.avg_pool2d(x, 3, stride=2, padding=1)
def interp(x, dest):
interp_args = {'mode': 'bilinear', 'align_corners': True}
return F.interpolate(x, dest.shape[2:], **interp_args)
def differentiable_warping(src_fea, src_proj, ref_proj, depth_samples, return_mask=False):
# src_fea: [B, C, H, W]
# src_proj: [B, 4, 4]
# ref_proj: [B, 4, 4]
# depth_samples: [B, Ndepth, H, W]
# out: [B, C, Ndepth, H, W]
batch, num_depth, height, width = depth_samples.size()
height1, width1 = src_fea.size()[2:]
with torch.no_grad():
if batch==2:
inv_ref_proj = []
for i in range(batch):
inv_ref_proj.append(torch.inverse(ref_proj[i]).unsqueeze(0))
inv_ref_proj = torch.cat(inv_ref_proj, dim=0)
assert (not torch.isnan(inv_ref_proj).any()), "nan in inverse(ref_proj)"
proj = torch.matmul(src_proj, inv_ref_proj)
else:
proj = torch.matmul(src_proj, torch.inverse(ref_proj))
assert (not torch.isnan(proj).any()), "nan in proj"
rot = proj[:, :3, :3] # [B,3,3]
trans = proj[:, :3, 3:4] # [B,3,1]
y, x = torch.meshgrid([torch.arange(0, height, dtype=torch.float32, device=depth_samples.device),
torch.arange(0, width, dtype=torch.float32, device=depth_samples.device)])
y, x = y.contiguous(), x.contiguous()
y, x = y.view(height * width), x.view(height * width)
y = y*(height1/height)
x = x*(width1/width)
xyz = torch.stack((x, y, torch.ones_like(x))) # [3, H*W]
xyz = torch.unsqueeze(xyz, 0).repeat(batch, 1, 1) # [B, 3, H*W]
rot_xyz = torch.matmul(rot, xyz) # [B, 3, H*W]
rot_depth_xyz = rot_xyz.unsqueeze(2).repeat(1, 1, num_depth, 1) * depth_samples.view(batch, 1, num_depth,
height * width) # [B, 3, Ndepth, H*W]
proj_xyz = rot_depth_xyz + trans.view(batch, 3, 1, 1) # [B, 3, Ndepth, H*W]
# avoid negative depth
valid_mask = proj_xyz[:, 2:] > 1e-2
proj_xyz[:, 0:1][~valid_mask] = width
proj_xyz[:, 1:2][~valid_mask] = height
proj_xyz[:, 2:3][~valid_mask] = 1
proj_xy = proj_xyz[:, :2, :, :] / proj_xyz[:, 2:3, :, :] # [B, 2, Ndepth, H*W]
valid_mask = valid_mask & (proj_xy[:, 0:1] >=0) & (proj_xy[:, 0:1] < width) \
& (proj_xy[:, 1:2] >=0) & (proj_xy[:, 1:2] < height)
proj_x_normalized = proj_xy[:, 0, :, :] / ((width1 - 1) / 2) - 1 # [B, Ndepth, H*W]
proj_y_normalized = proj_xy[:, 1, :, :] / ((height1 - 1) / 2) - 1
proj_xy = torch.stack((proj_x_normalized, proj_y_normalized), dim=3) # [B, Ndepth, H*W, 2]
grid = proj_xy
dim = src_fea.size()[1]
warped_src_fea = F.grid_sample(src_fea, grid.view(batch, num_depth * height, width, 2), mode='bilinear',
padding_mode='zeros',align_corners=True)
warped_src_fea = warped_src_fea.view(batch, dim, num_depth, height, width)
if return_mask:
valid_mask = valid_mask.view(batch,num_depth,height,width)
return warped_src_fea, valid_mask
else:
return warped_src_fea
def depth_normalization(depth, inverse_depth_min, inverse_depth_max):
'''convert depth map to the index in inverse range'''
inverse_depth = 1.0 / (depth+1e-5)
normalized_depth = (inverse_depth - inverse_depth_max) / (inverse_depth_min - inverse_depth_max)
return normalized_depth
def depth_unnormalization(normalized_depth, inverse_depth_min, inverse_depth_max):
'''convert the index in inverse range to depth map'''
inverse_depth = inverse_depth_max + normalized_depth * (inverse_depth_min - inverse_depth_max) # [B,1,H,W]
depth = 1.0 / inverse_depth
return depth

94
IGEV-MVS/core/update.py Normal file
View File

@ -0,0 +1,94 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from .submodule import *
class BasicMotionEncoder(nn.Module):
def __init__(self):
super(BasicMotionEncoder, self).__init__()
self.corr_levels = 2
self.corr_radius = 4
cor_planes = 2 * self.corr_levels * (2*self.corr_radius + 1)
self.convc1 = nn.Conv2d(cor_planes, 64, 1, padding=0)
self.convc2 = nn.Conv2d(64, 64, 3, padding=1)
self.convd1 = nn.Conv2d(1, 64, 7, padding=3)
self.convd2 = nn.Conv2d(64, 64, 3, padding=1)
self.conv = nn.Conv2d(64+64, 128-1, 3, padding=1)
def forward(self, disp, corr):
cor = F.relu(self.convc1(corr))
cor = F.relu(self.convc2(cor))
disp_ = F.relu(self.convd1(disp))
disp_ = F.relu(self.convd2(disp_))
cor_disp = torch.cat([cor, disp_], dim=1)
out = F.relu(self.conv(cor_disp))
return torch.cat([out, disp], dim=1)
class ConvGRU(nn.Module):
def __init__(self, hidden_dim, input_dim, kernel_size=3):
super(ConvGRU, self).__init__()
self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, kernel_size, padding=kernel_size//2)
self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, kernel_size, padding=kernel_size//2)
self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, kernel_size, padding=kernel_size//2)
def forward(self, h, *x_list):
x = torch.cat(x_list, dim=1)
hx = torch.cat([h, x], dim=1)
z = torch.sigmoid(self.convz(hx))
r = torch.sigmoid(self.convr(hx))
q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1)))
h = (1-z) * h + z * q
return h
class DispHead(nn.Module):
def __init__(self, input_dim=128, hidden_dim=256, output_dim=1):
super(DispHead, self).__init__()
self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
self.conv2 = nn.Conv2d(hidden_dim, output_dim, 3, padding=1)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
return self.conv2(self.relu(self.conv1(x)))
class BasicMultiUpdateBlock(nn.Module):
def __init__(self, hidden_dims=[]):
super().__init__()
self.n_gru_layers = 3
self.n_downsample = 2
self.encoder = BasicMotionEncoder()
encoder_output_dim = 128
self.gru04 = ConvGRU(hidden_dims[2], encoder_output_dim + hidden_dims[1] * (self.n_gru_layers > 1))
self.gru08 = ConvGRU(hidden_dims[1], hidden_dims[0] * (self.n_gru_layers == 3) + hidden_dims[2])
self.gru16 = ConvGRU(hidden_dims[0], hidden_dims[1])
self.disp_head = DispHead(hidden_dims[2], hidden_dim=256, output_dim=1)
factor = 2**self.n_downsample
self.mask_feat_4 = nn.Sequential(
nn.Conv2d(hidden_dims[2], 32, 3, padding=1),
nn.ReLU(inplace=True))
def forward(self, net, corr=None, disp=None, iter04=True, iter08=True, iter16=True, update=True):
if iter16:
net[2] = self.gru16(net[2], pool2x(net[1]))
if iter08:
if self.n_gru_layers > 2:
net[1] = self.gru08(net[1], pool2x(net[0]), interp(net[2], net[1]))
else:
net[1] = self.gru08(net[1], pool2x(net[0]))
if iter04:
motion_features = self.encoder(disp, corr)
if self.n_gru_layers > 1:
net[0] = self.gru04(net[0], motion_features, interp(net[1], net[0]))
else:
net[0] = self.gru04(net[0], motion_features)
if not update:
return net
delta_disp = self.disp_head(net[0])
mask_feat_4 = self.mask_feat_4(net[0])
return net, mask_feat_4, delta_disp

View File

@ -0,0 +1,8 @@
import importlib
# find the dataset definition by name, for example dtu_yao (dtu_yao.py)
def find_dataset_def(dataset_name):
module_name = 'datasets.{}'.format(dataset_name)
module = importlib.import_module(module_name)
return getattr(module, "MVSDataset")

View File

@ -0,0 +1,208 @@
from torch.utils.data import Dataset
from datasets.data_io import *
import os
import numpy as np
import cv2
from PIL import Image
from torchvision import transforms as T
import random
class MVSDataset(Dataset):
def __init__(self, datapath, listfile, split, nviews, img_wh=(768, 576), robust_train=True):
super(MVSDataset, self).__init__()
self.levels = 4
self.datapath = datapath
self.split = split
self.listfile = listfile
self.robust_train = robust_train
assert self.split in ['train', 'val', 'all'], \
'split must be either "train", "val" or "all"!'
self.img_wh = img_wh
if img_wh is not None:
assert img_wh[0]%32==0 and img_wh[1]%32==0, \
'img_wh must both be multiples of 32!'
self.nviews = nviews
self.scale_factors = {} # depth scale factors for each scan
self.build_metas()
self.color_augment = T.ColorJitter(brightness=0.5, contrast=0.5)
def build_metas(self):
self.metas = []
with open(self.listfile) as f:
self.scans = [line.rstrip() for line in f.readlines()]
for scan in self.scans:
with open(os.path.join(self.datapath, scan, "cams/pair.txt")) as f:
num_viewpoint = int(f.readline())
for _ in range(num_viewpoint):
ref_view = int(f.readline().rstrip())
src_views = [int(x) for x in f.readline().rstrip().split()[1::2]]
if len(src_views) >= self.nviews-1:
self.metas += [(scan, ref_view, src_views)]
def read_cam_file(self, scan, filename):
with open(filename) as f:
lines = f.readlines()
lines = [line.rstrip() for line in lines]
# extrinsics: line [1,5), 4x4 matrix
extrinsics = np.fromstring(' '.join(lines[1:5]), dtype=np.float32, sep=' ').reshape((4, 4))
# intrinsics: line [7-10), 3x3 matrix
intrinsics = np.fromstring(' '.join(lines[7:10]), dtype=np.float32, sep=' ').reshape((3, 3))
depth_min = float(lines[11].split()[0])
depth_max = float(lines[11].split()[-1])
if scan not in self.scale_factors:
self.scale_factors[scan] = 100.0/depth_min
depth_min *= self.scale_factors[scan]
depth_max *= self.scale_factors[scan]
extrinsics[:3, 3] *= self.scale_factors[scan]
return intrinsics, extrinsics, depth_min, depth_max
def read_depth_mask(self, scan, filename, depth_min, depth_max, scale):
depth = np.array(read_pfm(filename)[0], dtype=np.float32)
depth = depth * self.scale_factors[scan] * scale
depth = np.squeeze(depth,2)
mask = (depth>=depth_min) & (depth<=depth_max)
mask = mask.astype(np.float32)
if self.img_wh is not None:
depth = cv2.resize(depth, self.img_wh,
interpolation=cv2.INTER_NEAREST)
h, w = depth.shape
depth_ms = {}
mask_ms = {}
for i in range(4):
depth_cur = cv2.resize(depth, (w//(2**i), h//(2**i)), interpolation=cv2.INTER_NEAREST)
mask_cur = cv2.resize(mask, (w//(2**i), h//(2**i)), interpolation=cv2.INTER_NEAREST)
depth_ms[f"level_{i}"] = depth_cur
mask_ms[f"level_{i}"] = mask_cur
return depth_ms, mask_ms
def read_img(self, filename):
img = Image.open(filename)
if self.split=='train':
img = self.color_augment(img)
# scale 0~255 to -1~1
np_img = 2*np.array(img, dtype=np.float32) / 255. - 1
if self.img_wh is not None:
np_img = cv2.resize(np_img, self.img_wh,
interpolation=cv2.INTER_LINEAR)
h, w, _ = np_img.shape
np_img_ms = {
"level_3": cv2.resize(np_img, (w//8, h//8), interpolation=cv2.INTER_LINEAR),
"level_2": cv2.resize(np_img, (w//4, h//4), interpolation=cv2.INTER_LINEAR),
"level_1": cv2.resize(np_img, (w//2, h//2), interpolation=cv2.INTER_LINEAR),
"level_0": np_img
}
return np_img_ms
def __len__(self):
return len(self.metas)
def __getitem__(self, idx):
meta = self.metas[idx]
scan, ref_view, src_views = meta
if self.robust_train:
num_src_views = len(src_views)
index = random.sample(range(num_src_views), self.nviews - 1)
view_ids = [ref_view] + [src_views[i] for i in index]
scale = random.uniform(0.8, 1.25)
else:
view_ids = [ref_view] + src_views[:self.nviews - 1]
scale = 1
imgs_0 = []
imgs_1 = []
imgs_2 = []
imgs_3 = []
mask = None
depth = None
depth_min = None
depth_max = None
proj_matrices_0 = []
proj_matrices_1 = []
proj_matrices_2 = []
proj_matrices_3 = []
for i, vid in enumerate(view_ids):
img_filename = os.path.join(self.datapath, '{}/blended_images/{:0>8}.jpg'.format(scan, vid))
depth_filename = os.path.join(self.datapath, '{}/rendered_depth_maps/{:0>8}.pfm'.format(scan, vid))
proj_mat_filename = os.path.join(self.datapath, '{}/cams/{:0>8}_cam.txt'.format(scan, vid))
imgs = self.read_img(img_filename)
imgs_0.append(imgs['level_0'])
imgs_1.append(imgs['level_1'])
imgs_2.append(imgs['level_2'])
imgs_3.append(imgs['level_3'])
# here, the intrinsics from file is already adjusted to the downsampled size of feature 1/4H0 * 1/4W0
intrinsics, extrinsics, depth_min_, depth_max_ = self.read_cam_file(scan, proj_mat_filename)
extrinsics[:3, 3] *= scale
proj_mat = extrinsics.copy()
intrinsics[:2,:] *= 0.125
proj_mat[:3, :4] = np.matmul(intrinsics, proj_mat[:3, :4])
proj_matrices_3.append(proj_mat)
proj_mat = extrinsics.copy()
intrinsics[:2,:] *= 2
proj_mat[:3, :4] = np.matmul(intrinsics, proj_mat[:3, :4])
proj_matrices_2.append(proj_mat)
proj_mat = extrinsics.copy()
intrinsics[:2,:] *= 2
proj_mat[:3, :4] = np.matmul(intrinsics, proj_mat[:3, :4])
proj_matrices_1.append(proj_mat)
proj_mat = extrinsics.copy()
intrinsics[:2,:] *= 2
proj_mat[:3, :4] = np.matmul(intrinsics, proj_mat[:3, :4])
proj_matrices_0.append(proj_mat)
if i == 0: # reference view
depth_min = depth_min_ * scale
depth_max = depth_max_ * scale
depth, mask = self.read_depth_mask(scan, depth_filename, depth_min, depth_max, scale)
for l in range(self.levels):
mask[f'level_{l}'] = np.expand_dims(mask[f'level_{l}'],2)
mask[f'level_{l}'] = mask[f'level_{l}'].transpose([2,0,1])
depth[f'level_{l}'] = np.expand_dims(depth[f'level_{l}'],2)
depth[f'level_{l}'] = depth[f'level_{l}'].transpose([2,0,1])
# imgs: N*3*H0*W0, N is number of images
imgs_0 = np.stack(imgs_0).transpose([0, 3, 1, 2])
imgs_1 = np.stack(imgs_1).transpose([0, 3, 1, 2])
imgs_2 = np.stack(imgs_2).transpose([0, 3, 1, 2])
imgs_3 = np.stack(imgs_3).transpose([0, 3, 1, 2])
imgs = {}
imgs['level_0'] = imgs_0
imgs['level_1'] = imgs_1
imgs['level_2'] = imgs_2
imgs['level_3'] = imgs_3
# proj_matrices: N*4*4
proj_matrices_0 = np.stack(proj_matrices_0)
proj_matrices_1 = np.stack(proj_matrices_1)
proj_matrices_2 = np.stack(proj_matrices_2)
proj_matrices_3 = np.stack(proj_matrices_3)
proj={}
proj['level_3']=proj_matrices_3
proj['level_2']=proj_matrices_2
proj['level_1']=proj_matrices_1
proj['level_0']=proj_matrices_0
# data is numpy array
return {"imgs": imgs, # [N, 3, H, W]
"proj_matrices": proj, # [N,4,4]
"depth": depth, # [1, H, W]
"depth_min": depth_min, # scalar
"depth_max": depth_max, # scalar
"mask": mask} # [1, H, W]

145
IGEV-MVS/datasets/custom.py Normal file
View File

@ -0,0 +1,145 @@
from torch.utils.data import Dataset
from datasets.data_io import *
import os
import numpy as np
import cv2
from PIL import Image
from torchvision import transforms as T
import math
class MVSDataset(Dataset):
def __init__(self, datapath, n_views=5, img_wh=(640,480)):
self.levels = 4
self.datapath = datapath
self.img_wh = img_wh
self.build_metas()
self.n_views = n_views
def build_metas(self):
self.metas = []
with open(os.path.join(self.datapath, 'pair.txt')) as f:
num_viewpoint = int(f.readline())
for view_idx in range(num_viewpoint):
ref_view = int(f.readline().rstrip())
src_views = [int(x) for x in f.readline().rstrip().split()[1::2]]
if len(src_views) != 0:
self.metas += [(ref_view, src_views)]
def read_cam_file(self, filename):
with open(filename) as f:
lines = [line.rstrip() for line in f.readlines()]
# extrinsics: line [1,5), 4x4 matrix
extrinsics = np.fromstring(' '.join(lines[1:5]), dtype=np.float32, sep=' ')
extrinsics = extrinsics.reshape((4, 4))
# intrinsics: line [7-10), 3x3 matrix
intrinsics = np.fromstring(' '.join(lines[7:10]), dtype=np.float32, sep=' ')
intrinsics = intrinsics.reshape((3, 3))
depth_min = float(lines[11].split()[0])
depth_max = float(lines[11].split()[-1])
return intrinsics, extrinsics, depth_min, depth_max
def read_img(self, filename, h, w):
img = Image.open(filename)
# scale 0~255 to -1~1
np_img = 2*np.array(img, dtype=np.float32) / 255. - 1
original_h, original_w, _ = np_img.shape
np_img = cv2.resize(np_img, self.img_wh, interpolation=cv2.INTER_LINEAR)
np_img_ms = {
"level_3": cv2.resize(np_img, (w//8, h//8), interpolation=cv2.INTER_LINEAR),
"level_2": cv2.resize(np_img, (w//4, h//4), interpolation=cv2.INTER_LINEAR),
"level_1": cv2.resize(np_img, (w//2, h//2), interpolation=cv2.INTER_LINEAR),
"level_0": np_img
}
return np_img_ms, original_h, original_w
def __len__(self):
return len(self.metas)
def __getitem__(self, idx):
ref_view, src_views = self.metas[idx]
# use only the reference view and first nviews-1 source views
view_ids = [ref_view] + src_views[:self.n_views-1]
imgs_0 = []
imgs_1 = []
imgs_2 = []
imgs_3 = []
# depth = None
depth_min = None
depth_max = None
proj_matrices_0 = []
proj_matrices_1 = []
proj_matrices_2 = []
proj_matrices_3 = []
for i, vid in enumerate(view_ids):
img_filename = os.path.join(self.datapath, f'images/{vid:08d}.jpg')
proj_mat_filename = os.path.join(self.datapath, f'cams_1/{vid:08d}_cam.txt')
imgs, original_h, original_w = self.read_img(img_filename,self.img_wh[1], self.img_wh[0])
imgs_0.append(imgs['level_0'])
imgs_1.append(imgs['level_1'])
imgs_2.append(imgs['level_2'])
imgs_3.append(imgs['level_3'])
intrinsics, extrinsics, depth_min_, depth_max_ = self.read_cam_file(proj_mat_filename)
intrinsics[0] *= self.img_wh[0]/original_w
intrinsics[1] *= self.img_wh[1]/original_h
proj_mat = extrinsics.copy()
intrinsics[:2,:] *= 0.125
proj_mat[:3, :4] = np.matmul(intrinsics, proj_mat[:3, :4])
proj_matrices_3.append(proj_mat)
proj_mat = extrinsics.copy()
intrinsics[:2,:] *= 2
proj_mat[:3, :4] = np.matmul(intrinsics, proj_mat[:3, :4])
proj_matrices_2.append(proj_mat)
proj_mat = extrinsics.copy()
intrinsics[:2,:] *= 2
proj_mat[:3, :4] = np.matmul(intrinsics, proj_mat[:3, :4])
proj_matrices_1.append(proj_mat)
proj_mat = extrinsics.copy()
intrinsics[:2,:] *= 2
proj_mat[:3, :4] = np.matmul(intrinsics, proj_mat[:3, :4])
proj_matrices_0.append(proj_mat)
if i == 0: # reference view
depth_min = depth_min_
depth_max = depth_max_
# imgs: N*3*H0*W0, N is number of images
imgs_0 = np.stack(imgs_0).transpose([0, 3, 1, 2])
imgs_1 = np.stack(imgs_1).transpose([0, 3, 1, 2])
imgs_2 = np.stack(imgs_2).transpose([0, 3, 1, 2])
imgs_3 = np.stack(imgs_3).transpose([0, 3, 1, 2])
imgs = {}
imgs['level_0'] = imgs_0
imgs['level_1'] = imgs_1
imgs['level_2'] = imgs_2
imgs['level_3'] = imgs_3
# proj_matrices: N*4*4
proj_matrices_0 = np.stack(proj_matrices_0)
proj_matrices_1 = np.stack(proj_matrices_1)
proj_matrices_2 = np.stack(proj_matrices_2)
proj_matrices_3 = np.stack(proj_matrices_3)
proj={}
proj['level_3']=proj_matrices_3
proj['level_2']=proj_matrices_2
proj['level_1']=proj_matrices_1
proj['level_0']=proj_matrices_0
return {"imgs": imgs, # N*3*H0*W0
"proj_matrices": proj, # N*4*4
"depth_min": depth_min, # scalar
"depth_max": depth_max,
"filename": '{}/' + '{:0>8}'.format(view_ids[0]) + "{}"
}

View File

@ -0,0 +1,73 @@
import numpy as np
import re
import sys
def read_pfm(filename):
# rb: binary file and read only
file = open(filename, 'rb')
color = None
width = None
height = None
scale = None
endian = None
header = file.readline().decode('utf-8').rstrip()
if header == 'PF':
color = True
elif header == 'Pf': # depth is Pf
color = False
else:
raise Exception('Not a PFM file.')
dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode('utf-8')) # re is used for matching
if dim_match:
width, height = map(int, dim_match.groups())
else:
raise Exception('Malformed PFM header.')
scale = float(file.readline().rstrip())
if scale < 0: # little-endian
endian = '<'
scale = -scale
else:
endian = '>' # big-endian
data = np.fromfile(file, endian + 'f')
shape = (height, width, 3) if color else (height, width, 1)
# depth: H*W
data = np.reshape(data, shape)
data = np.flipud(data)
file.close()
return data, scale
def save_pfm(filename, image, scale=1):
file = open(filename, "wb")
color = None
image = np.flipud(image)
# print(image.shape)
if image.dtype.name != 'float32':
raise Exception('Image dtype must be float32.')
if len(image.shape) == 3 and image.shape[2] == 3: # color image
color = True
elif len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1: # greyscale
color = False
else:
raise Exception('Image must have H x W x 3, H x W x 1 or H x W dimensions.')
file.write('PF\n'.encode('utf-8') if color else 'Pf\n'.encode('utf-8'))
file.write('{} {}\n'.format(image.shape[1], image.shape[0]).encode('utf-8'))
endian = image.dtype.byteorder
if endian == '<' or endian == '=' and sys.byteorder == 'little':
scale = -scale
file.write(('%f\n' % scale).encode('utf-8'))
image.tofile(file)
file.close()

View File

@ -0,0 +1,236 @@
from torch.utils.data import Dataset
import numpy as np
import os
from PIL import Image
from datasets.data_io import *
import cv2
import random
from torchvision import transforms
class MVSDataset(Dataset):
def __init__(self, datapath, listfile, mode, nviews, robust_train = False):
super(MVSDataset, self).__init__()
self.levels = 4
self.datapath = datapath
self.listfile = listfile
self.mode = mode
self.nviews = nviews
self.img_wh = (640, 512)
# self.img_wh = (1440, 1056)
self.robust_train = robust_train
assert self.mode in ["train", "val", "test"]
self.metas = self.build_list()
self.color_augment = transforms.ColorJitter(brightness=0.5, contrast=0.5)
def build_list(self):
metas = []
with open(self.listfile) as f:
scans = f.readlines()
scans = [line.rstrip() for line in scans]
for scan in scans:
pair_file = "Cameras_1/pair.txt"
with open(os.path.join(self.datapath, pair_file)) as f:
self.num_viewpoint = int(f.readline())
# viewpoints (49)
for view_idx in range(self.num_viewpoint):
ref_view = int(f.readline().rstrip())
src_views = [int(x) for x in f.readline().rstrip().split()[1::2]]
# light conditions 0-6
for light_idx in range(7):
metas.append((scan, light_idx, ref_view, src_views))
print("dataset", self.mode, "metas:", len(metas))
return metas
def __len__(self):
return len(self.metas)
def read_cam_file(self, filename):
with open(filename) as f:
lines = f.readlines()
lines = [line.rstrip() for line in lines]
# extrinsics: line [1,5), 4x4 matrix
extrinsics = np.fromstring(' '.join(lines[1:5]), dtype=np.float32, sep=' ').reshape((4, 4))
# intrinsics: line [7-10), 3x3 matrix
intrinsics = np.fromstring(' '.join(lines[7:10]), dtype=np.float32, sep=' ').reshape((3, 3))
depth_min = float(lines[11].split()[0])
depth_max = float(lines[11].split()[-1])
return intrinsics, extrinsics, depth_min, depth_max
def read_img(self, filename):
img = Image.open(filename)
if self.mode=='train':
img = self.color_augment(img)
# scale 0~255 to -1~1
np_img = 2*np.array(img, dtype=np.float32) / 255. - 1
h, w, _ = np_img.shape
np_img_ms = {
"level_3": cv2.resize(np_img, (w//8, h//8), interpolation=cv2.INTER_LINEAR),
"level_2": cv2.resize(np_img, (w//4, h//4), interpolation=cv2.INTER_LINEAR),
"level_1": cv2.resize(np_img, (w//2, h//2), interpolation=cv2.INTER_LINEAR),
"level_0": np_img
}
return np_img_ms
def prepare_img(self, hr_img):
#downsample
h, w = hr_img.shape
# original w,h: 1600, 1200; downsample -> 800, 600 ; crop -> 640, 512
hr_img = cv2.resize(hr_img, (w//2, h//2), interpolation=cv2.INTER_NEAREST)
#crop
h, w = hr_img.shape
target_h, target_w = self.img_wh[1], self.img_wh[0]
start_h, start_w = (h - target_h)//2, (w - target_w)//2
hr_img_crop = hr_img[start_h: start_h + target_h, start_w: start_w + target_w]
return hr_img_crop
def read_mask(self, filename):
img = Image.open(filename)
np_img = np.array(img, dtype=np.float32)
np_img = (np_img > 10).astype(np.float32)
return np_img
def read_depth_mask(self, filename, mask_filename, scale):
depth_hr = np.array(read_pfm(filename)[0], dtype=np.float32) * scale
depth_hr = np.squeeze(depth_hr,2)
depth_lr = self.prepare_img(depth_hr)
mask = self.read_mask(mask_filename)
mask = self.prepare_img(mask)
mask = mask.astype(np.bool_)
mask = mask.astype(np.float32)
h, w = depth_lr.shape
depth_lr_ms = {}
mask_ms = {}
for i in range(self.levels):
depth_cur = cv2.resize(depth_lr, (w//(2**i), h//(2**i)), interpolation=cv2.INTER_NEAREST)
mask_cur = cv2.resize(mask, (w//(2**i), h//(2**i)), interpolation=cv2.INTER_NEAREST)
depth_lr_ms[f"level_{i}"] = depth_cur
mask_ms[f"level_{i}"] = mask_cur
return depth_lr_ms, mask_ms
def __getitem__(self, idx):
meta = self.metas[idx]
scan, light_idx, ref_view, src_views = meta
# robust training strategy
if self.robust_train:
num_src_views = len(src_views)
index = random.sample(range(num_src_views), self.nviews - 1)
view_ids = [ref_view] + [src_views[i] for i in index]
scale = random.uniform(0.8, 1.25)
else:
view_ids = [ref_view] + src_views[:self.nviews - 1]
scale = 1
imgs_0 = []
imgs_1 = []
imgs_2 = []
imgs_3 = []
mask = None
depth = None
depth_min = None
depth_max = None
proj_matrices_0 = []
proj_matrices_1 = []
proj_matrices_2 = []
proj_matrices_3 = []
for i, vid in enumerate(view_ids):
img_filename = os.path.join(self.datapath,
'Rectified/{}_train/rect_{:0>3}_{}_r5000.png'.format(scan, vid + 1, light_idx))
proj_mat_filename = os.path.join(self.datapath, 'Cameras_1/{}_train/{:0>8}_cam.txt').format(scan, vid)
mask_filename = os.path.join(self.datapath, 'Depths_raw/{}/depth_visual_{:0>4}.png'.format(scan, vid))
depth_filename = os.path.join(self.datapath, 'Depths_raw/{}/depth_map_{:0>4}.pfm'.format(scan, vid))
imgs = self.read_img(img_filename)
imgs_0.append(imgs['level_0'])
imgs_1.append(imgs['level_1'])
imgs_2.append(imgs['level_2'])
imgs_3.append(imgs['level_3'])
intrinsics, extrinsics, depth_min_, depth_max_ = self.read_cam_file(proj_mat_filename)
extrinsics[:3,3] *= scale
intrinsics[0] *= 4
intrinsics[1] *= 4
proj_mat = extrinsics.copy()
intrinsics[:2,:] *= 0.125
proj_mat[:3, :4] = np.matmul(intrinsics, proj_mat[:3, :4])
proj_matrices_3.append(proj_mat)
proj_mat = extrinsics.copy()
intrinsics[:2,:] *= 2
proj_mat[:3, :4] = np.matmul(intrinsics, proj_mat[:3, :4])
proj_matrices_2.append(proj_mat)
proj_mat = extrinsics.copy()
intrinsics[:2,:] *= 2
proj_mat[:3, :4] = np.matmul(intrinsics, proj_mat[:3, :4])
proj_matrices_1.append(proj_mat)
proj_mat = extrinsics.copy()
intrinsics[:2,:] *= 2
proj_mat[:3, :4] = np.matmul(intrinsics, proj_mat[:3, :4])
proj_matrices_0.append(proj_mat)
if i == 0: # reference view
depth_min = depth_min_ * scale
depth_max = depth_max_ * scale
depth, mask = self.read_depth_mask(depth_filename, mask_filename, scale)
for l in range(self.levels):
mask[f'level_{l}'] = np.expand_dims(mask[f'level_{l}'],2)
mask[f'level_{l}'] = mask[f'level_{l}'].transpose([2,0,1])
depth[f'level_{l}'] = np.expand_dims(depth[f'level_{l}'],2)
depth[f'level_{l}'] = depth[f'level_{l}'].transpose([2,0,1])
# imgs: N*3*H0*W0, N is number of images
imgs_0 = np.stack(imgs_0).transpose([0, 3, 1, 2])
imgs_1 = np.stack(imgs_1).transpose([0, 3, 1, 2])
imgs_2 = np.stack(imgs_2).transpose([0, 3, 1, 2])
imgs_3 = np.stack(imgs_3).transpose([0, 3, 1, 2])
imgs = {}
imgs['level_0'] = imgs_0
imgs['level_1'] = imgs_1
imgs['level_2'] = imgs_2
imgs['level_3'] = imgs_3
# proj_matrices: N*4*4
proj_matrices_0 = np.stack(proj_matrices_0)
proj_matrices_1 = np.stack(proj_matrices_1)
proj_matrices_2 = np.stack(proj_matrices_2)
proj_matrices_3 = np.stack(proj_matrices_3)
proj={}
proj['level_3']=proj_matrices_3
proj['level_2']=proj_matrices_2
proj['level_1']=proj_matrices_1
proj['level_0']=proj_matrices_0
# data is numpy array
return {"imgs": imgs, # [N, 3, H, W]
"proj_matrices": proj, # [N,4,4]
"depth": depth, # [1, H, W]
"depth_min": depth_min, # scalar
"depth_max": depth_max, # scalar
"mask": mask} # [1, H, W]

View File

@ -0,0 +1,158 @@
from torch.utils.data import Dataset
import numpy as np
import os
from PIL import Image
from datasets.data_io import *
import cv2
class MVSDataset(Dataset):
def __init__(self, datapath, listfile, nviews=5, img_wh=(1600, 1152)):
super(MVSDataset, self).__init__()
self.levels = 4
self.datapath = datapath
self.listfile = listfile
self.nviews = nviews
self.img_wh = img_wh
self.metas = self.build_list()
def build_list(self):
metas = []
with open(self.listfile) as f:
scans = f.readlines()
scans = [line.rstrip() for line in scans]
for scan in scans:
pair_file = "{}/pair.txt".format(scan)
# read the pair file
with open(os.path.join(self.datapath, pair_file)) as f:
num_viewpoint = int(f.readline())
# viewpoints (49)
for view_idx in range(num_viewpoint):
ref_view = int(f.readline().rstrip())
src_views = [int(x) for x in f.readline().rstrip().split()[1::2]]
metas.append((scan, ref_view, src_views))
print("dataset", "metas:", len(metas))
return metas
def __len__(self):
return len(self.metas)
def read_cam_file(self, filename):
with open(filename) as f:
lines = f.readlines()
lines = [line.rstrip() for line in lines]
# extrinsics: line [1,5), 4x4 matrix
extrinsics = np.fromstring(' '.join(lines[1:5]), dtype=np.float32, sep=' ').reshape((4, 4))
# intrinsics: line [7-10), 3x3 matrix
intrinsics = np.fromstring(' '.join(lines[7:10]), dtype=np.float32, sep=' ').reshape((3, 3))
depth_min = float(lines[11].split()[0])
depth_max = float(lines[11].split()[-1])
return intrinsics, extrinsics, depth_min, depth_max
def read_mask(self, filename):
img = Image.open(filename)
np_img = np.array(img, dtype=np.float32)
np_img = (np_img > 10).astype(np.float32)
return np_img
def read_img(self, filename):
img = Image.open(filename)
# scale 0~255 to -1~1
np_img = 2*np.array(img, dtype=np.float32) / 255. - 1
np_img = cv2.resize(np_img, self.img_wh, interpolation=cv2.INTER_LINEAR)
h, w, _ = np_img.shape
np_img_ms = {
"level_3": cv2.resize(np_img, (w//8, h//8), interpolation=cv2.INTER_LINEAR),
"level_2": cv2.resize(np_img, (w//4, h//4), interpolation=cv2.INTER_LINEAR),
"level_1": cv2.resize(np_img, (w//2, h//2), interpolation=cv2.INTER_LINEAR),
"level_0": np_img
}
return np_img_ms
def __getitem__(self, idx):
scan, ref_view, src_views = self.metas[idx]
# use only the reference view and first nviews-1 source views
view_ids = [ref_view] + src_views[:self.nviews - 1]
img_w = 1600
img_h = 1200
imgs_0 = []
imgs_1 = []
imgs_2 = []
imgs_3 = []
depth_min = None
depth_max = None
proj_matrices_0 = []
proj_matrices_1 = []
proj_matrices_2 = []
proj_matrices_3 = []
for i, vid in enumerate(view_ids):
img_filename = os.path.join(self.datapath, '{}/images/{:0>8}.jpg'.format(scan, vid))
proj_mat_filename = os.path.join(self.datapath, '{}/cams_1/{:0>8}_cam.txt'.format(scan, vid))
imgs = self.read_img(img_filename)
imgs_0.append(imgs['level_0'])
imgs_1.append(imgs['level_1'])
imgs_2.append(imgs['level_2'])
imgs_3.append(imgs['level_3'])
intrinsics, extrinsics, depth_min_, depth_max_ = self.read_cam_file(proj_mat_filename)
intrinsics[0] *= self.img_wh[0]/img_w
intrinsics[1] *= self.img_wh[1]/img_h
proj_mat = extrinsics.copy()
intrinsics[:2,:] *= 0.125
proj_mat[:3, :4] = np.matmul(intrinsics, proj_mat[:3, :4])
proj_matrices_3.append(proj_mat)
proj_mat = extrinsics.copy()
intrinsics[:2,:] *= 2
proj_mat[:3, :4] = np.matmul(intrinsics, proj_mat[:3, :4])
proj_matrices_2.append(proj_mat)
proj_mat = extrinsics.copy()
intrinsics[:2,:] *= 2
proj_mat[:3, :4] = np.matmul(intrinsics, proj_mat[:3, :4])
proj_matrices_1.append(proj_mat)
proj_mat = extrinsics.copy()
intrinsics[:2,:] *= 2
proj_mat[:3, :4] = np.matmul(intrinsics, proj_mat[:3, :4])
proj_matrices_0.append(proj_mat)
if i == 0: # reference view
depth_min = depth_min_
depth_max = depth_max_
imgs_0 = np.stack(imgs_0).transpose([0, 3, 1, 2])
imgs_1 = np.stack(imgs_1).transpose([0, 3, 1, 2])
imgs_2 = np.stack(imgs_2).transpose([0, 3, 1, 2])
imgs_3 = np.stack(imgs_3).transpose([0, 3, 1, 2])
imgs = {}
imgs['level_0'] = imgs_0
imgs['level_1'] = imgs_1
imgs['level_2'] = imgs_2
imgs['level_3'] = imgs_3
# proj_matrices: N*4*4
proj_matrices_0 = np.stack(proj_matrices_0)
proj_matrices_1 = np.stack(proj_matrices_1)
proj_matrices_2 = np.stack(proj_matrices_2)
proj_matrices_3 = np.stack(proj_matrices_3)
proj={}
proj['level_3']=proj_matrices_3
proj['level_2']=proj_matrices_2
proj['level_1']=proj_matrices_1
proj['level_0']=proj_matrices_0
return {"imgs": imgs, # N*3*H0*W0
"proj_matrices": proj, # N*4*4
"depth_min": depth_min, # scalar
"depth_max": depth_max, # scalar
"filename": scan + '/{}/' + '{:0>8}'.format(view_ids[0]) + "{}"}

158
IGEV-MVS/datasets/eth3d.py Normal file
View File

@ -0,0 +1,158 @@
from torch.utils.data import Dataset
from datasets.data_io import *
import os
import numpy as np
import cv2
from PIL import Image
class MVSDataset(Dataset):
def __init__(self, datapath, split='test', n_views=7, img_wh=(1920,1280)):
self.levels = 4
self.datapath = datapath
self.img_wh = img_wh
self.split = split
self.build_metas()
self.n_views = n_views
def build_metas(self):
self.metas = []
if self.split == "test":
self.scans = ['botanical_garden', 'boulders', 'bridge', 'door',
'exhibition_hall', 'lecture_room', 'living_room', 'lounge',
'observatory', 'old_computer', 'statue', 'terrace_2']
elif self.split == "train":
self.scans = ['courtyard', 'delivery_area', 'electro', 'facade',
'kicker', 'meadow', 'office', 'pipes', 'playground',
'relief', 'relief_2', 'terrace', 'terrains']
for scan in self.scans:
with open(os.path.join(self.datapath, scan, 'pair.txt')) as f:
num_viewpoint = int(f.readline())
for view_idx in range(num_viewpoint):
ref_view = int(f.readline().rstrip())
src_views = [int(x) for x in f.readline().rstrip().split()[1::2]]
if len(src_views) != 0:
self.metas += [(scan, -1, ref_view, src_views)]
def read_cam_file(self, filename):
with open(filename) as f:
lines = [line.rstrip() for line in f.readlines()]
# extrinsics: line [1,5), 4x4 matrix
extrinsics = np.fromstring(' '.join(lines[1:5]), dtype=np.float32, sep=' ')
extrinsics = extrinsics.reshape((4, 4))
# intrinsics: line [7-10), 3x3 matrix
intrinsics = np.fromstring(' '.join(lines[7:10]), dtype=np.float32, sep=' ')
intrinsics = intrinsics.reshape((3, 3))
depth_min = float(lines[11].split()[0])
if depth_min < 0:
depth_min = 1
depth_max = float(lines[11].split()[-1])
return intrinsics, extrinsics, depth_min, depth_max
def read_img(self, filename, h, w):
img = Image.open(filename)
# scale 0~255 to -1~1
np_img = 2*np.array(img, dtype=np.float32) / 255. - 1
original_h, original_w, _ = np_img.shape
np_img = cv2.resize(np_img, self.img_wh, interpolation=cv2.INTER_LINEAR)
np_img_ms = {
"level_3": cv2.resize(np_img, (w//8, h//8), interpolation=cv2.INTER_LINEAR),
"level_2": cv2.resize(np_img, (w//4, h//4), interpolation=cv2.INTER_LINEAR),
"level_1": cv2.resize(np_img, (w//2, h//2), interpolation=cv2.INTER_LINEAR),
"level_0": np_img
}
return np_img_ms, original_h, original_w
def __len__(self):
return len(self.metas)
def __getitem__(self, idx):
scan, _, ref_view, src_views = self.metas[idx]
# use only the reference view and first nviews-1 source views
view_ids = [ref_view] + src_views[:self.n_views-1]
imgs_0 = []
imgs_1 = []
imgs_2 = []
imgs_3 = []
# depth = None
depth_min = None
depth_max = None
proj_matrices_0 = []
proj_matrices_1 = []
proj_matrices_2 = []
proj_matrices_3 = []
for i, vid in enumerate(view_ids):
img_filename = os.path.join(self.datapath, scan, f'images/{vid:08d}.jpg')
proj_mat_filename = os.path.join(self.datapath, scan, f'cams_1/{vid:08d}_cam.txt')
imgs, original_h, original_w = self.read_img(img_filename,self.img_wh[1], self.img_wh[0])
imgs_0.append(imgs['level_0'])
imgs_1.append(imgs['level_1'])
imgs_2.append(imgs['level_2'])
imgs_3.append(imgs['level_3'])
intrinsics, extrinsics, depth_min_, depth_max_ = self.read_cam_file(proj_mat_filename)
intrinsics[0] *= self.img_wh[0]/original_w
intrinsics[1] *= self.img_wh[1]/original_h
proj_mat = extrinsics.copy()
intrinsics[:2,:] *= 0.125
proj_mat[:3, :4] = np.matmul(intrinsics, proj_mat[:3, :4])
proj_matrices_3.append(proj_mat)
proj_mat = extrinsics.copy()
intrinsics[:2,:] *= 2
proj_mat[:3, :4] = np.matmul(intrinsics, proj_mat[:3, :4])
proj_matrices_2.append(proj_mat)
proj_mat = extrinsics.copy()
intrinsics[:2,:] *= 2
proj_mat[:3, :4] = np.matmul(intrinsics, proj_mat[:3, :4])
proj_matrices_1.append(proj_mat)
proj_mat = extrinsics.copy()
intrinsics[:2,:] *= 2
proj_mat[:3, :4] = np.matmul(intrinsics, proj_mat[:3, :4])
proj_matrices_0.append(proj_mat)
if i == 0: # reference view
depth_min = depth_min_
depth_max = depth_max_
# imgs: N*3*H0*W0, N is number of images
imgs_0 = np.stack(imgs_0).transpose([0, 3, 1, 2])
imgs_1 = np.stack(imgs_1).transpose([0, 3, 1, 2])
imgs_2 = np.stack(imgs_2).transpose([0, 3, 1, 2])
imgs_3 = np.stack(imgs_3).transpose([0, 3, 1, 2])
imgs = {}
imgs['level_0'] = imgs_0
imgs['level_1'] = imgs_1
imgs['level_2'] = imgs_2
imgs['level_3'] = imgs_3
# proj_matrices: N*4*4
proj_matrices_0 = np.stack(proj_matrices_0)
proj_matrices_1 = np.stack(proj_matrices_1)
proj_matrices_2 = np.stack(proj_matrices_2)
proj_matrices_3 = np.stack(proj_matrices_3)
proj={}
proj['level_3']=proj_matrices_3
proj['level_2']=proj_matrices_2
proj['level_1']=proj_matrices_1
proj['level_0']=proj_matrices_0
return {"imgs": imgs, # N*3*H0*W0
"proj_matrices": proj, # N*4*4
"depth_min": depth_min, # scalar
"depth_max": depth_max,
"filename": scan + '/{}/' + '{:0>8}'.format(view_ids[0]) + "{}"
}

156
IGEV-MVS/datasets/tanks.py Normal file
View File

@ -0,0 +1,156 @@
from torch.utils.data import Dataset
from datasets.data_io import *
import os
import numpy as np
import cv2
from PIL import Image
class MVSDataset(Dataset):
def __init__(self, datapath, n_views=7, img_wh=(1920, 1024), split='intermediate'):
self.levels = 4
self.datapath = datapath
self.img_wh = img_wh
self.split = split
self.build_metas()
self.n_views = n_views
def build_metas(self):
self.metas = []
if self.split == 'intermediate':
self.scans = ['Family', 'Francis', 'Horse', 'Lighthouse',
'M60', 'Panther', 'Playground', 'Train']
elif self.split == 'advanced':
self.scans = ['Auditorium', 'Ballroom', 'Courtroom',
'Museum', 'Palace', 'Temple']
for scan in self.scans:
with open(os.path.join(self.datapath, self.split, scan, 'pair.txt')) as f:
num_viewpoint = int(f.readline())
for view_idx in range(num_viewpoint):
ref_view = int(f.readline().rstrip())
src_views = [int(x) for x in f.readline().rstrip().split()[1::2]]
if len(src_views) != 0:
self.metas += [(scan, -1, ref_view, src_views)]
def read_cam_file(self, filename):
with open(filename) as f:
lines = [line.rstrip() for line in f.readlines()]
# extrinsics: line [1,5), 4x4 matrix
extrinsics = np.fromstring(' '.join(lines[1:5]), dtype=np.float32, sep=' ')
extrinsics = extrinsics.reshape((4, 4))
# intrinsics: line [7-10), 3x3 matrix
intrinsics = np.fromstring(' '.join(lines[7:10]), dtype=np.float32, sep=' ')
intrinsics = intrinsics.reshape((3, 3))
depth_min = float(lines[11].split()[0])
depth_max = float(lines[11].split()[-1])
return intrinsics, extrinsics, depth_min, depth_max
def read_img(self, filename, h, w):
img = Image.open(filename)
# scale 0~255 to -1~1
np_img = 2*np.array(img, dtype=np.float32) / 255. - 1
original_h, original_w, _ = np_img.shape
np_img = cv2.resize(np_img, self.img_wh, interpolation=cv2.INTER_LINEAR)
np_img_ms = {
"level_3": cv2.resize(np_img, (w//8, h//8), interpolation=cv2.INTER_LINEAR),
"level_2": cv2.resize(np_img, (w//4, h//4), interpolation=cv2.INTER_LINEAR),
"level_1": cv2.resize(np_img, (w//2, h//2), interpolation=cv2.INTER_LINEAR),
"level_0": np_img
}
return np_img_ms, original_h, original_w
def __len__(self):
return len(self.metas)
def __getitem__(self, idx):
scan, _, ref_view, src_views = self.metas[idx]
# use only the reference view and first nviews-1 source views
view_ids = [ref_view] + src_views[:self.n_views-1]
imgs_0 = []
imgs_1 = []
imgs_2 = []
imgs_3 = []
# depth = None
depth_min = None
depth_max = None
proj_matrices_0 = []
proj_matrices_1 = []
proj_matrices_2 = []
proj_matrices_3 = []
for i, vid in enumerate(view_ids):
img_filename = os.path.join(self.datapath, self.split, scan, f'images/{vid:08d}.jpg')
proj_mat_filename = os.path.join(self.datapath, self.split, scan, f'cams_1/{vid:08d}_cam.txt')
imgs, original_h, original_w = self.read_img(img_filename,self.img_wh[1], self.img_wh[0])
imgs_0.append(imgs['level_0'])
imgs_1.append(imgs['level_1'])
imgs_2.append(imgs['level_2'])
imgs_3.append(imgs['level_3'])
intrinsics, extrinsics, depth_min_, depth_max_ = self.read_cam_file(proj_mat_filename)
intrinsics[0] *= self.img_wh[0]/original_w
intrinsics[1] *= self.img_wh[1]/original_h
proj_mat = extrinsics.copy()
intrinsics[:2,:] *= 0.125
proj_mat[:3, :4] = np.matmul(intrinsics, proj_mat[:3, :4])
proj_matrices_3.append(proj_mat)
proj_mat = extrinsics.copy()
intrinsics[:2,:] *= 2
proj_mat[:3, :4] = np.matmul(intrinsics, proj_mat[:3, :4])
proj_matrices_2.append(proj_mat)
proj_mat = extrinsics.copy()
intrinsics[:2,:] *= 2
proj_mat[:3, :4] = np.matmul(intrinsics, proj_mat[:3, :4])
proj_matrices_1.append(proj_mat)
proj_mat = extrinsics.copy()
intrinsics[:2,:] *= 2
proj_mat[:3, :4] = np.matmul(intrinsics, proj_mat[:3, :4])
proj_matrices_0.append(proj_mat)
if i == 0: # reference view
depth_min = depth_min_
depth_max = depth_max_
# imgs: N*3*H0*W0, N is number of images
imgs_0 = np.stack(imgs_0).transpose([0, 3, 1, 2])
imgs_1 = np.stack(imgs_1).transpose([0, 3, 1, 2])
imgs_2 = np.stack(imgs_2).transpose([0, 3, 1, 2])
imgs_3 = np.stack(imgs_3).transpose([0, 3, 1, 2])
imgs = {}
imgs['level_0'] = imgs_0
imgs['level_1'] = imgs_1
imgs['level_2'] = imgs_2
imgs['level_3'] = imgs_3
# proj_matrices: N*4*4
proj_matrices_0 = np.stack(proj_matrices_0)
proj_matrices_1 = np.stack(proj_matrices_1)
proj_matrices_2 = np.stack(proj_matrices_2)
proj_matrices_3 = np.stack(proj_matrices_3)
proj={}
proj['level_3']=proj_matrices_3
proj['level_2']=proj_matrices_2
proj['level_1']=proj_matrices_1
proj['level_0']=proj_matrices_0
return {"imgs": imgs, # N*3*H0*W0
"proj_matrices": proj, # N*4*4
"depth_min": depth_min, # scalar
"depth_max": depth_max,
"filename": scan + '/{}/' + '{:0>8}'.format(view_ids[0]) + "{}"
}

450
IGEV-MVS/evaluate_mvs.py Normal file
View File

@ -0,0 +1,450 @@
import argparse
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
from torch.utils.data import DataLoader
import torch.nn.functional as F
import numpy as np
import time
from datasets import find_dataset_def
from core.igev_mvs import IGEVMVS
from utils import *
import sys
import cv2
from datasets.data_io import read_pfm, save_pfm
from core.submodule import depth_unnormalization
from plyfile import PlyData, PlyElement
from tqdm import tqdm
from PIL import Image
cudnn.benchmark = True
parser = argparse.ArgumentParser(description='Predict depth, filter, and fuse')
parser.add_argument('--model', default='IterMVS', help='select model')
parser.add_argument('--dataset', default='dtu_yao_eval', help='select dataset')
parser.add_argument('--testpath', default='/data/dtu_data/dtu_test/', help='testing data path')
parser.add_argument('--testlist', default='./lists/dtu/test.txt', help='testing scan list')
parser.add_argument('--maxdisp', default=256)
parser.add_argument('--split', default='intermediate', help='select data')
parser.add_argument('--batch_size', type=int, default=2, help='testing batch size')
parser.add_argument('--n_views', type=int, default=5, help='num of view')
parser.add_argument('--img_wh', nargs='+', type=int, default=[640, 480],
help='height and width of the image')
parser.add_argument('--loadckpt', default='./pretrained_models/dtu.ckpt', help='load a specific checkpoint')
parser.add_argument('--outdir', default='./output/', help='output dir')
parser.add_argument('--display', action='store_true', help='display depth images and masks')
parser.add_argument('--iteration', type=int, default=32, help='num of iteration of GRU')
parser.add_argument('--geo_pixel_thres', type=float, default=1, help='pixel threshold for geometric consistency filtering')
parser.add_argument('--geo_depth_thres', type=float, default=0.01, help='depth threshold for geometric consistency filtering')
parser.add_argument('--photo_thres', type=float, default=0.3, help='threshold for photometric consistency filtering')
# parse arguments and check
args = parser.parse_args()
print("argv:", sys.argv[1:])
print_args(args)
if args.dataset=="dtu_yao_eval":
img_wh=(1600, 1152)
elif args.dataset=="tanks":
img_wh=(1920, 1024)
elif args.dataset=="eth3d":
img_wh = (1920,1280)
else:
img_wh = (args.img_wh[0], args.img_wh[1]) # custom dataset
# read intrinsics and extrinsics
def read_camera_parameters(filename):
with open(filename) as f:
lines = f.readlines()
lines = [line.rstrip() for line in lines]
# extrinsics: line [1,5), 4x4 matrix
extrinsics = np.fromstring(' '.join(lines[1:5]), dtype=np.float32, sep=' ').reshape((4, 4))
# intrinsics: line [7-10), 3x3 matrix
intrinsics = np.fromstring(' '.join(lines[7:10]), dtype=np.float32, sep=' ').reshape((3, 3))
return intrinsics, extrinsics
# read an image
def read_img(filename, img_wh):
img = Image.open(filename)
# scale 0~255 to 0~1
np_img = np.array(img, dtype=np.float32) / 255.
original_h, original_w, _ = np_img.shape
np_img = cv2.resize(np_img, img_wh, interpolation=cv2.INTER_LINEAR)
return np_img, original_h, original_w
# save a binary mask
def save_mask(filename, mask):
assert mask.dtype == np.bool_
mask = mask.astype(np.uint8) * 255
Image.fromarray(mask).save(filename)
def save_depth_img(filename, depth):
# assert mask.dtype == np.bool
depth = depth.astype(np.float32) * 255
Image.fromarray(depth).save(filename)
def read_pair_file(filename):
data = []
with open(filename) as f:
num_viewpoint = int(f.readline())
# 49 viewpoints
for view_idx in range(num_viewpoint):
ref_view = int(f.readline().rstrip())
src_views = [int(x) for x in f.readline().rstrip().split()[1::2]]
if len(src_views) != 0:
data.append((ref_view, src_views))
return data
# run MVS model to save depth maps
def save_depth():
# dataset, dataloader
MVSDataset = find_dataset_def(args.dataset)
if args.dataset=="dtu_yao_eval":
test_dataset = MVSDataset(args.testpath, args.testlist, args.n_views, img_wh)
elif args.dataset=="tanks":
test_dataset = MVSDataset(args.testpath, args.n_views, img_wh, args.split)
elif args.dataset=="eth3d":
test_dataset = MVSDataset(args.testpath, args.split, args.n_views, img_wh)
else:
test_dataset = MVSDataset(args.testpath, args.n_views, img_wh)
TestImgLoader = DataLoader(test_dataset, args.batch_size, shuffle=False, num_workers=4, drop_last=False)
# model
model = IGEVMVS(args)
model = nn.DataParallel(model)
model.cuda()
# load checkpoint file specified by args.loadckpt
print("loading model {}".format(args.loadckpt))
state_dict = torch.load(args.loadckpt)
model.load_state_dict(state_dict['model'])
model.eval()
with torch.no_grad():
tbar = tqdm(TestImgLoader)
for batch_idx, sample in enumerate(tbar):
start_time = time.time()
sample_cuda = tocuda(sample)
disp_prediction = model(sample_cuda["imgs"], sample_cuda["proj_matrices"],
sample_cuda["depth_min"], sample_cuda["depth_max"], test_mode=True)
b = sample_cuda["depth_min"].shape[0]
inverse_depth_min = (1.0 / sample_cuda["depth_min"]).view(b, 1, 1, 1)
inverse_depth_max = (1.0 / sample_cuda["depth_max"]).view(b, 1, 1, 1)
depth_prediction = depth_unnormalization(disp_prediction, inverse_depth_min, inverse_depth_max)
depth_prediction = tensor2numpy(depth_prediction.float())
del sample_cuda, disp_prediction
tbar.set_description('Iter {}/{}, time = {:.3f}'.format(batch_idx, len(TestImgLoader), time.time() - start_time))
filenames = sample["filename"]
# save depth maps and confidence maps
for filename, depth_est in zip(filenames, depth_prediction):
depth_filename = os.path.join(args.outdir, filename.format('depth_est', '.pfm'))
os.makedirs(depth_filename.rsplit('/', 1)[0], exist_ok=True)
# save depth maps
depth_est = np.squeeze(depth_est, 0)
save_pfm(depth_filename, depth_est)
# project the reference point cloud into the source view, then project back
def reproject_with_depth(depth_ref, intrinsics_ref, extrinsics_ref, depth_src, intrinsics_src, extrinsics_src):
width, height = depth_ref.shape[1], depth_ref.shape[0]
## step1. project reference pixels to the source view
# reference view x, y
x_ref, y_ref = np.meshgrid(np.arange(0, width), np.arange(0, height))
x_ref, y_ref = x_ref.reshape([-1]), y_ref.reshape([-1])
# reference 3D space
xyz_ref = np.matmul(np.linalg.inv(intrinsics_ref),
np.vstack((x_ref, y_ref, np.ones_like(x_ref))) * depth_ref.reshape([-1]))
# source 3D space
xyz_src = np.matmul(np.matmul(extrinsics_src, np.linalg.inv(extrinsics_ref)),
np.vstack((xyz_ref, np.ones_like(x_ref))))[:3]
# source view x, y
K_xyz_src = np.matmul(intrinsics_src, xyz_src)
xy_src = K_xyz_src[:2] / K_xyz_src[2:3]
## step2. reproject the source view points with source view depth estimation
# find the depth estimation of the source view
x_src = xy_src[0].reshape([height, width]).astype(np.float32)
y_src = xy_src[1].reshape([height, width]).astype(np.float32)
sampled_depth_src = cv2.remap(depth_src, x_src, y_src, interpolation=cv2.INTER_LINEAR)
# mask = sampled_depth_src > 0
# source 3D space
# NOTE that we should use sampled source-view depth_here to project back
xyz_src = np.matmul(np.linalg.inv(intrinsics_src),
np.vstack((xy_src, np.ones_like(x_ref))) * sampled_depth_src.reshape([-1]))
# reference 3D space
xyz_reprojected = np.matmul(np.matmul(extrinsics_ref, np.linalg.inv(extrinsics_src)),
np.vstack((xyz_src, np.ones_like(x_ref))))[:3]
# source view x, y, depth
depth_reprojected = xyz_reprojected[2].reshape([height, width]).astype(np.float32)
K_xyz_reprojected = np.matmul(intrinsics_ref, xyz_reprojected)
xy_reprojected = K_xyz_reprojected[:2] / (K_xyz_reprojected[2:3]+1e-6)
x_reprojected = xy_reprojected[0].reshape([height, width]).astype(np.float32)
y_reprojected = xy_reprojected[1].reshape([height, width]).astype(np.float32)
return depth_reprojected, x_reprojected, y_reprojected, x_src, y_src
def check_geometric_consistency(depth_ref, intrinsics_ref, extrinsics_ref, depth_src, intrinsics_src, extrinsics_src, thre1, thre2):
width, height = depth_ref.shape[1], depth_ref.shape[0]
x_ref, y_ref = np.meshgrid(np.arange(0, width), np.arange(0, height))
depth_reprojected, x2d_reprojected, y2d_reprojected, x2d_src, y2d_src = reproject_with_depth(depth_ref,
intrinsics_ref,
extrinsics_ref,
depth_src,
intrinsics_src,
extrinsics_src)
# check |p_reproj-p_1| < 1
dist = np.sqrt((x2d_reprojected - x_ref) ** 2 + (y2d_reprojected - y_ref) ** 2)
# check |d_reproj-d_1| / d_1 < 0.01
depth_diff = np.abs(depth_reprojected - depth_ref)
relative_depth_diff = depth_diff / depth_ref
masks=[]
for i in range(2,11):
mask = np.logical_and(dist < i/thre1, relative_depth_diff < i/thre2)
masks.append(mask)
depth_reprojected[~mask] = 0
return masks, mask, depth_reprojected, x2d_src, y2d_src
def filter_depth(scan_folder, out_folder, plyfilename, geo_pixel_thres, geo_depth_thres, photo_thres, img_wh, geo_mask_thres=3):
# the pair file
pair_file = os.path.join(scan_folder, "pair.txt")
# for the final point cloud
vertexs = []
vertex_colors = []
pair_data = read_pair_file(pair_file)
nviews = len(pair_data)
thre_left = -2
thre_right = 2
total_iter = 10
for iter in range(total_iter):
thre = (thre_left + thre_right) / 2
print(f"{iter} {10 ** thre}")
depth_est_averaged = []
geo_mask_all = []
# for each reference view and the corresponding source views
for ref_view, src_views in pair_data:
# load the camera parameters
ref_intrinsics, ref_extrinsics = read_camera_parameters(
os.path.join(scan_folder, 'cams_1/{:0>8}_cam.txt'.format(ref_view)))
ref_img, original_h, original_w = read_img(os.path.join(scan_folder, 'images/{:0>8}.jpg'.format(ref_view)), img_wh)
ref_intrinsics[0] *= img_wh[0]/original_w
ref_intrinsics[1] *= img_wh[1]/original_h
# load the estimated depth of the reference view
ref_depth_est = read_pfm(os.path.join(out_folder, 'depth_est/{:0>8}.pfm'.format(ref_view)))[0]
ref_depth_est = np.squeeze(ref_depth_est, 2)
all_srcview_depth_ests = []
# compute the geometric mask
geo_mask_sum = 0
geo_mask_sums=[]
n = 1 + len(src_views)
ct = 0
for src_view in src_views:
ct = ct + 1
# camera parameters of the source view
src_intrinsics, src_extrinsics = read_camera_parameters(
os.path.join(scan_folder, 'cams_1/{:0>8}_cam.txt'.format(src_view)))
_, original_h, original_w = read_img(os.path.join(scan_folder, 'images/{:0>8}.jpg'.format(src_view)), img_wh)
src_intrinsics[0] *= img_wh[0]/original_w
src_intrinsics[1] *= img_wh[1]/original_h
# the estimated depth of the source view
src_depth_est = read_pfm(os.path.join(out_folder, 'depth_est/{:0>8}.pfm'.format(src_view)))[0]
masks, geo_mask, depth_reprojected, _, _ = check_geometric_consistency(ref_depth_est, ref_intrinsics, ref_extrinsics,
src_depth_est,
src_intrinsics, src_extrinsics, 10 ** thre * 4, 10 ** thre * 1300)
if (ct==1):
for i in range(2,n):
geo_mask_sums.append(masks[i-2].astype(np.int32))
else:
for i in range(2,n):
geo_mask_sums[i-2]+=masks[i-2].astype(np.int32)
geo_mask_sum+=geo_mask.astype(np.int32)
all_srcview_depth_ests.append(depth_reprojected)
geo_mask=geo_mask_sum>=n
for i in range (2,n):
geo_mask=np.logical_or(geo_mask,geo_mask_sums[i-2]>=i)
depth_est_averaged.append((sum(all_srcview_depth_ests) + ref_depth_est) / (geo_mask_sum + 1))
geo_mask_all.append(np.mean(geo_mask))
final_mask = geo_mask
if iter == total_iter - 1:
os.makedirs(os.path.join(out_folder, "mask"), exist_ok=True)
save_mask(os.path.join(out_folder, "mask/{:0>8}_geo.png".format(ref_view)), geo_mask)
save_mask(os.path.join(out_folder, "mask/{:0>8}_final.png".format(ref_view)), final_mask)
print("processing {}, ref-view{:0>2}, geo_mask:{:3f} final_mask: {:3f}".format(scan_folder, ref_view,
geo_mask.mean(), final_mask.mean()))
if args.display:
cv2.imshow('ref_img', ref_img[:, :, ::-1])
cv2.imshow('ref_depth', ref_depth_est / np.max(ref_depth_est))
cv2.imshow('ref_depth * geo_mask', ref_depth_est * geo_mask.astype(np.float32) / np.max(ref_depth_est))
cv2.imshow('ref_depth * mask', ref_depth_est * final_mask.astype(np.float32) / np.max(ref_depth_est))
cv2.waitKey(0)
height, width = depth_est_averaged[-1].shape[:2]
x, y = np.meshgrid(np.arange(0, width), np.arange(0, height))
valid_points = final_mask
# print("valid_points", valid_points.mean())
x, y, depth = x[valid_points], y[valid_points], depth_est_averaged[-1][valid_points]
color = ref_img[valid_points]
xyz_ref = np.matmul(np.linalg.inv(ref_intrinsics),
np.vstack((x, y, np.ones_like(x))) * depth)
xyz_world = np.matmul(np.linalg.inv(ref_extrinsics),
np.vstack((xyz_ref, np.ones_like(x))))[:3]
vertexs.append(xyz_world.transpose((1, 0)))
vertex_colors.append((color * 255).astype(np.uint8))
if np.mean(geo_mask_all) >= 0.25:
thre_left = thre
else:
thre_right = thre
vertexs = np.concatenate(vertexs, axis=0)
vertex_colors = np.concatenate(vertex_colors, axis=0)
vertexs = np.array([tuple(v) for v in vertexs], dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4')])
vertex_colors = np.array([tuple(v) for v in vertex_colors], dtype=[('red', 'u1'), ('green', 'u1'), ('blue', 'u1')])
vertex_all = np.empty(len(vertexs), vertexs.dtype.descr + vertex_colors.dtype.descr)
for prop in vertexs.dtype.names:
vertex_all[prop] = vertexs[prop]
for prop in vertex_colors.dtype.names:
vertex_all[prop] = vertex_colors[prop]
el = PlyElement.describe(vertex_all, 'vertex')
PlyData([el]).write(plyfilename)
print("saving the final model to", plyfilename)
if __name__ == '__main__':
save_depth()
if args.dataset=="dtu_yao_eval":
with open(args.testlist) as f:
scans = f.readlines()
scans = [line.rstrip() for line in scans]
for scan in scans:
scan_id = int(scan[4:])
scan_folder = os.path.join(args.testpath, scan)
out_folder = os.path.join(args.outdir, scan)
filter_depth(scan_folder, out_folder, os.path.join(args.outdir, 'igev_mvs{:0>3}_l3.ply'.format(scan_id)),
args.geo_pixel_thres, args.geo_depth_thres, args.photo_thres, img_wh, 4)
elif args.dataset=="tanks":
# intermediate dataset
if args.split == "intermediate":
scans = ['Family', 'Francis', 'Horse', 'Lighthouse',
'M60', 'Panther', 'Playground', 'Train']
geo_mask_thres = {'Family': 5,
'Francis': 6,
'Horse': 5,
'Lighthouse': 6,
'M60': 5,
'Panther': 5,
'Playground': 5,
'Train': 5}
for scan in scans:
scan_folder = os.path.join(args.testpath, args.split, scan)
out_folder = os.path.join(args.outdir, scan)
filter_depth(scan_folder, out_folder, os.path.join(args.outdir, scan + '.ply'),
args.geo_pixel_thres, args.geo_depth_thres, args.photo_thres, img_wh, geo_mask_thres[scan])
# advanced dataset
elif args.split == "advanced":
scans = ['Auditorium', 'Ballroom', 'Courtroom',
'Museum', 'Palace', 'Temple']
geo_mask_thres = {'Auditorium': 3,
'Ballroom': 4,
'Courtroom': 4,
'Museum': 4,
'Palace': 5,
'Temple': 4}
for scan in scans:
scan_folder = os.path.join(args.testpath, args.split, scan)
out_folder = os.path.join(args.outdir, scan)
filter_depth(scan_folder, out_folder, os.path.join(args.outdir, scan + '.ply'),
args.geo_pixel_thres, args.geo_depth_thres, args.photo_thres, img_wh, geo_mask_thres[scan])
elif args.dataset=="eth3d":
if args.split == "test":
scans = ['botanical_garden', 'boulders', 'bridge', 'door',
'exhibition_hall', 'lecture_room', 'living_room', 'lounge',
'observatory', 'old_computer', 'statue', 'terrace_2']
geo_mask_thres = {'botanical_garden':1, # 30 images, outdoor
'boulders':1, # 26 images, outdoor
'bridge':2, # 110 images, outdoor
'door':2, # 6 images, indoor
'exhibition_hall':2, # 68 images, indoor
'lecture_room':2, # 23 images, indoor
'living_room':2, # 65 images, indoor
'lounge':1,# 10 images, indoor
'observatory':2, # 27 images, outdoor
'old_computer':2, # 54 images, indoor
'statue':2, # 10 images, indoor
'terrace_2':2 # 13 images, outdoor
}
for scan in scans:
start_time = time.time()
scan_folder = os.path.join(args.testpath, scan)
out_folder = os.path.join(args.outdir, scan)
filter_depth(scan_folder, out_folder, os.path.join(args.outdir, scan + '.ply'),
args.geo_pixel_thres, args.geo_depth_thres, args.photo_thres, img_wh, geo_mask_thres[scan])
print('scan: '+scan+' time = {:3f}'.format(time.time() - start_time))
elif args.split == "train":
scans = ['courtyard', 'delivery_area', 'electro', 'facade',
'kicker', 'meadow', 'office', 'pipes', 'playground',
'relief', 'relief_2', 'terrace', 'terrains']
geo_mask_thres = {'courtyard':1, # 38 images, outdoor
'delivery_area':2, # 44 images, indoor
'electro':1, # 45 images, outdoor
'facade':2, # 76 images, outdoor
'kicker':1, # 31 images, indoor
'meadow':1, # 15 images, outdoor
'office':1, # 26 images, indoor
'pipes':1,# 14 images, indoor
'playground':1, # 38 images, outdoor
'relief':1, # 31 images, indoor
'relief_2':1, # 31 images, indoor
'terrace':1, # 23 images, outdoor
'terrains':2 # 42 images, indoor
}
for scan in scans:
start_time = time.time()
scan_folder = os.path.join(args.testpath, scan)
out_folder = os.path.join(args.outdir, scan)
filter_depth(scan_folder, out_folder, os.path.join(args.outdir, scan + '.ply'),
args.geo_pixel_thres, args.geo_depth_thres, args.photo_thres, img_wh, geo_mask_thres[scan])
print('scan: '+scan+' time = {:3f}'.format(time.time() - start_time))
else:
filter_depth(args.testpath, args.outdir, os.path.join(args.outdir, 'custom.ply'),
args.geo_pixel_thres, args.geo_depth_thres, args.photo_thres, img_wh, geo_mask_thres=3)

View File

@ -0,0 +1,44 @@
function BaseEval2Obj_web(BaseEval,method_string,outputPath)
if(nargin<3)
outputPath='./';
end
% tresshold for coloring alpha channel in the range of 0-10 mm
dist_tresshold=10;
cSet=BaseEval.cSet;
Qdata=BaseEval.Qdata;
alpha=min(BaseEval.Ddata,dist_tresshold)/dist_tresshold;
fid=fopen([outputPath method_string '2Stl_' num2str(cSet) ' .obj'],'w+');
for cP=1:size(Qdata,2)
if(BaseEval.DataInMask(cP))
C=[1 0 0]*alpha(cP)+[1 1 1]*(1-alpha(cP)); %coloring from red to white in the range of 0-10 mm (0 to dist_tresshold)
else
C=[0 1 0]*alpha(cP)+[0 0 1]*(1-alpha(cP)); %green to blue for points outside the mask (which are not included in the analysis)
end
fprintf(fid,'v %f %f %f %f %f %f\n',[Qdata(1,cP) Qdata(2,cP) Qdata(3,cP) C(1) C(2) C(3)]);
end
fclose(fid);
disp('Data2Stl saved as obj')
Qstl=BaseEval.Qstl;
fid=fopen([outputPath 'Stl2' method_string '_' num2str(cSet) '.obj'],'w+');
alpha=min(BaseEval.Dstl,dist_tresshold)/dist_tresshold;
for cP=1:size(Qstl,2)
if(BaseEval.StlAbovePlane(cP))
C=[1 0 0]*alpha(cP)+[1 1 1]*(1-alpha(cP)); %coloring from red to white in the range of 0-10 mm (0 to dist_tresshold)
else
C=[0 1 0]*alpha(cP)+[0 0 1]*(1-alpha(cP)); %green to blue for points below plane (which are not included in the analysis)
end
fprintf(fid,'v %f %f %f %f %f %f\n',[Qstl(1,cP) Qstl(2,cP) Qstl(3,cP) C(1) C(2) C(3)]);
end
fclose(fid);
disp('Stl2Data saved as obj')

View File

@ -0,0 +1,104 @@
clear all
close all
format compact
clc
% script to calculate distances have been measured for all included scans (UsedSets)
dataPath='D:\xgw\IterMVS_data\MVS Data\';
plyPath='/data/xgw/IGEV_MVS/conf_03/';
resultsPath='/data/xgw/IGEV_MVS/outputs_conf_03/';
method_string='itermvs';
light_string='l3'; % l3 is the setting with all lights on, l7 is randomly sampled between the 7 settings (index 0-6)
representation_string='Points'; %mvs representation 'Points' or 'Surfaces'
switch representation_string
case 'Points'
eval_string='_Eval_'; %results naming
settings_string='';
end
% get sets used in evaluation
UsedSets=[1 4 9 10 11 12 13 15 23 24 29 32 33 34 48 49 62 75 77 110 114 118];
result = zeros(length(UsedSets),4);
dst=0.2; %Min dist between points when reducing
for cIdx=1:length(UsedSets)
%Data set number
cSet = UsedSets(cIdx)
%input data name
DataInName=[plyPath sprintf('%s%03d_%s%s.ply',lower(method_string),cSet,light_string,settings_string)]
%results name
%concatenate strings into one string
EvalName=[resultsPath method_string eval_string num2str(cSet) '.mat']
disp(EvalName)
%check if file is already computed
if(~exist(EvalName,'file'))
disp(DataInName);
time=clock;time(4:5), drawnow
tic
Mesh = plyread(DataInName);
Qdata=[Mesh.vertex.x Mesh.vertex.y Mesh.vertex.z]';
toc
BaseEval=PointCompareMain(cSet,Qdata,dst,dataPath);
disp('Saving results'), drawnow
toc
save(EvalName,'BaseEval');
toc
% write obj-file of evaluation
% BaseEval2Obj_web(BaseEval,method_string, resultsPath)
% toc
time=clock;time(4:5), drawnow
BaseEval.MaxDist=20; %outlier threshold of 20 mm
BaseEval.FilteredDstl=BaseEval.Dstl(BaseEval.StlAbovePlane); %use only points that are above the plane
BaseEval.FilteredDstl=BaseEval.FilteredDstl(BaseEval.FilteredDstl<BaseEval.MaxDist); % discard outliers
BaseEval.FilteredDdata=BaseEval.Ddata(BaseEval.DataInMask); %use only points that within mask
BaseEval.FilteredDdata=BaseEval.FilteredDdata(BaseEval.FilteredDdata<BaseEval.MaxDist); % discard outliers
fprintf("mean/median Data (acc.) %f/%f\n", mean(BaseEval.FilteredDdata), median(BaseEval.FilteredDdata));
fprintf("mean/median Stl (comp.) %f/%f\n", mean(BaseEval.FilteredDstl), median(BaseEval.FilteredDstl));
result(cIdx,1) = mean(BaseEval.FilteredDdata);
result(cIdx,2) = median(BaseEval.FilteredDdata);
result(cIdx,3) = mean(BaseEval.FilteredDstl);
result(cIdx,4) = median(BaseEval.FilteredDstl);
else
load(EvalName);
BaseEval.MaxDist=20; %outlier threshold of 20 mm
BaseEval.FilteredDstl=BaseEval.Dstl(BaseEval.StlAbovePlane); %use only points that are above the plane
BaseEval.FilteredDstl=BaseEval.FilteredDstl(BaseEval.FilteredDstl<BaseEval.MaxDist); % discard outliers
BaseEval.FilteredDdata=BaseEval.Ddata(BaseEval.DataInMask); %use only points that within mask
BaseEval.FilteredDdata=BaseEval.FilteredDdata(BaseEval.FilteredDdata<BaseEval.MaxDist); % discard outliers
fprintf("mean/median Data (acc.) %f/%f\n", mean(BaseEval.FilteredDdata), median(BaseEval.FilteredDdata));
fprintf("mean/median Stl (comp.) %f/%f\n", mean(BaseEval.FilteredDstl), median(BaseEval.FilteredDstl));
result(cIdx,1) = mean(BaseEval.FilteredDdata);
result(cIdx,2) = median(BaseEval.FilteredDdata);
result(cIdx,3) = mean(BaseEval.FilteredDstl);
result(cIdx,4) = median(BaseEval.FilteredDstl);
end
end
mean_result=mean(result);
fprintf("final evaluation result on all scans: acc.: %f, comp.: %f, overall: %f\n", mean_result(1), mean_result(3), (mean_result(1)+mean_result(3))/2);

View File

@ -0,0 +1,87 @@
clear all
close all
format compact
clc
% script to calculate the statistics for each scan given this will currently only run if distances have been measured
% for all included scans (UsedSets)
% modify the path to evaluate your models
dataPath='/home/SampleSet/MVS Data/';
resultsPath='/home/PatchmatchNet/outputs/';
MaxDist=20; %outlier thresshold of 20 mm
time=clock;
method_string='patchmatchnet';
light_string='l3'; %'l7'; l3 is the setting with all lights on, l7 is randomly sampled between the 7 settings (index 0-6)
representation_string='Points'; %mvs representation 'Points' or 'Surfaces'
switch representation_string
case 'Points'
eval_string='_Eval_'; %results naming
settings_string='';
end
% get sets used in evaluation
UsedSets=[1 4 9 10 11 12 13 15 23 24 29 32 33 34 48 49 62 75 77 110 114 118];
nStat=length(UsedSets);
% struct
BaseStat.nStl=zeros(1,nStat);
BaseStat.nData=zeros(1,nStat);
BaseStat.MeanStl=zeros(1,nStat);
BaseStat.MeanData=zeros(1,nStat);
BaseStat.VarStl=zeros(1,nStat);
BaseStat.VarData=zeros(1,nStat);
BaseStat.MedStl=zeros(1,nStat);
BaseStat.MedData=zeros(1,nStat);
for cStat=1:length(UsedSets) %Data set number
currentSet=UsedSets(cStat);
%input results name
EvalName=[resultsPath method_string eval_string num2str(currentSet) '.mat'];
disp(EvalName);
load(EvalName);
Dstl=BaseEval.Dstl(BaseEval.StlAbovePlane); %use only points that are above the plane
Dstl=Dstl(Dstl<MaxDist); % discard outliers
Ddata=BaseEval.Ddata(BaseEval.DataInMask); %use only points that within mask
Ddata=Ddata(Ddata<MaxDist); % discard outliers
BaseStat.nStl(cStat)=length(Dstl);
BaseStat.nData(cStat)=length(Ddata);
BaseStat.MeanStl(cStat)=mean(Dstl);
BaseStat.MeanData(cStat)=mean(Ddata);
BaseStat.VarStl(cStat)=var(Dstl);
BaseStat.VarData(cStat)=var(Ddata);
BaseStat.MedStl(cStat)=median(Dstl);
BaseStat.MedData(cStat)=median(Ddata);
disp("acc");
disp(mean(Ddata));
disp("comp");
disp(mean(Dstl));
time=clock;
end
disp(BaseStat);
disp("mean acc")
disp(mean(BaseStat.MeanData));
disp("mean comp")
disp(mean(BaseStat.MeanStl));
totalStatName=[resultsPath 'TotalStat_' method_string eval_string '.mat']
save(totalStatName,'BaseStat','time','MaxDist');

View File

@ -0,0 +1,50 @@
function Dist = MaxDistCP(Qto,Qfrom,BB,MaxDist)
Dist=ones(1,size(Qfrom,2))*MaxDist;
Range=floor((BB(2,:)-BB(1,:))/MaxDist);
tic
Done=0;
LookAt=zeros(1,size(Qfrom,2));
for x=0:Range(1),
for y=0:Range(2),
for z=0:Range(3),
Low=BB(1,:)+[x y z]*MaxDist;
High=Low+MaxDist;
idxF=find(Qfrom(1,:)>=Low(1) & Qfrom(2,:)>=Low(2) & Qfrom(3,:)>=Low(3) &...
Qfrom(1,:)<High(1) & Qfrom(2,:)<High(2) & Qfrom(3,:)<High(3));
SQfrom=Qfrom(:,idxF);
LookAt(idxF)=LookAt(idxF)+1; %Debug
Low=Low-MaxDist;
High=High+MaxDist;
idxT=find(Qto(1,:)>=Low(1) & Qto(2,:)>=Low(2) & Qto(3,:)>=Low(3) &...
Qto(1,:)<High(1) & Qto(2,:)<High(2) & Qto(3,:)<High(3));
SQto=Qto(:,idxT);
if(isempty(SQto))
Dist(idxF)=MaxDist;
else
KDstl=KDTreeSearcher(SQto');
[~,SDist] = knnsearch(KDstl,SQfrom');
Dist(idxF)=SDist;
end
Done=Done+length(idxF); %Debug
end
end
%Complete=Done/size(Qfrom,2);
%EstTime=(toc/Complete)/60
%toc
%LA=[sum(LookAt==0),...
% sum(LookAt==1),...
% sum(LookAt==2),...
% sum(LookAt==3),...
% sum(LookAt>3)]
end

View File

@ -0,0 +1,58 @@
function BaseEval=PointCompareMain(cSet,Qdata,dst,dataPath)
% evaluation function the calculates the distantes from the reference data (stl) to the evalution points (Qdata) and the
% distances from the evaluation points to the reference
tic
% reduce points 0.2 mm neighbourhood density
Qdata=reducePts_haa(Qdata,dst);
toc
StlInName=[dataPath '/Points/stl/stl' sprintf('%03d',cSet) '_total.ply'];
StlMesh = plyread(StlInName); %STL points already reduced 0.2 mm neighbourhood density
Qstl=[StlMesh.vertex.x StlMesh.vertex.y StlMesh.vertex.z]';
%Load Mask (ObsMask) and Bounding box (BB) and Resolution (Res)
Margin=10;
MaskName=[dataPath '/ObsMask/ObsMask' num2str(cSet) '_' num2str(Margin) '.mat'];
load(MaskName)
MaxDist=60;
disp('Computing Data 2 Stl distances')
Ddata = MaxDistCP(Qstl,Qdata,BB,MaxDist);
toc
disp('Computing Stl 2 Data distances')
Dstl=MaxDistCP(Qdata,Qstl,BB,MaxDist);
disp('Distances computed')
toc
%use mask
%From Get mask - inverted & modified.
One=ones(1,size(Qdata,2));
Qv=(Qdata-BB(1,:)'*One)/Res+1;
Qv=round(Qv);
Midx1=find(Qv(1,:)>0 & Qv(1,:)<=size(ObsMask,1) & Qv(2,:)>0 & Qv(2,:)<=size(ObsMask,2) & Qv(3,:)>0 & Qv(3,:)<=size(ObsMask,3));
MidxA=sub2ind(size(ObsMask),Qv(1,Midx1),Qv(2,Midx1),Qv(3,Midx1));
Midx2=find(ObsMask(MidxA));
BaseEval.DataInMask(1:size(Qv,2))=false;
BaseEval.DataInMask(Midx1(Midx2))=true; %If Data is within the mask
BaseEval.cSet=cSet;
BaseEval.Margin=Margin; %Margin of masks
BaseEval.dst=dst; %Min dist between points when reducing
BaseEval.Qdata=Qdata; %Input data points
BaseEval.Ddata=Ddata; %distance from data to stl
BaseEval.Qstl=Qstl; %Input stl points
BaseEval.Dstl=Dstl; %Distance from the stl to data
load([dataPath '/ObsMask/Plane' num2str(cSet)],'P')
BaseEval.GroundPlane=P; % Plane used to destinguise which Stl points are 'used'
BaseEval.StlAbovePlane=(P'*[Qstl;ones(1,size(Qstl,2))])>0; %Is stl above 'ground plane'
BaseEval.Time=clock; %Time when computation is finished

View File

@ -0,0 +1,454 @@
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function [Elements,varargout] = plyread(Path,Str)
%PLYREAD Read a PLY 3D data file.
% [DATA,COMMENTS] = PLYREAD(FILENAME) reads a version 1.0 PLY file
% FILENAME and returns a structure DATA. The fields in this structure
% are defined by the PLY header; each element type is a field and each
% element property is a subfield. If the file contains any comments,
% they are returned in a cell string array COMMENTS.
%
% [TRI,PTS] = PLYREAD(FILENAME,'tri') or
% [TRI,PTS,DATA,COMMENTS] = PLYREAD(FILENAME,'tri') converts vertex
% and face data into triangular connectivity and vertex arrays. The
% mesh can then be displayed using the TRISURF command.
%
% Note: This function is slow for large mesh files (+50K faces),
% especially when reading data with list type properties.
%
% Example:
% [Tri,Pts] = PLYREAD('cow.ply','tri');
% trisurf(Tri,Pts(:,1),Pts(:,2),Pts(:,3));
% colormap(gray); axis equal;
%
% See also: PLYWRITE
% Pascal Getreuer 2004
[fid,Msg] = fopen(Path,'rt'); % open file in read text mode
if fid == -1, error(Msg); end
Buf = fscanf(fid,'%s',1);
if ~strcmp(Buf,'ply')
fclose(fid);
error('Not a PLY file.');
end
%%% read header %%%
Position = ftell(fid);
Format = '';
NumComments = 0;
Comments = {}; % for storing any file comments
NumElements = 0;
NumProperties = 0;
Elements = []; % structure for holding the element data
ElementCount = []; % number of each type of element in file
PropertyTypes = []; % corresponding structure recording property types
ElementNames = {}; % list of element names in the order they are stored in the file
PropertyNames = []; % structure of lists of property names
while 1
Buf = fgetl(fid); % read one line from file
BufRem = Buf;
Token = {};
Count = 0;
while ~isempty(BufRem) % split line into tokens
[tmp,BufRem] = strtok(BufRem);
if ~isempty(tmp)
Count = Count + 1; % count tokens
Token{Count} = tmp;
end
end
if Count % parse line
switch lower(Token{1})
case 'format' % read data format
if Count >= 2
Format = lower(Token{2});
if Count == 3 & ~strcmp(Token{3},'1.0')
fclose(fid);
error('Only PLY format version 1.0 supported.');
end
end
case 'comment' % read file comment
NumComments = NumComments + 1;
Comments{NumComments} = '';
for i = 2:Count
Comments{NumComments} = [Comments{NumComments},Token{i},' '];
end
case 'element' % element name
if Count >= 3
if isfield(Elements,Token{2})
fclose(fid);
error(['Duplicate element name, ''',Token{2},'''.']);
end
NumElements = NumElements + 1;
NumProperties = 0;
Elements = setfield(Elements,Token{2},[]);
PropertyTypes = setfield(PropertyTypes,Token{2},[]);
ElementNames{NumElements} = Token{2};
PropertyNames = setfield(PropertyNames,Token{2},{});
CurElement = Token{2};
ElementCount(NumElements) = str2double(Token{3});
if isnan(ElementCount(NumElements))
fclose(fid);
error(['Bad element definition: ',Buf]);
end
else
error(['Bad element definition: ',Buf]);
end
case 'property' % element property
if ~isempty(CurElement) & Count >= 3
NumProperties = NumProperties + 1;
eval(['tmp=isfield(Elements.',CurElement,',Token{Count});'],...
'fclose(fid);error([''Error reading property: '',Buf])');
if tmp
error(['Duplicate property name, ''',CurElement,'.',Token{2},'''.']);
end
% add property subfield to Elements
eval(['Elements.',CurElement,'.',Token{Count},'=[];'], ...
'fclose(fid);error([''Error reading property: '',Buf])');
% add property subfield to PropertyTypes and save type
eval(['PropertyTypes.',CurElement,'.',Token{Count},'={Token{2:Count-1}};'], ...
'fclose(fid);error([''Error reading property: '',Buf])');
% record property name order
eval(['PropertyNames.',CurElement,'{NumProperties}=Token{Count};'], ...
'fclose(fid);error([''Error reading property: '',Buf])');
else
fclose(fid);
if isempty(CurElement)
error(['Property definition without element definition: ',Buf]);
else
error(['Bad property definition: ',Buf]);
end
end
case 'end_header' % end of header, break from while loop
break;
end
end
end
%%% set reading for specified data format %%%
if isempty(Format)
warning('Data format unspecified, assuming ASCII.');
Format = 'ascii';
end
switch Format
case 'ascii'
Format = 0;
case 'binary_little_endian'
Format = 1;
case 'binary_big_endian'
Format = 2;
otherwise
fclose(fid);
error(['Data format ''',Format,''' not supported.']);
end
if ~Format
Buf = fscanf(fid,'%f'); % read the rest of the file as ASCII data
BufOff = 1;
else
% reopen the file in read binary mode
fclose(fid);
if Format == 1
fid = fopen(Path,'r','ieee-le.l64'); % little endian
else
fid = fopen(Path,'r','ieee-be.l64'); % big endian
end
% find the end of the header again (using ftell on the old handle doesn't give the correct position)
BufSize = 8192;
Buf = [blanks(10),char(fread(fid,BufSize,'uchar')')];
i = [];
tmp = -11;
while isempty(i)
i = findstr(Buf,['end_header',13,10]); % look for end_header + CR/LF
i = [i,findstr(Buf,['end_header',10])]; % look for end_header + LF
if isempty(i)
tmp = tmp + BufSize;
Buf = [Buf(BufSize+1:BufSize+10),char(fread(fid,BufSize,'uchar')')];
end
end
% seek to just after the line feed
fseek(fid,i + tmp + 11 + (Buf(i + 10) == 13),-1);
end
%%% read element data %%%
% PLY and MATLAB data types (for fread)
PlyTypeNames = {'char','uchar','short','ushort','int','uint','float','double', ...
'char8','uchar8','short16','ushort16','int32','uint32','float32','double64'};
MatlabTypeNames = {'schar','uchar','int16','uint16','int32','uint32','single','double'};
SizeOf = [1,1,2,2,4,4,4,8]; % size in bytes of each type
for i = 1:NumElements
% get current element property information
eval(['CurPropertyNames=PropertyNames.',ElementNames{i},';']);
eval(['CurPropertyTypes=PropertyTypes.',ElementNames{i},';']);
NumProperties = size(CurPropertyNames,2);
% fprintf('Reading %s...\n',ElementNames{i});
if ~Format %%% read ASCII data %%%
for j = 1:NumProperties
Token = getfield(CurPropertyTypes,CurPropertyNames{j});
if strcmpi(Token{1},'list')
Type(j) = 1;
else
Type(j) = 0;
end
end
% parse buffer
if ~any(Type)
% no list types
Data = reshape(Buf(BufOff:BufOff+ElementCount(i)*NumProperties-1),NumProperties,ElementCount(i))';
BufOff = BufOff + ElementCount(i)*NumProperties;
else
ListData = cell(NumProperties,1);
for k = 1:NumProperties
ListData{k} = cell(ElementCount(i),1);
end
% list type
for j = 1:ElementCount(i)
for k = 1:NumProperties
if ~Type(k)
Data(j,k) = Buf(BufOff);
BufOff = BufOff + 1;
else
tmp = Buf(BufOff);
ListData{k}{j} = Buf(BufOff+(1:tmp))';
BufOff = BufOff + tmp + 1;
end
end
end
end
else %%% read binary data %%%
% translate PLY data type names to MATLAB data type names
ListFlag = 0; % = 1 if there is a list type
SameFlag = 1; % = 1 if all types are the same
for j = 1:NumProperties
Token = getfield(CurPropertyTypes,CurPropertyNames{j});
if ~strcmp(Token{1},'list') % non-list type
tmp = rem(strmatch(Token{1},PlyTypeNames,'exact')-1,8)+1;
if ~isempty(tmp)
TypeSize(j) = SizeOf(tmp);
Type{j} = MatlabTypeNames{tmp};
TypeSize2(j) = 0;
Type2{j} = '';
SameFlag = SameFlag & strcmp(Type{1},Type{j});
else
fclose(fid);
error(['Unknown property data type, ''',Token{1},''', in ', ...
ElementNames{i},'.',CurPropertyNames{j},'.']);
end
else % list type
if length(Token) == 3
ListFlag = 1;
SameFlag = 0;
tmp = rem(strmatch(Token{2},PlyTypeNames,'exact')-1,8)+1;
tmp2 = rem(strmatch(Token{3},PlyTypeNames,'exact')-1,8)+1;
if ~isempty(tmp) & ~isempty(tmp2)
TypeSize(j) = SizeOf(tmp);
Type{j} = MatlabTypeNames{tmp};
TypeSize2(j) = SizeOf(tmp2);
Type2{j} = MatlabTypeNames{tmp2};
else
fclose(fid);
error(['Unknown property data type, ''list ',Token{2},' ',Token{3},''', in ', ...
ElementNames{i},'.',CurPropertyNames{j},'.']);
end
else
fclose(fid);
error(['Invalid list syntax in ',ElementNames{i},'.',CurPropertyNames{j},'.']);
end
end
end
% read file
if ~ListFlag
if SameFlag
% no list types, all the same type (fast)
Data = fread(fid,[NumProperties,ElementCount(i)],Type{1})';
else
% no list types, mixed type
Data = zeros(ElementCount(i),NumProperties);
for j = 1:ElementCount(i)
for k = 1:NumProperties
Data(j,k) = fread(fid,1,Type{k});
end
end
end
else
ListData = cell(NumProperties,1);
for k = 1:NumProperties
ListData{k} = cell(ElementCount(i),1);
end
if NumProperties == 1
BufSize = 512;
SkipNum = 4;
j = 0;
% list type, one property (fast if lists are usually the same length)
while j < ElementCount(i)
Position = ftell(fid);
% read in BufSize count values, assuming all counts = SkipNum
[Buf,BufSize] = fread(fid,BufSize,Type{1},SkipNum*TypeSize2(1));
Miss = find(Buf ~= SkipNum); % find first count that is not SkipNum
fseek(fid,Position + TypeSize(1),-1); % seek back to after first count
if isempty(Miss) % all counts are SkipNum
Buf = fread(fid,[SkipNum,BufSize],[int2str(SkipNum),'*',Type2{1}],TypeSize(1))';
fseek(fid,-TypeSize(1),0); % undo last skip
for k = 1:BufSize
ListData{1}{j+k} = Buf(k,:);
end
j = j + BufSize;
BufSize = floor(1.5*BufSize);
else
if Miss(1) > 1 % some counts are SkipNum
Buf2 = fread(fid,[SkipNum,Miss(1)-1],[int2str(SkipNum),'*',Type2{1}],TypeSize(1))';
for k = 1:Miss(1)-1
ListData{1}{j+k} = Buf2(k,:);
end
j = j + k;
end
% read in the list with the missed count
SkipNum = Buf(Miss(1));
j = j + 1;
ListData{1}{j} = fread(fid,[1,SkipNum],Type2{1});
BufSize = ceil(0.6*BufSize);
end
end
else
% list type(s), multiple properties (slow)
Data = zeros(ElementCount(i),NumProperties);
for j = 1:ElementCount(i)
for k = 1:NumProperties
if isempty(Type2{k})
Data(j,k) = fread(fid,1,Type{k});
else
tmp = fread(fid,1,Type{k});
ListData{k}{j} = fread(fid,[1,tmp],Type2{k});
end
end
end
end
end
end
% put data into Elements structure
for k = 1:NumProperties
if (~Format & ~Type(k)) | (Format & isempty(Type2{k}))
eval(['Elements.',ElementNames{i},'.',CurPropertyNames{k},'=Data(:,k);']);
else
eval(['Elements.',ElementNames{i},'.',CurPropertyNames{k},'=ListData{k};']);
end
end
end
clear Data ListData;
fclose(fid);
if (nargin > 1 & strcmpi(Str,'Tri')) | nargout > 2
% find vertex element field
Name = {'vertex','Vertex','point','Point','pts','Pts'};
Names = [];
for i = 1:length(Name)
if any(strcmp(ElementNames,Name{i}))
Names = getfield(PropertyNames,Name{i});
Name = Name{i};
break;
end
end
if any(strcmp(Names,'x')) & any(strcmp(Names,'y')) & any(strcmp(Names,'z'))
eval(['varargout{1}=[Elements.',Name,'.x,Elements.',Name,'.y,Elements.',Name,'.z];']);
else
varargout{1} = zeros(1,3);
end
varargout{2} = Elements;
varargout{3} = Comments;
Elements = [];
% find face element field
Name = {'face','Face','poly','Poly','tri','Tri'};
Names = [];
for i = 1:length(Name)
if any(strcmp(ElementNames,Name{i}))
Names = getfield(PropertyNames,Name{i});
Name = Name{i};
break;
end
end
if ~isempty(Names)
% find vertex indices property subfield
PropertyName = {'vertex_indices','vertex_indexes','vertex_index','indices','indexes'};
for i = 1:length(PropertyName)
if any(strcmp(Names,PropertyName{i}))
PropertyName = PropertyName{i};
break;
end
end
if ~iscell(PropertyName)
% convert face index lists to triangular connectivity
eval(['FaceIndices=varargout{2}.',Name,'.',PropertyName,';']);
N = length(FaceIndices);
Elements = zeros(N*2,3);
Extra = 0;
for k = 1:N
Elements(k,:) = FaceIndices{k}(1:3);
for j = 4:length(FaceIndices{k})
Extra = Extra + 1;
Elements(N + Extra,:) = [Elements(k,[1,j-1]),FaceIndices{k}(j)];
end
end
Elements = Elements(1:N+Extra,:) + 1;
end
end
else
varargout{1} = Comments;
end

View File

@ -0,0 +1,35 @@
function [ptsOut,indexSet] = reducePts_haa(pts, dst)
%Reduces a point set, pts, in a stochastic manner, such that the minimum sdistance
% between points is 'dst'. Writen by abd, edited by haa, then by raje
nPoints=size(pts,2);
indexSet=true(nPoints,1);
RandOrd=randperm(nPoints);
%tic
NS = KDTreeSearcher(pts');
%toc
% search the KNTree for close neighbours in a chunk-wise fashion to save memory if point cloud is really big
Chunks=1:min(4e6,nPoints-1):nPoints;
Chunks(end)=nPoints;
for cChunk=1:(length(Chunks)-1)
Range=Chunks(cChunk):Chunks(cChunk+1);
idx = rangesearch(NS,pts(:,RandOrd(Range))',dst);
for i = 1:size(idx,1)
id =RandOrd(i-1+Chunks(cChunk));
if (indexSet(id))
indexSet(idx{i}) = 0;
indexSet(id) = 1;
end
end
end
ptsOut = pts(:,indexSet);
disp(['downsample factor: ' num2str(nPoints/sum(indexSet))]);

View File

@ -0,0 +1,106 @@
5c1f33f1d33e1f2e4aa6dda4
5bfe5ae0fe0ea555e6a969ca
5bff3c5cfe0ea555e6bcbf3a
58eaf1513353456af3a1682a
5bfc9d5aec61ca1dd69132a2
5bf18642c50e6f7f8bdbd492
5bf26cbbd43923194854b270
5bf17c0fd439231948355385
5be3ae47f44e235bdbbc9771
5be3a5fb8cfdd56947f6b67c
5bbb6eb2ea1cfa39f1af7e0c
5ba75d79d76ffa2c86cf2f05
5bb7a08aea1cfa39f1a947ab
5b864d850d072a699b32f4ae
5b6eff8b67b396324c5b2672
5b6e716d67b396324c2d77cb
5b69cc0cb44b61786eb959bf
5b62647143840965efc0dbde
5b60fa0c764f146feef84df0
5b558a928bbfb62204e77ba2
5b271079e0878c3816dacca4
5b08286b2775267d5b0634ba
5afacb69ab00705d0cefdd5b
5af28cea59bc705737003253
5af02e904c8216544b4ab5a2
5aa515e613d42d091d29d300
5c34529873a8df509ae57b58
5c34300a73a8df509add216d
5c1af2e2bee9a723c963d019
5c1892f726173c3a09ea9aeb
5c0d13b795da9479e12e2ee9
5c062d84a96e33018ff6f0a6
5bfd0f32ec61ca1dd69dc77b
5bf21799d43923194842c001
5bf3a82cd439231948877aed
5bf03590d4392319481971dc
5beb6e66abd34c35e18e66b9
5be883a4f98cee15019d5b83
5be47bf9b18881428d8fbc1d
5bcf979a6d5f586b95c258cd
5bce7ac9ca24970bce4934b6
5bb8a49aea1cfa39f1aa7f75
5b78e57afc8fcf6781d0c3ba
5b21e18c58e2823a67a10dd8
5b22269758e2823a67a3bd03
5b192eb2170cf166458ff886
5ae2e9c5fe405c5076abc6b2
5adc6bd52430a05ecb2ffb85
5ab8b8e029f5351f7f2ccf59
5abc2506b53b042ead637d86
5ab85f1dac4291329b17cb50
5a969eea91dfc339a9a3ad2c
5a8aa0fab18050187cbe060e
5a7d3db14989e929563eb153
5a69c47d0d5d0a7f3b2e9752
5a618c72784780334bc1972d
5a6464143d809f1d8208c43c
5a588a8193ac3d233f77fbca
5a57542f333d180827dfc132
5a572fd9fc597b0478a81d14
5a563183425d0f5186314855
5a4a38dad38c8a075495b5d2
5a48d4b2c7dab83a7d7b9851
5a489fb1c7dab83a7d7b1070
5a48ba95c7dab83a7d7b44ed
5a3ca9cb270f0e3f14d0eddb
5a3cb4e4270f0e3f14d12f43
5a3f4aba5889373fbbc5d3b5
5a0271884e62597cdee0d0eb
59e864b2a9e91f2c5529325f
599aa591d5b41f366fed0d58
59350ca084b7f26bf5ce6eb8
59338e76772c3e6384afbb15
5c20ca3a0843bc542d94e3e2
5c1dbf200843bc542d8ef8c4
5c1b1500bee9a723c96c3e78
5bea87f4abd34c35e1860ab5
5c2b3ed5e611832e8aed46bf
57f8d9bbe73f6760f10e916a
5bf7d63575c26f32dbf7413b
5be4ab93870d330ff2dce134
5bd43b4ba6b28b1ee86b92dd
5bccd6beca24970bce448134
5bc5f0e896b66a2cd8f9bd36
5b908d3dc6ab78485f3d24a9
5b2c67b5e0878c381608b8d8
5b4933abf2b5f44e95de482a
5b3b353d8d46a939f93524b9
5acf8ca0f3d8a750097e4b15
5ab8713ba3799a1d138bd69a
5aa235f64a17b335eeaf9609
5aa0f9d7a9efce63548c69a1
5a8315f624b8e938486e0bd8
5a48c4e9c7dab83a7d7b5cc7
59ecfd02e225f6492d20fcc9
59f87d0bfa6280566fb38c9a
59f363a8b45be22330016cad
59f70ab1e5c5d366af29bf3e
59e75a2ca9e91f2c5526005d
5947719bf1b45630bd096665
5947b62af1b45630bd0c2a02
59056e6760bb961de55f3501
58f7f7299f5b5647873cb110
58cf4771d0f5fb221defe6da
58d36897f387231e6c929903
58c4bb4f4a69c55606122be4

View File

@ -0,0 +1,7 @@
5b7a3890fc8fcf6781e2593a
5c189f2326173c3a09ed7ef3
5b950c71608de421b1e7318f
5a6400933d809f1d8200af15
59d2657f82ca7774b1ec081d
5ba19a8a360c7c30c1c169df
59817e4a1bd4b175e7038d19

View File

@ -0,0 +1,22 @@
scan1
scan4
scan9
scan10
scan11
scan12
scan13
scan15
scan23
scan24
scan29
scan32
scan33
scan34
scan48
scan49
scan62
scan75
scan77
scan110
scan114
scan118

View File

@ -0,0 +1,79 @@
scan2
scan6
scan7
scan8
scan14
scan16
scan18
scan19
scan20
scan22
scan30
scan31
scan36
scan39
scan41
scan42
scan44
scan45
scan46
scan47
scan50
scan51
scan52
scan53
scan55
scan57
scan58
scan60
scan61
scan63
scan64
scan65
scan68
scan69
scan70
scan71
scan72
scan74
scan76
scan83
scan84
scan85
scan87
scan88
scan89
scan90
scan91
scan92
scan93
scan94
scan95
scan96
scan97
scan98
scan99
scan100
scan101
scan102
scan103
scan104
scan105
scan107
scan108
scan109
scan111
scan112
scan113
scan115
scan116
scan119
scan120
scan121
scan122
scan123
scan124
scan125
scan126
scan127
scan128

View File

@ -0,0 +1,18 @@
scan3
scan5
scan17
scan21
scan28
scan35
scan37
scan38
scan40
scan43
scan56
scan59
scan66
scan67
scan82
scan86
scan106
scan117

293
IGEV-MVS/train_mvs.py Normal file
View File

@ -0,0 +1,293 @@
import argparse
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2'
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
from torch.utils.data import DataLoader
import torch.nn.functional as F
import numpy as np
import random
import time
from torch.utils.tensorboard import SummaryWriter
from datasets import find_dataset_def
from core.igev_mvs import IGEVMVS
from core.submodule import depth_normalization, depth_unnormalization
from utils import *
import sys
import datetime
from tqdm import tqdm
cudnn.benchmark = True
parser = argparse.ArgumentParser(description='IterMVStereo for high-resolution multi-view stereo')
parser.add_argument('--mode', default='train', help='train or val', choices=['train', 'val'])
parser.add_argument('--dataset', default='dtu_yao', help='select dataset')
parser.add_argument('--trainpath', default='/data/dtu_data/dtu_train/', help='train datapath')
parser.add_argument('--valpath', help='validation datapath')
parser.add_argument('--trainlist', default='./lists/dtu/train.txt', help='train list')
parser.add_argument('--vallist', default='./lists/dtu/val.txt', help='validation list')
parser.add_argument('--maxdisp', default=256)
parser.add_argument('--epochs', type=int, default=32, help='number of epochs to train')
parser.add_argument('--lr', type=float, default=0.0002, help='learning rate')
parser.add_argument('--wd', type=float, default=.00001, help='weight decay')
parser.add_argument('--batch_size', type=int, default=6, help='train batch size')
parser.add_argument('--loadckpt', default=None, help='load a specific checkpoint')
parser.add_argument('--logdir', default='./checkpoints/', help='the directory to save checkpoints/logs')
parser.add_argument('--resume', action='store_true', help='continue to train the model')
parser.add_argument('--regress', action='store_true', help='train the regression and confidence')
parser.add_argument('--small_image', action='store_true', help='train with small input as 640x512, otherwise train with 1280x1024')
parser.add_argument('--summary_freq', type=int, default=20, help='print and summary frequency')
parser.add_argument('--save_freq', type=int, default=1, help='save checkpoint frequency')
parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed')
parser.add_argument('--iteration', type=int, default=22, help='num of iteration of GRU')
try:
from torch.cuda.amp import GradScaler
except:
# dummy GradScaler for PyTorch < 1.6
class GradScaler:
def __init__(self):
pass
def scale(self, loss):
return loss
def unscale_(self, optimizer):
pass
def step(self, optimizer):
optimizer.step()
def update(self):
pass
def sequence_loss(disp_preds, disp_init_pred, depth_gt, mask, depth_min, depth_max, loss_gamma=0.9):
""" Loss function defined over sequence of depth predictions """
cross_entropy = nn.BCEWithLogitsLoss()
n_predictions = len(disp_preds)
assert n_predictions >= 1
loss = 0.0
mask = mask > 0.5
batch, _, height, width = depth_gt.size()
inverse_depth_min = (1.0 / depth_min).view(batch, 1, 1, 1)
inverse_depth_max = (1.0 / depth_max).view(batch, 1, 1, 1)
normalized_disp_gt = depth_normalization(depth_gt, inverse_depth_min, inverse_depth_max)
loss += 1.0 * F.l1_loss(disp_init_pred[mask], normalized_disp_gt[mask], reduction='mean')
if args.iteration != 0:
for i in range(n_predictions):
adjusted_loss_gamma = loss_gamma**(15/(n_predictions - 1))
i_weight = adjusted_loss_gamma**(n_predictions - i - 1)
loss += i_weight * F.l1_loss(disp_preds[i][mask], normalized_disp_gt[mask], reduction='mean')
return loss
# parse arguments and check
args = parser.parse_args()
if args.resume: # store_true means set the variable as "True"
assert args.mode == "train"
assert args.loadckpt is None
if args.valpath is None:
args.valpath = args.trainpath
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)
if args.mode == "train":
if not os.path.isdir(args.logdir):
os.mkdir(args.logdir)
current_time_str = str(datetime.datetime.now().strftime('%Y%m%d_%H%M%S'))
print("current time", current_time_str)
print("creating new summary file")
logger = SummaryWriter(args.logdir)
print("argv:", sys.argv[1:])
print_args(args)
# dataset, dataloader
MVSDataset = find_dataset_def(args.dataset)
train_dataset = MVSDataset(args.trainpath, args.trainlist, "train", 5, robust_train=True)
test_dataset = MVSDataset(args.valpath, args.vallist, "val", 5, robust_train=False)
TrainImgLoader = DataLoader(train_dataset, args.batch_size, shuffle=True, num_workers=4, drop_last=True)
TestImgLoader = DataLoader(test_dataset, args.batch_size, shuffle=False, num_workers=4, drop_last=False)
# model, optimizer
model = IGEVMVS(args)
if args.mode in ["train", "val"]:
model = nn.DataParallel(model)
model.cuda()
model_loss = sequence_loss
optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=args.wd, eps=1e-8)
# load parameters
start_epoch = 0
if (args.mode == "train" and args.resume) or (args.mode == "val" and not args.loadckpt):
saved_models = [fn for fn in os.listdir(args.logdir) if fn.endswith(".ckpt")]
saved_models = sorted(saved_models, key=lambda x: int(x.split('_')[-1].split('.')[0]))
# use the latest checkpoint file
loadckpt = os.path.join(args.logdir, saved_models[-1])
print("resuming", loadckpt)
state_dict = torch.load(loadckpt)
model.load_state_dict(state_dict['model'], strict=False)
optimizer.load_state_dict(state_dict['optimizer'])
start_epoch = state_dict['epoch'] + 1
elif args.loadckpt:
# load checkpoint file specified by args.loadckpt
print("loading model {}".format(args.loadckpt))
state_dict = torch.load(args.loadckpt)
model.load_state_dict(state_dict['model'], strict=False)
print("start at epoch {}".format(start_epoch))
print('Number of model parameters: {}'.format(sum([p.data.nelement() for p in model.parameters()])))
# main function
def train(args):
total_steps = len(TrainImgLoader) * args.epochs + 100
lr_scheduler = optim.lr_scheduler.OneCycleLR(optimizer, args.lr, total_steps, pct_start=0.01, cycle_momentum=False, anneal_strategy='linear')
for epoch_idx in range(start_epoch, args.epochs):
print('Epoch {}:'.format(epoch_idx))
global_step = len(TrainImgLoader) * epoch_idx
# training
tbar = tqdm(TrainImgLoader)
for batch_idx, sample in enumerate(tbar):
start_time = time.time()
global_step = len(TrainImgLoader) * epoch_idx + batch_idx
do_summary = global_step % args.summary_freq == 0
scaler = GradScaler(enabled=True)
loss, scalar_outputs = train_sample(args, sample, detailed_summary=do_summary, scaler=scaler)
if do_summary:
save_scalars(logger, 'train', scalar_outputs, global_step)
del scalar_outputs
tbar.set_description(
'Epoch {}/{}, Iter {}/{}, train loss = {:.3f}, time = {:.3f}'.format(epoch_idx, args.epochs, batch_idx, len(TrainImgLoader), loss, time.time() - start_time))
lr_scheduler.step()
# checkpoint
if (epoch_idx + 1) % args.save_freq == 0:
torch.save({
'model': model.state_dict()},
"{}/model_{:0>6}.ckpt".format(args.logdir, epoch_idx))
torch.cuda.empty_cache()
# testing
avg_test_scalars = DictAverageMeter()
tbar = tqdm(TestImgLoader)
for batch_idx, sample in enumerate(tbar):
start_time = time.time()
global_step = len(TestImgLoader) * epoch_idx + batch_idx
do_summary = global_step % args.summary_freq == 0
loss, scalar_outputs = test_sample(args, sample, detailed_summary=do_summary)
if do_summary:
save_scalars(logger, 'test', scalar_outputs, global_step)
avg_test_scalars.update(scalar_outputs)
del scalar_outputs
tbar.set_description('Epoch {}/{}, Iter {}/{}, test loss = {:.3f}, time = {:3f}'.format(epoch_idx, args.epochs, batch_idx,
len(TestImgLoader), loss,
time.time() - start_time))
save_scalars(logger, 'fulltest', avg_test_scalars.mean(), global_step)
print("avg_test_scalars:", avg_test_scalars.mean())
torch.cuda.empty_cache()
def test(args):
avg_test_scalars = DictAverageMeter()
for batch_idx, sample in enumerate(TestImgLoader):
start_time = time.time()
loss, scalar_outputs = test_sample(args, sample, detailed_summary=True)
avg_test_scalars.update(scalar_outputs)
del scalar_outputs
print('Iter {}/{}, test loss = {:.3f}, time = {:3f}'.format(batch_idx, len(TestImgLoader), loss,
time.time() - start_time))
if batch_idx % 100 == 0:
print("Iter {}/{}, test results = {}".format(batch_idx, len(TestImgLoader), avg_test_scalars.mean()))
print("final", avg_test_scalars)
def train_sample(args, sample, detailed_summary=False, scaler=None):
model.train()
optimizer.zero_grad()
sample_cuda = tocuda(sample)
depth_gt = sample_cuda["depth"]
mask = sample_cuda["mask"]
depth_gt_0 = depth_gt['level_0']
mask_0 = mask['level_0']
depth_gt_1 = depth_gt['level_2']
mask_1 = mask['level_2']
disp_init, disp_predictions = model(sample_cuda["imgs"], sample_cuda["proj_matrices"],
sample_cuda["depth_min"], sample_cuda["depth_max"])
loss = model_loss(disp_predictions, disp_init, depth_gt_0, mask_0, sample_cuda["depth_min"], sample_cuda["depth_max"])
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
scaler.step(optimizer)
scaler.update()
inverse_depth_min = (1.0 / sample_cuda["depth_min"]).view(args.batch_size, 1, 1, 1)
inverse_depth_max = (1.0 / sample_cuda["depth_max"]).view(args.batch_size, 1, 1, 1)
depth_init = depth_unnormalization(disp_init, inverse_depth_min, inverse_depth_max)
depth_predictions = []
for disp in disp_predictions:
depth_predictions.append(depth_unnormalization(disp, inverse_depth_min, inverse_depth_max))
scalar_outputs = {"loss": loss}
scalar_outputs["abs_error_initial"] = AbsDepthError_metrics(depth_init, depth_gt_0, mask_0 > 0.5)
scalar_outputs["thres1mm_initial"] = Thres_metrics(depth_init, depth_gt_0, mask_0 > 0.5, 1)
scalar_outputs["abs_error_final_full"] = AbsDepthError_metrics(depth_predictions[-1], depth_gt_0, mask_0 > 0.5)
return tensor2float(loss), tensor2float(scalar_outputs)
@make_nograd_func
def test_sample(args, sample, detailed_summary=True):
model.eval()
sample_cuda = tocuda(sample)
depth_gt = sample_cuda["depth"]
mask = sample_cuda["mask"]
depth_gt_0 = depth_gt['level_0']
mask_0 = mask['level_0']
depth_gt_1 = depth_gt['level_2']
mask_1 = mask['level_2']
disp_init, disp_predictions = model(sample_cuda["imgs"], sample_cuda["proj_matrices"],
sample_cuda["depth_min"], sample_cuda["depth_max"])
loss = model_loss(disp_predictions, disp_init, depth_gt_0, mask_0, sample_cuda["depth_min"], sample_cuda["depth_max"])
inverse_depth_min = (1.0 / sample_cuda["depth_min"]).view(sample_cuda["depth_min"].size()[0], 1, 1, 1)
inverse_depth_max = (1.0 / sample_cuda["depth_max"]).view(sample_cuda["depth_max"].size()[0], 1, 1, 1)
depth_init = depth_unnormalization(disp_init, inverse_depth_min, inverse_depth_max)
depth_predictions = []
for disp in disp_predictions:
depth_predictions.append(depth_unnormalization(disp, inverse_depth_min, inverse_depth_max))
scalar_outputs = {"loss": loss}
scalar_outputs["abs_error_initial"] = AbsDepthError_metrics(depth_init, depth_gt_0, mask_0 > 0.5)
scalar_outputs["thres1mm_initial"] = Thres_metrics(depth_init, depth_gt_0, mask_0 > 0.5, 1)
scalar_outputs["abs_error_final_full"] = AbsDepthError_metrics(depth_predictions[-1], depth_gt_0, mask_0 > 0.5)
return tensor2float(loss), tensor2float(scalar_outputs)
if __name__ == '__main__':
if args.mode == "train":
train(args)
elif args.mode == "val":
test(args)

155
IGEV-MVS/utils.py Normal file
View File

@ -0,0 +1,155 @@
import numpy as np
import torchvision.utils as vutils
import torch
import torch.nn.functional as F
# print arguments
def print_args(args):
print("################################ args ################################")
for k, v in args.__dict__.items():
print("{0: <10}\t{1: <30}\t{2: <20}".format(k, str(v), str(type(v))))
print("########################################################################")
# torch.no_grad warpper for functions
def make_nograd_func(func):
def wrapper(*f_args, **f_kwargs):
with torch.no_grad():
ret = func(*f_args, **f_kwargs)
return ret
return wrapper
# convert a function into recursive style to handle nested dict/list/tuple variables
def make_recursive_func(func):
def wrapper(vars):
if isinstance(vars, list):
return [wrapper(x) for x in vars]
elif isinstance(vars, tuple):
return tuple([wrapper(x) for x in vars])
elif isinstance(vars, dict):
return {k: wrapper(v) for k, v in vars.items()}
else:
return func(vars)
return wrapper
@make_recursive_func
def tensor2float(vars):
if isinstance(vars, float):
return vars
elif isinstance(vars, torch.Tensor):
return vars.data.item()
else:
raise NotImplementedError("invalid input type {} for tensor2float".format(type(vars)))
@make_recursive_func
def tensor2numpy(vars):
if isinstance(vars, np.ndarray):
return vars
elif isinstance(vars, torch.Tensor):
return vars.detach().cpu().numpy().copy()
else:
raise NotImplementedError("invalid input type {} for tensor2numpy".format(type(vars)))
@make_recursive_func
def tocuda(vars):
if isinstance(vars, torch.Tensor):
return vars.cuda()
elif isinstance(vars, str):
return vars
else:
raise NotImplementedError("invalid input type {} for tocuda".format(type(vars)))
def save_scalars(logger, mode, scalar_dict, global_step):
scalar_dict = tensor2float(scalar_dict)
for key, value in scalar_dict.items():
if not isinstance(value, (list, tuple)):
name = '{}/{}'.format(mode, key)
logger.add_scalar(name, value, global_step)
else:
for idx in range(len(value)):
name = '{}/{}_{}'.format(mode, key, idx)
logger.add_scalar(name, value[idx], global_step)
def save_images(logger, mode, images_dict, global_step):
images_dict = tensor2numpy(images_dict)
def preprocess(name, img):
if not (len(img.shape) == 3 or len(img.shape) == 4):
raise NotImplementedError("invalid img shape {}:{} in save_images".format(name, img.shape))
if len(img.shape) == 3:
img = img[:, np.newaxis, :, :]
img = torch.from_numpy(img[:1])
return vutils.make_grid(img, padding=0, nrow=1, normalize=True, scale_each=True)
for key, value in images_dict.items():
if not isinstance(value, (list, tuple)):
name = '{}/{}'.format(mode, key)
logger.add_image(name, preprocess(name, value), global_step)
else:
for idx in range(len(value)):
name = '{}/{}_{}'.format(mode, key, idx)
logger.add_image(name, preprocess(name, value[idx]), global_step)
class DictAverageMeter(object):
def __init__(self):
self.data = {}
self.count = 0
def update(self, new_input):
self.count += 1
if len(self.data) == 0:
for k, v in new_input.items():
if not isinstance(v, float):
raise NotImplementedError("invalid data {}: {}".format(k, type(v)))
self.data[k] = v
else:
for k, v in new_input.items():
if not isinstance(v, float):
raise NotImplementedError("invalid data {}: {}".format(k, type(v)))
self.data[k] += v
def mean(self):
return {k: v / self.count for k, v in self.data.items()}
# a wrapper to compute metrics for each image individually
def compute_metrics_for_each_image(metric_func):
def wrapper(depth_est, depth_gt, mask, *args):
batch_size = depth_gt.shape[0]
results = []
# compute result one by one
for idx in range(batch_size):
ret = metric_func(depth_est[idx], depth_gt[idx], mask[idx], *args)
results.append(ret)
return torch.stack(results).mean()
return wrapper
@make_nograd_func
@compute_metrics_for_each_image
def Thres_metrics(depth_est, depth_gt, mask, thres):
# if thres is int or float, then True
assert isinstance(thres, (int, float))
depth_est, depth_gt = depth_est[mask], depth_gt[mask]
errors = torch.abs(depth_est - depth_gt)
err_mask = errors > thres
return torch.mean(err_mask.float())
# NOTE: please do not use this to build up training loss
@make_nograd_func
@compute_metrics_for_each_image
def AbsDepthError_metrics(depth_est, depth_gt, mask):
depth_est, depth_gt = depth_est[mask], depth_gt[mask]
return torch.mean((depth_est - depth_gt).abs())