diff --git a/IGEV-MVS/core/__init__.py b/IGEV-MVS/core/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/IGEV-MVS/core/corr.py b/IGEV-MVS/core/corr.py new file mode 100644 index 0000000..f1283f6 --- /dev/null +++ b/IGEV-MVS/core/corr.py @@ -0,0 +1,61 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from .submodule import * + +class CorrBlock1D_Cost_Volume: + def __init__(self, init_corr, corr, num_levels=2, radius=4, inverse_depth_min=None, inverse_depth_max=None, num_sample=None): + self.num_levels = 2 + self.radius = radius + self.inverse_depth_min = inverse_depth_min + self.inverse_depth_max = inverse_depth_max + self.num_sample = num_sample + self.corr_pyramid = [] + self.init_corr_pyramid = [] + + # all pairs correlation + + # batch, h1, w1, dim, w2 = corr.shape + b, c, d, h, w = corr.shape + corr = corr.permute(0, 3, 4, 1, 2).reshape(b*h*w, 1, 1, d) + init_corr = init_corr.permute(0, 3, 4, 1, 2).reshape(b*h*w, 1, 1, d) + + self.corr_pyramid.append(corr) + self.init_corr_pyramid.append(init_corr) + + + for i in range(self.num_levels): + corr = F.avg_pool2d(corr, [1,2], stride=[1,2]) + self.corr_pyramid.append(corr) + + for i in range(self.num_levels): + init_corr = F.avg_pool2d(init_corr, [1,2], stride=[1,2]) + self.init_corr_pyramid.append(init_corr) + + + + def __call__(self, disp): + r = self.radius + b, _, h, w = disp.shape + out_pyramid = [] + for i in range(self.num_levels): + corr = self.corr_pyramid[i] + init_corr = self.init_corr_pyramid[i] + dx = torch.linspace(-r, r, 2*r+1) + dx = dx.view(1, 1, 2*r+1, 1).to(disp.device) + x0 = dx + disp.reshape(b*h*w, 1, 1, 1) / 2**i + y0 = torch.zeros_like(x0) + + disp_lvl = torch.cat([x0,y0], dim=-1) + corr = bilinear_sampler(corr, disp_lvl) + corr = corr.view(b, h, w, -1) + + init_corr = bilinear_sampler(init_corr, disp_lvl) + init_corr = init_corr.view(b, h, w, -1) + + out_pyramid.append(corr) + out_pyramid.append(init_corr) + + + out = torch.cat(out_pyramid, dim=-1) + return out.permute(0, 3, 1, 2).contiguous().float() \ No newline at end of file diff --git a/IGEV-MVS/core/extractor.py b/IGEV-MVS/core/extractor.py new file mode 100644 index 0000000..8bac6f9 --- /dev/null +++ b/IGEV-MVS/core/extractor.py @@ -0,0 +1,212 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import timm +import math +from .submodule import * + +class ResidualBlock(nn.Module): + def __init__(self, in_planes, planes, norm_fn='group', stride=1): + super(ResidualBlock, self).__init__() + + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + if not (stride == 1 and in_planes == planes): + self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + + elif norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(planes) + self.norm2 = nn.BatchNorm2d(planes) + if not (stride == 1 and in_planes == planes): + self.norm3 = nn.BatchNorm2d(planes) + + elif norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(planes) + self.norm2 = nn.InstanceNorm2d(planes) + if not (stride == 1 and in_planes == planes): + self.norm3 = nn.InstanceNorm2d(planes) + + elif norm_fn == 'none': + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + if not (stride == 1 and in_planes == planes): + self.norm3 = nn.Sequential() + + if stride == 1 and in_planes == planes: + self.downsample = None + + else: + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) + + + def forward(self, x): + y = x + y = self.conv1(y) + y = self.norm1(y) + y = self.relu(y) + y = self.conv2(y) + y = self.norm2(y) + y = self.relu(y) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x+y) + +class MultiBasicEncoder(nn.Module): + def __init__(self, output_dim=[128], norm_fn='batch', dropout=0.0, downsample=3): + super(MultiBasicEncoder, self).__init__() + self.norm_fn = norm_fn + self.downsample = downsample + + if self.norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) + + elif self.norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(64) + + elif self.norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(64) + + elif self.norm_fn == 'none': + self.norm1 = nn.Sequential() + + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=1 + (downsample > 2), padding=3) + self.relu1 = nn.ReLU(inplace=True) + + self.in_planes = 64 + self.layer1 = self._make_layer(64, stride=1) + self.layer2 = self._make_layer(96, stride=1 + (downsample > 1)) + self.layer3 = self._make_layer(128, stride=1 + (downsample > 0)) + self.layer4 = self._make_layer(128, stride=2) + self.layer5 = self._make_layer(128, stride=2) + + output_list = [] + + for dim in output_dim: + conv_out = nn.Sequential( + ResidualBlock(128, 128, self.norm_fn, stride=1), + nn.Conv2d(128, dim[2], 3, padding=1)) + output_list.append(conv_out) + + self.outputs04 = nn.ModuleList(output_list) + + output_list = [] + for dim in output_dim: + conv_out = nn.Sequential( + ResidualBlock(128, 128, self.norm_fn, stride=1), + nn.Conv2d(128, dim[1], 3, padding=1)) + output_list.append(conv_out) + + self.outputs08 = nn.ModuleList(output_list) + + output_list = [] + for dim in output_dim: + conv_out = nn.Conv2d(128, dim[0], 3, padding=1) + output_list.append(conv_out) + + self.outputs16 = nn.ModuleList(output_list) + + if dropout > 0: + self.dropout = nn.Dropout2d(p=dropout) + else: + self.dropout = None + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) + layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + def forward(self, x, dual_inp=False, num_layers=3): + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + + if dual_inp: + v = x + x = x[:(x.shape[0]//2)] + + outputs04 = [f(x) for f in self.outputs04] + if num_layers == 1: + return (outputs04, v) if dual_inp else (outputs04,) + + y = self.layer4(x) + outputs08 = [f(y) for f in self.outputs08] + + if num_layers == 2: + return (outputs04, outputs08, v) if dual_inp else (outputs04, outputs08) + + z = self.layer5(y) + outputs16 = [f(z) for f in self.outputs16] + + return (outputs04, outputs08, outputs16, v) if dual_inp else (outputs04, outputs08, outputs16) + +class Feature(SubModule): + def __init__(self): + super(Feature, self).__init__() + pretrained = True + model = timm.create_model('mobilenetv2_100', pretrained=pretrained, features_only=True) + + layers = [1,2,3,5,6] + chans = [16, 24, 32, 96, 160] + self.conv_stem = model.conv_stem + self.bn1 = model.bn1 + + self.block0 = torch.nn.Sequential(*model.blocks[0:layers[0]]) + self.block1 = torch.nn.Sequential(*model.blocks[layers[0]:layers[1]]) + self.block2 = torch.nn.Sequential(*model.blocks[layers[1]:layers[2]]) + self.block3 = torch.nn.Sequential(*model.blocks[layers[2]:layers[3]]) + self.block4 = torch.nn.Sequential(*model.blocks[layers[3]:layers[4]]) + + self.deconv32_16 = Conv2x_IN(chans[4], chans[3], deconv=True, concat=True) + self.deconv16_8 = Conv2x_IN(chans[3]*2, chans[2], deconv=True, concat=True) + self.deconv8_4 = Conv2x_IN(chans[2]*2, chans[1], deconv=True, concat=True) + self.conv4 = BasicConv_IN(chans[1]*2, chans[1]*2, kernel_size=3, stride=1, padding=1) + + def forward(self, x): + B, V, _, H, W = x.size() + x = x.view(B * V, -1, H, W) + #x = self.act1(self.bn1(self.conv_stem(x))) + x = self.bn1(self.conv_stem(x)) + x2 = self.block0(x) + x4 = self.block1(x2) + # return x4,x4,x4,x4 + x8 = self.block2(x4) + x16 = self.block3(x8) + x32 = self.block4(x16) + + x16 = self.deconv32_16(x32, x16) + x8 = self.deconv16_8(x16, x8) + x4 = self.deconv8_4(x8, x4) + x4 = self.conv4(x4) + + x4 = x4.view(B, V, -1, H // 4, W // 4) + x8 = x8.view(B, V, -1, H // 8, W // 8) + x16 = x16.view(B, V, -1, H // 16, W // 16) + x32 = x32.view(B, V, -1, H // 32, W // 32) + return [x4, x8, x16, x32] \ No newline at end of file diff --git a/IGEV-MVS/core/igev_mvs.py b/IGEV-MVS/core/igev_mvs.py new file mode 100644 index 0000000..8e4290d --- /dev/null +++ b/IGEV-MVS/core/igev_mvs.py @@ -0,0 +1,195 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from .submodule import * +from .corr import * +from .extractor import * +from .update import * + +try: + autocast = torch.cuda.amp.autocast +except: + class autocast: + def __init__(self, enabled): + pass + def __enter__(self): + pass + def __exit__(self, *args): + pass + +class IGEVMVS(nn.Module): + def __init__(self, args): + super().__init__() + + context_dims = [128, 128, 128] + self.n_gru_layers = 3 + self.slow_fast_gru = False + self.mixed_precision = True + self.num_sample = 64 + self.G = 1 + self.corr_radius = 4 + self.corr_levels = 2 + self.iters = args.iteration + self.update_block = BasicMultiUpdateBlock(hidden_dims=context_dims) + self.conv_hidden_1 = nn.Conv2d(64, 128, kernel_size=3, padding=1, stride=1) + self.conv_hidden_2 = nn.Conv2d(128, 128, kernel_size=3, padding=1, stride=2) + self.conv_hidden_4 = nn.Conv2d(128, 128, kernel_size=3, padding=1, stride=2) + self.feature = Feature() + + self.stem_2 = nn.Sequential( + BasicConv_IN(3, 32, kernel_size=3, stride=2, padding=1), + nn.Conv2d(32, 32, 3, 1, 1, bias=False), + nn.InstanceNorm2d(32), nn.ReLU() + ) + self.stem_4 = nn.Sequential( + BasicConv_IN(32, 48, kernel_size=3, stride=2, padding=1), + nn.Conv2d(48, 48, 3, 1, 1, bias=False), + nn.InstanceNorm2d(48), nn.ReLU() + ) + + self.conv = BasicConv_IN(96, 48, kernel_size=3, padding=1, stride=1) + self.desc = nn.Conv2d(48, 48, kernel_size=1, padding=0, stride=1) + + self.spx = nn.Sequential(nn.ConvTranspose2d(2*32, 9, kernel_size=4, stride=2, padding=1),) + self.spx_2 = Conv2x_IN(32, 32, True) + self.spx_4 = nn.Sequential( + BasicConv_IN(96, 32, kernel_size=3, stride=1, padding=1), + nn.Conv2d(32, 32, 3, 1, 1, bias=False), + nn.InstanceNorm2d(32), nn.ReLU() + ) + + self.depth_initialization = DepthInitialization(self.num_sample) + self.pixel_view_weight = PixelViewWeight(self.G) + + self.corr_stem = BasicConv(1, 8, is_3d=True, kernel_size=3, stride=1, padding=1) + self.corr_feature_att = FeatureAtt(8, 96) + self.cost_agg = hourglass(8) + + self.spx_2_gru = Conv2x(32, 32, True) + self.spx_gru = nn.Sequential(nn.ConvTranspose2d(2*32, 9, kernel_size=4, stride=2, padding=1),) + + def upsample_disp(self, depth, mask_feat_4, stem_2x): + with autocast(enabled=self.mixed_precision): + xspx = self.spx_2_gru(mask_feat_4, stem_2x) + spx_pred = self.spx_gru(xspx) + spx_pred = F.softmax(spx_pred, 1) + + up_depth = context_upsample(depth, spx_pred).unsqueeze(1) + + return up_depth + + def forward(self, imgs, proj_matrices, depth_min, depth_max, test_mode=False): + proj_matrices_2 = torch.unbind(proj_matrices['level_2'].float(), 1) + depth_min = depth_min.float() + depth_max = depth_max.float() + + ref_proj = proj_matrices_2[0] + src_projs = proj_matrices_2[1:] + + with autocast(enabled=self.mixed_precision): + images = torch.unbind(imgs['level_0'], dim=1) + features = self.feature(imgs['level_0']) + ref_feature = [] + for fea in features: + ref_feature.append(torch.unbind(fea, dim=1)[0]) + src_features = [src_fea for src_fea in torch.unbind(features[0], dim=1)[1:]] + + stem_2x = self.stem_2(images[0]) + stem_4x = self.stem_4(stem_2x) + ref_feature[0] = torch.cat((ref_feature[0], stem_4x), 1) + + for idx, src_fea in enumerate(src_features): + stem_2y = self.stem_2(images[idx + 1]) + stem_4y = self.stem_4(stem_2y) + src_features[idx] = torch.cat((src_fea, stem_4y), 1) + + match_left = self.desc(self.conv(ref_feature[0])) + match_left = match_left / torch.norm(match_left, 2, 1, True) + + match_rights = [self.desc(self.conv(src_fea)) for src_fea in src_features] + match_rights = [match_right / torch.norm(match_right, 2, 1, True) for match_right in match_rights] + + xspx = self.spx_4(ref_feature[0]) + xspx = self.spx_2(xspx, stem_2x) + spx_pred = self.spx(xspx) + spx_pred = F.softmax(spx_pred, 1) + + batch, dim, height, width = match_left.size() + inverse_depth_min = (1.0 / depth_min).view(batch, 1, 1, 1) + inverse_depth_max = (1.0 / depth_max).view(batch, 1, 1, 1) + + device = match_left.device + correlation_sum = 0 + view_weight_sum = 1e-5 + + match_left = match_left.float() + depth_samples = self.depth_initialization(inverse_depth_min, inverse_depth_max, height, width, device) + for src_feature, src_proj in zip(match_rights, src_projs): + src_feature = src_feature.float() + warped_feature = differentiable_warping(src_feature, src_proj, ref_proj, depth_samples) + warped_feature = warped_feature.view(batch, self.G, dim // self.G, self.num_sample, height, width) + correlation = torch.mean(warped_feature * match_left.view(batch, self.G, dim // self.G, 1, height, width), dim=2, keepdim=False) + + view_weight = self.pixel_view_weight(correlation) + del warped_feature, src_feature, src_proj + + correlation_sum += correlation * view_weight.unsqueeze(1) + view_weight_sum += view_weight_sum + view_weight.unsqueeze(1) + del correlation, view_weight + del match_left, match_rights, src_projs + + with autocast(enabled=self.mixed_precision): + init_corr_volume = correlation_sum.div_(view_weight_sum) + corr_volume = self.corr_stem(init_corr_volume) + corr_volume = self.corr_feature_att(corr_volume, ref_feature[0]) + regularized_cost_volume = self.cost_agg(corr_volume, ref_feature) + + GEV_hidden = self.conv_hidden_1(regularized_cost_volume.squeeze(1)) + + GEV_hidden_2 = self.conv_hidden_2(GEV_hidden) + + GEV_hidden_4 = self.conv_hidden_4(GEV_hidden_2) + + net_list = [GEV_hidden, GEV_hidden_2, GEV_hidden_4] + + net_list = [torch.tanh(x) for x in net_list] + + corr_block = CorrBlock1D_Cost_Volume + + init_corr_volume = init_corr_volume.float() + regularized_cost_volume = regularized_cost_volume.float() + probability = F.softmax(regularized_cost_volume.squeeze(1), dim=1) + index = torch.arange(0, self.num_sample, 1, device=probability.device).view(1, self.num_sample, 1, 1).float() + disp_init = torch.sum(index * probability, dim = 1, keepdim=True) + + corr_fn = corr_block(init_corr_volume, regularized_cost_volume, radius=self.corr_radius, num_levels=self.corr_levels, inverse_depth_min=inverse_depth_min, inverse_depth_max=inverse_depth_max, num_sample=self.num_sample) + + disp_predictions = [] + disp = disp_init + + for itr in range(self.iters): + disp = disp.detach() + corr = corr_fn(disp) + + with autocast(enabled=self.mixed_precision): + if self.n_gru_layers == 3 and self.slow_fast_gru: # Update low-res GRU + net_list = self.update_block(net_list, iter16=True, iter08=False, iter04=False, update=False) + if self.n_gru_layers >= 2 and self.slow_fast_gru:# Update low-res GRU and mid-res GRU + net_list = self.update_block(net_list, iter16=self.n_gru_layers==3, iter08=True, iter04=False, update=False) + net_list, mask_feat_4, delta_disp = self.update_block(net_list, corr, disp, iter16=self.n_gru_layers==3, iter08=self.n_gru_layers>=2) + + disp = disp + delta_disp + + if test_mode and itr < self.iters-1: + continue + + disp_up = self.upsample_disp(disp, mask_feat_4, stem_2x) / (self.num_sample-1) + disp_predictions.append(disp_up) + + disp_init = context_upsample(disp_init, spx_pred.float()).unsqueeze(1) / (self.num_sample-1) + + if test_mode: + return disp_up + + + return disp_init, disp_predictions \ No newline at end of file diff --git a/IGEV-MVS/core/submodule.py b/IGEV-MVS/core/submodule.py new file mode 100644 index 0000000..3caa80d --- /dev/null +++ b/IGEV-MVS/core/submodule.py @@ -0,0 +1,396 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math + +class SubModule(nn.Module): + def __init__(self): + super(SubModule, self).__init__() + + def weight_init(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.Conv3d): + n = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm3d): + m.weight.data.fill_(1) + m.bias.data.zero_() + +class BasicConv(nn.Module): + + def __init__(self, in_channels, out_channels, deconv=False, is_3d=False, bn=True, relu=True, **kwargs): + super(BasicConv, self).__init__() + + self.relu = relu + self.use_bn = bn + if is_3d: + if deconv: + self.conv = nn.ConvTranspose3d(in_channels, out_channels, bias=False, **kwargs) + else: + self.conv = nn.Conv3d(in_channels, out_channels, bias=False, **kwargs) + self.bn = nn.BatchNorm3d(out_channels) + else: + if deconv: + self.conv = nn.ConvTranspose2d(in_channels, out_channels, bias=False, **kwargs) + else: + self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) + self.bn = nn.BatchNorm2d(out_channels) + + def forward(self, x): + x = self.conv(x) + if self.use_bn: + x = self.bn(x) + if self.relu: + x = nn.LeakyReLU()(x)#, inplace=True) + return x + +class BasicConv_IN(nn.Module): + + def __init__(self, in_channels, out_channels, deconv=False, is_3d=False, IN=True, relu=True, **kwargs): + super(BasicConv_IN, self).__init__() + + self.relu = relu + self.use_in = IN + if is_3d: + if deconv: + self.conv = nn.ConvTranspose3d(in_channels, out_channels, bias=False, **kwargs) + else: + self.conv = nn.Conv3d(in_channels, out_channels, bias=False, **kwargs) + self.IN = nn.InstanceNorm3d(out_channels) + else: + if deconv: + self.conv = nn.ConvTranspose2d(in_channels, out_channels, bias=False, **kwargs) + else: + self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) + self.IN = nn.InstanceNorm2d(out_channels) + + def forward(self, x): + x = self.conv(x) + if self.use_in: + x = self.IN(x) + if self.relu: + x = nn.LeakyReLU()(x)#, inplace=True) + return x + +class Conv2x(nn.Module): + + def __init__(self, in_channels, out_channels, deconv=False, is_3d=False, concat=True, keep_concat=True, bn=True, relu=True, keep_dispc=False): + super(Conv2x, self).__init__() + self.concat = concat + self.is_3d = is_3d + if deconv and is_3d: + kernel = (4, 4, 4) + elif deconv: + kernel = 4 + else: + kernel = 3 + + if deconv and is_3d and keep_dispc: + kernel = (1, 4, 4) + stride = (1, 2, 2) + padding = (0, 1, 1) + self.conv1 = BasicConv(in_channels, out_channels, deconv, is_3d, bn=True, relu=True, kernel_size=kernel, stride=stride, padding=padding) + else: + self.conv1 = BasicConv(in_channels, out_channels, deconv, is_3d, bn=True, relu=True, kernel_size=kernel, stride=2, padding=1) + + if self.concat: + mul = 2 if keep_concat else 1 + self.conv2 = BasicConv(out_channels*2, out_channels*mul, False, is_3d, bn, relu, kernel_size=3, stride=1, padding=1) + else: + self.conv2 = BasicConv(out_channels, out_channels, False, is_3d, bn, relu, kernel_size=3, stride=1, padding=1) + + def forward(self, x, rem): + x = self.conv1(x) + if x.shape != rem.shape: + x = F.interpolate( + x, + size=(rem.shape[-2], rem.shape[-1]), + mode='nearest') + if self.concat: + x = torch.cat((x, rem), 1) + else: + x = x + rem + x = self.conv2(x) + return x + +class Conv2x_IN(nn.Module): + + def __init__(self, in_channels, out_channels, deconv=False, is_3d=False, concat=True, keep_concat=True, IN=True, relu=True, keep_dispc=False): + super(Conv2x_IN, self).__init__() + self.concat = concat + self.is_3d = is_3d + if deconv and is_3d: + kernel = (4, 4, 4) + elif deconv: + kernel = 4 + else: + kernel = 3 + + if deconv and is_3d and keep_dispc: + kernel = (1, 4, 4) + stride = (1, 2, 2) + padding = (0, 1, 1) + self.conv1 = BasicConv_IN(in_channels, out_channels, deconv, is_3d, IN=True, relu=True, kernel_size=kernel, stride=stride, padding=padding) + else: + self.conv1 = BasicConv_IN(in_channels, out_channels, deconv, is_3d, IN=True, relu=True, kernel_size=kernel, stride=2, padding=1) + + if self.concat: + mul = 2 if keep_concat else 1 + self.conv2 = BasicConv_IN(out_channels*2, out_channels*mul, False, is_3d, IN, relu, kernel_size=3, stride=1, padding=1) + else: + self.conv2 = BasicConv_IN(out_channels, out_channels, False, is_3d, IN, relu, kernel_size=3, stride=1, padding=1) + + def forward(self, x, rem): + x = self.conv1(x) + if x.shape != rem.shape: + x = F.interpolate( + x, + size=(rem.shape[-2], rem.shape[-1]), + mode='nearest') + if self.concat: + x = torch.cat((x, rem), 1) + else: + x = x + rem + x = self.conv2(x) + return x + +class ConvReLU(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, pad=1, dilation=1): + super(ConvReLU, self).__init__() + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=pad, dilation=dilation, bias=False) + + def forward(self,x): + return F.relu(self.conv(x), inplace=True) + +class DepthInitialization(nn.Module): + def __init__(self, num_sample): + super(DepthInitialization, self).__init__() + self.num_sample = num_sample + + def forward(self, inverse_depth_min, inverse_depth_max, height, width, device): + batch = inverse_depth_min.size()[0] + index = torch.arange(0, self.num_sample, 1, device=device).view(1, self.num_sample, 1, 1).float() + normalized_sample = index.repeat(batch, 1, height, width) / (self.num_sample-1) + depth_sample = inverse_depth_max + normalized_sample * (inverse_depth_min - inverse_depth_max) + + depth_sample = 1.0 / depth_sample + + return depth_sample + +class PixelViewWeight(nn.Module): + def __init__(self, G): + super(PixelViewWeight, self).__init__() + self.conv = nn.Sequential( + ConvReLU(G, 16), + nn.Conv2d(16, 1, 1, stride=1, padding=0), + ) + + def forward(self, x): + # x: [B, G, N, H, W] + batch, dim, num_depth, height, width = x.size() + x = x.permute(0, 2, 1, 3, 4).contiguous() + x = x.view(batch*num_depth, dim, height, width) # [B*N,G,H,W] + x =self.conv(x).view(batch, num_depth, height, width) + x = torch.softmax(x,dim=1) + x = torch.max(x, dim=1)[0] + + return x.unsqueeze(1) + +class FeatureAtt(nn.Module): + def __init__(self, cv_chan, feat_chan): + super(FeatureAtt, self).__init__() + + self.feat_att = nn.Sequential( + BasicConv(feat_chan, feat_chan//2, kernel_size=1, stride=1, padding=0), + nn.Conv2d(feat_chan//2, cv_chan, 1)) + + def forward(self, cv, feat): + ''' + ''' + feat_att = self.feat_att(feat).unsqueeze(2) + cv = torch.sigmoid(feat_att)*cv + return cv + +class hourglass(nn.Module): + def __init__(self, in_channels): + super(hourglass, self).__init__() + + self.conv1 = nn.Sequential(BasicConv(in_channels, in_channels*2, is_3d=True, bn=True, relu=True, kernel_size=3, + padding=1, stride=2, dilation=1), + BasicConv(in_channels*2, in_channels*2, is_3d=True, bn=True, relu=True, kernel_size=3, + padding=1, stride=1, dilation=1)) + + self.conv2 = nn.Sequential(BasicConv(in_channels*2, in_channels*4, is_3d=True, bn=True, relu=True, kernel_size=3, + padding=1, stride=2, dilation=1), + BasicConv(in_channels*4, in_channels*4, is_3d=True, bn=True, relu=True, kernel_size=3, + padding=1, stride=1, dilation=1)) + + self.conv3 = nn.Sequential(BasicConv(in_channels*4, in_channels*6, is_3d=True, bn=True, relu=True, kernel_size=3, + padding=1, stride=2, dilation=1), + BasicConv(in_channels*6, in_channels*6, is_3d=True, bn=True, relu=True, kernel_size=3, + padding=1, stride=1, dilation=1)) + + + self.conv3_up = BasicConv(in_channels*6, in_channels*4, deconv=True, is_3d=True, bn=True, + relu=True, kernel_size=(4, 4, 4), padding=(1, 1, 1), stride=(2, 2, 2)) + + self.conv2_up = BasicConv(in_channels*4, in_channels*2, deconv=True, is_3d=True, bn=True, + relu=True, kernel_size=(4, 4, 4), padding=(1, 1, 1), stride=(2, 2, 2)) + + self.conv1_up = BasicConv(in_channels*2, 1, deconv=True, is_3d=True, bn=False, + relu=False, kernel_size=(4, 4, 4), padding=(1, 1, 1), stride=(2, 2, 2)) + + self.agg_0 = nn.Sequential(BasicConv(in_channels*8, in_channels*4, is_3d=True, kernel_size=1, padding=0, stride=1), + BasicConv(in_channels*4, in_channels*4, is_3d=True, kernel_size=3, padding=1, stride=1), + BasicConv(in_channels*4, in_channels*4, is_3d=True, kernel_size=3, padding=1, stride=1),) + + self.agg_1 = nn.Sequential(BasicConv(in_channels*4, in_channels*2, is_3d=True, kernel_size=1, padding=0, stride=1), + BasicConv(in_channels*2, in_channels*2, is_3d=True, kernel_size=3, padding=1, stride=1), + BasicConv(in_channels*2, in_channels*2, is_3d=True, kernel_size=3, padding=1, stride=1)) + + + self.feature_att_8 = FeatureAtt(in_channels*2, 64) + self.feature_att_16 = FeatureAtt(in_channels*4, 192) + self.feature_att_32 = FeatureAtt(in_channels*6, 160) + self.feature_att_up_16 = FeatureAtt(in_channels*4, 192) + self.feature_att_up_8 = FeatureAtt(in_channels*2, 64) + + def forward(self, x, features): + conv1 = self.conv1(x) + conv1 = self.feature_att_8(conv1, features[1]) + + conv2 = self.conv2(conv1) + conv2 = self.feature_att_16(conv2, features[2]) + + conv3 = self.conv3(conv2) + conv3 = self.feature_att_32(conv3, features[3]) + + conv3_up = self.conv3_up(conv3) + + conv2 = torch.cat((conv3_up, conv2), dim=1) + conv2 = self.agg_0(conv2) + conv2 = self.feature_att_up_16(conv2, features[2]) + + conv2_up = self.conv2_up(conv2) + + conv1 = torch.cat((conv2_up, conv1), dim=1) + conv1 = self.agg_1(conv1) + conv1 = self.feature_att_up_8(conv1, features[1]) + + conv = self.conv1_up(conv1) + + return conv + +def bilinear_sampler(img, coords, mode='bilinear', mask=False): + """ Wrapper for grid_sample, uses pixel coordinates """ + H, W = img.shape[-2:] + xgrid, ygrid = coords.split([1,1], dim=-1) + xgrid = 2*xgrid/(W-1) - 1 + + assert torch.unique(ygrid).numel() == 1 and H == 1 # This is a stereo problem + + grid = torch.cat([xgrid, ygrid], dim=-1) + img = F.grid_sample(img, grid, align_corners=True) + if mask: + mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) + return img, mask.float() + + return img + +def context_upsample(disp_low, up_weights): + ### + # cv (b,1,h,w) + # sp (b,9,4*h,4*w) + ### + b, c, h, w = disp_low.shape + + disp_unfold = F.unfold(disp_low.reshape(b,c,h,w),3,1,1).reshape(b,-1,h,w) + disp_unfold = F.interpolate(disp_unfold,(h*4,w*4),mode='nearest').reshape(b,9,h*4,w*4) + + disp = (disp_unfold*up_weights).sum(1) + + return disp + +def pool2x(x): + return F.avg_pool2d(x, 3, stride=2, padding=1) + +def interp(x, dest): + interp_args = {'mode': 'bilinear', 'align_corners': True} + return F.interpolate(x, dest.shape[2:], **interp_args) + +def differentiable_warping(src_fea, src_proj, ref_proj, depth_samples, return_mask=False): + # src_fea: [B, C, H, W] + # src_proj: [B, 4, 4] + # ref_proj: [B, 4, 4] + # depth_samples: [B, Ndepth, H, W] + # out: [B, C, Ndepth, H, W] + batch, num_depth, height, width = depth_samples.size() + height1, width1 = src_fea.size()[2:] + + with torch.no_grad(): + if batch==2: + inv_ref_proj = [] + for i in range(batch): + inv_ref_proj.append(torch.inverse(ref_proj[i]).unsqueeze(0)) + inv_ref_proj = torch.cat(inv_ref_proj, dim=0) + assert (not torch.isnan(inv_ref_proj).any()), "nan in inverse(ref_proj)" + proj = torch.matmul(src_proj, inv_ref_proj) + else: + proj = torch.matmul(src_proj, torch.inverse(ref_proj)) + assert (not torch.isnan(proj).any()), "nan in proj" + + rot = proj[:, :3, :3] # [B,3,3] + trans = proj[:, :3, 3:4] # [B,3,1] + y, x = torch.meshgrid([torch.arange(0, height, dtype=torch.float32, device=depth_samples.device), + torch.arange(0, width, dtype=torch.float32, device=depth_samples.device)]) + y, x = y.contiguous(), x.contiguous() + y, x = y.view(height * width), x.view(height * width) + y = y*(height1/height) + x = x*(width1/width) + xyz = torch.stack((x, y, torch.ones_like(x))) # [3, H*W] + xyz = torch.unsqueeze(xyz, 0).repeat(batch, 1, 1) # [B, 3, H*W] + rot_xyz = torch.matmul(rot, xyz) # [B, 3, H*W] + + rot_depth_xyz = rot_xyz.unsqueeze(2).repeat(1, 1, num_depth, 1) * depth_samples.view(batch, 1, num_depth, + height * width) # [B, 3, Ndepth, H*W] + proj_xyz = rot_depth_xyz + trans.view(batch, 3, 1, 1) # [B, 3, Ndepth, H*W] + # avoid negative depth + valid_mask = proj_xyz[:, 2:] > 1e-2 + proj_xyz[:, 0:1][~valid_mask] = width + proj_xyz[:, 1:2][~valid_mask] = height + proj_xyz[:, 2:3][~valid_mask] = 1 + proj_xy = proj_xyz[:, :2, :, :] / proj_xyz[:, 2:3, :, :] # [B, 2, Ndepth, H*W] + valid_mask = valid_mask & (proj_xy[:, 0:1] >=0) & (proj_xy[:, 0:1] < width) \ + & (proj_xy[:, 1:2] >=0) & (proj_xy[:, 1:2] < height) + proj_x_normalized = proj_xy[:, 0, :, :] / ((width1 - 1) / 2) - 1 # [B, Ndepth, H*W] + proj_y_normalized = proj_xy[:, 1, :, :] / ((height1 - 1) / 2) - 1 + proj_xy = torch.stack((proj_x_normalized, proj_y_normalized), dim=3) # [B, Ndepth, H*W, 2] + grid = proj_xy + + dim = src_fea.size()[1] + warped_src_fea = F.grid_sample(src_fea, grid.view(batch, num_depth * height, width, 2), mode='bilinear', + padding_mode='zeros',align_corners=True) + warped_src_fea = warped_src_fea.view(batch, dim, num_depth, height, width) + if return_mask: + valid_mask = valid_mask.view(batch,num_depth,height,width) + return warped_src_fea, valid_mask + else: + return warped_src_fea + +def depth_normalization(depth, inverse_depth_min, inverse_depth_max): + '''convert depth map to the index in inverse range''' + inverse_depth = 1.0 / (depth+1e-5) + normalized_depth = (inverse_depth - inverse_depth_max) / (inverse_depth_min - inverse_depth_max) + return normalized_depth + +def depth_unnormalization(normalized_depth, inverse_depth_min, inverse_depth_max): + '''convert the index in inverse range to depth map''' + inverse_depth = inverse_depth_max + normalized_depth * (inverse_depth_min - inverse_depth_max) # [B,1,H,W] + depth = 1.0 / inverse_depth + return depth \ No newline at end of file diff --git a/IGEV-MVS/core/update.py b/IGEV-MVS/core/update.py new file mode 100644 index 0000000..10a6f3f --- /dev/null +++ b/IGEV-MVS/core/update.py @@ -0,0 +1,94 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from .submodule import * + +class BasicMotionEncoder(nn.Module): + def __init__(self): + super(BasicMotionEncoder, self).__init__() + self.corr_levels = 2 + self.corr_radius = 4 + + cor_planes = 2 * self.corr_levels * (2*self.corr_radius + 1) + + self.convc1 = nn.Conv2d(cor_planes, 64, 1, padding=0) + self.convc2 = nn.Conv2d(64, 64, 3, padding=1) + self.convd1 = nn.Conv2d(1, 64, 7, padding=3) + self.convd2 = nn.Conv2d(64, 64, 3, padding=1) + self.conv = nn.Conv2d(64+64, 128-1, 3, padding=1) + + def forward(self, disp, corr): + cor = F.relu(self.convc1(corr)) + cor = F.relu(self.convc2(cor)) + disp_ = F.relu(self.convd1(disp)) + disp_ = F.relu(self.convd2(disp_)) + + cor_disp = torch.cat([cor, disp_], dim=1) + out = F.relu(self.conv(cor_disp)) + return torch.cat([out, disp], dim=1) + +class ConvGRU(nn.Module): + def __init__(self, hidden_dim, input_dim, kernel_size=3): + super(ConvGRU, self).__init__() + self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, kernel_size, padding=kernel_size//2) + self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, kernel_size, padding=kernel_size//2) + self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, kernel_size, padding=kernel_size//2) + + def forward(self, h, *x_list): + x = torch.cat(x_list, dim=1) + hx = torch.cat([h, x], dim=1) + z = torch.sigmoid(self.convz(hx)) + r = torch.sigmoid(self.convr(hx)) + q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1))) + h = (1-z) * h + z * q + return h + +class DispHead(nn.Module): + def __init__(self, input_dim=128, hidden_dim=256, output_dim=1): + super(DispHead, self).__init__() + self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) + self.conv2 = nn.Conv2d(hidden_dim, output_dim, 3, padding=1) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + return self.conv2(self.relu(self.conv1(x))) + +class BasicMultiUpdateBlock(nn.Module): + def __init__(self, hidden_dims=[]): + super().__init__() + self.n_gru_layers = 3 + self.n_downsample = 2 + self.encoder = BasicMotionEncoder() + encoder_output_dim = 128 + + self.gru04 = ConvGRU(hidden_dims[2], encoder_output_dim + hidden_dims[1] * (self.n_gru_layers > 1)) + self.gru08 = ConvGRU(hidden_dims[1], hidden_dims[0] * (self.n_gru_layers == 3) + hidden_dims[2]) + self.gru16 = ConvGRU(hidden_dims[0], hidden_dims[1]) + self.disp_head = DispHead(hidden_dims[2], hidden_dim=256, output_dim=1) + factor = 2**self.n_downsample + + self.mask_feat_4 = nn.Sequential( + nn.Conv2d(hidden_dims[2], 32, 3, padding=1), + nn.ReLU(inplace=True)) + + def forward(self, net, corr=None, disp=None, iter04=True, iter08=True, iter16=True, update=True): + if iter16: + net[2] = self.gru16(net[2], pool2x(net[1])) + if iter08: + if self.n_gru_layers > 2: + net[1] = self.gru08(net[1], pool2x(net[0]), interp(net[2], net[1])) + else: + net[1] = self.gru08(net[1], pool2x(net[0])) + if iter04: + motion_features = self.encoder(disp, corr) + if self.n_gru_layers > 1: + net[0] = self.gru04(net[0], motion_features, interp(net[1], net[0])) + else: + net[0] = self.gru04(net[0], motion_features) + + if not update: + return net + + delta_disp = self.disp_head(net[0]) + mask_feat_4 = self.mask_feat_4(net[0]) + return net, mask_feat_4, delta_disp \ No newline at end of file diff --git a/IGEV-MVS/datasets/__init__.py b/IGEV-MVS/datasets/__init__.py new file mode 100644 index 0000000..36e74fa --- /dev/null +++ b/IGEV-MVS/datasets/__init__.py @@ -0,0 +1,8 @@ +import importlib + + +# find the dataset definition by name, for example dtu_yao (dtu_yao.py) +def find_dataset_def(dataset_name): + module_name = 'datasets.{}'.format(dataset_name) + module = importlib.import_module(module_name) + return getattr(module, "MVSDataset") diff --git a/IGEV-MVS/datasets/blendedmvs.py b/IGEV-MVS/datasets/blendedmvs.py new file mode 100644 index 0000000..75795e4 --- /dev/null +++ b/IGEV-MVS/datasets/blendedmvs.py @@ -0,0 +1,208 @@ +from torch.utils.data import Dataset +from datasets.data_io import * +import os +import numpy as np +import cv2 +from PIL import Image +from torchvision import transforms as T +import random + +class MVSDataset(Dataset): + def __init__(self, datapath, listfile, split, nviews, img_wh=(768, 576), robust_train=True): + + super(MVSDataset, self).__init__() + self.levels = 4 + self.datapath = datapath + self.split = split + self.listfile = listfile + self.robust_train = robust_train + assert self.split in ['train', 'val', 'all'], \ + 'split must be either "train", "val" or "all"!' + + self.img_wh = img_wh + if img_wh is not None: + assert img_wh[0]%32==0 and img_wh[1]%32==0, \ + 'img_wh must both be multiples of 32!' + self.nviews = nviews + self.scale_factors = {} # depth scale factors for each scan + self.build_metas() + + self.color_augment = T.ColorJitter(brightness=0.5, contrast=0.5) + + def build_metas(self): + self.metas = [] + with open(self.listfile) as f: + self.scans = [line.rstrip() for line in f.readlines()] + for scan in self.scans: + with open(os.path.join(self.datapath, scan, "cams/pair.txt")) as f: + num_viewpoint = int(f.readline()) + for _ in range(num_viewpoint): + ref_view = int(f.readline().rstrip()) + src_views = [int(x) for x in f.readline().rstrip().split()[1::2]] + if len(src_views) >= self.nviews-1: + self.metas += [(scan, ref_view, src_views)] + + def read_cam_file(self, scan, filename): + with open(filename) as f: + lines = f.readlines() + lines = [line.rstrip() for line in lines] + # extrinsics: line [1,5), 4x4 matrix + extrinsics = np.fromstring(' '.join(lines[1:5]), dtype=np.float32, sep=' ').reshape((4, 4)) + # intrinsics: line [7-10), 3x3 matrix + intrinsics = np.fromstring(' '.join(lines[7:10]), dtype=np.float32, sep=' ').reshape((3, 3)) + depth_min = float(lines[11].split()[0]) + depth_max = float(lines[11].split()[-1]) + if scan not in self.scale_factors: + self.scale_factors[scan] = 100.0/depth_min + depth_min *= self.scale_factors[scan] + depth_max *= self.scale_factors[scan] + extrinsics[:3, 3] *= self.scale_factors[scan] + return intrinsics, extrinsics, depth_min, depth_max + + def read_depth_mask(self, scan, filename, depth_min, depth_max, scale): + depth = np.array(read_pfm(filename)[0], dtype=np.float32) + depth = depth * self.scale_factors[scan] * scale + depth = np.squeeze(depth,2) + + mask = (depth>=depth_min) & (depth<=depth_max) + mask = mask.astype(np.float32) + if self.img_wh is not None: + depth = cv2.resize(depth, self.img_wh, + interpolation=cv2.INTER_NEAREST) + h, w = depth.shape + depth_ms = {} + mask_ms = {} + + for i in range(4): + depth_cur = cv2.resize(depth, (w//(2**i), h//(2**i)), interpolation=cv2.INTER_NEAREST) + mask_cur = cv2.resize(mask, (w//(2**i), h//(2**i)), interpolation=cv2.INTER_NEAREST) + + depth_ms[f"level_{i}"] = depth_cur + mask_ms[f"level_{i}"] = mask_cur + + return depth_ms, mask_ms + + def read_img(self, filename): + img = Image.open(filename) + if self.split=='train': + img = self.color_augment(img) + # scale 0~255 to -1~1 + np_img = 2*np.array(img, dtype=np.float32) / 255. - 1 + if self.img_wh is not None: + np_img = cv2.resize(np_img, self.img_wh, + interpolation=cv2.INTER_LINEAR) + h, w, _ = np_img.shape + np_img_ms = { + "level_3": cv2.resize(np_img, (w//8, h//8), interpolation=cv2.INTER_LINEAR), + "level_2": cv2.resize(np_img, (w//4, h//4), interpolation=cv2.INTER_LINEAR), + "level_1": cv2.resize(np_img, (w//2, h//2), interpolation=cv2.INTER_LINEAR), + "level_0": np_img + } + return np_img_ms + + def __len__(self): + return len(self.metas) + + def __getitem__(self, idx): + meta = self.metas[idx] + scan, ref_view, src_views = meta + + if self.robust_train: + num_src_views = len(src_views) + index = random.sample(range(num_src_views), self.nviews - 1) + view_ids = [ref_view] + [src_views[i] for i in index] + scale = random.uniform(0.8, 1.25) + + else: + view_ids = [ref_view] + src_views[:self.nviews - 1] + scale = 1 + + imgs_0 = [] + imgs_1 = [] + imgs_2 = [] + imgs_3 = [] + mask = None + depth = None + depth_min = None + depth_max = None + proj_matrices_0 = [] + proj_matrices_1 = [] + proj_matrices_2 = [] + proj_matrices_3 = [] + + + for i, vid in enumerate(view_ids): + img_filename = os.path.join(self.datapath, '{}/blended_images/{:0>8}.jpg'.format(scan, vid)) + depth_filename = os.path.join(self.datapath, '{}/rendered_depth_maps/{:0>8}.pfm'.format(scan, vid)) + proj_mat_filename = os.path.join(self.datapath, '{}/cams/{:0>8}_cam.txt'.format(scan, vid)) + + imgs = self.read_img(img_filename) + imgs_0.append(imgs['level_0']) + imgs_1.append(imgs['level_1']) + imgs_2.append(imgs['level_2']) + imgs_3.append(imgs['level_3']) + + # here, the intrinsics from file is already adjusted to the downsampled size of feature 1/4H0 * 1/4W0 + intrinsics, extrinsics, depth_min_, depth_max_ = self.read_cam_file(scan, proj_mat_filename) + extrinsics[:3, 3] *= scale + proj_mat = extrinsics.copy() + intrinsics[:2,:] *= 0.125 + proj_mat[:3, :4] = np.matmul(intrinsics, proj_mat[:3, :4]) + proj_matrices_3.append(proj_mat) + + proj_mat = extrinsics.copy() + intrinsics[:2,:] *= 2 + proj_mat[:3, :4] = np.matmul(intrinsics, proj_mat[:3, :4]) + proj_matrices_2.append(proj_mat) + + proj_mat = extrinsics.copy() + intrinsics[:2,:] *= 2 + proj_mat[:3, :4] = np.matmul(intrinsics, proj_mat[:3, :4]) + proj_matrices_1.append(proj_mat) + + proj_mat = extrinsics.copy() + intrinsics[:2,:] *= 2 + proj_mat[:3, :4] = np.matmul(intrinsics, proj_mat[:3, :4]) + proj_matrices_0.append(proj_mat) + + + if i == 0: # reference view + depth_min = depth_min_ * scale + depth_max = depth_max_ * scale + depth, mask = self.read_depth_mask(scan, depth_filename, depth_min, depth_max, scale) + for l in range(self.levels): + mask[f'level_{l}'] = np.expand_dims(mask[f'level_{l}'],2) + mask[f'level_{l}'] = mask[f'level_{l}'].transpose([2,0,1]) + depth[f'level_{l}'] = np.expand_dims(depth[f'level_{l}'],2) + depth[f'level_{l}'] = depth[f'level_{l}'].transpose([2,0,1]) + + # imgs: N*3*H0*W0, N is number of images + imgs_0 = np.stack(imgs_0).transpose([0, 3, 1, 2]) + imgs_1 = np.stack(imgs_1).transpose([0, 3, 1, 2]) + imgs_2 = np.stack(imgs_2).transpose([0, 3, 1, 2]) + imgs_3 = np.stack(imgs_3).transpose([0, 3, 1, 2]) + imgs = {} + imgs['level_0'] = imgs_0 + imgs['level_1'] = imgs_1 + imgs['level_2'] = imgs_2 + imgs['level_3'] = imgs_3 + + # proj_matrices: N*4*4 + proj_matrices_0 = np.stack(proj_matrices_0) + proj_matrices_1 = np.stack(proj_matrices_1) + proj_matrices_2 = np.stack(proj_matrices_2) + proj_matrices_3 = np.stack(proj_matrices_3) + proj={} + proj['level_3']=proj_matrices_3 + proj['level_2']=proj_matrices_2 + proj['level_1']=proj_matrices_1 + proj['level_0']=proj_matrices_0 + + # data is numpy array + return {"imgs": imgs, # [N, 3, H, W] + "proj_matrices": proj, # [N,4,4] + "depth": depth, # [1, H, W] + "depth_min": depth_min, # scalar + "depth_max": depth_max, # scalar + "mask": mask} # [1, H, W] + \ No newline at end of file diff --git a/IGEV-MVS/datasets/custom.py b/IGEV-MVS/datasets/custom.py new file mode 100644 index 0000000..df82954 --- /dev/null +++ b/IGEV-MVS/datasets/custom.py @@ -0,0 +1,145 @@ +from torch.utils.data import Dataset +from datasets.data_io import * +import os +import numpy as np +import cv2 +from PIL import Image +from torchvision import transforms as T +import math + +class MVSDataset(Dataset): + def __init__(self, datapath, n_views=5, img_wh=(640,480)): + self.levels = 4 + self.datapath = datapath + self.img_wh = img_wh + self.build_metas() + self.n_views = n_views + + def build_metas(self): + self.metas = [] + with open(os.path.join(self.datapath, 'pair.txt')) as f: + num_viewpoint = int(f.readline()) + for view_idx in range(num_viewpoint): + ref_view = int(f.readline().rstrip()) + src_views = [int(x) for x in f.readline().rstrip().split()[1::2]] + if len(src_views) != 0: + self.metas += [(ref_view, src_views)] + + + def read_cam_file(self, filename): + with open(filename) as f: + lines = [line.rstrip() for line in f.readlines()] + # extrinsics: line [1,5), 4x4 matrix + extrinsics = np.fromstring(' '.join(lines[1:5]), dtype=np.float32, sep=' ') + extrinsics = extrinsics.reshape((4, 4)) + # intrinsics: line [7-10), 3x3 matrix + intrinsics = np.fromstring(' '.join(lines[7:10]), dtype=np.float32, sep=' ') + intrinsics = intrinsics.reshape((3, 3)) + + depth_min = float(lines[11].split()[0]) + depth_max = float(lines[11].split()[-1]) + + return intrinsics, extrinsics, depth_min, depth_max + + def read_img(self, filename, h, w): + img = Image.open(filename) + # scale 0~255 to -1~1 + np_img = 2*np.array(img, dtype=np.float32) / 255. - 1 + original_h, original_w, _ = np_img.shape + np_img = cv2.resize(np_img, self.img_wh, interpolation=cv2.INTER_LINEAR) + + np_img_ms = { + "level_3": cv2.resize(np_img, (w//8, h//8), interpolation=cv2.INTER_LINEAR), + "level_2": cv2.resize(np_img, (w//4, h//4), interpolation=cv2.INTER_LINEAR), + "level_1": cv2.resize(np_img, (w//2, h//2), interpolation=cv2.INTER_LINEAR), + "level_0": np_img + } + return np_img_ms, original_h, original_w + + def __len__(self): + return len(self.metas) + + def __getitem__(self, idx): + ref_view, src_views = self.metas[idx] + # use only the reference view and first nviews-1 source views + view_ids = [ref_view] + src_views[:self.n_views-1] + + imgs_0 = [] + imgs_1 = [] + imgs_2 = [] + imgs_3 = [] + + # depth = None + depth_min = None + depth_max = None + + proj_matrices_0 = [] + proj_matrices_1 = [] + proj_matrices_2 = [] + proj_matrices_3 = [] + + for i, vid in enumerate(view_ids): + img_filename = os.path.join(self.datapath, f'images/{vid:08d}.jpg') + proj_mat_filename = os.path.join(self.datapath, f'cams_1/{vid:08d}_cam.txt') + + imgs, original_h, original_w = self.read_img(img_filename,self.img_wh[1], self.img_wh[0]) + imgs_0.append(imgs['level_0']) + imgs_1.append(imgs['level_1']) + imgs_2.append(imgs['level_2']) + imgs_3.append(imgs['level_3']) + + intrinsics, extrinsics, depth_min_, depth_max_ = self.read_cam_file(proj_mat_filename) + intrinsics[0] *= self.img_wh[0]/original_w + intrinsics[1] *= self.img_wh[1]/original_h + + proj_mat = extrinsics.copy() + intrinsics[:2,:] *= 0.125 + proj_mat[:3, :4] = np.matmul(intrinsics, proj_mat[:3, :4]) + proj_matrices_3.append(proj_mat) + + proj_mat = extrinsics.copy() + intrinsics[:2,:] *= 2 + proj_mat[:3, :4] = np.matmul(intrinsics, proj_mat[:3, :4]) + proj_matrices_2.append(proj_mat) + + proj_mat = extrinsics.copy() + intrinsics[:2,:] *= 2 + proj_mat[:3, :4] = np.matmul(intrinsics, proj_mat[:3, :4]) + proj_matrices_1.append(proj_mat) + + proj_mat = extrinsics.copy() + intrinsics[:2,:] *= 2 + proj_mat[:3, :4] = np.matmul(intrinsics, proj_mat[:3, :4]) + proj_matrices_0.append(proj_mat) + + if i == 0: # reference view + depth_min = depth_min_ + depth_max = depth_max_ + + # imgs: N*3*H0*W0, N is number of images + imgs_0 = np.stack(imgs_0).transpose([0, 3, 1, 2]) + imgs_1 = np.stack(imgs_1).transpose([0, 3, 1, 2]) + imgs_2 = np.stack(imgs_2).transpose([0, 3, 1, 2]) + imgs_3 = np.stack(imgs_3).transpose([0, 3, 1, 2]) + imgs = {} + imgs['level_0'] = imgs_0 + imgs['level_1'] = imgs_1 + imgs['level_2'] = imgs_2 + imgs['level_3'] = imgs_3 + # proj_matrices: N*4*4 + proj_matrices_0 = np.stack(proj_matrices_0) + proj_matrices_1 = np.stack(proj_matrices_1) + proj_matrices_2 = np.stack(proj_matrices_2) + proj_matrices_3 = np.stack(proj_matrices_3) + proj={} + proj['level_3']=proj_matrices_3 + proj['level_2']=proj_matrices_2 + proj['level_1']=proj_matrices_1 + proj['level_0']=proj_matrices_0 + + return {"imgs": imgs, # N*3*H0*W0 + "proj_matrices": proj, # N*4*4 + "depth_min": depth_min, # scalar + "depth_max": depth_max, + "filename": '{}/' + '{:0>8}'.format(view_ids[0]) + "{}" + } diff --git a/IGEV-MVS/datasets/data_io.py b/IGEV-MVS/datasets/data_io.py new file mode 100644 index 0000000..9b31c9f --- /dev/null +++ b/IGEV-MVS/datasets/data_io.py @@ -0,0 +1,73 @@ +import numpy as np +import re +import sys + + +def read_pfm(filename): + # rb: binary file and read only + file = open(filename, 'rb') + color = None + width = None + height = None + scale = None + endian = None + + header = file.readline().decode('utf-8').rstrip() + if header == 'PF': + color = True + elif header == 'Pf': # depth is Pf + color = False + else: + raise Exception('Not a PFM file.') + + dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode('utf-8')) # re is used for matching + if dim_match: + width, height = map(int, dim_match.groups()) + else: + raise Exception('Malformed PFM header.') + + scale = float(file.readline().rstrip()) + if scale < 0: # little-endian + endian = '<' + scale = -scale + else: + endian = '>' # big-endian + + data = np.fromfile(file, endian + 'f') + shape = (height, width, 3) if color else (height, width, 1) + # depth: H*W + data = np.reshape(data, shape) + data = np.flipud(data) + file.close() + return data, scale + + +def save_pfm(filename, image, scale=1): + file = open(filename, "wb") + color = None + + image = np.flipud(image) + # print(image.shape) + + if image.dtype.name != 'float32': + raise Exception('Image dtype must be float32.') + + if len(image.shape) == 3 and image.shape[2] == 3: # color image + color = True + elif len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1: # greyscale + color = False + else: + raise Exception('Image must have H x W x 3, H x W x 1 or H x W dimensions.') + + file.write('PF\n'.encode('utf-8') if color else 'Pf\n'.encode('utf-8')) + file.write('{} {}\n'.format(image.shape[1], image.shape[0]).encode('utf-8')) + + endian = image.dtype.byteorder + + if endian == '<' or endian == '=' and sys.byteorder == 'little': + scale = -scale + + file.write(('%f\n' % scale).encode('utf-8')) + + image.tofile(file) + file.close() diff --git a/IGEV-MVS/datasets/dtu_yao.py b/IGEV-MVS/datasets/dtu_yao.py new file mode 100644 index 0000000..bf05f71 --- /dev/null +++ b/IGEV-MVS/datasets/dtu_yao.py @@ -0,0 +1,236 @@ +from torch.utils.data import Dataset +import numpy as np +import os +from PIL import Image +from datasets.data_io import * +import cv2 +import random +from torchvision import transforms + + +class MVSDataset(Dataset): + def __init__(self, datapath, listfile, mode, nviews, robust_train = False): + super(MVSDataset, self).__init__() + + self.levels = 4 + self.datapath = datapath + self.listfile = listfile + self.mode = mode + self.nviews = nviews + self.img_wh = (640, 512) + # self.img_wh = (1440, 1056) + self.robust_train = robust_train + + + assert self.mode in ["train", "val", "test"] + self.metas = self.build_list() + self.color_augment = transforms.ColorJitter(brightness=0.5, contrast=0.5) + + def build_list(self): + metas = [] + with open(self.listfile) as f: + scans = f.readlines() + scans = [line.rstrip() for line in scans] + + for scan in scans: + pair_file = "Cameras_1/pair.txt" + + with open(os.path.join(self.datapath, pair_file)) as f: + self.num_viewpoint = int(f.readline()) + # viewpoints (49) + for view_idx in range(self.num_viewpoint): + ref_view = int(f.readline().rstrip()) + src_views = [int(x) for x in f.readline().rstrip().split()[1::2]] + # light conditions 0-6 + for light_idx in range(7): + metas.append((scan, light_idx, ref_view, src_views)) + print("dataset", self.mode, "metas:", len(metas)) + return metas + + def __len__(self): + return len(self.metas) + + def read_cam_file(self, filename): + with open(filename) as f: + lines = f.readlines() + lines = [line.rstrip() for line in lines] + # extrinsics: line [1,5), 4x4 matrix + extrinsics = np.fromstring(' '.join(lines[1:5]), dtype=np.float32, sep=' ').reshape((4, 4)) + # intrinsics: line [7-10), 3x3 matrix + intrinsics = np.fromstring(' '.join(lines[7:10]), dtype=np.float32, sep=' ').reshape((3, 3)) + depth_min = float(lines[11].split()[0]) + depth_max = float(lines[11].split()[-1]) + return intrinsics, extrinsics, depth_min, depth_max + + def read_img(self, filename): + img = Image.open(filename) + if self.mode=='train': + img = self.color_augment(img) + # scale 0~255 to -1~1 + np_img = 2*np.array(img, dtype=np.float32) / 255. - 1 + h, w, _ = np_img.shape + np_img_ms = { + "level_3": cv2.resize(np_img, (w//8, h//8), interpolation=cv2.INTER_LINEAR), + "level_2": cv2.resize(np_img, (w//4, h//4), interpolation=cv2.INTER_LINEAR), + "level_1": cv2.resize(np_img, (w//2, h//2), interpolation=cv2.INTER_LINEAR), + "level_0": np_img + } + return np_img_ms + + + def prepare_img(self, hr_img): + #downsample + h, w = hr_img.shape + # original w,h: 1600, 1200; downsample -> 800, 600 ; crop -> 640, 512 + hr_img = cv2.resize(hr_img, (w//2, h//2), interpolation=cv2.INTER_NEAREST) + #crop + h, w = hr_img.shape + target_h, target_w = self.img_wh[1], self.img_wh[0] + start_h, start_w = (h - target_h)//2, (w - target_w)//2 + hr_img_crop = hr_img[start_h: start_h + target_h, start_w: start_w + target_w] + + return hr_img_crop + + def read_mask(self, filename): + img = Image.open(filename) + np_img = np.array(img, dtype=np.float32) + np_img = (np_img > 10).astype(np.float32) + return np_img + + + def read_depth_mask(self, filename, mask_filename, scale): + depth_hr = np.array(read_pfm(filename)[0], dtype=np.float32) * scale + depth_hr = np.squeeze(depth_hr,2) + depth_lr = self.prepare_img(depth_hr) + mask = self.read_mask(mask_filename) + mask = self.prepare_img(mask) + mask = mask.astype(np.bool_) + mask = mask.astype(np.float32) + + h, w = depth_lr.shape + depth_lr_ms = {} + mask_ms = {} + + for i in range(self.levels): + depth_cur = cv2.resize(depth_lr, (w//(2**i), h//(2**i)), interpolation=cv2.INTER_NEAREST) + mask_cur = cv2.resize(mask, (w//(2**i), h//(2**i)), interpolation=cv2.INTER_NEAREST) + depth_lr_ms[f"level_{i}"] = depth_cur + mask_ms[f"level_{i}"] = mask_cur + + return depth_lr_ms, mask_ms + + + def __getitem__(self, idx): + meta = self.metas[idx] + scan, light_idx, ref_view, src_views = meta + # robust training strategy + if self.robust_train: + num_src_views = len(src_views) + index = random.sample(range(num_src_views), self.nviews - 1) + view_ids = [ref_view] + [src_views[i] for i in index] + scale = random.uniform(0.8, 1.25) + + else: + view_ids = [ref_view] + src_views[:self.nviews - 1] + scale = 1 + + imgs_0 = [] + imgs_1 = [] + imgs_2 = [] + imgs_3 = [] + + mask = None + depth = None + depth_min = None + depth_max = None + + proj_matrices_0 = [] + proj_matrices_1 = [] + proj_matrices_2 = [] + proj_matrices_3 = [] + + + + for i, vid in enumerate(view_ids): + img_filename = os.path.join(self.datapath, + 'Rectified/{}_train/rect_{:0>3}_{}_r5000.png'.format(scan, vid + 1, light_idx)) + proj_mat_filename = os.path.join(self.datapath, 'Cameras_1/{}_train/{:0>8}_cam.txt').format(scan, vid) + + mask_filename = os.path.join(self.datapath, 'Depths_raw/{}/depth_visual_{:0>4}.png'.format(scan, vid)) + depth_filename = os.path.join(self.datapath, 'Depths_raw/{}/depth_map_{:0>4}.pfm'.format(scan, vid)) + + imgs = self.read_img(img_filename) + imgs_0.append(imgs['level_0']) + imgs_1.append(imgs['level_1']) + imgs_2.append(imgs['level_2']) + imgs_3.append(imgs['level_3']) + + intrinsics, extrinsics, depth_min_, depth_max_ = self.read_cam_file(proj_mat_filename) + extrinsics[:3,3] *= scale + intrinsics[0] *= 4 + intrinsics[1] *= 4 + + proj_mat = extrinsics.copy() + intrinsics[:2,:] *= 0.125 + proj_mat[:3, :4] = np.matmul(intrinsics, proj_mat[:3, :4]) + proj_matrices_3.append(proj_mat) + + proj_mat = extrinsics.copy() + intrinsics[:2,:] *= 2 + proj_mat[:3, :4] = np.matmul(intrinsics, proj_mat[:3, :4]) + proj_matrices_2.append(proj_mat) + + proj_mat = extrinsics.copy() + intrinsics[:2,:] *= 2 + proj_mat[:3, :4] = np.matmul(intrinsics, proj_mat[:3, :4]) + proj_matrices_1.append(proj_mat) + + proj_mat = extrinsics.copy() + intrinsics[:2,:] *= 2 + proj_mat[:3, :4] = np.matmul(intrinsics, proj_mat[:3, :4]) + proj_matrices_0.append(proj_mat) + + if i == 0: # reference view + depth_min = depth_min_ * scale + depth_max = depth_max_ * scale + depth, mask = self.read_depth_mask(depth_filename, mask_filename, scale) + + for l in range(self.levels): + mask[f'level_{l}'] = np.expand_dims(mask[f'level_{l}'],2) + mask[f'level_{l}'] = mask[f'level_{l}'].transpose([2,0,1]) + depth[f'level_{l}'] = np.expand_dims(depth[f'level_{l}'],2) + depth[f'level_{l}'] = depth[f'level_{l}'].transpose([2,0,1]) + + # imgs: N*3*H0*W0, N is number of images + imgs_0 = np.stack(imgs_0).transpose([0, 3, 1, 2]) + imgs_1 = np.stack(imgs_1).transpose([0, 3, 1, 2]) + imgs_2 = np.stack(imgs_2).transpose([0, 3, 1, 2]) + imgs_3 = np.stack(imgs_3).transpose([0, 3, 1, 2]) + + imgs = {} + imgs['level_0'] = imgs_0 + imgs['level_1'] = imgs_1 + imgs['level_2'] = imgs_2 + imgs['level_3'] = imgs_3 + + # proj_matrices: N*4*4 + proj_matrices_0 = np.stack(proj_matrices_0) + proj_matrices_1 = np.stack(proj_matrices_1) + proj_matrices_2 = np.stack(proj_matrices_2) + proj_matrices_3 = np.stack(proj_matrices_3) + + proj={} + proj['level_3']=proj_matrices_3 + proj['level_2']=proj_matrices_2 + proj['level_1']=proj_matrices_1 + proj['level_0']=proj_matrices_0 + + + # data is numpy array + return {"imgs": imgs, # [N, 3, H, W] + "proj_matrices": proj, # [N,4,4] + "depth": depth, # [1, H, W] + "depth_min": depth_min, # scalar + "depth_max": depth_max, # scalar + "mask": mask} # [1, H, W] + diff --git a/IGEV-MVS/datasets/dtu_yao_eval.py b/IGEV-MVS/datasets/dtu_yao_eval.py new file mode 100644 index 0000000..59a92a5 --- /dev/null +++ b/IGEV-MVS/datasets/dtu_yao_eval.py @@ -0,0 +1,158 @@ +from torch.utils.data import Dataset +import numpy as np +import os +from PIL import Image +from datasets.data_io import * +import cv2 + + +class MVSDataset(Dataset): + def __init__(self, datapath, listfile, nviews=5, img_wh=(1600, 1152)): + super(MVSDataset, self).__init__() + self.levels = 4 + self.datapath = datapath + self.listfile = listfile + self.nviews = nviews + self.img_wh = img_wh + self.metas = self.build_list() + + def build_list(self): + metas = [] + with open(self.listfile) as f: + scans = f.readlines() + scans = [line.rstrip() for line in scans] + + for scan in scans: + pair_file = "{}/pair.txt".format(scan) + # read the pair file + with open(os.path.join(self.datapath, pair_file)) as f: + num_viewpoint = int(f.readline()) + # viewpoints (49) + for view_idx in range(num_viewpoint): + ref_view = int(f.readline().rstrip()) + src_views = [int(x) for x in f.readline().rstrip().split()[1::2]] + metas.append((scan, ref_view, src_views)) + print("dataset", "metas:", len(metas)) + return metas + + def __len__(self): + return len(self.metas) + + def read_cam_file(self, filename): + with open(filename) as f: + lines = f.readlines() + lines = [line.rstrip() for line in lines] + # extrinsics: line [1,5), 4x4 matrix + extrinsics = np.fromstring(' '.join(lines[1:5]), dtype=np.float32, sep=' ').reshape((4, 4)) + # intrinsics: line [7-10), 3x3 matrix + intrinsics = np.fromstring(' '.join(lines[7:10]), dtype=np.float32, sep=' ').reshape((3, 3)) + + depth_min = float(lines[11].split()[0]) + depth_max = float(lines[11].split()[-1]) + return intrinsics, extrinsics, depth_min, depth_max + + + def read_mask(self, filename): + img = Image.open(filename) + np_img = np.array(img, dtype=np.float32) + np_img = (np_img > 10).astype(np.float32) + return np_img + + def read_img(self, filename): + img = Image.open(filename) + # scale 0~255 to -1~1 + np_img = 2*np.array(img, dtype=np.float32) / 255. - 1 + np_img = cv2.resize(np_img, self.img_wh, interpolation=cv2.INTER_LINEAR) + + h, w, _ = np_img.shape + np_img_ms = { + "level_3": cv2.resize(np_img, (w//8, h//8), interpolation=cv2.INTER_LINEAR), + "level_2": cv2.resize(np_img, (w//4, h//4), interpolation=cv2.INTER_LINEAR), + "level_1": cv2.resize(np_img, (w//2, h//2), interpolation=cv2.INTER_LINEAR), + "level_0": np_img + } + return np_img_ms + + def __getitem__(self, idx): + scan, ref_view, src_views = self.metas[idx] + # use only the reference view and first nviews-1 source views + view_ids = [ref_view] + src_views[:self.nviews - 1] + img_w = 1600 + img_h = 1200 + imgs_0 = [] + imgs_1 = [] + imgs_2 = [] + imgs_3 = [] + + depth_min = None + depth_max = None + + proj_matrices_0 = [] + proj_matrices_1 = [] + proj_matrices_2 = [] + proj_matrices_3 = [] + + for i, vid in enumerate(view_ids): + img_filename = os.path.join(self.datapath, '{}/images/{:0>8}.jpg'.format(scan, vid)) + proj_mat_filename = os.path.join(self.datapath, '{}/cams_1/{:0>8}_cam.txt'.format(scan, vid)) + + imgs = self.read_img(img_filename) + imgs_0.append(imgs['level_0']) + imgs_1.append(imgs['level_1']) + imgs_2.append(imgs['level_2']) + imgs_3.append(imgs['level_3']) + + intrinsics, extrinsics, depth_min_, depth_max_ = self.read_cam_file(proj_mat_filename) + intrinsics[0] *= self.img_wh[0]/img_w + intrinsics[1] *= self.img_wh[1]/img_h + proj_mat = extrinsics.copy() + intrinsics[:2,:] *= 0.125 + proj_mat[:3, :4] = np.matmul(intrinsics, proj_mat[:3, :4]) + proj_matrices_3.append(proj_mat) + + proj_mat = extrinsics.copy() + intrinsics[:2,:] *= 2 + proj_mat[:3, :4] = np.matmul(intrinsics, proj_mat[:3, :4]) + proj_matrices_2.append(proj_mat) + + proj_mat = extrinsics.copy() + intrinsics[:2,:] *= 2 + proj_mat[:3, :4] = np.matmul(intrinsics, proj_mat[:3, :4]) + proj_matrices_1.append(proj_mat) + + proj_mat = extrinsics.copy() + intrinsics[:2,:] *= 2 + proj_mat[:3, :4] = np.matmul(intrinsics, proj_mat[:3, :4]) + proj_matrices_0.append(proj_mat) + + + if i == 0: # reference view + depth_min = depth_min_ + depth_max = depth_max_ + + imgs_0 = np.stack(imgs_0).transpose([0, 3, 1, 2]) + imgs_1 = np.stack(imgs_1).transpose([0, 3, 1, 2]) + imgs_2 = np.stack(imgs_2).transpose([0, 3, 1, 2]) + imgs_3 = np.stack(imgs_3).transpose([0, 3, 1, 2]) + imgs = {} + imgs['level_0'] = imgs_0 + imgs['level_1'] = imgs_1 + imgs['level_2'] = imgs_2 + imgs['level_3'] = imgs_3 + # proj_matrices: N*4*4 + proj_matrices_0 = np.stack(proj_matrices_0) + proj_matrices_1 = np.stack(proj_matrices_1) + proj_matrices_2 = np.stack(proj_matrices_2) + proj_matrices_3 = np.stack(proj_matrices_3) + proj={} + proj['level_3']=proj_matrices_3 + proj['level_2']=proj_matrices_2 + proj['level_1']=proj_matrices_1 + proj['level_0']=proj_matrices_0 + + + return {"imgs": imgs, # N*3*H0*W0 + "proj_matrices": proj, # N*4*4 + "depth_min": depth_min, # scalar + "depth_max": depth_max, # scalar + "filename": scan + '/{}/' + '{:0>8}'.format(view_ids[0]) + "{}"} diff --git a/IGEV-MVS/datasets/eth3d.py b/IGEV-MVS/datasets/eth3d.py new file mode 100644 index 0000000..f8a6cdf --- /dev/null +++ b/IGEV-MVS/datasets/eth3d.py @@ -0,0 +1,158 @@ +from torch.utils.data import Dataset +from datasets.data_io import * +import os +import numpy as np +import cv2 +from PIL import Image + +class MVSDataset(Dataset): + def __init__(self, datapath, split='test', n_views=7, img_wh=(1920,1280)): + self.levels = 4 + self.datapath = datapath + self.img_wh = img_wh + self.split = split + self.build_metas() + self.n_views = n_views + + def build_metas(self): + self.metas = [] + if self.split == "test": + self.scans = ['botanical_garden', 'boulders', 'bridge', 'door', + 'exhibition_hall', 'lecture_room', 'living_room', 'lounge', + 'observatory', 'old_computer', 'statue', 'terrace_2'] + + elif self.split == "train": + self.scans = ['courtyard', 'delivery_area', 'electro', 'facade', + 'kicker', 'meadow', 'office', 'pipes', 'playground', + 'relief', 'relief_2', 'terrace', 'terrains'] + + + for scan in self.scans: + with open(os.path.join(self.datapath, scan, 'pair.txt')) as f: + num_viewpoint = int(f.readline()) + for view_idx in range(num_viewpoint): + ref_view = int(f.readline().rstrip()) + src_views = [int(x) for x in f.readline().rstrip().split()[1::2]] + if len(src_views) != 0: + self.metas += [(scan, -1, ref_view, src_views)] + + + def read_cam_file(self, filename): + with open(filename) as f: + lines = [line.rstrip() for line in f.readlines()] + # extrinsics: line [1,5), 4x4 matrix + extrinsics = np.fromstring(' '.join(lines[1:5]), dtype=np.float32, sep=' ') + extrinsics = extrinsics.reshape((4, 4)) + # intrinsics: line [7-10), 3x3 matrix + intrinsics = np.fromstring(' '.join(lines[7:10]), dtype=np.float32, sep=' ') + intrinsics = intrinsics.reshape((3, 3)) + + depth_min = float(lines[11].split()[0]) + if depth_min < 0: + depth_min = 1 + depth_max = float(lines[11].split()[-1]) + + return intrinsics, extrinsics, depth_min, depth_max + + def read_img(self, filename, h, w): + img = Image.open(filename) + # scale 0~255 to -1~1 + np_img = 2*np.array(img, dtype=np.float32) / 255. - 1 + original_h, original_w, _ = np_img.shape + np_img = cv2.resize(np_img, self.img_wh, interpolation=cv2.INTER_LINEAR) + + np_img_ms = { + "level_3": cv2.resize(np_img, (w//8, h//8), interpolation=cv2.INTER_LINEAR), + "level_2": cv2.resize(np_img, (w//4, h//4), interpolation=cv2.INTER_LINEAR), + "level_1": cv2.resize(np_img, (w//2, h//2), interpolation=cv2.INTER_LINEAR), + "level_0": np_img + } + return np_img_ms, original_h, original_w + + def __len__(self): + return len(self.metas) + + def __getitem__(self, idx): + scan, _, ref_view, src_views = self.metas[idx] + # use only the reference view and first nviews-1 source views + view_ids = [ref_view] + src_views[:self.n_views-1] + imgs_0 = [] + imgs_1 = [] + imgs_2 = [] + imgs_3 = [] + + # depth = None + depth_min = None + depth_max = None + + proj_matrices_0 = [] + proj_matrices_1 = [] + proj_matrices_2 = [] + proj_matrices_3 = [] + + for i, vid in enumerate(view_ids): + img_filename = os.path.join(self.datapath, scan, f'images/{vid:08d}.jpg') + proj_mat_filename = os.path.join(self.datapath, scan, f'cams_1/{vid:08d}_cam.txt') + + imgs, original_h, original_w = self.read_img(img_filename,self.img_wh[1], self.img_wh[0]) + imgs_0.append(imgs['level_0']) + imgs_1.append(imgs['level_1']) + imgs_2.append(imgs['level_2']) + imgs_3.append(imgs['level_3']) + + intrinsics, extrinsics, depth_min_, depth_max_ = self.read_cam_file(proj_mat_filename) + intrinsics[0] *= self.img_wh[0]/original_w + intrinsics[1] *= self.img_wh[1]/original_h + + proj_mat = extrinsics.copy() + intrinsics[:2,:] *= 0.125 + proj_mat[:3, :4] = np.matmul(intrinsics, proj_mat[:3, :4]) + proj_matrices_3.append(proj_mat) + + proj_mat = extrinsics.copy() + intrinsics[:2,:] *= 2 + proj_mat[:3, :4] = np.matmul(intrinsics, proj_mat[:3, :4]) + proj_matrices_2.append(proj_mat) + + proj_mat = extrinsics.copy() + intrinsics[:2,:] *= 2 + proj_mat[:3, :4] = np.matmul(intrinsics, proj_mat[:3, :4]) + proj_matrices_1.append(proj_mat) + + proj_mat = extrinsics.copy() + intrinsics[:2,:] *= 2 + proj_mat[:3, :4] = np.matmul(intrinsics, proj_mat[:3, :4]) + proj_matrices_0.append(proj_mat) + + if i == 0: # reference view + depth_min = depth_min_ + depth_max = depth_max_ + + # imgs: N*3*H0*W0, N is number of images + imgs_0 = np.stack(imgs_0).transpose([0, 3, 1, 2]) + imgs_1 = np.stack(imgs_1).transpose([0, 3, 1, 2]) + imgs_2 = np.stack(imgs_2).transpose([0, 3, 1, 2]) + imgs_3 = np.stack(imgs_3).transpose([0, 3, 1, 2]) + imgs = {} + imgs['level_0'] = imgs_0 + imgs['level_1'] = imgs_1 + imgs['level_2'] = imgs_2 + imgs['level_3'] = imgs_3 + # proj_matrices: N*4*4 + proj_matrices_0 = np.stack(proj_matrices_0) + proj_matrices_1 = np.stack(proj_matrices_1) + proj_matrices_2 = np.stack(proj_matrices_2) + proj_matrices_3 = np.stack(proj_matrices_3) + proj={} + proj['level_3']=proj_matrices_3 + proj['level_2']=proj_matrices_2 + proj['level_1']=proj_matrices_1 + proj['level_0']=proj_matrices_0 + + + return {"imgs": imgs, # N*3*H0*W0 + "proj_matrices": proj, # N*4*4 + "depth_min": depth_min, # scalar + "depth_max": depth_max, + "filename": scan + '/{}/' + '{:0>8}'.format(view_ids[0]) + "{}" + } diff --git a/IGEV-MVS/datasets/tanks.py b/IGEV-MVS/datasets/tanks.py new file mode 100644 index 0000000..9046cc4 --- /dev/null +++ b/IGEV-MVS/datasets/tanks.py @@ -0,0 +1,156 @@ +from torch.utils.data import Dataset +from datasets.data_io import * +import os +import numpy as np +import cv2 +from PIL import Image + +class MVSDataset(Dataset): + def __init__(self, datapath, n_views=7, img_wh=(1920, 1024), split='intermediate'): + self.levels = 4 + self.datapath = datapath + self.img_wh = img_wh + self.split = split + self.build_metas() + self.n_views = n_views + + def build_metas(self): + self.metas = [] + if self.split == 'intermediate': + self.scans = ['Family', 'Francis', 'Horse', 'Lighthouse', + 'M60', 'Panther', 'Playground', 'Train'] + + elif self.split == 'advanced': + self.scans = ['Auditorium', 'Ballroom', 'Courtroom', + 'Museum', 'Palace', 'Temple'] + + for scan in self.scans: + with open(os.path.join(self.datapath, self.split, scan, 'pair.txt')) as f: + num_viewpoint = int(f.readline()) + for view_idx in range(num_viewpoint): + ref_view = int(f.readline().rstrip()) + src_views = [int(x) for x in f.readline().rstrip().split()[1::2]] + if len(src_views) != 0: + self.metas += [(scan, -1, ref_view, src_views)] + + def read_cam_file(self, filename): + with open(filename) as f: + lines = [line.rstrip() for line in f.readlines()] + # extrinsics: line [1,5), 4x4 matrix + extrinsics = np.fromstring(' '.join(lines[1:5]), dtype=np.float32, sep=' ') + extrinsics = extrinsics.reshape((4, 4)) + # intrinsics: line [7-10), 3x3 matrix + intrinsics = np.fromstring(' '.join(lines[7:10]), dtype=np.float32, sep=' ') + intrinsics = intrinsics.reshape((3, 3)) + + depth_min = float(lines[11].split()[0]) + depth_max = float(lines[11].split()[-1]) + + return intrinsics, extrinsics, depth_min, depth_max + + def read_img(self, filename, h, w): + img = Image.open(filename) + # scale 0~255 to -1~1 + np_img = 2*np.array(img, dtype=np.float32) / 255. - 1 + original_h, original_w, _ = np_img.shape + np_img = cv2.resize(np_img, self.img_wh, interpolation=cv2.INTER_LINEAR) + + np_img_ms = { + "level_3": cv2.resize(np_img, (w//8, h//8), interpolation=cv2.INTER_LINEAR), + "level_2": cv2.resize(np_img, (w//4, h//4), interpolation=cv2.INTER_LINEAR), + "level_1": cv2.resize(np_img, (w//2, h//2), interpolation=cv2.INTER_LINEAR), + "level_0": np_img + } + return np_img_ms, original_h, original_w + + def __len__(self): + return len(self.metas) + + def __getitem__(self, idx): + scan, _, ref_view, src_views = self.metas[idx] + # use only the reference view and first nviews-1 source views + view_ids = [ref_view] + src_views[:self.n_views-1] + + imgs_0 = [] + imgs_1 = [] + imgs_2 = [] + imgs_3 = [] + + # depth = None + depth_min = None + depth_max = None + + proj_matrices_0 = [] + proj_matrices_1 = [] + proj_matrices_2 = [] + proj_matrices_3 = [] + + for i, vid in enumerate(view_ids): + img_filename = os.path.join(self.datapath, self.split, scan, f'images/{vid:08d}.jpg') + proj_mat_filename = os.path.join(self.datapath, self.split, scan, f'cams_1/{vid:08d}_cam.txt') + + imgs, original_h, original_w = self.read_img(img_filename,self.img_wh[1], self.img_wh[0]) + imgs_0.append(imgs['level_0']) + imgs_1.append(imgs['level_1']) + imgs_2.append(imgs['level_2']) + imgs_3.append(imgs['level_3']) + + intrinsics, extrinsics, depth_min_, depth_max_ = self.read_cam_file(proj_mat_filename) + intrinsics[0] *= self.img_wh[0]/original_w + intrinsics[1] *= self.img_wh[1]/original_h + + proj_mat = extrinsics.copy() + intrinsics[:2,:] *= 0.125 + proj_mat[:3, :4] = np.matmul(intrinsics, proj_mat[:3, :4]) + proj_matrices_3.append(proj_mat) + + proj_mat = extrinsics.copy() + intrinsics[:2,:] *= 2 + proj_mat[:3, :4] = np.matmul(intrinsics, proj_mat[:3, :4]) + proj_matrices_2.append(proj_mat) + + proj_mat = extrinsics.copy() + intrinsics[:2,:] *= 2 + proj_mat[:3, :4] = np.matmul(intrinsics, proj_mat[:3, :4]) + proj_matrices_1.append(proj_mat) + + proj_mat = extrinsics.copy() + intrinsics[:2,:] *= 2 + proj_mat[:3, :4] = np.matmul(intrinsics, proj_mat[:3, :4]) + proj_matrices_0.append(proj_mat) + + + if i == 0: # reference view + depth_min = depth_min_ + depth_max = depth_max_ + + # imgs: N*3*H0*W0, N is number of images + imgs_0 = np.stack(imgs_0).transpose([0, 3, 1, 2]) + imgs_1 = np.stack(imgs_1).transpose([0, 3, 1, 2]) + imgs_2 = np.stack(imgs_2).transpose([0, 3, 1, 2]) + imgs_3 = np.stack(imgs_3).transpose([0, 3, 1, 2]) + imgs = {} + imgs['level_0'] = imgs_0 + imgs['level_1'] = imgs_1 + imgs['level_2'] = imgs_2 + imgs['level_3'] = imgs_3 + # proj_matrices: N*4*4 + proj_matrices_0 = np.stack(proj_matrices_0) + proj_matrices_1 = np.stack(proj_matrices_1) + proj_matrices_2 = np.stack(proj_matrices_2) + proj_matrices_3 = np.stack(proj_matrices_3) + proj={} + proj['level_3']=proj_matrices_3 + proj['level_2']=proj_matrices_2 + proj['level_1']=proj_matrices_1 + proj['level_0']=proj_matrices_0 + + + + + return {"imgs": imgs, # N*3*H0*W0 + "proj_matrices": proj, # N*4*4 + "depth_min": depth_min, # scalar + "depth_max": depth_max, + "filename": scan + '/{}/' + '{:0>8}'.format(view_ids[0]) + "{}" + } diff --git a/IGEV-MVS/evaluate_mvs.py b/IGEV-MVS/evaluate_mvs.py new file mode 100644 index 0000000..8b15d9c --- /dev/null +++ b/IGEV-MVS/evaluate_mvs.py @@ -0,0 +1,450 @@ +import argparse +import os +os.environ['CUDA_VISIBLE_DEVICES'] = '0' +import torch +import torch.nn as nn +import torch.nn.parallel +import torch.backends.cudnn as cudnn +import torch.optim as optim +from torch.utils.data import DataLoader +import torch.nn.functional as F +import numpy as np +import time +from datasets import find_dataset_def +from core.igev_mvs import IGEVMVS +from utils import * +import sys +import cv2 +from datasets.data_io import read_pfm, save_pfm +from core.submodule import depth_unnormalization +from plyfile import PlyData, PlyElement +from tqdm import tqdm +from PIL import Image + +cudnn.benchmark = True + +parser = argparse.ArgumentParser(description='Predict depth, filter, and fuse') +parser.add_argument('--model', default='IterMVS', help='select model') +parser.add_argument('--dataset', default='dtu_yao_eval', help='select dataset') +parser.add_argument('--testpath', default='/data/dtu_data/dtu_test/', help='testing data path') +parser.add_argument('--testlist', default='./lists/dtu/test.txt', help='testing scan list') +parser.add_argument('--maxdisp', default=256) +parser.add_argument('--split', default='intermediate', help='select data') +parser.add_argument('--batch_size', type=int, default=2, help='testing batch size') +parser.add_argument('--n_views', type=int, default=5, help='num of view') +parser.add_argument('--img_wh', nargs='+', type=int, default=[640, 480], + help='height and width of the image') +parser.add_argument('--loadckpt', default='./pretrained_models/dtu.ckpt', help='load a specific checkpoint') +parser.add_argument('--outdir', default='./output/', help='output dir') +parser.add_argument('--display', action='store_true', help='display depth images and masks') +parser.add_argument('--iteration', type=int, default=32, help='num of iteration of GRU') +parser.add_argument('--geo_pixel_thres', type=float, default=1, help='pixel threshold for geometric consistency filtering') +parser.add_argument('--geo_depth_thres', type=float, default=0.01, help='depth threshold for geometric consistency filtering') +parser.add_argument('--photo_thres', type=float, default=0.3, help='threshold for photometric consistency filtering') + +# parse arguments and check +args = parser.parse_args() +print("argv:", sys.argv[1:]) +print_args(args) + +if args.dataset=="dtu_yao_eval": + img_wh=(1600, 1152) +elif args.dataset=="tanks": + img_wh=(1920, 1024) +elif args.dataset=="eth3d": + img_wh = (1920,1280) +else: + img_wh = (args.img_wh[0], args.img_wh[1]) # custom dataset + +# read intrinsics and extrinsics +def read_camera_parameters(filename): + with open(filename) as f: + lines = f.readlines() + lines = [line.rstrip() for line in lines] + # extrinsics: line [1,5), 4x4 matrix + extrinsics = np.fromstring(' '.join(lines[1:5]), dtype=np.float32, sep=' ').reshape((4, 4)) + # intrinsics: line [7-10), 3x3 matrix + intrinsics = np.fromstring(' '.join(lines[7:10]), dtype=np.float32, sep=' ').reshape((3, 3)) + + return intrinsics, extrinsics + + +# read an image +def read_img(filename, img_wh): + img = Image.open(filename) + # scale 0~255 to 0~1 + np_img = np.array(img, dtype=np.float32) / 255. + original_h, original_w, _ = np_img.shape + np_img = cv2.resize(np_img, img_wh, interpolation=cv2.INTER_LINEAR) + return np_img, original_h, original_w + + +# save a binary mask +def save_mask(filename, mask): + assert mask.dtype == np.bool_ + mask = mask.astype(np.uint8) * 255 + Image.fromarray(mask).save(filename) + +def save_depth_img(filename, depth): + # assert mask.dtype == np.bool + depth = depth.astype(np.float32) * 255 + Image.fromarray(depth).save(filename) + + +def read_pair_file(filename): + data = [] + with open(filename) as f: + num_viewpoint = int(f.readline()) + # 49 viewpoints + for view_idx in range(num_viewpoint): + ref_view = int(f.readline().rstrip()) + src_views = [int(x) for x in f.readline().rstrip().split()[1::2]] + if len(src_views) != 0: + data.append((ref_view, src_views)) + return data + + +# run MVS model to save depth maps +def save_depth(): + # dataset, dataloader + MVSDataset = find_dataset_def(args.dataset) + if args.dataset=="dtu_yao_eval": + test_dataset = MVSDataset(args.testpath, args.testlist, args.n_views, img_wh) + elif args.dataset=="tanks": + test_dataset = MVSDataset(args.testpath, args.n_views, img_wh, args.split) + elif args.dataset=="eth3d": + test_dataset = MVSDataset(args.testpath, args.split, args.n_views, img_wh) + else: + test_dataset = MVSDataset(args.testpath, args.n_views, img_wh) + TestImgLoader = DataLoader(test_dataset, args.batch_size, shuffle=False, num_workers=4, drop_last=False) + + # model + model = IGEVMVS(args) + model = nn.DataParallel(model) + model.cuda() + + # load checkpoint file specified by args.loadckpt + print("loading model {}".format(args.loadckpt)) + state_dict = torch.load(args.loadckpt) + model.load_state_dict(state_dict['model']) + model.eval() + + with torch.no_grad(): + tbar = tqdm(TestImgLoader) + for batch_idx, sample in enumerate(tbar): + start_time = time.time() + sample_cuda = tocuda(sample) + disp_prediction = model(sample_cuda["imgs"], sample_cuda["proj_matrices"], + sample_cuda["depth_min"], sample_cuda["depth_max"], test_mode=True) + + b = sample_cuda["depth_min"].shape[0] + + inverse_depth_min = (1.0 / sample_cuda["depth_min"]).view(b, 1, 1, 1) + inverse_depth_max = (1.0 / sample_cuda["depth_max"]).view(b, 1, 1, 1) + + depth_prediction = depth_unnormalization(disp_prediction, inverse_depth_min, inverse_depth_max) + depth_prediction = tensor2numpy(depth_prediction.float()) + del sample_cuda, disp_prediction + tbar.set_description('Iter {}/{}, time = {:.3f}'.format(batch_idx, len(TestImgLoader), time.time() - start_time)) + filenames = sample["filename"] + + # save depth maps and confidence maps + for filename, depth_est in zip(filenames, depth_prediction): + depth_filename = os.path.join(args.outdir, filename.format('depth_est', '.pfm')) + os.makedirs(depth_filename.rsplit('/', 1)[0], exist_ok=True) + # save depth maps + depth_est = np.squeeze(depth_est, 0) + save_pfm(depth_filename, depth_est) + +# project the reference point cloud into the source view, then project back +def reproject_with_depth(depth_ref, intrinsics_ref, extrinsics_ref, depth_src, intrinsics_src, extrinsics_src): + width, height = depth_ref.shape[1], depth_ref.shape[0] + ## step1. project reference pixels to the source view + # reference view x, y + x_ref, y_ref = np.meshgrid(np.arange(0, width), np.arange(0, height)) + x_ref, y_ref = x_ref.reshape([-1]), y_ref.reshape([-1]) + # reference 3D space + xyz_ref = np.matmul(np.linalg.inv(intrinsics_ref), + np.vstack((x_ref, y_ref, np.ones_like(x_ref))) * depth_ref.reshape([-1])) + # source 3D space + xyz_src = np.matmul(np.matmul(extrinsics_src, np.linalg.inv(extrinsics_ref)), + np.vstack((xyz_ref, np.ones_like(x_ref))))[:3] + # source view x, y + K_xyz_src = np.matmul(intrinsics_src, xyz_src) + xy_src = K_xyz_src[:2] / K_xyz_src[2:3] + + ## step2. reproject the source view points with source view depth estimation + # find the depth estimation of the source view + x_src = xy_src[0].reshape([height, width]).astype(np.float32) + y_src = xy_src[1].reshape([height, width]).astype(np.float32) + sampled_depth_src = cv2.remap(depth_src, x_src, y_src, interpolation=cv2.INTER_LINEAR) + # mask = sampled_depth_src > 0 + + # source 3D space + # NOTE that we should use sampled source-view depth_here to project back + xyz_src = np.matmul(np.linalg.inv(intrinsics_src), + np.vstack((xy_src, np.ones_like(x_ref))) * sampled_depth_src.reshape([-1])) + # reference 3D space + xyz_reprojected = np.matmul(np.matmul(extrinsics_ref, np.linalg.inv(extrinsics_src)), + np.vstack((xyz_src, np.ones_like(x_ref))))[:3] + # source view x, y, depth + depth_reprojected = xyz_reprojected[2].reshape([height, width]).astype(np.float32) + K_xyz_reprojected = np.matmul(intrinsics_ref, xyz_reprojected) + xy_reprojected = K_xyz_reprojected[:2] / (K_xyz_reprojected[2:3]+1e-6) + x_reprojected = xy_reprojected[0].reshape([height, width]).astype(np.float32) + y_reprojected = xy_reprojected[1].reshape([height, width]).astype(np.float32) + + return depth_reprojected, x_reprojected, y_reprojected, x_src, y_src + + +def check_geometric_consistency(depth_ref, intrinsics_ref, extrinsics_ref, depth_src, intrinsics_src, extrinsics_src, thre1, thre2): + width, height = depth_ref.shape[1], depth_ref.shape[0] + x_ref, y_ref = np.meshgrid(np.arange(0, width), np.arange(0, height)) + depth_reprojected, x2d_reprojected, y2d_reprojected, x2d_src, y2d_src = reproject_with_depth(depth_ref, + intrinsics_ref, + extrinsics_ref, + depth_src, + intrinsics_src, + extrinsics_src) + # check |p_reproj-p_1| < 1 + dist = np.sqrt((x2d_reprojected - x_ref) ** 2 + (y2d_reprojected - y_ref) ** 2) + + # check |d_reproj-d_1| / d_1 < 0.01 + depth_diff = np.abs(depth_reprojected - depth_ref) + relative_depth_diff = depth_diff / depth_ref + masks=[] + for i in range(2,11): + mask = np.logical_and(dist < i/thre1, relative_depth_diff < i/thre2) + masks.append(mask) + depth_reprojected[~mask] = 0 + + return masks, mask, depth_reprojected, x2d_src, y2d_src + + +def filter_depth(scan_folder, out_folder, plyfilename, geo_pixel_thres, geo_depth_thres, photo_thres, img_wh, geo_mask_thres=3): + # the pair file + pair_file = os.path.join(scan_folder, "pair.txt") + # for the final point cloud + vertexs = [] + vertex_colors = [] + + pair_data = read_pair_file(pair_file) + nviews = len(pair_data) + + thre_left = -2 + thre_right = 2 + total_iter = 10 + for iter in range(total_iter): + thre = (thre_left + thre_right) / 2 + print(f"{iter} {10 ** thre}") + depth_est_averaged = [] + geo_mask_all = [] + # for each reference view and the corresponding source views + for ref_view, src_views in pair_data: + # load the camera parameters + ref_intrinsics, ref_extrinsics = read_camera_parameters( + os.path.join(scan_folder, 'cams_1/{:0>8}_cam.txt'.format(ref_view))) + ref_img, original_h, original_w = read_img(os.path.join(scan_folder, 'images/{:0>8}.jpg'.format(ref_view)), img_wh) + ref_intrinsics[0] *= img_wh[0]/original_w + ref_intrinsics[1] *= img_wh[1]/original_h + # load the estimated depth of the reference view + ref_depth_est = read_pfm(os.path.join(out_folder, 'depth_est/{:0>8}.pfm'.format(ref_view)))[0] + ref_depth_est = np.squeeze(ref_depth_est, 2) + + all_srcview_depth_ests = [] + # compute the geometric mask + geo_mask_sum = 0 + geo_mask_sums=[] + n = 1 + len(src_views) + ct = 0 + for src_view in src_views: + ct = ct + 1 + # camera parameters of the source view + src_intrinsics, src_extrinsics = read_camera_parameters( + os.path.join(scan_folder, 'cams_1/{:0>8}_cam.txt'.format(src_view))) + _, original_h, original_w = read_img(os.path.join(scan_folder, 'images/{:0>8}.jpg'.format(src_view)), img_wh) + src_intrinsics[0] *= img_wh[0]/original_w + src_intrinsics[1] *= img_wh[1]/original_h + + # the estimated depth of the source view + src_depth_est = read_pfm(os.path.join(out_folder, 'depth_est/{:0>8}.pfm'.format(src_view)))[0] + + + masks, geo_mask, depth_reprojected, _, _ = check_geometric_consistency(ref_depth_est, ref_intrinsics, ref_extrinsics, + src_depth_est, + src_intrinsics, src_extrinsics, 10 ** thre * 4, 10 ** thre * 1300) + if (ct==1): + for i in range(2,n): + geo_mask_sums.append(masks[i-2].astype(np.int32)) + else: + for i in range(2,n): + geo_mask_sums[i-2]+=masks[i-2].astype(np.int32) + + geo_mask_sum+=geo_mask.astype(np.int32) + all_srcview_depth_ests.append(depth_reprojected) + + geo_mask=geo_mask_sum>=n + for i in range (2,n): + geo_mask=np.logical_or(geo_mask,geo_mask_sums[i-2]>=i) + + depth_est_averaged.append((sum(all_srcview_depth_ests) + ref_depth_est) / (geo_mask_sum + 1)) + geo_mask_all.append(np.mean(geo_mask)) + final_mask = geo_mask + + if iter == total_iter - 1: + os.makedirs(os.path.join(out_folder, "mask"), exist_ok=True) + save_mask(os.path.join(out_folder, "mask/{:0>8}_geo.png".format(ref_view)), geo_mask) + save_mask(os.path.join(out_folder, "mask/{:0>8}_final.png".format(ref_view)), final_mask) + + print("processing {}, ref-view{:0>2}, geo_mask:{:3f} final_mask: {:3f}".format(scan_folder, ref_view, + geo_mask.mean(), final_mask.mean())) + + if args.display: + cv2.imshow('ref_img', ref_img[:, :, ::-1]) + cv2.imshow('ref_depth', ref_depth_est / np.max(ref_depth_est)) + cv2.imshow('ref_depth * geo_mask', ref_depth_est * geo_mask.astype(np.float32) / np.max(ref_depth_est)) + cv2.imshow('ref_depth * mask', ref_depth_est * final_mask.astype(np.float32) / np.max(ref_depth_est)) + cv2.waitKey(0) + + height, width = depth_est_averaged[-1].shape[:2] + x, y = np.meshgrid(np.arange(0, width), np.arange(0, height)) + + valid_points = final_mask + # print("valid_points", valid_points.mean()) + x, y, depth = x[valid_points], y[valid_points], depth_est_averaged[-1][valid_points] + + color = ref_img[valid_points] + xyz_ref = np.matmul(np.linalg.inv(ref_intrinsics), + np.vstack((x, y, np.ones_like(x))) * depth) + xyz_world = np.matmul(np.linalg.inv(ref_extrinsics), + np.vstack((xyz_ref, np.ones_like(x))))[:3] + vertexs.append(xyz_world.transpose((1, 0))) + vertex_colors.append((color * 255).astype(np.uint8)) + if np.mean(geo_mask_all) >= 0.25: + thre_left = thre + else: + thre_right = thre + vertexs = np.concatenate(vertexs, axis=0) + vertex_colors = np.concatenate(vertex_colors, axis=0) + vertexs = np.array([tuple(v) for v in vertexs], dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4')]) + vertex_colors = np.array([tuple(v) for v in vertex_colors], dtype=[('red', 'u1'), ('green', 'u1'), ('blue', 'u1')]) + + vertex_all = np.empty(len(vertexs), vertexs.dtype.descr + vertex_colors.dtype.descr) + for prop in vertexs.dtype.names: + vertex_all[prop] = vertexs[prop] + for prop in vertex_colors.dtype.names: + vertex_all[prop] = vertex_colors[prop] + + el = PlyElement.describe(vertex_all, 'vertex') + PlyData([el]).write(plyfilename) + print("saving the final model to", plyfilename) + + +if __name__ == '__main__': + save_depth() + if args.dataset=="dtu_yao_eval": + with open(args.testlist) as f: + scans = f.readlines() + scans = [line.rstrip() for line in scans] + + for scan in scans: + scan_id = int(scan[4:]) + scan_folder = os.path.join(args.testpath, scan) + out_folder = os.path.join(args.outdir, scan) + filter_depth(scan_folder, out_folder, os.path.join(args.outdir, 'igev_mvs{:0>3}_l3.ply'.format(scan_id)), + args.geo_pixel_thres, args.geo_depth_thres, args.photo_thres, img_wh, 4) + elif args.dataset=="tanks": + # intermediate dataset + if args.split == "intermediate": + scans = ['Family', 'Francis', 'Horse', 'Lighthouse', + 'M60', 'Panther', 'Playground', 'Train'] + geo_mask_thres = {'Family': 5, + 'Francis': 6, + 'Horse': 5, + 'Lighthouse': 6, + 'M60': 5, + 'Panther': 5, + 'Playground': 5, + 'Train': 5} + + for scan in scans: + scan_folder = os.path.join(args.testpath, args.split, scan) + out_folder = os.path.join(args.outdir, scan) + + filter_depth(scan_folder, out_folder, os.path.join(args.outdir, scan + '.ply'), + args.geo_pixel_thres, args.geo_depth_thres, args.photo_thres, img_wh, geo_mask_thres[scan]) + + # advanced dataset + elif args.split == "advanced": + scans = ['Auditorium', 'Ballroom', 'Courtroom', + 'Museum', 'Palace', 'Temple'] + geo_mask_thres = {'Auditorium': 3, + 'Ballroom': 4, + 'Courtroom': 4, + 'Museum': 4, + 'Palace': 5, + 'Temple': 4} + + for scan in scans: + scan_folder = os.path.join(args.testpath, args.split, scan) + out_folder = os.path.join(args.outdir, scan) + filter_depth(scan_folder, out_folder, os.path.join(args.outdir, scan + '.ply'), + args.geo_pixel_thres, args.geo_depth_thres, args.photo_thres, img_wh, geo_mask_thres[scan]) + + elif args.dataset=="eth3d": + if args.split == "test": + scans = ['botanical_garden', 'boulders', 'bridge', 'door', + 'exhibition_hall', 'lecture_room', 'living_room', 'lounge', + 'observatory', 'old_computer', 'statue', 'terrace_2'] + + geo_mask_thres = {'botanical_garden':1, # 30 images, outdoor + 'boulders':1, # 26 images, outdoor + 'bridge':2, # 110 images, outdoor + 'door':2, # 6 images, indoor + 'exhibition_hall':2, # 68 images, indoor + 'lecture_room':2, # 23 images, indoor + 'living_room':2, # 65 images, indoor + 'lounge':1,# 10 images, indoor + 'observatory':2, # 27 images, outdoor + 'old_computer':2, # 54 images, indoor + 'statue':2, # 10 images, indoor + 'terrace_2':2 # 13 images, outdoor + } + for scan in scans: + start_time = time.time() + scan_folder = os.path.join(args.testpath, scan) + out_folder = os.path.join(args.outdir, scan) + filter_depth(scan_folder, out_folder, os.path.join(args.outdir, scan + '.ply'), + args.geo_pixel_thres, args.geo_depth_thres, args.photo_thres, img_wh, geo_mask_thres[scan]) + print('scan: '+scan+' time = {:3f}'.format(time.time() - start_time)) + + elif args.split == "train": + scans = ['courtyard', 'delivery_area', 'electro', 'facade', + 'kicker', 'meadow', 'office', 'pipes', 'playground', + 'relief', 'relief_2', 'terrace', 'terrains'] + + geo_mask_thres = {'courtyard':1, # 38 images, outdoor + 'delivery_area':2, # 44 images, indoor + 'electro':1, # 45 images, outdoor + 'facade':2, # 76 images, outdoor + 'kicker':1, # 31 images, indoor + 'meadow':1, # 15 images, outdoor + 'office':1, # 26 images, indoor + 'pipes':1,# 14 images, indoor + 'playground':1, # 38 images, outdoor + 'relief':1, # 31 images, indoor + 'relief_2':1, # 31 images, indoor + 'terrace':1, # 23 images, outdoor + 'terrains':2 # 42 images, indoor + } + + for scan in scans: + start_time = time.time() + scan_folder = os.path.join(args.testpath, scan) + out_folder = os.path.join(args.outdir, scan) + filter_depth(scan_folder, out_folder, os.path.join(args.outdir, scan + '.ply'), + args.geo_pixel_thres, args.geo_depth_thres, args.photo_thres, img_wh, geo_mask_thres[scan]) + print('scan: '+scan+' time = {:3f}'.format(time.time() - start_time)) + else: + filter_depth(args.testpath, args.outdir, os.path.join(args.outdir, 'custom.ply'), + args.geo_pixel_thres, args.geo_depth_thres, args.photo_thres, img_wh, geo_mask_thres=3) \ No newline at end of file diff --git a/IGEV-MVS/evaluations/dtu/BaseEval2Obj_web.m b/IGEV-MVS/evaluations/dtu/BaseEval2Obj_web.m new file mode 100644 index 0000000..80a6eb0 --- /dev/null +++ b/IGEV-MVS/evaluations/dtu/BaseEval2Obj_web.m @@ -0,0 +1,44 @@ +function BaseEval2Obj_web(BaseEval,method_string,outputPath) + +if(nargin<3) + outputPath='./'; +end + +% tresshold for coloring alpha channel in the range of 0-10 mm +dist_tresshold=10; + +cSet=BaseEval.cSet; + +Qdata=BaseEval.Qdata; +alpha=min(BaseEval.Ddata,dist_tresshold)/dist_tresshold; + +fid=fopen([outputPath method_string '2Stl_' num2str(cSet) ' .obj'],'w+'); + +for cP=1:size(Qdata,2) + if(BaseEval.DataInMask(cP)) + C=[1 0 0]*alpha(cP)+[1 1 1]*(1-alpha(cP)); %coloring from red to white in the range of 0-10 mm (0 to dist_tresshold) + else + C=[0 1 0]*alpha(cP)+[0 0 1]*(1-alpha(cP)); %green to blue for points outside the mask (which are not included in the analysis) + end + fprintf(fid,'v %f %f %f %f %f %f\n',[Qdata(1,cP) Qdata(2,cP) Qdata(3,cP) C(1) C(2) C(3)]); +end +fclose(fid); + +disp('Data2Stl saved as obj') + +Qstl=BaseEval.Qstl; +fid=fopen([outputPath 'Stl2' method_string '_' num2str(cSet) '.obj'],'w+'); + +alpha=min(BaseEval.Dstl,dist_tresshold)/dist_tresshold; + +for cP=1:size(Qstl,2) + if(BaseEval.StlAbovePlane(cP)) + C=[1 0 0]*alpha(cP)+[1 1 1]*(1-alpha(cP)); %coloring from red to white in the range of 0-10 mm (0 to dist_tresshold) + else + C=[0 1 0]*alpha(cP)+[0 0 1]*(1-alpha(cP)); %green to blue for points below plane (which are not included in the analysis) + end + fprintf(fid,'v %f %f %f %f %f %f\n',[Qstl(1,cP) Qstl(2,cP) Qstl(3,cP) C(1) C(2) C(3)]); +end +fclose(fid); + +disp('Stl2Data saved as obj') \ No newline at end of file diff --git a/IGEV-MVS/evaluations/dtu/BaseEvalMain_web.m b/IGEV-MVS/evaluations/dtu/BaseEvalMain_web.m new file mode 100644 index 0000000..44192dd --- /dev/null +++ b/IGEV-MVS/evaluations/dtu/BaseEvalMain_web.m @@ -0,0 +1,104 @@ +clear all +close all +format compact +clc + +% script to calculate distances have been measured for all included scans (UsedSets) + +dataPath='D:\xgw\IterMVS_data\MVS Data\'; +plyPath='/data/xgw/IGEV_MVS/conf_03/'; +resultsPath='/data/xgw/IGEV_MVS/outputs_conf_03/'; + + +method_string='itermvs'; +light_string='l3'; % l3 is the setting with all lights on, l7 is randomly sampled between the 7 settings (index 0-6) +representation_string='Points'; %mvs representation 'Points' or 'Surfaces' + +switch representation_string + case 'Points' + eval_string='_Eval_'; %results naming + settings_string=''; +end + +% get sets used in evaluation +UsedSets=[1 4 9 10 11 12 13 15 23 24 29 32 33 34 48 49 62 75 77 110 114 118]; + +result = zeros(length(UsedSets),4); + +dst=0.2; %Min dist between points when reducing + +for cIdx=1:length(UsedSets) + %Data set number + cSet = UsedSets(cIdx) + %input data name + DataInName=[plyPath sprintf('%s%03d_%s%s.ply',lower(method_string),cSet,light_string,settings_string)] + + %results name + %concatenate strings into one string + EvalName=[resultsPath method_string eval_string num2str(cSet) '.mat'] + + disp(EvalName) + + %check if file is already computed + if(~exist(EvalName,'file')) + disp(DataInName); + + time=clock;time(4:5), drawnow + + tic + Mesh = plyread(DataInName); + Qdata=[Mesh.vertex.x Mesh.vertex.y Mesh.vertex.z]'; + toc + + BaseEval=PointCompareMain(cSet,Qdata,dst,dataPath); + + disp('Saving results'), drawnow + toc + save(EvalName,'BaseEval'); + toc + + % write obj-file of evaluation +% BaseEval2Obj_web(BaseEval,method_string, resultsPath) +% toc + time=clock;time(4:5), drawnow + + BaseEval.MaxDist=20; %outlier threshold of 20 mm + + BaseEval.FilteredDstl=BaseEval.Dstl(BaseEval.StlAbovePlane); %use only points that are above the plane + BaseEval.FilteredDstl=BaseEval.FilteredDstl(BaseEval.FilteredDstl=Low(1) & Qfrom(2,:)>=Low(2) & Qfrom(3,:)>=Low(3) &... + Qfrom(1,:)=Low(1) & Qto(2,:)>=Low(2) & Qto(3,:)>=Low(3) &... + Qto(1,:)3)] +end + diff --git a/IGEV-MVS/evaluations/dtu/PointCompareMain.m b/IGEV-MVS/evaluations/dtu/PointCompareMain.m new file mode 100644 index 0000000..e84df08 --- /dev/null +++ b/IGEV-MVS/evaluations/dtu/PointCompareMain.m @@ -0,0 +1,58 @@ +function BaseEval=PointCompareMain(cSet,Qdata,dst,dataPath) +% evaluation function the calculates the distantes from the reference data (stl) to the evalution points (Qdata) and the +% distances from the evaluation points to the reference + +tic +% reduce points 0.2 mm neighbourhood density +Qdata=reducePts_haa(Qdata,dst); +toc + +StlInName=[dataPath '/Points/stl/stl' sprintf('%03d',cSet) '_total.ply']; + +StlMesh = plyread(StlInName); %STL points already reduced 0.2 mm neighbourhood density +Qstl=[StlMesh.vertex.x StlMesh.vertex.y StlMesh.vertex.z]'; + +%Load Mask (ObsMask) and Bounding box (BB) and Resolution (Res) +Margin=10; +MaskName=[dataPath '/ObsMask/ObsMask' num2str(cSet) '_' num2str(Margin) '.mat']; +load(MaskName) + +MaxDist=60; +disp('Computing Data 2 Stl distances') +Ddata = MaxDistCP(Qstl,Qdata,BB,MaxDist); +toc + +disp('Computing Stl 2 Data distances') +Dstl=MaxDistCP(Qdata,Qstl,BB,MaxDist); +disp('Distances computed') +toc + +%use mask +%From Get mask - inverted & modified. +One=ones(1,size(Qdata,2)); +Qv=(Qdata-BB(1,:)'*One)/Res+1; +Qv=round(Qv); + +Midx1=find(Qv(1,:)>0 & Qv(1,:)<=size(ObsMask,1) & Qv(2,:)>0 & Qv(2,:)<=size(ObsMask,2) & Qv(3,:)>0 & Qv(3,:)<=size(ObsMask,3)); +MidxA=sub2ind(size(ObsMask),Qv(1,Midx1),Qv(2,Midx1),Qv(3,Midx1)); +Midx2=find(ObsMask(MidxA)); + +BaseEval.DataInMask(1:size(Qv,2))=false; +BaseEval.DataInMask(Midx1(Midx2))=true; %If Data is within the mask + +BaseEval.cSet=cSet; +BaseEval.Margin=Margin; %Margin of masks +BaseEval.dst=dst; %Min dist between points when reducing +BaseEval.Qdata=Qdata; %Input data points +BaseEval.Ddata=Ddata; %distance from data to stl +BaseEval.Qstl=Qstl; %Input stl points +BaseEval.Dstl=Dstl; %Distance from the stl to data + +load([dataPath '/ObsMask/Plane' num2str(cSet)],'P') +BaseEval.GroundPlane=P; % Plane used to destinguise which Stl points are 'used' +BaseEval.StlAbovePlane=(P'*[Qstl;ones(1,size(Qstl,2))])>0; %Is stl above 'ground plane' +BaseEval.Time=clock; %Time when computation is finished + + + + diff --git a/IGEV-MVS/evaluations/dtu/plyread.m b/IGEV-MVS/evaluations/dtu/plyread.m new file mode 100644 index 0000000..5141a0c --- /dev/null +++ b/IGEV-MVS/evaluations/dtu/plyread.m @@ -0,0 +1,454 @@ +%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% +function [Elements,varargout] = plyread(Path,Str) +%PLYREAD Read a PLY 3D data file. +% [DATA,COMMENTS] = PLYREAD(FILENAME) reads a version 1.0 PLY file +% FILENAME and returns a structure DATA. The fields in this structure +% are defined by the PLY header; each element type is a field and each +% element property is a subfield. If the file contains any comments, +% they are returned in a cell string array COMMENTS. +% +% [TRI,PTS] = PLYREAD(FILENAME,'tri') or +% [TRI,PTS,DATA,COMMENTS] = PLYREAD(FILENAME,'tri') converts vertex +% and face data into triangular connectivity and vertex arrays. The +% mesh can then be displayed using the TRISURF command. +% +% Note: This function is slow for large mesh files (+50K faces), +% especially when reading data with list type properties. +% +% Example: +% [Tri,Pts] = PLYREAD('cow.ply','tri'); +% trisurf(Tri,Pts(:,1),Pts(:,2),Pts(:,3)); +% colormap(gray); axis equal; +% +% See also: PLYWRITE + +% Pascal Getreuer 2004 + +[fid,Msg] = fopen(Path,'rt'); % open file in read text mode + +if fid == -1, error(Msg); end + +Buf = fscanf(fid,'%s',1); +if ~strcmp(Buf,'ply') + fclose(fid); + error('Not a PLY file.'); +end + + +%%% read header %%% + +Position = ftell(fid); +Format = ''; +NumComments = 0; +Comments = {}; % for storing any file comments +NumElements = 0; +NumProperties = 0; +Elements = []; % structure for holding the element data +ElementCount = []; % number of each type of element in file +PropertyTypes = []; % corresponding structure recording property types +ElementNames = {}; % list of element names in the order they are stored in the file +PropertyNames = []; % structure of lists of property names + +while 1 + Buf = fgetl(fid); % read one line from file + BufRem = Buf; + Token = {}; + Count = 0; + + while ~isempty(BufRem) % split line into tokens + [tmp,BufRem] = strtok(BufRem); + + if ~isempty(tmp) + Count = Count + 1; % count tokens + Token{Count} = tmp; + end + end + + if Count % parse line + switch lower(Token{1}) + case 'format' % read data format + if Count >= 2 + Format = lower(Token{2}); + + if Count == 3 & ~strcmp(Token{3},'1.0') + fclose(fid); + error('Only PLY format version 1.0 supported.'); + end + end + case 'comment' % read file comment + NumComments = NumComments + 1; + Comments{NumComments} = ''; + for i = 2:Count + Comments{NumComments} = [Comments{NumComments},Token{i},' ']; + end + case 'element' % element name + if Count >= 3 + if isfield(Elements,Token{2}) + fclose(fid); + error(['Duplicate element name, ''',Token{2},'''.']); + end + + NumElements = NumElements + 1; + NumProperties = 0; + Elements = setfield(Elements,Token{2},[]); + PropertyTypes = setfield(PropertyTypes,Token{2},[]); + ElementNames{NumElements} = Token{2}; + PropertyNames = setfield(PropertyNames,Token{2},{}); + CurElement = Token{2}; + ElementCount(NumElements) = str2double(Token{3}); + + if isnan(ElementCount(NumElements)) + fclose(fid); + error(['Bad element definition: ',Buf]); + end + else + error(['Bad element definition: ',Buf]); + end + case 'property' % element property + if ~isempty(CurElement) & Count >= 3 + NumProperties = NumProperties + 1; + eval(['tmp=isfield(Elements.',CurElement,',Token{Count});'],... + 'fclose(fid);error([''Error reading property: '',Buf])'); + + if tmp + error(['Duplicate property name, ''',CurElement,'.',Token{2},'''.']); + end + + % add property subfield to Elements + eval(['Elements.',CurElement,'.',Token{Count},'=[];'], ... + 'fclose(fid);error([''Error reading property: '',Buf])'); + % add property subfield to PropertyTypes and save type + eval(['PropertyTypes.',CurElement,'.',Token{Count},'={Token{2:Count-1}};'], ... + 'fclose(fid);error([''Error reading property: '',Buf])'); + % record property name order + eval(['PropertyNames.',CurElement,'{NumProperties}=Token{Count};'], ... + 'fclose(fid);error([''Error reading property: '',Buf])'); + else + fclose(fid); + + if isempty(CurElement) + error(['Property definition without element definition: ',Buf]); + else + error(['Bad property definition: ',Buf]); + end + end + case 'end_header' % end of header, break from while loop + break; + end + end +end + +%%% set reading for specified data format %%% + +if isempty(Format) + warning('Data format unspecified, assuming ASCII.'); + Format = 'ascii'; +end + +switch Format +case 'ascii' + Format = 0; +case 'binary_little_endian' + Format = 1; +case 'binary_big_endian' + Format = 2; +otherwise + fclose(fid); + error(['Data format ''',Format,''' not supported.']); +end + +if ~Format + Buf = fscanf(fid,'%f'); % read the rest of the file as ASCII data + BufOff = 1; +else + % reopen the file in read binary mode + fclose(fid); + + if Format == 1 + fid = fopen(Path,'r','ieee-le.l64'); % little endian + else + fid = fopen(Path,'r','ieee-be.l64'); % big endian + end + + % find the end of the header again (using ftell on the old handle doesn't give the correct position) + BufSize = 8192; + Buf = [blanks(10),char(fread(fid,BufSize,'uchar')')]; + i = []; + tmp = -11; + + while isempty(i) + i = findstr(Buf,['end_header',13,10]); % look for end_header + CR/LF + i = [i,findstr(Buf,['end_header',10])]; % look for end_header + LF + + if isempty(i) + tmp = tmp + BufSize; + Buf = [Buf(BufSize+1:BufSize+10),char(fread(fid,BufSize,'uchar')')]; + end + end + + % seek to just after the line feed + fseek(fid,i + tmp + 11 + (Buf(i + 10) == 13),-1); +end + + +%%% read element data %%% + +% PLY and MATLAB data types (for fread) +PlyTypeNames = {'char','uchar','short','ushort','int','uint','float','double', ... + 'char8','uchar8','short16','ushort16','int32','uint32','float32','double64'}; +MatlabTypeNames = {'schar','uchar','int16','uint16','int32','uint32','single','double'}; +SizeOf = [1,1,2,2,4,4,4,8]; % size in bytes of each type + +for i = 1:NumElements + % get current element property information + eval(['CurPropertyNames=PropertyNames.',ElementNames{i},';']); + eval(['CurPropertyTypes=PropertyTypes.',ElementNames{i},';']); + NumProperties = size(CurPropertyNames,2); + +% fprintf('Reading %s...\n',ElementNames{i}); + + if ~Format %%% read ASCII data %%% + for j = 1:NumProperties + Token = getfield(CurPropertyTypes,CurPropertyNames{j}); + + if strcmpi(Token{1},'list') + Type(j) = 1; + else + Type(j) = 0; + end + end + + % parse buffer + if ~any(Type) + % no list types + Data = reshape(Buf(BufOff:BufOff+ElementCount(i)*NumProperties-1),NumProperties,ElementCount(i))'; + BufOff = BufOff + ElementCount(i)*NumProperties; + else + ListData = cell(NumProperties,1); + + for k = 1:NumProperties + ListData{k} = cell(ElementCount(i),1); + end + + % list type + for j = 1:ElementCount(i) + for k = 1:NumProperties + if ~Type(k) + Data(j,k) = Buf(BufOff); + BufOff = BufOff + 1; + else + tmp = Buf(BufOff); + ListData{k}{j} = Buf(BufOff+(1:tmp))'; + BufOff = BufOff + tmp + 1; + end + end + end + end + else %%% read binary data %%% + % translate PLY data type names to MATLAB data type names + ListFlag = 0; % = 1 if there is a list type + SameFlag = 1; % = 1 if all types are the same + + for j = 1:NumProperties + Token = getfield(CurPropertyTypes,CurPropertyNames{j}); + + if ~strcmp(Token{1},'list') % non-list type + tmp = rem(strmatch(Token{1},PlyTypeNames,'exact')-1,8)+1; + + if ~isempty(tmp) + TypeSize(j) = SizeOf(tmp); + Type{j} = MatlabTypeNames{tmp}; + TypeSize2(j) = 0; + Type2{j} = ''; + + SameFlag = SameFlag & strcmp(Type{1},Type{j}); + else + fclose(fid); + error(['Unknown property data type, ''',Token{1},''', in ', ... + ElementNames{i},'.',CurPropertyNames{j},'.']); + end + else % list type + if length(Token) == 3 + ListFlag = 1; + SameFlag = 0; + tmp = rem(strmatch(Token{2},PlyTypeNames,'exact')-1,8)+1; + tmp2 = rem(strmatch(Token{3},PlyTypeNames,'exact')-1,8)+1; + + if ~isempty(tmp) & ~isempty(tmp2) + TypeSize(j) = SizeOf(tmp); + Type{j} = MatlabTypeNames{tmp}; + TypeSize2(j) = SizeOf(tmp2); + Type2{j} = MatlabTypeNames{tmp2}; + else + fclose(fid); + error(['Unknown property data type, ''list ',Token{2},' ',Token{3},''', in ', ... + ElementNames{i},'.',CurPropertyNames{j},'.']); + end + else + fclose(fid); + error(['Invalid list syntax in ',ElementNames{i},'.',CurPropertyNames{j},'.']); + end + end + end + + % read file + if ~ListFlag + if SameFlag + % no list types, all the same type (fast) + Data = fread(fid,[NumProperties,ElementCount(i)],Type{1})'; + else + % no list types, mixed type + Data = zeros(ElementCount(i),NumProperties); + + for j = 1:ElementCount(i) + for k = 1:NumProperties + Data(j,k) = fread(fid,1,Type{k}); + end + end + end + else + ListData = cell(NumProperties,1); + + for k = 1:NumProperties + ListData{k} = cell(ElementCount(i),1); + end + + if NumProperties == 1 + BufSize = 512; + SkipNum = 4; + j = 0; + + % list type, one property (fast if lists are usually the same length) + while j < ElementCount(i) + Position = ftell(fid); + % read in BufSize count values, assuming all counts = SkipNum + [Buf,BufSize] = fread(fid,BufSize,Type{1},SkipNum*TypeSize2(1)); + Miss = find(Buf ~= SkipNum); % find first count that is not SkipNum + fseek(fid,Position + TypeSize(1),-1); % seek back to after first count + + if isempty(Miss) % all counts are SkipNum + Buf = fread(fid,[SkipNum,BufSize],[int2str(SkipNum),'*',Type2{1}],TypeSize(1))'; + fseek(fid,-TypeSize(1),0); % undo last skip + + for k = 1:BufSize + ListData{1}{j+k} = Buf(k,:); + end + + j = j + BufSize; + BufSize = floor(1.5*BufSize); + else + if Miss(1) > 1 % some counts are SkipNum + Buf2 = fread(fid,[SkipNum,Miss(1)-1],[int2str(SkipNum),'*',Type2{1}],TypeSize(1))'; + + for k = 1:Miss(1)-1 + ListData{1}{j+k} = Buf2(k,:); + end + + j = j + k; + end + + % read in the list with the missed count + SkipNum = Buf(Miss(1)); + j = j + 1; + ListData{1}{j} = fread(fid,[1,SkipNum],Type2{1}); + BufSize = ceil(0.6*BufSize); + end + end + else + % list type(s), multiple properties (slow) + Data = zeros(ElementCount(i),NumProperties); + + for j = 1:ElementCount(i) + for k = 1:NumProperties + if isempty(Type2{k}) + Data(j,k) = fread(fid,1,Type{k}); + else + tmp = fread(fid,1,Type{k}); + ListData{k}{j} = fread(fid,[1,tmp],Type2{k}); + end + end + end + end + end + end + + % put data into Elements structure + for k = 1:NumProperties + if (~Format & ~Type(k)) | (Format & isempty(Type2{k})) + eval(['Elements.',ElementNames{i},'.',CurPropertyNames{k},'=Data(:,k);']); + else + eval(['Elements.',ElementNames{i},'.',CurPropertyNames{k},'=ListData{k};']); + end + end +end + +clear Data ListData; +fclose(fid); + +if (nargin > 1 & strcmpi(Str,'Tri')) | nargout > 2 + % find vertex element field + Name = {'vertex','Vertex','point','Point','pts','Pts'}; + Names = []; + + for i = 1:length(Name) + if any(strcmp(ElementNames,Name{i})) + Names = getfield(PropertyNames,Name{i}); + Name = Name{i}; + break; + end + end + + if any(strcmp(Names,'x')) & any(strcmp(Names,'y')) & any(strcmp(Names,'z')) + eval(['varargout{1}=[Elements.',Name,'.x,Elements.',Name,'.y,Elements.',Name,'.z];']); + else + varargout{1} = zeros(1,3); + end + + varargout{2} = Elements; + varargout{3} = Comments; + Elements = []; + + % find face element field + Name = {'face','Face','poly','Poly','tri','Tri'}; + Names = []; + + for i = 1:length(Name) + if any(strcmp(ElementNames,Name{i})) + Names = getfield(PropertyNames,Name{i}); + Name = Name{i}; + break; + end + end + + if ~isempty(Names) + % find vertex indices property subfield + PropertyName = {'vertex_indices','vertex_indexes','vertex_index','indices','indexes'}; + + for i = 1:length(PropertyName) + if any(strcmp(Names,PropertyName{i})) + PropertyName = PropertyName{i}; + break; + end + end + + if ~iscell(PropertyName) + % convert face index lists to triangular connectivity + eval(['FaceIndices=varargout{2}.',Name,'.',PropertyName,';']); + N = length(FaceIndices); + Elements = zeros(N*2,3); + Extra = 0; + + for k = 1:N + Elements(k,:) = FaceIndices{k}(1:3); + + for j = 4:length(FaceIndices{k}) + Extra = Extra + 1; + Elements(N + Extra,:) = [Elements(k,[1,j-1]),FaceIndices{k}(j)]; + end + end + Elements = Elements(1:N+Extra,:) + 1; + end + end +else + varargout{1} = Comments; +end \ No newline at end of file diff --git a/IGEV-MVS/evaluations/dtu/reducePts_haa.m b/IGEV-MVS/evaluations/dtu/reducePts_haa.m new file mode 100644 index 0000000..b883d8a --- /dev/null +++ b/IGEV-MVS/evaluations/dtu/reducePts_haa.m @@ -0,0 +1,35 @@ +function [ptsOut,indexSet] = reducePts_haa(pts, dst) + +%Reduces a point set, pts, in a stochastic manner, such that the minimum sdistance +% between points is 'dst'. Writen by abd, edited by haa, then by raje + +nPoints=size(pts,2); + +indexSet=true(nPoints,1); +RandOrd=randperm(nPoints); + +%tic +NS = KDTreeSearcher(pts'); +%toc + +% search the KNTree for close neighbours in a chunk-wise fashion to save memory if point cloud is really big +Chunks=1:min(4e6,nPoints-1):nPoints; +Chunks(end)=nPoints; + +for cChunk=1:(length(Chunks)-1) + Range=Chunks(cChunk):Chunks(cChunk+1); + + idx = rangesearch(NS,pts(:,RandOrd(Range))',dst); + + for i = 1:size(idx,1) + id =RandOrd(i-1+Chunks(cChunk)); + if (indexSet(id)) + indexSet(idx{i}) = 0; + indexSet(id) = 1; + end + end +end + +ptsOut = pts(:,indexSet); + +disp(['downsample factor: ' num2str(nPoints/sum(indexSet))]); diff --git a/IGEV-MVS/lists/blendedmvs/train.txt b/IGEV-MVS/lists/blendedmvs/train.txt new file mode 100644 index 0000000..22a9db2 --- /dev/null +++ b/IGEV-MVS/lists/blendedmvs/train.txt @@ -0,0 +1,106 @@ +5c1f33f1d33e1f2e4aa6dda4 +5bfe5ae0fe0ea555e6a969ca +5bff3c5cfe0ea555e6bcbf3a +58eaf1513353456af3a1682a +5bfc9d5aec61ca1dd69132a2 +5bf18642c50e6f7f8bdbd492 +5bf26cbbd43923194854b270 +5bf17c0fd439231948355385 +5be3ae47f44e235bdbbc9771 +5be3a5fb8cfdd56947f6b67c +5bbb6eb2ea1cfa39f1af7e0c +5ba75d79d76ffa2c86cf2f05 +5bb7a08aea1cfa39f1a947ab +5b864d850d072a699b32f4ae +5b6eff8b67b396324c5b2672 +5b6e716d67b396324c2d77cb +5b69cc0cb44b61786eb959bf +5b62647143840965efc0dbde +5b60fa0c764f146feef84df0 +5b558a928bbfb62204e77ba2 +5b271079e0878c3816dacca4 +5b08286b2775267d5b0634ba +5afacb69ab00705d0cefdd5b +5af28cea59bc705737003253 +5af02e904c8216544b4ab5a2 +5aa515e613d42d091d29d300 +5c34529873a8df509ae57b58 +5c34300a73a8df509add216d +5c1af2e2bee9a723c963d019 +5c1892f726173c3a09ea9aeb +5c0d13b795da9479e12e2ee9 +5c062d84a96e33018ff6f0a6 +5bfd0f32ec61ca1dd69dc77b +5bf21799d43923194842c001 +5bf3a82cd439231948877aed +5bf03590d4392319481971dc +5beb6e66abd34c35e18e66b9 +5be883a4f98cee15019d5b83 +5be47bf9b18881428d8fbc1d +5bcf979a6d5f586b95c258cd +5bce7ac9ca24970bce4934b6 +5bb8a49aea1cfa39f1aa7f75 +5b78e57afc8fcf6781d0c3ba +5b21e18c58e2823a67a10dd8 +5b22269758e2823a67a3bd03 +5b192eb2170cf166458ff886 +5ae2e9c5fe405c5076abc6b2 +5adc6bd52430a05ecb2ffb85 +5ab8b8e029f5351f7f2ccf59 +5abc2506b53b042ead637d86 +5ab85f1dac4291329b17cb50 +5a969eea91dfc339a9a3ad2c +5a8aa0fab18050187cbe060e +5a7d3db14989e929563eb153 +5a69c47d0d5d0a7f3b2e9752 +5a618c72784780334bc1972d +5a6464143d809f1d8208c43c +5a588a8193ac3d233f77fbca +5a57542f333d180827dfc132 +5a572fd9fc597b0478a81d14 +5a563183425d0f5186314855 +5a4a38dad38c8a075495b5d2 +5a48d4b2c7dab83a7d7b9851 +5a489fb1c7dab83a7d7b1070 +5a48ba95c7dab83a7d7b44ed +5a3ca9cb270f0e3f14d0eddb +5a3cb4e4270f0e3f14d12f43 +5a3f4aba5889373fbbc5d3b5 +5a0271884e62597cdee0d0eb +59e864b2a9e91f2c5529325f +599aa591d5b41f366fed0d58 +59350ca084b7f26bf5ce6eb8 +59338e76772c3e6384afbb15 +5c20ca3a0843bc542d94e3e2 +5c1dbf200843bc542d8ef8c4 +5c1b1500bee9a723c96c3e78 +5bea87f4abd34c35e1860ab5 +5c2b3ed5e611832e8aed46bf +57f8d9bbe73f6760f10e916a +5bf7d63575c26f32dbf7413b +5be4ab93870d330ff2dce134 +5bd43b4ba6b28b1ee86b92dd +5bccd6beca24970bce448134 +5bc5f0e896b66a2cd8f9bd36 +5b908d3dc6ab78485f3d24a9 +5b2c67b5e0878c381608b8d8 +5b4933abf2b5f44e95de482a +5b3b353d8d46a939f93524b9 +5acf8ca0f3d8a750097e4b15 +5ab8713ba3799a1d138bd69a +5aa235f64a17b335eeaf9609 +5aa0f9d7a9efce63548c69a1 +5a8315f624b8e938486e0bd8 +5a48c4e9c7dab83a7d7b5cc7 +59ecfd02e225f6492d20fcc9 +59f87d0bfa6280566fb38c9a +59f363a8b45be22330016cad +59f70ab1e5c5d366af29bf3e +59e75a2ca9e91f2c5526005d +5947719bf1b45630bd096665 +5947b62af1b45630bd0c2a02 +59056e6760bb961de55f3501 +58f7f7299f5b5647873cb110 +58cf4771d0f5fb221defe6da +58d36897f387231e6c929903 +58c4bb4f4a69c55606122be4 diff --git a/IGEV-MVS/lists/blendedmvs/val.txt b/IGEV-MVS/lists/blendedmvs/val.txt new file mode 100644 index 0000000..7916fcb --- /dev/null +++ b/IGEV-MVS/lists/blendedmvs/val.txt @@ -0,0 +1,7 @@ +5b7a3890fc8fcf6781e2593a +5c189f2326173c3a09ed7ef3 +5b950c71608de421b1e7318f +5a6400933d809f1d8200af15 +59d2657f82ca7774b1ec081d +5ba19a8a360c7c30c1c169df +59817e4a1bd4b175e7038d19 diff --git a/IGEV-MVS/lists/dtu/test.txt b/IGEV-MVS/lists/dtu/test.txt new file mode 100644 index 0000000..bb75734 --- /dev/null +++ b/IGEV-MVS/lists/dtu/test.txt @@ -0,0 +1,22 @@ +scan1 +scan4 +scan9 +scan10 +scan11 +scan12 +scan13 +scan15 +scan23 +scan24 +scan29 +scan32 +scan33 +scan34 +scan48 +scan49 +scan62 +scan75 +scan77 +scan110 +scan114 +scan118 \ No newline at end of file diff --git a/IGEV-MVS/lists/dtu/train.txt b/IGEV-MVS/lists/dtu/train.txt new file mode 100644 index 0000000..12ffd26 --- /dev/null +++ b/IGEV-MVS/lists/dtu/train.txt @@ -0,0 +1,79 @@ +scan2 +scan6 +scan7 +scan8 +scan14 +scan16 +scan18 +scan19 +scan20 +scan22 +scan30 +scan31 +scan36 +scan39 +scan41 +scan42 +scan44 +scan45 +scan46 +scan47 +scan50 +scan51 +scan52 +scan53 +scan55 +scan57 +scan58 +scan60 +scan61 +scan63 +scan64 +scan65 +scan68 +scan69 +scan70 +scan71 +scan72 +scan74 +scan76 +scan83 +scan84 +scan85 +scan87 +scan88 +scan89 +scan90 +scan91 +scan92 +scan93 +scan94 +scan95 +scan96 +scan97 +scan98 +scan99 +scan100 +scan101 +scan102 +scan103 +scan104 +scan105 +scan107 +scan108 +scan109 +scan111 +scan112 +scan113 +scan115 +scan116 +scan119 +scan120 +scan121 +scan122 +scan123 +scan124 +scan125 +scan126 +scan127 +scan128 \ No newline at end of file diff --git a/IGEV-MVS/lists/dtu/val.txt b/IGEV-MVS/lists/dtu/val.txt new file mode 100644 index 0000000..25323c8 --- /dev/null +++ b/IGEV-MVS/lists/dtu/val.txt @@ -0,0 +1,18 @@ +scan3 +scan5 +scan17 +scan21 +scan28 +scan35 +scan37 +scan38 +scan40 +scan43 +scan56 +scan59 +scan66 +scan67 +scan82 +scan86 +scan106 +scan117 \ No newline at end of file diff --git a/IGEV-MVS/train_mvs.py b/IGEV-MVS/train_mvs.py new file mode 100644 index 0000000..83fffb2 --- /dev/null +++ b/IGEV-MVS/train_mvs.py @@ -0,0 +1,293 @@ +import argparse +import os +os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2' +import torch +import torch.nn as nn +import torch.nn.parallel +import torch.backends.cudnn as cudnn +import torch.optim as optim +from torch.utils.data import DataLoader +import torch.nn.functional as F +import numpy as np +import random +import time +from torch.utils.tensorboard import SummaryWriter +from datasets import find_dataset_def +from core.igev_mvs import IGEVMVS +from core.submodule import depth_normalization, depth_unnormalization +from utils import * +import sys +import datetime +from tqdm import tqdm + +cudnn.benchmark = True + + + +parser = argparse.ArgumentParser(description='IterMVStereo for high-resolution multi-view stereo') +parser.add_argument('--mode', default='train', help='train or val', choices=['train', 'val']) + +parser.add_argument('--dataset', default='dtu_yao', help='select dataset') +parser.add_argument('--trainpath', default='/data/dtu_data/dtu_train/', help='train datapath') +parser.add_argument('--valpath', help='validation datapath') +parser.add_argument('--trainlist', default='./lists/dtu/train.txt', help='train list') +parser.add_argument('--vallist', default='./lists/dtu/val.txt', help='validation list') +parser.add_argument('--maxdisp', default=256) + +parser.add_argument('--epochs', type=int, default=32, help='number of epochs to train') +parser.add_argument('--lr', type=float, default=0.0002, help='learning rate') +parser.add_argument('--wd', type=float, default=.00001, help='weight decay') + +parser.add_argument('--batch_size', type=int, default=6, help='train batch size') +parser.add_argument('--loadckpt', default=None, help='load a specific checkpoint') +parser.add_argument('--logdir', default='./checkpoints/', help='the directory to save checkpoints/logs') +parser.add_argument('--resume', action='store_true', help='continue to train the model') +parser.add_argument('--regress', action='store_true', help='train the regression and confidence') +parser.add_argument('--small_image', action='store_true', help='train with small input as 640x512, otherwise train with 1280x1024') + +parser.add_argument('--summary_freq', type=int, default=20, help='print and summary frequency') +parser.add_argument('--save_freq', type=int, default=1, help='save checkpoint frequency') +parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed') +parser.add_argument('--iteration', type=int, default=22, help='num of iteration of GRU') + +try: + from torch.cuda.amp import GradScaler +except: + # dummy GradScaler for PyTorch < 1.6 + class GradScaler: + def __init__(self): + pass + def scale(self, loss): + return loss + def unscale_(self, optimizer): + pass + def step(self, optimizer): + optimizer.step() + def update(self): + pass + +def sequence_loss(disp_preds, disp_init_pred, depth_gt, mask, depth_min, depth_max, loss_gamma=0.9): + """ Loss function defined over sequence of depth predictions """ + cross_entropy = nn.BCEWithLogitsLoss() + n_predictions = len(disp_preds) + assert n_predictions >= 1 + loss = 0.0 + mask = mask > 0.5 + batch, _, height, width = depth_gt.size() + inverse_depth_min = (1.0 / depth_min).view(batch, 1, 1, 1) + inverse_depth_max = (1.0 / depth_max).view(batch, 1, 1, 1) + + normalized_disp_gt = depth_normalization(depth_gt, inverse_depth_min, inverse_depth_max) + loss += 1.0 * F.l1_loss(disp_init_pred[mask], normalized_disp_gt[mask], reduction='mean') + + if args.iteration != 0: + for i in range(n_predictions): + adjusted_loss_gamma = loss_gamma**(15/(n_predictions - 1)) + i_weight = adjusted_loss_gamma**(n_predictions - i - 1) + loss += i_weight * F.l1_loss(disp_preds[i][mask], normalized_disp_gt[mask], reduction='mean') + + return loss + +# parse arguments and check +args = parser.parse_args() +if args.resume: # store_true means set the variable as "True" + assert args.mode == "train" + assert args.loadckpt is None +if args.valpath is None: + args.valpath = args.trainpath + +torch.manual_seed(args.seed) +torch.cuda.manual_seed(args.seed) +np.random.seed(args.seed) +random.seed(args.seed) + +if args.mode == "train": + if not os.path.isdir(args.logdir): + os.mkdir(args.logdir) + + current_time_str = str(datetime.datetime.now().strftime('%Y%m%d_%H%M%S')) + print("current time", current_time_str) + + print("creating new summary file") + logger = SummaryWriter(args.logdir) + +print("argv:", sys.argv[1:]) +print_args(args) + +# dataset, dataloader +MVSDataset = find_dataset_def(args.dataset) + +train_dataset = MVSDataset(args.trainpath, args.trainlist, "train", 5, robust_train=True) +test_dataset = MVSDataset(args.valpath, args.vallist, "val", 5, robust_train=False) + +TrainImgLoader = DataLoader(train_dataset, args.batch_size, shuffle=True, num_workers=4, drop_last=True) +TestImgLoader = DataLoader(test_dataset, args.batch_size, shuffle=False, num_workers=4, drop_last=False) + +# model, optimizer +model = IGEVMVS(args) +if args.mode in ["train", "val"]: + model = nn.DataParallel(model) +model.cuda() +model_loss = sequence_loss +optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=args.wd, eps=1e-8) + +# load parameters +start_epoch = 0 +if (args.mode == "train" and args.resume) or (args.mode == "val" and not args.loadckpt): + saved_models = [fn for fn in os.listdir(args.logdir) if fn.endswith(".ckpt")] + saved_models = sorted(saved_models, key=lambda x: int(x.split('_')[-1].split('.')[0])) + # use the latest checkpoint file + loadckpt = os.path.join(args.logdir, saved_models[-1]) + print("resuming", loadckpt) + state_dict = torch.load(loadckpt) + model.load_state_dict(state_dict['model'], strict=False) + optimizer.load_state_dict(state_dict['optimizer']) + start_epoch = state_dict['epoch'] + 1 +elif args.loadckpt: + # load checkpoint file specified by args.loadckpt + print("loading model {}".format(args.loadckpt)) + state_dict = torch.load(args.loadckpt) + model.load_state_dict(state_dict['model'], strict=False) +print("start at epoch {}".format(start_epoch)) +print('Number of model parameters: {}'.format(sum([p.data.nelement() for p in model.parameters()]))) + + +# main function +def train(args): + total_steps = len(TrainImgLoader) * args.epochs + 100 + lr_scheduler = optim.lr_scheduler.OneCycleLR(optimizer, args.lr, total_steps, pct_start=0.01, cycle_momentum=False, anneal_strategy='linear') + + for epoch_idx in range(start_epoch, args.epochs): + print('Epoch {}:'.format(epoch_idx)) + global_step = len(TrainImgLoader) * epoch_idx + + # training + tbar = tqdm(TrainImgLoader) + for batch_idx, sample in enumerate(tbar): + start_time = time.time() + global_step = len(TrainImgLoader) * epoch_idx + batch_idx + do_summary = global_step % args.summary_freq == 0 + scaler = GradScaler(enabled=True) + loss, scalar_outputs = train_sample(args, sample, detailed_summary=do_summary, scaler=scaler) + if do_summary: + save_scalars(logger, 'train', scalar_outputs, global_step) + del scalar_outputs + + tbar.set_description( + 'Epoch {}/{}, Iter {}/{}, train loss = {:.3f}, time = {:.3f}'.format(epoch_idx, args.epochs, batch_idx, len(TrainImgLoader), loss, time.time() - start_time)) + + lr_scheduler.step() + + # checkpoint + if (epoch_idx + 1) % args.save_freq == 0: + torch.save({ + 'model': model.state_dict()}, + "{}/model_{:0>6}.ckpt".format(args.logdir, epoch_idx)) + torch.cuda.empty_cache() + # testing + avg_test_scalars = DictAverageMeter() + tbar = tqdm(TestImgLoader) + for batch_idx, sample in enumerate(tbar): + start_time = time.time() + global_step = len(TestImgLoader) * epoch_idx + batch_idx + do_summary = global_step % args.summary_freq == 0 + loss, scalar_outputs = test_sample(args, sample, detailed_summary=do_summary) + if do_summary: + save_scalars(logger, 'test', scalar_outputs, global_step) + avg_test_scalars.update(scalar_outputs) + del scalar_outputs + tbar.set_description('Epoch {}/{}, Iter {}/{}, test loss = {:.3f}, time = {:3f}'.format(epoch_idx, args.epochs, batch_idx, + len(TestImgLoader), loss, + time.time() - start_time)) + save_scalars(logger, 'fulltest', avg_test_scalars.mean(), global_step) + print("avg_test_scalars:", avg_test_scalars.mean()) + torch.cuda.empty_cache() + +def test(args): + avg_test_scalars = DictAverageMeter() + for batch_idx, sample in enumerate(TestImgLoader): + start_time = time.time() + loss, scalar_outputs = test_sample(args, sample, detailed_summary=True) + avg_test_scalars.update(scalar_outputs) + del scalar_outputs + print('Iter {}/{}, test loss = {:.3f}, time = {:3f}'.format(batch_idx, len(TestImgLoader), loss, + time.time() - start_time)) + if batch_idx % 100 == 0: + print("Iter {}/{}, test results = {}".format(batch_idx, len(TestImgLoader), avg_test_scalars.mean())) + print("final", avg_test_scalars) + + +def train_sample(args, sample, detailed_summary=False, scaler=None): + model.train() + optimizer.zero_grad() + sample_cuda = tocuda(sample) + depth_gt = sample_cuda["depth"] + mask = sample_cuda["mask"] + depth_gt_0 = depth_gt['level_0'] + mask_0 = mask['level_0'] + depth_gt_1 = depth_gt['level_2'] + mask_1 = mask['level_2'] + + disp_init, disp_predictions = model(sample_cuda["imgs"], sample_cuda["proj_matrices"], + sample_cuda["depth_min"], sample_cuda["depth_max"]) + + loss = model_loss(disp_predictions, disp_init, depth_gt_0, mask_0, sample_cuda["depth_min"], sample_cuda["depth_max"]) + scaler.scale(loss).backward() + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) + scaler.step(optimizer) + scaler.update() + + inverse_depth_min = (1.0 / sample_cuda["depth_min"]).view(args.batch_size, 1, 1, 1) + inverse_depth_max = (1.0 / sample_cuda["depth_max"]).view(args.batch_size, 1, 1, 1) + depth_init = depth_unnormalization(disp_init, inverse_depth_min, inverse_depth_max) + + depth_predictions = [] + for disp in disp_predictions: + depth_predictions.append(depth_unnormalization(disp, inverse_depth_min, inverse_depth_max)) + + scalar_outputs = {"loss": loss} + scalar_outputs["abs_error_initial"] = AbsDepthError_metrics(depth_init, depth_gt_0, mask_0 > 0.5) + scalar_outputs["thres1mm_initial"] = Thres_metrics(depth_init, depth_gt_0, mask_0 > 0.5, 1) + scalar_outputs["abs_error_final_full"] = AbsDepthError_metrics(depth_predictions[-1], depth_gt_0, mask_0 > 0.5) + return tensor2float(loss), tensor2float(scalar_outputs) + + +@make_nograd_func +def test_sample(args, sample, detailed_summary=True): + model.eval() + sample_cuda = tocuda(sample) + depth_gt = sample_cuda["depth"] + mask = sample_cuda["mask"] + depth_gt_0 = depth_gt['level_0'] + mask_0 = mask['level_0'] + depth_gt_1 = depth_gt['level_2'] + mask_1 = mask['level_2'] + + disp_init, disp_predictions = model(sample_cuda["imgs"], sample_cuda["proj_matrices"], + sample_cuda["depth_min"], sample_cuda["depth_max"]) + + loss = model_loss(disp_predictions, disp_init, depth_gt_0, mask_0, sample_cuda["depth_min"], sample_cuda["depth_max"]) + + + inverse_depth_min = (1.0 / sample_cuda["depth_min"]).view(sample_cuda["depth_min"].size()[0], 1, 1, 1) + inverse_depth_max = (1.0 / sample_cuda["depth_max"]).view(sample_cuda["depth_max"].size()[0], 1, 1, 1) + depth_init = depth_unnormalization(disp_init, inverse_depth_min, inverse_depth_max) + + depth_predictions = [] + for disp in disp_predictions: + depth_predictions.append(depth_unnormalization(disp, inverse_depth_min, inverse_depth_max)) + + scalar_outputs = {"loss": loss} + scalar_outputs["abs_error_initial"] = AbsDepthError_metrics(depth_init, depth_gt_0, mask_0 > 0.5) + scalar_outputs["thres1mm_initial"] = Thres_metrics(depth_init, depth_gt_0, mask_0 > 0.5, 1) + scalar_outputs["abs_error_final_full"] = AbsDepthError_metrics(depth_predictions[-1], depth_gt_0, mask_0 > 0.5) + return tensor2float(loss), tensor2float(scalar_outputs) + + +if __name__ == '__main__': + if args.mode == "train": + train(args) + elif args.mode == "val": + test(args) + diff --git a/IGEV-MVS/utils.py b/IGEV-MVS/utils.py new file mode 100644 index 0000000..c325f17 --- /dev/null +++ b/IGEV-MVS/utils.py @@ -0,0 +1,155 @@ +import numpy as np +import torchvision.utils as vutils +import torch +import torch.nn.functional as F + + +# print arguments +def print_args(args): + print("################################ args ################################") + for k, v in args.__dict__.items(): + print("{0: <10}\t{1: <30}\t{2: <20}".format(k, str(v), str(type(v)))) + print("########################################################################") + + +# torch.no_grad warpper for functions +def make_nograd_func(func): + def wrapper(*f_args, **f_kwargs): + with torch.no_grad(): + ret = func(*f_args, **f_kwargs) + return ret + + return wrapper + + +# convert a function into recursive style to handle nested dict/list/tuple variables +def make_recursive_func(func): + def wrapper(vars): + if isinstance(vars, list): + return [wrapper(x) for x in vars] + elif isinstance(vars, tuple): + return tuple([wrapper(x) for x in vars]) + elif isinstance(vars, dict): + return {k: wrapper(v) for k, v in vars.items()} + else: + return func(vars) + + return wrapper + + +@make_recursive_func +def tensor2float(vars): + if isinstance(vars, float): + return vars + elif isinstance(vars, torch.Tensor): + return vars.data.item() + else: + raise NotImplementedError("invalid input type {} for tensor2float".format(type(vars))) + + +@make_recursive_func +def tensor2numpy(vars): + if isinstance(vars, np.ndarray): + return vars + elif isinstance(vars, torch.Tensor): + return vars.detach().cpu().numpy().copy() + else: + raise NotImplementedError("invalid input type {} for tensor2numpy".format(type(vars))) + + +@make_recursive_func +def tocuda(vars): + if isinstance(vars, torch.Tensor): + return vars.cuda() + elif isinstance(vars, str): + return vars + else: + raise NotImplementedError("invalid input type {} for tocuda".format(type(vars))) + + +def save_scalars(logger, mode, scalar_dict, global_step): + scalar_dict = tensor2float(scalar_dict) + for key, value in scalar_dict.items(): + if not isinstance(value, (list, tuple)): + name = '{}/{}'.format(mode, key) + logger.add_scalar(name, value, global_step) + else: + for idx in range(len(value)): + name = '{}/{}_{}'.format(mode, key, idx) + logger.add_scalar(name, value[idx], global_step) + + +def save_images(logger, mode, images_dict, global_step): + images_dict = tensor2numpy(images_dict) + + def preprocess(name, img): + if not (len(img.shape) == 3 or len(img.shape) == 4): + raise NotImplementedError("invalid img shape {}:{} in save_images".format(name, img.shape)) + if len(img.shape) == 3: + img = img[:, np.newaxis, :, :] + img = torch.from_numpy(img[:1]) + return vutils.make_grid(img, padding=0, nrow=1, normalize=True, scale_each=True) + + for key, value in images_dict.items(): + if not isinstance(value, (list, tuple)): + name = '{}/{}'.format(mode, key) + logger.add_image(name, preprocess(name, value), global_step) + else: + for idx in range(len(value)): + name = '{}/{}_{}'.format(mode, key, idx) + logger.add_image(name, preprocess(name, value[idx]), global_step) + + +class DictAverageMeter(object): + def __init__(self): + self.data = {} + self.count = 0 + + def update(self, new_input): + self.count += 1 + if len(self.data) == 0: + for k, v in new_input.items(): + if not isinstance(v, float): + raise NotImplementedError("invalid data {}: {}".format(k, type(v))) + self.data[k] = v + else: + for k, v in new_input.items(): + if not isinstance(v, float): + raise NotImplementedError("invalid data {}: {}".format(k, type(v))) + self.data[k] += v + + def mean(self): + return {k: v / self.count for k, v in self.data.items()} + + +# a wrapper to compute metrics for each image individually +def compute_metrics_for_each_image(metric_func): + def wrapper(depth_est, depth_gt, mask, *args): + batch_size = depth_gt.shape[0] + results = [] + # compute result one by one + for idx in range(batch_size): + ret = metric_func(depth_est[idx], depth_gt[idx], mask[idx], *args) + results.append(ret) + return torch.stack(results).mean() + + return wrapper + + +@make_nograd_func +@compute_metrics_for_each_image +def Thres_metrics(depth_est, depth_gt, mask, thres): + # if thres is int or float, then True + assert isinstance(thres, (int, float)) + depth_est, depth_gt = depth_est[mask], depth_gt[mask] + errors = torch.abs(depth_est - depth_gt) + err_mask = errors > thres + return torch.mean(err_mask.float()) + + +# NOTE: please do not use this to build up training loss +@make_nograd_func +@compute_metrics_for_each_image +def AbsDepthError_metrics(depth_est, depth_gt, mask): + depth_est, depth_gt = depth_est[mask], depth_gt[mask] + return torch.mean((depth_est - depth_gt).abs()) \ No newline at end of file