pytorch/torch/_export/pass_base.py
Edward Z. Yang 3964a3ec73 Complete revamp of float/promotion sympy handling (#126905)
At a high level, the idea behind this PR is:

* Make it clearer what the promotion and int/float rules for various Sympy operations are. Operators that previously were polymorphic over int/float are now split into separate operators for clarity. We never do mixed int/float addition/multiplication etc in sympy, instead, we always promote to the appropriate operator. (However, equality is currently not done correctly.)
* Enforce strict typing on ValueRanges: if you have a ValueRange for a float, the lower and upper MUST be floats, and so forth for integers.

The story begins in **torch/utils/_sympy/functions.py**. Here, I make some changes to how we represent certain operations in sympy expressions:

* FloorDiv now only supports integer inputs; to do float floor division, do a truediv and then a trunc. Additionally, we remove the divide out addition by gcd optimization, because sympy gcd is over fields and is willing to generate rationals (but rationals are bad for ValueRange strict typing).
* ModularIndexing, LShift, RShift now assert they are given integer inputs.
* Mod only supports integer inputs; eventually we will support FloatMod (left for later work, when we build out Sympy support for floating operations). Unfortunately, I couldn't assert integer inputs here, because of a bad interaction with sympy's inequality solver that is used by the offline solver
* TrueDiv is split into FloatTrueDiv and IntTrueDiv. This allows for us to eventually generate accurate code for Python semantics IntTrueDiv, which is written in a special way to preserve precision when the inputs are >= 2**53 beyond what first coercing the integer to floats and then doing true division.
* Trunc is split to TruncToFloat and TruncToInt.
* Round is updated to return a float, not an int, making it consistent with the round op handler in Inductor. To get Python-style conversion to int, we call TruncToInt on the result.
* RoundDecimal updated to consistently only ever return a float
* Add ToFloat for explicit coercion to float (required so we can enforce strict ValueRanges typing)

In **torch/__init__.py**, we modify SymInt and SymFloat to appropriately call into new bindings that route to these refined sympy operations.  Also, we modify `torch.sym_min` and `torch.sym_max` to have promotion semantics (if one argument is a float, the return result is always a float), making them inconsistent with builtins.min/max, but possible to do type analysis without runtime information.

We also need to introduce some new op handlers in **torch/_inductor/ops_handler.py**:

* `to_int` for truncation to int64, directly corresponding to TruncToInt; this can be implemented by trunc and dtype, but with a dedicated handler it is more convenient for roundtripping in Sympy
* `int_truediv` for Python-style integer true division, which has higher precision than casting to floats and then running `truediv`

These changes have consequences. First, we need to make some administrative changes:

* Actually wire up these Sympy functions from SymInt/SymFloat in **torch/fx/experimental/sym_node.py**, including the new promotion rules (promote2)
* Add support for new Sympy functions in **torch/utils/_sympy/interp.py**, **torch/utils/_sympy/reference.py**
  * In particular, in torch.utils._sympy.reference, we have a strong preference to NOT do nontrivial compute, instead, everything in ops handler should map to a singular sympy function
  * TODO: I chose to roundtrip mod back to our Mod function, but I think I'm going to have to deal with the C/Python inconsistency this to fix tests here
* Add printer support for the Sympy functions in **torch/_inductor/codegen/common.py**, **torch/_inductor/codegen/cpp_utils.py**, **torch/_inductor/codegen/triton.py**. `int_truediv` and mixed precision equality is currently not implemented soundly, so we will lose precision in codegen for large values. TODO: The additions here are not exhaustive yet
* Update ValueRanges logic to use new sympy functions in **torch/utils/_sympy/value_ranges.py**. In general, we prefer to use the new Sympy function rather than try to roll things by hand, which is what was done previously for many VR analysis functions.

In **torch/fx/experimental/symbolic_shapes.py** we need to make some symbolic reasoning adjustments:

* Avoid generation of rational subexpressions by removing simplification of `x // y` into `floor(x / y)`. This simplification then triggers an addition simplification rule `(x + y) / c --> x / c + y / c` which is bad because x / c is a rational number now
* `_assert_bound_is_rational` is no more, we no longer generate rational bounds
* Don't intersect non-int value ranges with the `int_range`
* Support more sympy Functions for guard SYMPY_INTERP
* Assert the type of value range is consistent with the variable type

The new asserts uncovered necessary bug fixes:

* **torch/_inductor/codegen/cpp.py**, **torch/_inductor/select_algorithm.py**, **torch/_inductor/sizevars.py** - Ensure Wild/Symbol manually allocated in Inductor is marked `is_integer` so it's accepted to build expressions
* **torch/_inductor/utils.py** - make sure you actually pass in sympy.Expr to these functions
* **torch/_inductor/ir.py** - make_contiguous_strides_for takes int/SymInt, not sympy.Expr!
* **torch/export/dynamic_shapes.py** - don't use infinity to represent int ranges, instead use sys.maxsize - 1

Because of the removal of some symbolic reasoning that produced rationals, some of our symbolic reasoning has gotten worse and we are unable to simplify some guards. Check the TODO at **test/test_proxy_tensor.py**

**Reland notes.** This requires this internal fbcode diff https://www.internalfb.com/phabricator/paste/view/P1403322587 but I cannot prepare the diff codev due to https://fb.workplace.com/groups/osssupport/posts/26343544518600814/

It also requires this Executorch PR https://github.com/pytorch/executorch/pull/3911 but the ET PR can be landed prior to this landing.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126905
Approved by: https://github.com/xadupre, https://github.com/lezcano
2024-06-09 06:20:25 +00:00

442 lines
17 KiB
Python

# mypy: allow-untyped-defs
import operator
import traceback
import typing
from contextlib import nullcontext
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
import torch
from functorch.experimental.control_flow import _unstack_pytree
from torch import fx
from torch._dispatch.python import enable_python_dispatcher
from torch._export.pass_infra.node_metadata import NodeMetadata
from torch._export.pass_infra.proxy_value import ProxyValue
from torch._subclasses import FakeTensor, UnsupportedFakeTensorException
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx import traceback as fx_traceback
from torch.fx.experimental.proxy_tensor import PythonKeyTracer
from torch.fx.graph import CodeGen
from torch.fx.passes.infra.pass_base import PassBase, PassResult
from torch.fx.passes.shape_prop import _extract_tensor_metadata, TensorMetadata
from torch.utils import _pytree as pytree
from torch.fx.experimental.symbolic_shapes import PropagateUnbackedSymInts, compute_unbacked_bindings
__all__ = ["_ExportPassBaseDeprecatedDoNotUse"]
Argument = Any
Value = Any
Fn = Callable[..., Any]
PassType = Callable[[torch.fx.GraphModule], Optional[PassResult]]
_TORCH_SYM_OPS: Set[Callable] = {
torch.sym_int,
torch.sym_float,
torch.sym_ite,
torch.sym_max,
torch.sym_min,
torch.sym_not,
torch.sym_sqrt,
}
class ExportPassBaseError(RuntimeError):
pass
class _ExportPassBaseDeprecatedDoNotUse(PassBase):
"""
Interpreter-based pass class to help users maintain the IR spec while writing
transformations.
"""
@staticmethod
def _create_dummy_node_metadata():
return NodeMetadata({"stack_trace": "".join(traceback.format_stack(limit=1))})
class ExportTracer(PythonKeyTracer):
def __init__(self, callback: "_ExportPassBaseDeprecatedDoNotUse", codegen: CodeGen) -> None:
super().__init__()
self.callback = callback
self.root = torch.nn.Module()
self.graph = torch.fx.Graph()
self.graph.set_codegen(codegen)
self.tensor_attrs: Dict[str, torch.Tensor] = {} # type: ignore[assignment]
self.fake_tensor_mode: Optional[FakeTensorMode] = None
self.submodules: Dict[torch.nn.Module, str] = {}
def trace(self) -> None:
raise ExportPassBaseError("ExportTracer doesn't support trace().")
def create_arg(self, a: Argument) -> torch.fx.Node:
if isinstance(a, torch.nn.Module):
if a not in self.submodules:
name_submodule = f"submodule_{len(self.submodules)}"
self.root.add_module(name_submodule, a)
self.submodules[a] = name_submodule
elif isinstance(a, FakeTensor):
if not hasattr(a, "constant") or a.constant is None:
raise ExportPassBaseError(f"Cannot add {a} to graph.")
a = a.constant
node = super().create_arg(a)
if (
isinstance(a, torch.Tensor)
and isinstance(node, torch.fx.Node)
and node.op == "get_attr"
):
self.set_metadata(node, a)
self.callback.on_attr(ProxyValue(a, node))
return node
def set_metadata(
self, node: torch.fx.Node, value: Argument,
) -> None:
# propagate the fake tensor or sym nodes
def make_val(
x: Argument,
) -> Union[FakeTensor, torch.SymInt, torch.SymFloat, torch.SymBool, int, float, bool, str, None]:
if isinstance(x, FakeTensor):
return x
elif isinstance(x, torch.Tensor):
if x.is_quantized:
# TODO (tmanlaibaatar) properly support Quantized FakeTensor
x = torch.dequantize(x)
try:
assert self.fake_tensor_mode is not None
# TODO we should allocate static shapes
# for param/buffer values
if isinstance(x, torch.nn.Parameter):
fake_tensor = self.fake_tensor_mode.from_tensor(
x, static_shapes=True
)
else:
fake_tensor = self.fake_tensor_mode.from_tensor(x)
except UnsupportedFakeTensorException:
# TODO: This is just a workaround to get over the
# x.as_subclass error
print(
"Fakeifying a Tensor subclass is not supported \
right now. Instead a TensorMetadata is used."
)
fake_tensor = None
return fake_tensor
elif isinstance(x, (torch.SymInt, torch.SymFloat, torch.SymBool, int, float, bool, str)):
return x
else:
return None
node.meta["val"] = pytree.tree_map(make_val, value)
# Set the tensor_metadata for values that do not have a corresponding FakeTensor
def make_tensor_meta(x: Argument) -> Optional[TensorMetadata]:
if not isinstance(x, FakeTensor) and isinstance(x, torch.Tensor):
if x.is_quantized:
# TODO (tmanlaibaatar) properly support Quantized FakeTensor
x = torch.dequantize(x)
try:
assert self.fake_tensor_mode is not None
_ = self.fake_tensor_mode.from_tensor(x)
tensor_meta = None
except UnsupportedFakeTensorException:
# TODO: This is just a workaround to get over the
# x.as_subclass error
tensor_meta = _extract_tensor_metadata(x)
return tensor_meta
else:
return None
node.meta["tensor_meta"] = pytree.tree_map(make_tensor_meta, value)
class ExportInterpreter(fx.Interpreter):
def __init__(self, callback: "_ExportPassBaseDeprecatedDoNotUse", gm: fx.GraphModule) -> None:
super().__init__(gm)
self.callback = callback
self.node: torch.fx.Node = next(iter(gm.graph.nodes))
def placeholder(
self,
target: str,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
) -> ProxyValue:
arg = super().placeholder(target, args, kwargs)
return self.callback.placeholder(target, arg, NodeMetadata(self.node.meta))
def output(
self,
target: torch.fx.node.Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
) -> ProxyValue:
return self.callback.output(args[0], NodeMetadata(self.node.meta)).data
def call_function(
self,
target: torch.fx.node.Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
) -> ProxyValue:
meta = NodeMetadata(self.node.meta)
if target == operator.getitem:
value, key = args
return self.callback.call_getitem(value, key, meta)
elif getattr(target, "__module__", None) in {"_operator", "math"}:
assert callable(target)
return self.callback.call_sym(target, args, meta)
elif target in _TORCH_SYM_OPS:
assert callable(target)
return self.callback.call_sym(target, args, meta)
elif isinstance(target, (torch._ops.OpOverload, torch._ops.OpOverloadPacket)):
return self.callback.call_operator(
target,
args,
kwargs,
meta,
)
elif target == torch.ops.higher_order.cond:
pred, true_fn, false_fn, inputs = args
return self.callback.call_cond(pred, true_fn, false_fn, inputs, meta)
elif target == torch.ops.higher_order.map_impl:
f, mapped_args, operands = args # type: ignore[assignment]
return self.callback.call_map(f, mapped_args, operands, meta)
# For other unregistered HigherOrderOps, just interpret them blindly
elif isinstance(target, torch._ops.HigherOrderOperator):
return self.callback._fx(
"call_function",
target,
args,
kwargs,
meta,
)
else:
raise ExportPassBaseError(f"Unsupported target type: {target}")
def get_attr(
self, target: str, args: Tuple[Argument, ...], kwargs: Dict[str, Argument]
) -> Argument:
return super().get_attr(target, args, kwargs)
def call_module(
self,
target: torch.fx.node.Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
) -> None:
raise ExportPassBaseError("call_module is not supported.")
def call_method(
self, target: str, args: Tuple[Argument, ...], kwargs: Dict[str, Argument]
) -> None:
raise ExportPassBaseError("call_method is not supported.")
def run_node(self, n: torch.fx.Node) -> Argument:
self.node = n
self.callback.node_debug_str = n.format_node()
return super().run_node(n)
def __init__(self) -> None:
self.interpreter = PropagateUnbackedSymInts(
torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph())
)
self.tracer = self.ExportTracer(self, CodeGen())
self.fake_tensor_mode: Optional[FakeTensorMode] = None
self._initialized = True
self.node_debug_str: typing.Optional[str] = None
def _fx(
self,
kind: str,
target: torch.fx.node.Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
meta: NodeMetadata,
) -> ProxyValue:
args_data, kwargs_data = pytree.tree_map_only(
ProxyValue, lambda x: x.data, (args, kwargs)
)
res_data = getattr(self.interpreter, kind)(target, args_data, kwargs_data)
args_proxy, kwargs_proxy = pytree.tree_map_only(
ProxyValue, lambda x: x.proxy, (args, kwargs)
)
name = None
if isinstance(target, torch._ops.OpOverload):
name = self.tracer.graph._target_to_str(target.overloadpacket.__name__)
res_proxy = self.tracer.create_proxy(kind, target, args_proxy, kwargs_proxy, name=name)
res_proxy.node.meta.update(meta.data)
if self.fake_tensor_mode and (shape_env := self.fake_tensor_mode.shape_env):
if symbol_to_path := compute_unbacked_bindings(shape_env, res_data):
res_proxy.node.meta["unbacked_bindings"] = symbol_to_path
self.tracer.set_metadata(res_proxy.node, res_data)
return ProxyValue(res_data, res_proxy)
def inputs(self, graph_module: torch.fx.GraphModule) -> List[Argument]:
# TODO(angelayi): Update this with what we decide to do for metadata in
# the exported graph module
if (args := graph_module.meta.get("args", None)) is not None:
return list(args)
def extract_input(node: torch.fx.Node) -> Optional[FakeTensor]:
if "val" in node.meta:
fake = node.meta["val"]
if hasattr(fake, "constant") and fake.constant is not None:
return fake.constant
return fake
elif tensor_meta := node.meta.get("tensor_meta"):
assert self.fake_tensor_mode is not None
return FakeTensor(
self.fake_tensor_mode,
torch.empty(
tensor_meta.shape,
dtype=tensor_meta.dtype,
device="meta",
requires_grad=tensor_meta.requires_grad,
memory_format=tensor_meta.memory_format,
),
torch.device("cpu"),
)
elif len(node.users) == 0:
return None
raise ExportPassBaseError(
f"Cannot construct an input for graph module: {graph_module}.",
)
return [
extract_input(node)
for node in graph_module.graph.nodes
if node.op == "placeholder"
]
def on_attr(self, attr: ProxyValue) -> None:
pass
def placeholder(self, name: str, arg: Argument, meta: NodeMetadata) -> ProxyValue:
arg_proxy = self.tracer.create_proxy("placeholder", name, (), {})
arg_proxy.node.meta = meta.data
self.tracer.set_metadata(arg_proxy.node, arg)
return ProxyValue(arg, arg_proxy)
def call_operator(
self,
op,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
meta: NodeMetadata,
) -> ProxyValue:
return self._fx("call_function", op, args, kwargs, meta)
def call_sym(
self,
target: Fn,
args: Tuple[Argument, ...],
meta: NodeMetadata,
) -> ProxyValue:
return self._fx("call_function", target, args, {}, meta)
def call_cond(
self,
pred: ProxyValue,
true_fn: torch.fx.GraphModule,
false_fn: torch.fx.GraphModule,
inputs: List[Argument],
meta: NodeMetadata,
) -> ProxyValue:
true_branch = self.call_submodule(true_fn, tuple(inputs))
false_branch = self.call_submodule(false_fn, tuple(inputs))
assert true_branch is not None
assert false_branch is not None
return self._fx(
"call_function",
torch.ops.higher_order.cond,
(pred, true_branch.graph_module, false_branch.graph_module, list(inputs)),
{},
meta,
)
def call_map(
self,
f: torch.fx.GraphModule,
mapped_args: List[ProxyValue],
operands: List[ProxyValue],
meta: NodeMetadata,
) -> ProxyValue:
xs = _unstack_pytree([arg.data for arg in mapped_args])[0]
f_branch = self.call_submodule(f, tuple(xs + [arg.data for arg in operands]))
assert f_branch is not None
return self._fx(
"call_function",
torch.ops.higher_order.map_impl,
(f_branch.graph_module, mapped_args, operands),
{},
meta,
)
def call_getitem(
self, value: ProxyValue, key: int, meta: NodeMetadata
) -> ProxyValue:
return self._fx("call_function", operator.getitem, (value, key), {}, meta)
def output(self, results: List[Argument], meta: NodeMetadata) -> ProxyValue:
return self._fx("output", "output", (results,), {}, meta)
def call_submodule(
self, graph_module: fx.GraphModule, inputs: Tuple[Argument, ...]
) -> PassResult:
prev_tracer, self.tracer = self.tracer, self.ExportTracer(
self, graph_module.graph._codegen
)
self.tracer.fake_tensor_mode = prev_tracer.fake_tensor_mode
interpreter = self.ExportInterpreter(self, graph_module)
prev_interpreter, self.interpreter = self.interpreter, torch.fx.Interpreter(
torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph())
)
inputs_data = pytree.tree_map_only(ProxyValue, lambda x: x.data, inputs)
with fx_traceback.preserve_node_meta():
interpreter.run(*inputs_data)
new_graph_module = torch.fx.GraphModule(self.tracer.root, self.tracer.graph)
self.tracer = prev_tracer
self.interpreter = prev_interpreter
return PassResult(
new_graph_module,
True,
)
def call(self, graph_module: fx.GraphModule) -> PassResult:
if not getattr(self, "_initialized", False):
raise ExportPassBaseError(
"ExportPass is not initialized with __init__().",
)
inputs = self.inputs(graph_module)
fake_tensor_mode = None
for i in inputs:
if isinstance(i, FakeTensor):
assert (
fake_tensor_mode is None or fake_tensor_mode is i.fake_mode
), "Multiple fake tensor mode detected."
fake_tensor_mode = i.fake_mode
if fake_tensor_mode is None:
self.tracer.fake_tensor_mode = FakeTensorMode(allow_non_fake_inputs=True)
fake_tensor_mode = nullcontext() # type: ignore[assignment]
dispatcher_mode = nullcontext() # type: ignore[assignment]
else:
fake_tensor_mode.allow_non_fake_inputs = True
self.tracer.fake_tensor_mode = fake_tensor_mode
dispatcher_mode = enable_python_dispatcher() # type: ignore[assignment]
self.fake_tensor_mode = self.tracer.fake_tensor_mode
with fake_tensor_mode, dispatcher_mode: # type: ignore[assignment, union-attr]
result = self.call_submodule(graph_module, tuple(inputs))
return result