mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
FlopCounterMode: Decompose ops for inference mode (#138508)
Fixes #126268 I've basically followed @ezyang suggestion (I think) to use `func.decompose(...)`. Since `__torch_dispatch__` won't be called a second time for the same op, I've added a second `TorchDispatchMode` (`_DecomposedCounterMode`) that simpy dispatches to the parent flop counter. Using `self` as the inner context manager is not possible, since the second call to `__enter__` would re-initialize the counter's tracking state. Let me know if there's something wrong with this implementation, since I'm quite unsure how the decomposition thing actually works :D Pull Request resolved: https://github.com/pytorch/pytorch/pull/138508 Approved by: https://github.com/ezyang
This commit is contained in:
parent
4488e23763
commit
f915409c26
|
|
@ -810,6 +810,30 @@ class TestFlopCounter(TestCase):
|
||||||
self.assertEqual(called, 1)
|
self.assertEqual(called, 1)
|
||||||
self.assertExpectedInline(get_total_flops(mode), """9001""")
|
self.assertExpectedInline(get_total_flops(mode), """9001""")
|
||||||
|
|
||||||
|
@skipIfNoTorchVision
|
||||||
|
def test_inference_mode(self):
|
||||||
|
def get_flops(model):
|
||||||
|
with FlopCounterMode(model) as mode:
|
||||||
|
a = T(1, 3, 224, 224)
|
||||||
|
model(a).sum()
|
||||||
|
return mode
|
||||||
|
|
||||||
|
resnet18 = torchvision_models.resnet18()
|
||||||
|
|
||||||
|
mode_standard = get_flops(resnet18)
|
||||||
|
|
||||||
|
with torch.inference_mode():
|
||||||
|
mode_inference = get_flops(resnet18)
|
||||||
|
|
||||||
|
self.assertEqual(get_total_flops(mode_standard), get_total_flops(mode_inference))
|
||||||
|
|
||||||
|
layer1_conv_flops_standard = mode_standard.flop_counts["ResNet.layer1"][
|
||||||
|
torch.ops.aten.convolution
|
||||||
|
]
|
||||||
|
layer1_conv_flops_inference = mode_inference.flop_counts["ResNet.layer1"][
|
||||||
|
torch.ops.aten.convolution
|
||||||
|
]
|
||||||
|
self.assertEqual(layer1_conv_flops_standard, layer1_conv_flops_inference)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
run_tests()
|
run_tests()
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
# mypy: allow-untyped-defs
|
# mypy: allow-untyped-defs
|
||||||
# mypy: allow-untyped-decorators
|
# mypy: allow-untyped-decorators
|
||||||
import torch
|
import torch
|
||||||
|
from torch._C import DispatchKey
|
||||||
from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
|
from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
|
||||||
from .module_tracker import ModuleTracker
|
from .module_tracker import ModuleTracker
|
||||||
from typing import List, Any, Dict, Optional, Union, Tuple, Iterator
|
from typing import List, Any, Dict, Optional, Union, Tuple, Iterator
|
||||||
|
|
@ -632,6 +633,7 @@ class FlopCounterMode(TorchDispatchMode):
|
||||||
**{k: v if getattr(v, "_get_raw", False) else shape_wrapper(v) for k, v in custom_mapping.items()}
|
**{k: v if getattr(v, "_get_raw", False) else shape_wrapper(v) for k, v in custom_mapping.items()}
|
||||||
}
|
}
|
||||||
self.mod_tracker = ModuleTracker()
|
self.mod_tracker = ModuleTracker()
|
||||||
|
self.decomposed_counter = _DecomposedCounterMode(self)
|
||||||
|
|
||||||
def get_total_flops(self) -> int:
|
def get_total_flops(self) -> int:
|
||||||
return sum(self.flop_counts['Global'].values())
|
return sum(self.flop_counts['Global'].values())
|
||||||
|
|
@ -722,8 +724,34 @@ class FlopCounterMode(TorchDispatchMode):
|
||||||
|
|
||||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||||
kwargs = kwargs if kwargs else {}
|
kwargs = kwargs if kwargs else {}
|
||||||
out = func(*args, **kwargs)
|
|
||||||
return self._count_flops(func._overloadpacket, out, args, kwargs)
|
# Skip ops from non-standard dispatch_sizes_strides_policy such as NJT
|
||||||
|
if func in {torch.ops.aten.is_contiguous.default,
|
||||||
|
torch.ops.aten.is_contiguous.memory_format,
|
||||||
|
torch.ops.aten.is_strides_like_format.default,
|
||||||
|
torch.ops.aten.is_non_overlapping_and_dense.default,
|
||||||
|
torch.ops.aten.size.default,
|
||||||
|
torch.ops.aten.sym_size.default,
|
||||||
|
torch.ops.aten.stride.default,
|
||||||
|
torch.ops.aten.sym_stride.default,
|
||||||
|
torch.ops.aten.storage_offset.default,
|
||||||
|
torch.ops.aten.sym_storage_offset.default,
|
||||||
|
torch.ops.aten.numel.default,
|
||||||
|
torch.ops.aten.sym_numel.default,
|
||||||
|
torch.ops.aten.dim.default,
|
||||||
|
torch.ops.prim.layout.default}:
|
||||||
|
|
||||||
|
return NotImplemented
|
||||||
|
|
||||||
|
dk = DispatchKey.CompositeImplicitAutograd
|
||||||
|
if torch._C._dispatch_has_kernel_for_dispatch_key(func.name(), dk):
|
||||||
|
# func can be decomposed; redispatch
|
||||||
|
with self.decomposed_counter:
|
||||||
|
return func._op_dk(dk, *args, **kwargs)
|
||||||
|
else:
|
||||||
|
# no further decomposition; execute & count flops
|
||||||
|
out = func(*args, **kwargs)
|
||||||
|
return self._count_flops(func._overloadpacket, out, args, kwargs)
|
||||||
|
|
||||||
def _count_flops(self, func_packet, out, args, kwargs):
|
def _count_flops(self, func_packet, out, args, kwargs):
|
||||||
if func_packet in self.flop_registry:
|
if func_packet in self.flop_registry:
|
||||||
|
|
@ -733,3 +761,12 @@ class FlopCounterMode(TorchDispatchMode):
|
||||||
self.flop_counts[par][func_packet] += flop_count
|
self.flop_counts[par][func_packet] += flop_count
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
class _DecomposedCounterMode(TorchDispatchMode):
|
||||||
|
def __init__(self, counter):
|
||||||
|
self.counter = counter
|
||||||
|
|
||||||
|
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||||
|
kwargs = kwargs if kwargs else {}
|
||||||
|
out = func(*args, **kwargs)
|
||||||
|
return self.counter._count_flops(func._overloadpacket, out, args, kwargs)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user