Source code for torchscan.modules.macs

# Copyright (C) 2020-2024, François-Guillaume Fernandez.

# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details.

import warnings
from functools import reduce
from operator import mul

from torch import Tensor, nn
from torch.nn import Module
from torch.nn.modules.batchnorm import _BatchNorm
from torch.nn.modules.conv import _ConvNd, _ConvTransposeNd
from torch.nn.modules.pooling import _AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _MaxPoolNd

__all__ = ["module_macs"]


[docs] def module_macs(module: Module, inp: Tensor, out: Tensor) -> int: """Estimate the number of multiply-accumulation operations performed by the module Args: module (torch.nn.Module): PyTorch module inp (torch.Tensor): input to the module out (torch.Tensor): output of the module Returns: int: number of MACs """ if isinstance(module, nn.Linear): return macs_linear(module, inp, out) if isinstance(module, (nn.Identity, nn.ReLU, nn.ELU, nn.LeakyReLU, nn.ReLU6, nn.Tanh, nn.Sigmoid, nn.Flatten)): return 0 if isinstance(module, _ConvTransposeNd): return macs_convtransposend(module, inp, out) if isinstance(module, _ConvNd): return macs_convnd(module, inp, out) if isinstance(module, _BatchNorm): return macs_bn(module, inp, out) if isinstance(module, _MaxPoolNd): return macs_maxpool(module, inp, out) if isinstance(module, _AvgPoolNd): return macs_avgpool(module, inp, out) if isinstance(module, _AdaptiveMaxPoolNd): return macs_adaptive_maxpool(module, inp, out) if isinstance(module, _AdaptiveAvgPoolNd): return macs_adaptive_avgpool(module, inp, out) if isinstance(module, nn.Dropout): return 0 warnings.warn(f"Module type not supported: {module.__class__.__name__}", stacklevel=1) return 0
def macs_linear(module: nn.Linear, _: Tensor, out: Tensor) -> int: """MACs estimation for `torch.nn.Linear`""" # batch size * out_chan * macs_per_elt (bias already counted in accumulation) return module.in_features * reduce(mul, out.shape) def macs_convtransposend(module: _ConvTransposeNd, inp: Tensor, out: Tensor) -> int: """MACs estimation for `torch.nn.modules.conv._ConvTransposeNd`""" # Padding (# cf. https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/conv.py#L496-L532) # Define min and max sizes, then subtract them padding_macs = len(module.kernel_size) * 4 # Rest of the operations are almost identical to a convolution (given the padding) conv_macs = macs_convnd(module, inp, out) return padding_macs + conv_macs def macs_convnd(module: _ConvNd, inp: Tensor, out: Tensor) -> int: """MACs estimation for `torch.nn.modules.conv._ConvNd`""" # For each position, # mult = kernel size, # adds = kernel size - 1 window_macs_per_chan = reduce(mul, module.kernel_size) # Connections to input channels is controlled by the group parameter effective_in_chan = inp.shape[1] // module.groups # N * mac window_mac = effective_in_chan * window_macs_per_chan return out.numel() * window_mac # bias already counted in accumulation def macs_bn(module: _BatchNorm, inp: Tensor, _: Tensor) -> int: """MACs estimation for `torch.nn.modules.batchnorm._BatchNorm`""" # sub mean, div by denom norm_mac = 1 # mul by gamma, add beta scale_mac = 1 if module.affine else 0 # Sum everything up bn_mac = inp.numel() * (norm_mac + scale_mac) # Count tracking stats update ops # cf. https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py#L94-L101 tracking_mac = 0 b = inp.shape[0] num_spatial_elts = inp.shape[2:].numel() if module.track_running_stats and module.training: # running_mean: by channel, sum value and div by batch size tracking_mac += module.num_features * (b * num_spatial_elts - 1) # running_var: by channel, sub mean and square values, sum them, divide by batch size active_elts = b * num_spatial_elts tracking_mac += module.num_features * (2 * active_elts - 1) # Update both runnning stat: rescale previous value (mul by N), add it the new one, then div by (N + 1) tracking_mac += 2 * module.num_features * 2 return bn_mac + tracking_mac def macs_maxpool(module: _MaxPoolNd, _: Tensor, out: Tensor) -> int: """MACs estimation for `torch.nn.modules.pooling._MaxPoolNd`""" k_size = reduce(mul, module.kernel_size) if isinstance(module.kernel_size, tuple) else module.kernel_size # for each spatial output element, check max element in kernel scope return out.numel() * (k_size - 1) def macs_avgpool(module: _AvgPoolNd, inp: Tensor, out: Tensor) -> int: """MACs estimation for `torch.nn.modules.pooling._AvgPoolNd`""" k_size = reduce(mul, module.kernel_size) if isinstance(module.kernel_size, tuple) else module.kernel_size # for each spatial output element, sum elements in kernel scope and div by kernel size return out.numel() * (k_size - 1 + inp.ndim - 2) def macs_adaptive_maxpool(_: _AdaptiveMaxPoolNd, inp: Tensor, out: Tensor) -> int: """MACs estimation for `torch.nn.modules.pooling._AdaptiveMaxPoolNd`""" # Approximate kernel_size using ratio of spatial shapes between input and output kernel_size = tuple( i_size // o_size if (i_size % o_size) == 0 else i_size - o_size * (i_size // o_size) + 1 for i_size, o_size in zip(inp.shape[2:], out.shape[2:], strict=False) ) # for each spatial output element, check max element in kernel scope return out.numel() * (reduce(mul, kernel_size) - 1) def macs_adaptive_avgpool(_: _AdaptiveAvgPoolNd, inp: Tensor, out: Tensor) -> int: """MACs estimation for `torch.nn.modules.pooling._AdaptiveAvgPoolNd`""" # Approximate kernel_size using ratio of spatial shapes between input and output kernel_size = tuple( i_size // o_size if (i_size % o_size) == 0 else i_size - o_size * (i_size // o_size) + 1 for i_size, o_size in zip(inp.shape[2:], out.shape[2:], strict=False) ) # for each spatial output element, sum elements in kernel scope and div by kernel size return out.numel() * (reduce(mul, kernel_size) - 1 + len(kernel_size))