From eead4d9e8a739d36b27d7b075ccb740b58cd8e5f Mon Sep 17 00:00:00 2001 From: Xiaoyang Guo Date: Sun, 14 Apr 2019 08:48:49 -0400 Subject: [PATCH] fix bug: stride 2->1, global_step when testing --- main.py | 2 +- models/gwcnet.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/main.py b/main.py index 3144215..16e58bc 100644 --- a/main.py +++ b/main.py @@ -119,7 +119,7 @@ def train(): # testing avg_test_scalars = AverageMeterDict() for batch_idx, sample in enumerate(TestImgLoader): - global_step = len(TrainImgLoader) * epoch_idx + batch_idx + global_step = len(TestImgLoader) * epoch_idx + batch_idx start_time = time.time() do_summary = global_step % args.summary_freq == 0 loss, scalar_outputs, image_outputs = test_sample(sample, compute_metrics=do_summary) diff --git a/models/gwcnet.py b/models/gwcnet.py index 3622e1a..923e6bd 100644 --- a/models/gwcnet.py +++ b/models/gwcnet.py @@ -71,13 +71,13 @@ class hourglass(nn.Module): self.conv1 = nn.Sequential(convbn_3d(in_channels, in_channels * 2, 3, 2, 1), nn.ReLU(inplace=True)) - self.conv2 = nn.Sequential(convbn_3d(in_channels * 2, in_channels * 2, 3, 2, 1), + self.conv2 = nn.Sequential(convbn_3d(in_channels * 2, in_channels * 2, 3, 1, 1), nn.ReLU(inplace=True)) self.conv3 = nn.Sequential(convbn_3d(in_channels * 2, in_channels * 4, 3, 2, 1), nn.ReLU(inplace=True)) - self.conv4 = nn.Sequential(convbn_3d(in_channels * 4, in_channels * 4, 3, 2, 1), + self.conv4 = nn.Sequential(convbn_3d(in_channels * 4, in_channels * 4, 3, 1, 1), nn.ReLU(inplace=True)) self.conv5 = nn.Sequential(