locked backbone parameters
This commit is contained in:
parent
3f60e691f8
commit
0a3613711b
@ -325,21 +325,24 @@ class SubModule(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class Feature(SubModule):
|
class Feature(SubModule):
|
||||||
def __init__(self):
|
def __init__(self, freeze):
|
||||||
super(Feature, self).__init__()
|
super(Feature, self).__init__()
|
||||||
pretrained = True
|
pretrained = True
|
||||||
model = timm.create_model('mobilenetv2_100', pretrained=pretrained, features_only=True)
|
self.model = timm.create_model('mobilenetv2_100', pretrained=pretrained, features_only=True)
|
||||||
|
if freeze:
|
||||||
|
for p in self.model.parameters():
|
||||||
|
p.requires_grad = False
|
||||||
layers = [1,2,3,5,6]
|
layers = [1,2,3,5,6]
|
||||||
chans = [16, 24, 32, 96, 160]
|
chans = [16, 24, 32, 96, 160]
|
||||||
self.conv_stem = model.conv_stem
|
self.conv_stem = self.model.conv_stem
|
||||||
self.bn1 = model.bn1
|
self.bn1 = self.model.bn1
|
||||||
self.act1 = model.act1
|
self.act1 = self.model.act1
|
||||||
|
|
||||||
self.block0 = torch.nn.Sequential(*model.blocks[0:layers[0]])
|
self.block0 = torch.nn.Sequential(*self.model.blocks[0:layers[0]])
|
||||||
self.block1 = torch.nn.Sequential(*model.blocks[layers[0]:layers[1]])
|
self.block1 = torch.nn.Sequential(*self.model.blocks[layers[0]:layers[1]])
|
||||||
self.block2 = torch.nn.Sequential(*model.blocks[layers[1]:layers[2]])
|
self.block2 = torch.nn.Sequential(*self.model.blocks[layers[1]:layers[2]])
|
||||||
self.block3 = torch.nn.Sequential(*model.blocks[layers[2]:layers[3]])
|
self.block3 = torch.nn.Sequential(*self.model.blocks[layers[2]:layers[3]])
|
||||||
self.block4 = torch.nn.Sequential(*model.blocks[layers[3]:layers[4]])
|
self.block4 = torch.nn.Sequential(*self.model.blocks[layers[3]:layers[4]])
|
||||||
|
|
||||||
self.deconv32_16 = Conv2x_IN(chans[4], chans[3], deconv=True, concat=True)
|
self.deconv32_16 = Conv2x_IN(chans[4], chans[3], deconv=True, concat=True)
|
||||||
self.deconv16_8 = Conv2x_IN(chans[3]*2, chans[2], deconv=True, concat=True)
|
self.deconv16_8 = Conv2x_IN(chans[3]*2, chans[2], deconv=True, concat=True)
|
||||||
|
@ -100,7 +100,7 @@ class IGEVStereo(nn.Module):
|
|||||||
|
|
||||||
self.context_zqr_convs = nn.ModuleList([nn.Conv2d(context_dims[i], args.hidden_dims[i]*3, 3, padding=3//2) for i in range(self.args.n_gru_layers)])
|
self.context_zqr_convs = nn.ModuleList([nn.Conv2d(context_dims[i], args.hidden_dims[i]*3, 3, padding=3//2) for i in range(self.args.n_gru_layers)])
|
||||||
|
|
||||||
self.feature = Feature()
|
self.feature = Feature(args.freeze_backbone_params)
|
||||||
|
|
||||||
self.stem_2 = nn.Sequential(
|
self.stem_2 = nn.Sequential(
|
||||||
BasicConv_IN(3, 32, kernel_size=3, stride=2, padding=1),
|
BasicConv_IN(3, 32, kernel_size=3, stride=2, padding=1),
|
||||||
|
@ -78,7 +78,7 @@ def sequence_loss(disp_preds, disp_init_pred, disp_gt, valid, loss_gamma=0.9, ma
|
|||||||
|
|
||||||
def fetch_optimizer(args, model):
|
def fetch_optimizer(args, model):
|
||||||
""" Create the optimizer and learning rate scheduler """
|
""" Create the optimizer and learning rate scheduler """
|
||||||
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wdecay, eps=1e-8)
|
optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=args.wdecay, eps=1e-8)
|
||||||
|
|
||||||
# todo: cosine scheduler, warm-up
|
# todo: cosine scheduler, warm-up
|
||||||
scheduler = optim.lr_scheduler.OneCycleLR(optimizer, args.lr, args.num_steps+100,
|
scheduler = optim.lr_scheduler.OneCycleLR(optimizer, args.lr, args.num_steps+100,
|
||||||
@ -222,6 +222,7 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument('--train_datasets', nargs='+', default=['sceneflow'], help="training datasets.")
|
parser.add_argument('--train_datasets', nargs='+', default=['sceneflow'], help="training datasets.")
|
||||||
parser.add_argument('--lr', type=float, default=0.0002, help="max learning rate.")
|
parser.add_argument('--lr', type=float, default=0.0002, help="max learning rate.")
|
||||||
parser.add_argument('--num_steps', type=int, default=200000, help="length of training schedule.")
|
parser.add_argument('--num_steps', type=int, default=200000, help="length of training schedule.")
|
||||||
|
parser.add_argument('--freeze_backbone_params', action="store_true", help="freeze backbone parameters")
|
||||||
parser.add_argument('--image_size', type=int, nargs='+', default=[320, 736], help="size of the random image crops used during training.")
|
parser.add_argument('--image_size', type=int, nargs='+', default=[320, 736], help="size of the random image crops used during training.")
|
||||||
parser.add_argument('--train_iters', type=int, default=22, help="number of updates to the disparity field in each forward pass.")
|
parser.add_argument('--train_iters', type=int, default=22, help="number of updates to the disparity field in each forward pass.")
|
||||||
parser.add_argument('--wdecay', type=float, default=.00001, help="Weight decay in optimizer.")
|
parser.add_argument('--wdecay', type=float, default=.00001, help="Weight decay in optimizer.")
|
||||||
|
Loading…
Reference in New Issue
Block a user