From 27dee935afe02cb79f97c6c284b346cf9da7071b Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Tue, 8 Oct 2024 14:11:03 -0700 Subject: [PATCH] [Dynamo] Ensure torch function modes are dispatched on builtin ops (#137117) Pull Request resolved: https://github.com/pytorch/pytorch/pull/137117 Approved by: https://github.com/yanboliang, https://github.com/williamwen42 ghstack dependencies: #137114, #137115, #137116 --- .../pr_time_benchmarks/expected_results.csv | 4 +- test/dynamo/test_modes.py | 97 +++++++++++++ test/export/test_experimental.py | 4 +- test/export/test_export.py | 3 + test/functorch/test_control_flow.py | 2 + torch/_dynamo/backends/debugging.py | 14 +- torch/_dynamo/variables/builtin.py | 35 ++++- torch/_dynamo/variables/tensor.py | 24 ++++ torch/_dynamo/variables/torch.py | 4 + torch/_dynamo/variables/torch_function.py | 128 ++++++++++++++++++ torch/_higher_order_ops/strict_mode.py | 22 ++- 11 files changed, 326 insertions(+), 11 deletions(-) diff --git a/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv b/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv index 6e2a281aa22..41649541c7e 100644 --- a/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv +++ b/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv @@ -1,5 +1,5 @@ -add_loop_eager, compile_time_instruction_count, 2834456320, 0.015 -add_loop_eager_dynamic, compile_time_instruction_count, 5528896630, 0.025 +add_loop_eager, compile_time_instruction_count, 3004749893, 0.015 +add_loop_eager_dynamic, compile_time_instruction_count, 5726573328, 0.025 add_loop_inductor, compile_time_instruction_count, 24146845503, 0.015 add_loop_inductor_dynamic_gpu, compile_time_instruction_count, 39411706509, 0.025 add_loop_inductor_gpu, compile_time_instruction_count, 22171041650, 0.015 diff --git a/test/dynamo/test_modes.py b/test/dynamo/test_modes.py index 4d1f2bbea38..64c97916822 100644 --- a/test/dynamo/test_modes.py +++ b/test/dynamo/test_modes.py @@ -1,5 +1,8 @@ # Owner(s): ["module: dynamo"] +import operator +from unittest.mock import patch + import torch import torch._dynamo.test_case import torch._dynamo.testing @@ -484,6 +487,100 @@ class TorchFunctionModeTests(torch._dynamo.test_case.TestCase): self.assertEqual(expected, actual) + # Needs larger cache size since we recompile for each op + @patch.object(torch._dynamo.config, "cache_size_limit", 48) + def test_builtin_equivalent_funcs(self): + from torch._dynamo.variables.torch_function import ( + bin_int_ops, + bin_ops, + BUILTIN_TO_TENSOR_FN_MAP, + BUILTIN_TO_TENSOR_RFN_MAP, + tensor_and_int_ops, + un_int_ops, + un_ops, + ) + + expected_func = None + valid = False + + class FuncEquivMode(BaseTorchFunctionMode): + def __torch_function__(self, func, types, args=(), kwargs=None): + nonlocal expected_func + nonlocal valid + if not kwargs: + kwargs = {} + if torch._dynamo.is_compiling(): + valid = expected_func == func + return super().__torch_function__(func, types, args, kwargs) + + inp0 = torch.ones(1, 1) + inp1 = torch.ones(1, 1) + inp0_int = torch.ones(1, 1, dtype=torch.int32) + inp1_int = torch.ones(1, 1, dtype=torch.int32) + + @torch.compile(fullgraph=True) + def fn_un(op, inp): + return op(inp) + + @torch.compile(fullgraph=True) + def fn_un_int(op, inp): + return op(inp) + + @torch.compile(fullgraph=True) + def fn_bin(op, inp0, inp1): + return op(inp0, inp1) + + @torch.compile(fullgraph=True) + def fn_bin_int(op, inp0, inp1): + return op(inp0, inp1) + + @torch.compile(fullgraph=True) + def fn_tensor_and_int(op, inp0, inp1): + return op(inp0, inp1) + + setups_and_oplists = [ + (lambda o: fn_un(o, inp0), un_ops), + (lambda o: fn_un_int(o, inp0_int), un_int_ops), + (lambda o: fn_bin(o, inp0, inp1), bin_ops), + (lambda o: fn_bin_int(o, inp0_int, inp1_int), bin_int_ops), + (lambda o: fn_tensor_and_int(o, inp0_int, 0), tensor_and_int_ops), + ] + + # gather the reverse functions + rsetups_and_oplists = [ + ( + lambda o: fn_bin(o, 1, inp1), + bin_ops, + ), # Get r* ops, (ex. __sub__(int, Tensor) -> __rsub__(Tensor, int)) + (lambda o: fn_bin_int(o, 1, inp1_int), bin_int_ops), + (lambda o: fn_tensor_and_int(o, 0, inp0_int), tensor_and_int_ops), + ] + + skips = {operator.not_} # Has local scalar dense call which graph breaks + rskips = { + operator.matmul, + operator.imatmul, + operator.getitem, + } # Doesn't type check with reversed args + + def run_checks(setups_and_oplists, skips, ref_map): + nonlocal valid + nonlocal expected_func + for setup_fn, op_list in setups_and_oplists: + for op in op_list: + if op in skips or op not in ref_map: + continue + with FuncEquivMode(): + expected_func = ref_map[op] + setup_fn(op) + self.assertTrue(valid) + + expected_func = None + valid = False + + run_checks(setups_and_oplists, skips, BUILTIN_TO_TENSOR_FN_MAP) + run_checks(rsetups_and_oplists, rskips, BUILTIN_TO_TENSOR_RFN_MAP) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/export/test_experimental.py b/test/export/test_experimental.py index d12b1e53ca7..d5ac532a5cc 100644 --- a/test/export/test_experimental.py +++ b/test/export/test_experimental.py @@ -49,9 +49,9 @@ def forward(self, b_submodule_buffer1, x): sin = torch.ops.aten.sin.default(x) strict_graph_0 = self.strict_graph_0 strict_mode = torch.ops.higher_order.strict_mode(strict_graph_0, (sin, b_submodule_buffer1)); strict_graph_0 = sin = b_submodule_buffer1 = None - getitem_2 = strict_mode[0]; strict_mode = None + getitem = strict_mode[0]; strict_mode = None add = torch.ops.aten.add.Tensor(x, 3); x = None - return (getitem_2, add)""", + return (getitem, add)""", ) self.assertExpectedInline( diff --git a/test/export/test_export.py b/test/export/test_export.py index 5c985e27bd7..b13ce23c689 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -64,6 +64,7 @@ from torch.testing._internal.common_utils import ( IS_SANDCASTLE, IS_WINDOWS, run_tests, + skipIfCrossRef, TEST_TRANSFORMERS, TestCase as TorchTestCase, ) @@ -6989,6 +6990,7 @@ def forward(self, x): real_names_and_ops = [(node.name, node.op) for node in ep.graph.nodes] self.assertEqual(expected_names_and_ops, real_names_and_ops) + @skipIfCrossRef # Dynamo changes the order of ops under Torch function modes def test_placeholder_naming_collisions_hoo_subgraphs(self): # test collisions between user inputs, top-level nodes, and HOO subgraph nodes class Foo(torch.nn.Module): @@ -8325,6 +8327,7 @@ class TestOneOffModelExportResult(TestCase): # getitem = _scaled_dot_product_flash_attention_for_cpu[0]; _scaled_dot_product_flash_attention_for_cpu = None # return (getitem,)""") + @skipIfCrossRef @unittest.skipIf( not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Can't run fused SDPA on this platform", diff --git a/test/functorch/test_control_flow.py b/test/functorch/test_control_flow.py index 134e1fd6e41..e4714fe768f 100644 --- a/test/functorch/test_control_flow.py +++ b/test/functorch/test_control_flow.py @@ -4902,6 +4902,7 @@ def forward(self, arg0_1, arg1_1): return [getitem]""", # noqa: B950 ) + @skipIfCrossRef # Arg order changes with crossref def test_cond_make_fx_preserve_stack_trace_for_nodes_in_subgraph(self): def true_fn(x): return x + x.cos() @@ -5252,6 +5253,7 @@ def forward(self, arg0_1): ): torch.cond(inp.sum() > 0, f, f, (inp, tmp)) + @skipIfCrossRef # Arg order changes with crossref def test_cond_trace_set__and_mutate_intermediate(self): def f(a, tmp): a = a.clone() diff --git a/torch/_dynamo/backends/debugging.py b/torch/_dynamo/backends/debugging.py index 18784498862..2121db54c26 100644 --- a/torch/_dynamo/backends/debugging.py +++ b/torch/_dynamo/backends/debugging.py @@ -32,13 +32,23 @@ def eager(gm, fake_tensor_inputs, **kwargs): def make_eager_backend_with_torch_function_mode(mode): + return make_eager_backend_with_torch_function_modes([mode]) + + +def make_eager_backend_with_torch_function_modes(modes): """Used to trace HOPs (cond and while) for eager exectution, the metadata TF mode mutates vars outside of the scope of the HOP, and we can't have graph breaks in the HOP, so we need to externally run this mode and not trace it.""" + from contextlib import ExitStack def fn(gm, fake_tensor_inputs, **kwargs): - with mode: - return gm.forward + stack = ExitStack() + for mode in modes: + stack.enter_context(mode) + + result = gm.forward + stack.close() + return result return fn diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index 296eb646187..66b5be01221 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -200,7 +200,6 @@ class BuiltinVariable(VariableTracker): operator.ne, operator.eq, operator.sub, - operator.getitem, operator.length_hint, operator.lshift, operator.rshift, @@ -212,6 +211,7 @@ class BuiltinVariable(VariableTracker): operator.imatmul, operator.ifloordiv, operator.itruediv, + operator.getitem, operator.imod, operator.iadd, operator.isub, @@ -858,6 +858,39 @@ class BuiltinVariable(VariableTracker): if kwargs and not self.tensor_args(*args, *kwargs.values()): return + # insert handling for torch function here + from .builder import SourcelessBuilder + from .torch_function import ( + BUILTIN_TO_TENSOR_FN_MAP, + BUILTIN_TO_TENSOR_RFN_MAP, + can_dispatch_torch_function, + dispatch_torch_function, + ) + + if can_dispatch_torch_function(tx, args, kwargs): + # Only remap the fn to tensor methods if we aren't exporting + # export serde does not handle method descriptors today + if not tx.export: + # Use sourceless builder, we built the map ourselves + if not isinstance(args[0], TensorVariable): + if self.fn in BUILTIN_TO_TENSOR_RFN_MAP: + func = BUILTIN_TO_TENSOR_RFN_MAP[self.fn] + else: + func = BUILTIN_TO_TENSOR_FN_MAP[self.fn] + + tmp = args[0] + # swap args and call reverse version of func + args[0] = args[1] + args[1] = tmp + else: + func = BUILTIN_TO_TENSOR_FN_MAP[self.fn] + else: + func = self.fn + + fn_var = SourcelessBuilder.create(tx, func) + + return dispatch_torch_function(tx, fn_var, args, kwargs) + fn = self.fn try: # Constant fold for constant tensor and python constants diff --git a/torch/_dynamo/variables/tensor.py b/torch/_dynamo/variables/tensor.py index 7cdf286dc4e..514a712c891 100644 --- a/torch/_dynamo/variables/tensor.py +++ b/torch/_dynamo/variables/tensor.py @@ -772,6 +772,30 @@ class TensorVariable(VariableTracker): self._warn_capture_scalar_outputs() unimplemented("Tensor.item") + def method_getitem(self, *args, **kwargs): + from ..symbolic_convert import InstructionTranslator + from .builder import wrap_fx_proxy + + tx = InstructionTranslator.current_tx() + if isinstance(args[0], SymNodeVariable): + # Standard indexing will force specialization due to + # __index__. Rewrite as a regular torch op which will + # trace fine + fn, args = torch.select, [ + variables.ConstantVariable.create(0), + args[0], + ] + else: + fn = operator.getitem + + proxy = tx.output.create_proxy( + "call_function", + fn, + *proxy_args_kwargs([self] + list(args), kwargs), + ) + + return wrap_fx_proxy(tx, proxy) + @staticmethod @functools.lru_cache(None) def _warn_capture_scalar_outputs(): diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index 91dd06c30e7..77d8d2fcf8c 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -871,6 +871,10 @@ class TorchInGraphFunctionVariable(BaseTorchVariable): return ConstantVariable.create(None) + @register(torch._C.TensorBase.__getitem__) + def handle_getitem(self, tx: "InstructionTranslator", *args, **kwargs): + return args[0].call_method(tx, "getitem", args[1:], kwargs) + return handlers def call_function( diff --git a/torch/_dynamo/variables/torch_function.py b/torch/_dynamo/variables/torch_function.py index ffb3d27d4d7..3662f34804a 100644 --- a/torch/_dynamo/variables/torch_function.py +++ b/torch/_dynamo/variables/torch_function.py @@ -4,6 +4,7 @@ import collections import contextlib import functools import inspect +import operator from typing import Deque, Dict, List, TYPE_CHECKING import torch._C @@ -11,6 +12,7 @@ import torch.utils._pytree as pytree from torch._guards import Source from torch.overrides import ( _get_overloaded_args, + BaseTorchFunctionMode, get_default_nowrap_functions, TorchFunctionMode, ) @@ -62,6 +64,125 @@ if TYPE_CHECKING: # To enable subclass behavior, add your tensor subclass type to traceable_tensor_subclasses in dynamo/config.py +bin_ops = [ + operator.pow, + operator.mul, + operator.matmul, + operator.floordiv, + operator.truediv, + operator.mod, + operator.add, + operator.lt, + operator.gt, + operator.ge, + operator.le, + operator.ne, + operator.eq, + operator.sub, + operator.ipow, + operator.imul, + operator.imatmul, + operator.ifloordiv, + operator.itruediv, + operator.imod, + operator.iadd, + operator.isub, +] + +bin_int_ops = [ + operator.and_, + operator.or_, + operator.xor, + operator.iand, + operator.ixor, + operator.ior, +] + +un_int_ops = [operator.invert] + +tensor_and_int_ops = [ + operator.lshift, + operator.rshift, + operator.ilshift, + operator.irshift, + operator.getitem, +] + +un_ops = [ + operator.abs, + operator.pos, + operator.neg, + operator.not_, # Note: this has a local scalar dense call + operator.length_hint, +] + +BUILTIN_TO_TENSOR_FN_MAP = {} + +# These functions represent the r* versions of the above ops +# Basically, if __add__(1, Tensor) is called, it is translated +# to __radd__(Tensor, 1). +# In the builtin var, we check if there is a tensor in the first args position, +# if not, we swap the args and use the r* version of the op. +BUILTIN_TO_TENSOR_RFN_MAP = {} + + +def populate_builtin_to_tensor_fn_map(): + global BUILTIN_TO_TENSOR_FN_MAP + + most_recent_func = None + + class GetMethodMode(BaseTorchFunctionMode): + """ + Mode to extract the correct methods from torch function invocations + (Used to get the correct torch.Tensor methods from builtins) + """ + + def __torch_function__(self, func, types, args=(), kwargs=None): + kwargs = kwargs or {} + nonlocal most_recent_func + most_recent_func = func + return func(*args, **kwargs) + + inp0 = torch.ones(1) + inp1 = torch.ones(1) + inp0_int = torch.ones(1, dtype=torch.int32) + inp1_int = torch.ones(1, dtype=torch.int32) + with GetMethodMode(): + setups_and_oplists = [ + (lambda o: o(inp0), un_ops), + (lambda o: o(inp0_int), un_int_ops), + (lambda o: o(inp0, inp1), bin_ops), + (lambda o: o(inp0_int, inp1_int), bin_int_ops), + (lambda o: o(inp0_int, 0), tensor_and_int_ops), + ] + for setup_fn, op_list in setups_and_oplists: + for op in op_list: + setup_fn(op) + assert most_recent_func is not None + BUILTIN_TO_TENSOR_FN_MAP[op] = most_recent_func + + # gather the reverse functions + rsetups_and_oplists = [ + ( + lambda o: o(1, inp1), + bin_ops, + ), # Get r* ops, (ex. __sub__(int, Tensor) -> __rsub__(Tensor, int)) + (lambda o: o(1, inp1_int), bin_int_ops), + (lambda o: o(0, inp0_int), tensor_and_int_ops), + ] + + rskips = {operator.matmul, operator.imatmul, operator.getitem} + for setup_fn, op_list in rsetups_and_oplists: + for op in op_list: + if op in rskips: + continue + setup_fn(op) + assert most_recent_func is not None + if most_recent_func != BUILTIN_TO_TENSOR_FN_MAP[op]: + BUILTIN_TO_TENSOR_RFN_MAP[op] = most_recent_func + + +populate_builtin_to_tensor_fn_map() banned_attrs = [ fn.__self__.__name__ @@ -389,8 +510,15 @@ def call_torch_function( def build_torch_function_fn(tx: "InstructionTranslator", value, source): + from types import FunctionType + from .builder import SourcelessBuilder, VariableBuilder + func = value.__torch_function__.__func__ + + if not isinstance(func, FunctionType): + unimplemented("Builtin/C++ torch function implementations NYI") + if source: return VariableBuilder( tx, diff --git a/torch/_higher_order_ops/strict_mode.py b/torch/_higher_order_ops/strict_mode.py index 7324e20dcd4..e8543412c53 100644 --- a/torch/_higher_order_ops/strict_mode.py +++ b/torch/_higher_order_ops/strict_mode.py @@ -8,6 +8,8 @@ from torch._higher_order_ops.utils import _set_compilation_env, autograd_not_imp from torch._ops import HigherOrderOperator from torch._subclasses.fake_tensor import FakeTensorMode from torch.fx.experimental.proxy_tensor import ( + _temp_remove_metadata_torch_function_mode, + _temp_remove_pre_dispatch_torch_function_mode, disable_proxy_modes_tracing, make_fx, ProxyTorchDispatchMode, @@ -18,14 +20,26 @@ from torch.utils._python_dispatch import _get_current_dispatch_mode @exposed_in("torch") def strict_mode(callable, operands): + from torch._dynamo.backends.debugging import ( + make_eager_backend_with_torch_function_modes, + ) + if torch.compiler.is_dynamo_compiling(): return strict_mode_op(callable, operands) with _set_compilation_env(): - with torch._dynamo.utils.disable_cache_limit(): - return torch.compile(strict_mode_op, backend="eager", fullgraph=True)( - callable, operands - ) + with _temp_remove_metadata_torch_function_mode() as metadata_mode: + with _temp_remove_pre_dispatch_torch_function_mode() as predispatch_mode: + modes = [metadata_mode, predispatch_mode] + modes = [mode for mode in modes if mode is not None] + if modes: + backend = make_eager_backend_with_torch_function_modes(modes) + else: + backend = "eager" + with torch._dynamo.utils.disable_cache_limit(): + return torch.compile( + strict_mode_op, backend=backend, fullgraph=True + )(callable, operands) class StrictMode(HigherOrderOperator):