mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add option to flop counter formula registration to get raw values (#110591)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110591 Approved by: https://github.com/awgu ghstack dependencies: #110501, #110504
This commit is contained in:
parent
9e72c9cccd
commit
ada65508d2
|
|
@ -155,6 +155,17 @@ class TestFlopCounter(TestCase):
|
|||
|
||||
self.assertExpectedInline(get_total_flops(mode), """5""")
|
||||
|
||||
def count(*args, out):
|
||||
return out.numel()
|
||||
count._get_raw = True
|
||||
|
||||
mode = FlopCounterMode(custom_mapping={torch.ops.aten.add: count})
|
||||
with mode:
|
||||
a = T(4, 5)
|
||||
a + a
|
||||
|
||||
self.assertExpectedInline(get_total_flops(mode), """20""")
|
||||
|
||||
def test_noop(self):
|
||||
mode = FlopCounterMode()
|
||||
with mode:
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from torch.utils._python_dispatch import TorchDispatchMode
|
|||
from torch.utils.hooks import RemovableHandle
|
||||
from torch._decomp import register_decomposition
|
||||
from math import prod
|
||||
from functools import wraps
|
||||
|
||||
|
||||
|
||||
|
|
@ -21,8 +22,17 @@ def get_shape(i):
|
|||
|
||||
flop_registry: Dict[Any, Any] = {}
|
||||
|
||||
def register_flop_formula(targets):
|
||||
def shape_wrapper(f):
|
||||
@wraps(f)
|
||||
def nf(*args, out=None, **kwargs):
|
||||
args, kwargs, out_shape = tree_map(get_shape, (args, kwargs, out))
|
||||
return f(*args, out_shape=out_shape, **kwargs)
|
||||
return nf
|
||||
|
||||
def register_flop_formula(targets, get_raw=False):
|
||||
def register_fun(flop_formula):
|
||||
if not get_raw:
|
||||
flop_formula = shape_wrapper(flop_formula)
|
||||
register_decomposition(targets, registry=flop_registry, unsafe=True)(flop_formula)
|
||||
return flop_formula
|
||||
|
||||
|
|
@ -273,7 +283,10 @@ class FlopCounterMode(TorchDispatchMode):
|
|||
self.mods = mods
|
||||
# Keys will include the modules in `mods` and their submodules
|
||||
self._module_to_forward_hook_handles: Dict[nn.Module, _ForwardHookHandles] = {}
|
||||
self.flop_registry = {**flop_registry, **custom_mapping}
|
||||
self.flop_registry = {
|
||||
**flop_registry,
|
||||
**{k: v if getattr(v, "_get_raw", False) else shape_wrapper(v) for k, v in custom_mapping.items()}
|
||||
}
|
||||
|
||||
def _register_forward_hooks(self):
|
||||
if self.mods is None:
|
||||
|
|
@ -439,8 +452,7 @@ class FlopCounterMode(TorchDispatchMode):
|
|||
func_packet = func._overloadpacket
|
||||
if func_packet in self.flop_registry:
|
||||
flop_count_func = self.flop_registry[func_packet]
|
||||
args, kwargs, out_shape = tree_map(get_shape, (args, kwargs, out))
|
||||
flop_count = flop_count_func(*args, **kwargs, out_shape=out_shape) # type: ignore[operator]
|
||||
flop_count = flop_count_func(*args, **kwargs, out=out) # type: ignore[operator]
|
||||
for par in self.parents:
|
||||
self.flop_counts[par][func_packet] += flop_count
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user