mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Only thunkify proxies in some situations (#132421)
The goal of this PR is to avoid stack overflow when we create extremely long chains of thunks, and then evaluate them (e.g., as occurs if you sum(long list of symint)). The basic idea behind this PR is to only thunkify proxies if they're being created in places where they may or may not be used--crucially, symint operations that occur in user code we are tracing are eagerly placed into the graph, even if they may eventually be dead. I annotated the PR with explanation of changes. Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/132421 Approved by: https://github.com/Skylion007, https://github.com/zou3519 ghstack dependencies: #132674, #132675
This commit is contained in:
parent
54efd43022
commit
aec6332356
|
|
@ -51,6 +51,7 @@ torch.fx.experimental.symbolic_shapes
|
|||
compute_unbacked_bindings
|
||||
rebind_unbacked
|
||||
resolve_unbacked_bindings
|
||||
is_accessor_node
|
||||
|
||||
torch.fx.experimental.proxy_tensor
|
||||
-------------------------------------
|
||||
|
|
@ -65,3 +66,5 @@ torch.fx.experimental.proxy_tensor
|
|||
make_fx
|
||||
handle_sym_dispatch
|
||||
get_proxy_mode
|
||||
maybe_enable_thunkify
|
||||
maybe_disable_thunkify
|
||||
|
|
|
|||
|
|
@ -3480,8 +3480,8 @@ def forward(self, x):
|
|||
def forward(self, x):
|
||||
arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
|
||||
arg0_1 = arg0
|
||||
slice_1 = torch.ops.aten.slice.Tensor(arg0_1, 2, 0, 3)
|
||||
sym_size_int = torch.ops.aten.sym_size.int(arg0_1, 0)
|
||||
slice_1 = torch.ops.aten.slice.Tensor(arg0_1, 2, 0, 3)
|
||||
sub = sym_size_int - 1
|
||||
slice_2 = torch.ops.aten.slice.Tensor(arg0_1, 0, 0, sub); sub = None
|
||||
slice_3 = torch.ops.aten.slice.Tensor(slice_2, 1, 1, sym_size_int); slice_2 = None
|
||||
|
|
|
|||
|
|
@ -2148,7 +2148,6 @@ class TestNestedTensor(torch._dynamo.test_case.TestCase, NestedTensorTestCase):
|
|||
|
||||
self._validate_compile(fn, arg_fn)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_return_shape(self):
|
||||
nt, _ = self._get_jagged_tensor(((2, 3, 4), 5), None)
|
||||
|
||||
|
|
|
|||
|
|
@ -6037,14 +6037,14 @@ def forward(self, x):
|
|||
def forward(self, x):
|
||||
item = torch.ops.aten.item.default(x); x = None
|
||||
sym_constrain_range_for_size_default = torch.ops.aten.sym_constrain_range_for_size.default(item); sym_constrain_range_for_size_default = None
|
||||
ge = item >= 3
|
||||
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u1 >= 3 on node 'ge'"); ge = _assert_scalar_default = None
|
||||
ge_1 = item >= 3
|
||||
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u1 >= 3 on node 'ge_1'"); ge_1 = _assert_scalar_default = None
|
||||
le = item <= 5
|
||||
_assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(le, "Runtime assertion failed for expression u1 <= 5 on node 'le'"); le = _assert_scalar_default_1 = None
|
||||
gt = item > 2
|
||||
_assert_scalar_default_2 = torch.ops.aten._assert_scalar.default(gt, "Runtime assertion failed for expression 2 < u1 on node 'gt'"); gt = _assert_scalar_default_2 = None
|
||||
lt = item < 6
|
||||
_assert_scalar_default_3 = torch.ops.aten._assert_scalar.default(lt, "Runtime assertion failed for expression u1 < 6 on node 'lt'"); lt = _assert_scalar_default_3 = None
|
||||
gt_1 = item > 2
|
||||
_assert_scalar_default_2 = torch.ops.aten._assert_scalar.default(gt_1, "Runtime assertion failed for expression 2 < u1 on node 'gt_1'"); gt_1 = _assert_scalar_default_2 = None
|
||||
lt_1 = item < 6
|
||||
_assert_scalar_default_3 = torch.ops.aten._assert_scalar.default(lt_1, "Runtime assertion failed for expression u1 < 6 on node 'lt_1'"); lt_1 = _assert_scalar_default_3 = None
|
||||
foo_unbacked = torch.ops.testlib.foo_unbacked.default(item); item = None
|
||||
return (foo_unbacked,)""",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1320,8 +1320,8 @@ def forward(self, token, obj, x):
|
|||
with_effects = torch._higher_order_ops.effects.with_effects(token, torch.ops.higher_order.call_torchbind, obj, 'get'); token = obj = None
|
||||
getitem = with_effects[0]
|
||||
getitem_1 = with_effects[1]; with_effects = None
|
||||
add = torch.ops.aten.add.Tensor(getitem_1, x); getitem_1 = x = None
|
||||
return (getitem, add)""", # noqa: B950
|
||||
add_3 = torch.ops.aten.add.Tensor(getitem_1, x); getitem_1 = x = None
|
||||
return (getitem, add_3)""", # noqa: B950
|
||||
)
|
||||
|
||||
@parametrize("backend", ["eager", "aot_eager"])
|
||||
|
|
|
|||
|
|
@ -1056,8 +1056,8 @@ def forward(self, s0_1, s1_1, x_1, y_1):
|
|||
self.assertExpectedInline(r, """\
|
||||
def forward(self, x_1, y_1):
|
||||
sym_size_int = torch.ops.aten.sym_size.int(x_1, 0); x_1 = None
|
||||
empty = torch.ops.aten.empty.memory_format([sym_size_int], device = device(type='cpu'), pin_memory = False)
|
||||
sym_size_int_1 = torch.ops.aten.sym_size.int(y_1, 0); y_1 = None
|
||||
empty = torch.ops.aten.empty.memory_format([sym_size_int], device = device(type='cpu'), pin_memory = False)
|
||||
return ((sym_size_int, sym_size_int_1), empty)""")
|
||||
|
||||
def test_unary(self):
|
||||
|
|
@ -1355,8 +1355,8 @@ def forward(self, crop_camera_1, mask_1):
|
|||
view_1 = torch.ops.aten.view.default(expand_1, [sym_size_int, sym_size_int_1, sym_size_int_2]); expand_1 = sym_size_int_1 = sym_size_int_2 = None
|
||||
bmm = torch.ops.aten.bmm.default(view, view_1); view = view_1 = None
|
||||
view_2 = torch.ops.aten.view.default(bmm, [sym_size_int, 3, 3]); bmm = None
|
||||
mul = sym_size_int * 3
|
||||
view_3 = torch.ops.aten.view.default(view_2, [mul, 3]); view_2 = mul = None
|
||||
mul_4 = sym_size_int * 3
|
||||
view_3 = torch.ops.aten.view.default(view_2, [mul_4, 3]); view_2 = mul_4 = None
|
||||
mm = torch.ops.aten.mm.default(view_3, eye); view_3 = eye = None
|
||||
view_4 = torch.ops.aten.view.default(mm, [sym_size_int, 3, 3]); mm = sym_size_int = None
|
||||
index_put_ = torch.ops.aten.index_put_.default(crop_camera_1, [mask_1], view_4); crop_camera_1 = mask_1 = view_4 = index_put_ = None
|
||||
|
|
|
|||
|
|
@ -1536,29 +1536,7 @@ class OutputGraph:
|
|||
return False
|
||||
return True
|
||||
|
||||
# NB: You could try to expand this to cover more cases by simply
|
||||
# detecting whenever you have an int output, but this is a bit
|
||||
# dangerous in case someone adds a function that returns an int but is
|
||||
# mutating. So manually whitelist for now.
|
||||
def is_accessor_node(node):
|
||||
if (
|
||||
node.op == "call_method"
|
||||
and isinstance(node.args[0].meta.get("example_value"), torch.Tensor)
|
||||
and node.target in ["size", "stride", "storage_offset", "item"]
|
||||
):
|
||||
return True
|
||||
if node.op == "call_function" and node.target in [
|
||||
torch.ops.aten.sym_size,
|
||||
torch.ops.aten.sym_size.default,
|
||||
torch.ops.aten.sym_size.int,
|
||||
torch.ops.aten.sym_stride,
|
||||
torch.ops.aten.sym_stride.default,
|
||||
torch.ops.aten.sym_stride.int,
|
||||
torch.ops.aten.sym_storage_offset,
|
||||
torch.ops.aten.sym_storage_offset.default,
|
||||
]:
|
||||
return True
|
||||
return False
|
||||
from torch.fx.experimental.symbolic_shapes import is_accessor_node
|
||||
|
||||
for node in reversed(list(self.graph.nodes)):
|
||||
if len(list(node.users)) == 0:
|
||||
|
|
|
|||
|
|
@ -24,6 +24,10 @@ from torch import Tensor
|
|||
from torch._decomp.decompositions_for_rng import PhiloxStateTracker
|
||||
from torch._guards import detect_fake_mode
|
||||
from torch._prims_common import CUDARngStateHelper
|
||||
from torch.fx.experimental.proxy_tensor import (
|
||||
maybe_disable_thunkify,
|
||||
maybe_enable_thunkify,
|
||||
)
|
||||
from torch.fx.experimental.symbolic_shapes import (
|
||||
definitely_false,
|
||||
PropagateUnbackedSymInts,
|
||||
|
|
@ -188,6 +192,7 @@ def fn_prepped_for_autograd(
|
|||
def create_joint(fn: Callable, *, aot_config: AOTConfig) -> Any:
|
||||
def inner_fn(primals: List[Any], tangents: List[Any]):
|
||||
outs, tangent_mask = fn(*primals)
|
||||
|
||||
assert len(tangent_mask) == len(outs)
|
||||
outs_to_grad = [
|
||||
o for needs_tangent, o in zip(tangent_mask, outs) if needs_tangent
|
||||
|
|
@ -365,266 +370,280 @@ def create_functionalized_fn(
|
|||
) -> Any:
|
||||
@wraps(fn)
|
||||
def _functionalized_f_helper(*args):
|
||||
# See Note [Disabling Functionalize TLS Above Python Functionalization]
|
||||
disable_above = torch._C._ExcludeDispatchKeyGuard(
|
||||
torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize)
|
||||
)
|
||||
|
||||
# See Note [Side-Effectful Tokens in AOTAutograd]
|
||||
if trace_joint:
|
||||
assert (
|
||||
isinstance(args, tuple)
|
||||
and len(args) == 2
|
||||
and isinstance(args[0], (list, tuple))
|
||||
with maybe_enable_thunkify():
|
||||
# See Note [Disabling Functionalize TLS Above Python Functionalization]
|
||||
disable_above = torch._C._ExcludeDispatchKeyGuard(
|
||||
torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize)
|
||||
)
|
||||
tokens = args[0][: len(meta.tokens)]
|
||||
actual_args = args[0][len(meta.tokens) :]
|
||||
args = (actual_args, args[1])
|
||||
else:
|
||||
tokens = args[: len(meta.tokens)]
|
||||
args = args[len(meta.tokens) :]
|
||||
assert all(token.numel() == 0 for token in tokens)
|
||||
|
||||
with disable_above:
|
||||
# Wrap inputs into functional wrappers
|
||||
f_args = pytree.tree_map(to_fun, args)
|
||||
f_tokens = pytree.tree_map(to_fun, tokens)
|
||||
|
||||
# Populate the current FunctionalTensorMode with the tokens per
|
||||
# operator. See Note [FunctionalTensorMode is Stateful]
|
||||
functional_tensor_mode = torch.utils._python_dispatch._detect_infra_mode(
|
||||
torch._C._TorchDispatchModeKey.FUNCTIONAL
|
||||
)
|
||||
assert functional_tensor_mode is not None
|
||||
for i, k in enumerate(meta.tokens.keys()):
|
||||
functional_tensor_mode._tokens[k] = f_tokens[i]
|
||||
|
||||
# Run the joint
|
||||
f_outs = fn(*f_args)
|
||||
|
||||
# Return both the tokens and the outputs
|
||||
# See Note [Side-Effectful Tokens in AOTAutograd]
|
||||
f_outs = (*functional_tensor_mode._tokens.values(), *f_outs)
|
||||
|
||||
if trace_joint:
|
||||
# We support a limited amount of mutation of graph inputs during the backward pass.
|
||||
# (This is used e.g. by Float8, which needs to update buffers during the backward pass)
|
||||
# Here, we perform extra checks for primals that were mutated in the **backward**
|
||||
# We're doing the checks here instead of doing them with the rest of the input mutation handling because:
|
||||
# - We need to detect inputs that were mutated in the backward **separately** from mutations that happened
|
||||
# during the forward, because the handling is different: some input mutations from the the forward
|
||||
# can be only handled in a fw-only runtime epilogue, and in theory if we wanted to handle those same
|
||||
# types of mutations in the backward we would need a bw-only runtime epilogue.
|
||||
# - We could in theory have our analysis pass differentiate mutations in the fw from mutations in
|
||||
# the bw by running our analysis first on the fw-only graph, and then on the joint graph. This would
|
||||
# require an extra round of tracing though, so it's more efficient to do in-line here.
|
||||
assert (
|
||||
isinstance(args, tuple)
|
||||
and len(args) == 2
|
||||
and isinstance(args[0], (list, tuple))
|
||||
)
|
||||
# Only look at mutations that happened to forward inputs (e.g. fw buffers that were saved for bw)
|
||||
primals_before = args[0]
|
||||
primals_after = pytree.tree_map(from_fun, f_args[0])
|
||||
for idx, (f_inpt, before, after, inpt_info) in enumerate(
|
||||
zip(f_args[0], primals_before, primals_after, meta.input_info)
|
||||
):
|
||||
# Store information about mutations in joint(for backward analysis)
|
||||
joint_mutates_data = has_data_mutation(f_inpt)
|
||||
|
||||
joint_mutates_metadata = has_metadata_mutation(
|
||||
f_inpt, before, check_only_storage_mutation=False
|
||||
if trace_joint:
|
||||
assert (
|
||||
isinstance(args, tuple)
|
||||
and len(args) == 2
|
||||
and isinstance(args[0], (list, tuple))
|
||||
)
|
||||
tokens = args[0][: len(meta.tokens)]
|
||||
actual_args = args[0][len(meta.tokens) :]
|
||||
args = (actual_args, args[1])
|
||||
else:
|
||||
tokens = args[: len(meta.tokens)]
|
||||
args = args[len(meta.tokens) :]
|
||||
assert all(token.numel() == 0 for token in tokens)
|
||||
|
||||
# Ban metadata mutations on fw inputs during the bw
|
||||
if not inpt_info.mutates_metadata:
|
||||
assert (
|
||||
not joint_mutates_metadata
|
||||
), "Found a graph input that had its metadata mutated in the backward. This is not supported"
|
||||
with disable_above:
|
||||
# The functionalization code here can potentially trigger traces
|
||||
# into the graph, but we'd prefer to NOT do this, because if we
|
||||
# trace them now, we will end up with FX nodes that don't have
|
||||
# module stack annotations, which makes unflattener unhappy.
|
||||
# Wrap inputs into functional wrappers
|
||||
f_args = pytree.tree_map(to_fun, args)
|
||||
f_tokens = pytree.tree_map(to_fun, tokens)
|
||||
|
||||
# Ban storage resizing on fw inputs during the bw
|
||||
if not inpt_info.mutation_inductor_storage_resize:
|
||||
assert not was_inductor_storage_resized(
|
||||
f_inpt
|
||||
), "Found a graph input that had storage resizing in the backward. This is not supported"
|
||||
|
||||
# Allow data mutations on fw inputs during the bw, but only if they do not require grad
|
||||
# So we can guarantee that we can keep the mutations in the graph
|
||||
if (
|
||||
joint_mutates_data
|
||||
and not inpt_info.mutates_data
|
||||
and not inpt_info.mutates_storage_metadata
|
||||
):
|
||||
# Not banning here mutations on inpt_info.requires_grad -
|
||||
# we'll check at runtime and fail only when backward is under torch.is_grad_enabled (create_graph)
|
||||
# Add node meta for copy_ for partitioner that this node should be in backward graph.
|
||||
with torch.fx.traceback.preserve_node_meta(), set_partitioner_tag(
|
||||
"must_be_in_backward"
|
||||
):
|
||||
before.copy_(after)
|
||||
meta.indices_of_inputs_that_requires_grad_with_mutations_in_bw.append(
|
||||
idx
|
||||
# Populate the current FunctionalTensorMode with the tokens per
|
||||
# operator. See Note [FunctionalTensorMode is Stateful]
|
||||
functional_tensor_mode = (
|
||||
torch.utils._python_dispatch._detect_infra_mode(
|
||||
torch._C._TorchDispatchModeKey.FUNCTIONAL
|
||||
)
|
||||
# Now that we covered mutations to *forward* inputs during the backward,
|
||||
# we also need to cover mutations to *backward-only* inputs during the backward (e.g. mutation to a grad_out).
|
||||
# Today, we will just error in all cases of this happening unless someone needs us to support it.
|
||||
tangents_before = args[1]
|
||||
tangents_after = pytree.tree_map(from_fun, f_args[1])
|
||||
for f_inpt, before, after in zip(
|
||||
f_args[1], tangents_before, tangents_after
|
||||
):
|
||||
assert not has_metadata_mutation(
|
||||
f_inpt, before, check_only_storage_mutation=False
|
||||
) and not has_data_mutation(
|
||||
f_inpt
|
||||
), "Found an input to the backward that was mutated during the backward pass. This is not supported"
|
||||
)
|
||||
assert functional_tensor_mode is not None
|
||||
for i, k in enumerate(meta.tokens.keys()):
|
||||
functional_tensor_mode._tokens[k] = f_tokens[i]
|
||||
|
||||
if aot_config.keep_inference_input_mutations:
|
||||
# Note: This is a bit annoying. There's a layering issue here, where:
|
||||
# (1) functionalization needs to operate on **synthetic base** inputs, before unpacking them into the "real" inputs.
|
||||
# (2) For keep_input_mutations, we support tracing a call to copy_() directly on mutated inputs.
|
||||
# However, we **only** want to support this for inputs that have data-only (and no metadata) mutations,
|
||||
# because inductor (and backends in generally) would prefer not to see these (e.g. as_strided_(), resize_()).
|
||||
# This makes it pretty difficult for this logic to operate on synthetic bases.
|
||||
# (3) In addition, there are cases where it's significantly cheaper to perform the copy on the individual
|
||||
# (unpacked) input aliases, instead of the synthetic base.
|
||||
# Example case where (3) could be important:
|
||||
#
|
||||
# def f(x, y):
|
||||
# x.mul_(2)
|
||||
# y.mul_(3)
|
||||
# return x, y
|
||||
# a = torch.ones(1'000'000)
|
||||
# x, y = out(a[0:9], a[1:10])
|
||||
#
|
||||
# It would be much better to add copy_() calls into the graph for the two tiny slices, instead of materializing
|
||||
# a giant "updated synthetic base" and copying into a's entire storage.
|
||||
#
|
||||
# For now, we are pessimistically not performing the optimization from (3);
|
||||
# we will materialize an "updated" synthetic base, and copy it back to the synthetic input base.
|
||||
# This allows us to factor aot autograd much more nicely, since only one area of the code needs to worry
|
||||
# about synthetic bases.
|
||||
for i, (inpt_old, inpt_f) in enumerate(
|
||||
zip(args, f_args) if not trace_joint else zip(args[0], f_args[0])
|
||||
):
|
||||
if not isinstance(inpt_f, torch.Tensor):
|
||||
continue
|
||||
assert is_fun(inpt_f)
|
||||
inpt_new = from_fun(inpt_f)
|
||||
if meta.input_info[i].mutation_type == MutationType.MUTATED_IN_GRAPH:
|
||||
# See Note [set_() Input Mutations in AOTAutograd]
|
||||
# all mutations on the input must be under no_grad, so it is safe to put in the graph
|
||||
# Here, we're saying that if an input experienced a set call, inp.set_(other),
|
||||
# then we can effectively not have to worry about whether its data was mutated.
|
||||
# There are 3 cases:
|
||||
# (1) We mutate inp *after* the set_() call. other is a graph intermediate.
|
||||
# In this case, we're not really mutating the input storage of "inp";
|
||||
# we're mutating the storage of an intermdiate value (other),
|
||||
# and slamming that storage into the input tensor. So no data mutation is necessary.
|
||||
# (2) We mutate inp *after* the set_() call. other is a graph *input*.
|
||||
# In this case, the data mutation will be properly handled in the runtime
|
||||
# epilogue during the processing of "other"
|
||||
# (3) We mutate inp *before* the set_() call.
|
||||
# This case is *not* currently handled.
|
||||
if meta.input_info[i].mutates_storage_metadata:
|
||||
with torch.no_grad():
|
||||
inpt_old.set_(inpt_new)
|
||||
# Run the joint
|
||||
f_outs = fn(*f_args)
|
||||
|
||||
# Note [Ordering of resize_() and set_()]
|
||||
# Importantly: the common usage in FSDP is that we have a dummy parameter
|
||||
# that sees a set_() and **Then** a resize_().
|
||||
# We must put those mutations into the graph in the same order,
|
||||
# Since running them in the opposite order will have different behavior.
|
||||
# We fully ban resize_() followed by set_() for now, although in principal
|
||||
# we could support this
|
||||
if meta.input_info[i].mutation_inductor_storage_resize:
|
||||
# resizing is not supported on subclasses (we error earlier if this happens)
|
||||
from torch._subclasses.functional_tensor import FunctionalTensor
|
||||
# Return both the tokens and the outputs
|
||||
# See Note [Side-Effectful Tokens in AOTAutograd]
|
||||
f_outs = (*functional_tensor_mode._tokens.values(), *f_outs)
|
||||
|
||||
assert isinstance(inpt_f, FunctionalTensor)
|
||||
old_storage_size = torch._functionalize_get_storage_size( # type: ignore[attr-defined]
|
||||
inpt_f.elem, before=True
|
||||
)
|
||||
new_storage_size = torch._functionalize_get_storage_size( # type: ignore[attr-defined]
|
||||
inpt_f.elem, before=False
|
||||
)
|
||||
if old_storage_size != new_storage_size:
|
||||
assert (
|
||||
old_storage_size == 0 or new_storage_size == 0
|
||||
), f"""\
|
||||
Encountered a storage resize during tracing on input {i}. Old nbytes={old_storage_size}, new nbytes={new_storage_size}
|
||||
We only support storage resizing on graph inputs as long as the input either starts or ends with a storage size of 0
|
||||
(the case for FSDP)"""
|
||||
torch.ops.inductor.resize_storage_bytes_(
|
||||
inpt_old, new_storage_size
|
||||
)
|
||||
if new_storage_size == 0:
|
||||
# Even if we marked the input as having a data mutation (thus needing a copy_()),
|
||||
# We should **ignore** it if our input has no storage
|
||||
# (this can happen if, e.g. we temporarily resize our input, copy data into it,
|
||||
# and resize it back down to zero)
|
||||
continue
|
||||
# Optimization: if the copy_() is a no-op then don't include it in the graph.
|
||||
# In theory inductor could optimize this away, however in fsdp, we end up with
|
||||
# param.copy_(param), where param is a zero-storage-size tensor,
|
||||
# and running this op in eager mode (using the aot_eager backend) will result in a segfault.
|
||||
# So we may as well optimize it away here.
|
||||
if inpt_old is inpt_new:
|
||||
# (This check needs to be done after putting resize_() in the graph,
|
||||
# since a resize_(0) doesn't actually change the FunctionalTensor's inner tensor)
|
||||
continue
|
||||
# We found an input that had a (data-only) mutation.
|
||||
# Since keep_input_mutations is set, we need to faithfully apply a copy_()
|
||||
# so the compiler will see the input mutation in the graph.
|
||||
if (
|
||||
meta.input_info[i].mutates_data
|
||||
and meta.input_info[i].mutations_hidden_from_autograd
|
||||
):
|
||||
# Hidden from autograd = run under no_grad, **and** don't bump VC
|
||||
# (although if the tensor was created in inference mode, it has no VC)
|
||||
if inpt_old.is_inference():
|
||||
maybe_preserve_vc = nullcontext()
|
||||
else:
|
||||
maybe_preserve_vc = torch.autograd._unsafe_preserve_version_counter(
|
||||
inpt_old # type: ignore[assignment]
|
||||
)
|
||||
with torch.no_grad(), maybe_preserve_vc:
|
||||
inpt_old.copy_(inpt_new)
|
||||
elif (
|
||||
meta.input_info[i].mutates_data
|
||||
and meta.input_info[i].mutations_under_no_grad_or_inference_mode
|
||||
):
|
||||
# Under no_grad = run under no_grad (we still bump the VC though)
|
||||
# (inference_mode will also bump the VC, as long as the tensor in question
|
||||
# was created outside of inference_mode)
|
||||
with torch.no_grad():
|
||||
inpt_old.copy_(inpt_new)
|
||||
elif meta.input_info[i].mutates_data:
|
||||
inpt_old.copy_(inpt_new)
|
||||
|
||||
# When an output tensor is a functionalized mutated input, and we
|
||||
# were able to move the mutation in to the graph then we can return
|
||||
# the mutated input directly. This prevents duplicating the
|
||||
# tensors contents.
|
||||
flat_outs, outs_spec = pytree.tree_flatten(f_outs)
|
||||
flat_outs = [from_fun(o) for o in flat_outs]
|
||||
num_outs = len(meta.output_info)
|
||||
|
||||
for i, outp in enumerate(flat_outs[:num_outs]):
|
||||
info = meta.output_info[i]
|
||||
if info.output_type != OutputType.is_input:
|
||||
continue
|
||||
|
||||
assert info.base_idx is not None
|
||||
if (
|
||||
meta.input_info[info.base_idx].mutation_type
|
||||
== MutationType.MUTATED_IN_GRAPH
|
||||
if trace_joint:
|
||||
# We support a limited amount of mutation of graph inputs during the backward pass.
|
||||
# (This is used e.g. by Float8, which needs to update buffers during the backward pass)
|
||||
# Here, we perform extra checks for primals that were mutated in the **backward**
|
||||
# We're doing the checks here instead of doing them with the rest of the input mutation handling because:
|
||||
# - We need to detect inputs that were mutated in the backward **separately** from mutations that happened
|
||||
# during the forward, because the handling is different: some input mutations from the the forward
|
||||
# can be only handled in a fw-only runtime epilogue, and in theory if we wanted to handle those same
|
||||
# types of mutations in the backward we would need a bw-only runtime epilogue.
|
||||
# - We could in theory have our analysis pass differentiate mutations in the fw from mutations in
|
||||
# the bw by running our analysis first on the fw-only graph, and then on the joint graph. This would
|
||||
# require an extra round of tracing though, so it's more efficient to do in-line here.
|
||||
assert (
|
||||
isinstance(args, tuple)
|
||||
and len(args) == 2
|
||||
and isinstance(args[0], (list, tuple))
|
||||
)
|
||||
# Only look at mutations that happened to forward inputs (e.g. fw buffers that were saved for bw)
|
||||
primals_before = args[0]
|
||||
primals_after = pytree.tree_map(from_fun, f_args[0])
|
||||
for idx, (f_inpt, before, after, inpt_info) in enumerate(
|
||||
zip(f_args[0], primals_before, primals_after, meta.input_info)
|
||||
):
|
||||
fw_args = args[0] if trace_joint else args
|
||||
flat_outs[i] = fw_args[info.base_idx]
|
||||
return pytree.tree_unflatten(flat_outs, outs_spec)
|
||||
# Store information about mutations in joint(for backward analysis)
|
||||
joint_mutates_data = has_data_mutation(f_inpt)
|
||||
|
||||
return pytree.tree_map(from_fun, f_outs)
|
||||
joint_mutates_metadata = has_metadata_mutation(
|
||||
f_inpt, before, check_only_storage_mutation=False
|
||||
)
|
||||
|
||||
# Ban metadata mutations on fw inputs during the bw
|
||||
if not inpt_info.mutates_metadata:
|
||||
assert (
|
||||
not joint_mutates_metadata
|
||||
), "Found a graph input that had its metadata mutated in the backward. This is not supported"
|
||||
|
||||
# Ban storage resizing on fw inputs during the bw
|
||||
if not inpt_info.mutation_inductor_storage_resize:
|
||||
assert not was_inductor_storage_resized(
|
||||
f_inpt
|
||||
), "Found a graph input that had storage resizing in the backward. This is not supported"
|
||||
|
||||
# Allow data mutations on fw inputs during the bw, but only if they do not require grad
|
||||
# So we can guarantee that we can keep the mutations in the graph
|
||||
if (
|
||||
joint_mutates_data
|
||||
and not inpt_info.mutates_data
|
||||
and not inpt_info.mutates_storage_metadata
|
||||
):
|
||||
# Not banning here mutations on inpt_info.requires_grad -
|
||||
# we'll check at runtime and fail only when backward is under torch.is_grad_enabled (create_graph)
|
||||
# Add node meta for copy_ for partitioner that this node should be in backward graph.
|
||||
with torch.fx.traceback.preserve_node_meta(), set_partitioner_tag(
|
||||
"must_be_in_backward"
|
||||
):
|
||||
before.copy_(after)
|
||||
meta.indices_of_inputs_that_requires_grad_with_mutations_in_bw.append(
|
||||
idx
|
||||
)
|
||||
# Now that we covered mutations to *forward* inputs during the backward,
|
||||
# we also need to cover mutations to *backward-only* inputs during the backward (e.g. mutation to a grad_out).
|
||||
# Today, we will just error in all cases of this happening unless someone needs us to support it.
|
||||
tangents_before = args[1]
|
||||
tangents_after = pytree.tree_map(from_fun, f_args[1])
|
||||
for f_inpt, before, after in zip(
|
||||
f_args[1], tangents_before, tangents_after
|
||||
):
|
||||
assert not has_metadata_mutation(
|
||||
f_inpt, before, check_only_storage_mutation=False
|
||||
) and not has_data_mutation(
|
||||
f_inpt
|
||||
), "Found an input to the backward that was mutated during the backward pass. This is not supported"
|
||||
|
||||
if aot_config.keep_inference_input_mutations:
|
||||
# Note: This is a bit annoying. There's a layering issue here, where:
|
||||
# (1) functionalization needs to operate on **synthetic base** inputs, before unpacking them into the "real" inputs.
|
||||
# (2) For keep_input_mutations, we support tracing a call to copy_() directly on mutated inputs.
|
||||
# However, we **only** want to support this for inputs that have data-only (and no metadata) mutations,
|
||||
# because inductor (and backends in generally) would prefer not to see these (e.g. as_strided_(), resize_()).
|
||||
# This makes it pretty difficult for this logic to operate on synthetic bases.
|
||||
# (3) In addition, there are cases where it's significantly cheaper to perform the copy on the individual
|
||||
# (unpacked) input aliases, instead of the synthetic base.
|
||||
# Example case where (3) could be important:
|
||||
#
|
||||
# def f(x, y):
|
||||
# x.mul_(2)
|
||||
# y.mul_(3)
|
||||
# return x, y
|
||||
# a = torch.ones(1'000'000)
|
||||
# x, y = out(a[0:9], a[1:10])
|
||||
#
|
||||
# It would be much better to add copy_() calls into the graph for the two tiny slices, instead of materializing
|
||||
# a giant "updated synthetic base" and copying into a's entire storage.
|
||||
#
|
||||
# For now, we are pessimistically not performing the optimization from (3);
|
||||
# we will materialize an "updated" synthetic base, and copy it back to the synthetic input base.
|
||||
# This allows us to factor aot autograd much more nicely, since only one area of the code needs to worry
|
||||
# about synthetic bases.
|
||||
for i, (inpt_old, inpt_f) in enumerate(
|
||||
zip(args, f_args) if not trace_joint else zip(args[0], f_args[0])
|
||||
):
|
||||
if not isinstance(inpt_f, torch.Tensor):
|
||||
continue
|
||||
assert is_fun(inpt_f)
|
||||
inpt_new = from_fun(inpt_f)
|
||||
if (
|
||||
meta.input_info[i].mutation_type
|
||||
== MutationType.MUTATED_IN_GRAPH
|
||||
):
|
||||
# See Note [set_() Input Mutations in AOTAutograd]
|
||||
# all mutations on the input must be under no_grad, so it is safe to put in the graph
|
||||
# Here, we're saying that if an input experienced a set call, inp.set_(other),
|
||||
# then we can effectively not have to worry about whether its data was mutated.
|
||||
# There are 3 cases:
|
||||
# (1) We mutate inp *after* the set_() call. other is a graph intermediate.
|
||||
# In this case, we're not really mutating the input storage of "inp";
|
||||
# we're mutating the storage of an intermdiate value (other),
|
||||
# and slamming that storage into the input tensor. So no data mutation is necessary.
|
||||
# (2) We mutate inp *after* the set_() call. other is a graph *input*.
|
||||
# In this case, the data mutation will be properly handled in the runtime
|
||||
# epilogue during the processing of "other"
|
||||
# (3) We mutate inp *before* the set_() call.
|
||||
# This case is *not* currently handled.
|
||||
if meta.input_info[i].mutates_storage_metadata:
|
||||
with torch.no_grad():
|
||||
inpt_old.set_(inpt_new)
|
||||
|
||||
# Note [Ordering of resize_() and set_()]
|
||||
# Importantly: the common usage in FSDP is that we have a dummy parameter
|
||||
# that sees a set_() and **Then** a resize_().
|
||||
# We must put those mutations into the graph in the same order,
|
||||
# Since running them in the opposite order will have different behavior.
|
||||
# We fully ban resize_() followed by set_() for now, although in principal
|
||||
# we could support this
|
||||
if meta.input_info[i].mutation_inductor_storage_resize:
|
||||
# resizing is not supported on subclasses (we error earlier if this happens)
|
||||
from torch._subclasses.functional_tensor import (
|
||||
FunctionalTensor,
|
||||
)
|
||||
|
||||
assert isinstance(inpt_f, FunctionalTensor)
|
||||
old_storage_size = torch._functionalize_get_storage_size( # type: ignore[attr-defined]
|
||||
inpt_f.elem, before=True
|
||||
)
|
||||
new_storage_size = torch._functionalize_get_storage_size( # type: ignore[attr-defined]
|
||||
inpt_f.elem, before=False
|
||||
)
|
||||
if old_storage_size != new_storage_size:
|
||||
assert (
|
||||
old_storage_size == 0 or new_storage_size == 0
|
||||
), f"""\
|
||||
Encountered a storage resize during tracing on input {i}. Old nbytes={old_storage_size}, new nbytes={new_storage_size}
|
||||
We only support storage resizing on graph inputs as long as the input either starts or ends with a storage size of 0
|
||||
(the case for FSDP)"""
|
||||
torch.ops.inductor.resize_storage_bytes_(
|
||||
inpt_old, new_storage_size
|
||||
)
|
||||
if new_storage_size == 0:
|
||||
# Even if we marked the input as having a data mutation (thus needing a copy_()),
|
||||
# We should **ignore** it if our input has no storage
|
||||
# (this can happen if, e.g. we temporarily resize our input, copy data into it,
|
||||
# and resize it back down to zero)
|
||||
continue
|
||||
# Optimization: if the copy_() is a no-op then don't include it in the graph.
|
||||
# In theory inductor could optimize this away, however in fsdp, we end up with
|
||||
# param.copy_(param), where param is a zero-storage-size tensor,
|
||||
# and running this op in eager mode (using the aot_eager backend) will result in a segfault.
|
||||
# So we may as well optimize it away here.
|
||||
if inpt_old is inpt_new:
|
||||
# (This check needs to be done after putting resize_() in the graph,
|
||||
# since a resize_(0) doesn't actually change the FunctionalTensor's inner tensor)
|
||||
continue
|
||||
# We found an input that had a (data-only) mutation.
|
||||
# Since keep_input_mutations is set, we need to faithfully apply a copy_()
|
||||
# so the compiler will see the input mutation in the graph.
|
||||
if (
|
||||
meta.input_info[i].mutates_data
|
||||
and meta.input_info[i].mutations_hidden_from_autograd
|
||||
):
|
||||
# Hidden from autograd = run under no_grad, **and** don't bump VC
|
||||
# (although if the tensor was created in inference mode, it has no VC)
|
||||
if inpt_old.is_inference():
|
||||
maybe_preserve_vc = nullcontext()
|
||||
else:
|
||||
maybe_preserve_vc = torch.autograd._unsafe_preserve_version_counter(
|
||||
inpt_old # type: ignore[assignment]
|
||||
)
|
||||
with torch.no_grad(), maybe_preserve_vc:
|
||||
inpt_old.copy_(inpt_new)
|
||||
elif (
|
||||
meta.input_info[i].mutates_data
|
||||
and meta.input_info[
|
||||
i
|
||||
].mutations_under_no_grad_or_inference_mode
|
||||
):
|
||||
# Under no_grad = run under no_grad (we still bump the VC though)
|
||||
# (inference_mode will also bump the VC, as long as the tensor in question
|
||||
# was created outside of inference_mode)
|
||||
with torch.no_grad():
|
||||
inpt_old.copy_(inpt_new)
|
||||
elif meta.input_info[i].mutates_data:
|
||||
inpt_old.copy_(inpt_new)
|
||||
|
||||
# When an output tensor is a functionalized mutated input, and we
|
||||
# were able to move the mutation in to the graph then we can return
|
||||
# the mutated input directly. This prevents duplicating the
|
||||
# tensors contents.
|
||||
flat_outs, outs_spec = pytree.tree_flatten(f_outs)
|
||||
flat_outs = [from_fun(o) for o in flat_outs]
|
||||
num_outs = len(meta.output_info)
|
||||
|
||||
for i, outp in enumerate(flat_outs[:num_outs]):
|
||||
info = meta.output_info[i]
|
||||
if info.output_type != OutputType.is_input:
|
||||
continue
|
||||
|
||||
assert info.base_idx is not None
|
||||
if (
|
||||
meta.input_info[info.base_idx].mutation_type
|
||||
== MutationType.MUTATED_IN_GRAPH
|
||||
):
|
||||
fw_args = args[0] if trace_joint else args
|
||||
flat_outs[i] = fw_args[info.base_idx]
|
||||
return pytree.tree_unflatten(flat_outs, outs_spec)
|
||||
|
||||
return pytree.tree_map(from_fun, f_outs)
|
||||
|
||||
# Kinda annoying, but needed to make sure that the fx graph we trace out has "primals"
|
||||
# and "tangents" as its input names (which are special-cased by the partitioner)
|
||||
|
|
@ -709,10 +728,14 @@ def aot_dispatch_subclass(
|
|||
return unwrapped_outs
|
||||
|
||||
def joint_fn(primals, tangents):
|
||||
return inner_fn(flat_fn_maybe_joint, (primals, tangents), use_trace_joint=True)
|
||||
with maybe_enable_thunkify():
|
||||
return inner_fn(
|
||||
flat_fn_maybe_joint, (primals, tangents), use_trace_joint=True
|
||||
)
|
||||
|
||||
def fw_fn(*primals):
|
||||
return inner_fn(flat_fn_maybe_joint, primals, use_trace_joint=False)
|
||||
with maybe_enable_thunkify():
|
||||
return inner_fn(flat_fn_maybe_joint, primals, use_trace_joint=False)
|
||||
|
||||
def metadata_fn(*primals):
|
||||
return inner_fn(fw_only, primals, use_trace_joint=False)
|
||||
|
|
@ -771,7 +794,7 @@ def create_functional_call(mod, params_spec, params_len, store_orig_mod=False):
|
|||
def functional_call(*args, **kwargs):
|
||||
with stateless._reparametrize_module(
|
||||
mod, pytree.tree_unflatten(args[:params_len], params_spec)
|
||||
):
|
||||
), maybe_disable_thunkify():
|
||||
if isinstance(mod, torch.fx.GraphModule):
|
||||
with fx_traceback.preserve_node_meta(), warnings.catch_warnings():
|
||||
warnings.filterwarnings(
|
||||
|
|
|
|||
|
|
@ -54,6 +54,7 @@ from weakref import WeakKeyDictionary
|
|||
|
||||
if TYPE_CHECKING:
|
||||
import types
|
||||
import sympy
|
||||
|
||||
from torch._ops import OpOverload
|
||||
from torch.fx._symbolic_trace import PHBase
|
||||
|
|
@ -61,7 +62,8 @@ if TYPE_CHECKING:
|
|||
|
||||
__all__ = [
|
||||
"PythonKeyTracer", "dispatch_trace", "make_fx", "DecompositionInterpreter",
|
||||
"py_sym_types", "get_innermost_proxy_mode", "get_proxy_mode", "handle_sym_dispatch"
|
||||
"py_sym_types", "get_innermost_proxy_mode", "get_proxy_mode", "handle_sym_dispatch",
|
||||
"maybe_enable_thunkify", "maybe_disable_thunkify",
|
||||
]
|
||||
|
||||
_ProxyTracer = Union["PythonKeyTracer", "_GraphAppendingTracerEx"]
|
||||
|
|
@ -157,6 +159,7 @@ def set_proxy_slot(
|
|||
tracer: _ProxyTracer,
|
||||
proxy: object
|
||||
) -> None:
|
||||
log.debug("set_proxy_slot %s (%s) %s", obj, id(obj), proxy)
|
||||
if isinstance(obj, Tensor):
|
||||
# We DO want to clobber proxies whenever we run an inplace operation
|
||||
# on a tensor, and it affects the metadata on the proxy.
|
||||
|
|
@ -175,6 +178,22 @@ def set_proxy_slot(
|
|||
if obj not in tracer.symnode_tracker:
|
||||
tracer.symnode_tracker[obj] = typing.cast(_PySymProxyType, proxy)
|
||||
|
||||
# WAR: python test/dynamo/test_subclasses.py
|
||||
# TestNestedTensor.test_basic_autograd
|
||||
#
|
||||
# AOTAutograd doesn't pass the "outer sizes" as an actual argument
|
||||
# to make_fx, but it is made use of internally in AOTAutograd's
|
||||
# call to tensor unflatten. Because the outer sizes isn't passed
|
||||
# as an argument, it is therefore untracked. However, it turns
|
||||
# out you luck out, because *Dynamo* will manually add the outer
|
||||
# sizes as an argument so you can fix up the proxy'ness.
|
||||
#
|
||||
# This is probably fixed in
|
||||
# https://github.com/pytorch/pytorch/pull/125941/
|
||||
import sympy
|
||||
if isinstance(obj.node.expr, sympy.Symbol):
|
||||
tracer.sympy_expr_tracker[obj.node.expr] = proxy
|
||||
|
||||
def has_proxy_slot(obj: Tensor, tracer: _ProxyTracer) -> bool:
|
||||
assert isinstance(obj, (Tensor, SymNode)), type(obj)
|
||||
return bool(get_proxy_slot(obj, tracer, False, lambda _: True))
|
||||
|
|
@ -276,10 +295,15 @@ def get_proxy_slot(
|
|||
tracker = tracer.symnode_tracker
|
||||
|
||||
if obj not in tracker:
|
||||
if isinstance(default, _NoDefault):
|
||||
raise RuntimeError(f"{obj} is not tracked with proxy for {tracer}")
|
||||
return default
|
||||
value = tracker[obj]
|
||||
# Last ditch
|
||||
if isinstance(obj, py_sym_types) and obj.node.expr in tracer.sympy_expr_tracker:
|
||||
value = tracer.sympy_expr_tracker[obj.node.expr]
|
||||
else:
|
||||
if isinstance(default, _NoDefault):
|
||||
raise RuntimeError(f"{obj} ({id(obj)})is not tracked with proxy for {tracer}")
|
||||
return default
|
||||
else:
|
||||
value = tracker[obj]
|
||||
res = transform(value)
|
||||
return res
|
||||
|
||||
|
|
@ -330,6 +354,54 @@ def extract_val(val: _ExtractValType) -> _ExtractValType:
|
|||
|
||||
typing_extensions.assert_never(val)
|
||||
|
||||
@contextmanager
|
||||
def _enable_thunkify(tracer: _ProxyTracer, *, enable: bool = True) -> Generator[None, None, None]:
|
||||
"""
|
||||
Enable thunkification inside the context manager. Thunkification prevents
|
||||
SymNode computation from directly being traced into an FX graph; instead,
|
||||
the compute is only added to the graph if it is actually used. This helps
|
||||
us track SymNode compute when it is computed (since we need /something/
|
||||
to put in the tracker) even if it is unlikely to be used.
|
||||
"""
|
||||
old = tracer.enable_thunkify
|
||||
tracer.enable_thunkify = enable
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
tracer.enable_thunkify = old
|
||||
|
||||
@contextmanager
|
||||
def maybe_disable_thunkify() -> Generator[None, None, None]:
|
||||
"""Within a context, disable thunkification. See :func:`maybe_enable_thunkify`
|
||||
for more details. This is helpful if you have a wrapper function which
|
||||
you want to enable thunkification on, but in some segment on the inside (say,
|
||||
the original user function), you want to disable thunkification as you know
|
||||
it is not needed there.
|
||||
"""
|
||||
proxy_mode = get_proxy_mode()
|
||||
if proxy_mode is not None:
|
||||
with _enable_thunkify(proxy_mode.tracer, enable=False):
|
||||
yield
|
||||
else:
|
||||
yield
|
||||
|
||||
@contextmanager
|
||||
def maybe_enable_thunkify() -> Generator[None, None, None]:
|
||||
"""Within this context manager, if you are doing make_fx tracing, we will thunkify
|
||||
all SymNode compute and avoid tracing it into the graph unless it is actually needed.
|
||||
You should prefer to avoid using this as much as possible, as lazy evaluation of
|
||||
SymNode tracing can lead to long chains of thunks which will stack overflow
|
||||
if you evaluate them. However, this is currently sometimes necessary as there
|
||||
are buggy parts of PT2 which will fail with "s0 is not tracked with proxy" error
|
||||
due to insufficient tracing of SymNode computation.
|
||||
"""
|
||||
proxy_mode = get_proxy_mode()
|
||||
if proxy_mode is not None:
|
||||
with _enable_thunkify(proxy_mode.tracer):
|
||||
yield
|
||||
else:
|
||||
yield
|
||||
|
||||
# Note [invariants for node meta 'val']
|
||||
# What invariants do we have for the 'val' set on the FX node? It has accurate
|
||||
# metadata... but only for metadata that exists "below" all other subsystems
|
||||
|
|
@ -340,19 +412,24 @@ def extract_val(val: _ExtractValType) -> _ExtractValType:
|
|||
def set_meta(proxy: Proxy, val: _ExtractValType) -> Proxy:
|
||||
proxy.node.meta['val'] = extract_val(val)
|
||||
|
||||
# Best effort tensor_meta setting; prefer using val!
|
||||
if is_fake(val):
|
||||
proxy.node.meta['tensor_meta'] = _extract_tensor_metadata(val)
|
||||
elif isinstance(val, Tensor) and not val.is_sparse:
|
||||
proxy.node.meta['tensor_meta'] = _extract_tensor_metadata(val)
|
||||
with _enable_thunkify(proxy.tracer): # type: ignore[arg-type]
|
||||
# Best effort tensor_meta setting; prefer using val!
|
||||
if is_fake(val):
|
||||
proxy.node.meta['tensor_meta'] = _extract_tensor_metadata(val)
|
||||
elif isinstance(val, Tensor) and not val.is_sparse:
|
||||
proxy.node.meta['tensor_meta'] = _extract_tensor_metadata(val)
|
||||
return proxy
|
||||
|
||||
def thunkify(f: Callable[_P, R], *args: _P.args, **kwargs: _P.kwargs) -> Thunk[R]:
|
||||
def thunkify(tracer: _ProxyTracer, f: Callable[_P, R], *args: _P.args, **kwargs: _P.kwargs) -> Thunk[R]:
|
||||
"""
|
||||
Delays computation of f until it's called again
|
||||
Also caches the result
|
||||
"""
|
||||
return Thunk(functools.partial(f, *args, **kwargs))
|
||||
if tracer.enable_thunkify:
|
||||
return Thunk(functools.partial(f, *args, **kwargs))
|
||||
else:
|
||||
r = f(*args, **kwargs)
|
||||
return Thunk(lambda: r)
|
||||
|
||||
def track_tensor(tensor: Tensor, proxy: Proxy, *, constant: Optional[Tensor], tracer: _ProxyTracer) -> None:
|
||||
def try_set_proxy_slot(
|
||||
|
|
@ -363,7 +440,8 @@ def track_tensor(tensor: Tensor, proxy: Proxy, *, constant: Optional[Tensor], tr
|
|||
) -> None:
|
||||
assert callable(proxy_callable)
|
||||
if isinstance(outer_s, SymInt):
|
||||
set_proxy_slot(outer_s, tracer, thunkify(proxy_callable, outer_s, *args, **kwargs))
|
||||
with _enable_thunkify(tracer):
|
||||
set_proxy_slot(outer_s, tracer, thunkify(tracer, proxy_callable, outer_s, *args, **kwargs))
|
||||
# The basic idea is that we need to associate each tensor/SymInt
|
||||
# with a Proxy. How do we setup this association? We just store
|
||||
# the proxy on the proxy slot of the object, keyed on the tracer
|
||||
|
|
@ -411,7 +489,7 @@ def track_tensor_tree(
|
|||
assert isinstance(proxy, Proxy)
|
||||
# NB: eagerly set meta here, so that the numbering is in order
|
||||
set_meta(proxy, e)
|
||||
set_proxy_slot(e, tracer, thunkify(lambda: proxy))
|
||||
set_proxy_slot(e, tracer, thunkify(tracer, lambda: proxy))
|
||||
elif isinstance(e, _AnyScriptObject):
|
||||
assert isinstance(proxy, Proxy)
|
||||
set_proxy_slot(e, tracer, proxy)
|
||||
|
|
@ -700,7 +778,8 @@ def proxy_call(
|
|||
# Adding an undefined attribute to Tensor?
|
||||
args[0].proxy = proxy_out # type: ignore[attr-defined]
|
||||
|
||||
out = func(*args, **kwargs)
|
||||
with _enable_thunkify(proxy_mode.tracer):
|
||||
out = func(*args, **kwargs)
|
||||
|
||||
# In some circumstances, we will be tracing in a situation where a tensor
|
||||
# is *statically* known to be a constant (currently, this only happens if
|
||||
|
|
@ -804,12 +883,14 @@ class PythonKeyTracer(Tracer):
|
|||
self.tensor_tracker = WeakTensorKeyDictionary()
|
||||
self.symnode_tracker = _SymNodeDict()
|
||||
self.script_object_tracker = WeakIdKeyDictionary(dict=None, ref_type=_WeakHashRef)
|
||||
self.sympy_expr_tracker: Dict[sympy.Symbol, object] = dict()
|
||||
|
||||
# Stores the torch function that was called during tracing
|
||||
self.torch_fn_metadata = None
|
||||
# Stores the counts for every torch function called. This is to help
|
||||
# distinguish between different calls to the same torch function.
|
||||
self.torch_fn_counts = {}
|
||||
self.enable_thunkify = False
|
||||
|
||||
# In general, we don't want to make modules leaves. In principle, users of
|
||||
# this tracer might want to override this in order to turn a couple specific
|
||||
|
|
@ -901,6 +982,32 @@ def dispatch_trace(
|
|||
concrete_args: Optional[Tuple[Any, ...]] = None,
|
||||
) -> GraphModule:
|
||||
graph = tracer.trace(root, concrete_args)
|
||||
|
||||
# NB: be careful not to DCE .item() calls
|
||||
def impure_pred(n: fx.Node) -> bool:
|
||||
from .symbolic_shapes import is_accessor_node
|
||||
|
||||
# Always defer to the built-in notion of impure
|
||||
if n.is_impure():
|
||||
return True
|
||||
|
||||
# Accessors always OK to DCE
|
||||
if is_accessor_node(n):
|
||||
return False
|
||||
|
||||
# If the operator in question takes SymInt args to SymInt output,
|
||||
# we assume it's pure and OK to DCE
|
||||
if (
|
||||
isinstance(n.meta.get('val'), py_sym_types) and
|
||||
# NB: constant args ok
|
||||
all(isinstance(a.meta.get('val'), py_sym_types) for a in n.args if isinstance(a, fx.Node))
|
||||
):
|
||||
return False
|
||||
|
||||
# No idea, just assume it's not OK
|
||||
return True
|
||||
|
||||
graph.eliminate_dead_code(impure_pred)
|
||||
from torch._inductor.fx_passes.dedupe_symint_uses import dedupe_symints
|
||||
dedupe_symints(graph)
|
||||
name = root.__class__.__name__ if isinstance(root, Module) else root.__name__
|
||||
|
|
@ -945,6 +1052,7 @@ def wrap_key(f: Callable[_P, R], tensors: _P.args, tracer: _ProxyTracer, pre_dis
|
|||
|
||||
return wrapped
|
||||
|
||||
# TODO: Make downstream users of this work with OperatorBase
|
||||
ORIGINAL_ATEN: Optional[object] = None
|
||||
@contextmanager
|
||||
def set_original_aten_op(func: OpOverload) -> Generator[None, None, None]:
|
||||
|
|
@ -1129,7 +1237,7 @@ class ProxyTorchDispatchMode(TorchDispatchMode):
|
|||
# were symbolic) and it is no longer necessary to trace the
|
||||
# computation. This could occur if func triggered some guards.
|
||||
if isinstance(out, py_sym_types):
|
||||
p_out_thunk = thunkify(self._compute_proxy, func=func, args=args, out=out)
|
||||
p_out_thunk = thunkify(self.tracer, self._compute_proxy, func=func, args=args, out=out)
|
||||
set_proxy_slot(out, self.tracer, p_out_thunk)
|
||||
|
||||
return out
|
||||
|
|
@ -1139,8 +1247,10 @@ class _GraphAppendingTracerEx(fx.proxy.GraphAppendingTracer):
|
|||
script_object_tracker: WeakKeyDictionary
|
||||
symnode_tracker: WeakKeyDictionary
|
||||
tensor_tracker: WeakTensorKeyDictionary
|
||||
sympy_expr_tracker: Dict[sympy.Symbol, object]
|
||||
torch_fn_metadata: Optional[OpOverload]
|
||||
torch_fn_counts: Dict[OpOverload, int]
|
||||
enable_thunkify: bool = False
|
||||
|
||||
|
||||
# TODO: I'm not sure what the point of this class is; you can just
|
||||
|
|
@ -1159,6 +1269,7 @@ class DecompositionInterpreter(fx.Interpreter):
|
|||
# Blegh
|
||||
self.tracer.tensor_tracker = WeakTensorKeyDictionary()
|
||||
self.tracer.symnode_tracker = weakref.WeakKeyDictionary()
|
||||
self.tracer.sympy_expr_tracker = dict()
|
||||
self.decomposition_table = decomposition_table or {}
|
||||
self.mode = ProxyTorchDispatchMode(self.tracer, tracing_mode="real")
|
||||
|
||||
|
|
@ -1358,6 +1469,7 @@ class _ModuleStackTracer(PythonKeyTracer):
|
|||
concrete_args: Optional[Dict[str, object]]
|
||||
) -> fx.Graph:
|
||||
res = super().trace(root, concrete_args)
|
||||
|
||||
# Since we are making _AttrProxy mimic the original
|
||||
# submodule, when someone registers a module directly
|
||||
# to the tracer while tracing, the proxy object gets registered
|
||||
|
|
|
|||
|
|
@ -110,7 +110,7 @@ __all__ = [
|
|||
"StatefulSymbolicContext", "SubclassSymbolicContext", "statically_known_true",
|
||||
"guard_size_oblivious", "check_consistent",
|
||||
"compute_unbacked_bindings", "ConvertIntKey",
|
||||
"rebind_unbacked", "resolve_unbacked_bindings",
|
||||
"rebind_unbacked", "resolve_unbacked_bindings", "is_accessor_node",
|
||||
]
|
||||
|
||||
# FX node metadata keys for symbolic shape FX graph.
|
||||
|
|
@ -337,6 +337,32 @@ def rebind_unbacked(shape_env, n: torch.fx.Node, result):
|
|||
# Reuse the OLD symbol name
|
||||
shape_env._rename_unbacked_to(raw_u1, raw_u0)
|
||||
|
||||
# NB: You could try to expand this to cover more cases by simply
|
||||
# detecting whenever you have an int output, but this is a bit
|
||||
# dangerous in case someone adds a function that returns an int but is
|
||||
# mutating. So manually whitelist for now.
|
||||
def is_accessor_node(node: torch.fx.Node) -> bool:
|
||||
# Dynamo only exercised condition
|
||||
if (
|
||||
node.op == "call_method"
|
||||
and isinstance(node.args[0].meta.get("example_value"), torch.Tensor)
|
||||
and node.target in ["size", "stride", "storage_offset", "item"]
|
||||
):
|
||||
return True
|
||||
if node.op == "call_function" and node.target in [
|
||||
torch.ops.aten.sym_size,
|
||||
torch.ops.aten.sym_size.default,
|
||||
torch.ops.aten.sym_size.int,
|
||||
torch.ops.aten.sym_stride,
|
||||
torch.ops.aten.sym_stride.default,
|
||||
torch.ops.aten.sym_stride.int,
|
||||
torch.ops.aten.sym_storage_offset,
|
||||
torch.ops.aten.sym_storage_offset.default,
|
||||
torch.ops.aten.sym_numel.default,
|
||||
]:
|
||||
return True
|
||||
return False
|
||||
|
||||
def canonicalize_bool_expr(expr: SympyBoolean) -> SympyBoolean:
|
||||
r""" Canonicalize a boolean expression by transforming it into a lt / le
|
||||
inequality and moving all the non-constant terms to the rhs.
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ import torch
|
|||
import inspect
|
||||
import operator
|
||||
import collections
|
||||
import logging
|
||||
|
||||
from dataclasses import is_dataclass, fields
|
||||
|
||||
|
|
@ -25,6 +26,9 @@ __all__ = ['TracerBase', 'GraphAppendingTracer', 'TraceError',
|
|||
'ScopeContextManager']
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
class Scope:
|
||||
""" Scope object that records the module path and the module type
|
||||
|
|
@ -136,6 +140,7 @@ class TracerBase:
|
|||
modification of values used in node creation. For example, one might
|
||||
want to disallow in-place operations from being recorded.
|
||||
"""
|
||||
|
||||
if kind == 'call_function' and self.check_mutable_operations:
|
||||
check_for_mutable_operation(target, args, kwargs)
|
||||
|
||||
|
|
@ -175,6 +180,8 @@ class TracerBase:
|
|||
|
||||
elif self.module_stack:
|
||||
node.meta['nn_module_stack'] = copy.copy(self.module_stack)
|
||||
|
||||
log.debug("create_node %s", node)
|
||||
return node
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
|
|
|
|||
|
|
@ -302,14 +302,19 @@ class NestedTensor(torch.Tensor):
|
|||
if kwargs is None:
|
||||
kwargs = {}
|
||||
|
||||
from torch.fx.experimental.proxy_tensor import maybe_enable_thunkify
|
||||
|
||||
from .ops import jagged_torch_function
|
||||
|
||||
try:
|
||||
return jagged_torch_function(func, *args, **kwargs)
|
||||
except NotImplementedError:
|
||||
pass
|
||||
with torch._C.DisableTorchFunctionSubclass():
|
||||
return func(*args, **kwargs)
|
||||
# This should be removed after
|
||||
# https://github.com/pytorch/pytorch/pull/125941/ lands
|
||||
with maybe_enable_thunkify():
|
||||
try:
|
||||
return jagged_torch_function(func, *args, **kwargs)
|
||||
except NotImplementedError:
|
||||
pass
|
||||
with torch._C.DisableTorchFunctionSubclass():
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
# NB: These fake view autograd.Functions are superseded by real view ops. Don't use them!
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user