mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
At a high level, the idea behind this PR is: * Make it clearer what the promotion and int/float rules for various Sympy operations are. Operators that previously were polymorphic over int/float are now split into separate operators for clarity. We never do mixed int/float addition/multiplication etc in sympy, instead, we always promote to the appropriate operator. (However, equality is currently not done correctly.) * Enforce strict typing on ValueRanges: if you have a ValueRange for a float, the lower and upper MUST be floats, and so forth for integers. The story begins in **torch/utils/_sympy/functions.py**. Here, I make some changes to how we represent certain operations in sympy expressions: * FloorDiv now only supports integer inputs; to do float floor division, do a truediv and then a trunc. Additionally, we remove the divide out addition by gcd optimization, because sympy gcd is over fields and is willing to generate rationals (but rationals are bad for ValueRange strict typing). * ModularIndexing, LShift, RShift now assert they are given integer inputs. * Mod only supports integer inputs; eventually we will support FloatMod (left for later work, when we build out Sympy support for floating operations). Unfortunately, I couldn't assert integer inputs here, because of a bad interaction with sympy's inequality solver that is used by the offline solver * TrueDiv is split into FloatTrueDiv and IntTrueDiv. This allows for us to eventually generate accurate code for Python semantics IntTrueDiv, which is written in a special way to preserve precision when the inputs are >= 2**53 beyond what first coercing the integer to floats and then doing true division. * Trunc is split to TruncToFloat and TruncToInt. * Round is updated to return a float, not an int, making it consistent with the round op handler in Inductor. To get Python-style conversion to int, we call TruncToInt on the result. * RoundDecimal updated to consistently only ever return a float * Add ToFloat for explicit coercion to float (required so we can enforce strict ValueRanges typing) In **torch/__init__.py**, we modify SymInt and SymFloat to appropriately call into new bindings that route to these refined sympy operations. Also, we modify `torch.sym_min` and `torch.sym_max` to have promotion semantics (if one argument is a float, the return result is always a float), making them inconsistent with builtins.min/max, but possible to do type analysis without runtime information. We also need to introduce some new op handlers in **torch/_inductor/ops_handler.py**: * `to_int` for truncation to int64, directly corresponding to TruncToInt; this can be implemented by trunc and dtype, but with a dedicated handler it is more convenient for roundtripping in Sympy * `int_truediv` for Python-style integer true division, which has higher precision than casting to floats and then running `truediv` These changes have consequences. First, we need to make some administrative changes: * Actually wire up these Sympy functions from SymInt/SymFloat in **torch/fx/experimental/sym_node.py**, including the new promotion rules (promote2) * Add support for new Sympy functions in **torch/utils/_sympy/interp.py**, **torch/utils/_sympy/reference.py** * In particular, in torch.utils._sympy.reference, we have a strong preference to NOT do nontrivial compute, instead, everything in ops handler should map to a singular sympy function * TODO: I chose to roundtrip mod back to our Mod function, but I think I'm going to have to deal with the C/Python inconsistency this to fix tests here * Add printer support for the Sympy functions in **torch/_inductor/codegen/common.py**, **torch/_inductor/codegen/cpp_utils.py**, **torch/_inductor/codegen/triton.py**. `int_truediv` and mixed precision equality is currently not implemented soundly, so we will lose precision in codegen for large values. TODO: The additions here are not exhaustive yet * Update ValueRanges logic to use new sympy functions in **torch/utils/_sympy/value_ranges.py**. In general, we prefer to use the new Sympy function rather than try to roll things by hand, which is what was done previously for many VR analysis functions. In **torch/fx/experimental/symbolic_shapes.py** we need to make some symbolic reasoning adjustments: * Avoid generation of rational subexpressions by removing simplification of `x // y` into `floor(x / y)`. This simplification then triggers an addition simplification rule `(x + y) / c --> x / c + y / c` which is bad because x / c is a rational number now * `_assert_bound_is_rational` is no more, we no longer generate rational bounds * Don't intersect non-int value ranges with the `int_range` * Support more sympy Functions for guard SYMPY_INTERP * Assert the type of value range is consistent with the variable type The new asserts uncovered necessary bug fixes: * **torch/_inductor/codegen/cpp.py**, **torch/_inductor/select_algorithm.py**, **torch/_inductor/sizevars.py** - Ensure Wild/Symbol manually allocated in Inductor is marked `is_integer` so it's accepted to build expressions * **torch/_inductor/utils.py** - make sure you actually pass in sympy.Expr to these functions * **torch/_inductor/ir.py** - make_contiguous_strides_for takes int/SymInt, not sympy.Expr! * **torch/export/dynamic_shapes.py** - don't use infinity to represent int ranges, instead use sys.maxsize - 1 Because of the removal of some symbolic reasoning that produced rationals, some of our symbolic reasoning has gotten worse and we are unable to simplify some guards. Check the TODO at **test/test_proxy_tensor.py** **Reland notes.** This requires this internal fbcode diff https://www.internalfb.com/phabricator/paste/view/P1403322587 but I cannot prepare the diff codev due to https://fb.workplace.com/groups/osssupport/posts/26343544518600814/ It also requires this Executorch PR https://github.com/pytorch/executorch/pull/3911 but the ET PR can be landed prior to this landing. Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/126905 Approved by: https://github.com/xadupre, https://github.com/lezcano
442 lines
17 KiB
Python
442 lines
17 KiB
Python
# mypy: allow-untyped-defs
|
|
import operator
|
|
import traceback
|
|
import typing
|
|
from contextlib import nullcontext
|
|
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
|
|
|
import torch
|
|
from functorch.experimental.control_flow import _unstack_pytree
|
|
from torch import fx
|
|
from torch._dispatch.python import enable_python_dispatcher
|
|
from torch._export.pass_infra.node_metadata import NodeMetadata
|
|
from torch._export.pass_infra.proxy_value import ProxyValue
|
|
from torch._subclasses import FakeTensor, UnsupportedFakeTensorException
|
|
from torch._subclasses.fake_tensor import FakeTensorMode
|
|
from torch.fx import traceback as fx_traceback
|
|
from torch.fx.experimental.proxy_tensor import PythonKeyTracer
|
|
from torch.fx.graph import CodeGen
|
|
from torch.fx.passes.infra.pass_base import PassBase, PassResult
|
|
from torch.fx.passes.shape_prop import _extract_tensor_metadata, TensorMetadata
|
|
from torch.utils import _pytree as pytree
|
|
from torch.fx.experimental.symbolic_shapes import PropagateUnbackedSymInts, compute_unbacked_bindings
|
|
|
|
|
|
__all__ = ["_ExportPassBaseDeprecatedDoNotUse"]
|
|
|
|
|
|
Argument = Any
|
|
Value = Any
|
|
Fn = Callable[..., Any]
|
|
PassType = Callable[[torch.fx.GraphModule], Optional[PassResult]]
|
|
|
|
|
|
_TORCH_SYM_OPS: Set[Callable] = {
|
|
torch.sym_int,
|
|
torch.sym_float,
|
|
torch.sym_ite,
|
|
torch.sym_max,
|
|
torch.sym_min,
|
|
torch.sym_not,
|
|
torch.sym_sqrt,
|
|
}
|
|
|
|
|
|
class ExportPassBaseError(RuntimeError):
|
|
pass
|
|
|
|
|
|
class _ExportPassBaseDeprecatedDoNotUse(PassBase):
|
|
"""
|
|
Interpreter-based pass class to help users maintain the IR spec while writing
|
|
transformations.
|
|
"""
|
|
|
|
@staticmethod
|
|
def _create_dummy_node_metadata():
|
|
return NodeMetadata({"stack_trace": "".join(traceback.format_stack(limit=1))})
|
|
|
|
|
|
class ExportTracer(PythonKeyTracer):
|
|
def __init__(self, callback: "_ExportPassBaseDeprecatedDoNotUse", codegen: CodeGen) -> None:
|
|
super().__init__()
|
|
self.callback = callback
|
|
self.root = torch.nn.Module()
|
|
self.graph = torch.fx.Graph()
|
|
self.graph.set_codegen(codegen)
|
|
self.tensor_attrs: Dict[str, torch.Tensor] = {} # type: ignore[assignment]
|
|
self.fake_tensor_mode: Optional[FakeTensorMode] = None
|
|
self.submodules: Dict[torch.nn.Module, str] = {}
|
|
|
|
def trace(self) -> None:
|
|
raise ExportPassBaseError("ExportTracer doesn't support trace().")
|
|
|
|
def create_arg(self, a: Argument) -> torch.fx.Node:
|
|
if isinstance(a, torch.nn.Module):
|
|
if a not in self.submodules:
|
|
name_submodule = f"submodule_{len(self.submodules)}"
|
|
self.root.add_module(name_submodule, a)
|
|
self.submodules[a] = name_submodule
|
|
elif isinstance(a, FakeTensor):
|
|
if not hasattr(a, "constant") or a.constant is None:
|
|
raise ExportPassBaseError(f"Cannot add {a} to graph.")
|
|
a = a.constant
|
|
node = super().create_arg(a)
|
|
if (
|
|
isinstance(a, torch.Tensor)
|
|
and isinstance(node, torch.fx.Node)
|
|
and node.op == "get_attr"
|
|
):
|
|
self.set_metadata(node, a)
|
|
self.callback.on_attr(ProxyValue(a, node))
|
|
return node
|
|
|
|
def set_metadata(
|
|
self, node: torch.fx.Node, value: Argument,
|
|
) -> None:
|
|
# propagate the fake tensor or sym nodes
|
|
def make_val(
|
|
x: Argument,
|
|
) -> Union[FakeTensor, torch.SymInt, torch.SymFloat, torch.SymBool, int, float, bool, str, None]:
|
|
if isinstance(x, FakeTensor):
|
|
return x
|
|
elif isinstance(x, torch.Tensor):
|
|
if x.is_quantized:
|
|
# TODO (tmanlaibaatar) properly support Quantized FakeTensor
|
|
x = torch.dequantize(x)
|
|
|
|
try:
|
|
assert self.fake_tensor_mode is not None
|
|
# TODO we should allocate static shapes
|
|
# for param/buffer values
|
|
if isinstance(x, torch.nn.Parameter):
|
|
fake_tensor = self.fake_tensor_mode.from_tensor(
|
|
x, static_shapes=True
|
|
)
|
|
else:
|
|
fake_tensor = self.fake_tensor_mode.from_tensor(x)
|
|
except UnsupportedFakeTensorException:
|
|
# TODO: This is just a workaround to get over the
|
|
# x.as_subclass error
|
|
print(
|
|
"Fakeifying a Tensor subclass is not supported \
|
|
right now. Instead a TensorMetadata is used."
|
|
)
|
|
fake_tensor = None
|
|
return fake_tensor
|
|
elif isinstance(x, (torch.SymInt, torch.SymFloat, torch.SymBool, int, float, bool, str)):
|
|
return x
|
|
else:
|
|
return None
|
|
|
|
node.meta["val"] = pytree.tree_map(make_val, value)
|
|
|
|
# Set the tensor_metadata for values that do not have a corresponding FakeTensor
|
|
def make_tensor_meta(x: Argument) -> Optional[TensorMetadata]:
|
|
if not isinstance(x, FakeTensor) and isinstance(x, torch.Tensor):
|
|
if x.is_quantized:
|
|
# TODO (tmanlaibaatar) properly support Quantized FakeTensor
|
|
x = torch.dequantize(x)
|
|
|
|
try:
|
|
assert self.fake_tensor_mode is not None
|
|
_ = self.fake_tensor_mode.from_tensor(x)
|
|
tensor_meta = None
|
|
except UnsupportedFakeTensorException:
|
|
# TODO: This is just a workaround to get over the
|
|
# x.as_subclass error
|
|
tensor_meta = _extract_tensor_metadata(x)
|
|
return tensor_meta
|
|
else:
|
|
return None
|
|
|
|
node.meta["tensor_meta"] = pytree.tree_map(make_tensor_meta, value)
|
|
|
|
class ExportInterpreter(fx.Interpreter):
|
|
def __init__(self, callback: "_ExportPassBaseDeprecatedDoNotUse", gm: fx.GraphModule) -> None:
|
|
super().__init__(gm)
|
|
self.callback = callback
|
|
self.node: torch.fx.Node = next(iter(gm.graph.nodes))
|
|
|
|
def placeholder(
|
|
self,
|
|
target: str,
|
|
args: Tuple[Argument, ...],
|
|
kwargs: Dict[str, Argument],
|
|
) -> ProxyValue:
|
|
arg = super().placeholder(target, args, kwargs)
|
|
return self.callback.placeholder(target, arg, NodeMetadata(self.node.meta))
|
|
|
|
def output(
|
|
self,
|
|
target: torch.fx.node.Target,
|
|
args: Tuple[Argument, ...],
|
|
kwargs: Dict[str, Argument],
|
|
) -> ProxyValue:
|
|
return self.callback.output(args[0], NodeMetadata(self.node.meta)).data
|
|
|
|
def call_function(
|
|
self,
|
|
target: torch.fx.node.Target,
|
|
args: Tuple[Argument, ...],
|
|
kwargs: Dict[str, Argument],
|
|
) -> ProxyValue:
|
|
meta = NodeMetadata(self.node.meta)
|
|
|
|
if target == operator.getitem:
|
|
value, key = args
|
|
return self.callback.call_getitem(value, key, meta)
|
|
elif getattr(target, "__module__", None) in {"_operator", "math"}:
|
|
assert callable(target)
|
|
return self.callback.call_sym(target, args, meta)
|
|
elif target in _TORCH_SYM_OPS:
|
|
assert callable(target)
|
|
return self.callback.call_sym(target, args, meta)
|
|
elif isinstance(target, (torch._ops.OpOverload, torch._ops.OpOverloadPacket)):
|
|
return self.callback.call_operator(
|
|
target,
|
|
args,
|
|
kwargs,
|
|
meta,
|
|
)
|
|
elif target == torch.ops.higher_order.cond:
|
|
pred, true_fn, false_fn, inputs = args
|
|
return self.callback.call_cond(pred, true_fn, false_fn, inputs, meta)
|
|
elif target == torch.ops.higher_order.map_impl:
|
|
f, mapped_args, operands = args # type: ignore[assignment]
|
|
return self.callback.call_map(f, mapped_args, operands, meta)
|
|
# For other unregistered HigherOrderOps, just interpret them blindly
|
|
elif isinstance(target, torch._ops.HigherOrderOperator):
|
|
return self.callback._fx(
|
|
"call_function",
|
|
target,
|
|
args,
|
|
kwargs,
|
|
meta,
|
|
)
|
|
else:
|
|
raise ExportPassBaseError(f"Unsupported target type: {target}")
|
|
|
|
def get_attr(
|
|
self, target: str, args: Tuple[Argument, ...], kwargs: Dict[str, Argument]
|
|
) -> Argument:
|
|
return super().get_attr(target, args, kwargs)
|
|
|
|
def call_module(
|
|
self,
|
|
target: torch.fx.node.Target,
|
|
args: Tuple[Argument, ...],
|
|
kwargs: Dict[str, Argument],
|
|
) -> None:
|
|
raise ExportPassBaseError("call_module is not supported.")
|
|
|
|
def call_method(
|
|
self, target: str, args: Tuple[Argument, ...], kwargs: Dict[str, Argument]
|
|
) -> None:
|
|
raise ExportPassBaseError("call_method is not supported.")
|
|
|
|
def run_node(self, n: torch.fx.Node) -> Argument:
|
|
self.node = n
|
|
self.callback.node_debug_str = n.format_node()
|
|
return super().run_node(n)
|
|
|
|
def __init__(self) -> None:
|
|
self.interpreter = PropagateUnbackedSymInts(
|
|
torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph())
|
|
)
|
|
self.tracer = self.ExportTracer(self, CodeGen())
|
|
self.fake_tensor_mode: Optional[FakeTensorMode] = None
|
|
self._initialized = True
|
|
self.node_debug_str: typing.Optional[str] = None
|
|
|
|
def _fx(
|
|
self,
|
|
kind: str,
|
|
target: torch.fx.node.Target,
|
|
args: Tuple[Argument, ...],
|
|
kwargs: Dict[str, Argument],
|
|
meta: NodeMetadata,
|
|
) -> ProxyValue:
|
|
args_data, kwargs_data = pytree.tree_map_only(
|
|
ProxyValue, lambda x: x.data, (args, kwargs)
|
|
)
|
|
res_data = getattr(self.interpreter, kind)(target, args_data, kwargs_data)
|
|
args_proxy, kwargs_proxy = pytree.tree_map_only(
|
|
ProxyValue, lambda x: x.proxy, (args, kwargs)
|
|
)
|
|
|
|
name = None
|
|
if isinstance(target, torch._ops.OpOverload):
|
|
name = self.tracer.graph._target_to_str(target.overloadpacket.__name__)
|
|
|
|
res_proxy = self.tracer.create_proxy(kind, target, args_proxy, kwargs_proxy, name=name)
|
|
res_proxy.node.meta.update(meta.data)
|
|
if self.fake_tensor_mode and (shape_env := self.fake_tensor_mode.shape_env):
|
|
if symbol_to_path := compute_unbacked_bindings(shape_env, res_data):
|
|
res_proxy.node.meta["unbacked_bindings"] = symbol_to_path
|
|
self.tracer.set_metadata(res_proxy.node, res_data)
|
|
return ProxyValue(res_data, res_proxy)
|
|
|
|
def inputs(self, graph_module: torch.fx.GraphModule) -> List[Argument]:
|
|
# TODO(angelayi): Update this with what we decide to do for metadata in
|
|
# the exported graph module
|
|
if (args := graph_module.meta.get("args", None)) is not None:
|
|
return list(args)
|
|
|
|
def extract_input(node: torch.fx.Node) -> Optional[FakeTensor]:
|
|
if "val" in node.meta:
|
|
fake = node.meta["val"]
|
|
if hasattr(fake, "constant") and fake.constant is not None:
|
|
return fake.constant
|
|
return fake
|
|
elif tensor_meta := node.meta.get("tensor_meta"):
|
|
assert self.fake_tensor_mode is not None
|
|
return FakeTensor(
|
|
self.fake_tensor_mode,
|
|
torch.empty(
|
|
tensor_meta.shape,
|
|
dtype=tensor_meta.dtype,
|
|
device="meta",
|
|
requires_grad=tensor_meta.requires_grad,
|
|
memory_format=tensor_meta.memory_format,
|
|
),
|
|
torch.device("cpu"),
|
|
)
|
|
elif len(node.users) == 0:
|
|
return None
|
|
raise ExportPassBaseError(
|
|
f"Cannot construct an input for graph module: {graph_module}.",
|
|
)
|
|
|
|
return [
|
|
extract_input(node)
|
|
for node in graph_module.graph.nodes
|
|
if node.op == "placeholder"
|
|
]
|
|
|
|
def on_attr(self, attr: ProxyValue) -> None:
|
|
pass
|
|
|
|
def placeholder(self, name: str, arg: Argument, meta: NodeMetadata) -> ProxyValue:
|
|
arg_proxy = self.tracer.create_proxy("placeholder", name, (), {})
|
|
arg_proxy.node.meta = meta.data
|
|
self.tracer.set_metadata(arg_proxy.node, arg)
|
|
return ProxyValue(arg, arg_proxy)
|
|
|
|
def call_operator(
|
|
self,
|
|
op,
|
|
args: Tuple[Argument, ...],
|
|
kwargs: Dict[str, Argument],
|
|
meta: NodeMetadata,
|
|
) -> ProxyValue:
|
|
return self._fx("call_function", op, args, kwargs, meta)
|
|
|
|
def call_sym(
|
|
self,
|
|
target: Fn,
|
|
args: Tuple[Argument, ...],
|
|
meta: NodeMetadata,
|
|
) -> ProxyValue:
|
|
return self._fx("call_function", target, args, {}, meta)
|
|
|
|
def call_cond(
|
|
self,
|
|
pred: ProxyValue,
|
|
true_fn: torch.fx.GraphModule,
|
|
false_fn: torch.fx.GraphModule,
|
|
inputs: List[Argument],
|
|
meta: NodeMetadata,
|
|
) -> ProxyValue:
|
|
true_branch = self.call_submodule(true_fn, tuple(inputs))
|
|
false_branch = self.call_submodule(false_fn, tuple(inputs))
|
|
assert true_branch is not None
|
|
assert false_branch is not None
|
|
return self._fx(
|
|
"call_function",
|
|
torch.ops.higher_order.cond,
|
|
(pred, true_branch.graph_module, false_branch.graph_module, list(inputs)),
|
|
{},
|
|
meta,
|
|
)
|
|
|
|
def call_map(
|
|
self,
|
|
f: torch.fx.GraphModule,
|
|
mapped_args: List[ProxyValue],
|
|
operands: List[ProxyValue],
|
|
meta: NodeMetadata,
|
|
) -> ProxyValue:
|
|
xs = _unstack_pytree([arg.data for arg in mapped_args])[0]
|
|
f_branch = self.call_submodule(f, tuple(xs + [arg.data for arg in operands]))
|
|
assert f_branch is not None
|
|
return self._fx(
|
|
"call_function",
|
|
torch.ops.higher_order.map_impl,
|
|
(f_branch.graph_module, mapped_args, operands),
|
|
{},
|
|
meta,
|
|
)
|
|
|
|
def call_getitem(
|
|
self, value: ProxyValue, key: int, meta: NodeMetadata
|
|
) -> ProxyValue:
|
|
return self._fx("call_function", operator.getitem, (value, key), {}, meta)
|
|
|
|
def output(self, results: List[Argument], meta: NodeMetadata) -> ProxyValue:
|
|
return self._fx("output", "output", (results,), {}, meta)
|
|
|
|
def call_submodule(
|
|
self, graph_module: fx.GraphModule, inputs: Tuple[Argument, ...]
|
|
) -> PassResult:
|
|
prev_tracer, self.tracer = self.tracer, self.ExportTracer(
|
|
self, graph_module.graph._codegen
|
|
)
|
|
self.tracer.fake_tensor_mode = prev_tracer.fake_tensor_mode
|
|
interpreter = self.ExportInterpreter(self, graph_module)
|
|
prev_interpreter, self.interpreter = self.interpreter, torch.fx.Interpreter(
|
|
torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph())
|
|
)
|
|
inputs_data = pytree.tree_map_only(ProxyValue, lambda x: x.data, inputs)
|
|
with fx_traceback.preserve_node_meta():
|
|
interpreter.run(*inputs_data)
|
|
|
|
new_graph_module = torch.fx.GraphModule(self.tracer.root, self.tracer.graph)
|
|
|
|
self.tracer = prev_tracer
|
|
self.interpreter = prev_interpreter
|
|
return PassResult(
|
|
new_graph_module,
|
|
True,
|
|
)
|
|
|
|
def call(self, graph_module: fx.GraphModule) -> PassResult:
|
|
if not getattr(self, "_initialized", False):
|
|
raise ExportPassBaseError(
|
|
"ExportPass is not initialized with __init__().",
|
|
)
|
|
|
|
inputs = self.inputs(graph_module)
|
|
|
|
fake_tensor_mode = None
|
|
for i in inputs:
|
|
if isinstance(i, FakeTensor):
|
|
assert (
|
|
fake_tensor_mode is None or fake_tensor_mode is i.fake_mode
|
|
), "Multiple fake tensor mode detected."
|
|
fake_tensor_mode = i.fake_mode
|
|
if fake_tensor_mode is None:
|
|
self.tracer.fake_tensor_mode = FakeTensorMode(allow_non_fake_inputs=True)
|
|
fake_tensor_mode = nullcontext() # type: ignore[assignment]
|
|
dispatcher_mode = nullcontext() # type: ignore[assignment]
|
|
else:
|
|
fake_tensor_mode.allow_non_fake_inputs = True
|
|
self.tracer.fake_tensor_mode = fake_tensor_mode
|
|
dispatcher_mode = enable_python_dispatcher() # type: ignore[assignment]
|
|
self.fake_tensor_mode = self.tracer.fake_tensor_mode
|
|
|
|
with fake_tensor_mode, dispatcher_mode: # type: ignore[assignment, union-attr]
|
|
result = self.call_submodule(graph_module, tuple(inputs))
|
|
|
|
return result
|