fix bug: stride 2->1, global_step when testing

This commit is contained in:
Xiaoyang Guo 2019-04-14 08:48:49 -04:00
parent 6b8f506591
commit eead4d9e8a
2 changed files with 3 additions and 3 deletions

View File

@ -119,7 +119,7 @@ def train():
# testing
avg_test_scalars = AverageMeterDict()
for batch_idx, sample in enumerate(TestImgLoader):
global_step = len(TrainImgLoader) * epoch_idx + batch_idx
global_step = len(TestImgLoader) * epoch_idx + batch_idx
start_time = time.time()
do_summary = global_step % args.summary_freq == 0
loss, scalar_outputs, image_outputs = test_sample(sample, compute_metrics=do_summary)

View File

@ -71,13 +71,13 @@ class hourglass(nn.Module):
self.conv1 = nn.Sequential(convbn_3d(in_channels, in_channels * 2, 3, 2, 1),
nn.ReLU(inplace=True))
self.conv2 = nn.Sequential(convbn_3d(in_channels * 2, in_channels * 2, 3, 2, 1),
self.conv2 = nn.Sequential(convbn_3d(in_channels * 2, in_channels * 2, 3, 1, 1),
nn.ReLU(inplace=True))
self.conv3 = nn.Sequential(convbn_3d(in_channels * 2, in_channels * 4, 3, 2, 1),
nn.ReLU(inplace=True))
self.conv4 = nn.Sequential(convbn_3d(in_channels * 4, in_channels * 4, 3, 2, 1),
self.conv4 = nn.Sequential(convbn_3d(in_channels * 4, in_channels * 4, 3, 1, 1),
nn.ReLU(inplace=True))
self.conv5 = nn.Sequential(