Add option to flop counter formula registration to get raw values (#110591)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/110591
Approved by: https://github.com/awgu
ghstack dependencies: #110501, #110504
This commit is contained in:
chilli 2023-10-05 01:28:54 -07:00 committed by PyTorch MergeBot
parent 9e72c9cccd
commit ada65508d2
2 changed files with 27 additions and 4 deletions

View File

@ -155,6 +155,17 @@ class TestFlopCounter(TestCase):
self.assertExpectedInline(get_total_flops(mode), """5""")
def count(*args, out):
return out.numel()
count._get_raw = True
mode = FlopCounterMode(custom_mapping={torch.ops.aten.add: count})
with mode:
a = T(4, 5)
a + a
self.assertExpectedInline(get_total_flops(mode), """20""")
def test_noop(self):
mode = FlopCounterMode()
with mode:

View File

@ -7,6 +7,7 @@ from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils.hooks import RemovableHandle
from torch._decomp import register_decomposition
from math import prod
from functools import wraps
@ -21,8 +22,17 @@ def get_shape(i):
flop_registry: Dict[Any, Any] = {}
def register_flop_formula(targets):
def shape_wrapper(f):
@wraps(f)
def nf(*args, out=None, **kwargs):
args, kwargs, out_shape = tree_map(get_shape, (args, kwargs, out))
return f(*args, out_shape=out_shape, **kwargs)
return nf
def register_flop_formula(targets, get_raw=False):
def register_fun(flop_formula):
if not get_raw:
flop_formula = shape_wrapper(flop_formula)
register_decomposition(targets, registry=flop_registry, unsafe=True)(flop_formula)
return flop_formula
@ -273,7 +283,10 @@ class FlopCounterMode(TorchDispatchMode):
self.mods = mods
# Keys will include the modules in `mods` and their submodules
self._module_to_forward_hook_handles: Dict[nn.Module, _ForwardHookHandles] = {}
self.flop_registry = {**flop_registry, **custom_mapping}
self.flop_registry = {
**flop_registry,
**{k: v if getattr(v, "_get_raw", False) else shape_wrapper(v) for k, v in custom_mapping.items()}
}
def _register_forward_hooks(self):
if self.mods is None:
@ -439,8 +452,7 @@ class FlopCounterMode(TorchDispatchMode):
func_packet = func._overloadpacket
if func_packet in self.flop_registry:
flop_count_func = self.flop_registry[func_packet]
args, kwargs, out_shape = tree_map(get_shape, (args, kwargs, out))
flop_count = flop_count_func(*args, **kwargs, out_shape=out_shape) # type: ignore[operator]
flop_count = flop_count_func(*args, **kwargs, out=out) # type: ignore[operator]
for par in self.parents:
self.flop_counts[par][func_packet] += flop_count