[Dynamo] Ensure torch function modes are dispatched on builtin ops (#137117)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137117
Approved by: https://github.com/yanboliang, https://github.com/williamwen42
ghstack dependencies: #137114, #137115, #137116
This commit is contained in:
Michael Lazos 2024-10-08 14:11:03 -07:00 committed by PyTorch MergeBot
parent 38afac2917
commit 27dee935af
11 changed files with 326 additions and 11 deletions

View File

@ -1,5 +1,5 @@
add_loop_eager, compile_time_instruction_count, 2834456320, 0.015
add_loop_eager_dynamic, compile_time_instruction_count, 5528896630, 0.025
add_loop_eager, compile_time_instruction_count, 3004749893, 0.015
add_loop_eager_dynamic, compile_time_instruction_count, 5726573328, 0.025
add_loop_inductor, compile_time_instruction_count, 24146845503, 0.015
add_loop_inductor_dynamic_gpu, compile_time_instruction_count, 39411706509, 0.025
add_loop_inductor_gpu, compile_time_instruction_count, 22171041650, 0.015

1 add_loop_eager compile_time_instruction_count 2834456320 3004749893 0.015
2 add_loop_eager_dynamic compile_time_instruction_count 5528896630 5726573328 0.025
3 add_loop_inductor compile_time_instruction_count 24146845503 24146845503 0.015
4 add_loop_inductor_dynamic_gpu compile_time_instruction_count 39411706509 39411706509 0.025
5 add_loop_inductor_gpu compile_time_instruction_count 22171041650 22171041650 0.015

View File

@ -1,5 +1,8 @@
# Owner(s): ["module: dynamo"]
import operator
from unittest.mock import patch
import torch
import torch._dynamo.test_case
import torch._dynamo.testing
@ -484,6 +487,100 @@ class TorchFunctionModeTests(torch._dynamo.test_case.TestCase):
self.assertEqual(expected, actual)
# Needs larger cache size since we recompile for each op
@patch.object(torch._dynamo.config, "cache_size_limit", 48)
def test_builtin_equivalent_funcs(self):
from torch._dynamo.variables.torch_function import (
bin_int_ops,
bin_ops,
BUILTIN_TO_TENSOR_FN_MAP,
BUILTIN_TO_TENSOR_RFN_MAP,
tensor_and_int_ops,
un_int_ops,
un_ops,
)
expected_func = None
valid = False
class FuncEquivMode(BaseTorchFunctionMode):
def __torch_function__(self, func, types, args=(), kwargs=None):
nonlocal expected_func
nonlocal valid
if not kwargs:
kwargs = {}
if torch._dynamo.is_compiling():
valid = expected_func == func
return super().__torch_function__(func, types, args, kwargs)
inp0 = torch.ones(1, 1)
inp1 = torch.ones(1, 1)
inp0_int = torch.ones(1, 1, dtype=torch.int32)
inp1_int = torch.ones(1, 1, dtype=torch.int32)
@torch.compile(fullgraph=True)
def fn_un(op, inp):
return op(inp)
@torch.compile(fullgraph=True)
def fn_un_int(op, inp):
return op(inp)
@torch.compile(fullgraph=True)
def fn_bin(op, inp0, inp1):
return op(inp0, inp1)
@torch.compile(fullgraph=True)
def fn_bin_int(op, inp0, inp1):
return op(inp0, inp1)
@torch.compile(fullgraph=True)
def fn_tensor_and_int(op, inp0, inp1):
return op(inp0, inp1)
setups_and_oplists = [
(lambda o: fn_un(o, inp0), un_ops),
(lambda o: fn_un_int(o, inp0_int), un_int_ops),
(lambda o: fn_bin(o, inp0, inp1), bin_ops),
(lambda o: fn_bin_int(o, inp0_int, inp1_int), bin_int_ops),
(lambda o: fn_tensor_and_int(o, inp0_int, 0), tensor_and_int_ops),
]
# gather the reverse functions
rsetups_and_oplists = [
(
lambda o: fn_bin(o, 1, inp1),
bin_ops,
), # Get r* ops, (ex. __sub__(int, Tensor) -> __rsub__(Tensor, int))
(lambda o: fn_bin_int(o, 1, inp1_int), bin_int_ops),
(lambda o: fn_tensor_and_int(o, 0, inp0_int), tensor_and_int_ops),
]
skips = {operator.not_} # Has local scalar dense call which graph breaks
rskips = {
operator.matmul,
operator.imatmul,
operator.getitem,
} # Doesn't type check with reversed args
def run_checks(setups_and_oplists, skips, ref_map):
nonlocal valid
nonlocal expected_func
for setup_fn, op_list in setups_and_oplists:
for op in op_list:
if op in skips or op not in ref_map:
continue
with FuncEquivMode():
expected_func = ref_map[op]
setup_fn(op)
self.assertTrue(valid)
expected_func = None
valid = False
run_checks(setups_and_oplists, skips, BUILTIN_TO_TENSOR_FN_MAP)
run_checks(rsetups_and_oplists, rskips, BUILTIN_TO_TENSOR_RFN_MAP)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -49,9 +49,9 @@ def forward(self, b_submodule_buffer1, x):
sin = torch.ops.aten.sin.default(x)
strict_graph_0 = self.strict_graph_0
strict_mode = torch.ops.higher_order.strict_mode(strict_graph_0, (sin, b_submodule_buffer1)); strict_graph_0 = sin = b_submodule_buffer1 = None
getitem_2 = strict_mode[0]; strict_mode = None
getitem = strict_mode[0]; strict_mode = None
add = torch.ops.aten.add.Tensor(x, 3); x = None
return (getitem_2, add)""",
return (getitem, add)""",
)
self.assertExpectedInline(

View File

@ -64,6 +64,7 @@ from torch.testing._internal.common_utils import (
IS_SANDCASTLE,
IS_WINDOWS,
run_tests,
skipIfCrossRef,
TEST_TRANSFORMERS,
TestCase as TorchTestCase,
)
@ -6989,6 +6990,7 @@ def forward(self, x):
real_names_and_ops = [(node.name, node.op) for node in ep.graph.nodes]
self.assertEqual(expected_names_and_ops, real_names_and_ops)
@skipIfCrossRef # Dynamo changes the order of ops under Torch function modes
def test_placeholder_naming_collisions_hoo_subgraphs(self):
# test collisions between user inputs, top-level nodes, and HOO subgraph nodes
class Foo(torch.nn.Module):
@ -8325,6 +8327,7 @@ class TestOneOffModelExportResult(TestCase):
# getitem = _scaled_dot_product_flash_attention_for_cpu[0]; _scaled_dot_product_flash_attention_for_cpu = None
# return (getitem,)""")
@skipIfCrossRef
@unittest.skipIf(
not PLATFORM_SUPPORTS_FLASH_ATTENTION,
"Can't run fused SDPA on this platform",

View File

@ -4902,6 +4902,7 @@ def forward(self, arg0_1, arg1_1):
return [getitem]""", # noqa: B950
)
@skipIfCrossRef # Arg order changes with crossref
def test_cond_make_fx_preserve_stack_trace_for_nodes_in_subgraph(self):
def true_fn(x):
return x + x.cos()
@ -5252,6 +5253,7 @@ def forward(self, arg0_1):
):
torch.cond(inp.sum() > 0, f, f, (inp, tmp))
@skipIfCrossRef # Arg order changes with crossref
def test_cond_trace_set__and_mutate_intermediate(self):
def f(a, tmp):
a = a.clone()

View File

@ -32,13 +32,23 @@ def eager(gm, fake_tensor_inputs, **kwargs):
def make_eager_backend_with_torch_function_mode(mode):
return make_eager_backend_with_torch_function_modes([mode])
def make_eager_backend_with_torch_function_modes(modes):
"""Used to trace HOPs (cond and while) for eager exectution, the metadata
TF mode mutates vars outside of the scope of the HOP, and we can't have graph breaks
in the HOP, so we need to externally run this mode and not trace it."""
from contextlib import ExitStack
def fn(gm, fake_tensor_inputs, **kwargs):
with mode:
return gm.forward
stack = ExitStack()
for mode in modes:
stack.enter_context(mode)
result = gm.forward
stack.close()
return result
return fn

View File

@ -200,7 +200,6 @@ class BuiltinVariable(VariableTracker):
operator.ne,
operator.eq,
operator.sub,
operator.getitem,
operator.length_hint,
operator.lshift,
operator.rshift,
@ -212,6 +211,7 @@ class BuiltinVariable(VariableTracker):
operator.imatmul,
operator.ifloordiv,
operator.itruediv,
operator.getitem,
operator.imod,
operator.iadd,
operator.isub,
@ -858,6 +858,39 @@ class BuiltinVariable(VariableTracker):
if kwargs and not self.tensor_args(*args, *kwargs.values()):
return
# insert handling for torch function here
from .builder import SourcelessBuilder
from .torch_function import (
BUILTIN_TO_TENSOR_FN_MAP,
BUILTIN_TO_TENSOR_RFN_MAP,
can_dispatch_torch_function,
dispatch_torch_function,
)
if can_dispatch_torch_function(tx, args, kwargs):
# Only remap the fn to tensor methods if we aren't exporting
# export serde does not handle method descriptors today
if not tx.export:
# Use sourceless builder, we built the map ourselves
if not isinstance(args[0], TensorVariable):
if self.fn in BUILTIN_TO_TENSOR_RFN_MAP:
func = BUILTIN_TO_TENSOR_RFN_MAP[self.fn]
else:
func = BUILTIN_TO_TENSOR_FN_MAP[self.fn]
tmp = args[0]
# swap args and call reverse version of func
args[0] = args[1]
args[1] = tmp
else:
func = BUILTIN_TO_TENSOR_FN_MAP[self.fn]
else:
func = self.fn
fn_var = SourcelessBuilder.create(tx, func)
return dispatch_torch_function(tx, fn_var, args, kwargs)
fn = self.fn
try:
# Constant fold for constant tensor and python constants

View File

@ -772,6 +772,30 @@ class TensorVariable(VariableTracker):
self._warn_capture_scalar_outputs()
unimplemented("Tensor.item")
def method_getitem(self, *args, **kwargs):
from ..symbolic_convert import InstructionTranslator
from .builder import wrap_fx_proxy
tx = InstructionTranslator.current_tx()
if isinstance(args[0], SymNodeVariable):
# Standard indexing will force specialization due to
# __index__. Rewrite as a regular torch op which will
# trace fine
fn, args = torch.select, [
variables.ConstantVariable.create(0),
args[0],
]
else:
fn = operator.getitem
proxy = tx.output.create_proxy(
"call_function",
fn,
*proxy_args_kwargs([self] + list(args), kwargs),
)
return wrap_fx_proxy(tx, proxy)
@staticmethod
@functools.lru_cache(None)
def _warn_capture_scalar_outputs():

View File

@ -871,6 +871,10 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
return ConstantVariable.create(None)
@register(torch._C.TensorBase.__getitem__)
def handle_getitem(self, tx: "InstructionTranslator", *args, **kwargs):
return args[0].call_method(tx, "getitem", args[1:], kwargs)
return handlers
def call_function(

View File

@ -4,6 +4,7 @@ import collections
import contextlib
import functools
import inspect
import operator
from typing import Deque, Dict, List, TYPE_CHECKING
import torch._C
@ -11,6 +12,7 @@ import torch.utils._pytree as pytree
from torch._guards import Source
from torch.overrides import (
_get_overloaded_args,
BaseTorchFunctionMode,
get_default_nowrap_functions,
TorchFunctionMode,
)
@ -62,6 +64,125 @@ if TYPE_CHECKING:
# To enable subclass behavior, add your tensor subclass type to traceable_tensor_subclasses in dynamo/config.py
bin_ops = [
operator.pow,
operator.mul,
operator.matmul,
operator.floordiv,
operator.truediv,
operator.mod,
operator.add,
operator.lt,
operator.gt,
operator.ge,
operator.le,
operator.ne,
operator.eq,
operator.sub,
operator.ipow,
operator.imul,
operator.imatmul,
operator.ifloordiv,
operator.itruediv,
operator.imod,
operator.iadd,
operator.isub,
]
bin_int_ops = [
operator.and_,
operator.or_,
operator.xor,
operator.iand,
operator.ixor,
operator.ior,
]
un_int_ops = [operator.invert]
tensor_and_int_ops = [
operator.lshift,
operator.rshift,
operator.ilshift,
operator.irshift,
operator.getitem,
]
un_ops = [
operator.abs,
operator.pos,
operator.neg,
operator.not_, # Note: this has a local scalar dense call
operator.length_hint,
]
BUILTIN_TO_TENSOR_FN_MAP = {}
# These functions represent the r* versions of the above ops
# Basically, if __add__(1, Tensor) is called, it is translated
# to __radd__(Tensor, 1).
# In the builtin var, we check if there is a tensor in the first args position,
# if not, we swap the args and use the r* version of the op.
BUILTIN_TO_TENSOR_RFN_MAP = {}
def populate_builtin_to_tensor_fn_map():
global BUILTIN_TO_TENSOR_FN_MAP
most_recent_func = None
class GetMethodMode(BaseTorchFunctionMode):
"""
Mode to extract the correct methods from torch function invocations
(Used to get the correct torch.Tensor methods from builtins)
"""
def __torch_function__(self, func, types, args=(), kwargs=None):
kwargs = kwargs or {}
nonlocal most_recent_func
most_recent_func = func
return func(*args, **kwargs)
inp0 = torch.ones(1)
inp1 = torch.ones(1)
inp0_int = torch.ones(1, dtype=torch.int32)
inp1_int = torch.ones(1, dtype=torch.int32)
with GetMethodMode():
setups_and_oplists = [
(lambda o: o(inp0), un_ops),
(lambda o: o(inp0_int), un_int_ops),
(lambda o: o(inp0, inp1), bin_ops),
(lambda o: o(inp0_int, inp1_int), bin_int_ops),
(lambda o: o(inp0_int, 0), tensor_and_int_ops),
]
for setup_fn, op_list in setups_and_oplists:
for op in op_list:
setup_fn(op)
assert most_recent_func is not None
BUILTIN_TO_TENSOR_FN_MAP[op] = most_recent_func
# gather the reverse functions
rsetups_and_oplists = [
(
lambda o: o(1, inp1),
bin_ops,
), # Get r* ops, (ex. __sub__(int, Tensor) -> __rsub__(Tensor, int))
(lambda o: o(1, inp1_int), bin_int_ops),
(lambda o: o(0, inp0_int), tensor_and_int_ops),
]
rskips = {operator.matmul, operator.imatmul, operator.getitem}
for setup_fn, op_list in rsetups_and_oplists:
for op in op_list:
if op in rskips:
continue
setup_fn(op)
assert most_recent_func is not None
if most_recent_func != BUILTIN_TO_TENSOR_FN_MAP[op]:
BUILTIN_TO_TENSOR_RFN_MAP[op] = most_recent_func
populate_builtin_to_tensor_fn_map()
banned_attrs = [
fn.__self__.__name__
@ -389,8 +510,15 @@ def call_torch_function(
def build_torch_function_fn(tx: "InstructionTranslator", value, source):
from types import FunctionType
from .builder import SourcelessBuilder, VariableBuilder
func = value.__torch_function__.__func__
if not isinstance(func, FunctionType):
unimplemented("Builtin/C++ torch function implementations NYI")
if source:
return VariableBuilder(
tx,

View File

@ -8,6 +8,8 @@ from torch._higher_order_ops.utils import _set_compilation_env, autograd_not_imp
from torch._ops import HigherOrderOperator
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.experimental.proxy_tensor import (
_temp_remove_metadata_torch_function_mode,
_temp_remove_pre_dispatch_torch_function_mode,
disable_proxy_modes_tracing,
make_fx,
ProxyTorchDispatchMode,
@ -18,14 +20,26 @@ from torch.utils._python_dispatch import _get_current_dispatch_mode
@exposed_in("torch")
def strict_mode(callable, operands):
from torch._dynamo.backends.debugging import (
make_eager_backend_with_torch_function_modes,
)
if torch.compiler.is_dynamo_compiling():
return strict_mode_op(callable, operands)
with _set_compilation_env():
with _temp_remove_metadata_torch_function_mode() as metadata_mode:
with _temp_remove_pre_dispatch_torch_function_mode() as predispatch_mode:
modes = [metadata_mode, predispatch_mode]
modes = [mode for mode in modes if mode is not None]
if modes:
backend = make_eager_backend_with_torch_function_modes(modes)
else:
backend = "eager"
with torch._dynamo.utils.disable_cache_limit():
return torch.compile(strict_mode_op, backend="eager", fullgraph=True)(
callable, operands
)
return torch.compile(
strict_mode_op, backend=backend, fullgraph=True
)(callable, operands)
class StrictMode(HigherOrderOperator):