2023-03-20 19:52:04 +08:00
|
|
|
import argparse
|
|
|
|
import os
|
|
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2'
|
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
import torch.nn.parallel
|
|
|
|
import torch.backends.cudnn as cudnn
|
|
|
|
import torch.optim as optim
|
|
|
|
from torch.utils.data import DataLoader
|
|
|
|
import torch.nn.functional as F
|
|
|
|
import numpy as np
|
|
|
|
import random
|
|
|
|
import time
|
|
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
|
|
from datasets import find_dataset_def
|
|
|
|
from core.igev_mvs import IGEVMVS
|
|
|
|
from core.submodule import depth_normalization, depth_unnormalization
|
|
|
|
from utils import *
|
|
|
|
import sys
|
|
|
|
import datetime
|
|
|
|
from tqdm import tqdm
|
|
|
|
|
|
|
|
cudnn.benchmark = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser(description='IterMVStereo for high-resolution multi-view stereo')
|
|
|
|
parser.add_argument('--mode', default='train', help='train or val', choices=['train', 'val'])
|
|
|
|
|
|
|
|
parser.add_argument('--dataset', default='dtu_yao', help='select dataset')
|
2023-03-20 20:04:29 +08:00
|
|
|
parser.add_argument('--trainpath', default='/data/DTU_data/dtu_train/', help='train datapath')
|
2023-03-20 19:52:04 +08:00
|
|
|
parser.add_argument('--valpath', help='validation datapath')
|
|
|
|
parser.add_argument('--trainlist', default='./lists/dtu/train.txt', help='train list')
|
|
|
|
parser.add_argument('--vallist', default='./lists/dtu/val.txt', help='validation list')
|
|
|
|
parser.add_argument('--maxdisp', default=256)
|
|
|
|
|
|
|
|
parser.add_argument('--epochs', type=int, default=32, help='number of epochs to train')
|
|
|
|
parser.add_argument('--lr', type=float, default=0.0002, help='learning rate')
|
|
|
|
parser.add_argument('--wd', type=float, default=.00001, help='weight decay')
|
|
|
|
|
|
|
|
parser.add_argument('--batch_size', type=int, default=6, help='train batch size')
|
|
|
|
parser.add_argument('--loadckpt', default=None, help='load a specific checkpoint')
|
|
|
|
parser.add_argument('--logdir', default='./checkpoints/', help='the directory to save checkpoints/logs')
|
|
|
|
parser.add_argument('--resume', action='store_true', help='continue to train the model')
|
|
|
|
parser.add_argument('--regress', action='store_true', help='train the regression and confidence')
|
|
|
|
parser.add_argument('--small_image', action='store_true', help='train with small input as 640x512, otherwise train with 1280x1024')
|
|
|
|
|
|
|
|
parser.add_argument('--summary_freq', type=int, default=20, help='print and summary frequency')
|
|
|
|
parser.add_argument('--save_freq', type=int, default=1, help='save checkpoint frequency')
|
|
|
|
parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed')
|
|
|
|
parser.add_argument('--iteration', type=int, default=22, help='num of iteration of GRU')
|
|
|
|
|
|
|
|
try:
|
|
|
|
from torch.cuda.amp import GradScaler
|
|
|
|
except:
|
|
|
|
# dummy GradScaler for PyTorch < 1.6
|
|
|
|
class GradScaler:
|
|
|
|
def __init__(self):
|
|
|
|
pass
|
|
|
|
def scale(self, loss):
|
|
|
|
return loss
|
|
|
|
def unscale_(self, optimizer):
|
|
|
|
pass
|
|
|
|
def step(self, optimizer):
|
|
|
|
optimizer.step()
|
|
|
|
def update(self):
|
|
|
|
pass
|
|
|
|
|
|
|
|
def sequence_loss(disp_preds, disp_init_pred, depth_gt, mask, depth_min, depth_max, loss_gamma=0.9):
|
|
|
|
""" Loss function defined over sequence of depth predictions """
|
|
|
|
cross_entropy = nn.BCEWithLogitsLoss()
|
|
|
|
n_predictions = len(disp_preds)
|
|
|
|
assert n_predictions >= 1
|
|
|
|
loss = 0.0
|
|
|
|
mask = mask > 0.5
|
|
|
|
batch, _, height, width = depth_gt.size()
|
|
|
|
inverse_depth_min = (1.0 / depth_min).view(batch, 1, 1, 1)
|
|
|
|
inverse_depth_max = (1.0 / depth_max).view(batch, 1, 1, 1)
|
|
|
|
|
|
|
|
normalized_disp_gt = depth_normalization(depth_gt, inverse_depth_min, inverse_depth_max)
|
|
|
|
loss += 1.0 * F.l1_loss(disp_init_pred[mask], normalized_disp_gt[mask], reduction='mean')
|
|
|
|
|
|
|
|
if args.iteration != 0:
|
|
|
|
for i in range(n_predictions):
|
|
|
|
adjusted_loss_gamma = loss_gamma**(15/(n_predictions - 1))
|
|
|
|
i_weight = adjusted_loss_gamma**(n_predictions - i - 1)
|
|
|
|
loss += i_weight * F.l1_loss(disp_preds[i][mask], normalized_disp_gt[mask], reduction='mean')
|
|
|
|
|
|
|
|
return loss
|
|
|
|
|
|
|
|
# parse arguments and check
|
|
|
|
args = parser.parse_args()
|
|
|
|
if args.resume: # store_true means set the variable as "True"
|
|
|
|
assert args.mode == "train"
|
|
|
|
assert args.loadckpt is None
|
|
|
|
if args.valpath is None:
|
|
|
|
args.valpath = args.trainpath
|
|
|
|
|
|
|
|
torch.manual_seed(args.seed)
|
|
|
|
torch.cuda.manual_seed(args.seed)
|
|
|
|
np.random.seed(args.seed)
|
|
|
|
random.seed(args.seed)
|
|
|
|
|
|
|
|
if args.mode == "train":
|
|
|
|
if not os.path.isdir(args.logdir):
|
|
|
|
os.mkdir(args.logdir)
|
|
|
|
|
|
|
|
current_time_str = str(datetime.datetime.now().strftime('%Y%m%d_%H%M%S'))
|
|
|
|
print("current time", current_time_str)
|
|
|
|
|
|
|
|
print("creating new summary file")
|
|
|
|
logger = SummaryWriter(args.logdir)
|
|
|
|
|
|
|
|
print("argv:", sys.argv[1:])
|
|
|
|
print_args(args)
|
|
|
|
|
|
|
|
# dataset, dataloader
|
|
|
|
MVSDataset = find_dataset_def(args.dataset)
|
|
|
|
|
|
|
|
train_dataset = MVSDataset(args.trainpath, args.trainlist, "train", 5, robust_train=True)
|
|
|
|
test_dataset = MVSDataset(args.valpath, args.vallist, "val", 5, robust_train=False)
|
|
|
|
|
|
|
|
TrainImgLoader = DataLoader(train_dataset, args.batch_size, shuffle=True, num_workers=4, drop_last=True)
|
|
|
|
TestImgLoader = DataLoader(test_dataset, args.batch_size, shuffle=False, num_workers=4, drop_last=False)
|
|
|
|
|
|
|
|
# model, optimizer
|
|
|
|
model = IGEVMVS(args)
|
|
|
|
if args.mode in ["train", "val"]:
|
|
|
|
model = nn.DataParallel(model)
|
|
|
|
model.cuda()
|
|
|
|
model_loss = sequence_loss
|
|
|
|
optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=args.wd, eps=1e-8)
|
|
|
|
|
|
|
|
# load parameters
|
|
|
|
start_epoch = 0
|
|
|
|
if (args.mode == "train" and args.resume) or (args.mode == "val" and not args.loadckpt):
|
|
|
|
saved_models = [fn for fn in os.listdir(args.logdir) if fn.endswith(".ckpt")]
|
|
|
|
saved_models = sorted(saved_models, key=lambda x: int(x.split('_')[-1].split('.')[0]))
|
|
|
|
# use the latest checkpoint file
|
|
|
|
loadckpt = os.path.join(args.logdir, saved_models[-1])
|
|
|
|
print("resuming", loadckpt)
|
|
|
|
state_dict = torch.load(loadckpt)
|
|
|
|
model.load_state_dict(state_dict['model'], strict=False)
|
|
|
|
optimizer.load_state_dict(state_dict['optimizer'])
|
|
|
|
start_epoch = state_dict['epoch'] + 1
|
|
|
|
elif args.loadckpt:
|
|
|
|
# load checkpoint file specified by args.loadckpt
|
|
|
|
print("loading model {}".format(args.loadckpt))
|
|
|
|
state_dict = torch.load(args.loadckpt)
|
|
|
|
model.load_state_dict(state_dict['model'], strict=False)
|
|
|
|
print("start at epoch {}".format(start_epoch))
|
|
|
|
print('Number of model parameters: {}'.format(sum([p.data.nelement() for p in model.parameters()])))
|
|
|
|
|
|
|
|
|
|
|
|
# main function
|
|
|
|
def train(args):
|
|
|
|
total_steps = len(TrainImgLoader) * args.epochs + 100
|
|
|
|
lr_scheduler = optim.lr_scheduler.OneCycleLR(optimizer, args.lr, total_steps, pct_start=0.01, cycle_momentum=False, anneal_strategy='linear')
|
|
|
|
|
|
|
|
for epoch_idx in range(start_epoch, args.epochs):
|
|
|
|
print('Epoch {}:'.format(epoch_idx))
|
|
|
|
global_step = len(TrainImgLoader) * epoch_idx
|
|
|
|
|
|
|
|
# training
|
|
|
|
tbar = tqdm(TrainImgLoader)
|
|
|
|
for batch_idx, sample in enumerate(tbar):
|
|
|
|
start_time = time.time()
|
|
|
|
global_step = len(TrainImgLoader) * epoch_idx + batch_idx
|
|
|
|
do_summary = global_step % args.summary_freq == 0
|
|
|
|
scaler = GradScaler(enabled=True)
|
|
|
|
loss, scalar_outputs = train_sample(args, sample, detailed_summary=do_summary, scaler=scaler)
|
|
|
|
if do_summary:
|
|
|
|
save_scalars(logger, 'train', scalar_outputs, global_step)
|
|
|
|
del scalar_outputs
|
|
|
|
|
|
|
|
tbar.set_description(
|
|
|
|
'Epoch {}/{}, Iter {}/{}, train loss = {:.3f}, time = {:.3f}'.format(epoch_idx, args.epochs, batch_idx, len(TrainImgLoader), loss, time.time() - start_time))
|
|
|
|
|
|
|
|
lr_scheduler.step()
|
|
|
|
|
|
|
|
# checkpoint
|
|
|
|
if (epoch_idx + 1) % args.save_freq == 0:
|
|
|
|
torch.save({
|
|
|
|
'model': model.state_dict()},
|
|
|
|
"{}/model_{:0>6}.ckpt".format(args.logdir, epoch_idx))
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
# testing
|
|
|
|
avg_test_scalars = DictAverageMeter()
|
|
|
|
tbar = tqdm(TestImgLoader)
|
|
|
|
for batch_idx, sample in enumerate(tbar):
|
|
|
|
start_time = time.time()
|
|
|
|
global_step = len(TestImgLoader) * epoch_idx + batch_idx
|
|
|
|
do_summary = global_step % args.summary_freq == 0
|
|
|
|
loss, scalar_outputs = test_sample(args, sample, detailed_summary=do_summary)
|
|
|
|
if do_summary:
|
|
|
|
save_scalars(logger, 'test', scalar_outputs, global_step)
|
|
|
|
avg_test_scalars.update(scalar_outputs)
|
|
|
|
del scalar_outputs
|
|
|
|
tbar.set_description('Epoch {}/{}, Iter {}/{}, test loss = {:.3f}, time = {:3f}'.format(epoch_idx, args.epochs, batch_idx,
|
|
|
|
len(TestImgLoader), loss,
|
|
|
|
time.time() - start_time))
|
|
|
|
save_scalars(logger, 'fulltest', avg_test_scalars.mean(), global_step)
|
|
|
|
print("avg_test_scalars:", avg_test_scalars.mean())
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
|
|
def test(args):
|
|
|
|
avg_test_scalars = DictAverageMeter()
|
|
|
|
for batch_idx, sample in enumerate(TestImgLoader):
|
|
|
|
start_time = time.time()
|
|
|
|
loss, scalar_outputs = test_sample(args, sample, detailed_summary=True)
|
|
|
|
avg_test_scalars.update(scalar_outputs)
|
|
|
|
del scalar_outputs
|
|
|
|
print('Iter {}/{}, test loss = {:.3f}, time = {:3f}'.format(batch_idx, len(TestImgLoader), loss,
|
|
|
|
time.time() - start_time))
|
|
|
|
if batch_idx % 100 == 0:
|
|
|
|
print("Iter {}/{}, test results = {}".format(batch_idx, len(TestImgLoader), avg_test_scalars.mean()))
|
|
|
|
print("final", avg_test_scalars)
|
|
|
|
|
|
|
|
|
|
|
|
def train_sample(args, sample, detailed_summary=False, scaler=None):
|
|
|
|
model.train()
|
|
|
|
optimizer.zero_grad()
|
|
|
|
sample_cuda = tocuda(sample)
|
|
|
|
depth_gt = sample_cuda["depth"]
|
|
|
|
mask = sample_cuda["mask"]
|
|
|
|
depth_gt_0 = depth_gt['level_0']
|
|
|
|
mask_0 = mask['level_0']
|
|
|
|
depth_gt_1 = depth_gt['level_2']
|
|
|
|
mask_1 = mask['level_2']
|
|
|
|
|
|
|
|
disp_init, disp_predictions = model(sample_cuda["imgs"], sample_cuda["proj_matrices"],
|
|
|
|
sample_cuda["depth_min"], sample_cuda["depth_max"])
|
|
|
|
|
|
|
|
loss = model_loss(disp_predictions, disp_init, depth_gt_0, mask_0, sample_cuda["depth_min"], sample_cuda["depth_max"])
|
|
|
|
scaler.scale(loss).backward()
|
|
|
|
scaler.unscale_(optimizer)
|
|
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
|
|
|
scaler.step(optimizer)
|
|
|
|
scaler.update()
|
|
|
|
|
|
|
|
inverse_depth_min = (1.0 / sample_cuda["depth_min"]).view(args.batch_size, 1, 1, 1)
|
|
|
|
inverse_depth_max = (1.0 / sample_cuda["depth_max"]).view(args.batch_size, 1, 1, 1)
|
|
|
|
depth_init = depth_unnormalization(disp_init, inverse_depth_min, inverse_depth_max)
|
|
|
|
|
|
|
|
depth_predictions = []
|
|
|
|
for disp in disp_predictions:
|
|
|
|
depth_predictions.append(depth_unnormalization(disp, inverse_depth_min, inverse_depth_max))
|
|
|
|
|
|
|
|
scalar_outputs = {"loss": loss}
|
|
|
|
scalar_outputs["abs_error_initial"] = AbsDepthError_metrics(depth_init, depth_gt_0, mask_0 > 0.5)
|
|
|
|
scalar_outputs["thres1mm_initial"] = Thres_metrics(depth_init, depth_gt_0, mask_0 > 0.5, 1)
|
|
|
|
scalar_outputs["abs_error_final_full"] = AbsDepthError_metrics(depth_predictions[-1], depth_gt_0, mask_0 > 0.5)
|
|
|
|
return tensor2float(loss), tensor2float(scalar_outputs)
|
|
|
|
|
|
|
|
|
|
|
|
@make_nograd_func
|
|
|
|
def test_sample(args, sample, detailed_summary=True):
|
|
|
|
model.eval()
|
|
|
|
sample_cuda = tocuda(sample)
|
|
|
|
depth_gt = sample_cuda["depth"]
|
|
|
|
mask = sample_cuda["mask"]
|
|
|
|
depth_gt_0 = depth_gt['level_0']
|
|
|
|
mask_0 = mask['level_0']
|
|
|
|
depth_gt_1 = depth_gt['level_2']
|
|
|
|
mask_1 = mask['level_2']
|
|
|
|
|
|
|
|
disp_init, disp_predictions = model(sample_cuda["imgs"], sample_cuda["proj_matrices"],
|
|
|
|
sample_cuda["depth_min"], sample_cuda["depth_max"])
|
|
|
|
|
|
|
|
loss = model_loss(disp_predictions, disp_init, depth_gt_0, mask_0, sample_cuda["depth_min"], sample_cuda["depth_max"])
|
|
|
|
|
|
|
|
|
|
|
|
inverse_depth_min = (1.0 / sample_cuda["depth_min"]).view(sample_cuda["depth_min"].size()[0], 1, 1, 1)
|
|
|
|
inverse_depth_max = (1.0 / sample_cuda["depth_max"]).view(sample_cuda["depth_max"].size()[0], 1, 1, 1)
|
|
|
|
depth_init = depth_unnormalization(disp_init, inverse_depth_min, inverse_depth_max)
|
|
|
|
|
|
|
|
depth_predictions = []
|
|
|
|
for disp in disp_predictions:
|
|
|
|
depth_predictions.append(depth_unnormalization(disp, inverse_depth_min, inverse_depth_max))
|
|
|
|
|
|
|
|
scalar_outputs = {"loss": loss}
|
|
|
|
scalar_outputs["abs_error_initial"] = AbsDepthError_metrics(depth_init, depth_gt_0, mask_0 > 0.5)
|
|
|
|
scalar_outputs["thres1mm_initial"] = Thres_metrics(depth_init, depth_gt_0, mask_0 > 0.5, 1)
|
|
|
|
scalar_outputs["abs_error_final_full"] = AbsDepthError_metrics(depth_predictions[-1], depth_gt_0, mask_0 > 0.5)
|
|
|
|
return tensor2float(loss), tensor2float(scalar_outputs)
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
if args.mode == "train":
|
|
|
|
train(args)
|
|
|
|
elif args.mode == "val":
|
|
|
|
test(args)
|
|
|
|
|