# Copyright (C) 2019-2024, François-Guillaume Fernandez.# This program is licensed under the Apache License 2.0.# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details.fromtypingimportList,Optional,Tuplefromtorchimportnnfromtorch.nn.modules.batchnormimport_BatchNorm__all__=["freeze_bn","freeze_model","split_normalization_params"]
[docs]deffreeze_bn(mod:nn.Module)->None:"""Prevents parameter and stats from updating in Batchnorm layers that are frozen >>> from holocron.models import rexnet1_0x >>> from holocron.trainer.utils import freeze_bn >>> model = rexnet1_0x() >>> freeze_bn(model) Args: mod (torch.nn.Module): model to train """# Loop on modulesforminmod.modules():ifisinstance(m,_BatchNorm)andm.affineandall(notp.requires_gradforpinm.parameters()):# Switch back to commented code when https://github.com/pytorch/pytorch/issues/37823 is resolvedm.track_running_stats=Falsem.eval()
[docs]deffreeze_model(model:nn.Module,last_frozen_layer:Optional[str]=None,frozen_bn_stat_update:bool=False,)->None:"""Freeze a specific range of model layers. >>> from holocron.models import rexnet1_0x >>> from holocron.trainer.utils import freeze_model >>> model = rexnet1_0x() >>> freeze_model(model) Args: model (torch.nn.Module): model to train last_frozen_layer (str, optional): last layer to freeze. Assumes layers have been registered in forward order frozen_bn_stat_update (bool, optional): force stats update in BN layers that are frozen """# Unfreeze everythingforpinmodel.parameters():p.requires_grad_(True)# Loop on parametersifisinstance(last_frozen_layer,str):layer_reached=Falseforn,pinmodel.named_parameters():ifnotlayer_reachedorn.startswith(last_frozen_layer):p.requires_grad_(False)ifn.startswith(last_frozen_layer):layer_reached=True# Once the last param of the layer is frozen, we breakeliflayer_reached:breakifnotlayer_reached:raiseValueError(f"Unable to locate child module {last_frozen_layer}")# Loop on modulesifnotfrozen_bn_stat_update:freeze_bn(model)
defsplit_normalization_params(model:nn.Module,norm_classes:Optional[List[type]]=None,)->Tuple[List[nn.Parameter],List[nn.Parameter]]:"""Split the param groups by normalization schemes"""# Borrowed from https://github.com/pytorch/vision/blob/main/torchvision/ops/_utils.py# Adapted from https://github.com/facebookresearch/ClassyVision/blob/659d7f78/classy_vision/generic/util.py#L501ifnotnorm_classes:norm_classes=[nn.modules.batchnorm._BatchNorm,nn.LayerNorm,nn.GroupNorm]fortinnorm_classes:ifnotissubclass(t,nn.Module):raiseValueError(f"Class {t} is not a subclass of nn.Module.")classes=tuple(norm_classes)norm_params:List[nn.Parameter]=[]other_params:List[nn.Parameter]=[]formoduleinmodel.modules():ifnext(module.children(),None):other_params.extend(pforpinmodule.parameters(recurse=False)ifp.requires_grad)elifisinstance(module,classes):norm_params.extend(pforpinmodule.parameters()ifp.requires_grad)else:other_params.extend(pforpinmodule.parameters()ifp.requires_grad)returnnorm_params,other_params