Source code for torchscan.modules.flops

# -*- coding: utf-8 -*-

"""
Module FLOPs
"""

import warnings
from functools import reduce
from operator import mul

from torch import nn
from torch.nn.modules.batchnorm import _BatchNorm
from torch.nn.modules.conv import _ConvNd, _ConvTransposeNd
from torch.nn.modules.pooling import _MaxPoolNd, _AvgPoolNd, _AdaptiveMaxPoolNd, _AdaptiveAvgPoolNd


__all__ = ['module_flops']


[docs] def module_flops(module, input, output): """Estimate the number of floating point operations performed by the module Args: module (torch.nn.Module): PyTorch module input (torch.Tensor): input to the module output (torch.Tensor): output of the module Returns: int: number of FLOPs """ if isinstance(module, nn.Identity): return 0 elif isinstance(module, nn.Linear): return flops_linear(module, input, output) elif isinstance(module, nn.ReLU): return flops_relu(module, input, output) elif isinstance(module, nn.ELU): return flops_elu(module, input, output) elif isinstance(module, nn.LeakyReLU): return flops_leakyrelu(module, input, output) elif isinstance(module, nn.ReLU6): return flops_relu6(module, input, output) elif isinstance(module, nn.Tanh): return flops_tanh(module, input, output) elif isinstance(module, nn.Sigmoid): return flops_sigmoid(module, input, output) elif isinstance(module, _ConvTransposeNd): return flops_convtransposend(module, input, output) elif isinstance(module, _ConvNd): return flops_convnd(module, input, output) elif isinstance(module, _BatchNorm): return flops_bn(module, input, output) elif isinstance(module, _MaxPoolNd): return flops_maxpool(module, input, output) elif isinstance(module, _AvgPoolNd): return flops_avgpool(module, input, output) elif isinstance(module, _AdaptiveMaxPoolNd): return flops_adaptive_maxpool(module, input, output) elif isinstance(module, _AdaptiveAvgPoolNd): return flops_adaptive_avgpool(module, input, output) elif isinstance(module, nn.Dropout): return flops_dropout(module, input, output) else: warnings.warn(f'Module type not supported: {module.__class__.__name__}') return 0
def flops_linear(module, input, output): """FLOPs estimation for `torch.nn.Linear`""" # batch size * out_chan * in_chan mm_flops = input.shape[0] * output.shape[1] * (2 * input.shape[1] - 1) bias_flops = output.numel() if module.bias is not None else 0 return mm_flops + bias_flops def flops_sigmoid(module, input, output): """FLOPs estimation for `torch.nn.Sigmoid`""" # For each element, mul by -1, exp it, add 1, div return input.numel() * 4 def flops_relu(module, input, output): """FLOPs estimation for `torch.nn.ReLU`""" # Each element is compared to 0 return input.numel() def flops_elu(module, input, output): """FLOPs estimation for `torch.nn.ELU`""" # For each element, compare it to 0, exp it, sub 1, mul by alpha, compare it to 0 and sum both return input.numel() * 6 def flops_leakyrelu(module, input, output): """FLOPs estimation for `torch.nn.LeakyReLU`""" # For each element, compare it to 0 (max), compare it to 0 (min), mul by slope and sum both return input.numel() * 4 def flops_relu6(module, input, output): """FLOPs estimation for `torch.nn.ReLU6`""" # For each element, compare it to 0 (max), compare it to 0 (min), mul by slope and sum both return input.numel() * 2 def flops_tanh(module, input, output): """FLOPs estimation for `torch.nn.Tanh`""" # For each element, exp it, mul by -1 and exp it, divide the sub by the add return input.numel() * 6 def flops_dropout(module, input, output): """FLOPs estimation for `torch.nn.Dropout`""" if module.p > 0: # Sample a random number for each input element return input.numel() else: return 0 def flops_convtransposend(module, input, output): """FLOPs estimation for `torch.nn.modules.conv._ConvTranposeNd`""" # Padding (# cf. https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/conv.py#L496-L532) # Define min and max sizes padding_flops = len(module.kernel_size) * 8 # Once padding is determined, the operations are almost identical to those of a convolution conv_flops = flops_convnd(module, input, output) return padding_flops + conv_flops def flops_convnd(module, input, output): """FLOPs estimation for `torch.nn.modules.conv._ConvNd`""" # For each position, # mult = kernel size, # adds = kernel size - 1 window_flops_per_chan = 2 * reduce(mul, module.kernel_size) - 1 # Connections to input channels is controlled by the group parameter effective_in_chan = (input.shape[1] // module.groups) # N * flops + (N - 1) additions window_flops = effective_in_chan * window_flops_per_chan + (effective_in_chan - 1) conv_flops = output.numel() * window_flops # Each output element gets a bias addition bias_flops = output.numel() if module.bias is not None else 0 return conv_flops + bias_flops def flops_bn(module, input, output): """FLOPs estimation for `torch.nn.modules.batchnorm._BatchNorm`""" # for each channel, add eps and running_var, sqrt it norm_ops = module.num_features * 2 # For each element, sub running_mean, div by denom norm_ops += input.numel() * 2 # For each element, mul by gamma, add beta scale_ops = input.numel() * 2 if module.affine else 0 bn_flops = norm_ops + scale_ops # Count tracking stats update ops # cf. https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py#L94-L101 tracking_flops = 0 if module.track_running_stats and module.training: # exponential_average_factor if module.momentum is None: tracking_flops += 1 # running_mean: by channel, sum values and div by batch size tracking_flops += module.num_features * (input.shape[0] * input.shape[2:].numel()) # running_var: by channel, sub mean and square values, sum them, divide by batch size tracking_flops += 3 * input.numel() # Update both runnning stat: rescale previous value (mul by N), add it the new one, then div by (N + 1) tracking_flops += 2 * module.num_features * 3 return bn_flops + tracking_flops def flops_maxpool(module, input, output): """FLOPs estimation for `torch.nn.modules.pooling._MaxPoolNd`""" k_size = reduce(mul, module.kernel_size) if isinstance(module.kernel_size, tuple) else module.kernel_size # for each spatial output element, check max element in kernel scope return output.numel() * (k_size - 1) def flops_avgpool(module, input, output): """FLOPs estimation for `torch.nn.modules.pooling._AvgPoolNd`""" k_size = reduce(mul, module.kernel_size) if isinstance(module.kernel_size, tuple) else module.kernel_size # for each spatial output element, sum elements in kernel scope and div by kernel size return output.numel() * (k_size - 1 + input.ndim - 2) def flops_adaptive_maxpool(module, input, output): """FLOPs estimation for `torch.nn.modules.pooling._AdaptiveMaxPoolNd`""" o_sizes = module.output_size if isinstance(module.output_size, tuple) else (module.output_size,) * (input.ndim - 2) # Approximate kernel_size using ratio of spatial shapes between input and output kernel_size = tuple(i_size // o_size if (i_size % o_size) == 0 else i_size - o_size * (i_size // o_size) + 1 for i_size, o_size in zip(input.shape[2:], o_sizes)) # for each spatial output element, check max element in kernel scope return output.numel() * (reduce(mul, kernel_size) - 1) def flops_adaptive_avgpool(module, input, output): """FLOPs estimation for `torch.nn.modules.pooling._AdaptiveAvgPoolNd`""" o_sizes = module.output_size if isinstance(module.output_size, tuple) else (module.output_size,) * (input.ndim - 2) # Approximate kernel_size using ratio of spatial shapes between input and output kernel_size = tuple(i_size // o_size if (i_size % o_size) == 0 else i_size - o_size * (i_size // o_size) + 1 for i_size, o_size in zip(input.shape[2:], o_sizes)) # for each spatial output element, sum elements in kernel scope and div by kernel size return output.numel() * (reduce(mul, kernel_size) - 1 + len(kernel_size))