Add files via upload
This commit is contained in:
parent
84baaf7d8e
commit
e079f027ff
0
IGEV-MVS/core/__init__.py
Normal file
0
IGEV-MVS/core/__init__.py
Normal file
61
IGEV-MVS/core/corr.py
Normal file
61
IGEV-MVS/core/corr.py
Normal file
@ -0,0 +1,61 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from .submodule import *
|
||||
|
||||
class CorrBlock1D_Cost_Volume:
|
||||
def __init__(self, init_corr, corr, num_levels=2, radius=4, inverse_depth_min=None, inverse_depth_max=None, num_sample=None):
|
||||
self.num_levels = 2
|
||||
self.radius = radius
|
||||
self.inverse_depth_min = inverse_depth_min
|
||||
self.inverse_depth_max = inverse_depth_max
|
||||
self.num_sample = num_sample
|
||||
self.corr_pyramid = []
|
||||
self.init_corr_pyramid = []
|
||||
|
||||
# all pairs correlation
|
||||
|
||||
# batch, h1, w1, dim, w2 = corr.shape
|
||||
b, c, d, h, w = corr.shape
|
||||
corr = corr.permute(0, 3, 4, 1, 2).reshape(b*h*w, 1, 1, d)
|
||||
init_corr = init_corr.permute(0, 3, 4, 1, 2).reshape(b*h*w, 1, 1, d)
|
||||
|
||||
self.corr_pyramid.append(corr)
|
||||
self.init_corr_pyramid.append(init_corr)
|
||||
|
||||
|
||||
for i in range(self.num_levels):
|
||||
corr = F.avg_pool2d(corr, [1,2], stride=[1,2])
|
||||
self.corr_pyramid.append(corr)
|
||||
|
||||
for i in range(self.num_levels):
|
||||
init_corr = F.avg_pool2d(init_corr, [1,2], stride=[1,2])
|
||||
self.init_corr_pyramid.append(init_corr)
|
||||
|
||||
|
||||
|
||||
def __call__(self, disp):
|
||||
r = self.radius
|
||||
b, _, h, w = disp.shape
|
||||
out_pyramid = []
|
||||
for i in range(self.num_levels):
|
||||
corr = self.corr_pyramid[i]
|
||||
init_corr = self.init_corr_pyramid[i]
|
||||
dx = torch.linspace(-r, r, 2*r+1)
|
||||
dx = dx.view(1, 1, 2*r+1, 1).to(disp.device)
|
||||
x0 = dx + disp.reshape(b*h*w, 1, 1, 1) / 2**i
|
||||
y0 = torch.zeros_like(x0)
|
||||
|
||||
disp_lvl = torch.cat([x0,y0], dim=-1)
|
||||
corr = bilinear_sampler(corr, disp_lvl)
|
||||
corr = corr.view(b, h, w, -1)
|
||||
|
||||
init_corr = bilinear_sampler(init_corr, disp_lvl)
|
||||
init_corr = init_corr.view(b, h, w, -1)
|
||||
|
||||
out_pyramid.append(corr)
|
||||
out_pyramid.append(init_corr)
|
||||
|
||||
|
||||
out = torch.cat(out_pyramid, dim=-1)
|
||||
return out.permute(0, 3, 1, 2).contiguous().float()
|
212
IGEV-MVS/core/extractor.py
Normal file
212
IGEV-MVS/core/extractor.py
Normal file
@ -0,0 +1,212 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import timm
|
||||
import math
|
||||
from .submodule import *
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, in_planes, planes, norm_fn='group', stride=1):
|
||||
super(ResidualBlock, self).__init__()
|
||||
|
||||
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride)
|
||||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
num_groups = planes // 8
|
||||
|
||||
if norm_fn == 'group':
|
||||
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||
if not (stride == 1 and in_planes == planes):
|
||||
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||
|
||||
elif norm_fn == 'batch':
|
||||
self.norm1 = nn.BatchNorm2d(planes)
|
||||
self.norm2 = nn.BatchNorm2d(planes)
|
||||
if not (stride == 1 and in_planes == planes):
|
||||
self.norm3 = nn.BatchNorm2d(planes)
|
||||
|
||||
elif norm_fn == 'instance':
|
||||
self.norm1 = nn.InstanceNorm2d(planes)
|
||||
self.norm2 = nn.InstanceNorm2d(planes)
|
||||
if not (stride == 1 and in_planes == planes):
|
||||
self.norm3 = nn.InstanceNorm2d(planes)
|
||||
|
||||
elif norm_fn == 'none':
|
||||
self.norm1 = nn.Sequential()
|
||||
self.norm2 = nn.Sequential()
|
||||
if not (stride == 1 and in_planes == planes):
|
||||
self.norm3 = nn.Sequential()
|
||||
|
||||
if stride == 1 and in_planes == planes:
|
||||
self.downsample = None
|
||||
|
||||
else:
|
||||
self.downsample = nn.Sequential(
|
||||
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
y = x
|
||||
y = self.conv1(y)
|
||||
y = self.norm1(y)
|
||||
y = self.relu(y)
|
||||
y = self.conv2(y)
|
||||
y = self.norm2(y)
|
||||
y = self.relu(y)
|
||||
|
||||
if self.downsample is not None:
|
||||
x = self.downsample(x)
|
||||
|
||||
return self.relu(x+y)
|
||||
|
||||
class MultiBasicEncoder(nn.Module):
|
||||
def __init__(self, output_dim=[128], norm_fn='batch', dropout=0.0, downsample=3):
|
||||
super(MultiBasicEncoder, self).__init__()
|
||||
self.norm_fn = norm_fn
|
||||
self.downsample = downsample
|
||||
|
||||
if self.norm_fn == 'group':
|
||||
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
|
||||
|
||||
elif self.norm_fn == 'batch':
|
||||
self.norm1 = nn.BatchNorm2d(64)
|
||||
|
||||
elif self.norm_fn == 'instance':
|
||||
self.norm1 = nn.InstanceNorm2d(64)
|
||||
|
||||
elif self.norm_fn == 'none':
|
||||
self.norm1 = nn.Sequential()
|
||||
|
||||
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=1 + (downsample > 2), padding=3)
|
||||
self.relu1 = nn.ReLU(inplace=True)
|
||||
|
||||
self.in_planes = 64
|
||||
self.layer1 = self._make_layer(64, stride=1)
|
||||
self.layer2 = self._make_layer(96, stride=1 + (downsample > 1))
|
||||
self.layer3 = self._make_layer(128, stride=1 + (downsample > 0))
|
||||
self.layer4 = self._make_layer(128, stride=2)
|
||||
self.layer5 = self._make_layer(128, stride=2)
|
||||
|
||||
output_list = []
|
||||
|
||||
for dim in output_dim:
|
||||
conv_out = nn.Sequential(
|
||||
ResidualBlock(128, 128, self.norm_fn, stride=1),
|
||||
nn.Conv2d(128, dim[2], 3, padding=1))
|
||||
output_list.append(conv_out)
|
||||
|
||||
self.outputs04 = nn.ModuleList(output_list)
|
||||
|
||||
output_list = []
|
||||
for dim in output_dim:
|
||||
conv_out = nn.Sequential(
|
||||
ResidualBlock(128, 128, self.norm_fn, stride=1),
|
||||
nn.Conv2d(128, dim[1], 3, padding=1))
|
||||
output_list.append(conv_out)
|
||||
|
||||
self.outputs08 = nn.ModuleList(output_list)
|
||||
|
||||
output_list = []
|
||||
for dim in output_dim:
|
||||
conv_out = nn.Conv2d(128, dim[0], 3, padding=1)
|
||||
output_list.append(conv_out)
|
||||
|
||||
self.outputs16 = nn.ModuleList(output_list)
|
||||
|
||||
if dropout > 0:
|
||||
self.dropout = nn.Dropout2d(p=dropout)
|
||||
else:
|
||||
self.dropout = None
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
|
||||
if m.weight is not None:
|
||||
nn.init.constant_(m.weight, 1)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def _make_layer(self, dim, stride=1):
|
||||
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
|
||||
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
|
||||
layers = (layer1, layer2)
|
||||
|
||||
self.in_planes = dim
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x, dual_inp=False, num_layers=3):
|
||||
|
||||
x = self.conv1(x)
|
||||
x = self.norm1(x)
|
||||
x = self.relu1(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
|
||||
if dual_inp:
|
||||
v = x
|
||||
x = x[:(x.shape[0]//2)]
|
||||
|
||||
outputs04 = [f(x) for f in self.outputs04]
|
||||
if num_layers == 1:
|
||||
return (outputs04, v) if dual_inp else (outputs04,)
|
||||
|
||||
y = self.layer4(x)
|
||||
outputs08 = [f(y) for f in self.outputs08]
|
||||
|
||||
if num_layers == 2:
|
||||
return (outputs04, outputs08, v) if dual_inp else (outputs04, outputs08)
|
||||
|
||||
z = self.layer5(y)
|
||||
outputs16 = [f(z) for f in self.outputs16]
|
||||
|
||||
return (outputs04, outputs08, outputs16, v) if dual_inp else (outputs04, outputs08, outputs16)
|
||||
|
||||
class Feature(SubModule):
|
||||
def __init__(self):
|
||||
super(Feature, self).__init__()
|
||||
pretrained = True
|
||||
model = timm.create_model('mobilenetv2_100', pretrained=pretrained, features_only=True)
|
||||
|
||||
layers = [1,2,3,5,6]
|
||||
chans = [16, 24, 32, 96, 160]
|
||||
self.conv_stem = model.conv_stem
|
||||
self.bn1 = model.bn1
|
||||
|
||||
self.block0 = torch.nn.Sequential(*model.blocks[0:layers[0]])
|
||||
self.block1 = torch.nn.Sequential(*model.blocks[layers[0]:layers[1]])
|
||||
self.block2 = torch.nn.Sequential(*model.blocks[layers[1]:layers[2]])
|
||||
self.block3 = torch.nn.Sequential(*model.blocks[layers[2]:layers[3]])
|
||||
self.block4 = torch.nn.Sequential(*model.blocks[layers[3]:layers[4]])
|
||||
|
||||
self.deconv32_16 = Conv2x_IN(chans[4], chans[3], deconv=True, concat=True)
|
||||
self.deconv16_8 = Conv2x_IN(chans[3]*2, chans[2], deconv=True, concat=True)
|
||||
self.deconv8_4 = Conv2x_IN(chans[2]*2, chans[1], deconv=True, concat=True)
|
||||
self.conv4 = BasicConv_IN(chans[1]*2, chans[1]*2, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
B, V, _, H, W = x.size()
|
||||
x = x.view(B * V, -1, H, W)
|
||||
#x = self.act1(self.bn1(self.conv_stem(x)))
|
||||
x = self.bn1(self.conv_stem(x))
|
||||
x2 = self.block0(x)
|
||||
x4 = self.block1(x2)
|
||||
# return x4,x4,x4,x4
|
||||
x8 = self.block2(x4)
|
||||
x16 = self.block3(x8)
|
||||
x32 = self.block4(x16)
|
||||
|
||||
x16 = self.deconv32_16(x32, x16)
|
||||
x8 = self.deconv16_8(x16, x8)
|
||||
x4 = self.deconv8_4(x8, x4)
|
||||
x4 = self.conv4(x4)
|
||||
|
||||
x4 = x4.view(B, V, -1, H // 4, W // 4)
|
||||
x8 = x8.view(B, V, -1, H // 8, W // 8)
|
||||
x16 = x16.view(B, V, -1, H // 16, W // 16)
|
||||
x32 = x32.view(B, V, -1, H // 32, W // 32)
|
||||
return [x4, x8, x16, x32]
|
195
IGEV-MVS/core/igev_mvs.py
Normal file
195
IGEV-MVS/core/igev_mvs.py
Normal file
@ -0,0 +1,195 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from .submodule import *
|
||||
from .corr import *
|
||||
from .extractor import *
|
||||
from .update import *
|
||||
|
||||
try:
|
||||
autocast = torch.cuda.amp.autocast
|
||||
except:
|
||||
class autocast:
|
||||
def __init__(self, enabled):
|
||||
pass
|
||||
def __enter__(self):
|
||||
pass
|
||||
def __exit__(self, *args):
|
||||
pass
|
||||
|
||||
class IGEVMVS(nn.Module):
|
||||
def __init__(self, args):
|
||||
super().__init__()
|
||||
|
||||
context_dims = [128, 128, 128]
|
||||
self.n_gru_layers = 3
|
||||
self.slow_fast_gru = False
|
||||
self.mixed_precision = True
|
||||
self.num_sample = 64
|
||||
self.G = 1
|
||||
self.corr_radius = 4
|
||||
self.corr_levels = 2
|
||||
self.iters = args.iteration
|
||||
self.update_block = BasicMultiUpdateBlock(hidden_dims=context_dims)
|
||||
self.conv_hidden_1 = nn.Conv2d(64, 128, kernel_size=3, padding=1, stride=1)
|
||||
self.conv_hidden_2 = nn.Conv2d(128, 128, kernel_size=3, padding=1, stride=2)
|
||||
self.conv_hidden_4 = nn.Conv2d(128, 128, kernel_size=3, padding=1, stride=2)
|
||||
self.feature = Feature()
|
||||
|
||||
self.stem_2 = nn.Sequential(
|
||||
BasicConv_IN(3, 32, kernel_size=3, stride=2, padding=1),
|
||||
nn.Conv2d(32, 32, 3, 1, 1, bias=False),
|
||||
nn.InstanceNorm2d(32), nn.ReLU()
|
||||
)
|
||||
self.stem_4 = nn.Sequential(
|
||||
BasicConv_IN(32, 48, kernel_size=3, stride=2, padding=1),
|
||||
nn.Conv2d(48, 48, 3, 1, 1, bias=False),
|
||||
nn.InstanceNorm2d(48), nn.ReLU()
|
||||
)
|
||||
|
||||
self.conv = BasicConv_IN(96, 48, kernel_size=3, padding=1, stride=1)
|
||||
self.desc = nn.Conv2d(48, 48, kernel_size=1, padding=0, stride=1)
|
||||
|
||||
self.spx = nn.Sequential(nn.ConvTranspose2d(2*32, 9, kernel_size=4, stride=2, padding=1),)
|
||||
self.spx_2 = Conv2x_IN(32, 32, True)
|
||||
self.spx_4 = nn.Sequential(
|
||||
BasicConv_IN(96, 32, kernel_size=3, stride=1, padding=1),
|
||||
nn.Conv2d(32, 32, 3, 1, 1, bias=False),
|
||||
nn.InstanceNorm2d(32), nn.ReLU()
|
||||
)
|
||||
|
||||
self.depth_initialization = DepthInitialization(self.num_sample)
|
||||
self.pixel_view_weight = PixelViewWeight(self.G)
|
||||
|
||||
self.corr_stem = BasicConv(1, 8, is_3d=True, kernel_size=3, stride=1, padding=1)
|
||||
self.corr_feature_att = FeatureAtt(8, 96)
|
||||
self.cost_agg = hourglass(8)
|
||||
|
||||
self.spx_2_gru = Conv2x(32, 32, True)
|
||||
self.spx_gru = nn.Sequential(nn.ConvTranspose2d(2*32, 9, kernel_size=4, stride=2, padding=1),)
|
||||
|
||||
def upsample_disp(self, depth, mask_feat_4, stem_2x):
|
||||
with autocast(enabled=self.mixed_precision):
|
||||
xspx = self.spx_2_gru(mask_feat_4, stem_2x)
|
||||
spx_pred = self.spx_gru(xspx)
|
||||
spx_pred = F.softmax(spx_pred, 1)
|
||||
|
||||
up_depth = context_upsample(depth, spx_pred).unsqueeze(1)
|
||||
|
||||
return up_depth
|
||||
|
||||
def forward(self, imgs, proj_matrices, depth_min, depth_max, test_mode=False):
|
||||
proj_matrices_2 = torch.unbind(proj_matrices['level_2'].float(), 1)
|
||||
depth_min = depth_min.float()
|
||||
depth_max = depth_max.float()
|
||||
|
||||
ref_proj = proj_matrices_2[0]
|
||||
src_projs = proj_matrices_2[1:]
|
||||
|
||||
with autocast(enabled=self.mixed_precision):
|
||||
images = torch.unbind(imgs['level_0'], dim=1)
|
||||
features = self.feature(imgs['level_0'])
|
||||
ref_feature = []
|
||||
for fea in features:
|
||||
ref_feature.append(torch.unbind(fea, dim=1)[0])
|
||||
src_features = [src_fea for src_fea in torch.unbind(features[0], dim=1)[1:]]
|
||||
|
||||
stem_2x = self.stem_2(images[0])
|
||||
stem_4x = self.stem_4(stem_2x)
|
||||
ref_feature[0] = torch.cat((ref_feature[0], stem_4x), 1)
|
||||
|
||||
for idx, src_fea in enumerate(src_features):
|
||||
stem_2y = self.stem_2(images[idx + 1])
|
||||
stem_4y = self.stem_4(stem_2y)
|
||||
src_features[idx] = torch.cat((src_fea, stem_4y), 1)
|
||||
|
||||
match_left = self.desc(self.conv(ref_feature[0]))
|
||||
match_left = match_left / torch.norm(match_left, 2, 1, True)
|
||||
|
||||
match_rights = [self.desc(self.conv(src_fea)) for src_fea in src_features]
|
||||
match_rights = [match_right / torch.norm(match_right, 2, 1, True) for match_right in match_rights]
|
||||
|
||||
xspx = self.spx_4(ref_feature[0])
|
||||
xspx = self.spx_2(xspx, stem_2x)
|
||||
spx_pred = self.spx(xspx)
|
||||
spx_pred = F.softmax(spx_pred, 1)
|
||||
|
||||
batch, dim, height, width = match_left.size()
|
||||
inverse_depth_min = (1.0 / depth_min).view(batch, 1, 1, 1)
|
||||
inverse_depth_max = (1.0 / depth_max).view(batch, 1, 1, 1)
|
||||
|
||||
device = match_left.device
|
||||
correlation_sum = 0
|
||||
view_weight_sum = 1e-5
|
||||
|
||||
match_left = match_left.float()
|
||||
depth_samples = self.depth_initialization(inverse_depth_min, inverse_depth_max, height, width, device)
|
||||
for src_feature, src_proj in zip(match_rights, src_projs):
|
||||
src_feature = src_feature.float()
|
||||
warped_feature = differentiable_warping(src_feature, src_proj, ref_proj, depth_samples)
|
||||
warped_feature = warped_feature.view(batch, self.G, dim // self.G, self.num_sample, height, width)
|
||||
correlation = torch.mean(warped_feature * match_left.view(batch, self.G, dim // self.G, 1, height, width), dim=2, keepdim=False)
|
||||
|
||||
view_weight = self.pixel_view_weight(correlation)
|
||||
del warped_feature, src_feature, src_proj
|
||||
|
||||
correlation_sum += correlation * view_weight.unsqueeze(1)
|
||||
view_weight_sum += view_weight_sum + view_weight.unsqueeze(1)
|
||||
del correlation, view_weight
|
||||
del match_left, match_rights, src_projs
|
||||
|
||||
with autocast(enabled=self.mixed_precision):
|
||||
init_corr_volume = correlation_sum.div_(view_weight_sum)
|
||||
corr_volume = self.corr_stem(init_corr_volume)
|
||||
corr_volume = self.corr_feature_att(corr_volume, ref_feature[0])
|
||||
regularized_cost_volume = self.cost_agg(corr_volume, ref_feature)
|
||||
|
||||
GEV_hidden = self.conv_hidden_1(regularized_cost_volume.squeeze(1))
|
||||
|
||||
GEV_hidden_2 = self.conv_hidden_2(GEV_hidden)
|
||||
|
||||
GEV_hidden_4 = self.conv_hidden_4(GEV_hidden_2)
|
||||
|
||||
net_list = [GEV_hidden, GEV_hidden_2, GEV_hidden_4]
|
||||
|
||||
net_list = [torch.tanh(x) for x in net_list]
|
||||
|
||||
corr_block = CorrBlock1D_Cost_Volume
|
||||
|
||||
init_corr_volume = init_corr_volume.float()
|
||||
regularized_cost_volume = regularized_cost_volume.float()
|
||||
probability = F.softmax(regularized_cost_volume.squeeze(1), dim=1)
|
||||
index = torch.arange(0, self.num_sample, 1, device=probability.device).view(1, self.num_sample, 1, 1).float()
|
||||
disp_init = torch.sum(index * probability, dim = 1, keepdim=True)
|
||||
|
||||
corr_fn = corr_block(init_corr_volume, regularized_cost_volume, radius=self.corr_radius, num_levels=self.corr_levels, inverse_depth_min=inverse_depth_min, inverse_depth_max=inverse_depth_max, num_sample=self.num_sample)
|
||||
|
||||
disp_predictions = []
|
||||
disp = disp_init
|
||||
|
||||
for itr in range(self.iters):
|
||||
disp = disp.detach()
|
||||
corr = corr_fn(disp)
|
||||
|
||||
with autocast(enabled=self.mixed_precision):
|
||||
if self.n_gru_layers == 3 and self.slow_fast_gru: # Update low-res GRU
|
||||
net_list = self.update_block(net_list, iter16=True, iter08=False, iter04=False, update=False)
|
||||
if self.n_gru_layers >= 2 and self.slow_fast_gru:# Update low-res GRU and mid-res GRU
|
||||
net_list = self.update_block(net_list, iter16=self.n_gru_layers==3, iter08=True, iter04=False, update=False)
|
||||
net_list, mask_feat_4, delta_disp = self.update_block(net_list, corr, disp, iter16=self.n_gru_layers==3, iter08=self.n_gru_layers>=2)
|
||||
|
||||
disp = disp + delta_disp
|
||||
|
||||
if test_mode and itr < self.iters-1:
|
||||
continue
|
||||
|
||||
disp_up = self.upsample_disp(disp, mask_feat_4, stem_2x) / (self.num_sample-1)
|
||||
disp_predictions.append(disp_up)
|
||||
|
||||
disp_init = context_upsample(disp_init, spx_pred.float()).unsqueeze(1) / (self.num_sample-1)
|
||||
|
||||
if test_mode:
|
||||
return disp_up
|
||||
|
||||
|
||||
return disp_init, disp_predictions
|
396
IGEV-MVS/core/submodule.py
Normal file
396
IGEV-MVS/core/submodule.py
Normal file
@ -0,0 +1,396 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
|
||||
class SubModule(nn.Module):
|
||||
def __init__(self):
|
||||
super(SubModule, self).__init__()
|
||||
|
||||
def weight_init(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
m.weight.data.normal_(0, math.sqrt(2. / n))
|
||||
elif isinstance(m, nn.Conv3d):
|
||||
n = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] * m.out_channels
|
||||
m.weight.data.normal_(0, math.sqrt(2. / n))
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
m.weight.data.fill_(1)
|
||||
m.bias.data.zero_()
|
||||
elif isinstance(m, nn.BatchNorm3d):
|
||||
m.weight.data.fill_(1)
|
||||
m.bias.data.zero_()
|
||||
|
||||
class BasicConv(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, out_channels, deconv=False, is_3d=False, bn=True, relu=True, **kwargs):
|
||||
super(BasicConv, self).__init__()
|
||||
|
||||
self.relu = relu
|
||||
self.use_bn = bn
|
||||
if is_3d:
|
||||
if deconv:
|
||||
self.conv = nn.ConvTranspose3d(in_channels, out_channels, bias=False, **kwargs)
|
||||
else:
|
||||
self.conv = nn.Conv3d(in_channels, out_channels, bias=False, **kwargs)
|
||||
self.bn = nn.BatchNorm3d(out_channels)
|
||||
else:
|
||||
if deconv:
|
||||
self.conv = nn.ConvTranspose2d(in_channels, out_channels, bias=False, **kwargs)
|
||||
else:
|
||||
self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
|
||||
self.bn = nn.BatchNorm2d(out_channels)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
if self.use_bn:
|
||||
x = self.bn(x)
|
||||
if self.relu:
|
||||
x = nn.LeakyReLU()(x)#, inplace=True)
|
||||
return x
|
||||
|
||||
class BasicConv_IN(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, out_channels, deconv=False, is_3d=False, IN=True, relu=True, **kwargs):
|
||||
super(BasicConv_IN, self).__init__()
|
||||
|
||||
self.relu = relu
|
||||
self.use_in = IN
|
||||
if is_3d:
|
||||
if deconv:
|
||||
self.conv = nn.ConvTranspose3d(in_channels, out_channels, bias=False, **kwargs)
|
||||
else:
|
||||
self.conv = nn.Conv3d(in_channels, out_channels, bias=False, **kwargs)
|
||||
self.IN = nn.InstanceNorm3d(out_channels)
|
||||
else:
|
||||
if deconv:
|
||||
self.conv = nn.ConvTranspose2d(in_channels, out_channels, bias=False, **kwargs)
|
||||
else:
|
||||
self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
|
||||
self.IN = nn.InstanceNorm2d(out_channels)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
if self.use_in:
|
||||
x = self.IN(x)
|
||||
if self.relu:
|
||||
x = nn.LeakyReLU()(x)#, inplace=True)
|
||||
return x
|
||||
|
||||
class Conv2x(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, out_channels, deconv=False, is_3d=False, concat=True, keep_concat=True, bn=True, relu=True, keep_dispc=False):
|
||||
super(Conv2x, self).__init__()
|
||||
self.concat = concat
|
||||
self.is_3d = is_3d
|
||||
if deconv and is_3d:
|
||||
kernel = (4, 4, 4)
|
||||
elif deconv:
|
||||
kernel = 4
|
||||
else:
|
||||
kernel = 3
|
||||
|
||||
if deconv and is_3d and keep_dispc:
|
||||
kernel = (1, 4, 4)
|
||||
stride = (1, 2, 2)
|
||||
padding = (0, 1, 1)
|
||||
self.conv1 = BasicConv(in_channels, out_channels, deconv, is_3d, bn=True, relu=True, kernel_size=kernel, stride=stride, padding=padding)
|
||||
else:
|
||||
self.conv1 = BasicConv(in_channels, out_channels, deconv, is_3d, bn=True, relu=True, kernel_size=kernel, stride=2, padding=1)
|
||||
|
||||
if self.concat:
|
||||
mul = 2 if keep_concat else 1
|
||||
self.conv2 = BasicConv(out_channels*2, out_channels*mul, False, is_3d, bn, relu, kernel_size=3, stride=1, padding=1)
|
||||
else:
|
||||
self.conv2 = BasicConv(out_channels, out_channels, False, is_3d, bn, relu, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, x, rem):
|
||||
x = self.conv1(x)
|
||||
if x.shape != rem.shape:
|
||||
x = F.interpolate(
|
||||
x,
|
||||
size=(rem.shape[-2], rem.shape[-1]),
|
||||
mode='nearest')
|
||||
if self.concat:
|
||||
x = torch.cat((x, rem), 1)
|
||||
else:
|
||||
x = x + rem
|
||||
x = self.conv2(x)
|
||||
return x
|
||||
|
||||
class Conv2x_IN(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, out_channels, deconv=False, is_3d=False, concat=True, keep_concat=True, IN=True, relu=True, keep_dispc=False):
|
||||
super(Conv2x_IN, self).__init__()
|
||||
self.concat = concat
|
||||
self.is_3d = is_3d
|
||||
if deconv and is_3d:
|
||||
kernel = (4, 4, 4)
|
||||
elif deconv:
|
||||
kernel = 4
|
||||
else:
|
||||
kernel = 3
|
||||
|
||||
if deconv and is_3d and keep_dispc:
|
||||
kernel = (1, 4, 4)
|
||||
stride = (1, 2, 2)
|
||||
padding = (0, 1, 1)
|
||||
self.conv1 = BasicConv_IN(in_channels, out_channels, deconv, is_3d, IN=True, relu=True, kernel_size=kernel, stride=stride, padding=padding)
|
||||
else:
|
||||
self.conv1 = BasicConv_IN(in_channels, out_channels, deconv, is_3d, IN=True, relu=True, kernel_size=kernel, stride=2, padding=1)
|
||||
|
||||
if self.concat:
|
||||
mul = 2 if keep_concat else 1
|
||||
self.conv2 = BasicConv_IN(out_channels*2, out_channels*mul, False, is_3d, IN, relu, kernel_size=3, stride=1, padding=1)
|
||||
else:
|
||||
self.conv2 = BasicConv_IN(out_channels, out_channels, False, is_3d, IN, relu, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, x, rem):
|
||||
x = self.conv1(x)
|
||||
if x.shape != rem.shape:
|
||||
x = F.interpolate(
|
||||
x,
|
||||
size=(rem.shape[-2], rem.shape[-1]),
|
||||
mode='nearest')
|
||||
if self.concat:
|
||||
x = torch.cat((x, rem), 1)
|
||||
else:
|
||||
x = x + rem
|
||||
x = self.conv2(x)
|
||||
return x
|
||||
|
||||
class ConvReLU(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, pad=1, dilation=1):
|
||||
super(ConvReLU, self).__init__()
|
||||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=pad, dilation=dilation, bias=False)
|
||||
|
||||
def forward(self,x):
|
||||
return F.relu(self.conv(x), inplace=True)
|
||||
|
||||
class DepthInitialization(nn.Module):
|
||||
def __init__(self, num_sample):
|
||||
super(DepthInitialization, self).__init__()
|
||||
self.num_sample = num_sample
|
||||
|
||||
def forward(self, inverse_depth_min, inverse_depth_max, height, width, device):
|
||||
batch = inverse_depth_min.size()[0]
|
||||
index = torch.arange(0, self.num_sample, 1, device=device).view(1, self.num_sample, 1, 1).float()
|
||||
normalized_sample = index.repeat(batch, 1, height, width) / (self.num_sample-1)
|
||||
depth_sample = inverse_depth_max + normalized_sample * (inverse_depth_min - inverse_depth_max)
|
||||
|
||||
depth_sample = 1.0 / depth_sample
|
||||
|
||||
return depth_sample
|
||||
|
||||
class PixelViewWeight(nn.Module):
|
||||
def __init__(self, G):
|
||||
super(PixelViewWeight, self).__init__()
|
||||
self.conv = nn.Sequential(
|
||||
ConvReLU(G, 16),
|
||||
nn.Conv2d(16, 1, 1, stride=1, padding=0),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
# x: [B, G, N, H, W]
|
||||
batch, dim, num_depth, height, width = x.size()
|
||||
x = x.permute(0, 2, 1, 3, 4).contiguous()
|
||||
x = x.view(batch*num_depth, dim, height, width) # [B*N,G,H,W]
|
||||
x =self.conv(x).view(batch, num_depth, height, width)
|
||||
x = torch.softmax(x,dim=1)
|
||||
x = torch.max(x, dim=1)[0]
|
||||
|
||||
return x.unsqueeze(1)
|
||||
|
||||
class FeatureAtt(nn.Module):
|
||||
def __init__(self, cv_chan, feat_chan):
|
||||
super(FeatureAtt, self).__init__()
|
||||
|
||||
self.feat_att = nn.Sequential(
|
||||
BasicConv(feat_chan, feat_chan//2, kernel_size=1, stride=1, padding=0),
|
||||
nn.Conv2d(feat_chan//2, cv_chan, 1))
|
||||
|
||||
def forward(self, cv, feat):
|
||||
'''
|
||||
'''
|
||||
feat_att = self.feat_att(feat).unsqueeze(2)
|
||||
cv = torch.sigmoid(feat_att)*cv
|
||||
return cv
|
||||
|
||||
class hourglass(nn.Module):
|
||||
def __init__(self, in_channels):
|
||||
super(hourglass, self).__init__()
|
||||
|
||||
self.conv1 = nn.Sequential(BasicConv(in_channels, in_channels*2, is_3d=True, bn=True, relu=True, kernel_size=3,
|
||||
padding=1, stride=2, dilation=1),
|
||||
BasicConv(in_channels*2, in_channels*2, is_3d=True, bn=True, relu=True, kernel_size=3,
|
||||
padding=1, stride=1, dilation=1))
|
||||
|
||||
self.conv2 = nn.Sequential(BasicConv(in_channels*2, in_channels*4, is_3d=True, bn=True, relu=True, kernel_size=3,
|
||||
padding=1, stride=2, dilation=1),
|
||||
BasicConv(in_channels*4, in_channels*4, is_3d=True, bn=True, relu=True, kernel_size=3,
|
||||
padding=1, stride=1, dilation=1))
|
||||
|
||||
self.conv3 = nn.Sequential(BasicConv(in_channels*4, in_channels*6, is_3d=True, bn=True, relu=True, kernel_size=3,
|
||||
padding=1, stride=2, dilation=1),
|
||||
BasicConv(in_channels*6, in_channels*6, is_3d=True, bn=True, relu=True, kernel_size=3,
|
||||
padding=1, stride=1, dilation=1))
|
||||
|
||||
|
||||
self.conv3_up = BasicConv(in_channels*6, in_channels*4, deconv=True, is_3d=True, bn=True,
|
||||
relu=True, kernel_size=(4, 4, 4), padding=(1, 1, 1), stride=(2, 2, 2))
|
||||
|
||||
self.conv2_up = BasicConv(in_channels*4, in_channels*2, deconv=True, is_3d=True, bn=True,
|
||||
relu=True, kernel_size=(4, 4, 4), padding=(1, 1, 1), stride=(2, 2, 2))
|
||||
|
||||
self.conv1_up = BasicConv(in_channels*2, 1, deconv=True, is_3d=True, bn=False,
|
||||
relu=False, kernel_size=(4, 4, 4), padding=(1, 1, 1), stride=(2, 2, 2))
|
||||
|
||||
self.agg_0 = nn.Sequential(BasicConv(in_channels*8, in_channels*4, is_3d=True, kernel_size=1, padding=0, stride=1),
|
||||
BasicConv(in_channels*4, in_channels*4, is_3d=True, kernel_size=3, padding=1, stride=1),
|
||||
BasicConv(in_channels*4, in_channels*4, is_3d=True, kernel_size=3, padding=1, stride=1),)
|
||||
|
||||
self.agg_1 = nn.Sequential(BasicConv(in_channels*4, in_channels*2, is_3d=True, kernel_size=1, padding=0, stride=1),
|
||||
BasicConv(in_channels*2, in_channels*2, is_3d=True, kernel_size=3, padding=1, stride=1),
|
||||
BasicConv(in_channels*2, in_channels*2, is_3d=True, kernel_size=3, padding=1, stride=1))
|
||||
|
||||
|
||||
self.feature_att_8 = FeatureAtt(in_channels*2, 64)
|
||||
self.feature_att_16 = FeatureAtt(in_channels*4, 192)
|
||||
self.feature_att_32 = FeatureAtt(in_channels*6, 160)
|
||||
self.feature_att_up_16 = FeatureAtt(in_channels*4, 192)
|
||||
self.feature_att_up_8 = FeatureAtt(in_channels*2, 64)
|
||||
|
||||
def forward(self, x, features):
|
||||
conv1 = self.conv1(x)
|
||||
conv1 = self.feature_att_8(conv1, features[1])
|
||||
|
||||
conv2 = self.conv2(conv1)
|
||||
conv2 = self.feature_att_16(conv2, features[2])
|
||||
|
||||
conv3 = self.conv3(conv2)
|
||||
conv3 = self.feature_att_32(conv3, features[3])
|
||||
|
||||
conv3_up = self.conv3_up(conv3)
|
||||
|
||||
conv2 = torch.cat((conv3_up, conv2), dim=1)
|
||||
conv2 = self.agg_0(conv2)
|
||||
conv2 = self.feature_att_up_16(conv2, features[2])
|
||||
|
||||
conv2_up = self.conv2_up(conv2)
|
||||
|
||||
conv1 = torch.cat((conv2_up, conv1), dim=1)
|
||||
conv1 = self.agg_1(conv1)
|
||||
conv1 = self.feature_att_up_8(conv1, features[1])
|
||||
|
||||
conv = self.conv1_up(conv1)
|
||||
|
||||
return conv
|
||||
|
||||
def bilinear_sampler(img, coords, mode='bilinear', mask=False):
|
||||
""" Wrapper for grid_sample, uses pixel coordinates """
|
||||
H, W = img.shape[-2:]
|
||||
xgrid, ygrid = coords.split([1,1], dim=-1)
|
||||
xgrid = 2*xgrid/(W-1) - 1
|
||||
|
||||
assert torch.unique(ygrid).numel() == 1 and H == 1 # This is a stereo problem
|
||||
|
||||
grid = torch.cat([xgrid, ygrid], dim=-1)
|
||||
img = F.grid_sample(img, grid, align_corners=True)
|
||||
if mask:
|
||||
mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
|
||||
return img, mask.float()
|
||||
|
||||
return img
|
||||
|
||||
def context_upsample(disp_low, up_weights):
|
||||
###
|
||||
# cv (b,1,h,w)
|
||||
# sp (b,9,4*h,4*w)
|
||||
###
|
||||
b, c, h, w = disp_low.shape
|
||||
|
||||
disp_unfold = F.unfold(disp_low.reshape(b,c,h,w),3,1,1).reshape(b,-1,h,w)
|
||||
disp_unfold = F.interpolate(disp_unfold,(h*4,w*4),mode='nearest').reshape(b,9,h*4,w*4)
|
||||
|
||||
disp = (disp_unfold*up_weights).sum(1)
|
||||
|
||||
return disp
|
||||
|
||||
def pool2x(x):
|
||||
return F.avg_pool2d(x, 3, stride=2, padding=1)
|
||||
|
||||
def interp(x, dest):
|
||||
interp_args = {'mode': 'bilinear', 'align_corners': True}
|
||||
return F.interpolate(x, dest.shape[2:], **interp_args)
|
||||
|
||||
def differentiable_warping(src_fea, src_proj, ref_proj, depth_samples, return_mask=False):
|
||||
# src_fea: [B, C, H, W]
|
||||
# src_proj: [B, 4, 4]
|
||||
# ref_proj: [B, 4, 4]
|
||||
# depth_samples: [B, Ndepth, H, W]
|
||||
# out: [B, C, Ndepth, H, W]
|
||||
batch, num_depth, height, width = depth_samples.size()
|
||||
height1, width1 = src_fea.size()[2:]
|
||||
|
||||
with torch.no_grad():
|
||||
if batch==2:
|
||||
inv_ref_proj = []
|
||||
for i in range(batch):
|
||||
inv_ref_proj.append(torch.inverse(ref_proj[i]).unsqueeze(0))
|
||||
inv_ref_proj = torch.cat(inv_ref_proj, dim=0)
|
||||
assert (not torch.isnan(inv_ref_proj).any()), "nan in inverse(ref_proj)"
|
||||
proj = torch.matmul(src_proj, inv_ref_proj)
|
||||
else:
|
||||
proj = torch.matmul(src_proj, torch.inverse(ref_proj))
|
||||
assert (not torch.isnan(proj).any()), "nan in proj"
|
||||
|
||||
rot = proj[:, :3, :3] # [B,3,3]
|
||||
trans = proj[:, :3, 3:4] # [B,3,1]
|
||||
y, x = torch.meshgrid([torch.arange(0, height, dtype=torch.float32, device=depth_samples.device),
|
||||
torch.arange(0, width, dtype=torch.float32, device=depth_samples.device)])
|
||||
y, x = y.contiguous(), x.contiguous()
|
||||
y, x = y.view(height * width), x.view(height * width)
|
||||
y = y*(height1/height)
|
||||
x = x*(width1/width)
|
||||
xyz = torch.stack((x, y, torch.ones_like(x))) # [3, H*W]
|
||||
xyz = torch.unsqueeze(xyz, 0).repeat(batch, 1, 1) # [B, 3, H*W]
|
||||
rot_xyz = torch.matmul(rot, xyz) # [B, 3, H*W]
|
||||
|
||||
rot_depth_xyz = rot_xyz.unsqueeze(2).repeat(1, 1, num_depth, 1) * depth_samples.view(batch, 1, num_depth,
|
||||
height * width) # [B, 3, Ndepth, H*W]
|
||||
proj_xyz = rot_depth_xyz + trans.view(batch, 3, 1, 1) # [B, 3, Ndepth, H*W]
|
||||
# avoid negative depth
|
||||
valid_mask = proj_xyz[:, 2:] > 1e-2
|
||||
proj_xyz[:, 0:1][~valid_mask] = width
|
||||
proj_xyz[:, 1:2][~valid_mask] = height
|
||||
proj_xyz[:, 2:3][~valid_mask] = 1
|
||||
proj_xy = proj_xyz[:, :2, :, :] / proj_xyz[:, 2:3, :, :] # [B, 2, Ndepth, H*W]
|
||||
valid_mask = valid_mask & (proj_xy[:, 0:1] >=0) & (proj_xy[:, 0:1] < width) \
|
||||
& (proj_xy[:, 1:2] >=0) & (proj_xy[:, 1:2] < height)
|
||||
proj_x_normalized = proj_xy[:, 0, :, :] / ((width1 - 1) / 2) - 1 # [B, Ndepth, H*W]
|
||||
proj_y_normalized = proj_xy[:, 1, :, :] / ((height1 - 1) / 2) - 1
|
||||
proj_xy = torch.stack((proj_x_normalized, proj_y_normalized), dim=3) # [B, Ndepth, H*W, 2]
|
||||
grid = proj_xy
|
||||
|
||||
dim = src_fea.size()[1]
|
||||
warped_src_fea = F.grid_sample(src_fea, grid.view(batch, num_depth * height, width, 2), mode='bilinear',
|
||||
padding_mode='zeros',align_corners=True)
|
||||
warped_src_fea = warped_src_fea.view(batch, dim, num_depth, height, width)
|
||||
if return_mask:
|
||||
valid_mask = valid_mask.view(batch,num_depth,height,width)
|
||||
return warped_src_fea, valid_mask
|
||||
else:
|
||||
return warped_src_fea
|
||||
|
||||
def depth_normalization(depth, inverse_depth_min, inverse_depth_max):
|
||||
'''convert depth map to the index in inverse range'''
|
||||
inverse_depth = 1.0 / (depth+1e-5)
|
||||
normalized_depth = (inverse_depth - inverse_depth_max) / (inverse_depth_min - inverse_depth_max)
|
||||
return normalized_depth
|
||||
|
||||
def depth_unnormalization(normalized_depth, inverse_depth_min, inverse_depth_max):
|
||||
'''convert the index in inverse range to depth map'''
|
||||
inverse_depth = inverse_depth_max + normalized_depth * (inverse_depth_min - inverse_depth_max) # [B,1,H,W]
|
||||
depth = 1.0 / inverse_depth
|
||||
return depth
|
94
IGEV-MVS/core/update.py
Normal file
94
IGEV-MVS/core/update.py
Normal file
@ -0,0 +1,94 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from .submodule import *
|
||||
|
||||
class BasicMotionEncoder(nn.Module):
|
||||
def __init__(self):
|
||||
super(BasicMotionEncoder, self).__init__()
|
||||
self.corr_levels = 2
|
||||
self.corr_radius = 4
|
||||
|
||||
cor_planes = 2 * self.corr_levels * (2*self.corr_radius + 1)
|
||||
|
||||
self.convc1 = nn.Conv2d(cor_planes, 64, 1, padding=0)
|
||||
self.convc2 = nn.Conv2d(64, 64, 3, padding=1)
|
||||
self.convd1 = nn.Conv2d(1, 64, 7, padding=3)
|
||||
self.convd2 = nn.Conv2d(64, 64, 3, padding=1)
|
||||
self.conv = nn.Conv2d(64+64, 128-1, 3, padding=1)
|
||||
|
||||
def forward(self, disp, corr):
|
||||
cor = F.relu(self.convc1(corr))
|
||||
cor = F.relu(self.convc2(cor))
|
||||
disp_ = F.relu(self.convd1(disp))
|
||||
disp_ = F.relu(self.convd2(disp_))
|
||||
|
||||
cor_disp = torch.cat([cor, disp_], dim=1)
|
||||
out = F.relu(self.conv(cor_disp))
|
||||
return torch.cat([out, disp], dim=1)
|
||||
|
||||
class ConvGRU(nn.Module):
|
||||
def __init__(self, hidden_dim, input_dim, kernel_size=3):
|
||||
super(ConvGRU, self).__init__()
|
||||
self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, kernel_size, padding=kernel_size//2)
|
||||
self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, kernel_size, padding=kernel_size//2)
|
||||
self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, kernel_size, padding=kernel_size//2)
|
||||
|
||||
def forward(self, h, *x_list):
|
||||
x = torch.cat(x_list, dim=1)
|
||||
hx = torch.cat([h, x], dim=1)
|
||||
z = torch.sigmoid(self.convz(hx))
|
||||
r = torch.sigmoid(self.convr(hx))
|
||||
q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1)))
|
||||
h = (1-z) * h + z * q
|
||||
return h
|
||||
|
||||
class DispHead(nn.Module):
|
||||
def __init__(self, input_dim=128, hidden_dim=256, output_dim=1):
|
||||
super(DispHead, self).__init__()
|
||||
self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
|
||||
self.conv2 = nn.Conv2d(hidden_dim, output_dim, 3, padding=1)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv2(self.relu(self.conv1(x)))
|
||||
|
||||
class BasicMultiUpdateBlock(nn.Module):
|
||||
def __init__(self, hidden_dims=[]):
|
||||
super().__init__()
|
||||
self.n_gru_layers = 3
|
||||
self.n_downsample = 2
|
||||
self.encoder = BasicMotionEncoder()
|
||||
encoder_output_dim = 128
|
||||
|
||||
self.gru04 = ConvGRU(hidden_dims[2], encoder_output_dim + hidden_dims[1] * (self.n_gru_layers > 1))
|
||||
self.gru08 = ConvGRU(hidden_dims[1], hidden_dims[0] * (self.n_gru_layers == 3) + hidden_dims[2])
|
||||
self.gru16 = ConvGRU(hidden_dims[0], hidden_dims[1])
|
||||
self.disp_head = DispHead(hidden_dims[2], hidden_dim=256, output_dim=1)
|
||||
factor = 2**self.n_downsample
|
||||
|
||||
self.mask_feat_4 = nn.Sequential(
|
||||
nn.Conv2d(hidden_dims[2], 32, 3, padding=1),
|
||||
nn.ReLU(inplace=True))
|
||||
|
||||
def forward(self, net, corr=None, disp=None, iter04=True, iter08=True, iter16=True, update=True):
|
||||
if iter16:
|
||||
net[2] = self.gru16(net[2], pool2x(net[1]))
|
||||
if iter08:
|
||||
if self.n_gru_layers > 2:
|
||||
net[1] = self.gru08(net[1], pool2x(net[0]), interp(net[2], net[1]))
|
||||
else:
|
||||
net[1] = self.gru08(net[1], pool2x(net[0]))
|
||||
if iter04:
|
||||
motion_features = self.encoder(disp, corr)
|
||||
if self.n_gru_layers > 1:
|
||||
net[0] = self.gru04(net[0], motion_features, interp(net[1], net[0]))
|
||||
else:
|
||||
net[0] = self.gru04(net[0], motion_features)
|
||||
|
||||
if not update:
|
||||
return net
|
||||
|
||||
delta_disp = self.disp_head(net[0])
|
||||
mask_feat_4 = self.mask_feat_4(net[0])
|
||||
return net, mask_feat_4, delta_disp
|
8
IGEV-MVS/datasets/__init__.py
Normal file
8
IGEV-MVS/datasets/__init__.py
Normal file
@ -0,0 +1,8 @@
|
||||
import importlib
|
||||
|
||||
|
||||
# find the dataset definition by name, for example dtu_yao (dtu_yao.py)
|
||||
def find_dataset_def(dataset_name):
|
||||
module_name = 'datasets.{}'.format(dataset_name)
|
||||
module = importlib.import_module(module_name)
|
||||
return getattr(module, "MVSDataset")
|
208
IGEV-MVS/datasets/blendedmvs.py
Normal file
208
IGEV-MVS/datasets/blendedmvs.py
Normal file
@ -0,0 +1,208 @@
|
||||
from torch.utils.data import Dataset
|
||||
from datasets.data_io import *
|
||||
import os
|
||||
import numpy as np
|
||||
import cv2
|
||||
from PIL import Image
|
||||
from torchvision import transforms as T
|
||||
import random
|
||||
|
||||
class MVSDataset(Dataset):
|
||||
def __init__(self, datapath, listfile, split, nviews, img_wh=(768, 576), robust_train=True):
|
||||
|
||||
super(MVSDataset, self).__init__()
|
||||
self.levels = 4
|
||||
self.datapath = datapath
|
||||
self.split = split
|
||||
self.listfile = listfile
|
||||
self.robust_train = robust_train
|
||||
assert self.split in ['train', 'val', 'all'], \
|
||||
'split must be either "train", "val" or "all"!'
|
||||
|
||||
self.img_wh = img_wh
|
||||
if img_wh is not None:
|
||||
assert img_wh[0]%32==0 and img_wh[1]%32==0, \
|
||||
'img_wh must both be multiples of 32!'
|
||||
self.nviews = nviews
|
||||
self.scale_factors = {} # depth scale factors for each scan
|
||||
self.build_metas()
|
||||
|
||||
self.color_augment = T.ColorJitter(brightness=0.5, contrast=0.5)
|
||||
|
||||
def build_metas(self):
|
||||
self.metas = []
|
||||
with open(self.listfile) as f:
|
||||
self.scans = [line.rstrip() for line in f.readlines()]
|
||||
for scan in self.scans:
|
||||
with open(os.path.join(self.datapath, scan, "cams/pair.txt")) as f:
|
||||
num_viewpoint = int(f.readline())
|
||||
for _ in range(num_viewpoint):
|
||||
ref_view = int(f.readline().rstrip())
|
||||
src_views = [int(x) for x in f.readline().rstrip().split()[1::2]]
|
||||
if len(src_views) >= self.nviews-1:
|
||||
self.metas += [(scan, ref_view, src_views)]
|
||||
|
||||
def read_cam_file(self, scan, filename):
|
||||
with open(filename) as f:
|
||||
lines = f.readlines()
|
||||
lines = [line.rstrip() for line in lines]
|
||||
# extrinsics: line [1,5), 4x4 matrix
|
||||
extrinsics = np.fromstring(' '.join(lines[1:5]), dtype=np.float32, sep=' ').reshape((4, 4))
|
||||
# intrinsics: line [7-10), 3x3 matrix
|
||||
intrinsics = np.fromstring(' '.join(lines[7:10]), dtype=np.float32, sep=' ').reshape((3, 3))
|
||||
depth_min = float(lines[11].split()[0])
|
||||
depth_max = float(lines[11].split()[-1])
|
||||
if scan not in self.scale_factors:
|
||||
self.scale_factors[scan] = 100.0/depth_min
|
||||
depth_min *= self.scale_factors[scan]
|
||||
depth_max *= self.scale_factors[scan]
|
||||
extrinsics[:3, 3] *= self.scale_factors[scan]
|
||||
return intrinsics, extrinsics, depth_min, depth_max
|
||||
|
||||
def read_depth_mask(self, scan, filename, depth_min, depth_max, scale):
|
||||
depth = np.array(read_pfm(filename)[0], dtype=np.float32)
|
||||
depth = depth * self.scale_factors[scan] * scale
|
||||
depth = np.squeeze(depth,2)
|
||||
|
||||
mask = (depth>=depth_min) & (depth<=depth_max)
|
||||
mask = mask.astype(np.float32)
|
||||
if self.img_wh is not None:
|
||||
depth = cv2.resize(depth, self.img_wh,
|
||||
interpolation=cv2.INTER_NEAREST)
|
||||
h, w = depth.shape
|
||||
depth_ms = {}
|
||||
mask_ms = {}
|
||||
|
||||
for i in range(4):
|
||||
depth_cur = cv2.resize(depth, (w//(2**i), h//(2**i)), interpolation=cv2.INTER_NEAREST)
|
||||
mask_cur = cv2.resize(mask, (w//(2**i), h//(2**i)), interpolation=cv2.INTER_NEAREST)
|
||||
|
||||
depth_ms[f"level_{i}"] = depth_cur
|
||||
mask_ms[f"level_{i}"] = mask_cur
|
||||
|
||||
return depth_ms, mask_ms
|
||||
|
||||
def read_img(self, filename):
|
||||
img = Image.open(filename)
|
||||
if self.split=='train':
|
||||
img = self.color_augment(img)
|
||||
# scale 0~255 to -1~1
|
||||
np_img = 2*np.array(img, dtype=np.float32) / 255. - 1
|
||||
if self.img_wh is not None:
|
||||
np_img = cv2.resize(np_img, self.img_wh,
|
||||
interpolation=cv2.INTER_LINEAR)
|
||||
h, w, _ = np_img.shape
|
||||
np_img_ms = {
|
||||
"level_3": cv2.resize(np_img, (w//8, h//8), interpolation=cv2.INTER_LINEAR),
|
||||
"level_2": cv2.resize(np_img, (w//4, h//4), interpolation=cv2.INTER_LINEAR),
|
||||
"level_1": cv2.resize(np_img, (w//2, h//2), interpolation=cv2.INTER_LINEAR),
|
||||
"level_0": np_img
|
||||
}
|
||||
return np_img_ms
|
||||
|
||||
def __len__(self):
|
||||
return len(self.metas)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
meta = self.metas[idx]
|
||||
scan, ref_view, src_views = meta
|
||||
|
||||
if self.robust_train:
|
||||
num_src_views = len(src_views)
|
||||
index = random.sample(range(num_src_views), self.nviews - 1)
|
||||
view_ids = [ref_view] + [src_views[i] for i in index]
|
||||
scale = random.uniform(0.8, 1.25)
|
||||
|
||||
else:
|
||||
view_ids = [ref_view] + src_views[:self.nviews - 1]
|
||||
scale = 1
|
||||
|
||||
imgs_0 = []
|
||||
imgs_1 = []
|
||||
imgs_2 = []
|
||||
imgs_3 = []
|
||||
mask = None
|
||||
depth = None
|
||||
depth_min = None
|
||||
depth_max = None
|
||||
proj_matrices_0 = []
|
||||
proj_matrices_1 = []
|
||||
proj_matrices_2 = []
|
||||
proj_matrices_3 = []
|
||||
|
||||
|
||||
for i, vid in enumerate(view_ids):
|
||||
img_filename = os.path.join(self.datapath, '{}/blended_images/{:0>8}.jpg'.format(scan, vid))
|
||||
depth_filename = os.path.join(self.datapath, '{}/rendered_depth_maps/{:0>8}.pfm'.format(scan, vid))
|
||||
proj_mat_filename = os.path.join(self.datapath, '{}/cams/{:0>8}_cam.txt'.format(scan, vid))
|
||||
|
||||
imgs = self.read_img(img_filename)
|
||||
imgs_0.append(imgs['level_0'])
|
||||
imgs_1.append(imgs['level_1'])
|
||||
imgs_2.append(imgs['level_2'])
|
||||
imgs_3.append(imgs['level_3'])
|
||||
|
||||
# here, the intrinsics from file is already adjusted to the downsampled size of feature 1/4H0 * 1/4W0
|
||||
intrinsics, extrinsics, depth_min_, depth_max_ = self.read_cam_file(scan, proj_mat_filename)
|
||||
extrinsics[:3, 3] *= scale
|
||||
proj_mat = extrinsics.copy()
|
||||
intrinsics[:2,:] *= 0.125
|
||||
proj_mat[:3, :4] = np.matmul(intrinsics, proj_mat[:3, :4])
|
||||
proj_matrices_3.append(proj_mat)
|
||||
|
||||
proj_mat = extrinsics.copy()
|
||||
intrinsics[:2,:] *= 2
|
||||
proj_mat[:3, :4] = np.matmul(intrinsics, proj_mat[:3, :4])
|
||||
proj_matrices_2.append(proj_mat)
|
||||
|
||||
proj_mat = extrinsics.copy()
|
||||
intrinsics[:2,:] *= 2
|
||||
proj_mat[:3, :4] = np.matmul(intrinsics, proj_mat[:3, :4])
|
||||
proj_matrices_1.append(proj_mat)
|
||||
|
||||
proj_mat = extrinsics.copy()
|
||||
intrinsics[:2,:] *= 2
|
||||
proj_mat[:3, :4] = np.matmul(intrinsics, proj_mat[:3, :4])
|
||||
proj_matrices_0.append(proj_mat)
|
||||
|
||||
|
||||
if i == 0: # reference view
|
||||
depth_min = depth_min_ * scale
|
||||
depth_max = depth_max_ * scale
|
||||
depth, mask = self.read_depth_mask(scan, depth_filename, depth_min, depth_max, scale)
|
||||
for l in range(self.levels):
|
||||
mask[f'level_{l}'] = np.expand_dims(mask[f'level_{l}'],2)
|
||||
mask[f'level_{l}'] = mask[f'level_{l}'].transpose([2,0,1])
|
||||
depth[f'level_{l}'] = np.expand_dims(depth[f'level_{l}'],2)
|
||||
depth[f'level_{l}'] = depth[f'level_{l}'].transpose([2,0,1])
|
||||
|
||||
# imgs: N*3*H0*W0, N is number of images
|
||||
imgs_0 = np.stack(imgs_0).transpose([0, 3, 1, 2])
|
||||
imgs_1 = np.stack(imgs_1).transpose([0, 3, 1, 2])
|
||||
imgs_2 = np.stack(imgs_2).transpose([0, 3, 1, 2])
|
||||
imgs_3 = np.stack(imgs_3).transpose([0, 3, 1, 2])
|
||||
imgs = {}
|
||||
imgs['level_0'] = imgs_0
|
||||
imgs['level_1'] = imgs_1
|
||||
imgs['level_2'] = imgs_2
|
||||
imgs['level_3'] = imgs_3
|
||||
|
||||
# proj_matrices: N*4*4
|
||||
proj_matrices_0 = np.stack(proj_matrices_0)
|
||||
proj_matrices_1 = np.stack(proj_matrices_1)
|
||||
proj_matrices_2 = np.stack(proj_matrices_2)
|
||||
proj_matrices_3 = np.stack(proj_matrices_3)
|
||||
proj={}
|
||||
proj['level_3']=proj_matrices_3
|
||||
proj['level_2']=proj_matrices_2
|
||||
proj['level_1']=proj_matrices_1
|
||||
proj['level_0']=proj_matrices_0
|
||||
|
||||
# data is numpy array
|
||||
return {"imgs": imgs, # [N, 3, H, W]
|
||||
"proj_matrices": proj, # [N,4,4]
|
||||
"depth": depth, # [1, H, W]
|
||||
"depth_min": depth_min, # scalar
|
||||
"depth_max": depth_max, # scalar
|
||||
"mask": mask} # [1, H, W]
|
||||
|
145
IGEV-MVS/datasets/custom.py
Normal file
145
IGEV-MVS/datasets/custom.py
Normal file
@ -0,0 +1,145 @@
|
||||
from torch.utils.data import Dataset
|
||||
from datasets.data_io import *
|
||||
import os
|
||||
import numpy as np
|
||||
import cv2
|
||||
from PIL import Image
|
||||
from torchvision import transforms as T
|
||||
import math
|
||||
|
||||
class MVSDataset(Dataset):
|
||||
def __init__(self, datapath, n_views=5, img_wh=(640,480)):
|
||||
self.levels = 4
|
||||
self.datapath = datapath
|
||||
self.img_wh = img_wh
|
||||
self.build_metas()
|
||||
self.n_views = n_views
|
||||
|
||||
def build_metas(self):
|
||||
self.metas = []
|
||||
with open(os.path.join(self.datapath, 'pair.txt')) as f:
|
||||
num_viewpoint = int(f.readline())
|
||||
for view_idx in range(num_viewpoint):
|
||||
ref_view = int(f.readline().rstrip())
|
||||
src_views = [int(x) for x in f.readline().rstrip().split()[1::2]]
|
||||
if len(src_views) != 0:
|
||||
self.metas += [(ref_view, src_views)]
|
||||
|
||||
|
||||
def read_cam_file(self, filename):
|
||||
with open(filename) as f:
|
||||
lines = [line.rstrip() for line in f.readlines()]
|
||||
# extrinsics: line [1,5), 4x4 matrix
|
||||
extrinsics = np.fromstring(' '.join(lines[1:5]), dtype=np.float32, sep=' ')
|
||||
extrinsics = extrinsics.reshape((4, 4))
|
||||
# intrinsics: line [7-10), 3x3 matrix
|
||||
intrinsics = np.fromstring(' '.join(lines[7:10]), dtype=np.float32, sep=' ')
|
||||
intrinsics = intrinsics.reshape((3, 3))
|
||||
|
||||
depth_min = float(lines[11].split()[0])
|
||||
depth_max = float(lines[11].split()[-1])
|
||||
|
||||
return intrinsics, extrinsics, depth_min, depth_max
|
||||
|
||||
def read_img(self, filename, h, w):
|
||||
img = Image.open(filename)
|
||||
# scale 0~255 to -1~1
|
||||
np_img = 2*np.array(img, dtype=np.float32) / 255. - 1
|
||||
original_h, original_w, _ = np_img.shape
|
||||
np_img = cv2.resize(np_img, self.img_wh, interpolation=cv2.INTER_LINEAR)
|
||||
|
||||
np_img_ms = {
|
||||
"level_3": cv2.resize(np_img, (w//8, h//8), interpolation=cv2.INTER_LINEAR),
|
||||
"level_2": cv2.resize(np_img, (w//4, h//4), interpolation=cv2.INTER_LINEAR),
|
||||
"level_1": cv2.resize(np_img, (w//2, h//2), interpolation=cv2.INTER_LINEAR),
|
||||
"level_0": np_img
|
||||
}
|
||||
return np_img_ms, original_h, original_w
|
||||
|
||||
def __len__(self):
|
||||
return len(self.metas)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
ref_view, src_views = self.metas[idx]
|
||||
# use only the reference view and first nviews-1 source views
|
||||
view_ids = [ref_view] + src_views[:self.n_views-1]
|
||||
|
||||
imgs_0 = []
|
||||
imgs_1 = []
|
||||
imgs_2 = []
|
||||
imgs_3 = []
|
||||
|
||||
# depth = None
|
||||
depth_min = None
|
||||
depth_max = None
|
||||
|
||||
proj_matrices_0 = []
|
||||
proj_matrices_1 = []
|
||||
proj_matrices_2 = []
|
||||
proj_matrices_3 = []
|
||||
|
||||
for i, vid in enumerate(view_ids):
|
||||
img_filename = os.path.join(self.datapath, f'images/{vid:08d}.jpg')
|
||||
proj_mat_filename = os.path.join(self.datapath, f'cams_1/{vid:08d}_cam.txt')
|
||||
|
||||
imgs, original_h, original_w = self.read_img(img_filename,self.img_wh[1], self.img_wh[0])
|
||||
imgs_0.append(imgs['level_0'])
|
||||
imgs_1.append(imgs['level_1'])
|
||||
imgs_2.append(imgs['level_2'])
|
||||
imgs_3.append(imgs['level_3'])
|
||||
|
||||
intrinsics, extrinsics, depth_min_, depth_max_ = self.read_cam_file(proj_mat_filename)
|
||||
intrinsics[0] *= self.img_wh[0]/original_w
|
||||
intrinsics[1] *= self.img_wh[1]/original_h
|
||||
|
||||
proj_mat = extrinsics.copy()
|
||||
intrinsics[:2,:] *= 0.125
|
||||
proj_mat[:3, :4] = np.matmul(intrinsics, proj_mat[:3, :4])
|
||||
proj_matrices_3.append(proj_mat)
|
||||
|
||||
proj_mat = extrinsics.copy()
|
||||
intrinsics[:2,:] *= 2
|
||||
proj_mat[:3, :4] = np.matmul(intrinsics, proj_mat[:3, :4])
|
||||
proj_matrices_2.append(proj_mat)
|
||||
|
||||
proj_mat = extrinsics.copy()
|
||||
intrinsics[:2,:] *= 2
|
||||
proj_mat[:3, :4] = np.matmul(intrinsics, proj_mat[:3, :4])
|
||||
proj_matrices_1.append(proj_mat)
|
||||
|
||||
proj_mat = extrinsics.copy()
|
||||
intrinsics[:2,:] *= 2
|
||||
proj_mat[:3, :4] = np.matmul(intrinsics, proj_mat[:3, :4])
|
||||
proj_matrices_0.append(proj_mat)
|
||||
|
||||
if i == 0: # reference view
|
||||
depth_min = depth_min_
|
||||
depth_max = depth_max_
|
||||
|
||||
# imgs: N*3*H0*W0, N is number of images
|
||||
imgs_0 = np.stack(imgs_0).transpose([0, 3, 1, 2])
|
||||
imgs_1 = np.stack(imgs_1).transpose([0, 3, 1, 2])
|
||||
imgs_2 = np.stack(imgs_2).transpose([0, 3, 1, 2])
|
||||
imgs_3 = np.stack(imgs_3).transpose([0, 3, 1, 2])
|
||||
imgs = {}
|
||||
imgs['level_0'] = imgs_0
|
||||
imgs['level_1'] = imgs_1
|
||||
imgs['level_2'] = imgs_2
|
||||
imgs['level_3'] = imgs_3
|
||||
# proj_matrices: N*4*4
|
||||
proj_matrices_0 = np.stack(proj_matrices_0)
|
||||
proj_matrices_1 = np.stack(proj_matrices_1)
|
||||
proj_matrices_2 = np.stack(proj_matrices_2)
|
||||
proj_matrices_3 = np.stack(proj_matrices_3)
|
||||
proj={}
|
||||
proj['level_3']=proj_matrices_3
|
||||
proj['level_2']=proj_matrices_2
|
||||
proj['level_1']=proj_matrices_1
|
||||
proj['level_0']=proj_matrices_0
|
||||
|
||||
return {"imgs": imgs, # N*3*H0*W0
|
||||
"proj_matrices": proj, # N*4*4
|
||||
"depth_min": depth_min, # scalar
|
||||
"depth_max": depth_max,
|
||||
"filename": '{}/' + '{:0>8}'.format(view_ids[0]) + "{}"
|
||||
}
|
73
IGEV-MVS/datasets/data_io.py
Normal file
73
IGEV-MVS/datasets/data_io.py
Normal file
@ -0,0 +1,73 @@
|
||||
import numpy as np
|
||||
import re
|
||||
import sys
|
||||
|
||||
|
||||
def read_pfm(filename):
|
||||
# rb: binary file and read only
|
||||
file = open(filename, 'rb')
|
||||
color = None
|
||||
width = None
|
||||
height = None
|
||||
scale = None
|
||||
endian = None
|
||||
|
||||
header = file.readline().decode('utf-8').rstrip()
|
||||
if header == 'PF':
|
||||
color = True
|
||||
elif header == 'Pf': # depth is Pf
|
||||
color = False
|
||||
else:
|
||||
raise Exception('Not a PFM file.')
|
||||
|
||||
dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode('utf-8')) # re is used for matching
|
||||
if dim_match:
|
||||
width, height = map(int, dim_match.groups())
|
||||
else:
|
||||
raise Exception('Malformed PFM header.')
|
||||
|
||||
scale = float(file.readline().rstrip())
|
||||
if scale < 0: # little-endian
|
||||
endian = '<'
|
||||
scale = -scale
|
||||
else:
|
||||
endian = '>' # big-endian
|
||||
|
||||
data = np.fromfile(file, endian + 'f')
|
||||
shape = (height, width, 3) if color else (height, width, 1)
|
||||
# depth: H*W
|
||||
data = np.reshape(data, shape)
|
||||
data = np.flipud(data)
|
||||
file.close()
|
||||
return data, scale
|
||||
|
||||
|
||||
def save_pfm(filename, image, scale=1):
|
||||
file = open(filename, "wb")
|
||||
color = None
|
||||
|
||||
image = np.flipud(image)
|
||||
# print(image.shape)
|
||||
|
||||
if image.dtype.name != 'float32':
|
||||
raise Exception('Image dtype must be float32.')
|
||||
|
||||
if len(image.shape) == 3 and image.shape[2] == 3: # color image
|
||||
color = True
|
||||
elif len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1: # greyscale
|
||||
color = False
|
||||
else:
|
||||
raise Exception('Image must have H x W x 3, H x W x 1 or H x W dimensions.')
|
||||
|
||||
file.write('PF\n'.encode('utf-8') if color else 'Pf\n'.encode('utf-8'))
|
||||
file.write('{} {}\n'.format(image.shape[1], image.shape[0]).encode('utf-8'))
|
||||
|
||||
endian = image.dtype.byteorder
|
||||
|
||||
if endian == '<' or endian == '=' and sys.byteorder == 'little':
|
||||
scale = -scale
|
||||
|
||||
file.write(('%f\n' % scale).encode('utf-8'))
|
||||
|
||||
image.tofile(file)
|
||||
file.close()
|
236
IGEV-MVS/datasets/dtu_yao.py
Normal file
236
IGEV-MVS/datasets/dtu_yao.py
Normal file
@ -0,0 +1,236 @@
|
||||
from torch.utils.data import Dataset
|
||||
import numpy as np
|
||||
import os
|
||||
from PIL import Image
|
||||
from datasets.data_io import *
|
||||
import cv2
|
||||
import random
|
||||
from torchvision import transforms
|
||||
|
||||
|
||||
class MVSDataset(Dataset):
|
||||
def __init__(self, datapath, listfile, mode, nviews, robust_train = False):
|
||||
super(MVSDataset, self).__init__()
|
||||
|
||||
self.levels = 4
|
||||
self.datapath = datapath
|
||||
self.listfile = listfile
|
||||
self.mode = mode
|
||||
self.nviews = nviews
|
||||
self.img_wh = (640, 512)
|
||||
# self.img_wh = (1440, 1056)
|
||||
self.robust_train = robust_train
|
||||
|
||||
|
||||
assert self.mode in ["train", "val", "test"]
|
||||
self.metas = self.build_list()
|
||||
self.color_augment = transforms.ColorJitter(brightness=0.5, contrast=0.5)
|
||||
|
||||
def build_list(self):
|
||||
metas = []
|
||||
with open(self.listfile) as f:
|
||||
scans = f.readlines()
|
||||
scans = [line.rstrip() for line in scans]
|
||||
|
||||
for scan in scans:
|
||||
pair_file = "Cameras_1/pair.txt"
|
||||
|
||||
with open(os.path.join(self.datapath, pair_file)) as f:
|
||||
self.num_viewpoint = int(f.readline())
|
||||
# viewpoints (49)
|
||||
for view_idx in range(self.num_viewpoint):
|
||||
ref_view = int(f.readline().rstrip())
|
||||
src_views = [int(x) for x in f.readline().rstrip().split()[1::2]]
|
||||
# light conditions 0-6
|
||||
for light_idx in range(7):
|
||||
metas.append((scan, light_idx, ref_view, src_views))
|
||||
print("dataset", self.mode, "metas:", len(metas))
|
||||
return metas
|
||||
|
||||
def __len__(self):
|
||||
return len(self.metas)
|
||||
|
||||
def read_cam_file(self, filename):
|
||||
with open(filename) as f:
|
||||
lines = f.readlines()
|
||||
lines = [line.rstrip() for line in lines]
|
||||
# extrinsics: line [1,5), 4x4 matrix
|
||||
extrinsics = np.fromstring(' '.join(lines[1:5]), dtype=np.float32, sep=' ').reshape((4, 4))
|
||||
# intrinsics: line [7-10), 3x3 matrix
|
||||
intrinsics = np.fromstring(' '.join(lines[7:10]), dtype=np.float32, sep=' ').reshape((3, 3))
|
||||
depth_min = float(lines[11].split()[0])
|
||||
depth_max = float(lines[11].split()[-1])
|
||||
return intrinsics, extrinsics, depth_min, depth_max
|
||||
|
||||
def read_img(self, filename):
|
||||
img = Image.open(filename)
|
||||
if self.mode=='train':
|
||||
img = self.color_augment(img)
|
||||
# scale 0~255 to -1~1
|
||||
np_img = 2*np.array(img, dtype=np.float32) / 255. - 1
|
||||
h, w, _ = np_img.shape
|
||||
np_img_ms = {
|
||||
"level_3": cv2.resize(np_img, (w//8, h//8), interpolation=cv2.INTER_LINEAR),
|
||||
"level_2": cv2.resize(np_img, (w//4, h//4), interpolation=cv2.INTER_LINEAR),
|
||||
"level_1": cv2.resize(np_img, (w//2, h//2), interpolation=cv2.INTER_LINEAR),
|
||||
"level_0": np_img
|
||||
}
|
||||
return np_img_ms
|
||||
|
||||
|
||||
def prepare_img(self, hr_img):
|
||||
#downsample
|
||||
h, w = hr_img.shape
|
||||
# original w,h: 1600, 1200; downsample -> 800, 600 ; crop -> 640, 512
|
||||
hr_img = cv2.resize(hr_img, (w//2, h//2), interpolation=cv2.INTER_NEAREST)
|
||||
#crop
|
||||
h, w = hr_img.shape
|
||||
target_h, target_w = self.img_wh[1], self.img_wh[0]
|
||||
start_h, start_w = (h - target_h)//2, (w - target_w)//2
|
||||
hr_img_crop = hr_img[start_h: start_h + target_h, start_w: start_w + target_w]
|
||||
|
||||
return hr_img_crop
|
||||
|
||||
def read_mask(self, filename):
|
||||
img = Image.open(filename)
|
||||
np_img = np.array(img, dtype=np.float32)
|
||||
np_img = (np_img > 10).astype(np.float32)
|
||||
return np_img
|
||||
|
||||
|
||||
def read_depth_mask(self, filename, mask_filename, scale):
|
||||
depth_hr = np.array(read_pfm(filename)[0], dtype=np.float32) * scale
|
||||
depth_hr = np.squeeze(depth_hr,2)
|
||||
depth_lr = self.prepare_img(depth_hr)
|
||||
mask = self.read_mask(mask_filename)
|
||||
mask = self.prepare_img(mask)
|
||||
mask = mask.astype(np.bool_)
|
||||
mask = mask.astype(np.float32)
|
||||
|
||||
h, w = depth_lr.shape
|
||||
depth_lr_ms = {}
|
||||
mask_ms = {}
|
||||
|
||||
for i in range(self.levels):
|
||||
depth_cur = cv2.resize(depth_lr, (w//(2**i), h//(2**i)), interpolation=cv2.INTER_NEAREST)
|
||||
mask_cur = cv2.resize(mask, (w//(2**i), h//(2**i)), interpolation=cv2.INTER_NEAREST)
|
||||
depth_lr_ms[f"level_{i}"] = depth_cur
|
||||
mask_ms[f"level_{i}"] = mask_cur
|
||||
|
||||
return depth_lr_ms, mask_ms
|
||||
|
||||
|
||||
def __getitem__(self, idx):
|
||||
meta = self.metas[idx]
|
||||
scan, light_idx, ref_view, src_views = meta
|
||||
# robust training strategy
|
||||
if self.robust_train:
|
||||
num_src_views = len(src_views)
|
||||
index = random.sample(range(num_src_views), self.nviews - 1)
|
||||
view_ids = [ref_view] + [src_views[i] for i in index]
|
||||
scale = random.uniform(0.8, 1.25)
|
||||
|
||||
else:
|
||||
view_ids = [ref_view] + src_views[:self.nviews - 1]
|
||||
scale = 1
|
||||
|
||||
imgs_0 = []
|
||||
imgs_1 = []
|
||||
imgs_2 = []
|
||||
imgs_3 = []
|
||||
|
||||
mask = None
|
||||
depth = None
|
||||
depth_min = None
|
||||
depth_max = None
|
||||
|
||||
proj_matrices_0 = []
|
||||
proj_matrices_1 = []
|
||||
proj_matrices_2 = []
|
||||
proj_matrices_3 = []
|
||||
|
||||
|
||||
|
||||
for i, vid in enumerate(view_ids):
|
||||
img_filename = os.path.join(self.datapath,
|
||||
'Rectified/{}_train/rect_{:0>3}_{}_r5000.png'.format(scan, vid + 1, light_idx))
|
||||
proj_mat_filename = os.path.join(self.datapath, 'Cameras_1/{}_train/{:0>8}_cam.txt').format(scan, vid)
|
||||
|
||||
mask_filename = os.path.join(self.datapath, 'Depths_raw/{}/depth_visual_{:0>4}.png'.format(scan, vid))
|
||||
depth_filename = os.path.join(self.datapath, 'Depths_raw/{}/depth_map_{:0>4}.pfm'.format(scan, vid))
|
||||
|
||||
imgs = self.read_img(img_filename)
|
||||
imgs_0.append(imgs['level_0'])
|
||||
imgs_1.append(imgs['level_1'])
|
||||
imgs_2.append(imgs['level_2'])
|
||||
imgs_3.append(imgs['level_3'])
|
||||
|
||||
intrinsics, extrinsics, depth_min_, depth_max_ = self.read_cam_file(proj_mat_filename)
|
||||
extrinsics[:3,3] *= scale
|
||||
intrinsics[0] *= 4
|
||||
intrinsics[1] *= 4
|
||||
|
||||
proj_mat = extrinsics.copy()
|
||||
intrinsics[:2,:] *= 0.125
|
||||
proj_mat[:3, :4] = np.matmul(intrinsics, proj_mat[:3, :4])
|
||||
proj_matrices_3.append(proj_mat)
|
||||
|
||||
proj_mat = extrinsics.copy()
|
||||
intrinsics[:2,:] *= 2
|
||||
proj_mat[:3, :4] = np.matmul(intrinsics, proj_mat[:3, :4])
|
||||
proj_matrices_2.append(proj_mat)
|
||||
|
||||
proj_mat = extrinsics.copy()
|
||||
intrinsics[:2,:] *= 2
|
||||
proj_mat[:3, :4] = np.matmul(intrinsics, proj_mat[:3, :4])
|
||||
proj_matrices_1.append(proj_mat)
|
||||
|
||||
proj_mat = extrinsics.copy()
|
||||
intrinsics[:2,:] *= 2
|
||||
proj_mat[:3, :4] = np.matmul(intrinsics, proj_mat[:3, :4])
|
||||
proj_matrices_0.append(proj_mat)
|
||||
|
||||
if i == 0: # reference view
|
||||
depth_min = depth_min_ * scale
|
||||
depth_max = depth_max_ * scale
|
||||
depth, mask = self.read_depth_mask(depth_filename, mask_filename, scale)
|
||||
|
||||
for l in range(self.levels):
|
||||
mask[f'level_{l}'] = np.expand_dims(mask[f'level_{l}'],2)
|
||||
mask[f'level_{l}'] = mask[f'level_{l}'].transpose([2,0,1])
|
||||
depth[f'level_{l}'] = np.expand_dims(depth[f'level_{l}'],2)
|
||||
depth[f'level_{l}'] = depth[f'level_{l}'].transpose([2,0,1])
|
||||
|
||||
# imgs: N*3*H0*W0, N is number of images
|
||||
imgs_0 = np.stack(imgs_0).transpose([0, 3, 1, 2])
|
||||
imgs_1 = np.stack(imgs_1).transpose([0, 3, 1, 2])
|
||||
imgs_2 = np.stack(imgs_2).transpose([0, 3, 1, 2])
|
||||
imgs_3 = np.stack(imgs_3).transpose([0, 3, 1, 2])
|
||||
|
||||
imgs = {}
|
||||
imgs['level_0'] = imgs_0
|
||||
imgs['level_1'] = imgs_1
|
||||
imgs['level_2'] = imgs_2
|
||||
imgs['level_3'] = imgs_3
|
||||
|
||||
# proj_matrices: N*4*4
|
||||
proj_matrices_0 = np.stack(proj_matrices_0)
|
||||
proj_matrices_1 = np.stack(proj_matrices_1)
|
||||
proj_matrices_2 = np.stack(proj_matrices_2)
|
||||
proj_matrices_3 = np.stack(proj_matrices_3)
|
||||
|
||||
proj={}
|
||||
proj['level_3']=proj_matrices_3
|
||||
proj['level_2']=proj_matrices_2
|
||||
proj['level_1']=proj_matrices_1
|
||||
proj['level_0']=proj_matrices_0
|
||||
|
||||
|
||||
# data is numpy array
|
||||
return {"imgs": imgs, # [N, 3, H, W]
|
||||
"proj_matrices": proj, # [N,4,4]
|
||||
"depth": depth, # [1, H, W]
|
||||
"depth_min": depth_min, # scalar
|
||||
"depth_max": depth_max, # scalar
|
||||
"mask": mask} # [1, H, W]
|
||||
|
158
IGEV-MVS/datasets/dtu_yao_eval.py
Normal file
158
IGEV-MVS/datasets/dtu_yao_eval.py
Normal file
@ -0,0 +1,158 @@
|
||||
from torch.utils.data import Dataset
|
||||
import numpy as np
|
||||
import os
|
||||
from PIL import Image
|
||||
from datasets.data_io import *
|
||||
import cv2
|
||||
|
||||
|
||||
class MVSDataset(Dataset):
|
||||
def __init__(self, datapath, listfile, nviews=5, img_wh=(1600, 1152)):
|
||||
super(MVSDataset, self).__init__()
|
||||
self.levels = 4
|
||||
self.datapath = datapath
|
||||
self.listfile = listfile
|
||||
self.nviews = nviews
|
||||
self.img_wh = img_wh
|
||||
self.metas = self.build_list()
|
||||
|
||||
def build_list(self):
|
||||
metas = []
|
||||
with open(self.listfile) as f:
|
||||
scans = f.readlines()
|
||||
scans = [line.rstrip() for line in scans]
|
||||
|
||||
for scan in scans:
|
||||
pair_file = "{}/pair.txt".format(scan)
|
||||
# read the pair file
|
||||
with open(os.path.join(self.datapath, pair_file)) as f:
|
||||
num_viewpoint = int(f.readline())
|
||||
# viewpoints (49)
|
||||
for view_idx in range(num_viewpoint):
|
||||
ref_view = int(f.readline().rstrip())
|
||||
src_views = [int(x) for x in f.readline().rstrip().split()[1::2]]
|
||||
metas.append((scan, ref_view, src_views))
|
||||
print("dataset", "metas:", len(metas))
|
||||
return metas
|
||||
|
||||
def __len__(self):
|
||||
return len(self.metas)
|
||||
|
||||
def read_cam_file(self, filename):
|
||||
with open(filename) as f:
|
||||
lines = f.readlines()
|
||||
lines = [line.rstrip() for line in lines]
|
||||
# extrinsics: line [1,5), 4x4 matrix
|
||||
extrinsics = np.fromstring(' '.join(lines[1:5]), dtype=np.float32, sep=' ').reshape((4, 4))
|
||||
# intrinsics: line [7-10), 3x3 matrix
|
||||
intrinsics = np.fromstring(' '.join(lines[7:10]), dtype=np.float32, sep=' ').reshape((3, 3))
|
||||
|
||||
depth_min = float(lines[11].split()[0])
|
||||
depth_max = float(lines[11].split()[-1])
|
||||
return intrinsics, extrinsics, depth_min, depth_max
|
||||
|
||||
|
||||
def read_mask(self, filename):
|
||||
img = Image.open(filename)
|
||||
np_img = np.array(img, dtype=np.float32)
|
||||
np_img = (np_img > 10).astype(np.float32)
|
||||
return np_img
|
||||
|
||||
def read_img(self, filename):
|
||||
img = Image.open(filename)
|
||||
# scale 0~255 to -1~1
|
||||
np_img = 2*np.array(img, dtype=np.float32) / 255. - 1
|
||||
np_img = cv2.resize(np_img, self.img_wh, interpolation=cv2.INTER_LINEAR)
|
||||
|
||||
h, w, _ = np_img.shape
|
||||
np_img_ms = {
|
||||
"level_3": cv2.resize(np_img, (w//8, h//8), interpolation=cv2.INTER_LINEAR),
|
||||
"level_2": cv2.resize(np_img, (w//4, h//4), interpolation=cv2.INTER_LINEAR),
|
||||
"level_1": cv2.resize(np_img, (w//2, h//2), interpolation=cv2.INTER_LINEAR),
|
||||
"level_0": np_img
|
||||
}
|
||||
return np_img_ms
|
||||
|
||||
def __getitem__(self, idx):
|
||||
scan, ref_view, src_views = self.metas[idx]
|
||||
# use only the reference view and first nviews-1 source views
|
||||
view_ids = [ref_view] + src_views[:self.nviews - 1]
|
||||
img_w = 1600
|
||||
img_h = 1200
|
||||
imgs_0 = []
|
||||
imgs_1 = []
|
||||
imgs_2 = []
|
||||
imgs_3 = []
|
||||
|
||||
depth_min = None
|
||||
depth_max = None
|
||||
|
||||
proj_matrices_0 = []
|
||||
proj_matrices_1 = []
|
||||
proj_matrices_2 = []
|
||||
proj_matrices_3 = []
|
||||
|
||||
for i, vid in enumerate(view_ids):
|
||||
img_filename = os.path.join(self.datapath, '{}/images/{:0>8}.jpg'.format(scan, vid))
|
||||
proj_mat_filename = os.path.join(self.datapath, '{}/cams_1/{:0>8}_cam.txt'.format(scan, vid))
|
||||
|
||||
imgs = self.read_img(img_filename)
|
||||
imgs_0.append(imgs['level_0'])
|
||||
imgs_1.append(imgs['level_1'])
|
||||
imgs_2.append(imgs['level_2'])
|
||||
imgs_3.append(imgs['level_3'])
|
||||
|
||||
intrinsics, extrinsics, depth_min_, depth_max_ = self.read_cam_file(proj_mat_filename)
|
||||
intrinsics[0] *= self.img_wh[0]/img_w
|
||||
intrinsics[1] *= self.img_wh[1]/img_h
|
||||
proj_mat = extrinsics.copy()
|
||||
intrinsics[:2,:] *= 0.125
|
||||
proj_mat[:3, :4] = np.matmul(intrinsics, proj_mat[:3, :4])
|
||||
proj_matrices_3.append(proj_mat)
|
||||
|
||||
proj_mat = extrinsics.copy()
|
||||
intrinsics[:2,:] *= 2
|
||||
proj_mat[:3, :4] = np.matmul(intrinsics, proj_mat[:3, :4])
|
||||
proj_matrices_2.append(proj_mat)
|
||||
|
||||
proj_mat = extrinsics.copy()
|
||||
intrinsics[:2,:] *= 2
|
||||
proj_mat[:3, :4] = np.matmul(intrinsics, proj_mat[:3, :4])
|
||||
proj_matrices_1.append(proj_mat)
|
||||
|
||||
proj_mat = extrinsics.copy()
|
||||
intrinsics[:2,:] *= 2
|
||||
proj_mat[:3, :4] = np.matmul(intrinsics, proj_mat[:3, :4])
|
||||
proj_matrices_0.append(proj_mat)
|
||||
|
||||
|
||||
if i == 0: # reference view
|
||||
depth_min = depth_min_
|
||||
depth_max = depth_max_
|
||||
|
||||
imgs_0 = np.stack(imgs_0).transpose([0, 3, 1, 2])
|
||||
imgs_1 = np.stack(imgs_1).transpose([0, 3, 1, 2])
|
||||
imgs_2 = np.stack(imgs_2).transpose([0, 3, 1, 2])
|
||||
imgs_3 = np.stack(imgs_3).transpose([0, 3, 1, 2])
|
||||
imgs = {}
|
||||
imgs['level_0'] = imgs_0
|
||||
imgs['level_1'] = imgs_1
|
||||
imgs['level_2'] = imgs_2
|
||||
imgs['level_3'] = imgs_3
|
||||
# proj_matrices: N*4*4
|
||||
proj_matrices_0 = np.stack(proj_matrices_0)
|
||||
proj_matrices_1 = np.stack(proj_matrices_1)
|
||||
proj_matrices_2 = np.stack(proj_matrices_2)
|
||||
proj_matrices_3 = np.stack(proj_matrices_3)
|
||||
proj={}
|
||||
proj['level_3']=proj_matrices_3
|
||||
proj['level_2']=proj_matrices_2
|
||||
proj['level_1']=proj_matrices_1
|
||||
proj['level_0']=proj_matrices_0
|
||||
|
||||
|
||||
return {"imgs": imgs, # N*3*H0*W0
|
||||
"proj_matrices": proj, # N*4*4
|
||||
"depth_min": depth_min, # scalar
|
||||
"depth_max": depth_max, # scalar
|
||||
"filename": scan + '/{}/' + '{:0>8}'.format(view_ids[0]) + "{}"}
|
158
IGEV-MVS/datasets/eth3d.py
Normal file
158
IGEV-MVS/datasets/eth3d.py
Normal file
@ -0,0 +1,158 @@
|
||||
from torch.utils.data import Dataset
|
||||
from datasets.data_io import *
|
||||
import os
|
||||
import numpy as np
|
||||
import cv2
|
||||
from PIL import Image
|
||||
|
||||
class MVSDataset(Dataset):
|
||||
def __init__(self, datapath, split='test', n_views=7, img_wh=(1920,1280)):
|
||||
self.levels = 4
|
||||
self.datapath = datapath
|
||||
self.img_wh = img_wh
|
||||
self.split = split
|
||||
self.build_metas()
|
||||
self.n_views = n_views
|
||||
|
||||
def build_metas(self):
|
||||
self.metas = []
|
||||
if self.split == "test":
|
||||
self.scans = ['botanical_garden', 'boulders', 'bridge', 'door',
|
||||
'exhibition_hall', 'lecture_room', 'living_room', 'lounge',
|
||||
'observatory', 'old_computer', 'statue', 'terrace_2']
|
||||
|
||||
elif self.split == "train":
|
||||
self.scans = ['courtyard', 'delivery_area', 'electro', 'facade',
|
||||
'kicker', 'meadow', 'office', 'pipes', 'playground',
|
||||
'relief', 'relief_2', 'terrace', 'terrains']
|
||||
|
||||
|
||||
for scan in self.scans:
|
||||
with open(os.path.join(self.datapath, scan, 'pair.txt')) as f:
|
||||
num_viewpoint = int(f.readline())
|
||||
for view_idx in range(num_viewpoint):
|
||||
ref_view = int(f.readline().rstrip())
|
||||
src_views = [int(x) for x in f.readline().rstrip().split()[1::2]]
|
||||
if len(src_views) != 0:
|
||||
self.metas += [(scan, -1, ref_view, src_views)]
|
||||
|
||||
|
||||
def read_cam_file(self, filename):
|
||||
with open(filename) as f:
|
||||
lines = [line.rstrip() for line in f.readlines()]
|
||||
# extrinsics: line [1,5), 4x4 matrix
|
||||
extrinsics = np.fromstring(' '.join(lines[1:5]), dtype=np.float32, sep=' ')
|
||||
extrinsics = extrinsics.reshape((4, 4))
|
||||
# intrinsics: line [7-10), 3x3 matrix
|
||||
intrinsics = np.fromstring(' '.join(lines[7:10]), dtype=np.float32, sep=' ')
|
||||
intrinsics = intrinsics.reshape((3, 3))
|
||||
|
||||
depth_min = float(lines[11].split()[0])
|
||||
if depth_min < 0:
|
||||
depth_min = 1
|
||||
depth_max = float(lines[11].split()[-1])
|
||||
|
||||
return intrinsics, extrinsics, depth_min, depth_max
|
||||
|
||||
def read_img(self, filename, h, w):
|
||||
img = Image.open(filename)
|
||||
# scale 0~255 to -1~1
|
||||
np_img = 2*np.array(img, dtype=np.float32) / 255. - 1
|
||||
original_h, original_w, _ = np_img.shape
|
||||
np_img = cv2.resize(np_img, self.img_wh, interpolation=cv2.INTER_LINEAR)
|
||||
|
||||
np_img_ms = {
|
||||
"level_3": cv2.resize(np_img, (w//8, h//8), interpolation=cv2.INTER_LINEAR),
|
||||
"level_2": cv2.resize(np_img, (w//4, h//4), interpolation=cv2.INTER_LINEAR),
|
||||
"level_1": cv2.resize(np_img, (w//2, h//2), interpolation=cv2.INTER_LINEAR),
|
||||
"level_0": np_img
|
||||
}
|
||||
return np_img_ms, original_h, original_w
|
||||
|
||||
def __len__(self):
|
||||
return len(self.metas)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
scan, _, ref_view, src_views = self.metas[idx]
|
||||
# use only the reference view and first nviews-1 source views
|
||||
view_ids = [ref_view] + src_views[:self.n_views-1]
|
||||
imgs_0 = []
|
||||
imgs_1 = []
|
||||
imgs_2 = []
|
||||
imgs_3 = []
|
||||
|
||||
# depth = None
|
||||
depth_min = None
|
||||
depth_max = None
|
||||
|
||||
proj_matrices_0 = []
|
||||
proj_matrices_1 = []
|
||||
proj_matrices_2 = []
|
||||
proj_matrices_3 = []
|
||||
|
||||
for i, vid in enumerate(view_ids):
|
||||
img_filename = os.path.join(self.datapath, scan, f'images/{vid:08d}.jpg')
|
||||
proj_mat_filename = os.path.join(self.datapath, scan, f'cams_1/{vid:08d}_cam.txt')
|
||||
|
||||
imgs, original_h, original_w = self.read_img(img_filename,self.img_wh[1], self.img_wh[0])
|
||||
imgs_0.append(imgs['level_0'])
|
||||
imgs_1.append(imgs['level_1'])
|
||||
imgs_2.append(imgs['level_2'])
|
||||
imgs_3.append(imgs['level_3'])
|
||||
|
||||
intrinsics, extrinsics, depth_min_, depth_max_ = self.read_cam_file(proj_mat_filename)
|
||||
intrinsics[0] *= self.img_wh[0]/original_w
|
||||
intrinsics[1] *= self.img_wh[1]/original_h
|
||||
|
||||
proj_mat = extrinsics.copy()
|
||||
intrinsics[:2,:] *= 0.125
|
||||
proj_mat[:3, :4] = np.matmul(intrinsics, proj_mat[:3, :4])
|
||||
proj_matrices_3.append(proj_mat)
|
||||
|
||||
proj_mat = extrinsics.copy()
|
||||
intrinsics[:2,:] *= 2
|
||||
proj_mat[:3, :4] = np.matmul(intrinsics, proj_mat[:3, :4])
|
||||
proj_matrices_2.append(proj_mat)
|
||||
|
||||
proj_mat = extrinsics.copy()
|
||||
intrinsics[:2,:] *= 2
|
||||
proj_mat[:3, :4] = np.matmul(intrinsics, proj_mat[:3, :4])
|
||||
proj_matrices_1.append(proj_mat)
|
||||
|
||||
proj_mat = extrinsics.copy()
|
||||
intrinsics[:2,:] *= 2
|
||||
proj_mat[:3, :4] = np.matmul(intrinsics, proj_mat[:3, :4])
|
||||
proj_matrices_0.append(proj_mat)
|
||||
|
||||
if i == 0: # reference view
|
||||
depth_min = depth_min_
|
||||
depth_max = depth_max_
|
||||
|
||||
# imgs: N*3*H0*W0, N is number of images
|
||||
imgs_0 = np.stack(imgs_0).transpose([0, 3, 1, 2])
|
||||
imgs_1 = np.stack(imgs_1).transpose([0, 3, 1, 2])
|
||||
imgs_2 = np.stack(imgs_2).transpose([0, 3, 1, 2])
|
||||
imgs_3 = np.stack(imgs_3).transpose([0, 3, 1, 2])
|
||||
imgs = {}
|
||||
imgs['level_0'] = imgs_0
|
||||
imgs['level_1'] = imgs_1
|
||||
imgs['level_2'] = imgs_2
|
||||
imgs['level_3'] = imgs_3
|
||||
# proj_matrices: N*4*4
|
||||
proj_matrices_0 = np.stack(proj_matrices_0)
|
||||
proj_matrices_1 = np.stack(proj_matrices_1)
|
||||
proj_matrices_2 = np.stack(proj_matrices_2)
|
||||
proj_matrices_3 = np.stack(proj_matrices_3)
|
||||
proj={}
|
||||
proj['level_3']=proj_matrices_3
|
||||
proj['level_2']=proj_matrices_2
|
||||
proj['level_1']=proj_matrices_1
|
||||
proj['level_0']=proj_matrices_0
|
||||
|
||||
|
||||
return {"imgs": imgs, # N*3*H0*W0
|
||||
"proj_matrices": proj, # N*4*4
|
||||
"depth_min": depth_min, # scalar
|
||||
"depth_max": depth_max,
|
||||
"filename": scan + '/{}/' + '{:0>8}'.format(view_ids[0]) + "{}"
|
||||
}
|
156
IGEV-MVS/datasets/tanks.py
Normal file
156
IGEV-MVS/datasets/tanks.py
Normal file
@ -0,0 +1,156 @@
|
||||
from torch.utils.data import Dataset
|
||||
from datasets.data_io import *
|
||||
import os
|
||||
import numpy as np
|
||||
import cv2
|
||||
from PIL import Image
|
||||
|
||||
class MVSDataset(Dataset):
|
||||
def __init__(self, datapath, n_views=7, img_wh=(1920, 1024), split='intermediate'):
|
||||
self.levels = 4
|
||||
self.datapath = datapath
|
||||
self.img_wh = img_wh
|
||||
self.split = split
|
||||
self.build_metas()
|
||||
self.n_views = n_views
|
||||
|
||||
def build_metas(self):
|
||||
self.metas = []
|
||||
if self.split == 'intermediate':
|
||||
self.scans = ['Family', 'Francis', 'Horse', 'Lighthouse',
|
||||
'M60', 'Panther', 'Playground', 'Train']
|
||||
|
||||
elif self.split == 'advanced':
|
||||
self.scans = ['Auditorium', 'Ballroom', 'Courtroom',
|
||||
'Museum', 'Palace', 'Temple']
|
||||
|
||||
for scan in self.scans:
|
||||
with open(os.path.join(self.datapath, self.split, scan, 'pair.txt')) as f:
|
||||
num_viewpoint = int(f.readline())
|
||||
for view_idx in range(num_viewpoint):
|
||||
ref_view = int(f.readline().rstrip())
|
||||
src_views = [int(x) for x in f.readline().rstrip().split()[1::2]]
|
||||
if len(src_views) != 0:
|
||||
self.metas += [(scan, -1, ref_view, src_views)]
|
||||
|
||||
def read_cam_file(self, filename):
|
||||
with open(filename) as f:
|
||||
lines = [line.rstrip() for line in f.readlines()]
|
||||
# extrinsics: line [1,5), 4x4 matrix
|
||||
extrinsics = np.fromstring(' '.join(lines[1:5]), dtype=np.float32, sep=' ')
|
||||
extrinsics = extrinsics.reshape((4, 4))
|
||||
# intrinsics: line [7-10), 3x3 matrix
|
||||
intrinsics = np.fromstring(' '.join(lines[7:10]), dtype=np.float32, sep=' ')
|
||||
intrinsics = intrinsics.reshape((3, 3))
|
||||
|
||||
depth_min = float(lines[11].split()[0])
|
||||
depth_max = float(lines[11].split()[-1])
|
||||
|
||||
return intrinsics, extrinsics, depth_min, depth_max
|
||||
|
||||
def read_img(self, filename, h, w):
|
||||
img = Image.open(filename)
|
||||
# scale 0~255 to -1~1
|
||||
np_img = 2*np.array(img, dtype=np.float32) / 255. - 1
|
||||
original_h, original_w, _ = np_img.shape
|
||||
np_img = cv2.resize(np_img, self.img_wh, interpolation=cv2.INTER_LINEAR)
|
||||
|
||||
np_img_ms = {
|
||||
"level_3": cv2.resize(np_img, (w//8, h//8), interpolation=cv2.INTER_LINEAR),
|
||||
"level_2": cv2.resize(np_img, (w//4, h//4), interpolation=cv2.INTER_LINEAR),
|
||||
"level_1": cv2.resize(np_img, (w//2, h//2), interpolation=cv2.INTER_LINEAR),
|
||||
"level_0": np_img
|
||||
}
|
||||
return np_img_ms, original_h, original_w
|
||||
|
||||
def __len__(self):
|
||||
return len(self.metas)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
scan, _, ref_view, src_views = self.metas[idx]
|
||||
# use only the reference view and first nviews-1 source views
|
||||
view_ids = [ref_view] + src_views[:self.n_views-1]
|
||||
|
||||
imgs_0 = []
|
||||
imgs_1 = []
|
||||
imgs_2 = []
|
||||
imgs_3 = []
|
||||
|
||||
# depth = None
|
||||
depth_min = None
|
||||
depth_max = None
|
||||
|
||||
proj_matrices_0 = []
|
||||
proj_matrices_1 = []
|
||||
proj_matrices_2 = []
|
||||
proj_matrices_3 = []
|
||||
|
||||
for i, vid in enumerate(view_ids):
|
||||
img_filename = os.path.join(self.datapath, self.split, scan, f'images/{vid:08d}.jpg')
|
||||
proj_mat_filename = os.path.join(self.datapath, self.split, scan, f'cams_1/{vid:08d}_cam.txt')
|
||||
|
||||
imgs, original_h, original_w = self.read_img(img_filename,self.img_wh[1], self.img_wh[0])
|
||||
imgs_0.append(imgs['level_0'])
|
||||
imgs_1.append(imgs['level_1'])
|
||||
imgs_2.append(imgs['level_2'])
|
||||
imgs_3.append(imgs['level_3'])
|
||||
|
||||
intrinsics, extrinsics, depth_min_, depth_max_ = self.read_cam_file(proj_mat_filename)
|
||||
intrinsics[0] *= self.img_wh[0]/original_w
|
||||
intrinsics[1] *= self.img_wh[1]/original_h
|
||||
|
||||
proj_mat = extrinsics.copy()
|
||||
intrinsics[:2,:] *= 0.125
|
||||
proj_mat[:3, :4] = np.matmul(intrinsics, proj_mat[:3, :4])
|
||||
proj_matrices_3.append(proj_mat)
|
||||
|
||||
proj_mat = extrinsics.copy()
|
||||
intrinsics[:2,:] *= 2
|
||||
proj_mat[:3, :4] = np.matmul(intrinsics, proj_mat[:3, :4])
|
||||
proj_matrices_2.append(proj_mat)
|
||||
|
||||
proj_mat = extrinsics.copy()
|
||||
intrinsics[:2,:] *= 2
|
||||
proj_mat[:3, :4] = np.matmul(intrinsics, proj_mat[:3, :4])
|
||||
proj_matrices_1.append(proj_mat)
|
||||
|
||||
proj_mat = extrinsics.copy()
|
||||
intrinsics[:2,:] *= 2
|
||||
proj_mat[:3, :4] = np.matmul(intrinsics, proj_mat[:3, :4])
|
||||
proj_matrices_0.append(proj_mat)
|
||||
|
||||
|
||||
if i == 0: # reference view
|
||||
depth_min = depth_min_
|
||||
depth_max = depth_max_
|
||||
|
||||
# imgs: N*3*H0*W0, N is number of images
|
||||
imgs_0 = np.stack(imgs_0).transpose([0, 3, 1, 2])
|
||||
imgs_1 = np.stack(imgs_1).transpose([0, 3, 1, 2])
|
||||
imgs_2 = np.stack(imgs_2).transpose([0, 3, 1, 2])
|
||||
imgs_3 = np.stack(imgs_3).transpose([0, 3, 1, 2])
|
||||
imgs = {}
|
||||
imgs['level_0'] = imgs_0
|
||||
imgs['level_1'] = imgs_1
|
||||
imgs['level_2'] = imgs_2
|
||||
imgs['level_3'] = imgs_3
|
||||
# proj_matrices: N*4*4
|
||||
proj_matrices_0 = np.stack(proj_matrices_0)
|
||||
proj_matrices_1 = np.stack(proj_matrices_1)
|
||||
proj_matrices_2 = np.stack(proj_matrices_2)
|
||||
proj_matrices_3 = np.stack(proj_matrices_3)
|
||||
proj={}
|
||||
proj['level_3']=proj_matrices_3
|
||||
proj['level_2']=proj_matrices_2
|
||||
proj['level_1']=proj_matrices_1
|
||||
proj['level_0']=proj_matrices_0
|
||||
|
||||
|
||||
|
||||
|
||||
return {"imgs": imgs, # N*3*H0*W0
|
||||
"proj_matrices": proj, # N*4*4
|
||||
"depth_min": depth_min, # scalar
|
||||
"depth_max": depth_max,
|
||||
"filename": scan + '/{}/' + '{:0>8}'.format(view_ids[0]) + "{}"
|
||||
}
|
450
IGEV-MVS/evaluate_mvs.py
Normal file
450
IGEV-MVS/evaluate_mvs.py
Normal file
@ -0,0 +1,450 @@
|
||||
import argparse
|
||||
import os
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.parallel
|
||||
import torch.backends.cudnn as cudnn
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
import time
|
||||
from datasets import find_dataset_def
|
||||
from core.igev_mvs import IGEVMVS
|
||||
from utils import *
|
||||
import sys
|
||||
import cv2
|
||||
from datasets.data_io import read_pfm, save_pfm
|
||||
from core.submodule import depth_unnormalization
|
||||
from plyfile import PlyData, PlyElement
|
||||
from tqdm import tqdm
|
||||
from PIL import Image
|
||||
|
||||
cudnn.benchmark = True
|
||||
|
||||
parser = argparse.ArgumentParser(description='Predict depth, filter, and fuse')
|
||||
parser.add_argument('--model', default='IterMVS', help='select model')
|
||||
parser.add_argument('--dataset', default='dtu_yao_eval', help='select dataset')
|
||||
parser.add_argument('--testpath', default='/data/dtu_data/dtu_test/', help='testing data path')
|
||||
parser.add_argument('--testlist', default='./lists/dtu/test.txt', help='testing scan list')
|
||||
parser.add_argument('--maxdisp', default=256)
|
||||
parser.add_argument('--split', default='intermediate', help='select data')
|
||||
parser.add_argument('--batch_size', type=int, default=2, help='testing batch size')
|
||||
parser.add_argument('--n_views', type=int, default=5, help='num of view')
|
||||
parser.add_argument('--img_wh', nargs='+', type=int, default=[640, 480],
|
||||
help='height and width of the image')
|
||||
parser.add_argument('--loadckpt', default='./pretrained_models/dtu.ckpt', help='load a specific checkpoint')
|
||||
parser.add_argument('--outdir', default='./output/', help='output dir')
|
||||
parser.add_argument('--display', action='store_true', help='display depth images and masks')
|
||||
parser.add_argument('--iteration', type=int, default=32, help='num of iteration of GRU')
|
||||
parser.add_argument('--geo_pixel_thres', type=float, default=1, help='pixel threshold for geometric consistency filtering')
|
||||
parser.add_argument('--geo_depth_thres', type=float, default=0.01, help='depth threshold for geometric consistency filtering')
|
||||
parser.add_argument('--photo_thres', type=float, default=0.3, help='threshold for photometric consistency filtering')
|
||||
|
||||
# parse arguments and check
|
||||
args = parser.parse_args()
|
||||
print("argv:", sys.argv[1:])
|
||||
print_args(args)
|
||||
|
||||
if args.dataset=="dtu_yao_eval":
|
||||
img_wh=(1600, 1152)
|
||||
elif args.dataset=="tanks":
|
||||
img_wh=(1920, 1024)
|
||||
elif args.dataset=="eth3d":
|
||||
img_wh = (1920,1280)
|
||||
else:
|
||||
img_wh = (args.img_wh[0], args.img_wh[1]) # custom dataset
|
||||
|
||||
# read intrinsics and extrinsics
|
||||
def read_camera_parameters(filename):
|
||||
with open(filename) as f:
|
||||
lines = f.readlines()
|
||||
lines = [line.rstrip() for line in lines]
|
||||
# extrinsics: line [1,5), 4x4 matrix
|
||||
extrinsics = np.fromstring(' '.join(lines[1:5]), dtype=np.float32, sep=' ').reshape((4, 4))
|
||||
# intrinsics: line [7-10), 3x3 matrix
|
||||
intrinsics = np.fromstring(' '.join(lines[7:10]), dtype=np.float32, sep=' ').reshape((3, 3))
|
||||
|
||||
return intrinsics, extrinsics
|
||||
|
||||
|
||||
# read an image
|
||||
def read_img(filename, img_wh):
|
||||
img = Image.open(filename)
|
||||
# scale 0~255 to 0~1
|
||||
np_img = np.array(img, dtype=np.float32) / 255.
|
||||
original_h, original_w, _ = np_img.shape
|
||||
np_img = cv2.resize(np_img, img_wh, interpolation=cv2.INTER_LINEAR)
|
||||
return np_img, original_h, original_w
|
||||
|
||||
|
||||
# save a binary mask
|
||||
def save_mask(filename, mask):
|
||||
assert mask.dtype == np.bool_
|
||||
mask = mask.astype(np.uint8) * 255
|
||||
Image.fromarray(mask).save(filename)
|
||||
|
||||
def save_depth_img(filename, depth):
|
||||
# assert mask.dtype == np.bool
|
||||
depth = depth.astype(np.float32) * 255
|
||||
Image.fromarray(depth).save(filename)
|
||||
|
||||
|
||||
def read_pair_file(filename):
|
||||
data = []
|
||||
with open(filename) as f:
|
||||
num_viewpoint = int(f.readline())
|
||||
# 49 viewpoints
|
||||
for view_idx in range(num_viewpoint):
|
||||
ref_view = int(f.readline().rstrip())
|
||||
src_views = [int(x) for x in f.readline().rstrip().split()[1::2]]
|
||||
if len(src_views) != 0:
|
||||
data.append((ref_view, src_views))
|
||||
return data
|
||||
|
||||
|
||||
# run MVS model to save depth maps
|
||||
def save_depth():
|
||||
# dataset, dataloader
|
||||
MVSDataset = find_dataset_def(args.dataset)
|
||||
if args.dataset=="dtu_yao_eval":
|
||||
test_dataset = MVSDataset(args.testpath, args.testlist, args.n_views, img_wh)
|
||||
elif args.dataset=="tanks":
|
||||
test_dataset = MVSDataset(args.testpath, args.n_views, img_wh, args.split)
|
||||
elif args.dataset=="eth3d":
|
||||
test_dataset = MVSDataset(args.testpath, args.split, args.n_views, img_wh)
|
||||
else:
|
||||
test_dataset = MVSDataset(args.testpath, args.n_views, img_wh)
|
||||
TestImgLoader = DataLoader(test_dataset, args.batch_size, shuffle=False, num_workers=4, drop_last=False)
|
||||
|
||||
# model
|
||||
model = IGEVMVS(args)
|
||||
model = nn.DataParallel(model)
|
||||
model.cuda()
|
||||
|
||||
# load checkpoint file specified by args.loadckpt
|
||||
print("loading model {}".format(args.loadckpt))
|
||||
state_dict = torch.load(args.loadckpt)
|
||||
model.load_state_dict(state_dict['model'])
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
tbar = tqdm(TestImgLoader)
|
||||
for batch_idx, sample in enumerate(tbar):
|
||||
start_time = time.time()
|
||||
sample_cuda = tocuda(sample)
|
||||
disp_prediction = model(sample_cuda["imgs"], sample_cuda["proj_matrices"],
|
||||
sample_cuda["depth_min"], sample_cuda["depth_max"], test_mode=True)
|
||||
|
||||
b = sample_cuda["depth_min"].shape[0]
|
||||
|
||||
inverse_depth_min = (1.0 / sample_cuda["depth_min"]).view(b, 1, 1, 1)
|
||||
inverse_depth_max = (1.0 / sample_cuda["depth_max"]).view(b, 1, 1, 1)
|
||||
|
||||
depth_prediction = depth_unnormalization(disp_prediction, inverse_depth_min, inverse_depth_max)
|
||||
depth_prediction = tensor2numpy(depth_prediction.float())
|
||||
del sample_cuda, disp_prediction
|
||||
tbar.set_description('Iter {}/{}, time = {:.3f}'.format(batch_idx, len(TestImgLoader), time.time() - start_time))
|
||||
filenames = sample["filename"]
|
||||
|
||||
# save depth maps and confidence maps
|
||||
for filename, depth_est in zip(filenames, depth_prediction):
|
||||
depth_filename = os.path.join(args.outdir, filename.format('depth_est', '.pfm'))
|
||||
os.makedirs(depth_filename.rsplit('/', 1)[0], exist_ok=True)
|
||||
# save depth maps
|
||||
depth_est = np.squeeze(depth_est, 0)
|
||||
save_pfm(depth_filename, depth_est)
|
||||
|
||||
# project the reference point cloud into the source view, then project back
|
||||
def reproject_with_depth(depth_ref, intrinsics_ref, extrinsics_ref, depth_src, intrinsics_src, extrinsics_src):
|
||||
width, height = depth_ref.shape[1], depth_ref.shape[0]
|
||||
## step1. project reference pixels to the source view
|
||||
# reference view x, y
|
||||
x_ref, y_ref = np.meshgrid(np.arange(0, width), np.arange(0, height))
|
||||
x_ref, y_ref = x_ref.reshape([-1]), y_ref.reshape([-1])
|
||||
# reference 3D space
|
||||
xyz_ref = np.matmul(np.linalg.inv(intrinsics_ref),
|
||||
np.vstack((x_ref, y_ref, np.ones_like(x_ref))) * depth_ref.reshape([-1]))
|
||||
# source 3D space
|
||||
xyz_src = np.matmul(np.matmul(extrinsics_src, np.linalg.inv(extrinsics_ref)),
|
||||
np.vstack((xyz_ref, np.ones_like(x_ref))))[:3]
|
||||
# source view x, y
|
||||
K_xyz_src = np.matmul(intrinsics_src, xyz_src)
|
||||
xy_src = K_xyz_src[:2] / K_xyz_src[2:3]
|
||||
|
||||
## step2. reproject the source view points with source view depth estimation
|
||||
# find the depth estimation of the source view
|
||||
x_src = xy_src[0].reshape([height, width]).astype(np.float32)
|
||||
y_src = xy_src[1].reshape([height, width]).astype(np.float32)
|
||||
sampled_depth_src = cv2.remap(depth_src, x_src, y_src, interpolation=cv2.INTER_LINEAR)
|
||||
# mask = sampled_depth_src > 0
|
||||
|
||||
# source 3D space
|
||||
# NOTE that we should use sampled source-view depth_here to project back
|
||||
xyz_src = np.matmul(np.linalg.inv(intrinsics_src),
|
||||
np.vstack((xy_src, np.ones_like(x_ref))) * sampled_depth_src.reshape([-1]))
|
||||
# reference 3D space
|
||||
xyz_reprojected = np.matmul(np.matmul(extrinsics_ref, np.linalg.inv(extrinsics_src)),
|
||||
np.vstack((xyz_src, np.ones_like(x_ref))))[:3]
|
||||
# source view x, y, depth
|
||||
depth_reprojected = xyz_reprojected[2].reshape([height, width]).astype(np.float32)
|
||||
K_xyz_reprojected = np.matmul(intrinsics_ref, xyz_reprojected)
|
||||
xy_reprojected = K_xyz_reprojected[:2] / (K_xyz_reprojected[2:3]+1e-6)
|
||||
x_reprojected = xy_reprojected[0].reshape([height, width]).astype(np.float32)
|
||||
y_reprojected = xy_reprojected[1].reshape([height, width]).astype(np.float32)
|
||||
|
||||
return depth_reprojected, x_reprojected, y_reprojected, x_src, y_src
|
||||
|
||||
|
||||
def check_geometric_consistency(depth_ref, intrinsics_ref, extrinsics_ref, depth_src, intrinsics_src, extrinsics_src, thre1, thre2):
|
||||
width, height = depth_ref.shape[1], depth_ref.shape[0]
|
||||
x_ref, y_ref = np.meshgrid(np.arange(0, width), np.arange(0, height))
|
||||
depth_reprojected, x2d_reprojected, y2d_reprojected, x2d_src, y2d_src = reproject_with_depth(depth_ref,
|
||||
intrinsics_ref,
|
||||
extrinsics_ref,
|
||||
depth_src,
|
||||
intrinsics_src,
|
||||
extrinsics_src)
|
||||
# check |p_reproj-p_1| < 1
|
||||
dist = np.sqrt((x2d_reprojected - x_ref) ** 2 + (y2d_reprojected - y_ref) ** 2)
|
||||
|
||||
# check |d_reproj-d_1| / d_1 < 0.01
|
||||
depth_diff = np.abs(depth_reprojected - depth_ref)
|
||||
relative_depth_diff = depth_diff / depth_ref
|
||||
masks=[]
|
||||
for i in range(2,11):
|
||||
mask = np.logical_and(dist < i/thre1, relative_depth_diff < i/thre2)
|
||||
masks.append(mask)
|
||||
depth_reprojected[~mask] = 0
|
||||
|
||||
return masks, mask, depth_reprojected, x2d_src, y2d_src
|
||||
|
||||
|
||||
def filter_depth(scan_folder, out_folder, plyfilename, geo_pixel_thres, geo_depth_thres, photo_thres, img_wh, geo_mask_thres=3):
|
||||
# the pair file
|
||||
pair_file = os.path.join(scan_folder, "pair.txt")
|
||||
# for the final point cloud
|
||||
vertexs = []
|
||||
vertex_colors = []
|
||||
|
||||
pair_data = read_pair_file(pair_file)
|
||||
nviews = len(pair_data)
|
||||
|
||||
thre_left = -2
|
||||
thre_right = 2
|
||||
total_iter = 10
|
||||
for iter in range(total_iter):
|
||||
thre = (thre_left + thre_right) / 2
|
||||
print(f"{iter} {10 ** thre}")
|
||||
depth_est_averaged = []
|
||||
geo_mask_all = []
|
||||
# for each reference view and the corresponding source views
|
||||
for ref_view, src_views in pair_data:
|
||||
# load the camera parameters
|
||||
ref_intrinsics, ref_extrinsics = read_camera_parameters(
|
||||
os.path.join(scan_folder, 'cams_1/{:0>8}_cam.txt'.format(ref_view)))
|
||||
ref_img, original_h, original_w = read_img(os.path.join(scan_folder, 'images/{:0>8}.jpg'.format(ref_view)), img_wh)
|
||||
ref_intrinsics[0] *= img_wh[0]/original_w
|
||||
ref_intrinsics[1] *= img_wh[1]/original_h
|
||||
# load the estimated depth of the reference view
|
||||
ref_depth_est = read_pfm(os.path.join(out_folder, 'depth_est/{:0>8}.pfm'.format(ref_view)))[0]
|
||||
ref_depth_est = np.squeeze(ref_depth_est, 2)
|
||||
|
||||
all_srcview_depth_ests = []
|
||||
# compute the geometric mask
|
||||
geo_mask_sum = 0
|
||||
geo_mask_sums=[]
|
||||
n = 1 + len(src_views)
|
||||
ct = 0
|
||||
for src_view in src_views:
|
||||
ct = ct + 1
|
||||
# camera parameters of the source view
|
||||
src_intrinsics, src_extrinsics = read_camera_parameters(
|
||||
os.path.join(scan_folder, 'cams_1/{:0>8}_cam.txt'.format(src_view)))
|
||||
_, original_h, original_w = read_img(os.path.join(scan_folder, 'images/{:0>8}.jpg'.format(src_view)), img_wh)
|
||||
src_intrinsics[0] *= img_wh[0]/original_w
|
||||
src_intrinsics[1] *= img_wh[1]/original_h
|
||||
|
||||
# the estimated depth of the source view
|
||||
src_depth_est = read_pfm(os.path.join(out_folder, 'depth_est/{:0>8}.pfm'.format(src_view)))[0]
|
||||
|
||||
|
||||
masks, geo_mask, depth_reprojected, _, _ = check_geometric_consistency(ref_depth_est, ref_intrinsics, ref_extrinsics,
|
||||
src_depth_est,
|
||||
src_intrinsics, src_extrinsics, 10 ** thre * 4, 10 ** thre * 1300)
|
||||
if (ct==1):
|
||||
for i in range(2,n):
|
||||
geo_mask_sums.append(masks[i-2].astype(np.int32))
|
||||
else:
|
||||
for i in range(2,n):
|
||||
geo_mask_sums[i-2]+=masks[i-2].astype(np.int32)
|
||||
|
||||
geo_mask_sum+=geo_mask.astype(np.int32)
|
||||
all_srcview_depth_ests.append(depth_reprojected)
|
||||
|
||||
geo_mask=geo_mask_sum>=n
|
||||
for i in range (2,n):
|
||||
geo_mask=np.logical_or(geo_mask,geo_mask_sums[i-2]>=i)
|
||||
|
||||
depth_est_averaged.append((sum(all_srcview_depth_ests) + ref_depth_est) / (geo_mask_sum + 1))
|
||||
geo_mask_all.append(np.mean(geo_mask))
|
||||
final_mask = geo_mask
|
||||
|
||||
if iter == total_iter - 1:
|
||||
os.makedirs(os.path.join(out_folder, "mask"), exist_ok=True)
|
||||
save_mask(os.path.join(out_folder, "mask/{:0>8}_geo.png".format(ref_view)), geo_mask)
|
||||
save_mask(os.path.join(out_folder, "mask/{:0>8}_final.png".format(ref_view)), final_mask)
|
||||
|
||||
print("processing {}, ref-view{:0>2}, geo_mask:{:3f} final_mask: {:3f}".format(scan_folder, ref_view,
|
||||
geo_mask.mean(), final_mask.mean()))
|
||||
|
||||
if args.display:
|
||||
cv2.imshow('ref_img', ref_img[:, :, ::-1])
|
||||
cv2.imshow('ref_depth', ref_depth_est / np.max(ref_depth_est))
|
||||
cv2.imshow('ref_depth * geo_mask', ref_depth_est * geo_mask.astype(np.float32) / np.max(ref_depth_est))
|
||||
cv2.imshow('ref_depth * mask', ref_depth_est * final_mask.astype(np.float32) / np.max(ref_depth_est))
|
||||
cv2.waitKey(0)
|
||||
|
||||
height, width = depth_est_averaged[-1].shape[:2]
|
||||
x, y = np.meshgrid(np.arange(0, width), np.arange(0, height))
|
||||
|
||||
valid_points = final_mask
|
||||
# print("valid_points", valid_points.mean())
|
||||
x, y, depth = x[valid_points], y[valid_points], depth_est_averaged[-1][valid_points]
|
||||
|
||||
color = ref_img[valid_points]
|
||||
xyz_ref = np.matmul(np.linalg.inv(ref_intrinsics),
|
||||
np.vstack((x, y, np.ones_like(x))) * depth)
|
||||
xyz_world = np.matmul(np.linalg.inv(ref_extrinsics),
|
||||
np.vstack((xyz_ref, np.ones_like(x))))[:3]
|
||||
vertexs.append(xyz_world.transpose((1, 0)))
|
||||
vertex_colors.append((color * 255).astype(np.uint8))
|
||||
if np.mean(geo_mask_all) >= 0.25:
|
||||
thre_left = thre
|
||||
else:
|
||||
thre_right = thre
|
||||
vertexs = np.concatenate(vertexs, axis=0)
|
||||
vertex_colors = np.concatenate(vertex_colors, axis=0)
|
||||
vertexs = np.array([tuple(v) for v in vertexs], dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4')])
|
||||
vertex_colors = np.array([tuple(v) for v in vertex_colors], dtype=[('red', 'u1'), ('green', 'u1'), ('blue', 'u1')])
|
||||
|
||||
vertex_all = np.empty(len(vertexs), vertexs.dtype.descr + vertex_colors.dtype.descr)
|
||||
for prop in vertexs.dtype.names:
|
||||
vertex_all[prop] = vertexs[prop]
|
||||
for prop in vertex_colors.dtype.names:
|
||||
vertex_all[prop] = vertex_colors[prop]
|
||||
|
||||
el = PlyElement.describe(vertex_all, 'vertex')
|
||||
PlyData([el]).write(plyfilename)
|
||||
print("saving the final model to", plyfilename)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
save_depth()
|
||||
if args.dataset=="dtu_yao_eval":
|
||||
with open(args.testlist) as f:
|
||||
scans = f.readlines()
|
||||
scans = [line.rstrip() for line in scans]
|
||||
|
||||
for scan in scans:
|
||||
scan_id = int(scan[4:])
|
||||
scan_folder = os.path.join(args.testpath, scan)
|
||||
out_folder = os.path.join(args.outdir, scan)
|
||||
filter_depth(scan_folder, out_folder, os.path.join(args.outdir, 'igev_mvs{:0>3}_l3.ply'.format(scan_id)),
|
||||
args.geo_pixel_thres, args.geo_depth_thres, args.photo_thres, img_wh, 4)
|
||||
elif args.dataset=="tanks":
|
||||
# intermediate dataset
|
||||
if args.split == "intermediate":
|
||||
scans = ['Family', 'Francis', 'Horse', 'Lighthouse',
|
||||
'M60', 'Panther', 'Playground', 'Train']
|
||||
geo_mask_thres = {'Family': 5,
|
||||
'Francis': 6,
|
||||
'Horse': 5,
|
||||
'Lighthouse': 6,
|
||||
'M60': 5,
|
||||
'Panther': 5,
|
||||
'Playground': 5,
|
||||
'Train': 5}
|
||||
|
||||
for scan in scans:
|
||||
scan_folder = os.path.join(args.testpath, args.split, scan)
|
||||
out_folder = os.path.join(args.outdir, scan)
|
||||
|
||||
filter_depth(scan_folder, out_folder, os.path.join(args.outdir, scan + '.ply'),
|
||||
args.geo_pixel_thres, args.geo_depth_thres, args.photo_thres, img_wh, geo_mask_thres[scan])
|
||||
|
||||
# advanced dataset
|
||||
elif args.split == "advanced":
|
||||
scans = ['Auditorium', 'Ballroom', 'Courtroom',
|
||||
'Museum', 'Palace', 'Temple']
|
||||
geo_mask_thres = {'Auditorium': 3,
|
||||
'Ballroom': 4,
|
||||
'Courtroom': 4,
|
||||
'Museum': 4,
|
||||
'Palace': 5,
|
||||
'Temple': 4}
|
||||
|
||||
for scan in scans:
|
||||
scan_folder = os.path.join(args.testpath, args.split, scan)
|
||||
out_folder = os.path.join(args.outdir, scan)
|
||||
filter_depth(scan_folder, out_folder, os.path.join(args.outdir, scan + '.ply'),
|
||||
args.geo_pixel_thres, args.geo_depth_thres, args.photo_thres, img_wh, geo_mask_thres[scan])
|
||||
|
||||
elif args.dataset=="eth3d":
|
||||
if args.split == "test":
|
||||
scans = ['botanical_garden', 'boulders', 'bridge', 'door',
|
||||
'exhibition_hall', 'lecture_room', 'living_room', 'lounge',
|
||||
'observatory', 'old_computer', 'statue', 'terrace_2']
|
||||
|
||||
geo_mask_thres = {'botanical_garden':1, # 30 images, outdoor
|
||||
'boulders':1, # 26 images, outdoor
|
||||
'bridge':2, # 110 images, outdoor
|
||||
'door':2, # 6 images, indoor
|
||||
'exhibition_hall':2, # 68 images, indoor
|
||||
'lecture_room':2, # 23 images, indoor
|
||||
'living_room':2, # 65 images, indoor
|
||||
'lounge':1,# 10 images, indoor
|
||||
'observatory':2, # 27 images, outdoor
|
||||
'old_computer':2, # 54 images, indoor
|
||||
'statue':2, # 10 images, indoor
|
||||
'terrace_2':2 # 13 images, outdoor
|
||||
}
|
||||
for scan in scans:
|
||||
start_time = time.time()
|
||||
scan_folder = os.path.join(args.testpath, scan)
|
||||
out_folder = os.path.join(args.outdir, scan)
|
||||
filter_depth(scan_folder, out_folder, os.path.join(args.outdir, scan + '.ply'),
|
||||
args.geo_pixel_thres, args.geo_depth_thres, args.photo_thres, img_wh, geo_mask_thres[scan])
|
||||
print('scan: '+scan+' time = {:3f}'.format(time.time() - start_time))
|
||||
|
||||
elif args.split == "train":
|
||||
scans = ['courtyard', 'delivery_area', 'electro', 'facade',
|
||||
'kicker', 'meadow', 'office', 'pipes', 'playground',
|
||||
'relief', 'relief_2', 'terrace', 'terrains']
|
||||
|
||||
geo_mask_thres = {'courtyard':1, # 38 images, outdoor
|
||||
'delivery_area':2, # 44 images, indoor
|
||||
'electro':1, # 45 images, outdoor
|
||||
'facade':2, # 76 images, outdoor
|
||||
'kicker':1, # 31 images, indoor
|
||||
'meadow':1, # 15 images, outdoor
|
||||
'office':1, # 26 images, indoor
|
||||
'pipes':1,# 14 images, indoor
|
||||
'playground':1, # 38 images, outdoor
|
||||
'relief':1, # 31 images, indoor
|
||||
'relief_2':1, # 31 images, indoor
|
||||
'terrace':1, # 23 images, outdoor
|
||||
'terrains':2 # 42 images, indoor
|
||||
}
|
||||
|
||||
for scan in scans:
|
||||
start_time = time.time()
|
||||
scan_folder = os.path.join(args.testpath, scan)
|
||||
out_folder = os.path.join(args.outdir, scan)
|
||||
filter_depth(scan_folder, out_folder, os.path.join(args.outdir, scan + '.ply'),
|
||||
args.geo_pixel_thres, args.geo_depth_thres, args.photo_thres, img_wh, geo_mask_thres[scan])
|
||||
print('scan: '+scan+' time = {:3f}'.format(time.time() - start_time))
|
||||
else:
|
||||
filter_depth(args.testpath, args.outdir, os.path.join(args.outdir, 'custom.ply'),
|
||||
args.geo_pixel_thres, args.geo_depth_thres, args.photo_thres, img_wh, geo_mask_thres=3)
|
44
IGEV-MVS/evaluations/dtu/BaseEval2Obj_web.m
Normal file
44
IGEV-MVS/evaluations/dtu/BaseEval2Obj_web.m
Normal file
@ -0,0 +1,44 @@
|
||||
function BaseEval2Obj_web(BaseEval,method_string,outputPath)
|
||||
|
||||
if(nargin<3)
|
||||
outputPath='./';
|
||||
end
|
||||
|
||||
% tresshold for coloring alpha channel in the range of 0-10 mm
|
||||
dist_tresshold=10;
|
||||
|
||||
cSet=BaseEval.cSet;
|
||||
|
||||
Qdata=BaseEval.Qdata;
|
||||
alpha=min(BaseEval.Ddata,dist_tresshold)/dist_tresshold;
|
||||
|
||||
fid=fopen([outputPath method_string '2Stl_' num2str(cSet) ' .obj'],'w+');
|
||||
|
||||
for cP=1:size(Qdata,2)
|
||||
if(BaseEval.DataInMask(cP))
|
||||
C=[1 0 0]*alpha(cP)+[1 1 1]*(1-alpha(cP)); %coloring from red to white in the range of 0-10 mm (0 to dist_tresshold)
|
||||
else
|
||||
C=[0 1 0]*alpha(cP)+[0 0 1]*(1-alpha(cP)); %green to blue for points outside the mask (which are not included in the analysis)
|
||||
end
|
||||
fprintf(fid,'v %f %f %f %f %f %f\n',[Qdata(1,cP) Qdata(2,cP) Qdata(3,cP) C(1) C(2) C(3)]);
|
||||
end
|
||||
fclose(fid);
|
||||
|
||||
disp('Data2Stl saved as obj')
|
||||
|
||||
Qstl=BaseEval.Qstl;
|
||||
fid=fopen([outputPath 'Stl2' method_string '_' num2str(cSet) '.obj'],'w+');
|
||||
|
||||
alpha=min(BaseEval.Dstl,dist_tresshold)/dist_tresshold;
|
||||
|
||||
for cP=1:size(Qstl,2)
|
||||
if(BaseEval.StlAbovePlane(cP))
|
||||
C=[1 0 0]*alpha(cP)+[1 1 1]*(1-alpha(cP)); %coloring from red to white in the range of 0-10 mm (0 to dist_tresshold)
|
||||
else
|
||||
C=[0 1 0]*alpha(cP)+[0 0 1]*(1-alpha(cP)); %green to blue for points below plane (which are not included in the analysis)
|
||||
end
|
||||
fprintf(fid,'v %f %f %f %f %f %f\n',[Qstl(1,cP) Qstl(2,cP) Qstl(3,cP) C(1) C(2) C(3)]);
|
||||
end
|
||||
fclose(fid);
|
||||
|
||||
disp('Stl2Data saved as obj')
|
104
IGEV-MVS/evaluations/dtu/BaseEvalMain_web.m
Normal file
104
IGEV-MVS/evaluations/dtu/BaseEvalMain_web.m
Normal file
@ -0,0 +1,104 @@
|
||||
clear all
|
||||
close all
|
||||
format compact
|
||||
clc
|
||||
|
||||
% script to calculate distances have been measured for all included scans (UsedSets)
|
||||
|
||||
dataPath='D:\xgw\IterMVS_data\MVS Data\';
|
||||
plyPath='/data/xgw/IGEV_MVS/conf_03/';
|
||||
resultsPath='/data/xgw/IGEV_MVS/outputs_conf_03/';
|
||||
|
||||
|
||||
method_string='itermvs';
|
||||
light_string='l3'; % l3 is the setting with all lights on, l7 is randomly sampled between the 7 settings (index 0-6)
|
||||
representation_string='Points'; %mvs representation 'Points' or 'Surfaces'
|
||||
|
||||
switch representation_string
|
||||
case 'Points'
|
||||
eval_string='_Eval_'; %results naming
|
||||
settings_string='';
|
||||
end
|
||||
|
||||
% get sets used in evaluation
|
||||
UsedSets=[1 4 9 10 11 12 13 15 23 24 29 32 33 34 48 49 62 75 77 110 114 118];
|
||||
|
||||
result = zeros(length(UsedSets),4);
|
||||
|
||||
dst=0.2; %Min dist between points when reducing
|
||||
|
||||
for cIdx=1:length(UsedSets)
|
||||
%Data set number
|
||||
cSet = UsedSets(cIdx)
|
||||
%input data name
|
||||
DataInName=[plyPath sprintf('%s%03d_%s%s.ply',lower(method_string),cSet,light_string,settings_string)]
|
||||
|
||||
%results name
|
||||
%concatenate strings into one string
|
||||
EvalName=[resultsPath method_string eval_string num2str(cSet) '.mat']
|
||||
|
||||
disp(EvalName)
|
||||
|
||||
%check if file is already computed
|
||||
if(~exist(EvalName,'file'))
|
||||
disp(DataInName);
|
||||
|
||||
time=clock;time(4:5), drawnow
|
||||
|
||||
tic
|
||||
Mesh = plyread(DataInName);
|
||||
Qdata=[Mesh.vertex.x Mesh.vertex.y Mesh.vertex.z]';
|
||||
toc
|
||||
|
||||
BaseEval=PointCompareMain(cSet,Qdata,dst,dataPath);
|
||||
|
||||
disp('Saving results'), drawnow
|
||||
toc
|
||||
save(EvalName,'BaseEval');
|
||||
toc
|
||||
|
||||
% write obj-file of evaluation
|
||||
% BaseEval2Obj_web(BaseEval,method_string, resultsPath)
|
||||
% toc
|
||||
time=clock;time(4:5), drawnow
|
||||
|
||||
BaseEval.MaxDist=20; %outlier threshold of 20 mm
|
||||
|
||||
BaseEval.FilteredDstl=BaseEval.Dstl(BaseEval.StlAbovePlane); %use only points that are above the plane
|
||||
BaseEval.FilteredDstl=BaseEval.FilteredDstl(BaseEval.FilteredDstl<BaseEval.MaxDist); % discard outliers
|
||||
|
||||
BaseEval.FilteredDdata=BaseEval.Ddata(BaseEval.DataInMask); %use only points that within mask
|
||||
BaseEval.FilteredDdata=BaseEval.FilteredDdata(BaseEval.FilteredDdata<BaseEval.MaxDist); % discard outliers
|
||||
|
||||
fprintf("mean/median Data (acc.) %f/%f\n", mean(BaseEval.FilteredDdata), median(BaseEval.FilteredDdata));
|
||||
fprintf("mean/median Stl (comp.) %f/%f\n", mean(BaseEval.FilteredDstl), median(BaseEval.FilteredDstl));
|
||||
result(cIdx,1) = mean(BaseEval.FilteredDdata);
|
||||
result(cIdx,2) = median(BaseEval.FilteredDdata);
|
||||
result(cIdx,3) = mean(BaseEval.FilteredDstl);
|
||||
result(cIdx,4) = median(BaseEval.FilteredDstl);
|
||||
else
|
||||
|
||||
|
||||
load(EvalName);
|
||||
|
||||
BaseEval.MaxDist=20; %outlier threshold of 20 mm
|
||||
|
||||
BaseEval.FilteredDstl=BaseEval.Dstl(BaseEval.StlAbovePlane); %use only points that are above the plane
|
||||
BaseEval.FilteredDstl=BaseEval.FilteredDstl(BaseEval.FilteredDstl<BaseEval.MaxDist); % discard outliers
|
||||
|
||||
BaseEval.FilteredDdata=BaseEval.Ddata(BaseEval.DataInMask); %use only points that within mask
|
||||
BaseEval.FilteredDdata=BaseEval.FilteredDdata(BaseEval.FilteredDdata<BaseEval.MaxDist); % discard outliers
|
||||
|
||||
fprintf("mean/median Data (acc.) %f/%f\n", mean(BaseEval.FilteredDdata), median(BaseEval.FilteredDdata));
|
||||
fprintf("mean/median Stl (comp.) %f/%f\n", mean(BaseEval.FilteredDstl), median(BaseEval.FilteredDstl));
|
||||
result(cIdx,1) = mean(BaseEval.FilteredDdata);
|
||||
result(cIdx,2) = median(BaseEval.FilteredDdata);
|
||||
result(cIdx,3) = mean(BaseEval.FilteredDstl);
|
||||
result(cIdx,4) = median(BaseEval.FilteredDstl);
|
||||
end
|
||||
end
|
||||
|
||||
mean_result=mean(result);
|
||||
fprintf("final evaluation result on all scans: acc.: %f, comp.: %f, overall: %f\n", mean_result(1), mean_result(3), (mean_result(1)+mean_result(3))/2);
|
||||
|
||||
|
87
IGEV-MVS/evaluations/dtu/ComputeStat_web.m
Normal file
87
IGEV-MVS/evaluations/dtu/ComputeStat_web.m
Normal file
@ -0,0 +1,87 @@
|
||||
clear all
|
||||
close all
|
||||
format compact
|
||||
clc
|
||||
|
||||
% script to calculate the statistics for each scan given this will currently only run if distances have been measured
|
||||
% for all included scans (UsedSets)
|
||||
|
||||
% modify the path to evaluate your models
|
||||
dataPath='/home/SampleSet/MVS Data/';
|
||||
resultsPath='/home/PatchmatchNet/outputs/';
|
||||
|
||||
MaxDist=20; %outlier thresshold of 20 mm
|
||||
|
||||
time=clock;
|
||||
|
||||
method_string='patchmatchnet';
|
||||
light_string='l3'; %'l7'; l3 is the setting with all lights on, l7 is randomly sampled between the 7 settings (index 0-6)
|
||||
representation_string='Points'; %mvs representation 'Points' or 'Surfaces'
|
||||
|
||||
switch representation_string
|
||||
case 'Points'
|
||||
eval_string='_Eval_'; %results naming
|
||||
settings_string='';
|
||||
end
|
||||
|
||||
% get sets used in evaluation
|
||||
UsedSets=[1 4 9 10 11 12 13 15 23 24 29 32 33 34 48 49 62 75 77 110 114 118];
|
||||
|
||||
nStat=length(UsedSets);
|
||||
|
||||
% struct
|
||||
BaseStat.nStl=zeros(1,nStat);
|
||||
BaseStat.nData=zeros(1,nStat);
|
||||
BaseStat.MeanStl=zeros(1,nStat);
|
||||
BaseStat.MeanData=zeros(1,nStat);
|
||||
BaseStat.VarStl=zeros(1,nStat);
|
||||
BaseStat.VarData=zeros(1,nStat);
|
||||
BaseStat.MedStl=zeros(1,nStat);
|
||||
BaseStat.MedData=zeros(1,nStat);
|
||||
|
||||
for cStat=1:length(UsedSets) %Data set number
|
||||
|
||||
currentSet=UsedSets(cStat);
|
||||
|
||||
%input results name
|
||||
EvalName=[resultsPath method_string eval_string num2str(currentSet) '.mat'];
|
||||
|
||||
disp(EvalName);
|
||||
load(EvalName);
|
||||
|
||||
Dstl=BaseEval.Dstl(BaseEval.StlAbovePlane); %use only points that are above the plane
|
||||
Dstl=Dstl(Dstl<MaxDist); % discard outliers
|
||||
|
||||
Ddata=BaseEval.Ddata(BaseEval.DataInMask); %use only points that within mask
|
||||
Ddata=Ddata(Ddata<MaxDist); % discard outliers
|
||||
|
||||
BaseStat.nStl(cStat)=length(Dstl);
|
||||
BaseStat.nData(cStat)=length(Ddata);
|
||||
|
||||
BaseStat.MeanStl(cStat)=mean(Dstl);
|
||||
BaseStat.MeanData(cStat)=mean(Ddata);
|
||||
|
||||
BaseStat.VarStl(cStat)=var(Dstl);
|
||||
BaseStat.VarData(cStat)=var(Ddata);
|
||||
|
||||
BaseStat.MedStl(cStat)=median(Dstl);
|
||||
BaseStat.MedData(cStat)=median(Ddata);
|
||||
|
||||
disp("acc");
|
||||
disp(mean(Ddata));
|
||||
disp("comp");
|
||||
disp(mean(Dstl));
|
||||
time=clock;
|
||||
end
|
||||
|
||||
disp(BaseStat);
|
||||
disp("mean acc")
|
||||
disp(mean(BaseStat.MeanData));
|
||||
disp("mean comp")
|
||||
disp(mean(BaseStat.MeanStl));
|
||||
|
||||
totalStatName=[resultsPath 'TotalStat_' method_string eval_string '.mat']
|
||||
save(totalStatName,'BaseStat','time','MaxDist');
|
||||
|
||||
|
||||
|
50
IGEV-MVS/evaluations/dtu/MaxDistCP.m
Normal file
50
IGEV-MVS/evaluations/dtu/MaxDistCP.m
Normal file
@ -0,0 +1,50 @@
|
||||
function Dist = MaxDistCP(Qto,Qfrom,BB,MaxDist)
|
||||
|
||||
Dist=ones(1,size(Qfrom,2))*MaxDist;
|
||||
|
||||
Range=floor((BB(2,:)-BB(1,:))/MaxDist);
|
||||
|
||||
tic
|
||||
Done=0;
|
||||
LookAt=zeros(1,size(Qfrom,2));
|
||||
for x=0:Range(1),
|
||||
for y=0:Range(2),
|
||||
for z=0:Range(3),
|
||||
|
||||
Low=BB(1,:)+[x y z]*MaxDist;
|
||||
High=Low+MaxDist;
|
||||
|
||||
idxF=find(Qfrom(1,:)>=Low(1) & Qfrom(2,:)>=Low(2) & Qfrom(3,:)>=Low(3) &...
|
||||
Qfrom(1,:)<High(1) & Qfrom(2,:)<High(2) & Qfrom(3,:)<High(3));
|
||||
SQfrom=Qfrom(:,idxF);
|
||||
LookAt(idxF)=LookAt(idxF)+1; %Debug
|
||||
|
||||
Low=Low-MaxDist;
|
||||
High=High+MaxDist;
|
||||
idxT=find(Qto(1,:)>=Low(1) & Qto(2,:)>=Low(2) & Qto(3,:)>=Low(3) &...
|
||||
Qto(1,:)<High(1) & Qto(2,:)<High(2) & Qto(3,:)<High(3));
|
||||
SQto=Qto(:,idxT);
|
||||
|
||||
if(isempty(SQto))
|
||||
Dist(idxF)=MaxDist;
|
||||
else
|
||||
KDstl=KDTreeSearcher(SQto');
|
||||
[~,SDist] = knnsearch(KDstl,SQfrom');
|
||||
Dist(idxF)=SDist;
|
||||
|
||||
end
|
||||
|
||||
Done=Done+length(idxF); %Debug
|
||||
|
||||
end
|
||||
end
|
||||
%Complete=Done/size(Qfrom,2);
|
||||
%EstTime=(toc/Complete)/60
|
||||
%toc
|
||||
%LA=[sum(LookAt==0),...
|
||||
% sum(LookAt==1),...
|
||||
% sum(LookAt==2),...
|
||||
% sum(LookAt==3),...
|
||||
% sum(LookAt>3)]
|
||||
end
|
||||
|
58
IGEV-MVS/evaluations/dtu/PointCompareMain.m
Normal file
58
IGEV-MVS/evaluations/dtu/PointCompareMain.m
Normal file
@ -0,0 +1,58 @@
|
||||
function BaseEval=PointCompareMain(cSet,Qdata,dst,dataPath)
|
||||
% evaluation function the calculates the distantes from the reference data (stl) to the evalution points (Qdata) and the
|
||||
% distances from the evaluation points to the reference
|
||||
|
||||
tic
|
||||
% reduce points 0.2 mm neighbourhood density
|
||||
Qdata=reducePts_haa(Qdata,dst);
|
||||
toc
|
||||
|
||||
StlInName=[dataPath '/Points/stl/stl' sprintf('%03d',cSet) '_total.ply'];
|
||||
|
||||
StlMesh = plyread(StlInName); %STL points already reduced 0.2 mm neighbourhood density
|
||||
Qstl=[StlMesh.vertex.x StlMesh.vertex.y StlMesh.vertex.z]';
|
||||
|
||||
%Load Mask (ObsMask) and Bounding box (BB) and Resolution (Res)
|
||||
Margin=10;
|
||||
MaskName=[dataPath '/ObsMask/ObsMask' num2str(cSet) '_' num2str(Margin) '.mat'];
|
||||
load(MaskName)
|
||||
|
||||
MaxDist=60;
|
||||
disp('Computing Data 2 Stl distances')
|
||||
Ddata = MaxDistCP(Qstl,Qdata,BB,MaxDist);
|
||||
toc
|
||||
|
||||
disp('Computing Stl 2 Data distances')
|
||||
Dstl=MaxDistCP(Qdata,Qstl,BB,MaxDist);
|
||||
disp('Distances computed')
|
||||
toc
|
||||
|
||||
%use mask
|
||||
%From Get mask - inverted & modified.
|
||||
One=ones(1,size(Qdata,2));
|
||||
Qv=(Qdata-BB(1,:)'*One)/Res+1;
|
||||
Qv=round(Qv);
|
||||
|
||||
Midx1=find(Qv(1,:)>0 & Qv(1,:)<=size(ObsMask,1) & Qv(2,:)>0 & Qv(2,:)<=size(ObsMask,2) & Qv(3,:)>0 & Qv(3,:)<=size(ObsMask,3));
|
||||
MidxA=sub2ind(size(ObsMask),Qv(1,Midx1),Qv(2,Midx1),Qv(3,Midx1));
|
||||
Midx2=find(ObsMask(MidxA));
|
||||
|
||||
BaseEval.DataInMask(1:size(Qv,2))=false;
|
||||
BaseEval.DataInMask(Midx1(Midx2))=true; %If Data is within the mask
|
||||
|
||||
BaseEval.cSet=cSet;
|
||||
BaseEval.Margin=Margin; %Margin of masks
|
||||
BaseEval.dst=dst; %Min dist between points when reducing
|
||||
BaseEval.Qdata=Qdata; %Input data points
|
||||
BaseEval.Ddata=Ddata; %distance from data to stl
|
||||
BaseEval.Qstl=Qstl; %Input stl points
|
||||
BaseEval.Dstl=Dstl; %Distance from the stl to data
|
||||
|
||||
load([dataPath '/ObsMask/Plane' num2str(cSet)],'P')
|
||||
BaseEval.GroundPlane=P; % Plane used to destinguise which Stl points are 'used'
|
||||
BaseEval.StlAbovePlane=(P'*[Qstl;ones(1,size(Qstl,2))])>0; %Is stl above 'ground plane'
|
||||
BaseEval.Time=clock; %Time when computation is finished
|
||||
|
||||
|
||||
|
||||
|
454
IGEV-MVS/evaluations/dtu/plyread.m
Normal file
454
IGEV-MVS/evaluations/dtu/plyread.m
Normal file
@ -0,0 +1,454 @@
|
||||
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
|
||||
function [Elements,varargout] = plyread(Path,Str)
|
||||
%PLYREAD Read a PLY 3D data file.
|
||||
% [DATA,COMMENTS] = PLYREAD(FILENAME) reads a version 1.0 PLY file
|
||||
% FILENAME and returns a structure DATA. The fields in this structure
|
||||
% are defined by the PLY header; each element type is a field and each
|
||||
% element property is a subfield. If the file contains any comments,
|
||||
% they are returned in a cell string array COMMENTS.
|
||||
%
|
||||
% [TRI,PTS] = PLYREAD(FILENAME,'tri') or
|
||||
% [TRI,PTS,DATA,COMMENTS] = PLYREAD(FILENAME,'tri') converts vertex
|
||||
% and face data into triangular connectivity and vertex arrays. The
|
||||
% mesh can then be displayed using the TRISURF command.
|
||||
%
|
||||
% Note: This function is slow for large mesh files (+50K faces),
|
||||
% especially when reading data with list type properties.
|
||||
%
|
||||
% Example:
|
||||
% [Tri,Pts] = PLYREAD('cow.ply','tri');
|
||||
% trisurf(Tri,Pts(:,1),Pts(:,2),Pts(:,3));
|
||||
% colormap(gray); axis equal;
|
||||
%
|
||||
% See also: PLYWRITE
|
||||
|
||||
% Pascal Getreuer 2004
|
||||
|
||||
[fid,Msg] = fopen(Path,'rt'); % open file in read text mode
|
||||
|
||||
if fid == -1, error(Msg); end
|
||||
|
||||
Buf = fscanf(fid,'%s',1);
|
||||
if ~strcmp(Buf,'ply')
|
||||
fclose(fid);
|
||||
error('Not a PLY file.');
|
||||
end
|
||||
|
||||
|
||||
%%% read header %%%
|
||||
|
||||
Position = ftell(fid);
|
||||
Format = '';
|
||||
NumComments = 0;
|
||||
Comments = {}; % for storing any file comments
|
||||
NumElements = 0;
|
||||
NumProperties = 0;
|
||||
Elements = []; % structure for holding the element data
|
||||
ElementCount = []; % number of each type of element in file
|
||||
PropertyTypes = []; % corresponding structure recording property types
|
||||
ElementNames = {}; % list of element names in the order they are stored in the file
|
||||
PropertyNames = []; % structure of lists of property names
|
||||
|
||||
while 1
|
||||
Buf = fgetl(fid); % read one line from file
|
||||
BufRem = Buf;
|
||||
Token = {};
|
||||
Count = 0;
|
||||
|
||||
while ~isempty(BufRem) % split line into tokens
|
||||
[tmp,BufRem] = strtok(BufRem);
|
||||
|
||||
if ~isempty(tmp)
|
||||
Count = Count + 1; % count tokens
|
||||
Token{Count} = tmp;
|
||||
end
|
||||
end
|
||||
|
||||
if Count % parse line
|
||||
switch lower(Token{1})
|
||||
case 'format' % read data format
|
||||
if Count >= 2
|
||||
Format = lower(Token{2});
|
||||
|
||||
if Count == 3 & ~strcmp(Token{3},'1.0')
|
||||
fclose(fid);
|
||||
error('Only PLY format version 1.0 supported.');
|
||||
end
|
||||
end
|
||||
case 'comment' % read file comment
|
||||
NumComments = NumComments + 1;
|
||||
Comments{NumComments} = '';
|
||||
for i = 2:Count
|
||||
Comments{NumComments} = [Comments{NumComments},Token{i},' '];
|
||||
end
|
||||
case 'element' % element name
|
||||
if Count >= 3
|
||||
if isfield(Elements,Token{2})
|
||||
fclose(fid);
|
||||
error(['Duplicate element name, ''',Token{2},'''.']);
|
||||
end
|
||||
|
||||
NumElements = NumElements + 1;
|
||||
NumProperties = 0;
|
||||
Elements = setfield(Elements,Token{2},[]);
|
||||
PropertyTypes = setfield(PropertyTypes,Token{2},[]);
|
||||
ElementNames{NumElements} = Token{2};
|
||||
PropertyNames = setfield(PropertyNames,Token{2},{});
|
||||
CurElement = Token{2};
|
||||
ElementCount(NumElements) = str2double(Token{3});
|
||||
|
||||
if isnan(ElementCount(NumElements))
|
||||
fclose(fid);
|
||||
error(['Bad element definition: ',Buf]);
|
||||
end
|
||||
else
|
||||
error(['Bad element definition: ',Buf]);
|
||||
end
|
||||
case 'property' % element property
|
||||
if ~isempty(CurElement) & Count >= 3
|
||||
NumProperties = NumProperties + 1;
|
||||
eval(['tmp=isfield(Elements.',CurElement,',Token{Count});'],...
|
||||
'fclose(fid);error([''Error reading property: '',Buf])');
|
||||
|
||||
if tmp
|
||||
error(['Duplicate property name, ''',CurElement,'.',Token{2},'''.']);
|
||||
end
|
||||
|
||||
% add property subfield to Elements
|
||||
eval(['Elements.',CurElement,'.',Token{Count},'=[];'], ...
|
||||
'fclose(fid);error([''Error reading property: '',Buf])');
|
||||
% add property subfield to PropertyTypes and save type
|
||||
eval(['PropertyTypes.',CurElement,'.',Token{Count},'={Token{2:Count-1}};'], ...
|
||||
'fclose(fid);error([''Error reading property: '',Buf])');
|
||||
% record property name order
|
||||
eval(['PropertyNames.',CurElement,'{NumProperties}=Token{Count};'], ...
|
||||
'fclose(fid);error([''Error reading property: '',Buf])');
|
||||
else
|
||||
fclose(fid);
|
||||
|
||||
if isempty(CurElement)
|
||||
error(['Property definition without element definition: ',Buf]);
|
||||
else
|
||||
error(['Bad property definition: ',Buf]);
|
||||
end
|
||||
end
|
||||
case 'end_header' % end of header, break from while loop
|
||||
break;
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
%%% set reading for specified data format %%%
|
||||
|
||||
if isempty(Format)
|
||||
warning('Data format unspecified, assuming ASCII.');
|
||||
Format = 'ascii';
|
||||
end
|
||||
|
||||
switch Format
|
||||
case 'ascii'
|
||||
Format = 0;
|
||||
case 'binary_little_endian'
|
||||
Format = 1;
|
||||
case 'binary_big_endian'
|
||||
Format = 2;
|
||||
otherwise
|
||||
fclose(fid);
|
||||
error(['Data format ''',Format,''' not supported.']);
|
||||
end
|
||||
|
||||
if ~Format
|
||||
Buf = fscanf(fid,'%f'); % read the rest of the file as ASCII data
|
||||
BufOff = 1;
|
||||
else
|
||||
% reopen the file in read binary mode
|
||||
fclose(fid);
|
||||
|
||||
if Format == 1
|
||||
fid = fopen(Path,'r','ieee-le.l64'); % little endian
|
||||
else
|
||||
fid = fopen(Path,'r','ieee-be.l64'); % big endian
|
||||
end
|
||||
|
||||
% find the end of the header again (using ftell on the old handle doesn't give the correct position)
|
||||
BufSize = 8192;
|
||||
Buf = [blanks(10),char(fread(fid,BufSize,'uchar')')];
|
||||
i = [];
|
||||
tmp = -11;
|
||||
|
||||
while isempty(i)
|
||||
i = findstr(Buf,['end_header',13,10]); % look for end_header + CR/LF
|
||||
i = [i,findstr(Buf,['end_header',10])]; % look for end_header + LF
|
||||
|
||||
if isempty(i)
|
||||
tmp = tmp + BufSize;
|
||||
Buf = [Buf(BufSize+1:BufSize+10),char(fread(fid,BufSize,'uchar')')];
|
||||
end
|
||||
end
|
||||
|
||||
% seek to just after the line feed
|
||||
fseek(fid,i + tmp + 11 + (Buf(i + 10) == 13),-1);
|
||||
end
|
||||
|
||||
|
||||
%%% read element data %%%
|
||||
|
||||
% PLY and MATLAB data types (for fread)
|
||||
PlyTypeNames = {'char','uchar','short','ushort','int','uint','float','double', ...
|
||||
'char8','uchar8','short16','ushort16','int32','uint32','float32','double64'};
|
||||
MatlabTypeNames = {'schar','uchar','int16','uint16','int32','uint32','single','double'};
|
||||
SizeOf = [1,1,2,2,4,4,4,8]; % size in bytes of each type
|
||||
|
||||
for i = 1:NumElements
|
||||
% get current element property information
|
||||
eval(['CurPropertyNames=PropertyNames.',ElementNames{i},';']);
|
||||
eval(['CurPropertyTypes=PropertyTypes.',ElementNames{i},';']);
|
||||
NumProperties = size(CurPropertyNames,2);
|
||||
|
||||
% fprintf('Reading %s...\n',ElementNames{i});
|
||||
|
||||
if ~Format %%% read ASCII data %%%
|
||||
for j = 1:NumProperties
|
||||
Token = getfield(CurPropertyTypes,CurPropertyNames{j});
|
||||
|
||||
if strcmpi(Token{1},'list')
|
||||
Type(j) = 1;
|
||||
else
|
||||
Type(j) = 0;
|
||||
end
|
||||
end
|
||||
|
||||
% parse buffer
|
||||
if ~any(Type)
|
||||
% no list types
|
||||
Data = reshape(Buf(BufOff:BufOff+ElementCount(i)*NumProperties-1),NumProperties,ElementCount(i))';
|
||||
BufOff = BufOff + ElementCount(i)*NumProperties;
|
||||
else
|
||||
ListData = cell(NumProperties,1);
|
||||
|
||||
for k = 1:NumProperties
|
||||
ListData{k} = cell(ElementCount(i),1);
|
||||
end
|
||||
|
||||
% list type
|
||||
for j = 1:ElementCount(i)
|
||||
for k = 1:NumProperties
|
||||
if ~Type(k)
|
||||
Data(j,k) = Buf(BufOff);
|
||||
BufOff = BufOff + 1;
|
||||
else
|
||||
tmp = Buf(BufOff);
|
||||
ListData{k}{j} = Buf(BufOff+(1:tmp))';
|
||||
BufOff = BufOff + tmp + 1;
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
else %%% read binary data %%%
|
||||
% translate PLY data type names to MATLAB data type names
|
||||
ListFlag = 0; % = 1 if there is a list type
|
||||
SameFlag = 1; % = 1 if all types are the same
|
||||
|
||||
for j = 1:NumProperties
|
||||
Token = getfield(CurPropertyTypes,CurPropertyNames{j});
|
||||
|
||||
if ~strcmp(Token{1},'list') % non-list type
|
||||
tmp = rem(strmatch(Token{1},PlyTypeNames,'exact')-1,8)+1;
|
||||
|
||||
if ~isempty(tmp)
|
||||
TypeSize(j) = SizeOf(tmp);
|
||||
Type{j} = MatlabTypeNames{tmp};
|
||||
TypeSize2(j) = 0;
|
||||
Type2{j} = '';
|
||||
|
||||
SameFlag = SameFlag & strcmp(Type{1},Type{j});
|
||||
else
|
||||
fclose(fid);
|
||||
error(['Unknown property data type, ''',Token{1},''', in ', ...
|
||||
ElementNames{i},'.',CurPropertyNames{j},'.']);
|
||||
end
|
||||
else % list type
|
||||
if length(Token) == 3
|
||||
ListFlag = 1;
|
||||
SameFlag = 0;
|
||||
tmp = rem(strmatch(Token{2},PlyTypeNames,'exact')-1,8)+1;
|
||||
tmp2 = rem(strmatch(Token{3},PlyTypeNames,'exact')-1,8)+1;
|
||||
|
||||
if ~isempty(tmp) & ~isempty(tmp2)
|
||||
TypeSize(j) = SizeOf(tmp);
|
||||
Type{j} = MatlabTypeNames{tmp};
|
||||
TypeSize2(j) = SizeOf(tmp2);
|
||||
Type2{j} = MatlabTypeNames{tmp2};
|
||||
else
|
||||
fclose(fid);
|
||||
error(['Unknown property data type, ''list ',Token{2},' ',Token{3},''', in ', ...
|
||||
ElementNames{i},'.',CurPropertyNames{j},'.']);
|
||||
end
|
||||
else
|
||||
fclose(fid);
|
||||
error(['Invalid list syntax in ',ElementNames{i},'.',CurPropertyNames{j},'.']);
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
% read file
|
||||
if ~ListFlag
|
||||
if SameFlag
|
||||
% no list types, all the same type (fast)
|
||||
Data = fread(fid,[NumProperties,ElementCount(i)],Type{1})';
|
||||
else
|
||||
% no list types, mixed type
|
||||
Data = zeros(ElementCount(i),NumProperties);
|
||||
|
||||
for j = 1:ElementCount(i)
|
||||
for k = 1:NumProperties
|
||||
Data(j,k) = fread(fid,1,Type{k});
|
||||
end
|
||||
end
|
||||
end
|
||||
else
|
||||
ListData = cell(NumProperties,1);
|
||||
|
||||
for k = 1:NumProperties
|
||||
ListData{k} = cell(ElementCount(i),1);
|
||||
end
|
||||
|
||||
if NumProperties == 1
|
||||
BufSize = 512;
|
||||
SkipNum = 4;
|
||||
j = 0;
|
||||
|
||||
% list type, one property (fast if lists are usually the same length)
|
||||
while j < ElementCount(i)
|
||||
Position = ftell(fid);
|
||||
% read in BufSize count values, assuming all counts = SkipNum
|
||||
[Buf,BufSize] = fread(fid,BufSize,Type{1},SkipNum*TypeSize2(1));
|
||||
Miss = find(Buf ~= SkipNum); % find first count that is not SkipNum
|
||||
fseek(fid,Position + TypeSize(1),-1); % seek back to after first count
|
||||
|
||||
if isempty(Miss) % all counts are SkipNum
|
||||
Buf = fread(fid,[SkipNum,BufSize],[int2str(SkipNum),'*',Type2{1}],TypeSize(1))';
|
||||
fseek(fid,-TypeSize(1),0); % undo last skip
|
||||
|
||||
for k = 1:BufSize
|
||||
ListData{1}{j+k} = Buf(k,:);
|
||||
end
|
||||
|
||||
j = j + BufSize;
|
||||
BufSize = floor(1.5*BufSize);
|
||||
else
|
||||
if Miss(1) > 1 % some counts are SkipNum
|
||||
Buf2 = fread(fid,[SkipNum,Miss(1)-1],[int2str(SkipNum),'*',Type2{1}],TypeSize(1))';
|
||||
|
||||
for k = 1:Miss(1)-1
|
||||
ListData{1}{j+k} = Buf2(k,:);
|
||||
end
|
||||
|
||||
j = j + k;
|
||||
end
|
||||
|
||||
% read in the list with the missed count
|
||||
SkipNum = Buf(Miss(1));
|
||||
j = j + 1;
|
||||
ListData{1}{j} = fread(fid,[1,SkipNum],Type2{1});
|
||||
BufSize = ceil(0.6*BufSize);
|
||||
end
|
||||
end
|
||||
else
|
||||
% list type(s), multiple properties (slow)
|
||||
Data = zeros(ElementCount(i),NumProperties);
|
||||
|
||||
for j = 1:ElementCount(i)
|
||||
for k = 1:NumProperties
|
||||
if isempty(Type2{k})
|
||||
Data(j,k) = fread(fid,1,Type{k});
|
||||
else
|
||||
tmp = fread(fid,1,Type{k});
|
||||
ListData{k}{j} = fread(fid,[1,tmp],Type2{k});
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
% put data into Elements structure
|
||||
for k = 1:NumProperties
|
||||
if (~Format & ~Type(k)) | (Format & isempty(Type2{k}))
|
||||
eval(['Elements.',ElementNames{i},'.',CurPropertyNames{k},'=Data(:,k);']);
|
||||
else
|
||||
eval(['Elements.',ElementNames{i},'.',CurPropertyNames{k},'=ListData{k};']);
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
clear Data ListData;
|
||||
fclose(fid);
|
||||
|
||||
if (nargin > 1 & strcmpi(Str,'Tri')) | nargout > 2
|
||||
% find vertex element field
|
||||
Name = {'vertex','Vertex','point','Point','pts','Pts'};
|
||||
Names = [];
|
||||
|
||||
for i = 1:length(Name)
|
||||
if any(strcmp(ElementNames,Name{i}))
|
||||
Names = getfield(PropertyNames,Name{i});
|
||||
Name = Name{i};
|
||||
break;
|
||||
end
|
||||
end
|
||||
|
||||
if any(strcmp(Names,'x')) & any(strcmp(Names,'y')) & any(strcmp(Names,'z'))
|
||||
eval(['varargout{1}=[Elements.',Name,'.x,Elements.',Name,'.y,Elements.',Name,'.z];']);
|
||||
else
|
||||
varargout{1} = zeros(1,3);
|
||||
end
|
||||
|
||||
varargout{2} = Elements;
|
||||
varargout{3} = Comments;
|
||||
Elements = [];
|
||||
|
||||
% find face element field
|
||||
Name = {'face','Face','poly','Poly','tri','Tri'};
|
||||
Names = [];
|
||||
|
||||
for i = 1:length(Name)
|
||||
if any(strcmp(ElementNames,Name{i}))
|
||||
Names = getfield(PropertyNames,Name{i});
|
||||
Name = Name{i};
|
||||
break;
|
||||
end
|
||||
end
|
||||
|
||||
if ~isempty(Names)
|
||||
% find vertex indices property subfield
|
||||
PropertyName = {'vertex_indices','vertex_indexes','vertex_index','indices','indexes'};
|
||||
|
||||
for i = 1:length(PropertyName)
|
||||
if any(strcmp(Names,PropertyName{i}))
|
||||
PropertyName = PropertyName{i};
|
||||
break;
|
||||
end
|
||||
end
|
||||
|
||||
if ~iscell(PropertyName)
|
||||
% convert face index lists to triangular connectivity
|
||||
eval(['FaceIndices=varargout{2}.',Name,'.',PropertyName,';']);
|
||||
N = length(FaceIndices);
|
||||
Elements = zeros(N*2,3);
|
||||
Extra = 0;
|
||||
|
||||
for k = 1:N
|
||||
Elements(k,:) = FaceIndices{k}(1:3);
|
||||
|
||||
for j = 4:length(FaceIndices{k})
|
||||
Extra = Extra + 1;
|
||||
Elements(N + Extra,:) = [Elements(k,[1,j-1]),FaceIndices{k}(j)];
|
||||
end
|
||||
end
|
||||
Elements = Elements(1:N+Extra,:) + 1;
|
||||
end
|
||||
end
|
||||
else
|
||||
varargout{1} = Comments;
|
||||
end
|
35
IGEV-MVS/evaluations/dtu/reducePts_haa.m
Normal file
35
IGEV-MVS/evaluations/dtu/reducePts_haa.m
Normal file
@ -0,0 +1,35 @@
|
||||
function [ptsOut,indexSet] = reducePts_haa(pts, dst)
|
||||
|
||||
%Reduces a point set, pts, in a stochastic manner, such that the minimum sdistance
|
||||
% between points is 'dst'. Writen by abd, edited by haa, then by raje
|
||||
|
||||
nPoints=size(pts,2);
|
||||
|
||||
indexSet=true(nPoints,1);
|
||||
RandOrd=randperm(nPoints);
|
||||
|
||||
%tic
|
||||
NS = KDTreeSearcher(pts');
|
||||
%toc
|
||||
|
||||
% search the KNTree for close neighbours in a chunk-wise fashion to save memory if point cloud is really big
|
||||
Chunks=1:min(4e6,nPoints-1):nPoints;
|
||||
Chunks(end)=nPoints;
|
||||
|
||||
for cChunk=1:(length(Chunks)-1)
|
||||
Range=Chunks(cChunk):Chunks(cChunk+1);
|
||||
|
||||
idx = rangesearch(NS,pts(:,RandOrd(Range))',dst);
|
||||
|
||||
for i = 1:size(idx,1)
|
||||
id =RandOrd(i-1+Chunks(cChunk));
|
||||
if (indexSet(id))
|
||||
indexSet(idx{i}) = 0;
|
||||
indexSet(id) = 1;
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
ptsOut = pts(:,indexSet);
|
||||
|
||||
disp(['downsample factor: ' num2str(nPoints/sum(indexSet))]);
|
106
IGEV-MVS/lists/blendedmvs/train.txt
Normal file
106
IGEV-MVS/lists/blendedmvs/train.txt
Normal file
@ -0,0 +1,106 @@
|
||||
5c1f33f1d33e1f2e4aa6dda4
|
||||
5bfe5ae0fe0ea555e6a969ca
|
||||
5bff3c5cfe0ea555e6bcbf3a
|
||||
58eaf1513353456af3a1682a
|
||||
5bfc9d5aec61ca1dd69132a2
|
||||
5bf18642c50e6f7f8bdbd492
|
||||
5bf26cbbd43923194854b270
|
||||
5bf17c0fd439231948355385
|
||||
5be3ae47f44e235bdbbc9771
|
||||
5be3a5fb8cfdd56947f6b67c
|
||||
5bbb6eb2ea1cfa39f1af7e0c
|
||||
5ba75d79d76ffa2c86cf2f05
|
||||
5bb7a08aea1cfa39f1a947ab
|
||||
5b864d850d072a699b32f4ae
|
||||
5b6eff8b67b396324c5b2672
|
||||
5b6e716d67b396324c2d77cb
|
||||
5b69cc0cb44b61786eb959bf
|
||||
5b62647143840965efc0dbde
|
||||
5b60fa0c764f146feef84df0
|
||||
5b558a928bbfb62204e77ba2
|
||||
5b271079e0878c3816dacca4
|
||||
5b08286b2775267d5b0634ba
|
||||
5afacb69ab00705d0cefdd5b
|
||||
5af28cea59bc705737003253
|
||||
5af02e904c8216544b4ab5a2
|
||||
5aa515e613d42d091d29d300
|
||||
5c34529873a8df509ae57b58
|
||||
5c34300a73a8df509add216d
|
||||
5c1af2e2bee9a723c963d019
|
||||
5c1892f726173c3a09ea9aeb
|
||||
5c0d13b795da9479e12e2ee9
|
||||
5c062d84a96e33018ff6f0a6
|
||||
5bfd0f32ec61ca1dd69dc77b
|
||||
5bf21799d43923194842c001
|
||||
5bf3a82cd439231948877aed
|
||||
5bf03590d4392319481971dc
|
||||
5beb6e66abd34c35e18e66b9
|
||||
5be883a4f98cee15019d5b83
|
||||
5be47bf9b18881428d8fbc1d
|
||||
5bcf979a6d5f586b95c258cd
|
||||
5bce7ac9ca24970bce4934b6
|
||||
5bb8a49aea1cfa39f1aa7f75
|
||||
5b78e57afc8fcf6781d0c3ba
|
||||
5b21e18c58e2823a67a10dd8
|
||||
5b22269758e2823a67a3bd03
|
||||
5b192eb2170cf166458ff886
|
||||
5ae2e9c5fe405c5076abc6b2
|
||||
5adc6bd52430a05ecb2ffb85
|
||||
5ab8b8e029f5351f7f2ccf59
|
||||
5abc2506b53b042ead637d86
|
||||
5ab85f1dac4291329b17cb50
|
||||
5a969eea91dfc339a9a3ad2c
|
||||
5a8aa0fab18050187cbe060e
|
||||
5a7d3db14989e929563eb153
|
||||
5a69c47d0d5d0a7f3b2e9752
|
||||
5a618c72784780334bc1972d
|
||||
5a6464143d809f1d8208c43c
|
||||
5a588a8193ac3d233f77fbca
|
||||
5a57542f333d180827dfc132
|
||||
5a572fd9fc597b0478a81d14
|
||||
5a563183425d0f5186314855
|
||||
5a4a38dad38c8a075495b5d2
|
||||
5a48d4b2c7dab83a7d7b9851
|
||||
5a489fb1c7dab83a7d7b1070
|
||||
5a48ba95c7dab83a7d7b44ed
|
||||
5a3ca9cb270f0e3f14d0eddb
|
||||
5a3cb4e4270f0e3f14d12f43
|
||||
5a3f4aba5889373fbbc5d3b5
|
||||
5a0271884e62597cdee0d0eb
|
||||
59e864b2a9e91f2c5529325f
|
||||
599aa591d5b41f366fed0d58
|
||||
59350ca084b7f26bf5ce6eb8
|
||||
59338e76772c3e6384afbb15
|
||||
5c20ca3a0843bc542d94e3e2
|
||||
5c1dbf200843bc542d8ef8c4
|
||||
5c1b1500bee9a723c96c3e78
|
||||
5bea87f4abd34c35e1860ab5
|
||||
5c2b3ed5e611832e8aed46bf
|
||||
57f8d9bbe73f6760f10e916a
|
||||
5bf7d63575c26f32dbf7413b
|
||||
5be4ab93870d330ff2dce134
|
||||
5bd43b4ba6b28b1ee86b92dd
|
||||
5bccd6beca24970bce448134
|
||||
5bc5f0e896b66a2cd8f9bd36
|
||||
5b908d3dc6ab78485f3d24a9
|
||||
5b2c67b5e0878c381608b8d8
|
||||
5b4933abf2b5f44e95de482a
|
||||
5b3b353d8d46a939f93524b9
|
||||
5acf8ca0f3d8a750097e4b15
|
||||
5ab8713ba3799a1d138bd69a
|
||||
5aa235f64a17b335eeaf9609
|
||||
5aa0f9d7a9efce63548c69a1
|
||||
5a8315f624b8e938486e0bd8
|
||||
5a48c4e9c7dab83a7d7b5cc7
|
||||
59ecfd02e225f6492d20fcc9
|
||||
59f87d0bfa6280566fb38c9a
|
||||
59f363a8b45be22330016cad
|
||||
59f70ab1e5c5d366af29bf3e
|
||||
59e75a2ca9e91f2c5526005d
|
||||
5947719bf1b45630bd096665
|
||||
5947b62af1b45630bd0c2a02
|
||||
59056e6760bb961de55f3501
|
||||
58f7f7299f5b5647873cb110
|
||||
58cf4771d0f5fb221defe6da
|
||||
58d36897f387231e6c929903
|
||||
58c4bb4f4a69c55606122be4
|
7
IGEV-MVS/lists/blendedmvs/val.txt
Normal file
7
IGEV-MVS/lists/blendedmvs/val.txt
Normal file
@ -0,0 +1,7 @@
|
||||
5b7a3890fc8fcf6781e2593a
|
||||
5c189f2326173c3a09ed7ef3
|
||||
5b950c71608de421b1e7318f
|
||||
5a6400933d809f1d8200af15
|
||||
59d2657f82ca7774b1ec081d
|
||||
5ba19a8a360c7c30c1c169df
|
||||
59817e4a1bd4b175e7038d19
|
22
IGEV-MVS/lists/dtu/test.txt
Normal file
22
IGEV-MVS/lists/dtu/test.txt
Normal file
@ -0,0 +1,22 @@
|
||||
scan1
|
||||
scan4
|
||||
scan9
|
||||
scan10
|
||||
scan11
|
||||
scan12
|
||||
scan13
|
||||
scan15
|
||||
scan23
|
||||
scan24
|
||||
scan29
|
||||
scan32
|
||||
scan33
|
||||
scan34
|
||||
scan48
|
||||
scan49
|
||||
scan62
|
||||
scan75
|
||||
scan77
|
||||
scan110
|
||||
scan114
|
||||
scan118
|
79
IGEV-MVS/lists/dtu/train.txt
Normal file
79
IGEV-MVS/lists/dtu/train.txt
Normal file
@ -0,0 +1,79 @@
|
||||
scan2
|
||||
scan6
|
||||
scan7
|
||||
scan8
|
||||
scan14
|
||||
scan16
|
||||
scan18
|
||||
scan19
|
||||
scan20
|
||||
scan22
|
||||
scan30
|
||||
scan31
|
||||
scan36
|
||||
scan39
|
||||
scan41
|
||||
scan42
|
||||
scan44
|
||||
scan45
|
||||
scan46
|
||||
scan47
|
||||
scan50
|
||||
scan51
|
||||
scan52
|
||||
scan53
|
||||
scan55
|
||||
scan57
|
||||
scan58
|
||||
scan60
|
||||
scan61
|
||||
scan63
|
||||
scan64
|
||||
scan65
|
||||
scan68
|
||||
scan69
|
||||
scan70
|
||||
scan71
|
||||
scan72
|
||||
scan74
|
||||
scan76
|
||||
scan83
|
||||
scan84
|
||||
scan85
|
||||
scan87
|
||||
scan88
|
||||
scan89
|
||||
scan90
|
||||
scan91
|
||||
scan92
|
||||
scan93
|
||||
scan94
|
||||
scan95
|
||||
scan96
|
||||
scan97
|
||||
scan98
|
||||
scan99
|
||||
scan100
|
||||
scan101
|
||||
scan102
|
||||
scan103
|
||||
scan104
|
||||
scan105
|
||||
scan107
|
||||
scan108
|
||||
scan109
|
||||
scan111
|
||||
scan112
|
||||
scan113
|
||||
scan115
|
||||
scan116
|
||||
scan119
|
||||
scan120
|
||||
scan121
|
||||
scan122
|
||||
scan123
|
||||
scan124
|
||||
scan125
|
||||
scan126
|
||||
scan127
|
||||
scan128
|
18
IGEV-MVS/lists/dtu/val.txt
Normal file
18
IGEV-MVS/lists/dtu/val.txt
Normal file
@ -0,0 +1,18 @@
|
||||
scan3
|
||||
scan5
|
||||
scan17
|
||||
scan21
|
||||
scan28
|
||||
scan35
|
||||
scan37
|
||||
scan38
|
||||
scan40
|
||||
scan43
|
||||
scan56
|
||||
scan59
|
||||
scan66
|
||||
scan67
|
||||
scan82
|
||||
scan86
|
||||
scan106
|
||||
scan117
|
293
IGEV-MVS/train_mvs.py
Normal file
293
IGEV-MVS/train_mvs.py
Normal file
@ -0,0 +1,293 @@
|
||||
import argparse
|
||||
import os
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2'
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.parallel
|
||||
import torch.backends.cudnn as cudnn
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
import random
|
||||
import time
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from datasets import find_dataset_def
|
||||
from core.igev_mvs import IGEVMVS
|
||||
from core.submodule import depth_normalization, depth_unnormalization
|
||||
from utils import *
|
||||
import sys
|
||||
import datetime
|
||||
from tqdm import tqdm
|
||||
|
||||
cudnn.benchmark = True
|
||||
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description='IterMVStereo for high-resolution multi-view stereo')
|
||||
parser.add_argument('--mode', default='train', help='train or val', choices=['train', 'val'])
|
||||
|
||||
parser.add_argument('--dataset', default='dtu_yao', help='select dataset')
|
||||
parser.add_argument('--trainpath', default='/data/dtu_data/dtu_train/', help='train datapath')
|
||||
parser.add_argument('--valpath', help='validation datapath')
|
||||
parser.add_argument('--trainlist', default='./lists/dtu/train.txt', help='train list')
|
||||
parser.add_argument('--vallist', default='./lists/dtu/val.txt', help='validation list')
|
||||
parser.add_argument('--maxdisp', default=256)
|
||||
|
||||
parser.add_argument('--epochs', type=int, default=32, help='number of epochs to train')
|
||||
parser.add_argument('--lr', type=float, default=0.0002, help='learning rate')
|
||||
parser.add_argument('--wd', type=float, default=.00001, help='weight decay')
|
||||
|
||||
parser.add_argument('--batch_size', type=int, default=6, help='train batch size')
|
||||
parser.add_argument('--loadckpt', default=None, help='load a specific checkpoint')
|
||||
parser.add_argument('--logdir', default='./checkpoints/', help='the directory to save checkpoints/logs')
|
||||
parser.add_argument('--resume', action='store_true', help='continue to train the model')
|
||||
parser.add_argument('--regress', action='store_true', help='train the regression and confidence')
|
||||
parser.add_argument('--small_image', action='store_true', help='train with small input as 640x512, otherwise train with 1280x1024')
|
||||
|
||||
parser.add_argument('--summary_freq', type=int, default=20, help='print and summary frequency')
|
||||
parser.add_argument('--save_freq', type=int, default=1, help='save checkpoint frequency')
|
||||
parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed')
|
||||
parser.add_argument('--iteration', type=int, default=22, help='num of iteration of GRU')
|
||||
|
||||
try:
|
||||
from torch.cuda.amp import GradScaler
|
||||
except:
|
||||
# dummy GradScaler for PyTorch < 1.6
|
||||
class GradScaler:
|
||||
def __init__(self):
|
||||
pass
|
||||
def scale(self, loss):
|
||||
return loss
|
||||
def unscale_(self, optimizer):
|
||||
pass
|
||||
def step(self, optimizer):
|
||||
optimizer.step()
|
||||
def update(self):
|
||||
pass
|
||||
|
||||
def sequence_loss(disp_preds, disp_init_pred, depth_gt, mask, depth_min, depth_max, loss_gamma=0.9):
|
||||
""" Loss function defined over sequence of depth predictions """
|
||||
cross_entropy = nn.BCEWithLogitsLoss()
|
||||
n_predictions = len(disp_preds)
|
||||
assert n_predictions >= 1
|
||||
loss = 0.0
|
||||
mask = mask > 0.5
|
||||
batch, _, height, width = depth_gt.size()
|
||||
inverse_depth_min = (1.0 / depth_min).view(batch, 1, 1, 1)
|
||||
inverse_depth_max = (1.0 / depth_max).view(batch, 1, 1, 1)
|
||||
|
||||
normalized_disp_gt = depth_normalization(depth_gt, inverse_depth_min, inverse_depth_max)
|
||||
loss += 1.0 * F.l1_loss(disp_init_pred[mask], normalized_disp_gt[mask], reduction='mean')
|
||||
|
||||
if args.iteration != 0:
|
||||
for i in range(n_predictions):
|
||||
adjusted_loss_gamma = loss_gamma**(15/(n_predictions - 1))
|
||||
i_weight = adjusted_loss_gamma**(n_predictions - i - 1)
|
||||
loss += i_weight * F.l1_loss(disp_preds[i][mask], normalized_disp_gt[mask], reduction='mean')
|
||||
|
||||
return loss
|
||||
|
||||
# parse arguments and check
|
||||
args = parser.parse_args()
|
||||
if args.resume: # store_true means set the variable as "True"
|
||||
assert args.mode == "train"
|
||||
assert args.loadckpt is None
|
||||
if args.valpath is None:
|
||||
args.valpath = args.trainpath
|
||||
|
||||
torch.manual_seed(args.seed)
|
||||
torch.cuda.manual_seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
random.seed(args.seed)
|
||||
|
||||
if args.mode == "train":
|
||||
if not os.path.isdir(args.logdir):
|
||||
os.mkdir(args.logdir)
|
||||
|
||||
current_time_str = str(datetime.datetime.now().strftime('%Y%m%d_%H%M%S'))
|
||||
print("current time", current_time_str)
|
||||
|
||||
print("creating new summary file")
|
||||
logger = SummaryWriter(args.logdir)
|
||||
|
||||
print("argv:", sys.argv[1:])
|
||||
print_args(args)
|
||||
|
||||
# dataset, dataloader
|
||||
MVSDataset = find_dataset_def(args.dataset)
|
||||
|
||||
train_dataset = MVSDataset(args.trainpath, args.trainlist, "train", 5, robust_train=True)
|
||||
test_dataset = MVSDataset(args.valpath, args.vallist, "val", 5, robust_train=False)
|
||||
|
||||
TrainImgLoader = DataLoader(train_dataset, args.batch_size, shuffle=True, num_workers=4, drop_last=True)
|
||||
TestImgLoader = DataLoader(test_dataset, args.batch_size, shuffle=False, num_workers=4, drop_last=False)
|
||||
|
||||
# model, optimizer
|
||||
model = IGEVMVS(args)
|
||||
if args.mode in ["train", "val"]:
|
||||
model = nn.DataParallel(model)
|
||||
model.cuda()
|
||||
model_loss = sequence_loss
|
||||
optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=args.wd, eps=1e-8)
|
||||
|
||||
# load parameters
|
||||
start_epoch = 0
|
||||
if (args.mode == "train" and args.resume) or (args.mode == "val" and not args.loadckpt):
|
||||
saved_models = [fn for fn in os.listdir(args.logdir) if fn.endswith(".ckpt")]
|
||||
saved_models = sorted(saved_models, key=lambda x: int(x.split('_')[-1].split('.')[0]))
|
||||
# use the latest checkpoint file
|
||||
loadckpt = os.path.join(args.logdir, saved_models[-1])
|
||||
print("resuming", loadckpt)
|
||||
state_dict = torch.load(loadckpt)
|
||||
model.load_state_dict(state_dict['model'], strict=False)
|
||||
optimizer.load_state_dict(state_dict['optimizer'])
|
||||
start_epoch = state_dict['epoch'] + 1
|
||||
elif args.loadckpt:
|
||||
# load checkpoint file specified by args.loadckpt
|
||||
print("loading model {}".format(args.loadckpt))
|
||||
state_dict = torch.load(args.loadckpt)
|
||||
model.load_state_dict(state_dict['model'], strict=False)
|
||||
print("start at epoch {}".format(start_epoch))
|
||||
print('Number of model parameters: {}'.format(sum([p.data.nelement() for p in model.parameters()])))
|
||||
|
||||
|
||||
# main function
|
||||
def train(args):
|
||||
total_steps = len(TrainImgLoader) * args.epochs + 100
|
||||
lr_scheduler = optim.lr_scheduler.OneCycleLR(optimizer, args.lr, total_steps, pct_start=0.01, cycle_momentum=False, anneal_strategy='linear')
|
||||
|
||||
for epoch_idx in range(start_epoch, args.epochs):
|
||||
print('Epoch {}:'.format(epoch_idx))
|
||||
global_step = len(TrainImgLoader) * epoch_idx
|
||||
|
||||
# training
|
||||
tbar = tqdm(TrainImgLoader)
|
||||
for batch_idx, sample in enumerate(tbar):
|
||||
start_time = time.time()
|
||||
global_step = len(TrainImgLoader) * epoch_idx + batch_idx
|
||||
do_summary = global_step % args.summary_freq == 0
|
||||
scaler = GradScaler(enabled=True)
|
||||
loss, scalar_outputs = train_sample(args, sample, detailed_summary=do_summary, scaler=scaler)
|
||||
if do_summary:
|
||||
save_scalars(logger, 'train', scalar_outputs, global_step)
|
||||
del scalar_outputs
|
||||
|
||||
tbar.set_description(
|
||||
'Epoch {}/{}, Iter {}/{}, train loss = {:.3f}, time = {:.3f}'.format(epoch_idx, args.epochs, batch_idx, len(TrainImgLoader), loss, time.time() - start_time))
|
||||
|
||||
lr_scheduler.step()
|
||||
|
||||
# checkpoint
|
||||
if (epoch_idx + 1) % args.save_freq == 0:
|
||||
torch.save({
|
||||
'model': model.state_dict()},
|
||||
"{}/model_{:0>6}.ckpt".format(args.logdir, epoch_idx))
|
||||
torch.cuda.empty_cache()
|
||||
# testing
|
||||
avg_test_scalars = DictAverageMeter()
|
||||
tbar = tqdm(TestImgLoader)
|
||||
for batch_idx, sample in enumerate(tbar):
|
||||
start_time = time.time()
|
||||
global_step = len(TestImgLoader) * epoch_idx + batch_idx
|
||||
do_summary = global_step % args.summary_freq == 0
|
||||
loss, scalar_outputs = test_sample(args, sample, detailed_summary=do_summary)
|
||||
if do_summary:
|
||||
save_scalars(logger, 'test', scalar_outputs, global_step)
|
||||
avg_test_scalars.update(scalar_outputs)
|
||||
del scalar_outputs
|
||||
tbar.set_description('Epoch {}/{}, Iter {}/{}, test loss = {:.3f}, time = {:3f}'.format(epoch_idx, args.epochs, batch_idx,
|
||||
len(TestImgLoader), loss,
|
||||
time.time() - start_time))
|
||||
save_scalars(logger, 'fulltest', avg_test_scalars.mean(), global_step)
|
||||
print("avg_test_scalars:", avg_test_scalars.mean())
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def test(args):
|
||||
avg_test_scalars = DictAverageMeter()
|
||||
for batch_idx, sample in enumerate(TestImgLoader):
|
||||
start_time = time.time()
|
||||
loss, scalar_outputs = test_sample(args, sample, detailed_summary=True)
|
||||
avg_test_scalars.update(scalar_outputs)
|
||||
del scalar_outputs
|
||||
print('Iter {}/{}, test loss = {:.3f}, time = {:3f}'.format(batch_idx, len(TestImgLoader), loss,
|
||||
time.time() - start_time))
|
||||
if batch_idx % 100 == 0:
|
||||
print("Iter {}/{}, test results = {}".format(batch_idx, len(TestImgLoader), avg_test_scalars.mean()))
|
||||
print("final", avg_test_scalars)
|
||||
|
||||
|
||||
def train_sample(args, sample, detailed_summary=False, scaler=None):
|
||||
model.train()
|
||||
optimizer.zero_grad()
|
||||
sample_cuda = tocuda(sample)
|
||||
depth_gt = sample_cuda["depth"]
|
||||
mask = sample_cuda["mask"]
|
||||
depth_gt_0 = depth_gt['level_0']
|
||||
mask_0 = mask['level_0']
|
||||
depth_gt_1 = depth_gt['level_2']
|
||||
mask_1 = mask['level_2']
|
||||
|
||||
disp_init, disp_predictions = model(sample_cuda["imgs"], sample_cuda["proj_matrices"],
|
||||
sample_cuda["depth_min"], sample_cuda["depth_max"])
|
||||
|
||||
loss = model_loss(disp_predictions, disp_init, depth_gt_0, mask_0, sample_cuda["depth_min"], sample_cuda["depth_max"])
|
||||
scaler.scale(loss).backward()
|
||||
scaler.unscale_(optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
|
||||
inverse_depth_min = (1.0 / sample_cuda["depth_min"]).view(args.batch_size, 1, 1, 1)
|
||||
inverse_depth_max = (1.0 / sample_cuda["depth_max"]).view(args.batch_size, 1, 1, 1)
|
||||
depth_init = depth_unnormalization(disp_init, inverse_depth_min, inverse_depth_max)
|
||||
|
||||
depth_predictions = []
|
||||
for disp in disp_predictions:
|
||||
depth_predictions.append(depth_unnormalization(disp, inverse_depth_min, inverse_depth_max))
|
||||
|
||||
scalar_outputs = {"loss": loss}
|
||||
scalar_outputs["abs_error_initial"] = AbsDepthError_metrics(depth_init, depth_gt_0, mask_0 > 0.5)
|
||||
scalar_outputs["thres1mm_initial"] = Thres_metrics(depth_init, depth_gt_0, mask_0 > 0.5, 1)
|
||||
scalar_outputs["abs_error_final_full"] = AbsDepthError_metrics(depth_predictions[-1], depth_gt_0, mask_0 > 0.5)
|
||||
return tensor2float(loss), tensor2float(scalar_outputs)
|
||||
|
||||
|
||||
@make_nograd_func
|
||||
def test_sample(args, sample, detailed_summary=True):
|
||||
model.eval()
|
||||
sample_cuda = tocuda(sample)
|
||||
depth_gt = sample_cuda["depth"]
|
||||
mask = sample_cuda["mask"]
|
||||
depth_gt_0 = depth_gt['level_0']
|
||||
mask_0 = mask['level_0']
|
||||
depth_gt_1 = depth_gt['level_2']
|
||||
mask_1 = mask['level_2']
|
||||
|
||||
disp_init, disp_predictions = model(sample_cuda["imgs"], sample_cuda["proj_matrices"],
|
||||
sample_cuda["depth_min"], sample_cuda["depth_max"])
|
||||
|
||||
loss = model_loss(disp_predictions, disp_init, depth_gt_0, mask_0, sample_cuda["depth_min"], sample_cuda["depth_max"])
|
||||
|
||||
|
||||
inverse_depth_min = (1.0 / sample_cuda["depth_min"]).view(sample_cuda["depth_min"].size()[0], 1, 1, 1)
|
||||
inverse_depth_max = (1.0 / sample_cuda["depth_max"]).view(sample_cuda["depth_max"].size()[0], 1, 1, 1)
|
||||
depth_init = depth_unnormalization(disp_init, inverse_depth_min, inverse_depth_max)
|
||||
|
||||
depth_predictions = []
|
||||
for disp in disp_predictions:
|
||||
depth_predictions.append(depth_unnormalization(disp, inverse_depth_min, inverse_depth_max))
|
||||
|
||||
scalar_outputs = {"loss": loss}
|
||||
scalar_outputs["abs_error_initial"] = AbsDepthError_metrics(depth_init, depth_gt_0, mask_0 > 0.5)
|
||||
scalar_outputs["thres1mm_initial"] = Thres_metrics(depth_init, depth_gt_0, mask_0 > 0.5, 1)
|
||||
scalar_outputs["abs_error_final_full"] = AbsDepthError_metrics(depth_predictions[-1], depth_gt_0, mask_0 > 0.5)
|
||||
return tensor2float(loss), tensor2float(scalar_outputs)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if args.mode == "train":
|
||||
train(args)
|
||||
elif args.mode == "val":
|
||||
test(args)
|
||||
|
155
IGEV-MVS/utils.py
Normal file
155
IGEV-MVS/utils.py
Normal file
@ -0,0 +1,155 @@
|
||||
import numpy as np
|
||||
import torchvision.utils as vutils
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
# print arguments
|
||||
def print_args(args):
|
||||
print("################################ args ################################")
|
||||
for k, v in args.__dict__.items():
|
||||
print("{0: <10}\t{1: <30}\t{2: <20}".format(k, str(v), str(type(v))))
|
||||
print("########################################################################")
|
||||
|
||||
|
||||
# torch.no_grad warpper for functions
|
||||
def make_nograd_func(func):
|
||||
def wrapper(*f_args, **f_kwargs):
|
||||
with torch.no_grad():
|
||||
ret = func(*f_args, **f_kwargs)
|
||||
return ret
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
# convert a function into recursive style to handle nested dict/list/tuple variables
|
||||
def make_recursive_func(func):
|
||||
def wrapper(vars):
|
||||
if isinstance(vars, list):
|
||||
return [wrapper(x) for x in vars]
|
||||
elif isinstance(vars, tuple):
|
||||
return tuple([wrapper(x) for x in vars])
|
||||
elif isinstance(vars, dict):
|
||||
return {k: wrapper(v) for k, v in vars.items()}
|
||||
else:
|
||||
return func(vars)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@make_recursive_func
|
||||
def tensor2float(vars):
|
||||
if isinstance(vars, float):
|
||||
return vars
|
||||
elif isinstance(vars, torch.Tensor):
|
||||
return vars.data.item()
|
||||
else:
|
||||
raise NotImplementedError("invalid input type {} for tensor2float".format(type(vars)))
|
||||
|
||||
|
||||
@make_recursive_func
|
||||
def tensor2numpy(vars):
|
||||
if isinstance(vars, np.ndarray):
|
||||
return vars
|
||||
elif isinstance(vars, torch.Tensor):
|
||||
return vars.detach().cpu().numpy().copy()
|
||||
else:
|
||||
raise NotImplementedError("invalid input type {} for tensor2numpy".format(type(vars)))
|
||||
|
||||
|
||||
@make_recursive_func
|
||||
def tocuda(vars):
|
||||
if isinstance(vars, torch.Tensor):
|
||||
return vars.cuda()
|
||||
elif isinstance(vars, str):
|
||||
return vars
|
||||
else:
|
||||
raise NotImplementedError("invalid input type {} for tocuda".format(type(vars)))
|
||||
|
||||
|
||||
def save_scalars(logger, mode, scalar_dict, global_step):
|
||||
scalar_dict = tensor2float(scalar_dict)
|
||||
for key, value in scalar_dict.items():
|
||||
if not isinstance(value, (list, tuple)):
|
||||
name = '{}/{}'.format(mode, key)
|
||||
logger.add_scalar(name, value, global_step)
|
||||
else:
|
||||
for idx in range(len(value)):
|
||||
name = '{}/{}_{}'.format(mode, key, idx)
|
||||
logger.add_scalar(name, value[idx], global_step)
|
||||
|
||||
|
||||
def save_images(logger, mode, images_dict, global_step):
|
||||
images_dict = tensor2numpy(images_dict)
|
||||
|
||||
def preprocess(name, img):
|
||||
if not (len(img.shape) == 3 or len(img.shape) == 4):
|
||||
raise NotImplementedError("invalid img shape {}:{} in save_images".format(name, img.shape))
|
||||
if len(img.shape) == 3:
|
||||
img = img[:, np.newaxis, :, :]
|
||||
img = torch.from_numpy(img[:1])
|
||||
return vutils.make_grid(img, padding=0, nrow=1, normalize=True, scale_each=True)
|
||||
|
||||
for key, value in images_dict.items():
|
||||
if not isinstance(value, (list, tuple)):
|
||||
name = '{}/{}'.format(mode, key)
|
||||
logger.add_image(name, preprocess(name, value), global_step)
|
||||
else:
|
||||
for idx in range(len(value)):
|
||||
name = '{}/{}_{}'.format(mode, key, idx)
|
||||
logger.add_image(name, preprocess(name, value[idx]), global_step)
|
||||
|
||||
|
||||
class DictAverageMeter(object):
|
||||
def __init__(self):
|
||||
self.data = {}
|
||||
self.count = 0
|
||||
|
||||
def update(self, new_input):
|
||||
self.count += 1
|
||||
if len(self.data) == 0:
|
||||
for k, v in new_input.items():
|
||||
if not isinstance(v, float):
|
||||
raise NotImplementedError("invalid data {}: {}".format(k, type(v)))
|
||||
self.data[k] = v
|
||||
else:
|
||||
for k, v in new_input.items():
|
||||
if not isinstance(v, float):
|
||||
raise NotImplementedError("invalid data {}: {}".format(k, type(v)))
|
||||
self.data[k] += v
|
||||
|
||||
def mean(self):
|
||||
return {k: v / self.count for k, v in self.data.items()}
|
||||
|
||||
|
||||
# a wrapper to compute metrics for each image individually
|
||||
def compute_metrics_for_each_image(metric_func):
|
||||
def wrapper(depth_est, depth_gt, mask, *args):
|
||||
batch_size = depth_gt.shape[0]
|
||||
results = []
|
||||
# compute result one by one
|
||||
for idx in range(batch_size):
|
||||
ret = metric_func(depth_est[idx], depth_gt[idx], mask[idx], *args)
|
||||
results.append(ret)
|
||||
return torch.stack(results).mean()
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@make_nograd_func
|
||||
@compute_metrics_for_each_image
|
||||
def Thres_metrics(depth_est, depth_gt, mask, thres):
|
||||
# if thres is int or float, then True
|
||||
assert isinstance(thres, (int, float))
|
||||
depth_est, depth_gt = depth_est[mask], depth_gt[mask]
|
||||
errors = torch.abs(depth_est - depth_gt)
|
||||
err_mask = errors > thres
|
||||
return torch.mean(err_mask.float())
|
||||
|
||||
|
||||
# NOTE: please do not use this to build up training loss
|
||||
@make_nograd_func
|
||||
@compute_metrics_for_each_image
|
||||
def AbsDepthError_metrics(depth_est, depth_gt, mask):
|
||||
depth_est, depth_gt = depth_est[mask], depth_gt[mask]
|
||||
return torch.mean((depth_est - depth_gt).abs())
|
Loading…
Reference in New Issue
Block a user