From 0a3613711bafef22a39cd83319e82552de20d304 Mon Sep 17 00:00:00 2001 From: HTensor Date: Tue, 25 Apr 2023 20:19:43 +0800 Subject: [PATCH] locked backbone parameters --- IGEV-Stereo/core/extractor.py | 23 +++++++++++++---------- IGEV-Stereo/core/igev_stereo.py | 2 +- IGEV-Stereo/train_stereo.py | 3 ++- 3 files changed, 16 insertions(+), 12 deletions(-) diff --git a/IGEV-Stereo/core/extractor.py b/IGEV-Stereo/core/extractor.py index 307df38..35e3f35 100644 --- a/IGEV-Stereo/core/extractor.py +++ b/IGEV-Stereo/core/extractor.py @@ -325,21 +325,24 @@ class SubModule(nn.Module): class Feature(SubModule): - def __init__(self): + def __init__(self, freeze): super(Feature, self).__init__() pretrained = True - model = timm.create_model('mobilenetv2_100', pretrained=pretrained, features_only=True) + self.model = timm.create_model('mobilenetv2_100', pretrained=pretrained, features_only=True) + if freeze: + for p in self.model.parameters(): + p.requires_grad = False layers = [1,2,3,5,6] chans = [16, 24, 32, 96, 160] - self.conv_stem = model.conv_stem - self.bn1 = model.bn1 - self.act1 = model.act1 + self.conv_stem = self.model.conv_stem + self.bn1 = self.model.bn1 + self.act1 = self.model.act1 - self.block0 = torch.nn.Sequential(*model.blocks[0:layers[0]]) - self.block1 = torch.nn.Sequential(*model.blocks[layers[0]:layers[1]]) - self.block2 = torch.nn.Sequential(*model.blocks[layers[1]:layers[2]]) - self.block3 = torch.nn.Sequential(*model.blocks[layers[2]:layers[3]]) - self.block4 = torch.nn.Sequential(*model.blocks[layers[3]:layers[4]]) + self.block0 = torch.nn.Sequential(*self.model.blocks[0:layers[0]]) + self.block1 = torch.nn.Sequential(*self.model.blocks[layers[0]:layers[1]]) + self.block2 = torch.nn.Sequential(*self.model.blocks[layers[1]:layers[2]]) + self.block3 = torch.nn.Sequential(*self.model.blocks[layers[2]:layers[3]]) + self.block4 = torch.nn.Sequential(*self.model.blocks[layers[3]:layers[4]]) self.deconv32_16 = Conv2x_IN(chans[4], chans[3], deconv=True, concat=True) self.deconv16_8 = Conv2x_IN(chans[3]*2, chans[2], deconv=True, concat=True) diff --git a/IGEV-Stereo/core/igev_stereo.py b/IGEV-Stereo/core/igev_stereo.py index c432092..acb316e 100644 --- a/IGEV-Stereo/core/igev_stereo.py +++ b/IGEV-Stereo/core/igev_stereo.py @@ -100,7 +100,7 @@ class IGEVStereo(nn.Module): self.context_zqr_convs = nn.ModuleList([nn.Conv2d(context_dims[i], args.hidden_dims[i]*3, 3, padding=3//2) for i in range(self.args.n_gru_layers)]) - self.feature = Feature() + self.feature = Feature(args.freeze_backbone_params) self.stem_2 = nn.Sequential( BasicConv_IN(3, 32, kernel_size=3, stride=2, padding=1), diff --git a/IGEV-Stereo/train_stereo.py b/IGEV-Stereo/train_stereo.py index 7e3e493..19262c5 100644 --- a/IGEV-Stereo/train_stereo.py +++ b/IGEV-Stereo/train_stereo.py @@ -78,7 +78,7 @@ def sequence_loss(disp_preds, disp_init_pred, disp_gt, valid, loss_gamma=0.9, ma def fetch_optimizer(args, model): """ Create the optimizer and learning rate scheduler """ - optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wdecay, eps=1e-8) + optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=args.wdecay, eps=1e-8) # todo: cosine scheduler, warm-up scheduler = optim.lr_scheduler.OneCycleLR(optimizer, args.lr, args.num_steps+100, @@ -222,6 +222,7 @@ if __name__ == '__main__': parser.add_argument('--train_datasets', nargs='+', default=['sceneflow'], help="training datasets.") parser.add_argument('--lr', type=float, default=0.0002, help="max learning rate.") parser.add_argument('--num_steps', type=int, default=200000, help="length of training schedule.") + parser.add_argument('--freeze_backbone_params', action="store_true", help="freeze backbone parameters") parser.add_argument('--image_size', type=int, nargs='+', default=[320, 736], help="size of the random image crops used during training.") parser.add_argument('--train_iters', type=int, default=22, help="number of updates to the disparity field in each forward pass.") parser.add_argument('--wdecay', type=float, default=.00001, help="Weight decay in optimizer.")