Compare commits
No commits in common. "8a1b0e18f7f206d1ff0a529c23ee41c7564ae611" and "9df896bb706828a196f4f6658c6150137a3e5e13" have entirely different histories.
8a1b0e18f7
...
9df896bb70
@ -352,7 +352,7 @@ def fetch_dataloader(args):
|
|||||||
train_dataset = new_dataset if train_dataset is None else train_dataset + new_dataset
|
train_dataset = new_dataset if train_dataset is None else train_dataset + new_dataset
|
||||||
|
|
||||||
train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size,
|
train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size,
|
||||||
pin_memory=True, shuffle=True, num_workers=12, drop_last=True)
|
pin_memory=True, shuffle=True, num_workers=int(os.environ.get('SLURM_CPUS_PER_TASK', 6))-2, drop_last=True)
|
||||||
|
|
||||||
logging.info('Training with %d image pairs' % len(train_dataset))
|
logging.info('Training with %d image pairs' % len(train_dataset))
|
||||||
return train_loader
|
return train_loader
|
||||||
|
@ -5,7 +5,7 @@ import os
|
|||||||
import time
|
import time
|
||||||
from glob import glob
|
from glob import glob
|
||||||
from skimage import color, io
|
from skimage import color, io
|
||||||
from PIL import Image, ImageEnhance
|
from PIL import Image
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
cv2.setNumThreads(0)
|
cv2.setNumThreads(0)
|
||||||
@ -198,40 +198,21 @@ class SparseFlowAugmentor:
|
|||||||
self.v_flip_prob = 0.1
|
self.v_flip_prob = 0.1
|
||||||
|
|
||||||
# photometric augmentation params
|
# photometric augmentation params
|
||||||
# self.photo_aug = Compose([ColorJitter(brightness=0.3, contrast=0.3, saturation=saturation_range, hue=0.3/3.14), AdjustGamma(*gamma)])
|
self.photo_aug = Compose([ColorJitter(brightness=0.3, contrast=0.3, saturation=saturation_range, hue=0.3/3.14), AdjustGamma(*gamma)])
|
||||||
|
self.asymmetric_color_aug_prob = 0.2
|
||||||
self.eraser_aug_prob = 0.5
|
self.eraser_aug_prob = 0.5
|
||||||
|
|
||||||
def chromatic_augmentation(self, img):
|
|
||||||
random_brightness = np.random.uniform(0.8, 1.2)
|
|
||||||
random_contrast = np.random.uniform(0.8, 1.2)
|
|
||||||
random_gamma = np.random.uniform(0.8, 1.2)
|
|
||||||
|
|
||||||
img = Image.fromarray(img)
|
|
||||||
|
|
||||||
enhancer = ImageEnhance.Brightness(img)
|
|
||||||
img = enhancer.enhance(random_brightness)
|
|
||||||
enhancer = ImageEnhance.Contrast(img)
|
|
||||||
img = enhancer.enhance(random_contrast)
|
|
||||||
|
|
||||||
gamma_map = [
|
|
||||||
255 * 1.0 * pow(ele / 255.0, random_gamma) for ele in range(256)
|
|
||||||
] * 3
|
|
||||||
img = img.point(gamma_map) # use PIL's point-function to accelerate this part
|
|
||||||
|
|
||||||
img_ = np.array(img)
|
|
||||||
|
|
||||||
return img_
|
|
||||||
|
|
||||||
def color_transform(self, img1, img2):
|
def color_transform(self, img1, img2):
|
||||||
img1 = self.chromatic_augmentation(img1)
|
image_stack = np.concatenate([img1, img2], axis=0)
|
||||||
img2 = self.chromatic_augmentation(img2)
|
image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
|
||||||
|
img1, img2 = np.split(image_stack, 2, axis=0)
|
||||||
return img1, img2
|
return img1, img2
|
||||||
|
|
||||||
def eraser_transform(self, img1, img2):
|
def eraser_transform(self, img1, img2):
|
||||||
ht, wd = img1.shape[:2]
|
ht, wd = img1.shape[:2]
|
||||||
if np.random.rand() < self.eraser_aug_prob:
|
if np.random.rand() < self.eraser_aug_prob:
|
||||||
mean_color = np.mean(img2.reshape(-1, 3), axis=0)
|
mean_color = np.mean(img2.reshape(-1, 3), axis=0)
|
||||||
for _ in range(1):
|
for _ in range(np.random.randint(1, 3)):
|
||||||
x0 = np.random.randint(0, wd)
|
x0 = np.random.randint(0, wd)
|
||||||
y0 = np.random.randint(0, ht)
|
y0 = np.random.randint(0, ht)
|
||||||
dx = np.random.randint(50, 100)
|
dx = np.random.randint(50, 100)
|
||||||
|
@ -22,6 +22,7 @@ from evaluate_stereo import *
|
|||||||
import core.stereo_datasets as datasets
|
import core.stereo_datasets as datasets
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
torch.backends.cudnn.benchmark = True
|
||||||
|
|
||||||
ckpt_path = './checkpoints/igev_stereo'
|
ckpt_path = './checkpoints/igev_stereo'
|
||||||
log_path = './checkpoints/igev_stereo'
|
log_path = './checkpoints/igev_stereo'
|
||||||
@ -240,8 +241,8 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument('--max_disp', type=int, default=192, help="max disp of geometry encoding volume")
|
parser.add_argument('--max_disp', type=int, default=192, help="max disp of geometry encoding volume")
|
||||||
|
|
||||||
# Data augmentation
|
# Data augmentation
|
||||||
# parser.add_argument('--img_gamma', type=float, nargs='+', default=None, help="gamma range")
|
parser.add_argument('--img_gamma', type=float, nargs='+', default=None, help="gamma range")
|
||||||
# parser.add_argument('--saturation_range', type=float, nargs='+', default=[0, 1.4], help='color saturation')
|
parser.add_argument('--saturation_range', type=float, nargs='+', default=[0, 1.4], help='color saturation')
|
||||||
parser.add_argument('--do_flip', default=False, choices=['h', 'v'], help='flip the images horizontally or vertically')
|
parser.add_argument('--do_flip', default=False, choices=['h', 'v'], help='flip the images horizontally or vertically')
|
||||||
parser.add_argument('--spatial_scale', type=float, nargs='+', default=[-0.2, 0.4], help='re-scale the images randomly')
|
parser.add_argument('--spatial_scale', type=float, nargs='+', default=[-0.2, 0.4], help='re-scale the images randomly')
|
||||||
parser.add_argument('--noyjitter', action='store_true', help='don\'t simulate imperfect rectification')
|
parser.add_argument('--noyjitter', action='store_true', help='don\'t simulate imperfect rectification')
|
||||||
|
Loading…
Reference in New Issue
Block a user