diff --git a/docs/source/fx.experimental.md b/docs/source/fx.experimental.md index 24125cd310b..cba695b5e1c 100644 --- a/docs/source/fx.experimental.md +++ b/docs/source/fx.experimental.md @@ -8,6 +8,10 @@ These APIs are experimental and subject to change without notice. ::: +```{eval-rst} +.. autoclass:: torch.fx.experimental.sym_node.DynamicInt +``` + ## torch.fx.experimental.symbolic_shapes ```{eval-rst} diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index 0e90587822d..3fee860a798 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -1818,6 +1818,96 @@ class TestSymNumberMagicMethods(TestCase): self.assertTrue(isinstance(s3, int)) self.assertTrue(str(s1.node.expr) != str(s2.node.expr)) + @fresh_cache() + @torch._dynamo.config.patch("capture_scalar_outputs", True) + @parametrize("backend", ["inductor", "eager"]) + def test_dynamic_int_basic_compile(self, backend): + from torch.fx.experimental.sym_node import DynamicInt + + cnt = CompileCounterWithBackend(backend) + + # test scalar inputs to function + def f(x, y, z): + out = torch.tensor([x + y + z]) + out = out + torch.zeros(abs(x) + 2).sum() # test out tensor construction + return out + + fn = torch.compile(f, fullgraph=True, backend=cnt) + x = DynamicInt(1) + z = DynamicInt(3) + self.assertEqual(fn(x, x, z), f(1, 1, 3)) # guard: x == y + self.assertEqual(fn(2, 2, 0), f(2, 2, 0)) + self.assertEqual(fn(-1, -1, 2), f(-1, -1, 2)) + self.assertEqual(cnt.frame_count, 1) # no recompiles + + self.assertEqual(fn(3, 4, 5), f(3, 4, 5)) # now we recompile + self.assertEqual(cnt.frame_count, 2) + + # test nn module property + class Foo(torch.nn.Module): + def __init__(self): + super().__init__() + self.i = DynamicInt(1) + + def forward(self, x): + return torch.tensor([x + self.i]) + + cnt.clear() + m = Foo() + mc = torch.compile(m, backend=cnt, fullgraph=True) + + self.assertEqual(mc(DynamicInt(0)), m(0)) + mc.i = -2 # override attribute + self.assertEqual(mc(-1), m(-1)) + self.assertEqual(cnt.frame_count, 1) + + def test_dynamic_int_eager_usage(self): + from torch.fx.experimental.sym_node import DynamicInt + + w = DynamicInt(-1) + x = DynamicInt(0) + y = DynamicInt(1) + z = DynamicInt(2) + + def check(l, r): + self.assertTrue(isinstance(l, DynamicInt)) + self.assertEqual(l, r) + + # test arithmetic + check(2 * y + z, 4) + check((10 - z) // 2, 4) + check(1 // z, 0) + check(-w + w**2, 2) + check(x % z, 0) + check(1 << z, 4) + check(z | y, 3) + check(min(y, z), 1) + self.assertTrue(z > -2) + with self.assertRaises(ZeroDivisionError): + y % x + + # math, numpy + self.assertEqual(math.cos(x), y) + self.assertEqual(math.prod([z, z], start=z), 8) + self.assertEqual(np.arange(z)[y], 1) + self.assertTrue(np.allclose(np.ones([y, z]).sum(axis=x), np.ones(z))) + + # test conversions + self.assertTrue(isinstance(x + 2, int)) + self.assertTrue(isinstance(x + 2, DynamicInt)) + self.assertEqual(y / 2.0, 0.5) # this could return DynamicFloat in future + self.assertEqual(float(z), 2.0) + self.assertFalse(bool(x)) + self.assertEqual(DynamicInt(x).real, x.real) + + # torch functions, scalar inputs + self.assertEqual(torch.arange(z)[:w][x], 0) + self.assertEqual(torch.add(torch.tensor(w), torch.tensor(w), alpha=z), -3) + self.assertEqual( + list(torch.nn.Linear(z, y)(torch.randn(z * 2, z)).shape), [4, 1] + ) + self.assertEqual(z * torch.ones(z).sum(dim=x), 4) + instantiate_parametrized_tests(TestSymNumberMagicMethods) diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index bf5ba4be497..1e23e4e4e4e 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -136,6 +136,7 @@ from .source import ( DefaultsSource, DictGetItemSource, DictSubclassGetItemSource, + DynamicScalarSource, FlattenScriptObjectSource, FloatTensorSource, FSDPNNModuleSource, @@ -1719,6 +1720,14 @@ class GuardBuilder(GuardBuilderBase): example_value=example_value, guard_manager_enum=guard_manager_enum, ) + elif istype(source, DynamicScalarSource): + assert base_guard_manager + out = base_guard_manager.lambda_manager( + python_lambda=lambda x: int(x), + source=source_name, + example_value=example_value, + guard_manager_enum=guard_manager_enum, + ) else: raise AssertionError( f"missing guard manager builder {source} - {source.name()}" diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index 8a162942fe7..8c0c7633fa6 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -2698,6 +2698,9 @@ class SubgraphTracer(fx.Tracer): # tracer is the current tracer that's readily accessible in current tracer's graph. self.bound_symbols: dict[sympy.Symbol, Union[torch.fx.Proxy, LazyProxy]] = {} + # Maps _DynamicScalar object ids to allocated SymInt nodes, for symbol reuse + self.dynamic_scalar_nodes: dict[int, torch.SymInt] = {} + self.prev_inst = None # True if this tracer is currently tracing into torch.utils.checkpoint # as part of speculate_subgraph. diff --git a/torch/_dynamo/source.py b/torch/_dynamo/source.py index b17ccfe09da..559972464f8 100644 --- a/torch/_dynamo/source.py +++ b/torch/_dynamo/source.py @@ -526,6 +526,29 @@ class ConvertIntSource(ChainedSource): return f"cast_symbool_to_symint_guardless({self.base.name()})" +@dataclasses.dataclass(frozen=True) +class DynamicScalarSource(ChainedSource): + is_int: bool + + def __post_init__(self) -> None: + assert self.base is not None + + def reconstruct(self, codegen: "PyCodegen") -> None: + # Integer casting at reconstruction helps reduce the amount of DynamicInts returned + # to the user, in favor of plain ints. + # For example, a compiled region that only does int arithmetic could return a + # DynamicInt without the casting here. + codegen.add_push_null(lambda: codegen.load_import_from("builtins", "int")) + codegen(self.base) + codegen.extend_output(create_call_function(1, False)) + + def guard_source(self) -> GuardSource: + return self.base.guard_source() + + def name(self) -> str: + return f"int({self.base.name()})" + + @dataclasses.dataclass(frozen=True) class FlattenScriptObjectSource(ChainedSource): def __post_init__(self) -> None: diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index 660042b33b8..547e826f343 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -60,6 +60,7 @@ from torch._subclasses.meta_utils import is_sparse_any, safe_grad from torch._utils_internal import justknobs_check from torch.fx.experimental._backward_state import BackwardState from torch.fx.experimental._dynamism import normalize_source_name +from torch.fx.experimental.sym_node import _DynamicScalar, DynamicInt from torch.fx.experimental.symbolic_shapes import ( _constrain_range_for_size, _nested_int_aware_sort, @@ -101,6 +102,7 @@ from ..source import ( ConvertIntSource, DictGetItemSource, DictSubclassGetItemSource, + DynamicScalarSource, FloatTensorSource, GetItemSource, GradSource, @@ -456,7 +458,9 @@ class VariableBuilder: # should NOT track them. If we use a single SymNodeVariable instance to track them # across multiple uses, then guards created for one usage will incorrectly apply to # all other usages of that constant, leading to unnecessary recompilations. - return is_torch_sym(value) and isinstance(vt, SymNodeVariable) + return ( + is_torch_sym(value) or isinstance(value, _DynamicScalar) + ) and isinstance(vt, SymNodeVariable) if ( ( @@ -1103,6 +1107,46 @@ class VariableBuilder: ): self.install_guards(GuardBuilder.FUNCTION_MATCH) return ItertoolsVariable(value, source=self.source) + elif isinstance(value, _DynamicScalar): + is_int = isinstance(value, DynamicInt) + source = DynamicScalarSource(self.source, is_int) + if id(value) in self.tx.output.root_tracer.dynamic_scalar_nodes: + # If we've already seen this dynamic scalar, reuse the existing + # SymInt/SymFloat node. + node = self.tx.output.root_tracer.dynamic_scalar_nodes[id(value)] + else: + sym = self.tx.output.shape_env.create_unspecified_symbol( + value.real, + source=source, + dynamic_dim=DimDynamic.DYNAMIC, + ) + node = self.tx.output.shape_env.create_symintnode( + sym, + hint=value.real, + source=source, + ) + + # Bind to graph input + sym_node_proxy = self.tx.output.root_tracer.create_graph_input( + re.sub(r"[^a-zA-Z0-9]+", "_", self.name), + type(node), + node, + source=source, + ) + sym_node_proxy.node.meta["grapharg"] = GraphArg( + source, + node, + False, + None, + is_tensor=False, + example_strong_ref=node, + ) + sym_expr = node.node.expr + assert isinstance(sym_expr, sympy.Symbol), ( + f"{sym_expr} is not a basic Symbol." + ) + self.tx.output.tracked_fakes.append(TrackedFake(node, source, None)) + return SymNodeVariable(sym_node_proxy, node) elif is_torch_sym(value): # Note: this doesn't handle nested symints. # For SymBool input, we reuse the infra for SymInt by simulating SymBool with a SymInt in dynamo. diff --git a/torch/csrc/utils/python_arg_parser.cpp b/torch/csrc/utils/python_arg_parser.cpp index 53cee2632b1..a51cfaf8c5c 100644 --- a/torch/csrc/utils/python_arg_parser.cpp +++ b/torch/csrc/utils/python_arg_parser.cpp @@ -936,6 +936,9 @@ static bool is_int_or_symint(PyObject* obj) { if (torch::is_symint(py::handle(obj))) { return true; } + if (torch::is_dynint(py::handle(obj))) { + return true; + } // FakeTensor(..., size=()) is qualified for SymInt param, // but we can't go via __index__ (below) as we would normally @@ -1070,7 +1073,8 @@ auto FunctionParameter::_check( return !var.requires_grad() && var.dim() == 0; } if (torch::is_symfloat(py::handle(obj)) || - torch::is_symint(py::handle(obj))) { + torch::is_symint(py::handle(obj)) || + torch::is_dynint(py::handle(obj))) { // This will induce a guard return true; } @@ -1085,7 +1089,8 @@ auto FunctionParameter::_check( return at::isIntegralType(var.scalar_type(), /*includeBool=*/false) && !var.requires_grad() && var.dim() == 0; } - if (torch::is_symint(py::handle(obj))) { + if (torch::is_symint(py::handle(obj)) || + torch::is_dynint(py::handle(obj))) { // This will induce a guard return true; } @@ -1127,7 +1132,8 @@ auto FunctionParameter::_check( // Allow symint to be passed in as device, but we'll specialize and // guard in this case. return THPUtils_checkLong(obj) || THPUtils_checkString(obj) || - THPDevice_Check(obj) || torch::is_symint(py::handle(obj)); + THPDevice_Check(obj) || torch::is_symint(py::handle(obj)) || + torch::is_dynint(py::handle(obj)); case ParameterType::STREAM: return THPStream_Check(obj); case ParameterType::STRING: @@ -1881,7 +1887,8 @@ at::Tensor PythonArgs::tensor_slow(int i) { // NB: we DO NOT put symbolic ints/floats into the Scalar itself, // because although Scalar supports SymInt/SymFloat, the subsequent // conversion to Tensor does not. Instead, do it out of band. - } else if (torch::is_symint(py::handle(obj))) { + } else if ( + torch::is_symint(py::handle(obj)) || torch::is_dynint(py::handle(obj))) { save_symint = true; // This scalar value doesn't matter, it shouldn't ever actually // get read out. Make it a big and weird looking number to help @@ -1969,6 +1976,10 @@ at::Scalar PythonArgs::scalar_slow(PyObject* arg) { return at::Scalar(py::cast(arg)); } + if (torch::is_dynint(arg)) { + return at::Scalar(py::cast(arg)); + } + if (torch::is_symfloat(arg)) { return at::Scalar(py::cast(arg)); } diff --git a/torch/csrc/utils/python_arg_parser.h b/torch/csrc/utils/python_arg_parser.h index a81f861ae90..5887235f72e 100644 --- a/torch/csrc/utils/python_arg_parser.h +++ b/torch/csrc/utils/python_arg_parser.h @@ -89,7 +89,7 @@ inline bool THPUtils_checkScalar(PyObject* obj) { } #endif return PyFloat_Check(obj) || PyLong_Check(obj) || PyComplex_Check(obj) || - torch::is_symint(py::handle(obj)) || + torch::is_symint(py::handle(obj)) || torch::is_dynint(py::handle(obj)) || torch::is_symfloat(py::handle(obj)) || torch::is_symbool(py::handle(obj)); } @@ -612,6 +612,8 @@ inline std::vector PythonArgs::symintlist(int i) { try { if (is_symint(py::handle(obj))) { res.push_back(py::handle(obj).cast()); + } else if (is_dynint(py::handle(obj))) { + res.push_back(py::handle(obj).cast()); } else { res.emplace_back(THPUtils_unpackIndex(obj)); } @@ -640,6 +642,9 @@ inline std::vector PythonArgs::intlistWithDefault( size1, py::handle(arg).cast().guard_int(__FILE__, __LINE__)); } + if (size1 > 0 && torch::is_dynint(py::handle(arg))) { + return std::vector(size1, py::handle(arg).cast()); + } auto tuple = PyTuple_Check(arg); // NOLINTNEXTLINE(bugprone-branch-clone) const auto size2 = tuple ? PyTuple_GET_SIZE(arg) : PyList_GET_SIZE(arg); @@ -672,6 +677,8 @@ inline std::vector PythonArgs::intlistWithDefault( } else if (torch::is_symint(py::handle(obj))) { res[idx] = py::cast(py::handle(obj)) .guard_int(__FILE__, __LINE__); + } else if (torch::is_dynint(py::handle(obj))) { + res[idx] = py::handle(obj).cast(); } else if (THPVariable_Check(obj)) { auto& var = THPVariable_Unpack(obj); if (var.numel() != 1 || @@ -846,6 +853,10 @@ inline at::Device toDevice(PyObject* obj) { py::cast(py::handle(obj)).guard_int(__FILE__, __LINE__); return deviceFromLong(device_index); } + if (torch::is_dynint(py::handle(obj))) { + auto device_index = py::cast(py::handle(obj)); + return deviceFromLong(device_index); + } const std::string& device_str = THPUtils_unpackString(obj); return at::Device(device_str); } @@ -982,6 +993,9 @@ inline int64_t PythonArgs::toInt64(int i) { return py::cast(py::handle(args[i])) .guard_int(__FILE__, __LINE__); } + if (torch::is_dynint(py::handle(args[i]))) { + return py::cast(py::handle(args[i])); + } return THPUtils_unpackLong(args[i]); } @@ -1055,6 +1069,9 @@ inline double PythonArgs::toDouble(int i) { return static_cast(py::cast(py::handle(args[i])) .guard_int(__FILE__, __LINE__)); } + if (torch::is_dynint(py::handle(args[i]))) { + return static_cast(py::cast(py::handle(args[i]))); + } return THPUtils_unpackDouble(args[i]); } diff --git a/torch/csrc/utils/python_symnode.cpp b/torch/csrc/utils/python_symnode.cpp index 2c12e730abb..9e17f8166a4 100644 --- a/torch/csrc/utils/python_symnode.cpp +++ b/torch/csrc/utils/python_symnode.cpp @@ -53,4 +53,24 @@ py::handle get_symbool_class() { #endif } +py::handle get_dynint_class() { + // NB: leak +#if IS_PYBIND_2_13_PLUS + PYBIND11_CONSTINIT static py::gil_safe_call_once_and_store + storage; + return storage + .call_once_and_store_result([]() -> py::object { + return py::module::import("torch.fx.experimental.sym_node") + .attr("DynamicInt"); + }) + .get_stored(); +#else + static py::handle symbool_class = + py::object(py::module::import("torch.fx.experimental.sym_node") + .attr("DynamicInt")) + .release(); + return symbool_class; +#endif +} + } // namespace torch diff --git a/torch/csrc/utils/python_symnode.h b/torch/csrc/utils/python_symnode.h index 69d03b9b7a4..4b023744677 100644 --- a/torch/csrc/utils/python_symnode.h +++ b/torch/csrc/utils/python_symnode.h @@ -12,6 +12,7 @@ namespace torch { TORCH_PYTHON_API py::handle get_symint_class(); TORCH_PYTHON_API py::handle get_symfloat_class(); TORCH_PYTHON_API py::handle get_symbool_class(); +TORCH_PYTHON_API py::handle get_dynint_class(); // NB: These functions must not be called too early, otherwise torch not setup. // Alternate design is to have torch "register" the object to us @@ -24,6 +25,9 @@ inline bool is_symfloat(py::handle obj) { inline bool is_symbool(py::handle obj) { return py::isinstance(obj, get_symbool_class()); } +inline bool is_dynint(py::handle obj) { + return py::isinstance(obj, get_dynint_class()); +} namespace impl { diff --git a/torch/fx/experimental/sym_node.py b/torch/fx/experimental/sym_node.py index 5468191163a..b6c19b9ddeb 100644 --- a/torch/fx/experimental/sym_node.py +++ b/torch/fx/experimental/sym_node.py @@ -49,7 +49,7 @@ log = logging.getLogger(__name__) sym_node_log = torch._logging.getArtifactLogger(__name__, "sym_node") -__all__ = ["SymNode", "method_to_operator", "magic_methods"] +__all__ = ["SymNode", "method_to_operator", "magic_methods", "DynamicInt"] from torch.types import py_sym_types as SymTypes @@ -625,6 +625,40 @@ class SymNode: return False +class _DynamicScalar: + def __new__(cls, *args): + if cls is _DynamicScalar: + raise TypeError("_DynamicScalar is an abstract base class, use DynamicInt.") + return super().__new__(cls, *args) + + +class DynamicInt(_DynamicScalar, int): + """ + User API for marking dynamic integers in `torch.compile`. + Intended to be compatible with both compile and eager mode. + + Example usage:: + + fn = torch.compile(f) + x = DynamicInt(4) + fn(x) # compiles x as a dynamic integer input; returns f(4) + """ + + def __new__(cls, val): + assert isinstance(val, int) + obj = super().__new__(cls, int(val)) + return obj + + def __repr__(self): + return f"DynamicInt({self.real})" + + def __floordiv__(self, other): # // was casting to int without these overrides? + return DynamicInt(self.real // other) + + def __rfloordiv__(self, other): + return DynamicInt(other // self.real) + + # TODO: this probably needs the sizes-strides eval functions METHOD_TO_OPERATOR = { "pos": operator.pos, @@ -1650,7 +1684,6 @@ for method, func in sizes_strides_methods.items(): def _make_user_magic(method, user_type): # User magic takes care of wrapping the other operand into a node, # so that our internal logic can assume everything is nodes - if method in magic_methods_on_operator_with_trailing_underscore: method_attr = f"sym_{method}" else: @@ -1781,7 +1814,7 @@ def _make_user_magic(method, user_type): other = promote(other) self, other = promote2(self, other) if is_constant(self): - return (method_to_operator(method))(get_constant(self), other) + return (method_to_operator(method))(other, get_constant(self)) if is_constant(other): other = get_constant(other) other_node = to_node(self.node, other) @@ -1790,11 +1823,31 @@ def _make_user_magic(method, user_type): ret = wrap_node(getattr(other_node, method_attr)(self.node)) return get_constant(ret) if is_constant(ret) else ret + def setattrs(user_type, attr, symnode_impl): + """ + Registers the SymNode magic method on SymInt/Float/Bool, + and optionally registers a corresponding wrapped method on DynamicInt. + """ + + # SymInt/Float/Bool + setattr(user_type, attr, symnode_impl) + + # DynamicInt impl + def dynamic_int_impl(*args): + args = [x.real if isinstance(x, DynamicInt) else x for x in args] + out = getattr(int, attr)(*args) + if isinstance(out, int) and not isinstance(out, bool): + return DynamicInt(out) + return out + + if user_type is SymInt: + setattr(DynamicInt, attr, dynamic_int_impl) + if method in unary_magic_methods: - setattr(user_type, f"__{method}__", unary_magic_impl) + setattrs(user_type, f"__{method}__", unary_magic_impl) elif method in unary_nonmagic_methods: orig = getattr(user_type, method) - setattr(user_type, method, update_wrapper(unary_magic_impl, orig)) + setattrs(user_type, method, update_wrapper(unary_magic_impl, orig)) elif method == "sym_ite": def sym_ite_magic_impl(pred, then_val, else_val): @@ -1811,7 +1864,7 @@ def _make_user_magic(method, user_type): ret = wrap_node(getattr(pred.node, method_attr)(then_node, else_node)) return get_constant(ret) if ret.node.is_constant() else ret - setattr(user_type, f"__{method}__", sym_ite_magic_impl) + setattrs(user_type, f"__{method}__", sym_ite_magic_impl) elif method == "round": def round_magic_impl(self, ndigits=None): @@ -1820,14 +1873,14 @@ def _make_user_magic(method, user_type): return wrap_node(getattr(self.node, method)(ndigits)) - setattr(user_type, f"__{method}__", round_magic_impl) + setattrs(user_type, f"__{method}__", round_magic_impl) else: method_name = method if method in bitwise_ops: method_name = bitwise_ops[method] - setattr(user_type, f"__{method_name}__", binary_magic_impl) + setattrs(user_type, f"__{method_name}__", binary_magic_impl) if method in reflectable_magic_methods: - setattr(user_type, f"__r{method_name}__", rbinary_magic_impl) + setattrs(user_type, f"__r{method_name}__", rbinary_magic_impl) for method, func in magic_methods.items(): # type: ignore[assignment]