# Copyright (C) 2020-2022, François-Guillaume Fernandez.# This program is licensed under the Apache License 2.0.# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details.importwarningsfromfunctoolsimportreducefromoperatorimportmulfromtorchimportTensor,nnfromtorch.nnimportModulefromtorch.nn.modules.batchnormimport_BatchNormfromtorch.nn.modules.convimport_ConvNd,_ConvTransposeNdfromtorch.nn.modules.poolingimport_AdaptiveAvgPoolNd,_AdaptiveMaxPoolNd,_AvgPoolNd,_MaxPoolNd__all__=["module_macs"]
[docs]defmodule_macs(module:Module,inp:Tensor,out:Tensor)->int:"""Estimate the number of multiply-accumulation operations performed by the module Args: module (torch.nn.Module): PyTorch module inp (torch.Tensor): input to the module out (torch.Tensor): output of the module Returns: int: number of MACs """ifisinstance(module,nn.Linear):returnmacs_linear(module,inp,out)elifisinstance(module,(nn.Identity,nn.ReLU,nn.ELU,nn.LeakyReLU,nn.ReLU6,nn.Tanh,nn.Sigmoid,nn.Flatten)):return0elifisinstance(module,_ConvTransposeNd):returnmacs_convtransposend(module,inp,out)elifisinstance(module,_ConvNd):returnmacs_convnd(module,inp,out)elifisinstance(module,_BatchNorm):returnmacs_bn(module,inp,out)elifisinstance(module,_MaxPoolNd):returnmacs_maxpool(module,inp,out)elifisinstance(module,_AvgPoolNd):returnmacs_avgpool(module,inp,out)elifisinstance(module,_AdaptiveMaxPoolNd):returnmacs_adaptive_maxpool(module,inp,out)elifisinstance(module,_AdaptiveAvgPoolNd):returnmacs_adaptive_avgpool(module,inp,out)elifisinstance(module,nn.Dropout):return0else:warnings.warn(f"Module type not supported: {module.__class__.__name__}")return0
defmacs_linear(module:nn.Linear,inp:Tensor,out:Tensor)->int:"""MACs estimation for `torch.nn.Linear`"""# batch size * out_chan * macs_per_elt (bias already counted in accumulation)mm_mac=module.in_features*reduce(mul,out.shape)returnmm_macdefmacs_convtransposend(module:_ConvTransposeNd,inp:Tensor,out:Tensor)->int:"""MACs estimation for `torch.nn.modules.conv._ConvTransposeNd`"""# Padding (# cf. https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/conv.py#L496-L532)# Define min and max sizes, then subtract thempadding_macs=len(module.kernel_size)*4# Rest of the operations are almost identical to a convolution (given the padding)conv_macs=macs_convnd(module,inp,out)returnpadding_macs+conv_macsdefmacs_convnd(module:_ConvNd,inp:Tensor,out:Tensor)->int:"""MACs estimation for `torch.nn.modules.conv._ConvNd`"""# For each position, # mult = kernel size, # adds = kernel size - 1window_macs_per_chan=reduce(mul,module.kernel_size)# Connections to input channels is controlled by the group parametereffective_in_chan=inp.shape[1]//module.groups# N * macwindow_mac=effective_in_chan*window_macs_per_chanconv_mac=out.numel()*window_mac# bias already counted in accumulationreturnconv_macdefmacs_bn(module:_BatchNorm,inp:Tensor,out:Tensor)->int:"""MACs estimation for `torch.nn.modules.batchnorm._BatchNorm`"""# sub mean, div by denomnorm_mac=1# mul by gamma, add betascale_mac=1ifmodule.affineelse0# Sum everything upbn_mac=inp.numel()*(norm_mac+scale_mac)# Count tracking stats update ops# cf. https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py#L94-L101tracking_mac=0b=inp.shape[0]num_spatial_elts=inp.shape[2:].numel()ifmodule.track_running_statsandmodule.training:# running_mean: by channel, sum value and div by batch sizetracking_mac+=module.num_features*(b*num_spatial_elts-1)# running_var: by channel, sub mean and square values, sum them, divide by batch sizeactive_elts=b*num_spatial_eltstracking_mac+=module.num_features*(2*active_elts-1)# Update both runnning stat: rescale previous value (mul by N), add it the new one, then div by (N + 1)tracking_mac+=2*module.num_features*2returnbn_mac+tracking_macdefmacs_maxpool(module:_MaxPoolNd,inp:Tensor,out:Tensor)->int:"""MACs estimation for `torch.nn.modules.pooling._MaxPoolNd`"""k_size=reduce(mul,module.kernel_size)ifisinstance(module.kernel_size,tuple)elsemodule.kernel_size# for each spatial output element, check max element in kernel scopereturnout.numel()*(k_size-1)defmacs_avgpool(module:_AvgPoolNd,inp:Tensor,out:Tensor)->int:"""MACs estimation for `torch.nn.modules.pooling._AvgPoolNd`"""k_size=reduce(mul,module.kernel_size)ifisinstance(module.kernel_size,tuple)elsemodule.kernel_size# for each spatial output element, sum elements in kernel scope and div by kernel sizereturnout.numel()*(k_size-1+inp.ndim-2)defmacs_adaptive_maxpool(module:_AdaptiveMaxPoolNd,inp:Tensor,out:Tensor)->int:"""MACs estimation for `torch.nn.modules.pooling._AdaptiveMaxPoolNd`"""# Approximate kernel_size using ratio of spatial shapes between input and outputkernel_size=tuple(i_size//o_sizeif(i_size%o_size)==0elsei_size-o_size*(i_size//o_size)+1fori_size,o_sizeinzip(inp.shape[2:],out.shape[2:]))# for each spatial output element, check max element in kernel scopereturnout.numel()*(reduce(mul,kernel_size)-1)defmacs_adaptive_avgpool(module:_AdaptiveAvgPoolNd,inp:Tensor,out:Tensor)->int:"""MACs estimation for `torch.nn.modules.pooling._AdaptiveAvgPoolNd`"""# Approximate kernel_size using ratio of spatial shapes between input and outputkernel_size=tuple(i_size//o_sizeif(i_size%o_size)==0elsei_size-o_size*(i_size//o_size)+1fori_size,o_sizeinzip(inp.shape[2:],out.shape[2:]))# for each spatial output element, sum elements in kernel scope and div by kernel sizereturnout.numel()*(reduce(mul,kernel_size)-1+len(kernel_size))