mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Dynamo] Ensure torch function modes are dispatched on builtin ops (#137117)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137117 Approved by: https://github.com/yanboliang, https://github.com/williamwen42 ghstack dependencies: #137114, #137115, #137116
This commit is contained in:
parent
38afac2917
commit
27dee935af
|
|
@ -1,5 +1,5 @@
|
||||||
add_loop_eager, compile_time_instruction_count, 2834456320, 0.015
|
add_loop_eager, compile_time_instruction_count, 3004749893, 0.015
|
||||||
add_loop_eager_dynamic, compile_time_instruction_count, 5528896630, 0.025
|
add_loop_eager_dynamic, compile_time_instruction_count, 5726573328, 0.025
|
||||||
add_loop_inductor, compile_time_instruction_count, 24146845503, 0.015
|
add_loop_inductor, compile_time_instruction_count, 24146845503, 0.015
|
||||||
add_loop_inductor_dynamic_gpu, compile_time_instruction_count, 39411706509, 0.025
|
add_loop_inductor_dynamic_gpu, compile_time_instruction_count, 39411706509, 0.025
|
||||||
add_loop_inductor_gpu, compile_time_instruction_count, 22171041650, 0.015
|
add_loop_inductor_gpu, compile_time_instruction_count, 22171041650, 0.015
|
||||||
|
|
|
||||||
|
|
|
@ -1,5 +1,8 @@
|
||||||
# Owner(s): ["module: dynamo"]
|
# Owner(s): ["module: dynamo"]
|
||||||
|
|
||||||
|
import operator
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch._dynamo.test_case
|
import torch._dynamo.test_case
|
||||||
import torch._dynamo.testing
|
import torch._dynamo.testing
|
||||||
|
|
@ -484,6 +487,100 @@ class TorchFunctionModeTests(torch._dynamo.test_case.TestCase):
|
||||||
|
|
||||||
self.assertEqual(expected, actual)
|
self.assertEqual(expected, actual)
|
||||||
|
|
||||||
|
# Needs larger cache size since we recompile for each op
|
||||||
|
@patch.object(torch._dynamo.config, "cache_size_limit", 48)
|
||||||
|
def test_builtin_equivalent_funcs(self):
|
||||||
|
from torch._dynamo.variables.torch_function import (
|
||||||
|
bin_int_ops,
|
||||||
|
bin_ops,
|
||||||
|
BUILTIN_TO_TENSOR_FN_MAP,
|
||||||
|
BUILTIN_TO_TENSOR_RFN_MAP,
|
||||||
|
tensor_and_int_ops,
|
||||||
|
un_int_ops,
|
||||||
|
un_ops,
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_func = None
|
||||||
|
valid = False
|
||||||
|
|
||||||
|
class FuncEquivMode(BaseTorchFunctionMode):
|
||||||
|
def __torch_function__(self, func, types, args=(), kwargs=None):
|
||||||
|
nonlocal expected_func
|
||||||
|
nonlocal valid
|
||||||
|
if not kwargs:
|
||||||
|
kwargs = {}
|
||||||
|
if torch._dynamo.is_compiling():
|
||||||
|
valid = expected_func == func
|
||||||
|
return super().__torch_function__(func, types, args, kwargs)
|
||||||
|
|
||||||
|
inp0 = torch.ones(1, 1)
|
||||||
|
inp1 = torch.ones(1, 1)
|
||||||
|
inp0_int = torch.ones(1, 1, dtype=torch.int32)
|
||||||
|
inp1_int = torch.ones(1, 1, dtype=torch.int32)
|
||||||
|
|
||||||
|
@torch.compile(fullgraph=True)
|
||||||
|
def fn_un(op, inp):
|
||||||
|
return op(inp)
|
||||||
|
|
||||||
|
@torch.compile(fullgraph=True)
|
||||||
|
def fn_un_int(op, inp):
|
||||||
|
return op(inp)
|
||||||
|
|
||||||
|
@torch.compile(fullgraph=True)
|
||||||
|
def fn_bin(op, inp0, inp1):
|
||||||
|
return op(inp0, inp1)
|
||||||
|
|
||||||
|
@torch.compile(fullgraph=True)
|
||||||
|
def fn_bin_int(op, inp0, inp1):
|
||||||
|
return op(inp0, inp1)
|
||||||
|
|
||||||
|
@torch.compile(fullgraph=True)
|
||||||
|
def fn_tensor_and_int(op, inp0, inp1):
|
||||||
|
return op(inp0, inp1)
|
||||||
|
|
||||||
|
setups_and_oplists = [
|
||||||
|
(lambda o: fn_un(o, inp0), un_ops),
|
||||||
|
(lambda o: fn_un_int(o, inp0_int), un_int_ops),
|
||||||
|
(lambda o: fn_bin(o, inp0, inp1), bin_ops),
|
||||||
|
(lambda o: fn_bin_int(o, inp0_int, inp1_int), bin_int_ops),
|
||||||
|
(lambda o: fn_tensor_and_int(o, inp0_int, 0), tensor_and_int_ops),
|
||||||
|
]
|
||||||
|
|
||||||
|
# gather the reverse functions
|
||||||
|
rsetups_and_oplists = [
|
||||||
|
(
|
||||||
|
lambda o: fn_bin(o, 1, inp1),
|
||||||
|
bin_ops,
|
||||||
|
), # Get r* ops, (ex. __sub__(int, Tensor) -> __rsub__(Tensor, int))
|
||||||
|
(lambda o: fn_bin_int(o, 1, inp1_int), bin_int_ops),
|
||||||
|
(lambda o: fn_tensor_and_int(o, 0, inp0_int), tensor_and_int_ops),
|
||||||
|
]
|
||||||
|
|
||||||
|
skips = {operator.not_} # Has local scalar dense call which graph breaks
|
||||||
|
rskips = {
|
||||||
|
operator.matmul,
|
||||||
|
operator.imatmul,
|
||||||
|
operator.getitem,
|
||||||
|
} # Doesn't type check with reversed args
|
||||||
|
|
||||||
|
def run_checks(setups_and_oplists, skips, ref_map):
|
||||||
|
nonlocal valid
|
||||||
|
nonlocal expected_func
|
||||||
|
for setup_fn, op_list in setups_and_oplists:
|
||||||
|
for op in op_list:
|
||||||
|
if op in skips or op not in ref_map:
|
||||||
|
continue
|
||||||
|
with FuncEquivMode():
|
||||||
|
expected_func = ref_map[op]
|
||||||
|
setup_fn(op)
|
||||||
|
self.assertTrue(valid)
|
||||||
|
|
||||||
|
expected_func = None
|
||||||
|
valid = False
|
||||||
|
|
||||||
|
run_checks(setups_and_oplists, skips, BUILTIN_TO_TENSOR_FN_MAP)
|
||||||
|
run_checks(rsetups_and_oplists, rskips, BUILTIN_TO_TENSOR_RFN_MAP)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
from torch._dynamo.test_case import run_tests
|
from torch._dynamo.test_case import run_tests
|
||||||
|
|
|
||||||
|
|
@ -49,9 +49,9 @@ def forward(self, b_submodule_buffer1, x):
|
||||||
sin = torch.ops.aten.sin.default(x)
|
sin = torch.ops.aten.sin.default(x)
|
||||||
strict_graph_0 = self.strict_graph_0
|
strict_graph_0 = self.strict_graph_0
|
||||||
strict_mode = torch.ops.higher_order.strict_mode(strict_graph_0, (sin, b_submodule_buffer1)); strict_graph_0 = sin = b_submodule_buffer1 = None
|
strict_mode = torch.ops.higher_order.strict_mode(strict_graph_0, (sin, b_submodule_buffer1)); strict_graph_0 = sin = b_submodule_buffer1 = None
|
||||||
getitem_2 = strict_mode[0]; strict_mode = None
|
getitem = strict_mode[0]; strict_mode = None
|
||||||
add = torch.ops.aten.add.Tensor(x, 3); x = None
|
add = torch.ops.aten.add.Tensor(x, 3); x = None
|
||||||
return (getitem_2, add)""",
|
return (getitem, add)""",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertExpectedInline(
|
self.assertExpectedInline(
|
||||||
|
|
|
||||||
|
|
@ -64,6 +64,7 @@ from torch.testing._internal.common_utils import (
|
||||||
IS_SANDCASTLE,
|
IS_SANDCASTLE,
|
||||||
IS_WINDOWS,
|
IS_WINDOWS,
|
||||||
run_tests,
|
run_tests,
|
||||||
|
skipIfCrossRef,
|
||||||
TEST_TRANSFORMERS,
|
TEST_TRANSFORMERS,
|
||||||
TestCase as TorchTestCase,
|
TestCase as TorchTestCase,
|
||||||
)
|
)
|
||||||
|
|
@ -6989,6 +6990,7 @@ def forward(self, x):
|
||||||
real_names_and_ops = [(node.name, node.op) for node in ep.graph.nodes]
|
real_names_and_ops = [(node.name, node.op) for node in ep.graph.nodes]
|
||||||
self.assertEqual(expected_names_and_ops, real_names_and_ops)
|
self.assertEqual(expected_names_and_ops, real_names_and_ops)
|
||||||
|
|
||||||
|
@skipIfCrossRef # Dynamo changes the order of ops under Torch function modes
|
||||||
def test_placeholder_naming_collisions_hoo_subgraphs(self):
|
def test_placeholder_naming_collisions_hoo_subgraphs(self):
|
||||||
# test collisions between user inputs, top-level nodes, and HOO subgraph nodes
|
# test collisions between user inputs, top-level nodes, and HOO subgraph nodes
|
||||||
class Foo(torch.nn.Module):
|
class Foo(torch.nn.Module):
|
||||||
|
|
@ -8325,6 +8327,7 @@ class TestOneOffModelExportResult(TestCase):
|
||||||
# getitem = _scaled_dot_product_flash_attention_for_cpu[0]; _scaled_dot_product_flash_attention_for_cpu = None
|
# getitem = _scaled_dot_product_flash_attention_for_cpu[0]; _scaled_dot_product_flash_attention_for_cpu = None
|
||||||
# return (getitem,)""")
|
# return (getitem,)""")
|
||||||
|
|
||||||
|
@skipIfCrossRef
|
||||||
@unittest.skipIf(
|
@unittest.skipIf(
|
||||||
not PLATFORM_SUPPORTS_FLASH_ATTENTION,
|
not PLATFORM_SUPPORTS_FLASH_ATTENTION,
|
||||||
"Can't run fused SDPA on this platform",
|
"Can't run fused SDPA on this platform",
|
||||||
|
|
|
||||||
|
|
@ -4902,6 +4902,7 @@ def forward(self, arg0_1, arg1_1):
|
||||||
return [getitem]""", # noqa: B950
|
return [getitem]""", # noqa: B950
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@skipIfCrossRef # Arg order changes with crossref
|
||||||
def test_cond_make_fx_preserve_stack_trace_for_nodes_in_subgraph(self):
|
def test_cond_make_fx_preserve_stack_trace_for_nodes_in_subgraph(self):
|
||||||
def true_fn(x):
|
def true_fn(x):
|
||||||
return x + x.cos()
|
return x + x.cos()
|
||||||
|
|
@ -5252,6 +5253,7 @@ def forward(self, arg0_1):
|
||||||
):
|
):
|
||||||
torch.cond(inp.sum() > 0, f, f, (inp, tmp))
|
torch.cond(inp.sum() > 0, f, f, (inp, tmp))
|
||||||
|
|
||||||
|
@skipIfCrossRef # Arg order changes with crossref
|
||||||
def test_cond_trace_set__and_mutate_intermediate(self):
|
def test_cond_trace_set__and_mutate_intermediate(self):
|
||||||
def f(a, tmp):
|
def f(a, tmp):
|
||||||
a = a.clone()
|
a = a.clone()
|
||||||
|
|
|
||||||
|
|
@ -32,13 +32,23 @@ def eager(gm, fake_tensor_inputs, **kwargs):
|
||||||
|
|
||||||
|
|
||||||
def make_eager_backend_with_torch_function_mode(mode):
|
def make_eager_backend_with_torch_function_mode(mode):
|
||||||
|
return make_eager_backend_with_torch_function_modes([mode])
|
||||||
|
|
||||||
|
|
||||||
|
def make_eager_backend_with_torch_function_modes(modes):
|
||||||
"""Used to trace HOPs (cond and while) for eager exectution, the metadata
|
"""Used to trace HOPs (cond and while) for eager exectution, the metadata
|
||||||
TF mode mutates vars outside of the scope of the HOP, and we can't have graph breaks
|
TF mode mutates vars outside of the scope of the HOP, and we can't have graph breaks
|
||||||
in the HOP, so we need to externally run this mode and not trace it."""
|
in the HOP, so we need to externally run this mode and not trace it."""
|
||||||
|
from contextlib import ExitStack
|
||||||
|
|
||||||
def fn(gm, fake_tensor_inputs, **kwargs):
|
def fn(gm, fake_tensor_inputs, **kwargs):
|
||||||
with mode:
|
stack = ExitStack()
|
||||||
return gm.forward
|
for mode in modes:
|
||||||
|
stack.enter_context(mode)
|
||||||
|
|
||||||
|
result = gm.forward
|
||||||
|
stack.close()
|
||||||
|
return result
|
||||||
|
|
||||||
return fn
|
return fn
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -200,7 +200,6 @@ class BuiltinVariable(VariableTracker):
|
||||||
operator.ne,
|
operator.ne,
|
||||||
operator.eq,
|
operator.eq,
|
||||||
operator.sub,
|
operator.sub,
|
||||||
operator.getitem,
|
|
||||||
operator.length_hint,
|
operator.length_hint,
|
||||||
operator.lshift,
|
operator.lshift,
|
||||||
operator.rshift,
|
operator.rshift,
|
||||||
|
|
@ -212,6 +211,7 @@ class BuiltinVariable(VariableTracker):
|
||||||
operator.imatmul,
|
operator.imatmul,
|
||||||
operator.ifloordiv,
|
operator.ifloordiv,
|
||||||
operator.itruediv,
|
operator.itruediv,
|
||||||
|
operator.getitem,
|
||||||
operator.imod,
|
operator.imod,
|
||||||
operator.iadd,
|
operator.iadd,
|
||||||
operator.isub,
|
operator.isub,
|
||||||
|
|
@ -858,6 +858,39 @@ class BuiltinVariable(VariableTracker):
|
||||||
if kwargs and not self.tensor_args(*args, *kwargs.values()):
|
if kwargs and not self.tensor_args(*args, *kwargs.values()):
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# insert handling for torch function here
|
||||||
|
from .builder import SourcelessBuilder
|
||||||
|
from .torch_function import (
|
||||||
|
BUILTIN_TO_TENSOR_FN_MAP,
|
||||||
|
BUILTIN_TO_TENSOR_RFN_MAP,
|
||||||
|
can_dispatch_torch_function,
|
||||||
|
dispatch_torch_function,
|
||||||
|
)
|
||||||
|
|
||||||
|
if can_dispatch_torch_function(tx, args, kwargs):
|
||||||
|
# Only remap the fn to tensor methods if we aren't exporting
|
||||||
|
# export serde does not handle method descriptors today
|
||||||
|
if not tx.export:
|
||||||
|
# Use sourceless builder, we built the map ourselves
|
||||||
|
if not isinstance(args[0], TensorVariable):
|
||||||
|
if self.fn in BUILTIN_TO_TENSOR_RFN_MAP:
|
||||||
|
func = BUILTIN_TO_TENSOR_RFN_MAP[self.fn]
|
||||||
|
else:
|
||||||
|
func = BUILTIN_TO_TENSOR_FN_MAP[self.fn]
|
||||||
|
|
||||||
|
tmp = args[0]
|
||||||
|
# swap args and call reverse version of func
|
||||||
|
args[0] = args[1]
|
||||||
|
args[1] = tmp
|
||||||
|
else:
|
||||||
|
func = BUILTIN_TO_TENSOR_FN_MAP[self.fn]
|
||||||
|
else:
|
||||||
|
func = self.fn
|
||||||
|
|
||||||
|
fn_var = SourcelessBuilder.create(tx, func)
|
||||||
|
|
||||||
|
return dispatch_torch_function(tx, fn_var, args, kwargs)
|
||||||
|
|
||||||
fn = self.fn
|
fn = self.fn
|
||||||
try:
|
try:
|
||||||
# Constant fold for constant tensor and python constants
|
# Constant fold for constant tensor and python constants
|
||||||
|
|
|
||||||
|
|
@ -772,6 +772,30 @@ class TensorVariable(VariableTracker):
|
||||||
self._warn_capture_scalar_outputs()
|
self._warn_capture_scalar_outputs()
|
||||||
unimplemented("Tensor.item")
|
unimplemented("Tensor.item")
|
||||||
|
|
||||||
|
def method_getitem(self, *args, **kwargs):
|
||||||
|
from ..symbolic_convert import InstructionTranslator
|
||||||
|
from .builder import wrap_fx_proxy
|
||||||
|
|
||||||
|
tx = InstructionTranslator.current_tx()
|
||||||
|
if isinstance(args[0], SymNodeVariable):
|
||||||
|
# Standard indexing will force specialization due to
|
||||||
|
# __index__. Rewrite as a regular torch op which will
|
||||||
|
# trace fine
|
||||||
|
fn, args = torch.select, [
|
||||||
|
variables.ConstantVariable.create(0),
|
||||||
|
args[0],
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
fn = operator.getitem
|
||||||
|
|
||||||
|
proxy = tx.output.create_proxy(
|
||||||
|
"call_function",
|
||||||
|
fn,
|
||||||
|
*proxy_args_kwargs([self] + list(args), kwargs),
|
||||||
|
)
|
||||||
|
|
||||||
|
return wrap_fx_proxy(tx, proxy)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@functools.lru_cache(None)
|
@functools.lru_cache(None)
|
||||||
def _warn_capture_scalar_outputs():
|
def _warn_capture_scalar_outputs():
|
||||||
|
|
|
||||||
|
|
@ -871,6 +871,10 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
|
||||||
|
|
||||||
return ConstantVariable.create(None)
|
return ConstantVariable.create(None)
|
||||||
|
|
||||||
|
@register(torch._C.TensorBase.__getitem__)
|
||||||
|
def handle_getitem(self, tx: "InstructionTranslator", *args, **kwargs):
|
||||||
|
return args[0].call_method(tx, "getitem", args[1:], kwargs)
|
||||||
|
|
||||||
return handlers
|
return handlers
|
||||||
|
|
||||||
def call_function(
|
def call_function(
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@ import collections
|
||||||
import contextlib
|
import contextlib
|
||||||
import functools
|
import functools
|
||||||
import inspect
|
import inspect
|
||||||
|
import operator
|
||||||
from typing import Deque, Dict, List, TYPE_CHECKING
|
from typing import Deque, Dict, List, TYPE_CHECKING
|
||||||
|
|
||||||
import torch._C
|
import torch._C
|
||||||
|
|
@ -11,6 +12,7 @@ import torch.utils._pytree as pytree
|
||||||
from torch._guards import Source
|
from torch._guards import Source
|
||||||
from torch.overrides import (
|
from torch.overrides import (
|
||||||
_get_overloaded_args,
|
_get_overloaded_args,
|
||||||
|
BaseTorchFunctionMode,
|
||||||
get_default_nowrap_functions,
|
get_default_nowrap_functions,
|
||||||
TorchFunctionMode,
|
TorchFunctionMode,
|
||||||
)
|
)
|
||||||
|
|
@ -62,6 +64,125 @@ if TYPE_CHECKING:
|
||||||
|
|
||||||
# To enable subclass behavior, add your tensor subclass type to traceable_tensor_subclasses in dynamo/config.py
|
# To enable subclass behavior, add your tensor subclass type to traceable_tensor_subclasses in dynamo/config.py
|
||||||
|
|
||||||
|
bin_ops = [
|
||||||
|
operator.pow,
|
||||||
|
operator.mul,
|
||||||
|
operator.matmul,
|
||||||
|
operator.floordiv,
|
||||||
|
operator.truediv,
|
||||||
|
operator.mod,
|
||||||
|
operator.add,
|
||||||
|
operator.lt,
|
||||||
|
operator.gt,
|
||||||
|
operator.ge,
|
||||||
|
operator.le,
|
||||||
|
operator.ne,
|
||||||
|
operator.eq,
|
||||||
|
operator.sub,
|
||||||
|
operator.ipow,
|
||||||
|
operator.imul,
|
||||||
|
operator.imatmul,
|
||||||
|
operator.ifloordiv,
|
||||||
|
operator.itruediv,
|
||||||
|
operator.imod,
|
||||||
|
operator.iadd,
|
||||||
|
operator.isub,
|
||||||
|
]
|
||||||
|
|
||||||
|
bin_int_ops = [
|
||||||
|
operator.and_,
|
||||||
|
operator.or_,
|
||||||
|
operator.xor,
|
||||||
|
operator.iand,
|
||||||
|
operator.ixor,
|
||||||
|
operator.ior,
|
||||||
|
]
|
||||||
|
|
||||||
|
un_int_ops = [operator.invert]
|
||||||
|
|
||||||
|
tensor_and_int_ops = [
|
||||||
|
operator.lshift,
|
||||||
|
operator.rshift,
|
||||||
|
operator.ilshift,
|
||||||
|
operator.irshift,
|
||||||
|
operator.getitem,
|
||||||
|
]
|
||||||
|
|
||||||
|
un_ops = [
|
||||||
|
operator.abs,
|
||||||
|
operator.pos,
|
||||||
|
operator.neg,
|
||||||
|
operator.not_, # Note: this has a local scalar dense call
|
||||||
|
operator.length_hint,
|
||||||
|
]
|
||||||
|
|
||||||
|
BUILTIN_TO_TENSOR_FN_MAP = {}
|
||||||
|
|
||||||
|
# These functions represent the r* versions of the above ops
|
||||||
|
# Basically, if __add__(1, Tensor) is called, it is translated
|
||||||
|
# to __radd__(Tensor, 1).
|
||||||
|
# In the builtin var, we check if there is a tensor in the first args position,
|
||||||
|
# if not, we swap the args and use the r* version of the op.
|
||||||
|
BUILTIN_TO_TENSOR_RFN_MAP = {}
|
||||||
|
|
||||||
|
|
||||||
|
def populate_builtin_to_tensor_fn_map():
|
||||||
|
global BUILTIN_TO_TENSOR_FN_MAP
|
||||||
|
|
||||||
|
most_recent_func = None
|
||||||
|
|
||||||
|
class GetMethodMode(BaseTorchFunctionMode):
|
||||||
|
"""
|
||||||
|
Mode to extract the correct methods from torch function invocations
|
||||||
|
(Used to get the correct torch.Tensor methods from builtins)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __torch_function__(self, func, types, args=(), kwargs=None):
|
||||||
|
kwargs = kwargs or {}
|
||||||
|
nonlocal most_recent_func
|
||||||
|
most_recent_func = func
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
inp0 = torch.ones(1)
|
||||||
|
inp1 = torch.ones(1)
|
||||||
|
inp0_int = torch.ones(1, dtype=torch.int32)
|
||||||
|
inp1_int = torch.ones(1, dtype=torch.int32)
|
||||||
|
with GetMethodMode():
|
||||||
|
setups_and_oplists = [
|
||||||
|
(lambda o: o(inp0), un_ops),
|
||||||
|
(lambda o: o(inp0_int), un_int_ops),
|
||||||
|
(lambda o: o(inp0, inp1), bin_ops),
|
||||||
|
(lambda o: o(inp0_int, inp1_int), bin_int_ops),
|
||||||
|
(lambda o: o(inp0_int, 0), tensor_and_int_ops),
|
||||||
|
]
|
||||||
|
for setup_fn, op_list in setups_and_oplists:
|
||||||
|
for op in op_list:
|
||||||
|
setup_fn(op)
|
||||||
|
assert most_recent_func is not None
|
||||||
|
BUILTIN_TO_TENSOR_FN_MAP[op] = most_recent_func
|
||||||
|
|
||||||
|
# gather the reverse functions
|
||||||
|
rsetups_and_oplists = [
|
||||||
|
(
|
||||||
|
lambda o: o(1, inp1),
|
||||||
|
bin_ops,
|
||||||
|
), # Get r* ops, (ex. __sub__(int, Tensor) -> __rsub__(Tensor, int))
|
||||||
|
(lambda o: o(1, inp1_int), bin_int_ops),
|
||||||
|
(lambda o: o(0, inp0_int), tensor_and_int_ops),
|
||||||
|
]
|
||||||
|
|
||||||
|
rskips = {operator.matmul, operator.imatmul, operator.getitem}
|
||||||
|
for setup_fn, op_list in rsetups_and_oplists:
|
||||||
|
for op in op_list:
|
||||||
|
if op in rskips:
|
||||||
|
continue
|
||||||
|
setup_fn(op)
|
||||||
|
assert most_recent_func is not None
|
||||||
|
if most_recent_func != BUILTIN_TO_TENSOR_FN_MAP[op]:
|
||||||
|
BUILTIN_TO_TENSOR_RFN_MAP[op] = most_recent_func
|
||||||
|
|
||||||
|
|
||||||
|
populate_builtin_to_tensor_fn_map()
|
||||||
|
|
||||||
banned_attrs = [
|
banned_attrs = [
|
||||||
fn.__self__.__name__
|
fn.__self__.__name__
|
||||||
|
|
@ -389,8 +510,15 @@ def call_torch_function(
|
||||||
|
|
||||||
|
|
||||||
def build_torch_function_fn(tx: "InstructionTranslator", value, source):
|
def build_torch_function_fn(tx: "InstructionTranslator", value, source):
|
||||||
|
from types import FunctionType
|
||||||
|
|
||||||
from .builder import SourcelessBuilder, VariableBuilder
|
from .builder import SourcelessBuilder, VariableBuilder
|
||||||
|
|
||||||
|
func = value.__torch_function__.__func__
|
||||||
|
|
||||||
|
if not isinstance(func, FunctionType):
|
||||||
|
unimplemented("Builtin/C++ torch function implementations NYI")
|
||||||
|
|
||||||
if source:
|
if source:
|
||||||
return VariableBuilder(
|
return VariableBuilder(
|
||||||
tx,
|
tx,
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,8 @@ from torch._higher_order_ops.utils import _set_compilation_env, autograd_not_imp
|
||||||
from torch._ops import HigherOrderOperator
|
from torch._ops import HigherOrderOperator
|
||||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||||
from torch.fx.experimental.proxy_tensor import (
|
from torch.fx.experimental.proxy_tensor import (
|
||||||
|
_temp_remove_metadata_torch_function_mode,
|
||||||
|
_temp_remove_pre_dispatch_torch_function_mode,
|
||||||
disable_proxy_modes_tracing,
|
disable_proxy_modes_tracing,
|
||||||
make_fx,
|
make_fx,
|
||||||
ProxyTorchDispatchMode,
|
ProxyTorchDispatchMode,
|
||||||
|
|
@ -18,14 +20,26 @@ from torch.utils._python_dispatch import _get_current_dispatch_mode
|
||||||
|
|
||||||
@exposed_in("torch")
|
@exposed_in("torch")
|
||||||
def strict_mode(callable, operands):
|
def strict_mode(callable, operands):
|
||||||
|
from torch._dynamo.backends.debugging import (
|
||||||
|
make_eager_backend_with_torch_function_modes,
|
||||||
|
)
|
||||||
|
|
||||||
if torch.compiler.is_dynamo_compiling():
|
if torch.compiler.is_dynamo_compiling():
|
||||||
return strict_mode_op(callable, operands)
|
return strict_mode_op(callable, operands)
|
||||||
|
|
||||||
with _set_compilation_env():
|
with _set_compilation_env():
|
||||||
with torch._dynamo.utils.disable_cache_limit():
|
with _temp_remove_metadata_torch_function_mode() as metadata_mode:
|
||||||
return torch.compile(strict_mode_op, backend="eager", fullgraph=True)(
|
with _temp_remove_pre_dispatch_torch_function_mode() as predispatch_mode:
|
||||||
callable, operands
|
modes = [metadata_mode, predispatch_mode]
|
||||||
)
|
modes = [mode for mode in modes if mode is not None]
|
||||||
|
if modes:
|
||||||
|
backend = make_eager_backend_with_torch_function_modes(modes)
|
||||||
|
else:
|
||||||
|
backend = "eager"
|
||||||
|
with torch._dynamo.utils.disable_cache_limit():
|
||||||
|
return torch.compile(
|
||||||
|
strict_mode_op, backend=backend, fullgraph=True
|
||||||
|
)(callable, operands)
|
||||||
|
|
||||||
|
|
||||||
class StrictMode(HigherOrderOperator):
|
class StrictMode(HigherOrderOperator):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user