mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
FlopCounterMode: Decompose ops for inference mode (#138508)
Fixes #126268 I've basically followed @ezyang suggestion (I think) to use `func.decompose(...)`. Since `__torch_dispatch__` won't be called a second time for the same op, I've added a second `TorchDispatchMode` (`_DecomposedCounterMode`) that simpy dispatches to the parent flop counter. Using `self` as the inner context manager is not possible, since the second call to `__enter__` would re-initialize the counter's tracking state. Let me know if there's something wrong with this implementation, since I'm quite unsure how the decomposition thing actually works :D Pull Request resolved: https://github.com/pytorch/pytorch/pull/138508 Approved by: https://github.com/ezyang
This commit is contained in:
parent
4488e23763
commit
f915409c26
|
|
@ -810,6 +810,30 @@ class TestFlopCounter(TestCase):
|
|||
self.assertEqual(called, 1)
|
||||
self.assertExpectedInline(get_total_flops(mode), """9001""")
|
||||
|
||||
@skipIfNoTorchVision
|
||||
def test_inference_mode(self):
|
||||
def get_flops(model):
|
||||
with FlopCounterMode(model) as mode:
|
||||
a = T(1, 3, 224, 224)
|
||||
model(a).sum()
|
||||
return mode
|
||||
|
||||
resnet18 = torchvision_models.resnet18()
|
||||
|
||||
mode_standard = get_flops(resnet18)
|
||||
|
||||
with torch.inference_mode():
|
||||
mode_inference = get_flops(resnet18)
|
||||
|
||||
self.assertEqual(get_total_flops(mode_standard), get_total_flops(mode_inference))
|
||||
|
||||
layer1_conv_flops_standard = mode_standard.flop_counts["ResNet.layer1"][
|
||||
torch.ops.aten.convolution
|
||||
]
|
||||
layer1_conv_flops_inference = mode_inference.flop_counts["ResNet.layer1"][
|
||||
torch.ops.aten.convolution
|
||||
]
|
||||
self.assertEqual(layer1_conv_flops_standard, layer1_conv_flops_inference)
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
# mypy: allow-untyped-defs
|
||||
# mypy: allow-untyped-decorators
|
||||
import torch
|
||||
from torch._C import DispatchKey
|
||||
from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
|
||||
from .module_tracker import ModuleTracker
|
||||
from typing import List, Any, Dict, Optional, Union, Tuple, Iterator
|
||||
|
|
@ -632,6 +633,7 @@ class FlopCounterMode(TorchDispatchMode):
|
|||
**{k: v if getattr(v, "_get_raw", False) else shape_wrapper(v) for k, v in custom_mapping.items()}
|
||||
}
|
||||
self.mod_tracker = ModuleTracker()
|
||||
self.decomposed_counter = _DecomposedCounterMode(self)
|
||||
|
||||
def get_total_flops(self) -> int:
|
||||
return sum(self.flop_counts['Global'].values())
|
||||
|
|
@ -722,6 +724,32 @@ class FlopCounterMode(TorchDispatchMode):
|
|||
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
kwargs = kwargs if kwargs else {}
|
||||
|
||||
# Skip ops from non-standard dispatch_sizes_strides_policy such as NJT
|
||||
if func in {torch.ops.aten.is_contiguous.default,
|
||||
torch.ops.aten.is_contiguous.memory_format,
|
||||
torch.ops.aten.is_strides_like_format.default,
|
||||
torch.ops.aten.is_non_overlapping_and_dense.default,
|
||||
torch.ops.aten.size.default,
|
||||
torch.ops.aten.sym_size.default,
|
||||
torch.ops.aten.stride.default,
|
||||
torch.ops.aten.sym_stride.default,
|
||||
torch.ops.aten.storage_offset.default,
|
||||
torch.ops.aten.sym_storage_offset.default,
|
||||
torch.ops.aten.numel.default,
|
||||
torch.ops.aten.sym_numel.default,
|
||||
torch.ops.aten.dim.default,
|
||||
torch.ops.prim.layout.default}:
|
||||
|
||||
return NotImplemented
|
||||
|
||||
dk = DispatchKey.CompositeImplicitAutograd
|
||||
if torch._C._dispatch_has_kernel_for_dispatch_key(func.name(), dk):
|
||||
# func can be decomposed; redispatch
|
||||
with self.decomposed_counter:
|
||||
return func._op_dk(dk, *args, **kwargs)
|
||||
else:
|
||||
# no further decomposition; execute & count flops
|
||||
out = func(*args, **kwargs)
|
||||
return self._count_flops(func._overloadpacket, out, args, kwargs)
|
||||
|
||||
|
|
@ -733,3 +761,12 @@ class FlopCounterMode(TorchDispatchMode):
|
|||
self.flop_counts[par][func_packet] += flop_count
|
||||
|
||||
return out
|
||||
|
||||
class _DecomposedCounterMode(TorchDispatchMode):
|
||||
def __init__(self, counter):
|
||||
self.counter = counter
|
||||
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
kwargs = kwargs if kwargs else {}
|
||||
out = func(*args, **kwargs)
|
||||
return self.counter._count_flops(func._overloadpacket, out, args, kwargs)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user