Source code for torchscan.modules.memory

# -*- coding: utf-8 -*-

"""
Module DMAs
"""

import warnings
from functools import reduce
from operator import mul

from torch import nn
from torch.nn.modules.batchnorm import _BatchNorm
from torch.nn.modules.conv import _ConvNd, _ConvTransposeNd
from torch.nn.modules.pooling import _MaxPoolNd, _AvgPoolNd, _AdaptiveMaxPoolNd, _AdaptiveAvgPoolNd


__all__ = ['module_dmas']


[docs] def module_dmas(module, input, output): """Estimate the number of direct memory accesses by the module. The implementation overhead is neglected Args: module (torch.nn.Module): PyTorch module input (torch.Tensor): input to the module output (torch.Tensor): output of the module Returns: int: number of DMAs """ if isinstance(module, nn.Identity): return dmas_identity(module, input, output) elif isinstance(module, nn.Linear): return dmas_linear(module, input, output) elif isinstance(module, (nn.ReLU, nn.ReLU6)): return dmas_relu(module, input, output) elif isinstance(module, (nn.ELU, nn.LeakyReLU)): return dmas_act_single_param(module, input, output) elif isinstance(module, nn.Sigmoid): return dmas_sigmoid(module, input, output) elif isinstance(module, nn.Tanh): return dmas_tanh(module, input, output) elif isinstance(module, _ConvTransposeNd): return dmas_convtransposend(module, input, output) elif isinstance(module, _ConvNd): return dmas_convnd(module, input, output) elif isinstance(module, _BatchNorm): return dmas_bn(module, input, output) elif isinstance(module, (_MaxPoolNd, _AvgPoolNd)): return dmas_pool(module, input, output) elif isinstance(module, (_AdaptiveMaxPoolNd, _AdaptiveAvgPoolNd)): return dmas_adaptive_pool(module, input, output) elif isinstance(module, nn.Dropout): return dmas_dropout(module, input, output) else: warnings.warn(f'Module type not supported: {module.__class__.__name__}') return 0
def num_params(module): """Compute the number of parameters Args: module (torch.nn.Module): PyTorch module Returns: int: number of parameter elements """ return sum(p.data.numel() for p in module.parameters()) def dmas_identity(module, input, output): """DMAs estimation for `torch.nn.Identity`""" return input.numel() def dmas_linear(module, input, output): """DMAs estimation for `torch.nn.Linear`""" input_dma = input.numel() # Access weight and bias ops_dma = num_params(module) output_dma = output.numel() return input_dma + ops_dma + output_dma def dmas_relu(module, input, output): """DMAs estimation for `torch.nn.ReLU`""" input_dma = input.numel() output_dma = 0 if module.inplace else output.numel() return input_dma + output_dma def dmas_act_single_param(module, input, output): """DMAs estimation for activations with single parameter""" input_dma = input.numel() # Access alpha, slope or other ops_dma = 1 output_dma = 0 if module.inplace else output.numel() return input_dma + ops_dma + output_dma def dmas_sigmoid(module, input, output): """DMAs estimation for `torch.nn.Sigmoid`""" # Access for both exp input_dma = input.numel() output_dma = output.numel() return input_dma + output_dma def dmas_tanh(module, input, output): """DMAs estimation for `torch.nn.Tanh`""" # Access for both exp input_dma = input.numel() * 2 output_dma = output.numel() return input_dma + output_dma def dmas_dropout(module, input, output): """DMAs estimation for `torch.nn.Dropout`""" input_dma = input.numel() # Access sampling probability ops_dma = 1 output_dma = 0 if module.inplace else output.numel() return input_dma + ops_dma + output_dma def dmas_convtransposend(module, input, output): """DMAs estimation for `torch.nn.modules.conv._ConvTransposeNd`""" # Padding (# cf. https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/conv.py#L496-L532) # Access stride, padding and kernel_size in_padding = len(module.kernel_size) * 4 out_padding = len(module.kernel_size) # The rest is like a classic convolution conv_dmas = dmas_convnd(module, input, output) return in_padding + out_padding + conv_dmas def dmas_convnd(module, input, output): """DMAs estimation for `torch.nn.modules.conv._ConvNd`""" # Each output element required K ** 2 memory access of each input channel input_dma = module.in_channels * reduce(mul, module.kernel_size) * output.numel() # Correct with groups input_dma /= module.groups # Access weight & bias ops_dma = num_params(module) output_dma = output.numel() return input_dma + ops_dma + output_dma def dmas_bn(module, input, output): """DMAs estimation for `torch.nn.modules.batchnorm._BatchNorm`""" input_dma = input.numel() # Access running_mean, running_var and eps ops_dma = module.running_mean.numel() + module.running_var.numel() + 1 # Access to weight and bias if module.affine: ops_dma += module.weight.data.numel() + module.bias.data.numel() # Exp avg factor if module.momentum: ops_dma += 1 # Update stats if module.training and module.track_running_stats: # Current mean and std computation only requires access to input, already counted in input_dma # Update num of batches and running stats ops_dma += module.num_batches_tracked.numel() + module.running_mean.numel() + module.running_var.numel() output_dma = output.numel() return input_dma + ops_dma + output_dma def dmas_pool(module, input, output): """DMAs estimation for spatial pooling modules""" # Resolve kernel size and stride size (can be stored as a single integer or a tuple) if isinstance(module.kernel_size, tuple): kernel_size = module.kernel_size else: kernel_size = (module.kernel_size,) * (input.ndim - 2) # Each output element required K ** 2 memory accesses input_dma = reduce(mul, kernel_size) * output.numel() output_dma = output.numel() return input_dma + output_dma def dmas_adaptive_pool(module, input, output): """DMAs estimation for adaptive spatial pooling modules""" o_sizes = module.output_size if isinstance(module.output_size, tuple) else (module.output_size,) * (input.ndim - 2) # Approximate kernel_size using ratio of spatial shapes between input and output kernel_size = tuple(i_size // o_size if (i_size % o_size) == 0 else i_size - o_size * (i_size // o_size) + 1 for i_size, o_size in zip(input.shape[2:], o_sizes)) # Each output element required K ** 2 memory accesses input_dma = reduce(mul, kernel_size) * output.numel() output_dma = output.numel() return input_dma + output_dma