GwcNet/utils/experiment.py
2019-04-14 21:26:51 +08:00

152 lines
4.4 KiB
Python

from __future__ import print_function, division
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.utils.data
from torch.autograd import Variable
import torchvision.utils as vutils
import torch.nn.functional as F
import numpy as np
import copy
def make_iterative_func(func):
def wrapper(vars):
if isinstance(vars, list):
return [wrapper(x) for x in vars]
elif isinstance(vars, tuple):
return tuple([wrapper(x) for x in vars])
elif isinstance(vars, dict):
return {k: wrapper(v) for k, v in vars.items()}
else:
return func(vars)
return wrapper
def make_nograd_func(func):
def wrapper(*f_args, **f_kwargs):
with torch.no_grad():
ret = func(*f_args, **f_kwargs)
return ret
return wrapper
@make_iterative_func
def tensor2float(vars):
if isinstance(vars, float):
return vars
elif isinstance(vars, torch.Tensor):
return vars.data.item()
else:
raise NotImplementedError("invalid input type for tensor2float")
@make_iterative_func
def tensor2numpy(vars):
if isinstance(vars, np.ndarray):
return vars
elif isinstance(vars, torch.Tensor):
return vars.data.cpu().numpy()
else:
raise NotImplementedError("invalid input type for tensor2numpy")
@make_iterative_func
def check_allfloat(vars):
assert isinstance(vars, float)
def save_scalars(logger, mode_tag, scalar_dict, global_step):
scalar_dict = tensor2float(scalar_dict)
for tag, values in scalar_dict.items():
if not isinstance(values, list) and not isinstance(values, tuple):
values = [values]
for idx, value in enumerate(values):
scalar_name = '{}/{}'.format(mode_tag, tag)
# if len(values) > 1:
scalar_name = scalar_name + "_" + str(idx)
logger.add_scalar(scalar_name, value, global_step)
def save_images(logger, mode_tag, images_dict, global_step):
images_dict = tensor2numpy(images_dict)
for tag, values in images_dict.items():
if not isinstance(values, list) and not isinstance(values, tuple):
values = [values]
for idx, value in enumerate(values):
if len(value.shape) == 3:
value = value[:, np.newaxis, :, :]
value = value[:1]
value = torch.from_numpy(value)
image_name = '{}/{}'.format(mode_tag, tag)
if len(values) > 1:
image_name = image_name + "_" + str(idx)
logger.add_image(image_name, vutils.make_grid(value, padding=0, nrow=1, normalize=True, scale_each=True),
global_step)
def adjust_learning_rate(optimizer, epoch, base_lr, lrepochs):
splits = lrepochs.split(':')
assert len(splits) == 2
# parse the epochs to downscale the learning rate (before :)
downscale_epochs = [int(eid_str) for eid_str in splits[0].split(',')]
# parse downscale rate (after :)
downscale_rate = float(splits[1])
print("downscale epochs: {}, downscale rate: {}".format(downscale_epochs, downscale_rate))
lr = base_lr
for eid in downscale_epochs:
if epoch >= eid:
lr /= downscale_rate
else:
break
print("setting learning rate to {}".format(lr))
for param_group in optimizer.param_groups:
param_group['lr'] = lr
class AverageMeter(object):
def __init__(self):
self.sum_value = 0.
self.count = 0
def update(self, x):
check_allfloat(x)
self.sum_value += x
self.count += 1
def mean(self):
return self.sum_value / self.count
class AverageMeterDict(object):
def __init__(self):
self.data = None
self.count = 0
def update(self, x):
check_allfloat(x)
self.count += 1
if self.data is None:
self.data = copy.deepcopy(x)
else:
for k1, v1 in x.items():
if isinstance(v1, float):
self.data[k1] += v1
elif isinstance(v1, tuple) or isinstance(v1, list):
for idx, v2 in enumerate(v1):
self.data[k1][idx] += v2
else:
assert NotImplementedError("error input type for update AvgMeterDict")
def mean(self):
@make_iterative_func
def get_mean(v):
return v / float(self.count)
return get_mean(self.data)