mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
New output <img width="942" alt="image" src="https://user-images.githubusercontent.com/6355099/224794006-a993a2a8-d6ff-49da-8891-7b2373030a3d.png"> Pull Request resolved: https://github.com/pytorch/pytorch/pull/96581 Approved by: https://github.com/ngimel, https://github.com/shunting314, https://github.com/voznesenskym
400 lines
13 KiB
Python
400 lines
13 KiB
Python
import torch
|
|
from torch.utils._pytree import tree_map
|
|
from typing import List, Any, Dict, Optional, Union
|
|
from collections import defaultdict
|
|
from torch.utils._python_dispatch import TorchDispatchMode
|
|
from math import prod
|
|
|
|
__all__ = ["FlopCounterMode"]
|
|
|
|
aten = torch.ops.aten
|
|
|
|
def get_shape(i):
|
|
if isinstance(i, torch.Tensor):
|
|
return i.shape
|
|
return i
|
|
|
|
def mm_flop(a_shape, b_shape, out=None) -> int:
|
|
"""
|
|
Count flops for matmul.
|
|
"""
|
|
# Inputs should be a list of length 2.
|
|
# Inputs contains the shapes of two matrices.
|
|
m, k = a_shape
|
|
k2, n = b_shape
|
|
assert k == k2
|
|
# NB(chilli): Should be 2 * k - 1 technically for FLOPs.
|
|
return m * n * 2 * k
|
|
|
|
def addmm_flop(self_shape, a_shape, b_shape, out=None, **kwargs) -> int:
|
|
"""
|
|
Count flops for addmm
|
|
"""
|
|
return mm_flop(a_shape, b_shape)
|
|
|
|
def bmm_flop(a_shape, b_shape, out=None) -> int:
|
|
"""
|
|
Count flops for the bmm operation.
|
|
"""
|
|
# Inputs should be a list of length 2.
|
|
# Inputs contains the shapes of two tensor.
|
|
b, m, k = a_shape
|
|
b2, k2, n = b_shape
|
|
assert b == b2
|
|
assert k == k2
|
|
# NB(chilli): Should be 2 * k - 1 technically for FLOPs.
|
|
flop = b * m * n * 2 * k
|
|
return flop
|
|
|
|
def baddbmm_flop(self_shape, a_shape, b_shape, out=None) -> int:
|
|
"""
|
|
Count flops for the baddbmm operation.
|
|
"""
|
|
# Inputs should be a list of length 3.
|
|
# Inputs contains the shapes of three tensors.
|
|
return bmm_flop(a_shape, b_shape)
|
|
|
|
|
|
def conv_flop_count(
|
|
x_shape: List[int],
|
|
w_shape: List[int],
|
|
out_shape: List[int],
|
|
transposed: bool = False,
|
|
) -> int:
|
|
"""
|
|
Count flops for convolution. Note only multiplication is
|
|
counted. Computation for bias are ignored.
|
|
Flops for a transposed convolution are calculated as
|
|
flops = (x_shape[2:] * prod(w_shape) * batch_size).
|
|
Args:
|
|
x_shape (list(int)): The input shape before convolution.
|
|
w_shape (list(int)): The filter shape.
|
|
out_shape (list(int)): The output shape after convolution.
|
|
transposed (bool): is the convolution transposed
|
|
Returns:
|
|
int: the number of flops
|
|
"""
|
|
batch_size = x_shape[0]
|
|
conv_shape = (x_shape if transposed else out_shape)[2:]
|
|
c_out, c_in, *dims = w_shape
|
|
|
|
# NB(chilli): I don't think this properly accounts for padding :think:
|
|
# NB(chilli): Should be 2 * c_in - 1 technically for FLOPs.
|
|
flop = batch_size * prod(conv_shape) * c_out * prod(dims) * 2 * c_in
|
|
return flop
|
|
|
|
def conv_flop(x_shape, w_shape, _bias, _stride, _padding, _dilation, transposed, *args, out=None, **kwargs) -> int:
|
|
"""
|
|
Count flops for convolution.
|
|
"""
|
|
return conv_flop_count(x_shape, w_shape, out, transposed=transposed)
|
|
|
|
def transpose_shape(shape):
|
|
return [shape[1], shape[0]] + list(shape[2:])
|
|
|
|
def conv_backward_flop(
|
|
grad_out_shape,
|
|
x_shape,
|
|
w_shape,
|
|
_bias,
|
|
_stride,
|
|
_padding,
|
|
_dilation,
|
|
transposed,
|
|
_output_padding,
|
|
_groups,
|
|
output_mask,
|
|
out) -> int:
|
|
flop_count = 0
|
|
|
|
if output_mask[0]:
|
|
grad_input_shape = get_shape(out[0])
|
|
flop_count += conv_flop_count(grad_out_shape, w_shape, grad_input_shape, not transposed)
|
|
if output_mask[1]:
|
|
grad_weight_shape = get_shape(out[1])
|
|
flop_count += conv_flop_count(transpose_shape(x_shape), grad_out_shape, grad_weight_shape, transposed)
|
|
|
|
return flop_count
|
|
|
|
def sdpa_flop_count(query_shape, key_shape):
|
|
"""
|
|
Count flops for self-attention.
|
|
NB: We can assume that value_shape == key_shape
|
|
"""
|
|
b, h, s, d = query_shape
|
|
_b2, _h2, s2, _d2 = key_shape
|
|
assert b == _b2 and h == _h2 and d == _d2
|
|
total_flops = 0
|
|
# q: [b, h, s, d] @ k: [b, h, d, s2] -> scores: [b, h, s, s2]
|
|
total_flops += bmm_flop((b * h, s, d), (b * h, d, s2))
|
|
# scores: [b, h, s, s2] @ v: [b, h, s2, d] -> out: [b, h, s, d]
|
|
total_flops += bmm_flop((b * h, s, s2), (b * h, s2, d))
|
|
return total_flops
|
|
|
|
|
|
|
|
def sdpa_flop(query_shape, key_shape, value_shape, *args, out=None, **kwargs) -> int:
|
|
"""
|
|
Count flops for self-attention.
|
|
"""
|
|
# NB: We aren't accounting for causal attention here
|
|
return sdpa_flop_count(query_shape, key_shape)
|
|
|
|
|
|
def sdpa_backward_flop_count(grad_out_shape, query_shape, key_shape):
|
|
total_flops = 0
|
|
b, h, s, d = query_shape
|
|
_b2, _h2, s2, _d2 = key_shape
|
|
_b3, _h3, _s3, _d3 = grad_out_shape
|
|
assert b == _b2 == _b3 and h == _h2 == _h3 and d == _d2 == _d3
|
|
assert s == _s3
|
|
total_flops = 0
|
|
# Step 1: We recompute the scores matrix.
|
|
# q: [b, h, s, d] @ k: [b, h, d, s2] -> scores: [b, h, s, s2]
|
|
total_flops += bmm_flop((b * h, s, d), (b * h, d, s2))
|
|
|
|
# Step 2: We propagate the gradients through the score @ v operation.
|
|
# gradOut: [b, h, s, d] @ v: [b, h, d, s2] -> gradScores: [b, h, s, s2]
|
|
total_flops += bmm_flop((b * h, s, d), (b * h, d, s2))
|
|
# scores: [b, h, s2, s] @ gradOut: [b, h, s, d] -> gradV: [b, h, s2, d]
|
|
total_flops += bmm_flop((b * h, s2, s), (b * h, s, d))
|
|
|
|
# Step 3: We propagate th gradients through the k @ v operation
|
|
# gradScores: [b, h, s, s2] @ k: [b, h, s2, d] -> gradQ: [b, h, s, d]
|
|
total_flops += bmm_flop((b * h, s, s2), (b * h, s2, d))
|
|
# q: [b, h, d, s] @ gradScores: [b, h, s, s2] -> gradK: [b, h, d, s2]
|
|
total_flops += bmm_flop((b * h, d, s), (b * h, s, s2))
|
|
return total_flops
|
|
|
|
|
|
def sdpa_backward_flop(grad_out_shape, query_shape, key_shape, value_shape, *args, out=None, **kwargs) -> int:
|
|
"""
|
|
Count flops for self-attention backward.
|
|
"""
|
|
return sdpa_backward_flop_count(grad_out_shape, query_shape, key_shape)
|
|
|
|
flop_mapping = {
|
|
aten.mm: mm_flop,
|
|
aten.addmm: addmm_flop,
|
|
aten.bmm: bmm_flop,
|
|
aten.baddbmm: baddbmm_flop,
|
|
aten.convolution: conv_flop,
|
|
aten._convolution: conv_flop,
|
|
aten.convolution_backward: conv_backward_flop,
|
|
aten._scaled_dot_product_efficient_attention: sdpa_flop,
|
|
aten._scaled_dot_product_flash_attention: sdpa_flop,
|
|
aten._scaled_dot_product_efficient_attention_backward: sdpa_backward_flop,
|
|
aten._scaled_dot_product_flash_attention_backward: sdpa_backward_flop,
|
|
}
|
|
|
|
def normalize_tuple(x):
|
|
if not isinstance(x, tuple):
|
|
return (x,)
|
|
return x
|
|
|
|
|
|
# Define the suffixes for different orders of magnitude
|
|
suffixes = ["", "K", "M", "B", "T"]
|
|
# Thanks BingChat!
|
|
def get_suffix_str(number):
|
|
# Find the index of the appropriate suffix based on the number of digits
|
|
# with some additional overflow.
|
|
# i.e. 1.01B should be displayed as 1001M, not 1.001B
|
|
index = max(0, min(len(suffixes) - 1, (len(str(number)) - 3) // 3))
|
|
return suffixes[index]
|
|
|
|
def convert_num_with_suffix(number, suffix):
|
|
index = suffixes.index(suffix)
|
|
# Divide the number by 1000^index and format it to two decimal places
|
|
value = "{:.3f}".format(number / (1000 ** index))
|
|
# Return the value and the suffix as a string
|
|
return value + suffixes[index]
|
|
|
|
class FlopCounterMode(TorchDispatchMode):
|
|
"""
|
|
``FlopCounterMode`` is a context manager that counts the number of
|
|
flops within its context. It does this using a ``TorchDispatchMode``.
|
|
|
|
It also supports hierarchical output by passing a module (or list of modules) to FlopCounterMode on construction.
|
|
|
|
Example usage
|
|
|
|
.. code-block:: python
|
|
|
|
mod = ...
|
|
flop_counter = FlopCounterMode(mod)
|
|
with flop_counter:
|
|
mod.sum().backward()
|
|
|
|
"""
|
|
def __init__(
|
|
self,
|
|
mods: Optional[Union[torch.nn.Module, List[torch.nn.Module]]] = None,
|
|
depth: int = 2,
|
|
display: bool = True,
|
|
custom_mapping: Dict[Any, Any] = None):
|
|
self.flop_counts: Dict[str, Dict[Any, int]] = defaultdict(lambda: defaultdict(int))
|
|
self.depth = depth
|
|
self.parents = ["Global"]
|
|
self.display = display
|
|
if custom_mapping is None:
|
|
custom_mapping = {}
|
|
if isinstance(mods, torch.nn.Module):
|
|
mods = [mods]
|
|
self.mods = mods
|
|
if mods is not None:
|
|
for mod in mods:
|
|
prefix = type(mod).__name__
|
|
for name, module in dict(mod.named_modules()).items():
|
|
if name == "":
|
|
name = prefix
|
|
else:
|
|
name = ".".join([prefix, name])
|
|
module.register_forward_pre_hook(self._enter_module(name))
|
|
module.register_forward_hook(self._exit_module(name))
|
|
self.flop_mapping = {**flop_mapping, **custom_mapping}
|
|
|
|
def _enter_module(self, name):
|
|
def f(module, inputs):
|
|
inputs = normalize_tuple(inputs)
|
|
out = self._create_pre_module(name)(*inputs)
|
|
return out
|
|
|
|
return f
|
|
|
|
def _exit_module(self, name):
|
|
def f(module, inputs, outputs):
|
|
outputs = normalize_tuple(outputs)
|
|
return self._create_post_module(name)(*outputs)
|
|
return f
|
|
|
|
def _create_post_module(self, name):
|
|
class PushState(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, *args):
|
|
assert(self.parents[-1] == name)
|
|
self.parents.pop()
|
|
args = tree_map(lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args)
|
|
if len(args) == 1:
|
|
return args[0]
|
|
return args
|
|
|
|
@staticmethod
|
|
def backward(ctx, *grad_outs):
|
|
self.parents.append(name)
|
|
return grad_outs
|
|
|
|
return PushState.apply
|
|
|
|
def _create_pre_module(self, name):
|
|
class PopState(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, *args):
|
|
self.parents.append(name)
|
|
args = tree_map(lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args)
|
|
if len(args) == 1:
|
|
return args[0]
|
|
return args
|
|
|
|
@staticmethod
|
|
def backward(ctx, *grad_outs):
|
|
assert(self.parents[-1] == name)
|
|
self.parents.pop()
|
|
return grad_outs
|
|
|
|
return PopState.apply
|
|
|
|
def get_flop_counts(self) -> Dict[str, Dict[Any, int]]:
|
|
"""Returns the flop counts as a dictionary of dictionaries. The outer
|
|
dictionary is keyed by module name, and the inner dictionary is keyed by
|
|
operation name.
|
|
|
|
Returns:
|
|
Dict[str, Dict[Any, int]]: The flop counts as a dictionary.
|
|
"""
|
|
return dict(self.flop_counts)
|
|
|
|
def get_table(self, depth=None):
|
|
if depth is None:
|
|
depth = self.depth
|
|
if depth is None:
|
|
depth = 999999
|
|
|
|
import tabulate
|
|
tabulate.PRESERVE_WHITESPACE = True
|
|
header = ["Module", "FLOP", "% Total"]
|
|
values = []
|
|
global_flops = sum(self.flop_counts['Global'].values())
|
|
global_suffix = get_suffix_str(global_flops)
|
|
is_global_subsumed = False
|
|
|
|
def process_mod(mod_name, depth):
|
|
nonlocal is_global_subsumed
|
|
|
|
total_flops = sum(self.flop_counts[mod_name].values())
|
|
|
|
is_global_subsumed |= total_flops >= global_flops
|
|
|
|
padding = " " * depth
|
|
values = []
|
|
values.append([
|
|
padding + mod_name,
|
|
convert_num_with_suffix(total_flops, global_suffix),
|
|
"{:.2f}%".format(total_flops / global_flops * 100)
|
|
])
|
|
for k, v in self.flop_counts[mod_name].items():
|
|
values.append([
|
|
padding + " - " + str(k),
|
|
convert_num_with_suffix(v, global_suffix),
|
|
"{:.2f}%".format(v / global_flops * 100)
|
|
])
|
|
return values
|
|
|
|
for mod in self.flop_counts.keys():
|
|
if mod == 'Global':
|
|
continue
|
|
mod_depth = mod.count(".") + 1
|
|
if mod_depth > depth:
|
|
continue
|
|
|
|
cur_values = process_mod(mod, mod_depth - 1)
|
|
for value in cur_values:
|
|
values.append(value)
|
|
|
|
# We do a bit of messing around here to only output the "Global" value
|
|
# if there are any FLOPs in there that aren't already fully contained by
|
|
# a module.
|
|
if 'Global' in self.flop_counts and not is_global_subsumed:
|
|
for idx, value in enumerate(values):
|
|
values[idx][0] = " " + values[idx][0]
|
|
|
|
values = process_mod('Global', 0) + values
|
|
|
|
if len(values) == 0:
|
|
values = [["Global", "0", "0%"]]
|
|
|
|
return tabulate.tabulate(values, headers=header, colalign=("left", "right", "right"))
|
|
|
|
def __enter__(self):
|
|
self.flop_counts.clear()
|
|
super().__enter__()
|
|
return self
|
|
|
|
def __exit__(self, *args):
|
|
if self.display:
|
|
print(self.get_table(self.depth))
|
|
super().__exit__(*args)
|
|
|
|
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
|
kwargs = kwargs if kwargs else {}
|
|
out = func(*args, **kwargs)
|
|
func_packet = func._overloadpacket
|
|
if func_packet in self.flop_mapping:
|
|
flop_count_func = self.flop_mapping[func_packet]
|
|
args, kwargs, out_shape = tree_map(get_shape, (args, kwargs, out))
|
|
flop_count = flop_count_func(*args, **kwargs, out=out_shape) # type: ignore[operator]
|
|
for par in self.parents:
|
|
self.flop_counts[par][func_packet] += flop_count
|
|
|
|
return out
|