pytorch/torch/_inductor/fx_utils.py
Gabriel Ferns 7e83d50845 Inductor logging + analysis of torch.profile (#149697)
Prereqs:
 - https://github.com/pytorch/pytorch/pull/152708

Features:
1. Adds inductor's estimate of flops and bandwidth to the json trace events that perfetto uses.
1. Only use the tflops estimation from triton if we don't have the info from the datasheet because Triton's estimates are inaccurate. I have a backlog item to fix triton flops estimation upstream. New `DeviceInfo` class, and new function `get_device_tflops`.
1. New helpers `countable_fx` and `count_flops_fx` helps get the flops of an `fx.Node`.
1. Extends Triton `torch.profiler` logging to `DebugAutotuner`.
1. New script `profile_analysis.py`: `--augment_trace` adds perf estimates to any perfetto json trace, `--analyze` creates a summary table of these perf estimates, and `--diff` will compare two traces side by side:
```python
Device(NVIDIA H100, 0):
 Kernel Name                              | resnet Kernel Count | resnet FLOPS       | resnet bw gbps        | resnet Dur (ms)    | resnet Achieved FLOPS % | resnet Achieved Bandwidth % | newresnet Kernel Count | newresnet FLOPS    | newresnet bw gbps     | newresnet Dur (ms) | newresnet Achieved FLOPS % | newresnet Achieved Bandwidth %
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
 triton_poi_fused__native_batch_norm_legi | 24                  | 0                  | 0.11395268248131513   | 2.5919166666666666 | 0                       | 0.003401572611382541        | 24                     | 0                  | 0.11395268248131513   | 2.5919166666666666 | 0                          | 0.003401572611382541
 sm90_xmma_fprop_implicit_gemm_f32f32_tf3 | 142                 | 16932673552.422373 | 0.2585007824198784    | 12.441619718309857 | 0.08683422334575583     | 0.007716441266265022        | 142                    | 16932673552.422373 | 0.2585007824198784    | 12.441619718309857 | 0.08683422334575583        | 0.007716441266265022
 triton_red_fused__native_batch_norm_legi | 39                  | 0                  | 0.13990024992108846   | 5.752589743589743  | 0                       | 0.004176126863316074        | 39                     | 0                  | 0.13990024992108846   | 5.752589743589743  | 0                          | 0.004176126863316074
 triton_poi_fused__native_batch_norm_legi | 25                  | 0                  | 0.31824055917536503   | 2.5291999999999994 | 0                       | 0.009499718184339253        | 25                     | 0                  | 0.31824055917536503   | 2.5291999999999994 | 0                          | 0.009499718184339253
 void cutlass::Kernel2<cutlass_80_tensoro | 98                  | 16211056473.596165 | 0.42972434051025826   | 7.130408163265306  | 0.08313362294151874     | 0.012827592254037562        | 98                     | 16211056473.596165 | 0.42972434051025826   | 7.130408163265306  | 0.08313362294151874        | 0.012827592254037562
 triton_red_fused__native_batch_norm_legi | 73                  | 0                  | 0.3225381327611705    | 9.987068493150682  | 0                       | 0.009628003963020014        | 73                     | 0                  | 0.3225381327611705    | 9.987068493150682  | 0                          | 0.009628003963020014
 triton_poi_fused__native_batch_norm_legi | 15                  | 0                  | 1.4491211346487216    | 4.439333333333333  | 0                       | 0.043257347302946926        | 15                     | 0                  | 1.4491211346487216    | 4.439333333333333  | 0                          | 0.043257347302946926
 void cutlass::Kernel2<cutlass_80_tensoro | 186                 | 14501701145.337954 | 0.2667131401910989    | 7.873865591397849  | 0.07436769818122027     | 0.007961586274361157        | 186                    | 14501701145.337954 | 0.2667131401910989    | 7.873865591397849  | 0.07436769818122027        | 0.007961586274361157
 triton_poi_fused__native_batch_norm_legi | 33                  | 0                  | 1.4924556538193923    | 4.3101515151515155 | 0                       | 0.044550915039384846        | 33                     | 0                  | 1.4924556538193923    | 4.3101515151515155 | 0                          | 0.044550915039384846
 triton_red_fused__native_batch_norm_legi | 29                  | 0                  | 0.25562590522631107   | 6.296275862068965  | 0                       | 0.007630624036606301        | 29                     | 0                  | 0.25562590522631107   | 6.296275862068965  | 0                          | 0.007630624036606301
 triton_poi_fused__native_batch_norm_legi | 13                  | 0                  | 0.5870562174192726    | 2.7397692307692307 | 0                       | 0.01752406619162008         | 13                     | 0                  | 0.5870562174192726    | 2.7397692307692307 | 0                          | 0.01752406619162008
 triton_poi_fused__native_batch_norm_legi | 34                  | 0                  | 0.41409928846284      | 2.853588235294117  | 0                       | 0.012361172789935523        | 34                     | 0                  | 0.41409928846284      | 2.853588235294117  | 0                          | 0.012361172789935523
 triton_per_fused__native_batch_norm_legi | 34                  | 0                  | 0.11705315007018151   | 3.460647058823529  | 0                       | 0.0034941238826919864       | 34                     | 0                  | 0.11705315007018151   | 3.460647058823529  | 0                          | 0.0034941238826919864
 triton_poi_fused__native_batch_norm_legi | 16                  | 0                  | 0.17207853197124584   | 2.3459375000000002 | 0                       | 0.005136672596156592        | 16                     | 0                  | 0.17207853197124584   | 2.3459375000000002 | 0                          | 0.005136672596156592
 triton_per_fused__native_batch_norm_legi | 30                  | 0                  | 0.2639714322022256    | 6.131199999999999  | 0                       | 0.007879744244842555        | 30                     | 0                  | 0.2639714322022256    | 6.131199999999999  | 0                          | 0.007879744244842555
 sm90_xmma_fprop_implicit_gemm_f32f32_tf3 | 100                 | 11875430356.891787 | 0.19494470869421385   | 16.36534           | 0.06089964285585531     | 0.005819245035648175        | 100                    | 11875430356.891787 | 0.19494470869421385   | 16.36534           | 0.06089964285585531        | 0.005819245035648175
 triton_poi_fused__native_batch_norm_legi | 8                   | 0                  | 0.9854096626224687    | 3.2757500000000004 | 0                       | 0.029415213809625928        | 8                      | 0                  | 0.9854096626224687    | 3.2757500000000004 | 0                          | 0.029415213809625928
 void cublasLt::splitKreduce_kernel<32, 1 | 56                  | 34377923395.147064 | 0.8310300045762317    | 3.4199999999999986 | 0.17629704305203628     | 0.024806865808245714        | 56                     | 34377923395.147064 | 0.8310300045762317    | 3.4199999999999986 | 0.17629704305203628        | 0.024806865808245714
 triton_poi_fused__native_batch_norm_legi | 23                  | 0                  | 0.9944002965861103    | 3.2431304347826084 | 0                       | 0.02968359094286896         | 23                     | 0                  | 0.9944002965861103    | 3.2431304347826084 | 0                          | 0.02968359094286896
 triton_per_fused__native_batch_norm_legi | 10                  | 0                  | 0.1826801058931057    | 4.428800000000001  | 0                       | 0.00545313748934644         | 10                     | 0                  | 0.1826801058931057    | 4.428800000000001  | 0                          | 0.00545313748934644
 triton_poi_fused__native_batch_norm_legi | 10                  | 0                  | 0.3168973585366449    | 2.5471999999999997 | 0                       | 0.009459622642884923        | 10                     | 0                  | 0.3168973585366449    | 2.5471999999999997 | 0                          | 0.009459622642884923
 triton_poi_fused__native_batch_norm_legi | 34                  | 0                  | 1.1463614897015777    | 4.124323529411764  | 0                       | 0.03421974596124114         | 34                     | 0                  | 1.1463614897015777    | 4.124323529411764  | 0                          | 0.03421974596124114
 void cask_plugin_cudnn::xmma_cudnn::init | 44                  | 44045510816.64277  | 2.0661232850348643    | 3.6887499999999993 | 0.22587441444432194     | 0.06167532194133924         | 44                     | 44045510816.64277  | 2.0661232850348643    | 3.6887499999999993 | 0.22587441444432194        | 0.06167532194133924
 sm90_xmma_fprop_implicit_gemm_f32f32_tf3 | 95                  | 7876855400.165316  | 0.4694941555946739    | 18.224315789473682 | 0.04039413025725802     | 0.014014750913273854        | 95                     | 7876855400.165316  | 0.4694941555946739    | 18.224315789473682 | 0.04039413025725802        | 0.014014750913273854
 triton_per_fused__native_batch_norm_legi | 41                  | 0                  | 0.06825669875995298   | 3.0384146341463416 | 0                       | 0.002037513395819492        | 41                     | 0                  | 0.06825669875995298   | 3.0384146341463416 | 0                          | 0.002037513395819492
 triton_poi_fused__native_batch_norm_legi | 23                  | 0                  | 0.08808154712430301   | 2.3275652173913044 | 0                       | 0.0026292999141582997       | 23                     | 0                  | 0.08808154712430301   | 2.3275652173913044 | 0                          | 0.0026292999141582997
 triton_per_fused__native_batch_norm_legi | 40                  | 0                  | 0.18179321034952417   | 4.556825           | 0                       | 0.005426662995508183        | 40                     | 0                  | 0.18179321034952417   | 4.556825           | 0                          | 0.005426662995508183
 triton_poi_fused__native_batch_norm_legi | 15                  | 0                  | 0.5887415155454232    | 2.783866666666667  | 0                       | 0.017574373598370836        | 15                     | 0                  | 0.5887415155454232    | 2.783866666666667  | 0                          | 0.017574373598370836
 void cutlass::Kernel2<cutlass_80_tensoro | 38                  | 14242013806.264643 | 0.256592404353939     | 7.217631578947369  | 0.0730359682372546      | 0.007659474756834           | 38                     | 14242013806.264643 | 0.256592404353939     | 7.217631578947369  | 0.0730359682372546         | 0.007659474756834
 triton_poi_fused__native_batch_norm_legi | 21                  | 0                  | 0.5842860973430516    | 2.7779047619047623 | 0                       | 0.017441376040091088        | 21                     | 0                  | 0.5842860973430516    | 2.7779047619047623 | 0                          | 0.017441376040091088
 triton_per_fused__native_batch_norm_legi | 16                  | 0                  | 0.11509365173486417   | 3.5959375000000002 | 0                       | 0.0034356313950705724       | 16                     | 0                  | 0.11509365173486417   | 3.5959375000000002 | 0                          | 0.0034356313950705724
 triton_poi_fused__native_batch_norm_legi | 14                  | 0                  | 0.1704672000243914    | 2.4044285714285714 | 0                       | 0.00508857313505646         | 14                     | 0                  | 0.1704672000243914    | 2.4044285714285714 | 0                          | 0.00508857313505646
 triton_poi_fused__native_batch_norm_legi | 58                  | 0                  | 2.307520779930795     | 8.190706896551722  | 0                       | 0.06888121731136704         | 58                     | 0                  | 2.307520779930795     | 8.190706896551722  | 0                          | 0.06888121731136704
 triton_per_fused__native_batch_norm_legi | 29                  | 0                  | 0.037243248971881276  | 3.0277586206896556 | 0                       | 0.001111738775280038        | 29                     | 0                  | 0.037243248971881276  | 3.0277586206896556 | 0                          | 0.001111738775280038
 triton_poi_fused__native_batch_norm_legi | 20                  | 0                  | 0.04741699795428918   | 2.2911500000000005 | 0                       | 0.0014154327747549007       | 20                     | 0                  | 0.04741699795428918   | 2.2911500000000005 | 0                          | 0.0014154327747549007
 triton_per_fused__native_batch_norm_legi | 25                  | 0                  | 0.13357016893727824   | 3.37536            | 0                       | 0.003987169222008305        | 25                     | 0                  | 0.13357016893727824   | 3.37536            | 0                          | 0.003987169222008305
 triton_poi_fused__native_batch_norm_legi | 13                  | 0                  | 0.3089862268300253    | 2.8111538461538457 | 0                       | 0.009223469457612694        | 13                     | 0                  | 0.3089862268300253    | 2.8111538461538457 | 0                          | 0.009223469457612694
 triton_poi_fused__native_batch_norm_legi | 17                  | 0                  | 0.3129385387909844    | 2.673              | 0                       | 0.009341448919133863        | 17                     | 0                  | 0.3129385387909844    | 2.673              | 0                          | 0.009341448919133863
 triton_per_fused__native_batch_norm_legi | 19                  | 0                  | 0.2215568162533158    | 3.8837368421052636 | 0                       | 0.0066136363060691275       | 19                     | 0                  | 0.2215568162533158    | 3.8837368421052636 | 0                          | 0.0066136363060691275
 std::enable_if<!(false), void>::type int | 23                  | 504916805.19297093 | 1.0118296096314707    | 8.113913043478261  | 0.0025893169497075447   | 0.030203868944223014        | 23                     | 504916805.19297093 | 1.0118296096314707    | 8.113913043478261  | 0.0025893169497075447      | 0.030203868944223014
 triton_poi_fused_add_copy__38            | 56                  | 0                  | 0                     | 2.132482142857143  | 0                       | 0                           | 56                     | 0                  | 0                     | 2.132482142857143  | 0                          | 0
 triton_poi_fused_convolution_0           | 18                  | 0                  | 0.43458610794936897   | 2.773333333333334  | 0                       | 0.012972719640279667        | 18                     | 0                  | 0.43458610794936897   | 2.773333333333334  | 0                          | 0.012972719640279667
 triton_poi_fused_convolution_1           | 17                  | 0                  | 0.028816312469162712  | 2.6145882352941174 | 0                       | 0.0008601884319153051       | 17                     | 0                  | 0.028816312469162712  | 2.6145882352941174 | 0                          | 0.0008601884319153051
 void convolve_common_engine_float_NHWC<f | 44                  | 8641868995.31118   | 0.024730540008465626  | 25.87327272727273  | 0.04431727689903169     | 0.0007382250748795709       | 44                     | 8641868995.31118   | 0.024730540008465626  | 25.87327272727273  | 0.04431727689903169        | 0.0007382250748795709
 triton_per_fused__native_batch_norm_legi | 12                  | 0                  | 0.6809930918986744    | 4.82675            | 0                       | 0.020328151996975356        | 12                     | 0                  | 0.6809930918986744    | 4.82675            | 0                          | 0.020328151996975356
 triton_per_fused__native_batch_norm_legi | 14                  | 0                  | 0.02883030597936608   | 2.6651428571428575 | 0                       | 0.0008606061486377935       | 14                     | 0                  | 0.02883030597936608   | 2.6651428571428575 | 0                          | 0.0008606061486377935
 triton_per_fused__native_batch_norm_legi | 16                  | 0                  | 0.0014658988233201874 | 2.098              | 0                       | 4.375817383045335e-05       | 16                     | 0                  | 0.0014658988233201874 | 2.098              | 0                          | 4.375817383045335e-05
 triton_poi_fused__native_batch_norm_legi | 13                  | 0                  | 0.9926297180284697    | 3.2367692307692306 | 0                       | 0.02963073785159611         | 13                     | 0                  | 0.9926297180284697    | 3.2367692307692306 | 0                          | 0.02963073785159611
 triton_poi_fused__native_batch_norm_legi | 9                   | 0                  | 1.3008817095666507    | 3.0863333333333336 | 0                       | 0.03883228983781048         | 9                      | 0                  | 1.3008817095666507    | 3.0863333333333336 | 0                          | 0.03883228983781048
 void at::native::(anonymous namespace):: | 98                  | 0                  | 0.09174335613709389   | 4.408520408163265  | 0                       | 0.0027386076458833994       | 98                     | 0                  | 0.09174335613709389   | 4.408520408163265  | 0                          | 0.0027386076458833994
 void at::native::vectorized_elementwise_ | 7                   | 0                  | 0                     | 1.7278571428571428 | 0                       | 0                           | 7                      | 0                  | 0                     | 1.7278571428571428 | 0                          | 0
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149697
Approved by: https://github.com/eellison, https://github.com/shunting314
2025-07-07 22:13:34 +00:00

345 lines
13 KiB
Python

# mypy: allow-untyped-defs
import contextlib
import operator
from collections import defaultdict
from typing import Any, Callable, Optional
import sympy
import torch
import torch.fx
from torch._dispatch.python import enable_python_dispatcher
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.experimental.symbolic_shapes import (
compute_unbacked_bindings,
rebind_unbacked,
statically_known_true,
sym_eq,
)
from torch.utils import _pytree as pytree
from torch.utils._ordered_set import OrderedSet
from torch.utils._pytree import tree_map
from torch.utils.flop_counter import flop_registry
from .virtualized import V
# Check the pattern: (nn.module, F.function/torch.Tensor.method) matched.
# Works for length 2 patterns with 1 module and 1 function/method.
def matches_module_function_pattern(
pattern: tuple[type[torch.nn.modules.Module], Callable[..., Any]],
node: torch.fx.node.Node,
modules: dict[str, torch.nn.modules.Module],
) -> bool:
if len(node.args) == 0:
return False
if not isinstance(node.args[0], torch.fx.Node) or not isinstance(
node, torch.fx.Node
):
return False
# the first node is call_module
if node.args[0].op != "call_module":
return False
if not isinstance(node.args[0].target, str):
return False
if node.args[0].target not in modules:
return False
if type(modules[node.args[0].target]) is not pattern[0]:
return False
# the second node is call_function or call_method
if node.op != "call_function" and node.op != "call_method":
return False
if node.target != pattern[1]:
return False
# make sure node.args[0] output is only used by current node.
if len(node.args[0].users) > 1:
return False
return True
class FakeTensorUpdater:
"""
The main idea here is that it's difficult to maintain accurate fake
tensors (our primary form of metadata) for each node in our graph as we
transform it.
The most reliable way to obtain this information is by rerunning
faketensor propagation. However, in general, faketensor propagation is
fairly expensive. So, instead we'd like to only rerun faketensor
propagation on nodes that have changed.
In order to detect which nodes have changed, we first hash its node,
target, and argument lists (which are immutable in FX).
Then, whenever we call incremental_update, we check which FX nodes have a
new hash, and recompute the faketensor metadata for that node. Then, we
continue to recursively compute the faketensors for all users until the
fake tensors stop changing.
"""
def __init__(self, graph: torch.fx.Graph) -> None:
self.processed_hashes = OrderedSet[Any]()
self.graph = graph
for node in self.graph.nodes:
self.processed_hashes.add(self.hash_node(node))
def hash_node(self, node: torch.fx.Node):
# todo(chilli): Not a great hash function
return (node, node.target, id(node.args), id(node.kwargs))
def incremental_update(self):
"""Update FakeTensors on self.graph. We will try to do the minimum amount of work."""
existing_storages: defaultdict[Optional[int], int] = defaultdict(int)
for node in self.graph.nodes:
existing_storages[get_node_storage(node)] += 1
def is_intlist_same(new, old):
return statically_known_true(sym_eq(new, old))
def is_fake_tensor_same(new, old, *, node):
if type(new) != type(old):
return False
if isinstance(new, (list, tuple)):
if len(new) != len(old):
return False
return all(
is_fake_tensor_same(new_i, old_i, node=node)
for new_i, old_i in zip(new, old)
)
if new is None:
return old is None
if not isinstance(new, torch.Tensor):
assert isinstance(new, (torch.SymInt, torch.SymBool, torch.SymFloat)), (
f"Unknown type {type(new)} in {self.graph}"
)
return (
new.node.shape_env._maybe_evaluate_static(
sympy.Eq(new.node.expr, old.node.expr)
)
== sympy.true
)
if not is_intlist_same(new.shape, old.shape) or new.layout != old.layout:
return False
if new.layout == torch.strided and (
not is_intlist_same(new.stride(), old.stride())
or not statically_known_true(
new.storage_offset() == old.storage_offset()
)
):
return False
if new.device != old.device:
return False
if get_storage(new) == get_storage(old):
return True
def any_user_may_alias(node):
if not isinstance(node.meta["val"], torch.Tensor):
# analysis too complicated on lists, can support in the future
return True
for user in node.users:
if not (
isinstance(
user.target,
(torch._ops.OpOverload, torch._ops.HigherOrderOperator),
)
or user.target
== torch._inductor.fx_passes.reinplace._generalized_scatter
):
return True
if isinstance(user.target, torch._ops.HigherOrderOperator):
# HOPs that survive until inductor are all non-aliasing HOPs.
# We will likely never support HOPs that are aliasing.
continue
# Strategy: do a FakeTensor prop, see if the storage aliases.
# If Inductor ever gets tighter invariants on OpOverloads
# (that is, we ban things like torch.ops.aten.reshape calls in the graph),
# Then this could just be a fast schema lookup.
is_valid, args, kwargs = get_fake_args_kwargs(user)
if not is_valid:
return True
with (
V.fake_mode,
enable_python_dispatcher(),
contextlib.ExitStack() as stack,
):
# Ignore unbacked symbols (if they exist): we're making
# this FakeTensor and then throwing it away.
shape_env = V.fake_mode.shape_env
if shape_env is not None:
stack.enter_context(
shape_env.ignore_fresh_unbacked_symbols()
)
new_fake_tensor = user.target(*args, **kwargs)
if not isinstance(new_fake_tensor, torch.Tensor):
# analysis too complicated on lists, can support in the future
return True
if get_storage(new_fake_tensor) == get_storage(node.meta["val"]):
return True
return False
# This is the case where it returns a completely fresh storage that's used nowhere else.
# If the FakeTensor's storage is fresh and none of the node's users can alias it, then
# we don't need to update this node.
if (
existing_storages[get_storage(old)] == 1
and get_storage(new) not in existing_storages
and not any_user_may_alias(node)
):
return True
return False
def should_process_node(node):
# node.target for nodes returning true from this function
# are called under fake mode and does not work for inductor
# lowerings. We check if the node.target is an aten operator
# or operator.getitem which is used when returning multiple
# tensors from an op.
return node.op == "call_function" and (
isinstance(node.target, torch._ops.OpOverload)
or node.target == operator.getitem
or node.target
== torch._inductor.fx_passes.reinplace._generalized_scatter
)
to_process = OrderedSet[int]()
for node in self.graph.nodes:
# NB: Be very careful about skipping nodes (via continues) here
# and ask for a careful review when changing this code. The
# consequence for incorrect FakeTensor metadata is difficult-to-debug
# silent incorrectness.
if (
self.hash_node(node) in self.processed_hashes
and id(node) not in to_process
):
continue
if not should_process_node(node):
continue
is_valid, args, kwargs = get_fake_args_kwargs(node)
if not is_valid:
continue
with V.fake_mode, enable_python_dispatcher():
new_fake_tensor = node.target(*args, **kwargs)
if "val" in node.meta and is_fake_tensor_same(
new_fake_tensor, node.meta["val"], node=node
):
continue
rebind_unbacked(V.fake_mode.shape_env, node, new_fake_tensor)
node.meta["val"] = new_fake_tensor
if (shape_env := V.fake_mode.shape_env) and (
symbol_to_path := compute_unbacked_bindings(shape_env, new_fake_tensor)
):
# Refresh the bindings to the new symbols
node.meta["unbacked_bindings"] = symbol_to_path
existing_storages[get_node_storage(node)] += 1
to_process.update([id(user) for user in node.users])
self.processed_hashes.add(self.hash_node(node))
def get_storage(t: torch.Tensor) -> int:
return t.untyped_storage()._cdata
def get_node_storage(node: torch.fx.Node) -> Optional[int]:
if "val" not in node.meta:
return None
if not isinstance(node.meta["val"], torch.Tensor):
return None
if not torch._C._has_storage(node.meta["val"]):
return None
return get_storage(node.meta["val"])
def get_fake(x):
if isinstance(x, torch.fx.Node):
if "val" not in x.meta:
return x
return x.meta["val"]
return x
def get_fake_args_kwargs(x: torch.fx.Node) -> tuple[bool, tuple[Any], dict[str, Any]]:
"""
First value returns a boolean if any of the input nodes don't have a faketensor.
"""
args, kwargs = tree_map(get_fake, (x.args, x.kwargs))
if any(
isinstance(a, torch.fx.Node) for a in pytree.arg_tree_leaves(*args, **kwargs)
):
return False, args, kwargs
return True, args, kwargs
def is_node_realized(node: torch.fx.Node) -> bool:
"""Returns true if a node is always realized when lowered to inductor IR.
NOTE: This may return some false negatives. e.g. it doesn't
handle buffers realized heuristically during lowering, or
buffers realized indirectly through view ops.
"""
from torch._inductor.lowering import fallbacks, needs_realized_inputs
def is_buffer(node: torch.fx.Node) -> bool:
if node.op == "call_function" and node.target is operator.getitem:
# For nodes with multiple outputs, we get the fx graph:
# foo = torch.ops.aten.foo(...)
# getitem = foo[0]
# getitem_1 = foo[1]
# where we need to check if foo is a fallback kernel
return is_buffer(node.args[0]) # type: ignore[arg-type]
return node.op in ("placeholder", "output") or node.target in fallbacks
if is_buffer(node):
return True
def realizes_inputs(node: torch.fx.Node) -> bool:
return node.op == "output" or node.target in needs_realized_inputs
if any(realizes_inputs(user) for user in node.users):
return True
# Otherwise, assume node isn't realized
return False
def count_flops_fx(node: torch.fx.Node) -> Optional[int]:
if not countable_fx(node) or isinstance(node.target, str):
return None
with FakeTensorMode(allow_non_fake_inputs=True):
success, args, kwargs = get_fake_args_kwargs(node)
if success:
with torch.utils.flop_counter.FlopCounterMode(
display=False
) as flop_counter_mode:
node.target(*args, **kwargs)
counted_flops = flop_counter_mode.get_total_flops()
return counted_flops
return None
def countable_fx(node: torch.fx.Node) -> bool:
"""
Whether or not we can count the flops of an FX node.
"""
assert isinstance(node, torch.fx.Node)
if not hasattr(node, "target"):
return False
target = node.target
if not hasattr(target, "overloadpacket"):
return target in flop_registry
packet = target.overloadpacket
return packet in flop_registry