158 lines
4.5 KiB
Python
158 lines
4.5 KiB
Python
|
from __future__ import print_function, division
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
import torch.nn.parallel
|
||
|
import torch.utils.data
|
||
|
from torch.autograd import Variable
|
||
|
import torchvision.utils as vutils
|
||
|
import torch.nn.functional as F
|
||
|
import numpy as np
|
||
|
import time
|
||
|
from datasets import *
|
||
|
from models import *
|
||
|
import copy
|
||
|
import yaml
|
||
|
import sys
|
||
|
import argparse
|
||
|
|
||
|
|
||
|
def make_iterative_func(func):
|
||
|
def wrapper(vars):
|
||
|
if isinstance(vars, list):
|
||
|
return [wrapper(x) for x in vars]
|
||
|
elif isinstance(vars, tuple):
|
||
|
return tuple([wrapper(x) for x in vars])
|
||
|
elif isinstance(vars, dict):
|
||
|
return {k: wrapper(v) for k, v in vars.items()}
|
||
|
else:
|
||
|
return func(vars)
|
||
|
|
||
|
return wrapper
|
||
|
|
||
|
|
||
|
def make_nograd_func(func):
|
||
|
def wrapper(*f_args, **f_kwargs):
|
||
|
with torch.no_grad():
|
||
|
ret = func(*f_args, **f_kwargs)
|
||
|
return ret
|
||
|
|
||
|
return wrapper
|
||
|
|
||
|
|
||
|
@make_iterative_func
|
||
|
def tensor2float(vars):
|
||
|
if isinstance(vars, float):
|
||
|
return vars
|
||
|
elif isinstance(vars, torch.Tensor):
|
||
|
return vars.data.item()
|
||
|
else:
|
||
|
raise NotImplementedError("invalid input type for tensor2float")
|
||
|
|
||
|
|
||
|
@make_iterative_func
|
||
|
def tensor2numpy(vars):
|
||
|
if isinstance(vars, np.ndarray):
|
||
|
return vars
|
||
|
elif isinstance(vars, torch.Tensor):
|
||
|
return vars.data.cpu().numpy()
|
||
|
else:
|
||
|
raise NotImplementedError("invalid input type for tensor2numpy")
|
||
|
|
||
|
|
||
|
@make_iterative_func
|
||
|
def check_allfloat(vars):
|
||
|
assert isinstance(vars, float)
|
||
|
|
||
|
|
||
|
def save_scalars(logger, mode_tag, scalar_dict, global_step):
|
||
|
scalar_dict = tensor2float(scalar_dict)
|
||
|
for tag, values in scalar_dict.items():
|
||
|
if not isinstance(values, list) and not isinstance(values, tuple):
|
||
|
values = [values]
|
||
|
for idx, value in enumerate(values):
|
||
|
scalar_name = '{}/{}'.format(mode_tag, tag)
|
||
|
# if len(values) > 1:
|
||
|
scalar_name = scalar_name + "_" + str(idx)
|
||
|
logger.add_scalar(scalar_name, value, global_step)
|
||
|
|
||
|
|
||
|
def save_images(logger, mode_tag, images_dict, global_step):
|
||
|
images_dict = tensor2numpy(images_dict)
|
||
|
for tag, values in images_dict.items():
|
||
|
if not isinstance(values, list) and not isinstance(values, tuple):
|
||
|
values = [values]
|
||
|
for idx, value in enumerate(values):
|
||
|
if len(value.shape) == 3:
|
||
|
value = value[:, np.newaxis, :, :]
|
||
|
value = value[:1]
|
||
|
value = torch.from_numpy(value)
|
||
|
|
||
|
image_name = '{}/{}'.format(mode_tag, tag)
|
||
|
if len(values) > 1:
|
||
|
image_name = image_name + "_" + str(idx)
|
||
|
logger.add_image(image_name, vutils.make_grid(value, padding=0, nrow=1, normalize=True, scale_each=True),
|
||
|
global_step)
|
||
|
|
||
|
|
||
|
def adjust_learning_rate(optimizer, epoch, base_lr, lrepochs):
|
||
|
splits = lrepochs.split(':')
|
||
|
assert len(splits) == 2
|
||
|
|
||
|
# parse the epochs to downscale the learning rate (before :)
|
||
|
downscale_epochs = [int(eid_str) for eid_str in splits[0].split(',')]
|
||
|
# parse downscale rate (after :)
|
||
|
downscale_rate = float(splits[1])
|
||
|
print("downscale epochs: {}, downscale rate: {}".format(downscale_epochs, downscale_rate))
|
||
|
|
||
|
lr = base_lr
|
||
|
for eid in downscale_epochs:
|
||
|
if epoch >= eid:
|
||
|
lr /= downscale_rate
|
||
|
else:
|
||
|
break
|
||
|
print("setting learning rate to {}".format(lr))
|
||
|
for param_group in optimizer.param_groups:
|
||
|
param_group['lr'] = lr
|
||
|
|
||
|
|
||
|
class AverageMeter(object):
|
||
|
def __init__(self):
|
||
|
self.sum_value = 0.
|
||
|
self.count = 0
|
||
|
|
||
|
def update(self, x):
|
||
|
check_allfloat(x)
|
||
|
self.sum_value += x
|
||
|
self.count += 1
|
||
|
|
||
|
def mean(self):
|
||
|
return self.sum_value / self.count
|
||
|
|
||
|
|
||
|
class AverageMeterDict(object):
|
||
|
def __init__(self):
|
||
|
self.data = None
|
||
|
self.count = 0
|
||
|
|
||
|
def update(self, x):
|
||
|
check_allfloat(x)
|
||
|
self.count += 1
|
||
|
if self.data is None:
|
||
|
self.data = copy.deepcopy(x)
|
||
|
else:
|
||
|
for k1, v1 in x.items():
|
||
|
if isinstance(v1, float):
|
||
|
self.data[k1] += v1
|
||
|
elif isinstance(v1, tuple) or isinstance(v1, list):
|
||
|
for idx, v2 in enumerate(v1):
|
||
|
self.data[k1][idx] += v2
|
||
|
else:
|
||
|
assert NotImplementedError("error input type for update AvgMeterDict")
|
||
|
|
||
|
def mean(self):
|
||
|
@make_iterative_func
|
||
|
def get_mean(v):
|
||
|
return v / float(self.count)
|
||
|
|
||
|
return get_mean(self.data)
|