# Copyright (C) 2019-2022, François-Guillaume Fernandez.# This program is licensed under the Apache License 2.0.# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details.fromtypingimportList,Optional,Tuplefromtorchimportnnfromtorch.nn.modules.batchnormimport_BatchNorm__all__=["freeze_bn","freeze_model","split_normalization_params"]
[docs]deffreeze_bn(mod:nn.Module)->nn.Module:"""Prevents parameter and stats from updating in Batchnorm layers that are frozen Args: mod (torch.nn.Module): model to train Returns: torch.nn.Module: model """# Loop on modulesforminmod.modules():ifisinstance(m,_BatchNorm)andm.affineandall(notp.requires_gradforpinm.parameters()):# Switch back to commented code when https://github.com/pytorch/pytorch/issues/37823 is resolvedm.track_running_stats=Falsem.eval()returnmod
[docs]deffreeze_model(model:nn.Module,last_frozen_layer:Optional[str]=None,frozen_bn_stat_update:bool=False)->nn.Module:"""Freeze a specific range of model layers Args: model (torch.nn.Module): model to train last_frozen_layer (str, optional): last layer to freeze. Assumes layers have been registered in forward order frozen_bn_stat_update (bool, optional): force stats update in BN layers that are frozen Returns: torch.nn.Module: model """# Unfreeze everythingforpinmodel.parameters():p.requires_grad_(True)# Loop on parametersifisinstance(last_frozen_layer,str):layer_reached=Falseforn,pinmodel.named_parameters():ifnotlayer_reachedorn.startswith(last_frozen_layer):p.requires_grad_(False)ifn.startswith(last_frozen_layer):layer_reached=True# Once the last param of the layer is frozen, we breakeliflayer_reached:breakifnotlayer_reached:raiseValueError(f"Unable to locate child module {last_frozen_layer}")# Loop on modulesifnotfrozen_bn_stat_update:model=freeze_bn(model)returnmodel
defsplit_normalization_params(model:nn.Module,norm_classes:Optional[List[type]]=None,)->Tuple[List[nn.Parameter],List[nn.Parameter]]:# type: ignore[name-defined]# Borrowed from https://github.com/pytorch/vision/blob/main/torchvision/ops/_utils.py# Adapted from https://github.com/facebookresearch/ClassyVision/blob/659d7f78/classy_vision/generic/util.py#L501ifnotnorm_classes:norm_classes=[nn.modules.batchnorm._BatchNorm,nn.LayerNorm,nn.GroupNorm]fortinnorm_classes:ifnotissubclass(t,nn.Module):raiseValueError(f"Class {t} is not a subclass of nn.Module.")classes=tuple(norm_classes)norm_params:List[nn.Parameter]=[]# type: ignore[name-defined]other_params:List[nn.Parameter]=[]# type: ignore[name-defined]formoduleinmodel.modules():ifnext(module.children(),None):other_params.extend(pforpinmodule.parameters(recurse=False)ifp.requires_grad)elifisinstance(module,classes):norm_params.extend(pforpinmodule.parameters()ifp.requires_grad)else:other_params.extend(pforpinmodule.parameters()ifp.requires_grad)returnnorm_params,other_params