mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Support map in pre-dispatch functionalization (#121444)
When we enter map_autograd, we try to trace through fwd/bwd of a map operator that is wrapped in ctx.functionalize wrapper. This forces us to go through PreDispatch functionalization again (only the python part). As a result, it revealed our previous bug where pre-dispatch mode handling doesn't actually manage the local dispatch key set. (If there is no active mode, we need to turn off PreDispatch key). This PR fixes that. Also I shuffled some APIs around so that there is less code duplication as the setting/unsetting logic is quite hard to get it right. Pull Request resolved: https://github.com/pytorch/pytorch/pull/121444 Approved by: https://github.com/bdhirsh
This commit is contained in:
parent
8ac0f072e6
commit
8a0436014d
|
|
@ -70,7 +70,7 @@ class TestHOP(TestCase):
|
|||
self.assertEqual(type(orig), type(loaded))
|
||||
self.assertEqual(orig, loaded)
|
||||
|
||||
@ops(hop_tests, allowed_dtypes=(torch.float, torch.int))
|
||||
@ops(hop_tests, allowed_dtypes=(torch.float,))
|
||||
def test_aot_export(self, device, dtype, op):
|
||||
class Foo(torch.nn.Module):
|
||||
def forward(self, *args):
|
||||
|
|
@ -85,7 +85,7 @@ class TestHOP(TestCase):
|
|||
ep = export(model, args, kwargs)
|
||||
self._compare(model, ep, args, kwargs)
|
||||
|
||||
@ops(hop_tests, allowed_dtypes=(torch.float, torch.int))
|
||||
@ops(hop_tests, allowed_dtypes=(torch.float,))
|
||||
def test_pre_dispatch_export(self, device, dtype, op):
|
||||
class Foo(torch.nn.Module):
|
||||
def forward(self, *args):
|
||||
|
|
@ -100,7 +100,7 @@ class TestHOP(TestCase):
|
|||
ep = _export(model, args, kwargs, pre_dispatch=True)
|
||||
self._compare(model, ep, args, kwargs)
|
||||
|
||||
@ops(hop_tests, allowed_dtypes=(torch.float, torch.int))
|
||||
@ops(hop_tests, allowed_dtypes=(torch.float,))
|
||||
def test_retrace_export(self, device, dtype, op):
|
||||
class Foo(torch.nn.Module):
|
||||
def forward(self, *args):
|
||||
|
|
@ -116,7 +116,7 @@ class TestHOP(TestCase):
|
|||
ep = ep.run_decompositions()
|
||||
self._compare(model, ep, args, kwargs)
|
||||
|
||||
@ops(hop_tests, allowed_dtypes=(torch.float, torch.int))
|
||||
@ops(hop_tests, allowed_dtypes=(torch.float,))
|
||||
def test_serialize_export(self, device, dtype, op):
|
||||
class Foo(torch.nn.Module):
|
||||
def forward(self, *args):
|
||||
|
|
|
|||
|
|
@ -49,6 +49,7 @@ from functorch.compile import (
|
|||
nop, default_partition, default_decompositions,
|
||||
memory_efficient_fusion, get_aot_compilation_context, make_boxed_compiler
|
||||
)
|
||||
from functorch.experimental import control_flow
|
||||
from torch._decomp import decomposition_table
|
||||
|
||||
from torch.testing._internal.common_device_type import ops
|
||||
|
|
@ -3067,6 +3068,106 @@ def forward(self, arg0_1):
|
|||
sin_1 = torch.ops.aten.sin.default(add); add = None
|
||||
return (sin_1,)""")
|
||||
|
||||
@unittest.skipIf(IS_WINDOWS, "Windows isn't supported for this case")
|
||||
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "TorchDynamo is not supported")
|
||||
def test_aot_export_predispatch_map_1(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x, y):
|
||||
def true_fn(x, r):
|
||||
y = x.sin()
|
||||
y.add_(5)
|
||||
return y.cos() + r.sum()
|
||||
|
||||
def false_fn(x, r):
|
||||
z = x.cos()
|
||||
|
||||
def f(x, y):
|
||||
a = x.cos()
|
||||
a.add_(5)
|
||||
return a + y
|
||||
|
||||
return z + control_flow.map(f, z, r).sum() + control_flow.map(f, z, r).sum()
|
||||
|
||||
a = torch.cond(x.shape[0] > 4, true_fn, false_fn, [x, y])
|
||||
return (a + 3, a + 4)
|
||||
inps = [torch.randn(2, 2), torch.ones(2)]
|
||||
gm, _ = aot_export_module(M(), inps, trace_joint=False, pre_dispatch=True)
|
||||
self.assertExpectedInline(str(gm.code).strip(), """\
|
||||
def forward(self, arg0_1, arg1_1):
|
||||
true_graph_0 = self.true_graph_0
|
||||
false_graph_0 = self.false_graph_0
|
||||
conditional = torch.ops.higher_order.cond(False, true_graph_0, false_graph_0, [arg0_1, arg1_1]); true_graph_0 = false_graph_0 = arg0_1 = arg1_1 = None
|
||||
getitem = conditional[0]; conditional = None
|
||||
add = torch.ops.aten.add.Tensor(getitem, 3)
|
||||
add_1 = torch.ops.aten.add.Tensor(getitem, 4); getitem = None
|
||||
return (add, add_1)""") # noqa: B950
|
||||
self.assertExpectedInline(str(gm.true_graph_0.code).strip(), """\
|
||||
def forward(self, arg0_1, arg1_1):
|
||||
sin = torch.ops.aten.sin.default(arg0_1); arg0_1 = None
|
||||
add = torch.ops.aten.add.Tensor(sin, 5); sin = None
|
||||
cos = torch.ops.aten.cos.default(add); add = None
|
||||
sum_1 = torch.ops.aten.sum.default(arg1_1); arg1_1 = None
|
||||
add_1 = torch.ops.aten.add.Tensor(cos, sum_1); cos = sum_1 = None
|
||||
return (add_1,)""")
|
||||
self.assertExpectedInline(str(gm.false_graph_0.code).strip(), """\
|
||||
def forward(self, arg0_1, arg1_1):
|
||||
cos = torch.ops.aten.cos.default(arg0_1); arg0_1 = None
|
||||
select = torch.ops.aten.select.int(cos, 0, 0)
|
||||
body_graph_0 = self.body_graph_0
|
||||
map_impl = torch.ops.higher_order.map_impl(body_graph_0, [cos], [arg1_1]); body_graph_0 = None
|
||||
getitem = map_impl[0]; map_impl = None
|
||||
sum_1 = torch.ops.aten.sum.default(getitem); getitem = None
|
||||
add = torch.ops.aten.add.Tensor(cos, sum_1); sum_1 = None
|
||||
select_1 = torch.ops.aten.select.int(cos, 0, 0)
|
||||
body_graph_1 = self.body_graph_1
|
||||
map_impl_1 = torch.ops.higher_order.map_impl(body_graph_1, [cos], [arg1_1]); body_graph_1 = cos = arg1_1 = None
|
||||
getitem_1 = map_impl_1[0]; map_impl_1 = None
|
||||
sum_2 = torch.ops.aten.sum.default(getitem_1); getitem_1 = None
|
||||
add_1 = torch.ops.aten.add.Tensor(add, sum_2); add = sum_2 = None
|
||||
return (add_1,)""")
|
||||
self.assertExpectedInline(str(gm.false_graph_0.body_graph_0.code).strip(), """\
|
||||
def forward(self, arg0_1, arg1_1):
|
||||
cos = torch.ops.aten.cos.default(arg0_1); arg0_1 = None
|
||||
add = torch.ops.aten.add.Tensor(cos, 5); cos = None
|
||||
add_1 = torch.ops.aten.add.Tensor(add, arg1_1); add = arg1_1 = None
|
||||
return (add_1,)""")
|
||||
|
||||
def test_aot_export_predispatch_map_2(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x, y):
|
||||
z = x.cos()
|
||||
|
||||
def f(x, y):
|
||||
a = x.cos()
|
||||
a.add_(5)
|
||||
return a + y
|
||||
|
||||
return (z + control_flow.map(f, z, y).sum(),)
|
||||
|
||||
inps = [torch.randn(2, 2), torch.ones(2)]
|
||||
gm, _ = aot_export_module(M(), inps, trace_joint=False, pre_dispatch=True)
|
||||
self.assertExpectedInline(str(gm.code).strip(), """\
|
||||
def forward(self, arg0_1, arg1_1):
|
||||
cos = torch.ops.aten.cos.default(arg0_1); arg0_1 = None
|
||||
body_graph_0 = self.body_graph_0
|
||||
map_impl = torch.ops.higher_order.map_impl(body_graph_0, [cos], [arg1_1]); body_graph_0 = arg1_1 = None
|
||||
getitem = map_impl[0]; map_impl = None
|
||||
sum_1 = torch.ops.aten.sum.default(getitem); getitem = None
|
||||
add = torch.ops.aten.add.Tensor(cos, sum_1); cos = sum_1 = None
|
||||
return (add,)""") # noqa: B950
|
||||
self.assertExpectedInline(str(gm.body_graph_0.code).strip(), """\
|
||||
def forward(self, arg0_1, arg1_1):
|
||||
cos = torch.ops.aten.cos.default(arg0_1); arg0_1 = None
|
||||
add = torch.ops.aten.add.Tensor(cos, 5); cos = None
|
||||
add_1 = torch.ops.aten.add.Tensor(add, arg1_1); add = arg1_1 = None
|
||||
return [add_1]""")
|
||||
|
||||
@unittest.skipIf(IS_WINDOWS, "Windows isn't supported for this case")
|
||||
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "TorchDynamo is not supported")
|
||||
def test_aot_export_predispatch_with_cond(self):
|
||||
|
|
|
|||
|
|
@ -273,21 +273,6 @@ class HigherOrderOperator(OperatorBase):
|
|||
# that PreDispatch key is still active. In that case, we just redispatch
|
||||
# it to next key. This is only safe to do when PreDispatch key stack has no
|
||||
# active modes.
|
||||
# TODO (tmanlaibaatar) Make it generic fallback mechanism
|
||||
def _(*args, **kwargs):
|
||||
if _len_torch_dispatch_stack_pre_dispatch() == 0:
|
||||
with torch._C._ExcludeDispatchKeyGuard(
|
||||
torch._C.DispatchKeySet(DispatchKey.PreDispatch)
|
||||
):
|
||||
return self(*args, **kwargs)
|
||||
raise AssertionError(
|
||||
"""
|
||||
Can't directly invoke HOP implementation at PreDispatch key
|
||||
if there are active modes on PreDispatch mode stack.
|
||||
"""
|
||||
)
|
||||
|
||||
self.py_impl(torch._C.DispatchKey.PreDispatch)(_)
|
||||
|
||||
def py_impl(self, k):
|
||||
if isinstance(k, torch._C.DispatchKey) and not self.non_fallthrough_keys.has(k):
|
||||
|
|
@ -456,14 +441,30 @@ def unset_mode_pre_dispatch(mode_key):
|
|||
torch._C._TorchDispatchModeKey.PROXY,
|
||||
torch._C._TorchDispatchModeKey.FUNCTIONAL,
|
||||
)
|
||||
if mode_key == torch._C._TorchDispatchModeKey.PROXY:
|
||||
current_mode = current_mode_stack_pre_dispatch.get(0)
|
||||
mode_stack_state_for_pre_dispatch().set(0, None)
|
||||
return current_mode
|
||||
else:
|
||||
current_mode = current_mode_stack_pre_dispatch.get(1)
|
||||
mode_stack_state_for_pre_dispatch().set(1, None)
|
||||
return current_mode
|
||||
|
||||
def _unset_mode():
|
||||
if mode_key == torch._C._TorchDispatchModeKey.PROXY:
|
||||
current_mode = current_mode_stack_pre_dispatch.get(0)
|
||||
mode_stack_state_for_pre_dispatch().set(0, None)
|
||||
return current_mode
|
||||
else:
|
||||
current_mode = current_mode_stack_pre_dispatch.get(1)
|
||||
mode_stack_state_for_pre_dispatch().set(1, None)
|
||||
return current_mode
|
||||
|
||||
current_mode = _unset_mode()
|
||||
|
||||
new_pre_dispatch_len = _len_torch_dispatch_stack_pre_dispatch()
|
||||
# When we are unsetting a mode, we need to check if there is
|
||||
# active mode left on the PreDispatch key. If there is nothing
|
||||
# active, we need to remove PreDispatch key from local dispatch include
|
||||
# set.
|
||||
if new_pre_dispatch_len == 0:
|
||||
torch._C._dispatch_tls_set_dispatch_key_included(
|
||||
torch._C.DispatchKey.PreDispatch, False
|
||||
)
|
||||
|
||||
return current_mode
|
||||
|
||||
|
||||
def _set_mode_pre_dispatch(mode):
|
||||
|
|
@ -471,30 +472,39 @@ def _set_mode_pre_dispatch(mode):
|
|||
from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode
|
||||
|
||||
assert isinstance(mode, (FunctionalTensorMode, ProxyTorchDispatchMode))
|
||||
|
||||
previous_mode_stack_len = _len_torch_dispatch_stack_pre_dispatch()
|
||||
if isinstance(mode, FunctionalTensorMode):
|
||||
current_mode = mode_stack_state_for_pre_dispatch().get(1)
|
||||
assert current_mode is None
|
||||
mode_stack_state_for_pre_dispatch().set(1, mode)
|
||||
return
|
||||
else:
|
||||
current_mode = mode_stack_state_for_pre_dispatch().get(0)
|
||||
assert current_mode is None
|
||||
mode_stack_state_for_pre_dispatch().set(0, mode)
|
||||
|
||||
current_mode = mode_stack_state_for_pre_dispatch().get(0)
|
||||
assert current_mode is None
|
||||
mode_stack_state_for_pre_dispatch().set(0, mode)
|
||||
# When we are setting a mode, we need to check if there is
|
||||
# active mode left on the PreDispatch key. If there was nothing
|
||||
# active before setting this mode, it means that PreDispatch key
|
||||
# was turned off. So we need to turn it on again.
|
||||
if previous_mode_stack_len == 0:
|
||||
torch._C._dispatch_tls_set_dispatch_key_included(
|
||||
torch._C.DispatchKey.PreDispatch, True
|
||||
)
|
||||
|
||||
|
||||
def _pop_mode_from_pre_dispatch():
|
||||
mode_stack = mode_stack_state_for_pre_dispatch()
|
||||
pre_dispatch_len = _len_torch_dispatch_stack_pre_dispatch()
|
||||
|
||||
if pre_dispatch_len == 0:
|
||||
raise AssertionError("Trying to pop empty mode stack")
|
||||
|
||||
if mode_stack.get(1) is not None:
|
||||
res = mode_stack.get(1)
|
||||
mode_stack.set(1, None)
|
||||
return res
|
||||
return unset_mode_pre_dispatch(torch._C._TorchDispatchModeKey.FUNCTIONAL)
|
||||
|
||||
if mode_stack.get(0) is not None:
|
||||
res = mode_stack.get(0)
|
||||
mode_stack.set(0, None)
|
||||
return res
|
||||
|
||||
raise AssertionError("Trying to pop empty mode stack")
|
||||
return unset_mode_pre_dispatch(torch._C._TorchDispatchModeKey.PROXY)
|
||||
|
||||
|
||||
def _len_torch_dispatch_stack_pre_dispatch():
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ from torch.fx.graph_module import _assign_attr
|
|||
from weakref import WeakKeyDictionary
|
||||
from collections import defaultdict
|
||||
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode, unset_fake_temporarily, is_fake
|
||||
from torch._dispatch.python import enable_python_dispatcher, enable_pre_dispatch
|
||||
from torch._dispatch.python import enable_python_dispatcher
|
||||
import torch.fx as fx
|
||||
from torch.fx.node import _side_effectful_need_to_be_preserved_pre_dispatch
|
||||
from torch.fx.passes.shape_prop import _extract_tensor_metadata
|
||||
|
|
@ -1154,13 +1154,10 @@ def make_fx(f,
|
|||
raise AssertionError(f"Unexpected tracing type: {tracing_mode}")
|
||||
|
||||
python_dispatcher_mode: Any = nullcontext()
|
||||
pre_dispatch_mode: Any = nullcontext()
|
||||
# pre-autograd tracing uses per-dispatch-key modes,
|
||||
# which requires the python dispatcher
|
||||
if tracing_mode == "symbolic" or pre_dispatch:
|
||||
python_dispatcher_mode = enable_python_dispatcher()
|
||||
if pre_dispatch:
|
||||
pre_dispatch_mode = enable_pre_dispatch()
|
||||
|
||||
proxy_function_mode: Any = nullcontext()
|
||||
if pre_dispatch:
|
||||
|
|
@ -1218,7 +1215,7 @@ def make_fx(f,
|
|||
# We also disable tracing by any other tensor proxy-based tracers except the current. The
|
||||
# purpose of `make_fx` is to produce graphmodules as a side effect; its internal execution is
|
||||
# thus irrelevant to any external functional trace.
|
||||
with decompose(decomposition_table), fake_tensor_mode, python_dispatcher_mode, pre_dispatch_mode, proxy_function_mode, \
|
||||
with decompose(decomposition_table), fake_tensor_mode, python_dispatcher_mode, proxy_function_mode, \
|
||||
sym_mode, torch_fn_metadata_mode, proxy_mode, disable_autocast_cache():
|
||||
t = dispatch_trace(wrap_key(func, args, fx_tracer, pre_dispatch), tracer=fx_tracer, concrete_args=tuple(phs))
|
||||
|
||||
|
|
|
|||
|
|
@ -50,7 +50,7 @@ hop_that_doesnt_have_opinfo_test_allowlist = [
|
|||
"run_with_rng_state",
|
||||
"out_dtype",
|
||||
"trace_wrapped",
|
||||
"map",
|
||||
"map", # T183144629
|
||||
"map_impl",
|
||||
"with_effects",
|
||||
"strict_mode",
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user