diff --git a/main.py b/main.py index 16e58bc..beb9f4d 100644 --- a/main.py +++ b/main.py @@ -15,18 +15,15 @@ import time from tensorboardX import SummaryWriter from datasets import __datasets__ from models import __models__ -from models import * from utils import * from torch.utils.data import DataLoader -import skimage import gc -import datetime -import cv2 cudnn.benchmark = True parser = argparse.ArgumentParser(description='Group-wise Correlation Stereo Network (GwcNet)') parser.add_argument('--model', default='gwcnet-g', help='select a model structure', choices=__models__.keys()) +parser.add_argument('--maxdisp', type=int, default=192, help='maximum disparity') parser.add_argument('--dataset', required=True, help='dataset name', choices=__datasets__.keys()) parser.add_argument('--datapath', required=True, help='data path') @@ -36,7 +33,6 @@ parser.add_argument('--testlist', required=True, help='testing list') parser.add_argument('--lr', type=float, default=0.001, help='base learning rate') parser.add_argument('--batch_size', type=int, default=16, help='training batch size') parser.add_argument('--test_batch_size', type=int, default=8, help='testing batch size') -parser.add_argument('--maxdisp', type=int, default=192, help='maximum disparity') parser.add_argument('--epochs', type=int, required=True, help='number of epochs to train') parser.add_argument('--lrepochs', type=str, required=True, help='the epochs to decay lr: the downscale rate') diff --git a/models/gwcnet.py b/models/gwcnet.py index 923e6bd..4746807 100644 --- a/models/gwcnet.py +++ b/models/gwcnet.py @@ -4,7 +4,6 @@ import torch.nn as nn import torch.utils.data from torch.autograd import Variable import torch.nn.functional as F -import math from models.submodule import * diff --git a/models/submodule.py b/models/submodule.py index 7cd1b1e..a4618c4 100644 --- a/models/submodule.py +++ b/models/submodule.py @@ -5,7 +5,6 @@ import torch.utils.data from torch.autograd import Variable from torch.autograd.function import Function import torch.nn.functional as F -import math import numpy as np diff --git a/utils/experiment.py b/utils/experiment.py index f6e45d1..340f921 100644 --- a/utils/experiment.py +++ b/utils/experiment.py @@ -7,13 +7,7 @@ from torch.autograd import Variable import torchvision.utils as vutils import torch.nn.functional as F import numpy as np -import time -from datasets import * -from models import * import copy -import yaml -import sys -import argparse def make_iterative_func(func): diff --git a/utils/visualization.py b/utils/visualization.py index f898882..3be1305 100644 --- a/utils/visualization.py +++ b/utils/visualization.py @@ -6,10 +6,6 @@ from torch.autograd import Variable, Function import torch.nn.functional as F import math import numpy as np -import cv2 - -# disable multi-thread -cv2.setNumThreads(0) def gen_error_colormap():