locked backbone parameters

This commit is contained in:
HTensor 2023-04-25 20:19:43 +08:00
parent 3f60e691f8
commit 0a3613711b
3 changed files with 16 additions and 12 deletions

View File

@ -325,21 +325,24 @@ class SubModule(nn.Module):
class Feature(SubModule): class Feature(SubModule):
def __init__(self): def __init__(self, freeze):
super(Feature, self).__init__() super(Feature, self).__init__()
pretrained = True pretrained = True
model = timm.create_model('mobilenetv2_100', pretrained=pretrained, features_only=True) self.model = timm.create_model('mobilenetv2_100', pretrained=pretrained, features_only=True)
if freeze:
for p in self.model.parameters():
p.requires_grad = False
layers = [1,2,3,5,6] layers = [1,2,3,5,6]
chans = [16, 24, 32, 96, 160] chans = [16, 24, 32, 96, 160]
self.conv_stem = model.conv_stem self.conv_stem = self.model.conv_stem
self.bn1 = model.bn1 self.bn1 = self.model.bn1
self.act1 = model.act1 self.act1 = self.model.act1
self.block0 = torch.nn.Sequential(*model.blocks[0:layers[0]]) self.block0 = torch.nn.Sequential(*self.model.blocks[0:layers[0]])
self.block1 = torch.nn.Sequential(*model.blocks[layers[0]:layers[1]]) self.block1 = torch.nn.Sequential(*self.model.blocks[layers[0]:layers[1]])
self.block2 = torch.nn.Sequential(*model.blocks[layers[1]:layers[2]]) self.block2 = torch.nn.Sequential(*self.model.blocks[layers[1]:layers[2]])
self.block3 = torch.nn.Sequential(*model.blocks[layers[2]:layers[3]]) self.block3 = torch.nn.Sequential(*self.model.blocks[layers[2]:layers[3]])
self.block4 = torch.nn.Sequential(*model.blocks[layers[3]:layers[4]]) self.block4 = torch.nn.Sequential(*self.model.blocks[layers[3]:layers[4]])
self.deconv32_16 = Conv2x_IN(chans[4], chans[3], deconv=True, concat=True) self.deconv32_16 = Conv2x_IN(chans[4], chans[3], deconv=True, concat=True)
self.deconv16_8 = Conv2x_IN(chans[3]*2, chans[2], deconv=True, concat=True) self.deconv16_8 = Conv2x_IN(chans[3]*2, chans[2], deconv=True, concat=True)

View File

@ -100,7 +100,7 @@ class IGEVStereo(nn.Module):
self.context_zqr_convs = nn.ModuleList([nn.Conv2d(context_dims[i], args.hidden_dims[i]*3, 3, padding=3//2) for i in range(self.args.n_gru_layers)]) self.context_zqr_convs = nn.ModuleList([nn.Conv2d(context_dims[i], args.hidden_dims[i]*3, 3, padding=3//2) for i in range(self.args.n_gru_layers)])
self.feature = Feature() self.feature = Feature(args.freeze_backbone_params)
self.stem_2 = nn.Sequential( self.stem_2 = nn.Sequential(
BasicConv_IN(3, 32, kernel_size=3, stride=2, padding=1), BasicConv_IN(3, 32, kernel_size=3, stride=2, padding=1),

View File

@ -78,7 +78,7 @@ def sequence_loss(disp_preds, disp_init_pred, disp_gt, valid, loss_gamma=0.9, ma
def fetch_optimizer(args, model): def fetch_optimizer(args, model):
""" Create the optimizer and learning rate scheduler """ """ Create the optimizer and learning rate scheduler """
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wdecay, eps=1e-8) optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=args.wdecay, eps=1e-8)
# todo: cosine scheduler, warm-up # todo: cosine scheduler, warm-up
scheduler = optim.lr_scheduler.OneCycleLR(optimizer, args.lr, args.num_steps+100, scheduler = optim.lr_scheduler.OneCycleLR(optimizer, args.lr, args.num_steps+100,
@ -222,6 +222,7 @@ if __name__ == '__main__':
parser.add_argument('--train_datasets', nargs='+', default=['sceneflow'], help="training datasets.") parser.add_argument('--train_datasets', nargs='+', default=['sceneflow'], help="training datasets.")
parser.add_argument('--lr', type=float, default=0.0002, help="max learning rate.") parser.add_argument('--lr', type=float, default=0.0002, help="max learning rate.")
parser.add_argument('--num_steps', type=int, default=200000, help="length of training schedule.") parser.add_argument('--num_steps', type=int, default=200000, help="length of training schedule.")
parser.add_argument('--freeze_backbone_params', action="store_true", help="freeze backbone parameters")
parser.add_argument('--image_size', type=int, nargs='+', default=[320, 736], help="size of the random image crops used during training.") parser.add_argument('--image_size', type=int, nargs='+', default=[320, 736], help="size of the random image crops used during training.")
parser.add_argument('--train_iters', type=int, default=22, help="number of updates to the disparity field in each forward pass.") parser.add_argument('--train_iters', type=int, default=22, help="number of updates to the disparity field in each forward pass.")
parser.add_argument('--wdecay', type=float, default=.00001, help="Weight decay in optimizer.") parser.add_argument('--wdecay', type=float, default=.00001, help="Weight decay in optimizer.")