mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
The original motivation for MYPYINDUCTOR was a faster type checking configuration that only checked a subset of files. With the removal of `follow_imports = ignore`, we are now able to use dmypy to do fast incremental typechecking, eliminating the need for this. Perhaps erroneously, when I tee'ed up this PR I elected to delete the `follow_imports = skip` designations in the mypy-inductor.ini. This lead to a number of extra type error suppressions that I manually edited. You will need to review. Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/118432 Approved by: https://github.com/Skylion007 ghstack dependencies: #118414, #118418
221 lines
7.8 KiB
Python
221 lines
7.8 KiB
Python
import operator
|
|
from collections import defaultdict
|
|
from typing import Any, Callable, DefaultDict, Dict, Optional, Tuple, Type
|
|
|
|
import torch
|
|
import torch.fx
|
|
from torch.fx.experimental.symbolic_shapes import statically_known_true, sym_eq
|
|
from torch.utils import _pytree as pytree
|
|
from torch.utils._pytree import tree_map
|
|
from .virtualized import V
|
|
|
|
|
|
# Check the pattern: (nn.module, F.function/torch.Tensor.method) matched.
|
|
# Works for length 2 patterns with 1 module and 1 function/method.
|
|
def matches_module_function_pattern(
|
|
pattern: Tuple[Type[torch.nn.modules.Module], Callable[..., Any]],
|
|
node: torch.fx.node.Node,
|
|
modules: Dict[str, torch.nn.modules.Module],
|
|
) -> bool:
|
|
if len(node.args) == 0:
|
|
return False
|
|
if not isinstance(node.args[0], torch.fx.Node) or not isinstance(
|
|
node, torch.fx.Node
|
|
):
|
|
return False
|
|
# the first node is call_module
|
|
if node.args[0].op != "call_module":
|
|
return False
|
|
if not isinstance(node.args[0].target, str):
|
|
return False
|
|
if node.args[0].target not in modules:
|
|
return False
|
|
if type(modules[node.args[0].target]) is not pattern[0]:
|
|
return False
|
|
# the second node is call_function or call_method
|
|
if node.op != "call_function" and node.op != "call_method":
|
|
return False
|
|
if node.target != pattern[1]:
|
|
return False
|
|
# make sure node.args[0] output is only used by current node.
|
|
if len(node.args[0].users) > 1:
|
|
return False
|
|
return True
|
|
|
|
|
|
class FakeTensorUpdater:
|
|
"""
|
|
The main idea here is that it's difficult to maintain accurate fake
|
|
tensors (our primary form of metadata) for each node in our graph as we
|
|
transform it.
|
|
|
|
The most reliable way to obtain this information is by rerunning
|
|
faketensor propagation. However, in general, faketensor propagation is
|
|
fairly expensive. So, instead we'd like to only rerun faketensor
|
|
propagation on nodes that have changed.
|
|
|
|
In order to detect which nodes have changed, we first hash its node,
|
|
target, and argument lists (which are immutable in FX).
|
|
|
|
Then, whenever we call incremental_update, we check which FX nodes have a
|
|
new hash, and recompute the faketensor metadata for that node. Then, we
|
|
continue to recursively compute the faketensors for all users until the
|
|
fake tensors stop changing.
|
|
"""
|
|
|
|
def __init__(self, graph: torch.fx.Graph):
|
|
self.processed_hashes = set()
|
|
self.graph = graph
|
|
|
|
for node in self.graph.nodes:
|
|
self.processed_hashes.add(self.hash_node(node))
|
|
|
|
def hash_node(self, node: torch.fx.Node):
|
|
# todo(chilli): Not a great hash function
|
|
return (node, node.target, id(node.args), id(node.kwargs))
|
|
|
|
def incremental_update(self):
|
|
processed = set()
|
|
existing_storages: DefaultDict[Optional[int], int] = defaultdict(int)
|
|
for node in self.graph.nodes:
|
|
existing_storages[get_node_storage(node)] += 1
|
|
|
|
def is_intlist_same(new, old):
|
|
return statically_known_true(sym_eq(new, old))
|
|
|
|
def is_fake_tensor_same(new, old):
|
|
if type(new) != type(old):
|
|
return False
|
|
if isinstance(new, (list, tuple)):
|
|
if len(new) != len(old):
|
|
return False
|
|
return all(
|
|
is_fake_tensor_same(new_i, old_i) for new_i, old_i in zip(new, old)
|
|
)
|
|
assert isinstance(new, torch.Tensor)
|
|
if not is_intlist_same(new.shape, old.shape) or new.layout != old.layout:
|
|
return False
|
|
if new.layout == torch.strided and (
|
|
not is_intlist_same(new.stride(), old.stride())
|
|
or not statically_known_true(
|
|
new.storage_offset() == old.storage_offset()
|
|
)
|
|
):
|
|
return False
|
|
|
|
if get_storage(new) == get_storage(old):
|
|
return True
|
|
|
|
# This is the case where it returns a completely fresh storage that's used nowhere else.
|
|
if (
|
|
existing_storages[get_storage(old)] == 1
|
|
and get_storage(new) not in existing_storages
|
|
):
|
|
return True
|
|
return False
|
|
|
|
for node in self.graph.nodes:
|
|
if self.hash_node(node) in self.processed_hashes:
|
|
continue
|
|
|
|
def is_aten_node(node):
|
|
return node.op == "call_function" and isinstance(
|
|
node.target, torch._ops.OpOverload
|
|
)
|
|
|
|
if not is_aten_node(node):
|
|
continue
|
|
|
|
processing = [node]
|
|
while len(processing) > 0:
|
|
updating_node = processing.pop()
|
|
if updating_node in processed:
|
|
continue
|
|
if is_aten_node(updating_node):
|
|
continue
|
|
|
|
is_valid, args, kwargs = get_fake_args_kwargs(updating_node)
|
|
if not is_valid:
|
|
continue
|
|
with V.fake_mode:
|
|
new_fake_tensor = updating_node.target(*args, **kwargs)
|
|
if "val" in updating_node.meta and is_fake_tensor_same(
|
|
new_fake_tensor, updating_node.meta["val"]
|
|
):
|
|
continue
|
|
updating_node.meta["val"] = new_fake_tensor
|
|
|
|
# todo(chilli): This code path is not exercised by our existing
|
|
# tests - add a test
|
|
existing_storages[get_node_storage(new_fake_tensor)] += 1
|
|
processed.add(updating_node)
|
|
processing.extend(updating_node.users)
|
|
|
|
self.processed_hashes.add(self.hash_node(updating_node))
|
|
|
|
|
|
def get_storage(t: torch.Tensor) -> int:
|
|
return t.untyped_storage()._cdata
|
|
|
|
|
|
def get_node_storage(node: torch.fx.Node) -> Optional[int]:
|
|
if "val" not in node.meta:
|
|
return None
|
|
if not isinstance(node.meta["val"], torch.Tensor):
|
|
return None
|
|
if not torch._C._has_storage(node.meta["val"]):
|
|
return None
|
|
return get_storage(node.meta["val"])
|
|
|
|
|
|
def get_fake(x):
|
|
if isinstance(x, torch.fx.Node):
|
|
if "val" not in x.meta:
|
|
return x
|
|
return x.meta["val"]
|
|
return x
|
|
|
|
|
|
def get_fake_args_kwargs(x: torch.fx.Node) -> Tuple[bool, Tuple[Any], Dict[str, Any]]:
|
|
"""
|
|
First value returns a boolean if any of the input nodes don't have a faketensor.
|
|
"""
|
|
args, kwargs = tree_map(get_fake, (x.args, x.kwargs))
|
|
if any(
|
|
isinstance(a, torch.fx.Node) for a in pytree.arg_tree_leaves(*args, **kwargs)
|
|
):
|
|
return False, args, kwargs
|
|
return True, args, kwargs
|
|
|
|
|
|
def is_node_realized(node: torch.fx.Node) -> bool:
|
|
"""Returns true if a node is always realized when lowered to inductor IR.
|
|
|
|
NOTE: This may return some false negatives. e.g. it doesn't
|
|
handle buffers realized heuristically during lowering, or
|
|
buffers realized indirectly through view ops.
|
|
"""
|
|
from torch._inductor.lowering import fallbacks, needs_realized_inputs
|
|
|
|
def is_buffer(node: torch.fx.Node) -> bool:
|
|
if node.op == "call_function" and node.target is operator.getitem:
|
|
# For nodes with multiple outputs, we get the fx graph:
|
|
# foo = torch.ops.aten.foo(...)
|
|
# getitem = foo[0]
|
|
# getitem_1 = foo[1]
|
|
# where we need to check if foo is a fallback kernel
|
|
return is_buffer(node.args[0]) # type: ignore[arg-type]
|
|
return node.op in ("placeholder", "output") or node.target in fallbacks
|
|
|
|
if is_buffer(node):
|
|
return True
|
|
|
|
def realizes_inputs(node: torch.fx.Node) -> bool:
|
|
return node.op == "output" or node.target in needs_realized_inputs
|
|
|
|
if any(realizes_inputs(user) for user in node.users):
|
|
return True
|
|
|
|
# Otherwise, assume node isn't realized
|
|
return False
|