IGEV/IGEV-Stereo/core/igev_stereo.py
2023-04-24 16:37:30 +08:00

222 lines
10 KiB
Python

import torch
import torch.nn as nn
import torch.nn.functional as F
from core.update import BasicMultiUpdateBlock
from core.extractor import MultiBasicEncoder, Feature
from core.geometry import Combined_Geo_Encoding_Volume
from core.submodule import *
import time
try:
autocast = torch.cuda.amp.autocast
except:
class autocast:
def __init__(self, enabled):
pass
def __enter__(self):
pass
def __exit__(self, *args):
pass
class hourglass(nn.Module):
def __init__(self, in_channels):
super(hourglass, self).__init__()
self.conv1 = nn.Sequential(BasicConv(in_channels, in_channels*2, is_3d=True, bn=True, relu=True, kernel_size=3,
padding=1, stride=2, dilation=1),
BasicConv(in_channels*2, in_channels*2, is_3d=True, bn=True, relu=True, kernel_size=3,
padding=1, stride=1, dilation=1))
self.conv2 = nn.Sequential(BasicConv(in_channels*2, in_channels*4, is_3d=True, bn=True, relu=True, kernel_size=3,
padding=1, stride=2, dilation=1),
BasicConv(in_channels*4, in_channels*4, is_3d=True, bn=True, relu=True, kernel_size=3,
padding=1, stride=1, dilation=1))
self.conv3 = nn.Sequential(BasicConv(in_channels*4, in_channels*6, is_3d=True, bn=True, relu=True, kernel_size=3,
padding=1, stride=2, dilation=1),
BasicConv(in_channels*6, in_channels*6, is_3d=True, bn=True, relu=True, kernel_size=3,
padding=1, stride=1, dilation=1))
self.conv3_up = BasicConv(in_channels*6, in_channels*4, deconv=True, is_3d=True, bn=True,
relu=True, kernel_size=(4, 4, 4), padding=(1, 1, 1), stride=(2, 2, 2))
self.conv2_up = BasicConv(in_channels*4, in_channels*2, deconv=True, is_3d=True, bn=True,
relu=True, kernel_size=(4, 4, 4), padding=(1, 1, 1), stride=(2, 2, 2))
self.conv1_up = BasicConv(in_channels*2, 8, deconv=True, is_3d=True, bn=False,
relu=False, kernel_size=(4, 4, 4), padding=(1, 1, 1), stride=(2, 2, 2))
self.agg_0 = nn.Sequential(BasicConv(in_channels*8, in_channels*4, is_3d=True, kernel_size=1, padding=0, stride=1),
BasicConv(in_channels*4, in_channels*4, is_3d=True, kernel_size=3, padding=1, stride=1),
BasicConv(in_channels*4, in_channels*4, is_3d=True, kernel_size=3, padding=1, stride=1),)
self.agg_1 = nn.Sequential(BasicConv(in_channels*4, in_channels*2, is_3d=True, kernel_size=1, padding=0, stride=1),
BasicConv(in_channels*2, in_channels*2, is_3d=True, kernel_size=3, padding=1, stride=1),
BasicConv(in_channels*2, in_channels*2, is_3d=True, kernel_size=3, padding=1, stride=1))
self.feature_att_8 = FeatureAtt(in_channels*2, 64)
self.feature_att_16 = FeatureAtt(in_channels*4, 192)
self.feature_att_32 = FeatureAtt(in_channels*6, 160)
self.feature_att_up_16 = FeatureAtt(in_channels*4, 192)
self.feature_att_up_8 = FeatureAtt(in_channels*2, 64)
def forward(self, x, features):
conv1 = self.conv1(x)
conv1 = self.feature_att_8(conv1, features[1])
conv2 = self.conv2(conv1)
conv2 = self.feature_att_16(conv2, features[2])
conv3 = self.conv3(conv2)
conv3 = self.feature_att_32(conv3, features[3])
conv3_up = self.conv3_up(conv3)
conv2 = torch.cat((conv3_up, conv2), dim=1)
conv2 = self.agg_0(conv2)
conv2 = self.feature_att_up_16(conv2, features[2])
conv2_up = self.conv2_up(conv2)
conv1 = torch.cat((conv2_up, conv1), dim=1)
conv1 = self.agg_1(conv1)
conv1 = self.feature_att_up_8(conv1, features[1])
conv = self.conv1_up(conv1)
return conv
class IGEVStereo(nn.Module):
def __init__(self, args):
super().__init__()
self.args = args
context_dims = args.hidden_dims
self.cnet = MultiBasicEncoder(output_dim=[args.hidden_dims, context_dims], norm_fn="batch", downsample=args.n_downsample)
self.update_block = BasicMultiUpdateBlock(self.args, hidden_dims=args.hidden_dims)
self.context_zqr_convs = nn.ModuleList([nn.Conv2d(context_dims[i], args.hidden_dims[i]*3, 3, padding=3//2) for i in range(self.args.n_gru_layers)])
self.feature = Feature()
self.stem_2 = nn.Sequential(
BasicConv_IN(3, 32, kernel_size=3, stride=2, padding=1),
nn.Conv2d(32, 32, 3, 1, 1, bias=False),
nn.InstanceNorm2d(32), nn.ReLU()
)
self.stem_4 = nn.Sequential(
BasicConv_IN(32, 48, kernel_size=3, stride=2, padding=1),
nn.Conv2d(48, 48, 3, 1, 1, bias=False),
nn.InstanceNorm2d(48), nn.ReLU()
)
self.spx = nn.Sequential(nn.ConvTranspose2d(2*32, 9, kernel_size=4, stride=2, padding=1),)
self.spx_2 = Conv2x_IN(24, 32, True)
self.spx_4 = nn.Sequential(
BasicConv_IN(96, 24, kernel_size=3, stride=1, padding=1),
nn.Conv2d(24, 24, 3, 1, 1, bias=False),
nn.InstanceNorm2d(24), nn.ReLU()
)
self.spx_2_gru = Conv2x(32, 32, True)
self.spx_gru = nn.Sequential(nn.ConvTranspose2d(2*32, 9, kernel_size=4, stride=2, padding=1),)
self.conv = BasicConv_IN(96, 96, kernel_size=3, padding=1, stride=1)
self.desc = nn.Conv2d(96, 96, kernel_size=1, padding=0, stride=1)
self.corr_stem = BasicConv(8, 8, is_3d=True, kernel_size=3, stride=1, padding=1)
self.corr_feature_att = FeatureAtt(8, 96)
self.cost_agg = hourglass(8)
self.classifier = nn.Conv3d(8, 1, 3, 1, 1, bias=False)
def freeze_bn(self):
for m in self.modules():
if isinstance(m, nn.BatchNorm2d):
m.eval()
def upsample_disp(self, disp, mask_feat_4, stem_2x):
with autocast(enabled=self.args.mixed_precision):
xspx = self.spx_2_gru(mask_feat_4, stem_2x)
spx_pred = self.spx_gru(xspx)
spx_pred = F.softmax(spx_pred, 1)
up_disp = context_upsample(disp*4., spx_pred).unsqueeze(1)
return up_disp
def forward(self, image1, image2, iters=12, flow_init=None, test_mode=False):
""" Estimate disparity between pair of frames """
image1 = (2 * (image1 / 255.0) - 1.0).contiguous()
image2 = (2 * (image2 / 255.0) - 1.0).contiguous()
with autocast(enabled=self.args.mixed_precision):
features_left = self.feature(image1)
features_right = self.feature(image2)
stem_2x = self.stem_2(image1)
stem_4x = self.stem_4(stem_2x)
stem_2y = self.stem_2(image2)
stem_4y = self.stem_4(stem_2y)
features_left[0] = torch.cat((features_left[0], stem_4x), 1)
features_right[0] = torch.cat((features_right[0], stem_4y), 1)
match_left = self.desc(self.conv(features_left[0]))
match_right = self.desc(self.conv(features_right[0]))
gwc_volume = build_gwc_volume(match_left, match_right, self.args.max_disp//4, 8)
gwc_volume = self.corr_stem(gwc_volume)
gwc_volume = self.corr_feature_att(gwc_volume, features_left[0])
geo_encoding_volume = self.cost_agg(gwc_volume, features_left)
# Init disp from geometry encoding volume
prob = F.softmax(self.classifier(geo_encoding_volume).squeeze(1), dim=1)
init_disp = disparity_regression(prob, self.args.max_disp//4)
del prob, gwc_volume
if not test_mode:
xspx = self.spx_4(features_left[0])
xspx = self.spx_2(xspx, stem_2x)
spx_pred = self.spx(xspx)
spx_pred = F.softmax(spx_pred, 1)
cnet_list = self.cnet(image1, num_layers=self.args.n_gru_layers)
net_list = [torch.tanh(x[0]) for x in cnet_list]
inp_list = [torch.relu(x[1]) for x in cnet_list]
inp_list = [list(conv(i).split(split_size=conv.out_channels//3, dim=1)) for i,conv in zip(inp_list, self.context_zqr_convs)]
geo_block = Combined_Geo_Encoding_Volume
geo_fn = geo_block(match_left.float(), match_right.float(), geo_encoding_volume.float(), radius=self.args.corr_radius, num_levels=self.args.corr_levels)
b, c, h, w = match_left.shape
coords = torch.arange(w, device=match_left.device).float().reshape(1,1,w,1).repeat(b, h, 1, 1)
disp = init_disp
disp_preds = []
# GRUs iterations to update disparity
for itr in range(iters):
disp = disp.detach()
geo_feat = geo_fn(disp, coords)
with autocast(enabled=self.args.mixed_precision):
if self.args.n_gru_layers == 3 and self.args.slow_fast_gru: # Update low-res ConvGRU
net_list = self.update_block(net_list, inp_list, iter16=True, iter08=False, iter04=False, update=False)
if self.args.n_gru_layers >= 2 and self.args.slow_fast_gru:# Update low-res ConvGRU and mid-res ConvGRU
net_list = self.update_block(net_list, inp_list, iter16=self.args.n_gru_layers==3, iter08=True, iter04=False, update=False)
net_list, mask_feat_4, delta_disp = self.update_block(net_list, inp_list, geo_feat, disp, iter16=self.args.n_gru_layers==3, iter08=self.args.n_gru_layers>=2)
disp = disp + delta_disp
if test_mode and itr < iters-1:
continue
# upsample predictions
disp_up = self.upsample_disp(disp, mask_feat_4, stem_2x)
disp_preds.append(disp_up)
if test_mode:
return disp_up
init_disp = context_upsample(init_disp*4., spx_pred.float()).unsqueeze(1)
return init_disp, disp_preds