Only thunkify proxies in some situations (#132421)

The goal of this PR is to avoid stack overflow when we create extremely long chains of thunks, and then evaluate them (e.g., as occurs if you sum(long list of symint)). The basic idea behind this PR is to only thunkify proxies if they're being created in places where they may or may not be used--crucially, symint operations that occur in user code we are tracing are eagerly placed into the graph, even if they may eventually be dead.

I annotated the PR with explanation of changes.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132421
Approved by: https://github.com/Skylion007, https://github.com/zou3519
ghstack dependencies: #132674, #132675
This commit is contained in:
Edward Z. Yang 2024-08-08 04:59:12 -07:00 committed by PyTorch MergeBot
parent 54efd43022
commit aec6332356
12 changed files with 463 additions and 310 deletions

View File

@ -51,6 +51,7 @@ torch.fx.experimental.symbolic_shapes
compute_unbacked_bindings
rebind_unbacked
resolve_unbacked_bindings
is_accessor_node
torch.fx.experimental.proxy_tensor
-------------------------------------
@ -65,3 +66,5 @@ torch.fx.experimental.proxy_tensor
make_fx
handle_sym_dispatch
get_proxy_mode
maybe_enable_thunkify
maybe_disable_thunkify

View File

@ -3480,8 +3480,8 @@ def forward(self, x):
def forward(self, x):
arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
arg0_1 = arg0
slice_1 = torch.ops.aten.slice.Tensor(arg0_1, 2, 0, 3)
sym_size_int = torch.ops.aten.sym_size.int(arg0_1, 0)
slice_1 = torch.ops.aten.slice.Tensor(arg0_1, 2, 0, 3)
sub = sym_size_int - 1
slice_2 = torch.ops.aten.slice.Tensor(arg0_1, 0, 0, sub); sub = None
slice_3 = torch.ops.aten.slice.Tensor(slice_2, 1, 1, sym_size_int); slice_2 = None

View File

@ -2148,7 +2148,6 @@ class TestNestedTensor(torch._dynamo.test_case.TestCase, NestedTensorTestCase):
self._validate_compile(fn, arg_fn)
@unittest.expectedFailure
def test_return_shape(self):
nt, _ = self._get_jagged_tensor(((2, 3, 4), 5), None)

View File

@ -6037,14 +6037,14 @@ def forward(self, x):
def forward(self, x):
item = torch.ops.aten.item.default(x); x = None
sym_constrain_range_for_size_default = torch.ops.aten.sym_constrain_range_for_size.default(item); sym_constrain_range_for_size_default = None
ge = item >= 3
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u1 >= 3 on node 'ge'"); ge = _assert_scalar_default = None
ge_1 = item >= 3
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u1 >= 3 on node 'ge_1'"); ge_1 = _assert_scalar_default = None
le = item <= 5
_assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(le, "Runtime assertion failed for expression u1 <= 5 on node 'le'"); le = _assert_scalar_default_1 = None
gt = item > 2
_assert_scalar_default_2 = torch.ops.aten._assert_scalar.default(gt, "Runtime assertion failed for expression 2 < u1 on node 'gt'"); gt = _assert_scalar_default_2 = None
lt = item < 6
_assert_scalar_default_3 = torch.ops.aten._assert_scalar.default(lt, "Runtime assertion failed for expression u1 < 6 on node 'lt'"); lt = _assert_scalar_default_3 = None
gt_1 = item > 2
_assert_scalar_default_2 = torch.ops.aten._assert_scalar.default(gt_1, "Runtime assertion failed for expression 2 < u1 on node 'gt_1'"); gt_1 = _assert_scalar_default_2 = None
lt_1 = item < 6
_assert_scalar_default_3 = torch.ops.aten._assert_scalar.default(lt_1, "Runtime assertion failed for expression u1 < 6 on node 'lt_1'"); lt_1 = _assert_scalar_default_3 = None
foo_unbacked = torch.ops.testlib.foo_unbacked.default(item); item = None
return (foo_unbacked,)""",
)

View File

@ -1320,8 +1320,8 @@ def forward(self, token, obj, x):
with_effects = torch._higher_order_ops.effects.with_effects(token, torch.ops.higher_order.call_torchbind, obj, 'get'); token = obj = None
getitem = with_effects[0]
getitem_1 = with_effects[1]; with_effects = None
add = torch.ops.aten.add.Tensor(getitem_1, x); getitem_1 = x = None
return (getitem, add)""", # noqa: B950
add_3 = torch.ops.aten.add.Tensor(getitem_1, x); getitem_1 = x = None
return (getitem, add_3)""", # noqa: B950
)
@parametrize("backend", ["eager", "aot_eager"])

View File

@ -1056,8 +1056,8 @@ def forward(self, s0_1, s1_1, x_1, y_1):
self.assertExpectedInline(r, """\
def forward(self, x_1, y_1):
sym_size_int = torch.ops.aten.sym_size.int(x_1, 0); x_1 = None
empty = torch.ops.aten.empty.memory_format([sym_size_int], device = device(type='cpu'), pin_memory = False)
sym_size_int_1 = torch.ops.aten.sym_size.int(y_1, 0); y_1 = None
empty = torch.ops.aten.empty.memory_format([sym_size_int], device = device(type='cpu'), pin_memory = False)
return ((sym_size_int, sym_size_int_1), empty)""")
def test_unary(self):
@ -1355,8 +1355,8 @@ def forward(self, crop_camera_1, mask_1):
view_1 = torch.ops.aten.view.default(expand_1, [sym_size_int, sym_size_int_1, sym_size_int_2]); expand_1 = sym_size_int_1 = sym_size_int_2 = None
bmm = torch.ops.aten.bmm.default(view, view_1); view = view_1 = None
view_2 = torch.ops.aten.view.default(bmm, [sym_size_int, 3, 3]); bmm = None
mul = sym_size_int * 3
view_3 = torch.ops.aten.view.default(view_2, [mul, 3]); view_2 = mul = None
mul_4 = sym_size_int * 3
view_3 = torch.ops.aten.view.default(view_2, [mul_4, 3]); view_2 = mul_4 = None
mm = torch.ops.aten.mm.default(view_3, eye); view_3 = eye = None
view_4 = torch.ops.aten.view.default(mm, [sym_size_int, 3, 3]); mm = sym_size_int = None
index_put_ = torch.ops.aten.index_put_.default(crop_camera_1, [mask_1], view_4); crop_camera_1 = mask_1 = view_4 = index_put_ = None

View File

@ -1536,29 +1536,7 @@ class OutputGraph:
return False
return True
# NB: You could try to expand this to cover more cases by simply
# detecting whenever you have an int output, but this is a bit
# dangerous in case someone adds a function that returns an int but is
# mutating. So manually whitelist for now.
def is_accessor_node(node):
if (
node.op == "call_method"
and isinstance(node.args[0].meta.get("example_value"), torch.Tensor)
and node.target in ["size", "stride", "storage_offset", "item"]
):
return True
if node.op == "call_function" and node.target in [
torch.ops.aten.sym_size,
torch.ops.aten.sym_size.default,
torch.ops.aten.sym_size.int,
torch.ops.aten.sym_stride,
torch.ops.aten.sym_stride.default,
torch.ops.aten.sym_stride.int,
torch.ops.aten.sym_storage_offset,
torch.ops.aten.sym_storage_offset.default,
]:
return True
return False
from torch.fx.experimental.symbolic_shapes import is_accessor_node
for node in reversed(list(self.graph.nodes)):
if len(list(node.users)) == 0:

View File

@ -24,6 +24,10 @@ from torch import Tensor
from torch._decomp.decompositions_for_rng import PhiloxStateTracker
from torch._guards import detect_fake_mode
from torch._prims_common import CUDARngStateHelper
from torch.fx.experimental.proxy_tensor import (
maybe_disable_thunkify,
maybe_enable_thunkify,
)
from torch.fx.experimental.symbolic_shapes import (
definitely_false,
PropagateUnbackedSymInts,
@ -188,6 +192,7 @@ def fn_prepped_for_autograd(
def create_joint(fn: Callable, *, aot_config: AOTConfig) -> Any:
def inner_fn(primals: List[Any], tangents: List[Any]):
outs, tangent_mask = fn(*primals)
assert len(tangent_mask) == len(outs)
outs_to_grad = [
o for needs_tangent, o in zip(tangent_mask, outs) if needs_tangent
@ -365,266 +370,280 @@ def create_functionalized_fn(
) -> Any:
@wraps(fn)
def _functionalized_f_helper(*args):
# See Note [Disabling Functionalize TLS Above Python Functionalization]
disable_above = torch._C._ExcludeDispatchKeyGuard(
torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize)
)
# See Note [Side-Effectful Tokens in AOTAutograd]
if trace_joint:
assert (
isinstance(args, tuple)
and len(args) == 2
and isinstance(args[0], (list, tuple))
with maybe_enable_thunkify():
# See Note [Disabling Functionalize TLS Above Python Functionalization]
disable_above = torch._C._ExcludeDispatchKeyGuard(
torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize)
)
tokens = args[0][: len(meta.tokens)]
actual_args = args[0][len(meta.tokens) :]
args = (actual_args, args[1])
else:
tokens = args[: len(meta.tokens)]
args = args[len(meta.tokens) :]
assert all(token.numel() == 0 for token in tokens)
with disable_above:
# Wrap inputs into functional wrappers
f_args = pytree.tree_map(to_fun, args)
f_tokens = pytree.tree_map(to_fun, tokens)
# Populate the current FunctionalTensorMode with the tokens per
# operator. See Note [FunctionalTensorMode is Stateful]
functional_tensor_mode = torch.utils._python_dispatch._detect_infra_mode(
torch._C._TorchDispatchModeKey.FUNCTIONAL
)
assert functional_tensor_mode is not None
for i, k in enumerate(meta.tokens.keys()):
functional_tensor_mode._tokens[k] = f_tokens[i]
# Run the joint
f_outs = fn(*f_args)
# Return both the tokens and the outputs
# See Note [Side-Effectful Tokens in AOTAutograd]
f_outs = (*functional_tensor_mode._tokens.values(), *f_outs)
if trace_joint:
# We support a limited amount of mutation of graph inputs during the backward pass.
# (This is used e.g. by Float8, which needs to update buffers during the backward pass)
# Here, we perform extra checks for primals that were mutated in the **backward**
# We're doing the checks here instead of doing them with the rest of the input mutation handling because:
# - We need to detect inputs that were mutated in the backward **separately** from mutations that happened
# during the forward, because the handling is different: some input mutations from the the forward
# can be only handled in a fw-only runtime epilogue, and in theory if we wanted to handle those same
# types of mutations in the backward we would need a bw-only runtime epilogue.
# - We could in theory have our analysis pass differentiate mutations in the fw from mutations in
# the bw by running our analysis first on the fw-only graph, and then on the joint graph. This would
# require an extra round of tracing though, so it's more efficient to do in-line here.
assert (
isinstance(args, tuple)
and len(args) == 2
and isinstance(args[0], (list, tuple))
)
# Only look at mutations that happened to forward inputs (e.g. fw buffers that were saved for bw)
primals_before = args[0]
primals_after = pytree.tree_map(from_fun, f_args[0])
for idx, (f_inpt, before, after, inpt_info) in enumerate(
zip(f_args[0], primals_before, primals_after, meta.input_info)
):
# Store information about mutations in joint(for backward analysis)
joint_mutates_data = has_data_mutation(f_inpt)
joint_mutates_metadata = has_metadata_mutation(
f_inpt, before, check_only_storage_mutation=False
if trace_joint:
assert (
isinstance(args, tuple)
and len(args) == 2
and isinstance(args[0], (list, tuple))
)
tokens = args[0][: len(meta.tokens)]
actual_args = args[0][len(meta.tokens) :]
args = (actual_args, args[1])
else:
tokens = args[: len(meta.tokens)]
args = args[len(meta.tokens) :]
assert all(token.numel() == 0 for token in tokens)
# Ban metadata mutations on fw inputs during the bw
if not inpt_info.mutates_metadata:
assert (
not joint_mutates_metadata
), "Found a graph input that had its metadata mutated in the backward. This is not supported"
with disable_above:
# The functionalization code here can potentially trigger traces
# into the graph, but we'd prefer to NOT do this, because if we
# trace them now, we will end up with FX nodes that don't have
# module stack annotations, which makes unflattener unhappy.
# Wrap inputs into functional wrappers
f_args = pytree.tree_map(to_fun, args)
f_tokens = pytree.tree_map(to_fun, tokens)
# Ban storage resizing on fw inputs during the bw
if not inpt_info.mutation_inductor_storage_resize:
assert not was_inductor_storage_resized(
f_inpt
), "Found a graph input that had storage resizing in the backward. This is not supported"
# Allow data mutations on fw inputs during the bw, but only if they do not require grad
# So we can guarantee that we can keep the mutations in the graph
if (
joint_mutates_data
and not inpt_info.mutates_data
and not inpt_info.mutates_storage_metadata
):
# Not banning here mutations on inpt_info.requires_grad -
# we'll check at runtime and fail only when backward is under torch.is_grad_enabled (create_graph)
# Add node meta for copy_ for partitioner that this node should be in backward graph.
with torch.fx.traceback.preserve_node_meta(), set_partitioner_tag(
"must_be_in_backward"
):
before.copy_(after)
meta.indices_of_inputs_that_requires_grad_with_mutations_in_bw.append(
idx
# Populate the current FunctionalTensorMode with the tokens per
# operator. See Note [FunctionalTensorMode is Stateful]
functional_tensor_mode = (
torch.utils._python_dispatch._detect_infra_mode(
torch._C._TorchDispatchModeKey.FUNCTIONAL
)
# Now that we covered mutations to *forward* inputs during the backward,
# we also need to cover mutations to *backward-only* inputs during the backward (e.g. mutation to a grad_out).
# Today, we will just error in all cases of this happening unless someone needs us to support it.
tangents_before = args[1]
tangents_after = pytree.tree_map(from_fun, f_args[1])
for f_inpt, before, after in zip(
f_args[1], tangents_before, tangents_after
):
assert not has_metadata_mutation(
f_inpt, before, check_only_storage_mutation=False
) and not has_data_mutation(
f_inpt
), "Found an input to the backward that was mutated during the backward pass. This is not supported"
)
assert functional_tensor_mode is not None
for i, k in enumerate(meta.tokens.keys()):
functional_tensor_mode._tokens[k] = f_tokens[i]
if aot_config.keep_inference_input_mutations:
# Note: This is a bit annoying. There's a layering issue here, where:
# (1) functionalization needs to operate on **synthetic base** inputs, before unpacking them into the "real" inputs.
# (2) For keep_input_mutations, we support tracing a call to copy_() directly on mutated inputs.
# However, we **only** want to support this for inputs that have data-only (and no metadata) mutations,
# because inductor (and backends in generally) would prefer not to see these (e.g. as_strided_(), resize_()).
# This makes it pretty difficult for this logic to operate on synthetic bases.
# (3) In addition, there are cases where it's significantly cheaper to perform the copy on the individual
# (unpacked) input aliases, instead of the synthetic base.
# Example case where (3) could be important:
#
# def f(x, y):
# x.mul_(2)
# y.mul_(3)
# return x, y
# a = torch.ones(1'000'000)
# x, y = out(a[0:9], a[1:10])
#
# It would be much better to add copy_() calls into the graph for the two tiny slices, instead of materializing
# a giant "updated synthetic base" and copying into a's entire storage.
#
# For now, we are pessimistically not performing the optimization from (3);
# we will materialize an "updated" synthetic base, and copy it back to the synthetic input base.
# This allows us to factor aot autograd much more nicely, since only one area of the code needs to worry
# about synthetic bases.
for i, (inpt_old, inpt_f) in enumerate(
zip(args, f_args) if not trace_joint else zip(args[0], f_args[0])
):
if not isinstance(inpt_f, torch.Tensor):
continue
assert is_fun(inpt_f)
inpt_new = from_fun(inpt_f)
if meta.input_info[i].mutation_type == MutationType.MUTATED_IN_GRAPH:
# See Note [set_() Input Mutations in AOTAutograd]
# all mutations on the input must be under no_grad, so it is safe to put in the graph
# Here, we're saying that if an input experienced a set call, inp.set_(other),
# then we can effectively not have to worry about whether its data was mutated.
# There are 3 cases:
# (1) We mutate inp *after* the set_() call. other is a graph intermediate.
# In this case, we're not really mutating the input storage of "inp";
# we're mutating the storage of an intermdiate value (other),
# and slamming that storage into the input tensor. So no data mutation is necessary.
# (2) We mutate inp *after* the set_() call. other is a graph *input*.
# In this case, the data mutation will be properly handled in the runtime
# epilogue during the processing of "other"
# (3) We mutate inp *before* the set_() call.
# This case is *not* currently handled.
if meta.input_info[i].mutates_storage_metadata:
with torch.no_grad():
inpt_old.set_(inpt_new)
# Run the joint
f_outs = fn(*f_args)
# Note [Ordering of resize_() and set_()]
# Importantly: the common usage in FSDP is that we have a dummy parameter
# that sees a set_() and **Then** a resize_().
# We must put those mutations into the graph in the same order,
# Since running them in the opposite order will have different behavior.
# We fully ban resize_() followed by set_() for now, although in principal
# we could support this
if meta.input_info[i].mutation_inductor_storage_resize:
# resizing is not supported on subclasses (we error earlier if this happens)
from torch._subclasses.functional_tensor import FunctionalTensor
# Return both the tokens and the outputs
# See Note [Side-Effectful Tokens in AOTAutograd]
f_outs = (*functional_tensor_mode._tokens.values(), *f_outs)
assert isinstance(inpt_f, FunctionalTensor)
old_storage_size = torch._functionalize_get_storage_size( # type: ignore[attr-defined]
inpt_f.elem, before=True
)
new_storage_size = torch._functionalize_get_storage_size( # type: ignore[attr-defined]
inpt_f.elem, before=False
)
if old_storage_size != new_storage_size:
assert (
old_storage_size == 0 or new_storage_size == 0
), f"""\
Encountered a storage resize during tracing on input {i}. Old nbytes={old_storage_size}, new nbytes={new_storage_size}
We only support storage resizing on graph inputs as long as the input either starts or ends with a storage size of 0
(the case for FSDP)"""
torch.ops.inductor.resize_storage_bytes_(
inpt_old, new_storage_size
)
if new_storage_size == 0:
# Even if we marked the input as having a data mutation (thus needing a copy_()),
# We should **ignore** it if our input has no storage
# (this can happen if, e.g. we temporarily resize our input, copy data into it,
# and resize it back down to zero)
continue
# Optimization: if the copy_() is a no-op then don't include it in the graph.
# In theory inductor could optimize this away, however in fsdp, we end up with
# param.copy_(param), where param is a zero-storage-size tensor,
# and running this op in eager mode (using the aot_eager backend) will result in a segfault.
# So we may as well optimize it away here.
if inpt_old is inpt_new:
# (This check needs to be done after putting resize_() in the graph,
# since a resize_(0) doesn't actually change the FunctionalTensor's inner tensor)
continue
# We found an input that had a (data-only) mutation.
# Since keep_input_mutations is set, we need to faithfully apply a copy_()
# so the compiler will see the input mutation in the graph.
if (
meta.input_info[i].mutates_data
and meta.input_info[i].mutations_hidden_from_autograd
):
# Hidden from autograd = run under no_grad, **and** don't bump VC
# (although if the tensor was created in inference mode, it has no VC)
if inpt_old.is_inference():
maybe_preserve_vc = nullcontext()
else:
maybe_preserve_vc = torch.autograd._unsafe_preserve_version_counter(
inpt_old # type: ignore[assignment]
)
with torch.no_grad(), maybe_preserve_vc:
inpt_old.copy_(inpt_new)
elif (
meta.input_info[i].mutates_data
and meta.input_info[i].mutations_under_no_grad_or_inference_mode
):
# Under no_grad = run under no_grad (we still bump the VC though)
# (inference_mode will also bump the VC, as long as the tensor in question
# was created outside of inference_mode)
with torch.no_grad():
inpt_old.copy_(inpt_new)
elif meta.input_info[i].mutates_data:
inpt_old.copy_(inpt_new)
# When an output tensor is a functionalized mutated input, and we
# were able to move the mutation in to the graph then we can return
# the mutated input directly. This prevents duplicating the
# tensors contents.
flat_outs, outs_spec = pytree.tree_flatten(f_outs)
flat_outs = [from_fun(o) for o in flat_outs]
num_outs = len(meta.output_info)
for i, outp in enumerate(flat_outs[:num_outs]):
info = meta.output_info[i]
if info.output_type != OutputType.is_input:
continue
assert info.base_idx is not None
if (
meta.input_info[info.base_idx].mutation_type
== MutationType.MUTATED_IN_GRAPH
if trace_joint:
# We support a limited amount of mutation of graph inputs during the backward pass.
# (This is used e.g. by Float8, which needs to update buffers during the backward pass)
# Here, we perform extra checks for primals that were mutated in the **backward**
# We're doing the checks here instead of doing them with the rest of the input mutation handling because:
# - We need to detect inputs that were mutated in the backward **separately** from mutations that happened
# during the forward, because the handling is different: some input mutations from the the forward
# can be only handled in a fw-only runtime epilogue, and in theory if we wanted to handle those same
# types of mutations in the backward we would need a bw-only runtime epilogue.
# - We could in theory have our analysis pass differentiate mutations in the fw from mutations in
# the bw by running our analysis first on the fw-only graph, and then on the joint graph. This would
# require an extra round of tracing though, so it's more efficient to do in-line here.
assert (
isinstance(args, tuple)
and len(args) == 2
and isinstance(args[0], (list, tuple))
)
# Only look at mutations that happened to forward inputs (e.g. fw buffers that were saved for bw)
primals_before = args[0]
primals_after = pytree.tree_map(from_fun, f_args[0])
for idx, (f_inpt, before, after, inpt_info) in enumerate(
zip(f_args[0], primals_before, primals_after, meta.input_info)
):
fw_args = args[0] if trace_joint else args
flat_outs[i] = fw_args[info.base_idx]
return pytree.tree_unflatten(flat_outs, outs_spec)
# Store information about mutations in joint(for backward analysis)
joint_mutates_data = has_data_mutation(f_inpt)
return pytree.tree_map(from_fun, f_outs)
joint_mutates_metadata = has_metadata_mutation(
f_inpt, before, check_only_storage_mutation=False
)
# Ban metadata mutations on fw inputs during the bw
if not inpt_info.mutates_metadata:
assert (
not joint_mutates_metadata
), "Found a graph input that had its metadata mutated in the backward. This is not supported"
# Ban storage resizing on fw inputs during the bw
if not inpt_info.mutation_inductor_storage_resize:
assert not was_inductor_storage_resized(
f_inpt
), "Found a graph input that had storage resizing in the backward. This is not supported"
# Allow data mutations on fw inputs during the bw, but only if they do not require grad
# So we can guarantee that we can keep the mutations in the graph
if (
joint_mutates_data
and not inpt_info.mutates_data
and not inpt_info.mutates_storage_metadata
):
# Not banning here mutations on inpt_info.requires_grad -
# we'll check at runtime and fail only when backward is under torch.is_grad_enabled (create_graph)
# Add node meta for copy_ for partitioner that this node should be in backward graph.
with torch.fx.traceback.preserve_node_meta(), set_partitioner_tag(
"must_be_in_backward"
):
before.copy_(after)
meta.indices_of_inputs_that_requires_grad_with_mutations_in_bw.append(
idx
)
# Now that we covered mutations to *forward* inputs during the backward,
# we also need to cover mutations to *backward-only* inputs during the backward (e.g. mutation to a grad_out).
# Today, we will just error in all cases of this happening unless someone needs us to support it.
tangents_before = args[1]
tangents_after = pytree.tree_map(from_fun, f_args[1])
for f_inpt, before, after in zip(
f_args[1], tangents_before, tangents_after
):
assert not has_metadata_mutation(
f_inpt, before, check_only_storage_mutation=False
) and not has_data_mutation(
f_inpt
), "Found an input to the backward that was mutated during the backward pass. This is not supported"
if aot_config.keep_inference_input_mutations:
# Note: This is a bit annoying. There's a layering issue here, where:
# (1) functionalization needs to operate on **synthetic base** inputs, before unpacking them into the "real" inputs.
# (2) For keep_input_mutations, we support tracing a call to copy_() directly on mutated inputs.
# However, we **only** want to support this for inputs that have data-only (and no metadata) mutations,
# because inductor (and backends in generally) would prefer not to see these (e.g. as_strided_(), resize_()).
# This makes it pretty difficult for this logic to operate on synthetic bases.
# (3) In addition, there are cases where it's significantly cheaper to perform the copy on the individual
# (unpacked) input aliases, instead of the synthetic base.
# Example case where (3) could be important:
#
# def f(x, y):
# x.mul_(2)
# y.mul_(3)
# return x, y
# a = torch.ones(1'000'000)
# x, y = out(a[0:9], a[1:10])
#
# It would be much better to add copy_() calls into the graph for the two tiny slices, instead of materializing
# a giant "updated synthetic base" and copying into a's entire storage.
#
# For now, we are pessimistically not performing the optimization from (3);
# we will materialize an "updated" synthetic base, and copy it back to the synthetic input base.
# This allows us to factor aot autograd much more nicely, since only one area of the code needs to worry
# about synthetic bases.
for i, (inpt_old, inpt_f) in enumerate(
zip(args, f_args) if not trace_joint else zip(args[0], f_args[0])
):
if not isinstance(inpt_f, torch.Tensor):
continue
assert is_fun(inpt_f)
inpt_new = from_fun(inpt_f)
if (
meta.input_info[i].mutation_type
== MutationType.MUTATED_IN_GRAPH
):
# See Note [set_() Input Mutations in AOTAutograd]
# all mutations on the input must be under no_grad, so it is safe to put in the graph
# Here, we're saying that if an input experienced a set call, inp.set_(other),
# then we can effectively not have to worry about whether its data was mutated.
# There are 3 cases:
# (1) We mutate inp *after* the set_() call. other is a graph intermediate.
# In this case, we're not really mutating the input storage of "inp";
# we're mutating the storage of an intermdiate value (other),
# and slamming that storage into the input tensor. So no data mutation is necessary.
# (2) We mutate inp *after* the set_() call. other is a graph *input*.
# In this case, the data mutation will be properly handled in the runtime
# epilogue during the processing of "other"
# (3) We mutate inp *before* the set_() call.
# This case is *not* currently handled.
if meta.input_info[i].mutates_storage_metadata:
with torch.no_grad():
inpt_old.set_(inpt_new)
# Note [Ordering of resize_() and set_()]
# Importantly: the common usage in FSDP is that we have a dummy parameter
# that sees a set_() and **Then** a resize_().
# We must put those mutations into the graph in the same order,
# Since running them in the opposite order will have different behavior.
# We fully ban resize_() followed by set_() for now, although in principal
# we could support this
if meta.input_info[i].mutation_inductor_storage_resize:
# resizing is not supported on subclasses (we error earlier if this happens)
from torch._subclasses.functional_tensor import (
FunctionalTensor,
)
assert isinstance(inpt_f, FunctionalTensor)
old_storage_size = torch._functionalize_get_storage_size( # type: ignore[attr-defined]
inpt_f.elem, before=True
)
new_storage_size = torch._functionalize_get_storage_size( # type: ignore[attr-defined]
inpt_f.elem, before=False
)
if old_storage_size != new_storage_size:
assert (
old_storage_size == 0 or new_storage_size == 0
), f"""\
Encountered a storage resize during tracing on input {i}. Old nbytes={old_storage_size}, new nbytes={new_storage_size}
We only support storage resizing on graph inputs as long as the input either starts or ends with a storage size of 0
(the case for FSDP)"""
torch.ops.inductor.resize_storage_bytes_(
inpt_old, new_storage_size
)
if new_storage_size == 0:
# Even if we marked the input as having a data mutation (thus needing a copy_()),
# We should **ignore** it if our input has no storage
# (this can happen if, e.g. we temporarily resize our input, copy data into it,
# and resize it back down to zero)
continue
# Optimization: if the copy_() is a no-op then don't include it in the graph.
# In theory inductor could optimize this away, however in fsdp, we end up with
# param.copy_(param), where param is a zero-storage-size tensor,
# and running this op in eager mode (using the aot_eager backend) will result in a segfault.
# So we may as well optimize it away here.
if inpt_old is inpt_new:
# (This check needs to be done after putting resize_() in the graph,
# since a resize_(0) doesn't actually change the FunctionalTensor's inner tensor)
continue
# We found an input that had a (data-only) mutation.
# Since keep_input_mutations is set, we need to faithfully apply a copy_()
# so the compiler will see the input mutation in the graph.
if (
meta.input_info[i].mutates_data
and meta.input_info[i].mutations_hidden_from_autograd
):
# Hidden from autograd = run under no_grad, **and** don't bump VC
# (although if the tensor was created in inference mode, it has no VC)
if inpt_old.is_inference():
maybe_preserve_vc = nullcontext()
else:
maybe_preserve_vc = torch.autograd._unsafe_preserve_version_counter(
inpt_old # type: ignore[assignment]
)
with torch.no_grad(), maybe_preserve_vc:
inpt_old.copy_(inpt_new)
elif (
meta.input_info[i].mutates_data
and meta.input_info[
i
].mutations_under_no_grad_or_inference_mode
):
# Under no_grad = run under no_grad (we still bump the VC though)
# (inference_mode will also bump the VC, as long as the tensor in question
# was created outside of inference_mode)
with torch.no_grad():
inpt_old.copy_(inpt_new)
elif meta.input_info[i].mutates_data:
inpt_old.copy_(inpt_new)
# When an output tensor is a functionalized mutated input, and we
# were able to move the mutation in to the graph then we can return
# the mutated input directly. This prevents duplicating the
# tensors contents.
flat_outs, outs_spec = pytree.tree_flatten(f_outs)
flat_outs = [from_fun(o) for o in flat_outs]
num_outs = len(meta.output_info)
for i, outp in enumerate(flat_outs[:num_outs]):
info = meta.output_info[i]
if info.output_type != OutputType.is_input:
continue
assert info.base_idx is not None
if (
meta.input_info[info.base_idx].mutation_type
== MutationType.MUTATED_IN_GRAPH
):
fw_args = args[0] if trace_joint else args
flat_outs[i] = fw_args[info.base_idx]
return pytree.tree_unflatten(flat_outs, outs_spec)
return pytree.tree_map(from_fun, f_outs)
# Kinda annoying, but needed to make sure that the fx graph we trace out has "primals"
# and "tangents" as its input names (which are special-cased by the partitioner)
@ -709,10 +728,14 @@ def aot_dispatch_subclass(
return unwrapped_outs
def joint_fn(primals, tangents):
return inner_fn(flat_fn_maybe_joint, (primals, tangents), use_trace_joint=True)
with maybe_enable_thunkify():
return inner_fn(
flat_fn_maybe_joint, (primals, tangents), use_trace_joint=True
)
def fw_fn(*primals):
return inner_fn(flat_fn_maybe_joint, primals, use_trace_joint=False)
with maybe_enable_thunkify():
return inner_fn(flat_fn_maybe_joint, primals, use_trace_joint=False)
def metadata_fn(*primals):
return inner_fn(fw_only, primals, use_trace_joint=False)
@ -771,7 +794,7 @@ def create_functional_call(mod, params_spec, params_len, store_orig_mod=False):
def functional_call(*args, **kwargs):
with stateless._reparametrize_module(
mod, pytree.tree_unflatten(args[:params_len], params_spec)
):
), maybe_disable_thunkify():
if isinstance(mod, torch.fx.GraphModule):
with fx_traceback.preserve_node_meta(), warnings.catch_warnings():
warnings.filterwarnings(

View File

@ -54,6 +54,7 @@ from weakref import WeakKeyDictionary
if TYPE_CHECKING:
import types
import sympy
from torch._ops import OpOverload
from torch.fx._symbolic_trace import PHBase
@ -61,7 +62,8 @@ if TYPE_CHECKING:
__all__ = [
"PythonKeyTracer", "dispatch_trace", "make_fx", "DecompositionInterpreter",
"py_sym_types", "get_innermost_proxy_mode", "get_proxy_mode", "handle_sym_dispatch"
"py_sym_types", "get_innermost_proxy_mode", "get_proxy_mode", "handle_sym_dispatch",
"maybe_enable_thunkify", "maybe_disable_thunkify",
]
_ProxyTracer = Union["PythonKeyTracer", "_GraphAppendingTracerEx"]
@ -157,6 +159,7 @@ def set_proxy_slot(
tracer: _ProxyTracer,
proxy: object
) -> None:
log.debug("set_proxy_slot %s (%s) %s", obj, id(obj), proxy)
if isinstance(obj, Tensor):
# We DO want to clobber proxies whenever we run an inplace operation
# on a tensor, and it affects the metadata on the proxy.
@ -175,6 +178,22 @@ def set_proxy_slot(
if obj not in tracer.symnode_tracker:
tracer.symnode_tracker[obj] = typing.cast(_PySymProxyType, proxy)
# WAR: python test/dynamo/test_subclasses.py
# TestNestedTensor.test_basic_autograd
#
# AOTAutograd doesn't pass the "outer sizes" as an actual argument
# to make_fx, but it is made use of internally in AOTAutograd's
# call to tensor unflatten. Because the outer sizes isn't passed
# as an argument, it is therefore untracked. However, it turns
# out you luck out, because *Dynamo* will manually add the outer
# sizes as an argument so you can fix up the proxy'ness.
#
# This is probably fixed in
# https://github.com/pytorch/pytorch/pull/125941/
import sympy
if isinstance(obj.node.expr, sympy.Symbol):
tracer.sympy_expr_tracker[obj.node.expr] = proxy
def has_proxy_slot(obj: Tensor, tracer: _ProxyTracer) -> bool:
assert isinstance(obj, (Tensor, SymNode)), type(obj)
return bool(get_proxy_slot(obj, tracer, False, lambda _: True))
@ -276,10 +295,15 @@ def get_proxy_slot(
tracker = tracer.symnode_tracker
if obj not in tracker:
if isinstance(default, _NoDefault):
raise RuntimeError(f"{obj} is not tracked with proxy for {tracer}")
return default
value = tracker[obj]
# Last ditch
if isinstance(obj, py_sym_types) and obj.node.expr in tracer.sympy_expr_tracker:
value = tracer.sympy_expr_tracker[obj.node.expr]
else:
if isinstance(default, _NoDefault):
raise RuntimeError(f"{obj} ({id(obj)})is not tracked with proxy for {tracer}")
return default
else:
value = tracker[obj]
res = transform(value)
return res
@ -330,6 +354,54 @@ def extract_val(val: _ExtractValType) -> _ExtractValType:
typing_extensions.assert_never(val)
@contextmanager
def _enable_thunkify(tracer: _ProxyTracer, *, enable: bool = True) -> Generator[None, None, None]:
"""
Enable thunkification inside the context manager. Thunkification prevents
SymNode computation from directly being traced into an FX graph; instead,
the compute is only added to the graph if it is actually used. This helps
us track SymNode compute when it is computed (since we need /something/
to put in the tracker) even if it is unlikely to be used.
"""
old = tracer.enable_thunkify
tracer.enable_thunkify = enable
try:
yield
finally:
tracer.enable_thunkify = old
@contextmanager
def maybe_disable_thunkify() -> Generator[None, None, None]:
"""Within a context, disable thunkification. See :func:`maybe_enable_thunkify`
for more details. This is helpful if you have a wrapper function which
you want to enable thunkification on, but in some segment on the inside (say,
the original user function), you want to disable thunkification as you know
it is not needed there.
"""
proxy_mode = get_proxy_mode()
if proxy_mode is not None:
with _enable_thunkify(proxy_mode.tracer, enable=False):
yield
else:
yield
@contextmanager
def maybe_enable_thunkify() -> Generator[None, None, None]:
"""Within this context manager, if you are doing make_fx tracing, we will thunkify
all SymNode compute and avoid tracing it into the graph unless it is actually needed.
You should prefer to avoid using this as much as possible, as lazy evaluation of
SymNode tracing can lead to long chains of thunks which will stack overflow
if you evaluate them. However, this is currently sometimes necessary as there
are buggy parts of PT2 which will fail with "s0 is not tracked with proxy" error
due to insufficient tracing of SymNode computation.
"""
proxy_mode = get_proxy_mode()
if proxy_mode is not None:
with _enable_thunkify(proxy_mode.tracer):
yield
else:
yield
# Note [invariants for node meta 'val']
# What invariants do we have for the 'val' set on the FX node? It has accurate
# metadata... but only for metadata that exists "below" all other subsystems
@ -340,19 +412,24 @@ def extract_val(val: _ExtractValType) -> _ExtractValType:
def set_meta(proxy: Proxy, val: _ExtractValType) -> Proxy:
proxy.node.meta['val'] = extract_val(val)
# Best effort tensor_meta setting; prefer using val!
if is_fake(val):
proxy.node.meta['tensor_meta'] = _extract_tensor_metadata(val)
elif isinstance(val, Tensor) and not val.is_sparse:
proxy.node.meta['tensor_meta'] = _extract_tensor_metadata(val)
with _enable_thunkify(proxy.tracer): # type: ignore[arg-type]
# Best effort tensor_meta setting; prefer using val!
if is_fake(val):
proxy.node.meta['tensor_meta'] = _extract_tensor_metadata(val)
elif isinstance(val, Tensor) and not val.is_sparse:
proxy.node.meta['tensor_meta'] = _extract_tensor_metadata(val)
return proxy
def thunkify(f: Callable[_P, R], *args: _P.args, **kwargs: _P.kwargs) -> Thunk[R]:
def thunkify(tracer: _ProxyTracer, f: Callable[_P, R], *args: _P.args, **kwargs: _P.kwargs) -> Thunk[R]:
"""
Delays computation of f until it's called again
Also caches the result
"""
return Thunk(functools.partial(f, *args, **kwargs))
if tracer.enable_thunkify:
return Thunk(functools.partial(f, *args, **kwargs))
else:
r = f(*args, **kwargs)
return Thunk(lambda: r)
def track_tensor(tensor: Tensor, proxy: Proxy, *, constant: Optional[Tensor], tracer: _ProxyTracer) -> None:
def try_set_proxy_slot(
@ -363,7 +440,8 @@ def track_tensor(tensor: Tensor, proxy: Proxy, *, constant: Optional[Tensor], tr
) -> None:
assert callable(proxy_callable)
if isinstance(outer_s, SymInt):
set_proxy_slot(outer_s, tracer, thunkify(proxy_callable, outer_s, *args, **kwargs))
with _enable_thunkify(tracer):
set_proxy_slot(outer_s, tracer, thunkify(tracer, proxy_callable, outer_s, *args, **kwargs))
# The basic idea is that we need to associate each tensor/SymInt
# with a Proxy. How do we setup this association? We just store
# the proxy on the proxy slot of the object, keyed on the tracer
@ -411,7 +489,7 @@ def track_tensor_tree(
assert isinstance(proxy, Proxy)
# NB: eagerly set meta here, so that the numbering is in order
set_meta(proxy, e)
set_proxy_slot(e, tracer, thunkify(lambda: proxy))
set_proxy_slot(e, tracer, thunkify(tracer, lambda: proxy))
elif isinstance(e, _AnyScriptObject):
assert isinstance(proxy, Proxy)
set_proxy_slot(e, tracer, proxy)
@ -700,7 +778,8 @@ def proxy_call(
# Adding an undefined attribute to Tensor?
args[0].proxy = proxy_out # type: ignore[attr-defined]
out = func(*args, **kwargs)
with _enable_thunkify(proxy_mode.tracer):
out = func(*args, **kwargs)
# In some circumstances, we will be tracing in a situation where a tensor
# is *statically* known to be a constant (currently, this only happens if
@ -804,12 +883,14 @@ class PythonKeyTracer(Tracer):
self.tensor_tracker = WeakTensorKeyDictionary()
self.symnode_tracker = _SymNodeDict()
self.script_object_tracker = WeakIdKeyDictionary(dict=None, ref_type=_WeakHashRef)
self.sympy_expr_tracker: Dict[sympy.Symbol, object] = dict()
# Stores the torch function that was called during tracing
self.torch_fn_metadata = None
# Stores the counts for every torch function called. This is to help
# distinguish between different calls to the same torch function.
self.torch_fn_counts = {}
self.enable_thunkify = False
# In general, we don't want to make modules leaves. In principle, users of
# this tracer might want to override this in order to turn a couple specific
@ -901,6 +982,32 @@ def dispatch_trace(
concrete_args: Optional[Tuple[Any, ...]] = None,
) -> GraphModule:
graph = tracer.trace(root, concrete_args)
# NB: be careful not to DCE .item() calls
def impure_pred(n: fx.Node) -> bool:
from .symbolic_shapes import is_accessor_node
# Always defer to the built-in notion of impure
if n.is_impure():
return True
# Accessors always OK to DCE
if is_accessor_node(n):
return False
# If the operator in question takes SymInt args to SymInt output,
# we assume it's pure and OK to DCE
if (
isinstance(n.meta.get('val'), py_sym_types) and
# NB: constant args ok
all(isinstance(a.meta.get('val'), py_sym_types) for a in n.args if isinstance(a, fx.Node))
):
return False
# No idea, just assume it's not OK
return True
graph.eliminate_dead_code(impure_pred)
from torch._inductor.fx_passes.dedupe_symint_uses import dedupe_symints
dedupe_symints(graph)
name = root.__class__.__name__ if isinstance(root, Module) else root.__name__
@ -945,6 +1052,7 @@ def wrap_key(f: Callable[_P, R], tensors: _P.args, tracer: _ProxyTracer, pre_dis
return wrapped
# TODO: Make downstream users of this work with OperatorBase
ORIGINAL_ATEN: Optional[object] = None
@contextmanager
def set_original_aten_op(func: OpOverload) -> Generator[None, None, None]:
@ -1129,7 +1237,7 @@ class ProxyTorchDispatchMode(TorchDispatchMode):
# were symbolic) and it is no longer necessary to trace the
# computation. This could occur if func triggered some guards.
if isinstance(out, py_sym_types):
p_out_thunk = thunkify(self._compute_proxy, func=func, args=args, out=out)
p_out_thunk = thunkify(self.tracer, self._compute_proxy, func=func, args=args, out=out)
set_proxy_slot(out, self.tracer, p_out_thunk)
return out
@ -1139,8 +1247,10 @@ class _GraphAppendingTracerEx(fx.proxy.GraphAppendingTracer):
script_object_tracker: WeakKeyDictionary
symnode_tracker: WeakKeyDictionary
tensor_tracker: WeakTensorKeyDictionary
sympy_expr_tracker: Dict[sympy.Symbol, object]
torch_fn_metadata: Optional[OpOverload]
torch_fn_counts: Dict[OpOverload, int]
enable_thunkify: bool = False
# TODO: I'm not sure what the point of this class is; you can just
@ -1159,6 +1269,7 @@ class DecompositionInterpreter(fx.Interpreter):
# Blegh
self.tracer.tensor_tracker = WeakTensorKeyDictionary()
self.tracer.symnode_tracker = weakref.WeakKeyDictionary()
self.tracer.sympy_expr_tracker = dict()
self.decomposition_table = decomposition_table or {}
self.mode = ProxyTorchDispatchMode(self.tracer, tracing_mode="real")
@ -1358,6 +1469,7 @@ class _ModuleStackTracer(PythonKeyTracer):
concrete_args: Optional[Dict[str, object]]
) -> fx.Graph:
res = super().trace(root, concrete_args)
# Since we are making _AttrProxy mimic the original
# submodule, when someone registers a module directly
# to the tracer while tracing, the proxy object gets registered

View File

@ -110,7 +110,7 @@ __all__ = [
"StatefulSymbolicContext", "SubclassSymbolicContext", "statically_known_true",
"guard_size_oblivious", "check_consistent",
"compute_unbacked_bindings", "ConvertIntKey",
"rebind_unbacked", "resolve_unbacked_bindings",
"rebind_unbacked", "resolve_unbacked_bindings", "is_accessor_node",
]
# FX node metadata keys for symbolic shape FX graph.
@ -337,6 +337,32 @@ def rebind_unbacked(shape_env, n: torch.fx.Node, result):
# Reuse the OLD symbol name
shape_env._rename_unbacked_to(raw_u1, raw_u0)
# NB: You could try to expand this to cover more cases by simply
# detecting whenever you have an int output, but this is a bit
# dangerous in case someone adds a function that returns an int but is
# mutating. So manually whitelist for now.
def is_accessor_node(node: torch.fx.Node) -> bool:
# Dynamo only exercised condition
if (
node.op == "call_method"
and isinstance(node.args[0].meta.get("example_value"), torch.Tensor)
and node.target in ["size", "stride", "storage_offset", "item"]
):
return True
if node.op == "call_function" and node.target in [
torch.ops.aten.sym_size,
torch.ops.aten.sym_size.default,
torch.ops.aten.sym_size.int,
torch.ops.aten.sym_stride,
torch.ops.aten.sym_stride.default,
torch.ops.aten.sym_stride.int,
torch.ops.aten.sym_storage_offset,
torch.ops.aten.sym_storage_offset.default,
torch.ops.aten.sym_numel.default,
]:
return True
return False
def canonicalize_bool_expr(expr: SympyBoolean) -> SympyBoolean:
r""" Canonicalize a boolean expression by transforming it into a lt / le
inequality and moving all the non-constant terms to the rhs.

View File

@ -8,6 +8,7 @@ import torch
import inspect
import operator
import collections
import logging
from dataclasses import is_dataclass, fields
@ -25,6 +26,9 @@ __all__ = ['TracerBase', 'GraphAppendingTracer', 'TraceError',
'ScopeContextManager']
log = logging.getLogger(__name__)
@compatibility(is_backward_compatible=False)
class Scope:
""" Scope object that records the module path and the module type
@ -136,6 +140,7 @@ class TracerBase:
modification of values used in node creation. For example, one might
want to disallow in-place operations from being recorded.
"""
if kind == 'call_function' and self.check_mutable_operations:
check_for_mutable_operation(target, args, kwargs)
@ -175,6 +180,8 @@ class TracerBase:
elif self.module_stack:
node.meta['nn_module_stack'] = copy.copy(self.module_stack)
log.debug("create_node %s", node)
return node
@compatibility(is_backward_compatible=True)

View File

@ -302,14 +302,19 @@ class NestedTensor(torch.Tensor):
if kwargs is None:
kwargs = {}
from torch.fx.experimental.proxy_tensor import maybe_enable_thunkify
from .ops import jagged_torch_function
try:
return jagged_torch_function(func, *args, **kwargs)
except NotImplementedError:
pass
with torch._C.DisableTorchFunctionSubclass():
return func(*args, **kwargs)
# This should be removed after
# https://github.com/pytorch/pytorch/pull/125941/ lands
with maybe_enable_thunkify():
try:
return jagged_torch_function(func, *args, **kwargs)
except NotImplementedError:
pass
with torch._C.DisableTorchFunctionSubclass():
return func(*args, **kwargs)
# NB: These fake view autograd.Functions are superseded by real view ops. Don't use them!