FlopCounterMode: Decompose ops for inference mode (#138508)

Fixes #126268

I've basically followed @ezyang suggestion (I think) to use `func.decompose(...)`. Since `__torch_dispatch__` won't be called a second time for the same op, I've added a second `TorchDispatchMode` (`_DecomposedCounterMode`) that simpy dispatches to the parent flop counter. Using `self` as the inner context manager is not possible, since the second call to `__enter__` would re-initialize the counter's tracking state.

Let me know if there's something wrong with this implementation, since I'm quite unsure how the decomposition thing actually works :D

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138508
Approved by: https://github.com/ezyang
This commit is contained in:
Florian (Feuermagier) 2024-11-09 03:13:50 +00:00 committed by PyTorch MergeBot
parent 4488e23763
commit f915409c26
2 changed files with 63 additions and 2 deletions

View File

@ -810,6 +810,30 @@ class TestFlopCounter(TestCase):
self.assertEqual(called, 1)
self.assertExpectedInline(get_total_flops(mode), """9001""")
@skipIfNoTorchVision
def test_inference_mode(self):
def get_flops(model):
with FlopCounterMode(model) as mode:
a = T(1, 3, 224, 224)
model(a).sum()
return mode
resnet18 = torchvision_models.resnet18()
mode_standard = get_flops(resnet18)
with torch.inference_mode():
mode_inference = get_flops(resnet18)
self.assertEqual(get_total_flops(mode_standard), get_total_flops(mode_inference))
layer1_conv_flops_standard = mode_standard.flop_counts["ResNet.layer1"][
torch.ops.aten.convolution
]
layer1_conv_flops_inference = mode_inference.flop_counts["ResNet.layer1"][
torch.ops.aten.convolution
]
self.assertEqual(layer1_conv_flops_standard, layer1_conv_flops_inference)
if __name__ == "__main__":
run_tests()

View File

@ -1,6 +1,7 @@
# mypy: allow-untyped-defs
# mypy: allow-untyped-decorators
import torch
from torch._C import DispatchKey
from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
from .module_tracker import ModuleTracker
from typing import List, Any, Dict, Optional, Union, Tuple, Iterator
@ -632,6 +633,7 @@ class FlopCounterMode(TorchDispatchMode):
**{k: v if getattr(v, "_get_raw", False) else shape_wrapper(v) for k, v in custom_mapping.items()}
}
self.mod_tracker = ModuleTracker()
self.decomposed_counter = _DecomposedCounterMode(self)
def get_total_flops(self) -> int:
return sum(self.flop_counts['Global'].values())
@ -722,8 +724,34 @@ class FlopCounterMode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
kwargs = kwargs if kwargs else {}
out = func(*args, **kwargs)
return self._count_flops(func._overloadpacket, out, args, kwargs)
# Skip ops from non-standard dispatch_sizes_strides_policy such as NJT
if func in {torch.ops.aten.is_contiguous.default,
torch.ops.aten.is_contiguous.memory_format,
torch.ops.aten.is_strides_like_format.default,
torch.ops.aten.is_non_overlapping_and_dense.default,
torch.ops.aten.size.default,
torch.ops.aten.sym_size.default,
torch.ops.aten.stride.default,
torch.ops.aten.sym_stride.default,
torch.ops.aten.storage_offset.default,
torch.ops.aten.sym_storage_offset.default,
torch.ops.aten.numel.default,
torch.ops.aten.sym_numel.default,
torch.ops.aten.dim.default,
torch.ops.prim.layout.default}:
return NotImplemented
dk = DispatchKey.CompositeImplicitAutograd
if torch._C._dispatch_has_kernel_for_dispatch_key(func.name(), dk):
# func can be decomposed; redispatch
with self.decomposed_counter:
return func._op_dk(dk, *args, **kwargs)
else:
# no further decomposition; execute & count flops
out = func(*args, **kwargs)
return self._count_flops(func._overloadpacket, out, args, kwargs)
def _count_flops(self, func_packet, out, args, kwargs):
if func_packet in self.flop_registry:
@ -733,3 +761,12 @@ class FlopCounterMode(TorchDispatchMode):
self.flop_counts[par][func_packet] += flop_count
return out
class _DecomposedCounterMode(TorchDispatchMode):
def __init__(self, counter):
self.counter = counter
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
kwargs = kwargs if kwargs else {}
out = func(*args, **kwargs)
return self.counter._count_flops(func._overloadpacket, out, args, kwargs)