Source code for torchscan.crawler

# -*- coding: utf-8 -*-

"""
Module crawler
"""

import os

import torch

from .modules import module_dmas, module_flops, module_macs, module_rf
from .process import get_process_gpu_ram
from .utils import aggregate_info, format_info

__all__ = ['crawl_module', 'summary']


def apply(module, fn, depth=0, name=None):
    """Modified version of `torch.nn.Module.apply` method

    Args:
        module (torch.nn.Module): target module
        fn (callable): function to apply to each module
        depth (int, optional): current depth of `module`
        name (str, optional): name of the current module
    """

    if name is None:
        name = module.__class__.__name__.lower()
    fn(module, depth, name)
    for n, m in module.named_children():
        apply(m, fn, depth + 1, n)


[docs] def crawl_module(module, input_shape, dtype=None, max_depth=None): """Retrieves module information for an expected input tensor shape Example:: >>> import torch.nn as nn >>> from torchscan import summary >>> mod = nn.Conv2d(3, 8, 3) >>> module_info = crawl_module(mod, (3, 224, 224)) Args: module (torch.nn.Module): module to inspect input_shape (tuple<int>): expected input shapes dtype (type): data type of each input argument to the module max_depth (int, optional): maximum depth of layer information Returns: dict: layer and overhead information """ # Get device and data types from model p = next(module.parameters()) device = p.device cuda_overhead, framework_overhead = 0, 0 if torch.cuda.is_available(): # Process RAM - allocator RAM cuda_overhead = get_process_gpu_ram(os.getpid()) - (torch.cuda.memory_reserved() / 1024 ** 2) # Allocator RAM - Used RAM framework_overhead = (torch.cuda.memory_reserved() - torch.cuda.memory_allocated()) / 1024 ** 2 # input if isinstance(input_shape[0], int): input_shape = [input_shape] if dtype is None: dtype = p.data.dtype if isinstance(dtype, torch.dtype): dtype = [dtype] * len(input_shape) # Tensor arguments input_ts = [torch.rand(1, *in_shape).to(dtype=_dtype, device=device) for in_shape, _dtype in zip(input_shape, dtype)] # Hook definition def _hook_info(module, depth, name): def _pre_hook(module, input): """Pre-forward hook""" # Params grad_params, nograd_params, param_size = 0, 0, 0 num_buffers, buffer_size = 0, 0 is_shared = False if not any(module.children()): # Parameters for p in module.parameters(): if id(p) not in param_ids: if p.requires_grad: grad_params += p.data.numel() else: nograd_params += p.data.numel() param_size += p.data.numel() * p.data.element_size() param_ids.append(id(p)) else: is_shared = True # Buffers for b in module.buffers(): if id(b) not in param_ids: num_buffers += b.numel() buffer_size += b.numel() * b.element_size() param_ids.append(id(b)) else: is_shared = True call_idxs[id(module)] = len(info) info.append(dict(name=name, depth=depth, type=module.__class__.__name__, input_shape=(-1, *input[0][0].shape[1:]), output_shape=None, grad_params=grad_params, nograd_params=nograd_params, param_size=param_size, num_buffers=num_buffers, buffer_size=buffer_size, flops=0, macs=0, dmas=0, rf=1, s=1, p=0, is_shared=is_shared, is_leaf=not any(module.children()))) # Remove the hook by using its handle pre_fw_handle.remove() def _fwd_hook(module, input, output): """Post-forward hook""" # Retrieve forward index fw_idx = call_idxs[id(module)] if any(module.children()): tot_flops, tot_macs, tot_dmas = 0, 0, 0 current_rf, current_stride, current_padding = 1, 1, 0 else: # Compute stats for standalone layers tot_flops = module_flops(module, input[0], output) tot_macs = module_macs(module, input[0], output) tot_dmas = module_dmas(module, input[0], output) current_rf, current_stride, current_padding = module_rf(module, input[0], output) # Update layer information info[fw_idx]['output_shape'] = (-1, *output.shape[1:]) # Add them, since some modules can be used several times info[fw_idx]['flops'] = tot_flops info[fw_idx]['macs'] = tot_macs info[fw_idx]['dmas'] = tot_dmas # Compute receptive field info[fw_idx]['rf'] = current_rf info[fw_idx]['s'] = current_stride info[fw_idx]['p'] = current_padding # Remove the hook by using its handle post_fw_handle.remove() # Hook only leaf children pre_fw_handle = module.register_forward_pre_hook(_pre_hook) post_fw_handle = module.register_forward_hook(_fwd_hook) # Hook model info = [] param_ids = [] call_idxs = {} apply(module, _hook_info) # Forward with torch.no_grad(): module(*input_ts) reserved_ram, diff_ram = 0, 0 if torch.cuda.is_available(): reserved_ram = torch.cuda.memory_reserved() / 1024 ** 2 diff_ram = (torch.cuda.memory_reserved() - torch.cuda.memory_allocated()) / 1024 ** 2 torch.cuda.synchronize() torch.cuda.empty_cache() grad_params, nograd_params, param_size = 0, 0, 0 num_buffers, buffer_size = 0, 0 for p in module.parameters(): if p.requires_grad: grad_params += p.data.numel() else: nograd_params += p.data.numel() param_size += p.data.numel() * p.data.element_size() for b in module.buffers(): num_buffers += b.numel() buffer_size += b.numel() * b.element_size() # Update cumulative receptive field _rf, _s, _p = 1, 1, 0 for fw_idx, _layer in enumerate(info[::-1]): _rf = _layer['s'] * (_rf - 1) + _layer['rf'] _s *= _layer['s'] _p = _layer['s'] * _p + _layer['p'] info[len(info) - 1 - fw_idx]['rf'] = _rf info[len(info) - 1 - fw_idx]['s'] = _s info[len(info) - 1 - fw_idx]['p'] = _p return dict(overheads=dict(cuda=dict(pre=cuda_overhead, fwd=get_process_gpu_ram(os.getpid()) - reserved_ram), framework=dict(pre=framework_overhead, fwd=diff_ram)), layers=info, overall=dict(grad_params=grad_params, nograd_params=nograd_params, param_size=param_size, num_buffers=num_buffers, buffer_size=buffer_size))
[docs] def summary(module, input_shape, wrap_mode='mid', max_depth=None, receptive_field=False): """Print module summary for an expected input tensor shape Example:: >>> import torch.nn as nn >>> from torchscan import summary >>> mod = nn.Conv2d(3, 8, 3) >>> summary(mod, (3, 224, 224), receptive_field=True) Args: module (torch.nn.Module): module to inspect input_shape (tuple<int>): expected input shapes wrap_mode (str, optional): if a value is too long, where the wrapping should be performed max_depth (int, optional): maximum depth of layer information receptive_field (bool, optional): whether receptive field estimation should be performed """ # Get the summary dict module_info = crawl_module(module, input_shape) # Aggregate until max_depth if isinstance(max_depth, int): module_info = aggregate_info(module_info, max_depth) # Format it and print it print(format_info(module_info, wrap_mode, receptive_field))