IGEV/IGEV-Stereo/train_stereo.py

258 lines
10 KiB
Python
Raw Permalink Normal View History

2023-03-12 20:19:58 +08:00
from __future__ import print_function, division
import math
2023-03-12 20:19:58 +08:00
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0, 1'
import argparse
import logging
import numpy as np
from pathlib import Path
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
import torch
import torch.nn as nn
import torch.optim as optim
from core.igev_stereo import IGEVStereo
from evaluate_stereo import *
import core.stereo_datasets as datasets
import torch.nn.functional as F
ckpt_path = './checkpoints/igev_stereo'
log_path = './checkpoints/igev_stereo'
try:
from torch.cuda.amp import GradScaler
except:
class GradScaler:
def __init__(self):
pass
def scale(self, loss):
return loss
def unscale_(self, optimizer):
pass
def step(self, optimizer):
optimizer.step()
def update(self):
pass
def sequence_loss(disp_preds, disp_init_pred, disp_gt, valid, loss_gamma=0.9, max_disp=192):
""" Loss function defined over sequence of flow predictions """
n_predictions = len(disp_preds)
assert n_predictions >= 1
disp_loss = 0.0
mag = torch.sum(disp_gt**2, dim=1).sqrt()
valid = ((valid >= 0.5) & (mag < max_disp)).unsqueeze(1)
assert valid.shape == disp_gt.shape, [valid.shape, disp_gt.shape]
assert not torch.isinf(disp_gt[valid.bool()]).any()
disp_loss += 1.0 * F.smooth_l1_loss(disp_init_pred[valid.bool()], disp_gt[valid.bool()], reduction='mean')
2023-03-12 20:19:58 +08:00
for i in range(n_predictions):
adjusted_loss_gamma = loss_gamma**(15/(n_predictions - 1))
i_weight = adjusted_loss_gamma**(n_predictions - i - 1)
i_loss = (disp_preds[i] - disp_gt).abs()
assert i_loss.shape == valid.shape, [i_loss.shape, valid.shape, disp_gt.shape, disp_preds[i].shape]
disp_loss += i_weight * i_loss[valid.bool()].mean()
epe = torch.sum((disp_preds[-1] - disp_gt)**2, dim=1).sqrt()
epe = epe.view(-1)[valid.view(-1)]
metrics = {
'epe': epe.mean().item(),
'1px': (epe < 1).float().mean().item(),
'3px': (epe < 3).float().mean().item(),
'5px': (epe < 5).float().mean().item(),
}
return disp_loss, metrics
def fetch_optimizer(args, model):
""" Create the optimizer and learning rate scheduler """
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wdecay, eps=1e-8)
# todo: cosine scheduler, warm-up
2023-03-12 20:19:58 +08:00
scheduler = optim.lr_scheduler.OneCycleLR(optimizer, args.lr, args.num_steps+100,
pct_start=0.01, cycle_momentum=False, anneal_strategy='linear')
return optimizer, scheduler
class Logger:
SUM_FREQ = 100
def __init__(self, model, scheduler):
self.model = model
self.scheduler = scheduler
self.total_steps = 0
self.running_loss = {}
self.writer = SummaryWriter(log_dir=log_path)
def _print_training_status(self):
metrics_data = [self.running_loss[k]/Logger.SUM_FREQ for k in sorted(self.running_loss.keys())]
training_str = "[{:6d}, {:10.7f}] ".format(self.total_steps+1, self.scheduler.get_last_lr()[0])
metrics_str = ("{:10.4f}, "*len(metrics_data)).format(*metrics_data)
# print the training status
logging.info(f"Training Metrics ({self.total_steps}): {training_str + metrics_str}")
if self.writer is None:
self.writer = SummaryWriter(log_dir=log_path)
for k in self.running_loss:
self.writer.add_scalar(k, self.running_loss[k]/Logger.SUM_FREQ, self.total_steps)
self.running_loss[k] = 0.0
def push(self, metrics):
self.total_steps += 1
for key in metrics:
if key not in self.running_loss:
self.running_loss[key] = 0.0
self.running_loss[key] += metrics[key]
if self.total_steps % Logger.SUM_FREQ == Logger.SUM_FREQ-1:
self._print_training_status()
self.running_loss = {}
def write_dict(self, results):
if self.writer is None:
self.writer = SummaryWriter(log_dir=log_path)
for key in results:
self.writer.add_scalar(key, results[key], self.total_steps)
def close(self):
self.writer.close()
def train(args):
2023-04-19 00:58:31 +08:00
# todo: compile the model to speed up at pytorch 2.0.
2023-03-12 20:19:58 +08:00
model = nn.DataParallel(IGEVStereo(args))
print("Parameter Count: %d" % count_parameters(model))
train_loader = datasets.fetch_dataloader(args)
optimizer, scheduler = fetch_optimizer(args, model)
total_steps = 0
logger = Logger(model, scheduler)
if args.restore_ckpt is not None:
assert args.restore_ckpt.endswith(".pth")
logging.info("Loading checkpoint...")
checkpoint = torch.load(args.restore_ckpt)
model.load_state_dict(checkpoint, strict=True)
logging.info(f"Done loading checkpoint")
model.cuda()
model.train()
model.module.freeze_bn() # We keep BatchNorm frozen
validation_frequency = 1000
2023-03-12 20:19:58 +08:00
scaler = GradScaler(enabled=args.mixed_precision)
should_keep_training = True
global_batch_num = 0
while should_keep_training:
for i_batch, (_, *data_blob) in enumerate(tqdm(train_loader)):
optimizer.zero_grad()
image1, image2, disp_gt, valid = [x.cuda() for x in data_blob]
assert model.training
disp_init_pred, disp_preds = model(image1, image2, iters=args.train_iters)
assert model.training
loss, metrics = sequence_loss(disp_preds, disp_init_pred, disp_gt, valid, max_disp=args.max_disp)
logger.writer.add_scalar("live_loss", loss.item(), global_batch_num)
logger.writer.add_scalar(f'learning_rate', optimizer.param_groups[0]['lr'], global_batch_num)
global_batch_num += 1
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
#warning
2023-03-12 20:19:58 +08:00
scaler.step(optimizer)
scheduler.step()
scaler.update()
logger.push(metrics)
if total_steps % validation_frequency == validation_frequency - 1:
save_path = Path(ckpt_path + '/%d_%s.pth' % (total_steps + 1, args.name))
logging.info(f"Saving file {save_path.absolute()}")
torch.save(model.state_dict(), save_path)
results = validate_middlebury(model.module, iters=args.valid_iters)
2023-03-12 20:19:58 +08:00
logger.write_dict(results)
model.train()
model.module.freeze_bn()
total_steps += 1
if total_steps > args.num_steps:
should_keep_training = False
break
if len(train_loader) >= 10000:
save_path = Path(ckpt_path + '/%d_epoch_%s.pth.gz' % (total_steps + 1, args.name))
logging.info(f"Saving file {save_path}")
torch.save(model.state_dict(), save_path)
print("FINISHED TRAINING")
logger.close()
PATH = ckpt_path + '/%s.pth' % args.name
torch.save(model.state_dict(), PATH)
return PATH
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--name', default='igev-stereo', help="name your experiment")
parser.add_argument('--restore_ckpt', default=None, help="")
parser.add_argument('--mixed_precision', default=False, action='store_true', help='use mixed precision')
2023-03-12 20:19:58 +08:00
# Training parameters
parser.add_argument('--batch_size', type=int, default=8, help="batch size used during training.")
parser.add_argument('--train_datasets', nargs='+', default=['sceneflow'], help="training datasets.")
parser.add_argument('--lr', type=float, default=0.0002, help="max learning rate.")
parser.add_argument('--num_steps', type=int, default=200000, help="length of training schedule.")
parser.add_argument('--image_size', type=int, nargs='+', default=[320, 736], help="size of the random image crops used during training.")
parser.add_argument('--train_iters', type=int, default=22, help="number of updates to the disparity field in each forward pass.")
parser.add_argument('--wdecay', type=float, default=.00001, help="Weight decay in optimizer.")
# Validation parameters
parser.add_argument('--valid_iters', type=int, default=32, help='number of flow-field updates during validation forward pass')
# Architecure choices
parser.add_argument('--corr_implementation', choices=["reg", "alt", "reg_cuda", "alt_cuda"], default="reg", help="correlation volume implementation")
parser.add_argument('--shared_backbone', action='store_true', help="use a single backbone for the context and feature encoders")
parser.add_argument('--corr_levels', type=int, default=2, help="number of levels in the correlation pyramid")
parser.add_argument('--corr_radius', type=int, default=4, help="width of the correlation pyramid")
parser.add_argument('--n_downsample', type=int, default=2, help="resolution of the disparity field (1/2^K)")
parser.add_argument('--slow_fast_gru', action='store_true', help="iterate the low-res GRUs more frequently")
parser.add_argument('--n_gru_layers', type=int, default=3, help="number of hidden GRU levels")
parser.add_argument('--hidden_dims', nargs='+', type=int, default=[128]*3, help="hidden state and context dimensions")
parser.add_argument('--max_disp', type=int, default=192, help="max disp of geometry encoding volume")
# Data augmentation
# parser.add_argument('--img_gamma', type=float, nargs='+', default=None, help="gamma range")
# parser.add_argument('--saturation_range', type=float, nargs='+', default=[0, 1.4], help='color saturation')
2023-03-12 20:19:58 +08:00
parser.add_argument('--do_flip', default=False, choices=['h', 'v'], help='flip the images horizontally or vertically')
parser.add_argument('--spatial_scale', type=float, nargs='+', default=[-0.2, 0.4], help='re-scale the images randomly')
parser.add_argument('--noyjitter', action='store_true', help='don\'t simulate imperfect rectification')
args = parser.parse_args()
torch.manual_seed(666)
np.random.seed(666)
logging.basicConfig(level=logging.INFO,
format='%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s')
Path(ckpt_path).mkdir(exist_ok=True, parents=True)
train(args)