199 lines
9.1 KiB
Python
199 lines
9.1 KiB
Python
from __future__ import print_function, division
|
|
import argparse
|
|
import os
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.parallel
|
|
import torch.backends.cudnn as cudnn
|
|
import torch.optim as optim
|
|
import torch.utils.data
|
|
from torch.autograd import Variable
|
|
import torchvision.utils as vutils
|
|
import torch.nn.functional as F
|
|
import numpy as np
|
|
import time
|
|
from tensorboardX import SummaryWriter
|
|
from datasets import __datasets__
|
|
from models import __models__, model_loss
|
|
from utils import *
|
|
from torch.utils.data import DataLoader
|
|
import gc
|
|
|
|
cudnn.benchmark = True
|
|
|
|
parser = argparse.ArgumentParser(description='Group-wise Correlation Stereo Network (GwcNet)')
|
|
parser.add_argument('--model', default='gwcnet-g', help='select a model structure', choices=__models__.keys())
|
|
parser.add_argument('--maxdisp', type=int, default=192, help='maximum disparity')
|
|
|
|
parser.add_argument('--dataset', required=True, help='dataset name', choices=__datasets__.keys())
|
|
parser.add_argument('--datapath', required=True, help='data path')
|
|
parser.add_argument('--trainlist', required=True, help='training list')
|
|
parser.add_argument('--testlist', required=True, help='testing list')
|
|
|
|
parser.add_argument('--lr', type=float, default=0.001, help='base learning rate')
|
|
parser.add_argument('--batch_size', type=int, default=16, help='training batch size')
|
|
parser.add_argument('--test_batch_size', type=int, default=8, help='testing batch size')
|
|
parser.add_argument('--epochs', type=int, required=True, help='number of epochs to train')
|
|
parser.add_argument('--lrepochs', type=str, required=True, help='the epochs to decay lr: the downscale rate')
|
|
|
|
parser.add_argument('--logdir', required=True, help='the directory to save logs and checkpoints')
|
|
parser.add_argument('--loadckpt', help='load the weights from a specific checkpoint')
|
|
parser.add_argument('--resume', action='store_true', help='continue training the model')
|
|
parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)')
|
|
|
|
parser.add_argument('--summary_freq', type=int, default=20, help='the frequency of saving summary')
|
|
parser.add_argument('--save_freq', type=int, default=1, help='the frequency of saving checkpoint')
|
|
|
|
# parse arguments, set seeds
|
|
args = parser.parse_args()
|
|
torch.manual_seed(args.seed)
|
|
torch.cuda.manual_seed(args.seed)
|
|
os.makedirs(args.logdir, exist_ok=True)
|
|
|
|
# create summary logger
|
|
print("creating new summary file")
|
|
logger = SummaryWriter(args.logdir)
|
|
|
|
# dataset, dataloader
|
|
StereoDataset = __datasets__[args.dataset]
|
|
train_dataset = StereoDataset(args.datapath, args.trainlist, True)
|
|
test_dataset = StereoDataset(args.datapath, args.testlist, False)
|
|
TrainImgLoader = DataLoader(train_dataset, args.batch_size, shuffle=True, num_workers=8, drop_last=True)
|
|
TestImgLoader = DataLoader(test_dataset, args.test_batch_size, shuffle=False, num_workers=4, drop_last=False)
|
|
|
|
# model, optimizer
|
|
model = __models__[args.model](args.maxdisp)
|
|
model = nn.DataParallel(model)
|
|
model.cuda()
|
|
optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.999))
|
|
|
|
# load parameters
|
|
start_epoch = 0
|
|
if args.resume:
|
|
# find all checkpoints file and sort according to epoch id
|
|
all_saved_ckpts = [fn for fn in os.listdir(args.logdir) if fn.endswith(".ckpt")]
|
|
all_saved_ckpts = sorted(all_saved_ckpts, key=lambda x: int(x.split('_')[-1].split('.')[0]))
|
|
# use the latest checkpoint file
|
|
loadckpt = os.path.join(args.logdir, all_saved_ckpts[-1])
|
|
print("loading the lastest model in logdir: {}".format(loadckpt))
|
|
state_dict = torch.load(loadckpt)
|
|
model.load_state_dict(state_dict['model'])
|
|
optimizer.load_state_dict(state_dict['optimizer'])
|
|
start_epoch = state_dict['epoch'] + 1
|
|
elif args.loadckpt:
|
|
# load the checkpoint file specified by args.loadckpt
|
|
print("loading model {}".format(args.loadckpt))
|
|
state_dict = torch.load(args.loadckpt)
|
|
model.load_state_dict(state_dict['model'])
|
|
print("start at epoch {}".format(start_epoch))
|
|
|
|
|
|
def train():
|
|
for epoch_idx in range(start_epoch, args.epochs):
|
|
adjust_learning_rate(optimizer, epoch_idx, args.lr, args.lrepochs)
|
|
|
|
# training
|
|
for batch_idx, sample in enumerate(TrainImgLoader):
|
|
global_step = len(TrainImgLoader) * epoch_idx + batch_idx
|
|
start_time = time.time()
|
|
do_summary = global_step % args.summary_freq == 0
|
|
loss, scalar_outputs, image_outputs = train_sample(sample, compute_metrics=do_summary)
|
|
if do_summary:
|
|
save_scalars(logger, 'train', scalar_outputs, global_step)
|
|
save_images(logger, 'train', image_outputs, global_step)
|
|
del scalar_outputs, image_outputs
|
|
print('Epoch {}/{}, Iter {}/{}, train loss = {:.3f}, time = {:.3f}'.format(epoch_idx, args.epochs,
|
|
batch_idx,
|
|
len(TrainImgLoader), loss,
|
|
time.time() - start_time))
|
|
# saving checkpoints
|
|
if (epoch_idx + 1) % args.save_freq == 0:
|
|
checkpoint_data = {'epoch': epoch_idx, 'model': model.state_dict(), 'optimizer': optimizer.state_dict()}
|
|
torch.save(checkpoint_data, "{}/checkpoint_{:0>6}.ckpt".format(args.logdir, epoch_idx))
|
|
gc.collect()
|
|
|
|
# testing
|
|
avg_test_scalars = AverageMeterDict()
|
|
for batch_idx, sample in enumerate(TestImgLoader):
|
|
global_step = len(TestImgLoader) * epoch_idx + batch_idx
|
|
start_time = time.time()
|
|
do_summary = global_step % args.summary_freq == 0
|
|
loss, scalar_outputs, image_outputs = test_sample(sample, compute_metrics=do_summary)
|
|
if do_summary:
|
|
save_scalars(logger, 'test', scalar_outputs, global_step)
|
|
save_images(logger, 'test', image_outputs, global_step)
|
|
avg_test_scalars.update(scalar_outputs)
|
|
del scalar_outputs, image_outputs
|
|
print('Epoch {}/{}, Iter {}/{}, test loss = {:.3f}, time = {:3f}'.format(epoch_idx, args.epochs,
|
|
batch_idx,
|
|
len(TestImgLoader), loss,
|
|
time.time() - start_time))
|
|
avg_test_scalars = avg_test_scalars.mean()
|
|
save_scalars(logger, 'fulltest', avg_test_scalars, len(TrainImgLoader) * (epoch_idx + 1))
|
|
print("avg_test_scalars", avg_test_scalars)
|
|
gc.collect()
|
|
|
|
|
|
# train one sample
|
|
def train_sample(sample, compute_metrics=False):
|
|
model.train()
|
|
|
|
imgL, imgR, disp_gt = sample['left'], sample['right'], sample['disparity']
|
|
imgL = imgL.cuda()
|
|
imgR = imgR.cuda()
|
|
disp_gt = disp_gt.cuda()
|
|
|
|
optimizer.zero_grad()
|
|
|
|
disp_ests = model(imgL, imgR)
|
|
mask = (disp_gt < args.maxdisp) & (disp_gt > 0)
|
|
loss = model_loss(disp_ests, disp_gt, mask)
|
|
|
|
scalar_outputs = {"loss": loss}
|
|
image_outputs = {"disp_est": disp_ests, "disp_gt": disp_gt, "imgL": imgL, "imgR": imgR}
|
|
if compute_metrics:
|
|
with torch.no_grad():
|
|
image_outputs["errormap"] = [disp_error_image_func()(disp_est, disp_gt) for disp_est in disp_ests]
|
|
scalar_outputs["EPE"] = [EPE_metric(disp_est, disp_gt, mask) for disp_est in disp_ests]
|
|
scalar_outputs["D1"] = [D1_metric(disp_est, disp_gt, mask) for disp_est in disp_ests]
|
|
scalar_outputs["Thres1"] = [Thres_metric(disp_est, disp_gt, mask, 1.0) for disp_est in disp_ests]
|
|
scalar_outputs["Thres2"] = [Thres_metric(disp_est, disp_gt, mask, 2.0) for disp_est in disp_ests]
|
|
scalar_outputs["Thres3"] = [Thres_metric(disp_est, disp_gt, mask, 3.0) for disp_est in disp_ests]
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
return tensor2float(loss), tensor2float(scalar_outputs), image_outputs
|
|
|
|
|
|
# test one sample
|
|
@make_nograd_func
|
|
def test_sample(sample, compute_metrics=True):
|
|
model.eval()
|
|
|
|
imgL, imgR, disp_gt = sample['left'], sample['right'], sample['disparity']
|
|
imgL = imgL.cuda()
|
|
imgR = imgR.cuda()
|
|
disp_gt = disp_gt.cuda()
|
|
|
|
disp_ests = model(imgL, imgR)
|
|
mask = (disp_gt < args.maxdisp) & (disp_gt > 0)
|
|
loss = model_loss(disp_ests, disp_gt, mask)
|
|
|
|
scalar_outputs = {"loss": loss}
|
|
image_outputs = {"disp_est": disp_ests, "disp_gt": disp_gt, "imgL": imgL, "imgR": imgR}
|
|
|
|
scalar_outputs["D1"] = [D1_metric(disp_est, disp_gt, mask) for disp_est in disp_ests]
|
|
scalar_outputs["EPE"] = [EPE_metric(disp_est, disp_gt, mask) for disp_est in disp_ests]
|
|
scalar_outputs["Thres1"] = [Thres_metric(disp_est, disp_gt, mask, 1.0) for disp_est in disp_ests]
|
|
scalar_outputs["Thres2"] = [Thres_metric(disp_est, disp_gt, mask, 2.0) for disp_est in disp_ests]
|
|
scalar_outputs["Thres3"] = [Thres_metric(disp_est, disp_gt, mask, 3.0) for disp_est in disp_ests]
|
|
|
|
if compute_metrics:
|
|
image_outputs["errormap"] = [disp_error_image_func()(disp_est, disp_gt) for disp_est in disp_ests]
|
|
|
|
return tensor2float(loss), tensor2float(scalar_outputs), image_outputs
|
|
|
|
|
|
if __name__ == '__main__':
|
|
train()
|