# Copyright (C) 2020-2024, François-Guillaume Fernandez.
# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details.
from itertools import starmap
from typing import Any, Dict, List, Optional, Tuple
def format_name(name: str, depth: int = 0) -> str:
"""Format a string for nested data printing
Args:
name: input string
depth: depth of the nested information
Returns:
formatted string
"""
if depth == 0:
return name
if depth == 1:
return f"├─{name}"
return f"{'| ' * (depth - 1)}└─{name}"
def wrap_string(s: str, max_len: int, delimiter: str = ".", wrap: str = "[...]", mode: str = "end") -> str:
"""Wrap a string into a given length
Args:
s: input string
max_len: maximum string length
delimiter: character used for delimiting information categories
wrap: wrapping sequence used
mode: wrapping mode
Returns:
wrapped string
"""
if len(s) <= max_len or mode is None:
return s
if mode == "end":
return s[: max_len - len(wrap)] + wrap
if mode == "mid":
final_part = s.rpartition(delimiter)[-1]
wrapped_end = f"{wrap}.{final_part}"
return s[: max_len - len(wrapped_end)] + wrapped_end
raise ValueError("received an unexpected value of argument `mode`")
def unit_scale(val: float) -> Tuple[float, str]:
"""Rescale value using scale units
Args:
val: input value
Returns:
tuple of rescaled value and unit
"""
if val // 1e12 > 0:
return val / 1e12, "T"
if val // 1e9 > 0:
return val / 1e9, "G"
if val // 1e6 > 0:
return val / 1e6, "M"
if val // 1e3 > 0:
return val / 1e3, "k"
return val, ""
def format_s(f_string: str, min_w: Optional[int] = None, max_w: Optional[int] = None) -> str:
"""Format number strings"""
if isinstance(min_w, int):
f_string = f"{f_string:<{min_w}}"
if isinstance(max_w, int):
f_string = f"{f_string:.{max_w}}"
return f_string
def format_line_str(
layer: Dict[str, Any],
col_w: Optional[List[int]] = None,
wrap_mode: str = "mid",
receptive_field: bool = False,
effective_rf_stats: bool = False,
) -> List[str]:
"""Wrap all information into multiple lines"""
if not isinstance(col_w, list):
col_w = [None] * 7 # type: ignore[list-item]
max_len = col_w[0] + 3 if isinstance(col_w[0], int) else 100
line_str = [
format_s(wrap_string(format_name(layer["name"], layer["depth"]), max_len, mode=wrap_mode), col_w[0], col_w[0]),
format_s(layer["type"], col_w[1], col_w[1]),
format_s(str(layer["output_shape"]), col_w[2], col_w[2]),
format_s(f"{layer['grad_params'] + layer['nograd_params'] + layer['num_buffers']:,}", col_w[3], col_w[3]),
]
if receptive_field:
line_str.append(format_s(f"{layer['rf']:.0f}", col_w[4], col_w[4]))
if effective_rf_stats:
line_str.extend((
format_s(f"{layer['s']:.0f}", col_w[5], col_w[5]),
format_s(f"{layer['p']:.0f}", col_w[6], col_w[6]),
))
return line_str
[docs]
def aggregate_info(info: Dict[str, Any], max_depth: int) -> Dict[str, Any]:
"""Aggregate module information to a maximum depth
Args:
info: dictionary output of `crawl_module`
max_depth: depth at which parent node aggregates children information
Returns:
edited dictionary information
"""
if not any(layer["depth"] == max_depth for layer in info["layers"]):
raise ValueError("The `max_depth` argument cannot be higher than module depth.")
for fw_idx, layer in enumerate(info["layers"]):
# Need to aggregate information
if not layer["is_leaf"] and layer["depth"] == max_depth:
grad_p, nograd_p, p_size, num_buffers, b_size = 0, 0, 0, 0, 0
flops, macs, dmas = 0, 0, 0
for _layer in info["layers"][fw_idx + 1 :]:
# Children have superior depth and were hooked after parent
if _layer["depth"] <= max_depth:
break
# Aggregate all information (flops, macc, ram)
flops += _layer["flops"]
macs += _layer["macs"]
dmas += _layer["dmas"]
grad_p += _layer["grad_params"]
nograd_p += _layer["nograd_params"]
p_size += _layer["param_size"]
num_buffers += _layer["num_buffers"]
b_size += _layer["buffer_size"]
# Take last child effective RF
_rf, _s, _p = _layer["rf"], _layer["s"], _layer["p"]
# Update info
info["layers"][fw_idx]["flops"] = flops
info["layers"][fw_idx]["macs"] = macs
info["layers"][fw_idx]["dmas"] = dmas
info["layers"][fw_idx]["rf"] = _rf
info["layers"][fw_idx]["s"] = _s
info["layers"][fw_idx]["p"] = _p
info["layers"][fw_idx]["grad_params"] = grad_p
info["layers"][fw_idx]["nograd_params"] = nograd_p
info["layers"][fw_idx]["param_size"] = p_size
info["layers"][fw_idx]["num_buffers"] = num_buffers
info["layers"][fw_idx]["buffer_size"] = b_size
# Filter out further depth information
info["layers"] = [layer for layer in info["layers"] if layer["depth"] <= max_depth]
return info