GwcNet/models/gwcnet.py

232 lines
9.4 KiB
Python
Raw Normal View History

2019-04-14 17:34:58 +08:00
from __future__ import print_function
import torch
import torch.nn as nn
import torch.utils.data
from torch.autograd import Variable
import torch.nn.functional as F
from models.submodule import *
2019-04-14 21:28:06 +08:00
import math
2019-04-14 17:34:58 +08:00
class feature_extraction(nn.Module):
def __init__(self, concat_feature=False, concat_feature_channel=12):
super(feature_extraction, self).__init__()
self.concat_feature = concat_feature
self.inplanes = 32
self.firstconv = nn.Sequential(convbn(3, 32, 3, 2, 1, 1),
nn.ReLU(inplace=True),
convbn(32, 32, 3, 1, 1, 1),
nn.ReLU(inplace=True),
convbn(32, 32, 3, 1, 1, 1),
nn.ReLU(inplace=True))
self.layer1 = self._make_layer(BasicBlock, 32, 3, 1, 1, 1)
self.layer2 = self._make_layer(BasicBlock, 64, 16, 2, 1, 1)
self.layer3 = self._make_layer(BasicBlock, 128, 3, 1, 1, 1)
self.layer4 = self._make_layer(BasicBlock, 128, 3, 1, 1, 2)
if self.concat_feature:
self.lastconv = nn.Sequential(convbn(320, 128, 3, 1, 1, 1),
nn.ReLU(inplace=True),
nn.Conv2d(128, concat_feature_channel, kernel_size=1, padding=0, stride=1,
bias=False))
def _make_layer(self, block, planes, blocks, stride, pad, dilation):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * block.expansion), )
layers = []
layers.append(block(self.inplanes, planes, stride, downsample, pad, dilation))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes, 1, None, pad, dilation))
return nn.Sequential(*layers)
def forward(self, x):
x = self.firstconv(x)
x = self.layer1(x)
l2 = self.layer2(x)
l3 = self.layer3(l2)
l4 = self.layer4(l3)
gwc_feature = torch.cat((l2, l3, l4), dim=1)
if not self.concat_feature:
return {"gwc_feature": gwc_feature}
else:
concat_feature = self.lastconv(gwc_feature)
return {"gwc_feature": gwc_feature, "concat_feature": concat_feature}
class hourglass(nn.Module):
def __init__(self, in_channels):
super(hourglass, self).__init__()
self.conv1 = nn.Sequential(convbn_3d(in_channels, in_channels * 2, 3, 2, 1),
nn.ReLU(inplace=True))
self.conv2 = nn.Sequential(convbn_3d(in_channels * 2, in_channels * 2, 3, 1, 1),
2019-04-14 17:34:58 +08:00
nn.ReLU(inplace=True))
self.conv3 = nn.Sequential(convbn_3d(in_channels * 2, in_channels * 4, 3, 2, 1),
nn.ReLU(inplace=True))
self.conv4 = nn.Sequential(convbn_3d(in_channels * 4, in_channels * 4, 3, 1, 1),
2019-04-14 17:34:58 +08:00
nn.ReLU(inplace=True))
self.conv5 = nn.Sequential(
nn.ConvTranspose3d(in_channels * 4, in_channels * 2, 3, padding=1, output_padding=1, stride=2, bias=False),
nn.BatchNorm3d(in_channels * 2))
self.conv6 = nn.Sequential(
nn.ConvTranspose3d(in_channels * 2, in_channels, 3, padding=1, output_padding=1, stride=2, bias=False),
nn.BatchNorm3d(in_channels))
self.redir1 = convbn_3d(in_channels, in_channels, kernel_size=1, stride=1, pad=0)
self.redir2 = convbn_3d(in_channels * 2, in_channels * 2, kernel_size=1, stride=1, pad=0)
def forward(self, x):
conv1 = self.conv1(x)
conv2 = self.conv2(conv1)
conv3 = self.conv3(conv2)
conv4 = self.conv4(conv3)
conv5 = F.relu(self.conv5(conv4) + self.redir2(conv2), inplace=True)
conv6 = F.relu(self.conv6(conv5) + self.redir1(x), inplace=True)
return conv6
class GwcNet(nn.Module):
def __init__(self, maxdisp, use_concat_volume=False):
super(GwcNet, self).__init__()
self.maxdisp = maxdisp
self.use_concat_volume = use_concat_volume
self.num_groups = 40
if self.use_concat_volume:
self.concat_channels = 12
self.feature_extraction = feature_extraction(concat_feature=True,
concat_feature_channel=self.concat_channels)
else:
self.concat_channels = 0
self.feature_extraction = feature_extraction(concat_feature=False)
self.dres0 = nn.Sequential(convbn_3d(self.num_groups + self.concat_channels * 2, 32, 3, 1, 1),
nn.ReLU(inplace=True),
convbn_3d(32, 32, 3, 1, 1),
nn.ReLU(inplace=True))
self.dres1 = nn.Sequential(convbn_3d(32, 32, 3, 1, 1),
nn.ReLU(inplace=True),
convbn_3d(32, 32, 3, 1, 1))
self.dres2 = hourglass(32)
self.dres3 = hourglass(32)
self.dres4 = hourglass(32)
self.classif0 = nn.Sequential(convbn_3d(32, 32, 3, 1, 1),
nn.ReLU(inplace=True),
nn.Conv3d(32, 1, kernel_size=3, padding=1, stride=1, bias=False))
self.classif1 = nn.Sequential(convbn_3d(32, 32, 3, 1, 1),
nn.ReLU(inplace=True),
nn.Conv3d(32, 1, kernel_size=3, padding=1, stride=1, bias=False))
self.classif2 = nn.Sequential(convbn_3d(32, 32, 3, 1, 1),
nn.ReLU(inplace=True),
nn.Conv3d(32, 1, kernel_size=3, padding=1, stride=1, bias=False))
self.classif3 = nn.Sequential(convbn_3d(32, 32, 3, 1, 1),
nn.ReLU(inplace=True),
nn.Conv3d(32, 1, kernel_size=3, padding=1, stride=1, bias=False))
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.Conv3d):
n = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm3d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.bias.data.zero_()
def forward(self, left, right):
features_left = self.feature_extraction(left)
features_right = self.feature_extraction(right)
gwc_volume = build_gwc_volume(features_left["gwc_feature"], features_right["gwc_feature"], self.maxdisp // 4,
self.num_groups)
if self.use_concat_volume:
concat_volume = build_concat_volume(features_left["concat_feature"], features_right["concat_feature"],
self.maxdisp // 4)
volume = torch.cat((gwc_volume, concat_volume), 1)
else:
volume = gwc_volume
cost0 = self.dres0(volume)
cost0 = self.dres1(cost0) + cost0
out1 = self.dres2(cost0)
out2 = self.dres3(out1)
out3 = self.dres4(out2)
if self.training:
cost0 = self.classif0(cost0)
cost1 = self.classif1(out1)
cost2 = self.classif2(out2)
cost3 = self.classif3(out3)
cost0 = F.upsample(cost0, [self.maxdisp, left.size()[2], left.size()[3]], mode='trilinear')
cost0 = torch.squeeze(cost0, 1)
pred0 = F.softmax(cost0, dim=1)
pred0 = disparity_regression(pred0, self.maxdisp)
cost1 = F.upsample(cost1, [self.maxdisp, left.size()[2], left.size()[3]], mode='trilinear')
cost1 = torch.squeeze(cost1, 1)
pred1 = F.softmax(cost1, dim=1)
pred1 = disparity_regression(pred1, self.maxdisp)
cost2 = F.upsample(cost2, [self.maxdisp, left.size()[2], left.size()[3]], mode='trilinear')
cost2 = torch.squeeze(cost2, 1)
pred2 = F.softmax(cost2, dim=1)
pred2 = disparity_regression(pred2, self.maxdisp)
cost3 = F.upsample(cost3, [self.maxdisp, left.size()[2], left.size()[3]], mode='trilinear')
cost3 = torch.squeeze(cost3, 1)
pred3 = F.softmax(cost3, dim=1)
pred3 = disparity_regression(pred3, self.maxdisp)
return [pred0, pred1, pred2, pred3]
else:
cost3 = self.classif3(out3)
cost3 = F.upsample(cost3, [self.maxdisp, left.size()[2], left.size()[3]], mode='trilinear')
cost3 = torch.squeeze(cost3, 1)
pred3 = F.softmax(cost3, dim=1)
pred3 = disparity_regression(pred3, self.maxdisp)
return [pred3]
def GwcNet_G(d):
return GwcNet(d, use_concat_volume=False)
def GwcNet_GC(d):
return GwcNet(d, use_concat_volume=True)