mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Only thunkify proxies in some situations (#132421)
The goal of this PR is to avoid stack overflow when we create extremely long chains of thunks, and then evaluate them (e.g., as occurs if you sum(long list of symint)). The basic idea behind this PR is to only thunkify proxies if they're being created in places where they may or may not be used--crucially, symint operations that occur in user code we are tracing are eagerly placed into the graph, even if they may eventually be dead. I annotated the PR with explanation of changes. Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/132421 Approved by: https://github.com/Skylion007, https://github.com/zou3519 ghstack dependencies: #132674, #132675
This commit is contained in:
parent
54efd43022
commit
aec6332356
|
|
@ -51,6 +51,7 @@ torch.fx.experimental.symbolic_shapes
|
|||
compute_unbacked_bindings
|
||||
rebind_unbacked
|
||||
resolve_unbacked_bindings
|
||||
is_accessor_node
|
||||
|
||||
torch.fx.experimental.proxy_tensor
|
||||
-------------------------------------
|
||||
|
|
@ -65,3 +66,5 @@ torch.fx.experimental.proxy_tensor
|
|||
make_fx
|
||||
handle_sym_dispatch
|
||||
get_proxy_mode
|
||||
maybe_enable_thunkify
|
||||
maybe_disable_thunkify
|
||||
|
|
|
|||
|
|
@ -3480,8 +3480,8 @@ def forward(self, x):
|
|||
def forward(self, x):
|
||||
arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
|
||||
arg0_1 = arg0
|
||||
slice_1 = torch.ops.aten.slice.Tensor(arg0_1, 2, 0, 3)
|
||||
sym_size_int = torch.ops.aten.sym_size.int(arg0_1, 0)
|
||||
slice_1 = torch.ops.aten.slice.Tensor(arg0_1, 2, 0, 3)
|
||||
sub = sym_size_int - 1
|
||||
slice_2 = torch.ops.aten.slice.Tensor(arg0_1, 0, 0, sub); sub = None
|
||||
slice_3 = torch.ops.aten.slice.Tensor(slice_2, 1, 1, sym_size_int); slice_2 = None
|
||||
|
|
|
|||
|
|
@ -2148,7 +2148,6 @@ class TestNestedTensor(torch._dynamo.test_case.TestCase, NestedTensorTestCase):
|
|||
|
||||
self._validate_compile(fn, arg_fn)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_return_shape(self):
|
||||
nt, _ = self._get_jagged_tensor(((2, 3, 4), 5), None)
|
||||
|
||||
|
|
|
|||
|
|
@ -6037,14 +6037,14 @@ def forward(self, x):
|
|||
def forward(self, x):
|
||||
item = torch.ops.aten.item.default(x); x = None
|
||||
sym_constrain_range_for_size_default = torch.ops.aten.sym_constrain_range_for_size.default(item); sym_constrain_range_for_size_default = None
|
||||
ge = item >= 3
|
||||
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u1 >= 3 on node 'ge'"); ge = _assert_scalar_default = None
|
||||
ge_1 = item >= 3
|
||||
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u1 >= 3 on node 'ge_1'"); ge_1 = _assert_scalar_default = None
|
||||
le = item <= 5
|
||||
_assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(le, "Runtime assertion failed for expression u1 <= 5 on node 'le'"); le = _assert_scalar_default_1 = None
|
||||
gt = item > 2
|
||||
_assert_scalar_default_2 = torch.ops.aten._assert_scalar.default(gt, "Runtime assertion failed for expression 2 < u1 on node 'gt'"); gt = _assert_scalar_default_2 = None
|
||||
lt = item < 6
|
||||
_assert_scalar_default_3 = torch.ops.aten._assert_scalar.default(lt, "Runtime assertion failed for expression u1 < 6 on node 'lt'"); lt = _assert_scalar_default_3 = None
|
||||
gt_1 = item > 2
|
||||
_assert_scalar_default_2 = torch.ops.aten._assert_scalar.default(gt_1, "Runtime assertion failed for expression 2 < u1 on node 'gt_1'"); gt_1 = _assert_scalar_default_2 = None
|
||||
lt_1 = item < 6
|
||||
_assert_scalar_default_3 = torch.ops.aten._assert_scalar.default(lt_1, "Runtime assertion failed for expression u1 < 6 on node 'lt_1'"); lt_1 = _assert_scalar_default_3 = None
|
||||
foo_unbacked = torch.ops.testlib.foo_unbacked.default(item); item = None
|
||||
return (foo_unbacked,)""",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1320,8 +1320,8 @@ def forward(self, token, obj, x):
|
|||
with_effects = torch._higher_order_ops.effects.with_effects(token, torch.ops.higher_order.call_torchbind, obj, 'get'); token = obj = None
|
||||
getitem = with_effects[0]
|
||||
getitem_1 = with_effects[1]; with_effects = None
|
||||
add = torch.ops.aten.add.Tensor(getitem_1, x); getitem_1 = x = None
|
||||
return (getitem, add)""", # noqa: B950
|
||||
add_3 = torch.ops.aten.add.Tensor(getitem_1, x); getitem_1 = x = None
|
||||
return (getitem, add_3)""", # noqa: B950
|
||||
)
|
||||
|
||||
@parametrize("backend", ["eager", "aot_eager"])
|
||||
|
|
|
|||
|
|
@ -1056,8 +1056,8 @@ def forward(self, s0_1, s1_1, x_1, y_1):
|
|||
self.assertExpectedInline(r, """\
|
||||
def forward(self, x_1, y_1):
|
||||
sym_size_int = torch.ops.aten.sym_size.int(x_1, 0); x_1 = None
|
||||
empty = torch.ops.aten.empty.memory_format([sym_size_int], device = device(type='cpu'), pin_memory = False)
|
||||
sym_size_int_1 = torch.ops.aten.sym_size.int(y_1, 0); y_1 = None
|
||||
empty = torch.ops.aten.empty.memory_format([sym_size_int], device = device(type='cpu'), pin_memory = False)
|
||||
return ((sym_size_int, sym_size_int_1), empty)""")
|
||||
|
||||
def test_unary(self):
|
||||
|
|
@ -1355,8 +1355,8 @@ def forward(self, crop_camera_1, mask_1):
|
|||
view_1 = torch.ops.aten.view.default(expand_1, [sym_size_int, sym_size_int_1, sym_size_int_2]); expand_1 = sym_size_int_1 = sym_size_int_2 = None
|
||||
bmm = torch.ops.aten.bmm.default(view, view_1); view = view_1 = None
|
||||
view_2 = torch.ops.aten.view.default(bmm, [sym_size_int, 3, 3]); bmm = None
|
||||
mul = sym_size_int * 3
|
||||
view_3 = torch.ops.aten.view.default(view_2, [mul, 3]); view_2 = mul = None
|
||||
mul_4 = sym_size_int * 3
|
||||
view_3 = torch.ops.aten.view.default(view_2, [mul_4, 3]); view_2 = mul_4 = None
|
||||
mm = torch.ops.aten.mm.default(view_3, eye); view_3 = eye = None
|
||||
view_4 = torch.ops.aten.view.default(mm, [sym_size_int, 3, 3]); mm = sym_size_int = None
|
||||
index_put_ = torch.ops.aten.index_put_.default(crop_camera_1, [mask_1], view_4); crop_camera_1 = mask_1 = view_4 = index_put_ = None
|
||||
|
|
|
|||
|
|
@ -1536,29 +1536,7 @@ class OutputGraph:
|
|||
return False
|
||||
return True
|
||||
|
||||
# NB: You could try to expand this to cover more cases by simply
|
||||
# detecting whenever you have an int output, but this is a bit
|
||||
# dangerous in case someone adds a function that returns an int but is
|
||||
# mutating. So manually whitelist for now.
|
||||
def is_accessor_node(node):
|
||||
if (
|
||||
node.op == "call_method"
|
||||
and isinstance(node.args[0].meta.get("example_value"), torch.Tensor)
|
||||
and node.target in ["size", "stride", "storage_offset", "item"]
|
||||
):
|
||||
return True
|
||||
if node.op == "call_function" and node.target in [
|
||||
torch.ops.aten.sym_size,
|
||||
torch.ops.aten.sym_size.default,
|
||||
torch.ops.aten.sym_size.int,
|
||||
torch.ops.aten.sym_stride,
|
||||
torch.ops.aten.sym_stride.default,
|
||||
torch.ops.aten.sym_stride.int,
|
||||
torch.ops.aten.sym_storage_offset,
|
||||
torch.ops.aten.sym_storage_offset.default,
|
||||
]:
|
||||
return True
|
||||
return False
|
||||
from torch.fx.experimental.symbolic_shapes import is_accessor_node
|
||||
|
||||
for node in reversed(list(self.graph.nodes)):
|
||||
if len(list(node.users)) == 0:
|
||||
|
|
|
|||
|
|
@ -24,6 +24,10 @@ from torch import Tensor
|
|||
from torch._decomp.decompositions_for_rng import PhiloxStateTracker
|
||||
from torch._guards import detect_fake_mode
|
||||
from torch._prims_common import CUDARngStateHelper
|
||||
from torch.fx.experimental.proxy_tensor import (
|
||||
maybe_disable_thunkify,
|
||||
maybe_enable_thunkify,
|
||||
)
|
||||
from torch.fx.experimental.symbolic_shapes import (
|
||||
definitely_false,
|
||||
PropagateUnbackedSymInts,
|
||||
|
|
@ -188,6 +192,7 @@ def fn_prepped_for_autograd(
|
|||
def create_joint(fn: Callable, *, aot_config: AOTConfig) -> Any:
|
||||
def inner_fn(primals: List[Any], tangents: List[Any]):
|
||||
outs, tangent_mask = fn(*primals)
|
||||
|
||||
assert len(tangent_mask) == len(outs)
|
||||
outs_to_grad = [
|
||||
o for needs_tangent, o in zip(tangent_mask, outs) if needs_tangent
|
||||
|
|
@ -365,6 +370,7 @@ def create_functionalized_fn(
|
|||
) -> Any:
|
||||
@wraps(fn)
|
||||
def _functionalized_f_helper(*args):
|
||||
with maybe_enable_thunkify():
|
||||
# See Note [Disabling Functionalize TLS Above Python Functionalization]
|
||||
disable_above = torch._C._ExcludeDispatchKeyGuard(
|
||||
torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize)
|
||||
|
|
@ -386,15 +392,21 @@ def create_functionalized_fn(
|
|||
assert all(token.numel() == 0 for token in tokens)
|
||||
|
||||
with disable_above:
|
||||
# The functionalization code here can potentially trigger traces
|
||||
# into the graph, but we'd prefer to NOT do this, because if we
|
||||
# trace them now, we will end up with FX nodes that don't have
|
||||
# module stack annotations, which makes unflattener unhappy.
|
||||
# Wrap inputs into functional wrappers
|
||||
f_args = pytree.tree_map(to_fun, args)
|
||||
f_tokens = pytree.tree_map(to_fun, tokens)
|
||||
|
||||
# Populate the current FunctionalTensorMode with the tokens per
|
||||
# operator. See Note [FunctionalTensorMode is Stateful]
|
||||
functional_tensor_mode = torch.utils._python_dispatch._detect_infra_mode(
|
||||
functional_tensor_mode = (
|
||||
torch.utils._python_dispatch._detect_infra_mode(
|
||||
torch._C._TorchDispatchModeKey.FUNCTIONAL
|
||||
)
|
||||
)
|
||||
assert functional_tensor_mode is not None
|
||||
for i, k in enumerate(meta.tokens.keys()):
|
||||
functional_tensor_mode._tokens[k] = f_tokens[i]
|
||||
|
|
@ -511,7 +523,10 @@ def create_functionalized_fn(
|
|||
continue
|
||||
assert is_fun(inpt_f)
|
||||
inpt_new = from_fun(inpt_f)
|
||||
if meta.input_info[i].mutation_type == MutationType.MUTATED_IN_GRAPH:
|
||||
if (
|
||||
meta.input_info[i].mutation_type
|
||||
== MutationType.MUTATED_IN_GRAPH
|
||||
):
|
||||
# See Note [set_() Input Mutations in AOTAutograd]
|
||||
# all mutations on the input must be under no_grad, so it is safe to put in the graph
|
||||
# Here, we're saying that if an input experienced a set call, inp.set_(other),
|
||||
|
|
@ -539,7 +554,9 @@ def create_functionalized_fn(
|
|||
# we could support this
|
||||
if meta.input_info[i].mutation_inductor_storage_resize:
|
||||
# resizing is not supported on subclasses (we error earlier if this happens)
|
||||
from torch._subclasses.functional_tensor import FunctionalTensor
|
||||
from torch._subclasses.functional_tensor import (
|
||||
FunctionalTensor,
|
||||
)
|
||||
|
||||
assert isinstance(inpt_f, FunctionalTensor)
|
||||
old_storage_size = torch._functionalize_get_storage_size( # type: ignore[attr-defined]
|
||||
|
|
@ -592,7 +609,9 @@ We only support storage resizing on graph inputs as long as the input either sta
|
|||
inpt_old.copy_(inpt_new)
|
||||
elif (
|
||||
meta.input_info[i].mutates_data
|
||||
and meta.input_info[i].mutations_under_no_grad_or_inference_mode
|
||||
and meta.input_info[
|
||||
i
|
||||
].mutations_under_no_grad_or_inference_mode
|
||||
):
|
||||
# Under no_grad = run under no_grad (we still bump the VC though)
|
||||
# (inference_mode will also bump the VC, as long as the tensor in question
|
||||
|
|
@ -709,9 +728,13 @@ def aot_dispatch_subclass(
|
|||
return unwrapped_outs
|
||||
|
||||
def joint_fn(primals, tangents):
|
||||
return inner_fn(flat_fn_maybe_joint, (primals, tangents), use_trace_joint=True)
|
||||
with maybe_enable_thunkify():
|
||||
return inner_fn(
|
||||
flat_fn_maybe_joint, (primals, tangents), use_trace_joint=True
|
||||
)
|
||||
|
||||
def fw_fn(*primals):
|
||||
with maybe_enable_thunkify():
|
||||
return inner_fn(flat_fn_maybe_joint, primals, use_trace_joint=False)
|
||||
|
||||
def metadata_fn(*primals):
|
||||
|
|
@ -771,7 +794,7 @@ def create_functional_call(mod, params_spec, params_len, store_orig_mod=False):
|
|||
def functional_call(*args, **kwargs):
|
||||
with stateless._reparametrize_module(
|
||||
mod, pytree.tree_unflatten(args[:params_len], params_spec)
|
||||
):
|
||||
), maybe_disable_thunkify():
|
||||
if isinstance(mod, torch.fx.GraphModule):
|
||||
with fx_traceback.preserve_node_meta(), warnings.catch_warnings():
|
||||
warnings.filterwarnings(
|
||||
|
|
|
|||
|
|
@ -54,6 +54,7 @@ from weakref import WeakKeyDictionary
|
|||
|
||||
if TYPE_CHECKING:
|
||||
import types
|
||||
import sympy
|
||||
|
||||
from torch._ops import OpOverload
|
||||
from torch.fx._symbolic_trace import PHBase
|
||||
|
|
@ -61,7 +62,8 @@ if TYPE_CHECKING:
|
|||
|
||||
__all__ = [
|
||||
"PythonKeyTracer", "dispatch_trace", "make_fx", "DecompositionInterpreter",
|
||||
"py_sym_types", "get_innermost_proxy_mode", "get_proxy_mode", "handle_sym_dispatch"
|
||||
"py_sym_types", "get_innermost_proxy_mode", "get_proxy_mode", "handle_sym_dispatch",
|
||||
"maybe_enable_thunkify", "maybe_disable_thunkify",
|
||||
]
|
||||
|
||||
_ProxyTracer = Union["PythonKeyTracer", "_GraphAppendingTracerEx"]
|
||||
|
|
@ -157,6 +159,7 @@ def set_proxy_slot(
|
|||
tracer: _ProxyTracer,
|
||||
proxy: object
|
||||
) -> None:
|
||||
log.debug("set_proxy_slot %s (%s) %s", obj, id(obj), proxy)
|
||||
if isinstance(obj, Tensor):
|
||||
# We DO want to clobber proxies whenever we run an inplace operation
|
||||
# on a tensor, and it affects the metadata on the proxy.
|
||||
|
|
@ -175,6 +178,22 @@ def set_proxy_slot(
|
|||
if obj not in tracer.symnode_tracker:
|
||||
tracer.symnode_tracker[obj] = typing.cast(_PySymProxyType, proxy)
|
||||
|
||||
# WAR: python test/dynamo/test_subclasses.py
|
||||
# TestNestedTensor.test_basic_autograd
|
||||
#
|
||||
# AOTAutograd doesn't pass the "outer sizes" as an actual argument
|
||||
# to make_fx, but it is made use of internally in AOTAutograd's
|
||||
# call to tensor unflatten. Because the outer sizes isn't passed
|
||||
# as an argument, it is therefore untracked. However, it turns
|
||||
# out you luck out, because *Dynamo* will manually add the outer
|
||||
# sizes as an argument so you can fix up the proxy'ness.
|
||||
#
|
||||
# This is probably fixed in
|
||||
# https://github.com/pytorch/pytorch/pull/125941/
|
||||
import sympy
|
||||
if isinstance(obj.node.expr, sympy.Symbol):
|
||||
tracer.sympy_expr_tracker[obj.node.expr] = proxy
|
||||
|
||||
def has_proxy_slot(obj: Tensor, tracer: _ProxyTracer) -> bool:
|
||||
assert isinstance(obj, (Tensor, SymNode)), type(obj)
|
||||
return bool(get_proxy_slot(obj, tracer, False, lambda _: True))
|
||||
|
|
@ -276,9 +295,14 @@ def get_proxy_slot(
|
|||
tracker = tracer.symnode_tracker
|
||||
|
||||
if obj not in tracker:
|
||||
# Last ditch
|
||||
if isinstance(obj, py_sym_types) and obj.node.expr in tracer.sympy_expr_tracker:
|
||||
value = tracer.sympy_expr_tracker[obj.node.expr]
|
||||
else:
|
||||
if isinstance(default, _NoDefault):
|
||||
raise RuntimeError(f"{obj} is not tracked with proxy for {tracer}")
|
||||
raise RuntimeError(f"{obj} ({id(obj)})is not tracked with proxy for {tracer}")
|
||||
return default
|
||||
else:
|
||||
value = tracker[obj]
|
||||
res = transform(value)
|
||||
return res
|
||||
|
|
@ -330,6 +354,54 @@ def extract_val(val: _ExtractValType) -> _ExtractValType:
|
|||
|
||||
typing_extensions.assert_never(val)
|
||||
|
||||
@contextmanager
|
||||
def _enable_thunkify(tracer: _ProxyTracer, *, enable: bool = True) -> Generator[None, None, None]:
|
||||
"""
|
||||
Enable thunkification inside the context manager. Thunkification prevents
|
||||
SymNode computation from directly being traced into an FX graph; instead,
|
||||
the compute is only added to the graph if it is actually used. This helps
|
||||
us track SymNode compute when it is computed (since we need /something/
|
||||
to put in the tracker) even if it is unlikely to be used.
|
||||
"""
|
||||
old = tracer.enable_thunkify
|
||||
tracer.enable_thunkify = enable
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
tracer.enable_thunkify = old
|
||||
|
||||
@contextmanager
|
||||
def maybe_disable_thunkify() -> Generator[None, None, None]:
|
||||
"""Within a context, disable thunkification. See :func:`maybe_enable_thunkify`
|
||||
for more details. This is helpful if you have a wrapper function which
|
||||
you want to enable thunkification on, but in some segment on the inside (say,
|
||||
the original user function), you want to disable thunkification as you know
|
||||
it is not needed there.
|
||||
"""
|
||||
proxy_mode = get_proxy_mode()
|
||||
if proxy_mode is not None:
|
||||
with _enable_thunkify(proxy_mode.tracer, enable=False):
|
||||
yield
|
||||
else:
|
||||
yield
|
||||
|
||||
@contextmanager
|
||||
def maybe_enable_thunkify() -> Generator[None, None, None]:
|
||||
"""Within this context manager, if you are doing make_fx tracing, we will thunkify
|
||||
all SymNode compute and avoid tracing it into the graph unless it is actually needed.
|
||||
You should prefer to avoid using this as much as possible, as lazy evaluation of
|
||||
SymNode tracing can lead to long chains of thunks which will stack overflow
|
||||
if you evaluate them. However, this is currently sometimes necessary as there
|
||||
are buggy parts of PT2 which will fail with "s0 is not tracked with proxy" error
|
||||
due to insufficient tracing of SymNode computation.
|
||||
"""
|
||||
proxy_mode = get_proxy_mode()
|
||||
if proxy_mode is not None:
|
||||
with _enable_thunkify(proxy_mode.tracer):
|
||||
yield
|
||||
else:
|
||||
yield
|
||||
|
||||
# Note [invariants for node meta 'val']
|
||||
# What invariants do we have for the 'val' set on the FX node? It has accurate
|
||||
# metadata... but only for metadata that exists "below" all other subsystems
|
||||
|
|
@ -340,6 +412,7 @@ def extract_val(val: _ExtractValType) -> _ExtractValType:
|
|||
def set_meta(proxy: Proxy, val: _ExtractValType) -> Proxy:
|
||||
proxy.node.meta['val'] = extract_val(val)
|
||||
|
||||
with _enable_thunkify(proxy.tracer): # type: ignore[arg-type]
|
||||
# Best effort tensor_meta setting; prefer using val!
|
||||
if is_fake(val):
|
||||
proxy.node.meta['tensor_meta'] = _extract_tensor_metadata(val)
|
||||
|
|
@ -347,12 +420,16 @@ def set_meta(proxy: Proxy, val: _ExtractValType) -> Proxy:
|
|||
proxy.node.meta['tensor_meta'] = _extract_tensor_metadata(val)
|
||||
return proxy
|
||||
|
||||
def thunkify(f: Callable[_P, R], *args: _P.args, **kwargs: _P.kwargs) -> Thunk[R]:
|
||||
def thunkify(tracer: _ProxyTracer, f: Callable[_P, R], *args: _P.args, **kwargs: _P.kwargs) -> Thunk[R]:
|
||||
"""
|
||||
Delays computation of f until it's called again
|
||||
Also caches the result
|
||||
"""
|
||||
if tracer.enable_thunkify:
|
||||
return Thunk(functools.partial(f, *args, **kwargs))
|
||||
else:
|
||||
r = f(*args, **kwargs)
|
||||
return Thunk(lambda: r)
|
||||
|
||||
def track_tensor(tensor: Tensor, proxy: Proxy, *, constant: Optional[Tensor], tracer: _ProxyTracer) -> None:
|
||||
def try_set_proxy_slot(
|
||||
|
|
@ -363,7 +440,8 @@ def track_tensor(tensor: Tensor, proxy: Proxy, *, constant: Optional[Tensor], tr
|
|||
) -> None:
|
||||
assert callable(proxy_callable)
|
||||
if isinstance(outer_s, SymInt):
|
||||
set_proxy_slot(outer_s, tracer, thunkify(proxy_callable, outer_s, *args, **kwargs))
|
||||
with _enable_thunkify(tracer):
|
||||
set_proxy_slot(outer_s, tracer, thunkify(tracer, proxy_callable, outer_s, *args, **kwargs))
|
||||
# The basic idea is that we need to associate each tensor/SymInt
|
||||
# with a Proxy. How do we setup this association? We just store
|
||||
# the proxy on the proxy slot of the object, keyed on the tracer
|
||||
|
|
@ -411,7 +489,7 @@ def track_tensor_tree(
|
|||
assert isinstance(proxy, Proxy)
|
||||
# NB: eagerly set meta here, so that the numbering is in order
|
||||
set_meta(proxy, e)
|
||||
set_proxy_slot(e, tracer, thunkify(lambda: proxy))
|
||||
set_proxy_slot(e, tracer, thunkify(tracer, lambda: proxy))
|
||||
elif isinstance(e, _AnyScriptObject):
|
||||
assert isinstance(proxy, Proxy)
|
||||
set_proxy_slot(e, tracer, proxy)
|
||||
|
|
@ -700,6 +778,7 @@ def proxy_call(
|
|||
# Adding an undefined attribute to Tensor?
|
||||
args[0].proxy = proxy_out # type: ignore[attr-defined]
|
||||
|
||||
with _enable_thunkify(proxy_mode.tracer):
|
||||
out = func(*args, **kwargs)
|
||||
|
||||
# In some circumstances, we will be tracing in a situation where a tensor
|
||||
|
|
@ -804,12 +883,14 @@ class PythonKeyTracer(Tracer):
|
|||
self.tensor_tracker = WeakTensorKeyDictionary()
|
||||
self.symnode_tracker = _SymNodeDict()
|
||||
self.script_object_tracker = WeakIdKeyDictionary(dict=None, ref_type=_WeakHashRef)
|
||||
self.sympy_expr_tracker: Dict[sympy.Symbol, object] = dict()
|
||||
|
||||
# Stores the torch function that was called during tracing
|
||||
self.torch_fn_metadata = None
|
||||
# Stores the counts for every torch function called. This is to help
|
||||
# distinguish between different calls to the same torch function.
|
||||
self.torch_fn_counts = {}
|
||||
self.enable_thunkify = False
|
||||
|
||||
# In general, we don't want to make modules leaves. In principle, users of
|
||||
# this tracer might want to override this in order to turn a couple specific
|
||||
|
|
@ -901,6 +982,32 @@ def dispatch_trace(
|
|||
concrete_args: Optional[Tuple[Any, ...]] = None,
|
||||
) -> GraphModule:
|
||||
graph = tracer.trace(root, concrete_args)
|
||||
|
||||
# NB: be careful not to DCE .item() calls
|
||||
def impure_pred(n: fx.Node) -> bool:
|
||||
from .symbolic_shapes import is_accessor_node
|
||||
|
||||
# Always defer to the built-in notion of impure
|
||||
if n.is_impure():
|
||||
return True
|
||||
|
||||
# Accessors always OK to DCE
|
||||
if is_accessor_node(n):
|
||||
return False
|
||||
|
||||
# If the operator in question takes SymInt args to SymInt output,
|
||||
# we assume it's pure and OK to DCE
|
||||
if (
|
||||
isinstance(n.meta.get('val'), py_sym_types) and
|
||||
# NB: constant args ok
|
||||
all(isinstance(a.meta.get('val'), py_sym_types) for a in n.args if isinstance(a, fx.Node))
|
||||
):
|
||||
return False
|
||||
|
||||
# No idea, just assume it's not OK
|
||||
return True
|
||||
|
||||
graph.eliminate_dead_code(impure_pred)
|
||||
from torch._inductor.fx_passes.dedupe_symint_uses import dedupe_symints
|
||||
dedupe_symints(graph)
|
||||
name = root.__class__.__name__ if isinstance(root, Module) else root.__name__
|
||||
|
|
@ -945,6 +1052,7 @@ def wrap_key(f: Callable[_P, R], tensors: _P.args, tracer: _ProxyTracer, pre_dis
|
|||
|
||||
return wrapped
|
||||
|
||||
# TODO: Make downstream users of this work with OperatorBase
|
||||
ORIGINAL_ATEN: Optional[object] = None
|
||||
@contextmanager
|
||||
def set_original_aten_op(func: OpOverload) -> Generator[None, None, None]:
|
||||
|
|
@ -1129,7 +1237,7 @@ class ProxyTorchDispatchMode(TorchDispatchMode):
|
|||
# were symbolic) and it is no longer necessary to trace the
|
||||
# computation. This could occur if func triggered some guards.
|
||||
if isinstance(out, py_sym_types):
|
||||
p_out_thunk = thunkify(self._compute_proxy, func=func, args=args, out=out)
|
||||
p_out_thunk = thunkify(self.tracer, self._compute_proxy, func=func, args=args, out=out)
|
||||
set_proxy_slot(out, self.tracer, p_out_thunk)
|
||||
|
||||
return out
|
||||
|
|
@ -1139,8 +1247,10 @@ class _GraphAppendingTracerEx(fx.proxy.GraphAppendingTracer):
|
|||
script_object_tracker: WeakKeyDictionary
|
||||
symnode_tracker: WeakKeyDictionary
|
||||
tensor_tracker: WeakTensorKeyDictionary
|
||||
sympy_expr_tracker: Dict[sympy.Symbol, object]
|
||||
torch_fn_metadata: Optional[OpOverload]
|
||||
torch_fn_counts: Dict[OpOverload, int]
|
||||
enable_thunkify: bool = False
|
||||
|
||||
|
||||
# TODO: I'm not sure what the point of this class is; you can just
|
||||
|
|
@ -1159,6 +1269,7 @@ class DecompositionInterpreter(fx.Interpreter):
|
|||
# Blegh
|
||||
self.tracer.tensor_tracker = WeakTensorKeyDictionary()
|
||||
self.tracer.symnode_tracker = weakref.WeakKeyDictionary()
|
||||
self.tracer.sympy_expr_tracker = dict()
|
||||
self.decomposition_table = decomposition_table or {}
|
||||
self.mode = ProxyTorchDispatchMode(self.tracer, tracing_mode="real")
|
||||
|
||||
|
|
@ -1358,6 +1469,7 @@ class _ModuleStackTracer(PythonKeyTracer):
|
|||
concrete_args: Optional[Dict[str, object]]
|
||||
) -> fx.Graph:
|
||||
res = super().trace(root, concrete_args)
|
||||
|
||||
# Since we are making _AttrProxy mimic the original
|
||||
# submodule, when someone registers a module directly
|
||||
# to the tracer while tracing, the proxy object gets registered
|
||||
|
|
|
|||
|
|
@ -110,7 +110,7 @@ __all__ = [
|
|||
"StatefulSymbolicContext", "SubclassSymbolicContext", "statically_known_true",
|
||||
"guard_size_oblivious", "check_consistent",
|
||||
"compute_unbacked_bindings", "ConvertIntKey",
|
||||
"rebind_unbacked", "resolve_unbacked_bindings",
|
||||
"rebind_unbacked", "resolve_unbacked_bindings", "is_accessor_node",
|
||||
]
|
||||
|
||||
# FX node metadata keys for symbolic shape FX graph.
|
||||
|
|
@ -337,6 +337,32 @@ def rebind_unbacked(shape_env, n: torch.fx.Node, result):
|
|||
# Reuse the OLD symbol name
|
||||
shape_env._rename_unbacked_to(raw_u1, raw_u0)
|
||||
|
||||
# NB: You could try to expand this to cover more cases by simply
|
||||
# detecting whenever you have an int output, but this is a bit
|
||||
# dangerous in case someone adds a function that returns an int but is
|
||||
# mutating. So manually whitelist for now.
|
||||
def is_accessor_node(node: torch.fx.Node) -> bool:
|
||||
# Dynamo only exercised condition
|
||||
if (
|
||||
node.op == "call_method"
|
||||
and isinstance(node.args[0].meta.get("example_value"), torch.Tensor)
|
||||
and node.target in ["size", "stride", "storage_offset", "item"]
|
||||
):
|
||||
return True
|
||||
if node.op == "call_function" and node.target in [
|
||||
torch.ops.aten.sym_size,
|
||||
torch.ops.aten.sym_size.default,
|
||||
torch.ops.aten.sym_size.int,
|
||||
torch.ops.aten.sym_stride,
|
||||
torch.ops.aten.sym_stride.default,
|
||||
torch.ops.aten.sym_stride.int,
|
||||
torch.ops.aten.sym_storage_offset,
|
||||
torch.ops.aten.sym_storage_offset.default,
|
||||
torch.ops.aten.sym_numel.default,
|
||||
]:
|
||||
return True
|
||||
return False
|
||||
|
||||
def canonicalize_bool_expr(expr: SympyBoolean) -> SympyBoolean:
|
||||
r""" Canonicalize a boolean expression by transforming it into a lt / le
|
||||
inequality and moving all the non-constant terms to the rhs.
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ import torch
|
|||
import inspect
|
||||
import operator
|
||||
import collections
|
||||
import logging
|
||||
|
||||
from dataclasses import is_dataclass, fields
|
||||
|
||||
|
|
@ -25,6 +26,9 @@ __all__ = ['TracerBase', 'GraphAppendingTracer', 'TraceError',
|
|||
'ScopeContextManager']
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
class Scope:
|
||||
""" Scope object that records the module path and the module type
|
||||
|
|
@ -136,6 +140,7 @@ class TracerBase:
|
|||
modification of values used in node creation. For example, one might
|
||||
want to disallow in-place operations from being recorded.
|
||||
"""
|
||||
|
||||
if kind == 'call_function' and self.check_mutable_operations:
|
||||
check_for_mutable_operation(target, args, kwargs)
|
||||
|
||||
|
|
@ -175,6 +180,8 @@ class TracerBase:
|
|||
|
||||
elif self.module_stack:
|
||||
node.meta['nn_module_stack'] = copy.copy(self.module_stack)
|
||||
|
||||
log.debug("create_node %s", node)
|
||||
return node
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
|
|
|
|||
|
|
@ -302,8 +302,13 @@ class NestedTensor(torch.Tensor):
|
|||
if kwargs is None:
|
||||
kwargs = {}
|
||||
|
||||
from torch.fx.experimental.proxy_tensor import maybe_enable_thunkify
|
||||
|
||||
from .ops import jagged_torch_function
|
||||
|
||||
# This should be removed after
|
||||
# https://github.com/pytorch/pytorch/pull/125941/ lands
|
||||
with maybe_enable_thunkify():
|
||||
try:
|
||||
return jagged_torch_function(func, *args, **kwargs)
|
||||
except NotImplementedError:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user