[docs]deffreeze_bn(mod:Module)->Module:"""Prevents parameter and stats from updating in Batchnorm layers that are frozen Args: mod (torch.nn.Module): model to train Returns: torch.nn.Module: model """# Loop on modulesforminmod.modules():ifisinstance(m,_BatchNorm)andm.affineandall(notp.requires_gradforpinm.parameters()):# Switch back to commented code when https://github.com/pytorch/pytorch/issues/37823 is resolvedm.track_running_stats=Falsem.eval()returnmod
[docs]deffreeze_model(model:Module,last_frozen_layer:Optional[str]=None,frozen_bn_stat_update:bool=False)->Module:"""Freeze a specific range of model layers Args: model (torch.nn.Module): model to train last_frozen_layer (str, optional): last layer to freeze. Assumes layers have been registered in forward order frozen_bn_stat_update (bool, optional): force stats update in BN layers that are frozen Returns: torch.nn.Module: model """# Loop on parametersifisinstance(last_frozen_layer,str):layer_reached=Falseforn,pinmodel.named_parameters():ifn.startswith(last_frozen_layer):layer_reached=Truep.requires_grad_(False)elifnotlayer_reached:p.requires_grad_(False)ifnotlayer_reached:raiseValueError(f"Unable to locate child module {last_frozen_layer}")# Loop on modulesifnotfrozen_bn_stat_update:model=freeze_bn(model)returnmodel