Compare commits

..

3 Commits

Author SHA1 Message Date
8a1b0e18f7 added asymmetric chromatic augmentation & adjusted augmentor param 2023-05-03 17:33:27 +08:00
HTensor
06fa0c222c set 12 threads to prepare datasets 2023-05-03 17:30:12 +08:00
HTensor
5cbf5ede88 disabled cudnn benchmark 2023-05-03 17:30:12 +08:00
3 changed files with 30 additions and 12 deletions

View File

@ -352,7 +352,7 @@ def fetch_dataloader(args):
train_dataset = new_dataset if train_dataset is None else train_dataset + new_dataset train_dataset = new_dataset if train_dataset is None else train_dataset + new_dataset
train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size, train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size,
pin_memory=True, shuffle=True, num_workers=int(os.environ.get('SLURM_CPUS_PER_TASK', 6))-2, drop_last=True) pin_memory=True, shuffle=True, num_workers=12, drop_last=True)
logging.info('Training with %d image pairs' % len(train_dataset)) logging.info('Training with %d image pairs' % len(train_dataset))
return train_loader return train_loader

View File

@ -5,7 +5,7 @@ import os
import time import time
from glob import glob from glob import glob
from skimage import color, io from skimage import color, io
from PIL import Image from PIL import Image, ImageEnhance
import cv2 import cv2
cv2.setNumThreads(0) cv2.setNumThreads(0)
@ -198,21 +198,40 @@ class SparseFlowAugmentor:
self.v_flip_prob = 0.1 self.v_flip_prob = 0.1
# photometric augmentation params # photometric augmentation params
self.photo_aug = Compose([ColorJitter(brightness=0.3, contrast=0.3, saturation=saturation_range, hue=0.3/3.14), AdjustGamma(*gamma)]) # self.photo_aug = Compose([ColorJitter(brightness=0.3, contrast=0.3, saturation=saturation_range, hue=0.3/3.14), AdjustGamma(*gamma)])
self.asymmetric_color_aug_prob = 0.2
self.eraser_aug_prob = 0.5 self.eraser_aug_prob = 0.5
def chromatic_augmentation(self, img):
random_brightness = np.random.uniform(0.8, 1.2)
random_contrast = np.random.uniform(0.8, 1.2)
random_gamma = np.random.uniform(0.8, 1.2)
img = Image.fromarray(img)
enhancer = ImageEnhance.Brightness(img)
img = enhancer.enhance(random_brightness)
enhancer = ImageEnhance.Contrast(img)
img = enhancer.enhance(random_contrast)
gamma_map = [
255 * 1.0 * pow(ele / 255.0, random_gamma) for ele in range(256)
] * 3
img = img.point(gamma_map) # use PIL's point-function to accelerate this part
img_ = np.array(img)
return img_
def color_transform(self, img1, img2): def color_transform(self, img1, img2):
image_stack = np.concatenate([img1, img2], axis=0) img1 = self.chromatic_augmentation(img1)
image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8) img2 = self.chromatic_augmentation(img2)
img1, img2 = np.split(image_stack, 2, axis=0)
return img1, img2 return img1, img2
def eraser_transform(self, img1, img2): def eraser_transform(self, img1, img2):
ht, wd = img1.shape[:2] ht, wd = img1.shape[:2]
if np.random.rand() < self.eraser_aug_prob: if np.random.rand() < self.eraser_aug_prob:
mean_color = np.mean(img2.reshape(-1, 3), axis=0) mean_color = np.mean(img2.reshape(-1, 3), axis=0)
for _ in range(np.random.randint(1, 3)): for _ in range(1):
x0 = np.random.randint(0, wd) x0 = np.random.randint(0, wd)
y0 = np.random.randint(0, ht) y0 = np.random.randint(0, ht)
dx = np.random.randint(50, 100) dx = np.random.randint(50, 100)

View File

@ -22,7 +22,6 @@ from evaluate_stereo import *
import core.stereo_datasets as datasets import core.stereo_datasets as datasets
import torch.nn.functional as F import torch.nn.functional as F
torch.backends.cudnn.benchmark = True
ckpt_path = './checkpoints/igev_stereo' ckpt_path = './checkpoints/igev_stereo'
log_path = './checkpoints/igev_stereo' log_path = './checkpoints/igev_stereo'
@ -241,8 +240,8 @@ if __name__ == '__main__':
parser.add_argument('--max_disp', type=int, default=192, help="max disp of geometry encoding volume") parser.add_argument('--max_disp', type=int, default=192, help="max disp of geometry encoding volume")
# Data augmentation # Data augmentation
parser.add_argument('--img_gamma', type=float, nargs='+', default=None, help="gamma range") # parser.add_argument('--img_gamma', type=float, nargs='+', default=None, help="gamma range")
parser.add_argument('--saturation_range', type=float, nargs='+', default=[0, 1.4], help='color saturation') # parser.add_argument('--saturation_range', type=float, nargs='+', default=[0, 1.4], help='color saturation')
parser.add_argument('--do_flip', default=False, choices=['h', 'v'], help='flip the images horizontally or vertically') parser.add_argument('--do_flip', default=False, choices=['h', 'v'], help='flip the images horizontally or vertically')
parser.add_argument('--spatial_scale', type=float, nargs='+', default=[-0.2, 0.4], help='re-scale the images randomly') parser.add_argument('--spatial_scale', type=float, nargs='+', default=[-0.2, 0.4], help='re-scale the images randomly')
parser.add_argument('--noyjitter', action='store_true', help='don\'t simulate imperfect rectification') parser.add_argument('--noyjitter', action='store_true', help='don\'t simulate imperfect rectification')