[aotd] Support mutations of the same input in fw and bw (#155354)

Original issue: https://github.com/pytorch/pytorch/issues/154820

The issue happens when there is a mutation for the same input in forward AND in backward.

AOTD emited copy_ after joint_function tracing. This made this fx-node to correspond to the side effects of both mutations (in forward and in backward).
After that partitioner can put it either in forward or in backward.

The fix:

1/ Introduce joint_function.handle that allows to set "post_forward" callback, to be able to check inputs state after forward

We do not want to apply the mutation after joint, if we already applied it in forward. For that we need "mutation_counter" and memorize the version of mutation that we applied for  forward mutation.

2/ Exposing mutation_counter to python

We want to keep invariant that copy_ exist only in the end of joint graph.

3/ We memorize mutation_counter and state of the inputs after forward, using the handle post_forward.
Emit post_forward mutations after joint graph fully traced.

add for post_forward mutations "must_be_in_forward" tag (similar to existing "must_be_in_backward") to keep them in forward.

4/ Ban recompute of the source of mutation. Recompute can apply the same op (e.g. add) in forward and backward.
For this set MUST_SAVE for the source of mutation in forward.

proxy_tensor changes:

By default proxy tensor updates tensor_tracker. In this case applied mutations will be chained.
But we want that this copy_ will be independent and applied just to primals.
For this introducing a contextmanager to be able to disable update of tensor_tracker for adding forward mutations.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155354
Approved by: https://github.com/bdhirsh
This commit is contained in:
IvanKobzarev 2025-06-23 10:04:16 -07:00 committed by PyTorch MergeBot
parent c82a174cea
commit 3f920f3d8f
10 changed files with 396 additions and 106 deletions

View File

@ -122,6 +122,9 @@ struct TORCH_API FunctionalStorageImpl : public c10::StorageImpl {
~FunctionalStorageImpl() override = default;
uint64_t mutation_counter() {
return mutation_counter_;
}
void mark_mutation() {
mutation_counter_++;
}

View File

@ -74,7 +74,9 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl {
bool has_metadata_mutation() const {
return has_metadata_mutation_;
}
uint64_t mutation_counter() const {
return functional_storage_impl()->mutation_counter();
}
void mark_mutation() {
functional_storage_impl()->mark_mutation();
}

View File

@ -14,7 +14,7 @@ add_loop_inductor_dynamic_gpu,compile_time_instruction_count,39110000000,0.025
add_loop_inductor_gpu,compile_time_instruction_count,26180000000,0.015
add_loop_inductor_gpu,compile_time_instruction_count,25780000000,0.015
@ -62,7 +62,7 @@ aotdispatcher_partitioner_cpu,compile_time_instruction_count,8844000000,0.015
aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1963000000,0.015
aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1917000000,0.015

1 add_loop_eager compile_time_instruction_count 2937000000 0.015
14 aotdispatcher_inference_nosubclass_cpu compile_time_instruction_count 2113000000 0.015
15 aotdispatcher_inference_subclass_cpu compile_time_instruction_count 6022000000 0.015
16 aotdispatcher_partitioner_cpu compile_time_instruction_count 8844000000 0.015
17 aotdispatcher_partitioner_cpu2 compile_time_instruction_count 1963000000 1917000000 0.015
18 aotdispatcher_training_nosubclass_cpu compile_time_instruction_count 3875000000 0.015
19 aotdispatcher_training_subclass_cpu compile_time_instruction_count 10420000000 0.015
20
62
63
64
65
66
67
68

View File

@ -7842,6 +7842,53 @@ class TestAOTAutogradWithDynamo(TestAOTAutograd):
self.assertEqual(ref_inps_after_fw, inps_after_fw)
self.assertEqual(ref_inps_after_bw, inps_after_bw)
def test_mutation_of_input_in_fw_and_bw(self):
class AF(torch.autograd.Function):
@staticmethod
def forward(ctx, dummy, inplace_tensor):
inplace_tensor.add_(1)
ctx.inplace_tensor = inplace_tensor
return dummy.clone()
@staticmethod
def backward(ctx, grad_output):
inplace_tensor = ctx.inplace_tensor
inplace_tensor.add_(1)
return grad_output, None, None
def fn(dummy, inplace_tensor):
return AF.apply(dummy, inplace_tensor)
def inps():
dummy = torch.randn((2,), requires_grad=True)
inplace_tensor = torch.zeros((2,), requires_grad=False)
return dummy, inplace_tensor
def sc_inps():
dummy = TwoTensor(
torch.randn((2,), requires_grad=True),
torch.randn((2,), requires_grad=True),
)
inplace_tensor = TwoTensor(
torch.zeros((2,), requires_grad=False),
torch.zeros((2,), requires_grad=False),
)
return dummy, inplace_tensor
for _inps in [inps, sc_inps]:
dummy, inplace = _inps()
y = fn(dummy, inplace)
ref0 = inplace.clone().detach()
y.sum().backward()
ref = inplace.clone().detach()
dummy, inplace = _inps()
y = torch.compile(fn, backend="aot_eager", fullgraph=True)(dummy, inplace)
self.assertEqual(ref0, inplace)
y.sum().backward()
self.assertEqual(ref, inplace)
class MockFXGraphCache:
"""

View File

@ -912,6 +912,13 @@ def gen_pyi(
"None",
)
],
"_functionalize_mutation_counter": [
defs(
"_functionalize_mutation_counter",
["t: Tensor"],
"_int",
)
],
"_functionalize_are_all_mutations_hidden_from_autograd": [
defs(
"_functionalize_are_all_mutations_hidden_from_autograd",

View File

@ -265,6 +265,7 @@ def aot_dispatch_autograd_graph(
fw_metadata,
)
joint_fn_to_trace = create_joint(fn_prepared_for_autograd, aot_config=aot_config)
joint_fn_handle = joint_fn_to_trace.handle
joint_fn_to_trace, updated_joint_inputs = create_functionalized_fn(
joint_fn_to_trace,
@ -272,6 +273,7 @@ def aot_dispatch_autograd_graph(
meta=fw_metadata,
aot_config=aot_config,
trace_joint=True,
joint_fn_handle=joint_fn_handle,
)
# TODO: replace with AOTDispatchSubclassWrapper once we refactor

View File

@ -10,11 +10,11 @@ It does so by:
3. transforming mutations into extra outputs
4. dispatching subclasses
"""
import warnings
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
from functools import wraps
from typing import Any, Callable, Union
from typing import Any, Callable, Optional, Union
from unittest.mock import patch
import torch
@ -25,6 +25,7 @@ from torch._decomp.decompositions_for_rng import PhiloxStateTracker
from torch._guards import detect_fake_mode
from torch._prims_common import CUDARngStateHelper
from torch.fx.experimental.proxy_tensor import (
_proxy_tensor_disable_update_tensor_tracker,
maybe_disable_thunkify,
maybe_enable_thunkify,
)
@ -34,6 +35,7 @@ from torch.fx.experimental.symbolic_shapes import (
sym_eq,
)
from torch.nn.utils import stateless
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
from .. import config
from .collect_metadata_analysis import run_functionalized_fw_and_collect_metadata
@ -182,6 +184,11 @@ def fn_prepped_for_autograd(
return inner_fn
@dataclass
class JointFnHandle:
post_forward: Optional[Callable] = None
# Given a fn, computes the joint.
# NOTE: fn is expects the following behavior:
# (1) fn() needs to return a tuple of (outs, mask),
@ -193,9 +200,15 @@ def fn_prepped_for_autograd(
# otherwise, when we compute autograd.grad(), we will not take those input mutations into account
# (the way this is handled is that we ensure any inputs that normally get mutated are cloned first)
def create_joint(fn: Callable, *, aot_config: AOTConfig) -> Any:
joint_fn_handle = JointFnHandle()
# post_forward
def inner_fn(primals: list[Any], tangents: list[Any]):
outs, tangent_mask = fn(*primals)
if joint_fn_handle and joint_fn_handle.post_forward:
joint_fn_handle.post_forward(primals)
assert len(tangent_mask) == len(outs)
outs_to_grad = [
o for needs_tangent, o in zip(tangent_mask, outs) if needs_tangent
@ -285,6 +298,8 @@ def create_joint(fn: Callable, *, aot_config: AOTConfig) -> Any:
with torch.autograd.detect_anomaly(check_nan=False):
return inner_fn(*args)
inner_fn_with_anomaly.handle = joint_fn_handle # type: ignore[attr-defined]
return inner_fn_with_anomaly
@ -379,6 +394,118 @@ def set_partitioner_tag_must_be_in_backward():
return set_partitioner_tag("must_be_in_backward")
def set_partitioner_tag_must_be_in_forward():
return set_partitioner_tag("must_be_in_forward")
def _get_mutation_counter(t) -> int:
if not is_traceable_wrapper_subclass(t):
return torch._functionalize_mutation_counter(t.elem) # type: ignore[attr-defined]
max_mc = -1
def visit(e):
if not is_traceable_wrapper_subclass(e):
mc = torch._functionalize_mutation_counter(e.elem) # type: ignore[attr-defined]
nonlocal max_mc
max_mc = max(mc, max_mc)
return
for a in e.__tensor_flatten__()[0]:
visit(getattr(e, a))
visit(t)
return max_mc
def apply_in_graph_mutations(input_info, inpt_old, inpt_new, f_inpt, input_idx):
assert input_info.mutation_type == MutationType.MUTATED_IN_GRAPH
# See Note [set_() Input Mutations in AOTAutograd]
# all mutations on the input must be under no_grad, so it is safe to put in the graph
# Here, we're saying that if an input experienced a set call, inp.set_(other),
# then we can effectively not have to worry about whether its data was mutated.
# There are 3 cases:
# (1) We mutate inp *after* the set_() call. other is a graph intermediate.
# In this case, we're not really mutating the input storage of "inp";
# we're mutating the storage of an intermdiate value (other),
# and slamming that storage into the input tensor. So no data mutation is necessary.
# (2) We mutate inp *after* the set_() call. other is a graph *input*.
# In this case, the data mutation will be properly handled in the runtime
# epilogue during the processing of "other"
# (3) We mutate inp *before* the set_() call.
# This case is *not* currently handled.
if input_info.mutates_storage_metadata:
with torch.no_grad():
inpt_old.set_(inpt_new)
# Note [Ordering of resize_() and set_()]
# Importantly: the common usage in FSDP is that we have a dummy parameter
# that sees a set_() and **Then** a resize_().
# We must put those mutations into the graph in the same order,
# Since running them in the opposite order will have different behavior.
# We fully ban resize_() followed by set_() for now, although in principal
# we could support this
if input_info.mutation_inductor_storage_resize:
# resizing is not supported on subclasses (we error earlier if this happens)
from torch._subclasses.functional_tensor import FunctionalTensor
assert isinstance(f_inpt, FunctionalTensor)
old_storage_size = torch._functionalize_get_storage_size( # type: ignore[attr-defined]
f_inpt.elem, before=True
)
new_storage_size = torch._functionalize_get_storage_size( # type: ignore[attr-defined]
f_inpt.elem, before=False
)
if old_storage_size != new_storage_size:
assert (
old_storage_size == 0 or new_storage_size == 0
), f"""\
Encosize during tracing on input {input_idx}. Old nbytes={old_storage_size}, new nbytes={new_storage_size}
We oresizing on graph inputs as long as the input either starts or ends with a storage size of 0
(thee for FSDP)"""
torch.ops.inductor.resize_storage_bytes_(inpt_old, new_storage_size)
if new_storage_size == 0:
# Even if we marked the input as having a data mutation (thus needing a copy_()),
# We should **ignore** it if our input has no storage
# (this can happen if, e.g. we temporarily resize our input, copy data into it,
# and resize it back down to zero)
return
# Optimization: if the copy_() is a no-op then don't include it in the graph.
# In theory inductor could optimize this away, however in fsdp, we end up with
# param.copy_(param), where param is a zero-storage-size tensor,
# and running this op in eager mode (using the aot_eager backend) will result in a segfault.
# So we may as well optimize it away here.
if inpt_old is inpt_new:
# (This check needs to be done after putting resize_() in the graph,
# since a resize_(0) doesn't actually change the FunctionalTensor's inner tensor)
return
# We found an input that had a (data-only) mutation.
# Since keep_input_mutations is set, we need to faithfully apply a copy_()
# so the compiler will see the input mutation in the graph.
if input_info.mutates_data and input_info.mutations_hidden_from_autograd:
# Hidden from autograd = run under no_grad, **and** don't bump VC
# (although if the tensor was created in inference mode, it has no VC)
if inpt_old.is_inference():
maybe_preserve_vc = nullcontext()
else:
maybe_preserve_vc = torch.autograd._unsafe_preserve_version_counter(
inpt_old # type: ignore[assignment]
)
with torch.no_grad(), maybe_preserve_vc:
inpt_old.copy_(inpt_new)
elif (
input_info.mutates_data and input_info.mutations_under_no_grad_or_inference_mode
):
# Under no_grad = run under no_grad (we still bump the VC though)
# (inference_mode will also bump the VC, as long as the tensor in question
# was created outside of inference_mode)
with torch.no_grad():
inpt_old.copy_(inpt_new)
elif input_info.mutates_data:
inpt_old.copy_(inpt_new)
# This creates the final function that we want to trace using make_fx(),
# in both aot_dispatch_autograd and aot_dispatch_base.
# Preconditions:
@ -398,7 +525,16 @@ def create_functionalized_fn(
meta: ViewAndMutationMeta,
aot_config: AOTConfig,
trace_joint: bool,
joint_fn_handle: Optional[JointFnHandle] = None,
) -> Any:
primals_after_forward = None
f_args_after_forward = None
f_args_mutation_counter_after_forward: Optional[list[int]] = None
inputs_mutated_in_graph = [
info.mutation_type == MutationType.MUTATED_IN_GRAPH for info in meta.input_info
]
has_input_mutated_in_graph = any(inputs_mutated_in_graph)
@wraps(fn)
def _functionalized_f_helper(*args):
with maybe_enable_thunkify():
@ -415,6 +551,24 @@ def create_functionalized_fn(
# Wrap inputs into functional wrappers
f_args = pytree.tree_map(to_fun, args)
if trace_joint and has_input_mutated_in_graph and joint_fn_handle:
# TODO(ivankobzarev): Support fw and bw mutations for subclasses
def _post_forward(primals):
nonlocal primals_after_forward
primals_after_forward = pytree.tree_map(from_fun, primals)
nonlocal f_args_after_forward
f_args_after_forward = f_args[0]
nonlocal f_args_mutation_counter_after_forward
f_args_mutation_counter_after_forward = [
-1
if not inputs_mutated_in_graph[i]
else _get_mutation_counter(f_arg)
for i, f_arg in enumerate(f_args_after_forward)
]
joint_fn_handle.post_forward = _post_forward
# Run the joint
f_outs = fn(*f_args)
@ -535,110 +689,86 @@ def create_functionalized_fn(
# we will materialize an "updated" synthetic base, and copy it back to the synthetic input base.
# This allows us to factor aot autograd much more nicely, since only one area of the code needs to worry
# about synthetic bases.
for i, (inpt_old, inpt_f) in enumerate(
# Apply in graph forward mutations only in joint case.
# Note: Mutations of primals in forward AND backward.
# If we have mutations of the same input in forward and in backward,
# we can not fuse them into one copy_ node. As in this case partitioner will put it
# either in forward or in backward. This will lead to incorrect state
# after forward and before backward.
# We have to emit two copy_ nodes, marking with additional meta each node,
# if it must be in forward or backward.
# We memorize mutation counter of the inputs after forward.
# Based on this after joint graph we check if backward also mutated input or not.
# We emit copy_ only in the end of joint tracing, to provide invariant for joint
# graph passes, that our graph is functional, except only some number of copy_ nodes
# in the end.
inputs_mutated_in_graph_applied_mutation_counters: list[int] = [
0
] * len(meta.input_info)
if f_args_mutation_counter_after_forward is not None:
primals_before = args[0]
for idx, (f_inpt, before, after, inpt_info) in enumerate(
zip(
f_args_after_forward, # type: ignore[arg-type]
primals_before, # type: ignore[arg-type]
primals_after_forward, # type: ignore[arg-type]
meta.input_info,
)
):
if inpt_info.mutation_type != MutationType.MUTATED_IN_GRAPH:
continue
assert f_args_mutation_counter_after_forward
post_fw_mc = f_args_mutation_counter_after_forward[idx]
mc = _get_mutation_counter(f_inpt)
if mc > 0:
# Mutation in forward.
with (
torch.fx.traceback.preserve_node_meta(),
set_partitioner_tag_must_be_in_forward(),
_proxy_tensor_disable_update_tensor_tracker(),
):
apply_in_graph_mutations(
inpt_info,
before,
after,
f_inpt,
idx,
)
inputs_mutated_in_graph_applied_mutation_counters[
idx
] = post_fw_mc
for idx, (inpt_old, f_inpt) in enumerate(
zip(args, f_args) if not trace_joint else zip(args[0], f_args[0])
):
if not isinstance(inpt_f, torch.Tensor):
if not isinstance(f_inpt, torch.Tensor):
continue
assert is_fun(inpt_f)
inpt_new = from_fun(inpt_f)
assert is_fun(f_inpt)
inpt_new = from_fun(f_inpt)
if (
meta.input_info[i].mutation_type
== MutationType.MUTATED_IN_GRAPH
meta.input_info[idx].mutation_type
!= MutationType.MUTATED_IN_GRAPH
):
# See Note [set_() Input Mutations in AOTAutograd]
# all mutations on the input must be under no_grad, so it is safe to put in the graph
# Here, we're saying that if an input experienced a set call, inp.set_(other),
# then we can effectively not have to worry about whether its data was mutated.
# There are 3 cases:
# (1) We mutate inp *after* the set_() call. other is a graph intermediate.
# In this case, we're not really mutating the input storage of "inp";
# we're mutating the storage of an intermdiate value (other),
# and slamming that storage into the input tensor. So no data mutation is necessary.
# (2) We mutate inp *after* the set_() call. other is a graph *input*.
# In this case, the data mutation will be properly handled in the runtime
# epilogue during the processing of "other"
# (3) We mutate inp *before* the set_() call.
# This case is *not* currently handled.
if meta.input_info[i].mutates_storage_metadata:
with torch.no_grad():
inpt_old.set_(inpt_new)
# Note [Ordering of resize_() and set_()]
# Importantly: the common usage in FSDP is that we have a dummy parameter
# that sees a set_() and **Then** a resize_().
# We must put those mutations into the graph in the same order,
# Since running them in the opposite order will have different behavior.
# We fully ban resize_() followed by set_() for now, although in principal
# we could support this
if meta.input_info[i].mutation_inductor_storage_resize:
# resizing is not supported on subclasses (we error earlier if this happens)
from torch._subclasses.functional_tensor import (
FunctionalTensor,
)
assert isinstance(inpt_f, FunctionalTensor)
old_storage_size = torch._functionalize_get_storage_size( # type: ignore[attr-defined]
inpt_f.elem, before=True
)
new_storage_size = torch._functionalize_get_storage_size( # type: ignore[attr-defined]
inpt_f.elem, before=False
)
if old_storage_size != new_storage_size:
assert (
old_storage_size == 0 or new_storage_size == 0
), f"""\
Encountered a storage resize during tracing on input {i}. Old nbytes={old_storage_size}, new nbytes={new_storage_size}
We only support storage resizing on graph inputs as long as the input either starts or ends with a storage size of 0
(the case for FSDP)"""
torch.ops.inductor.resize_storage_bytes_(
inpt_old, new_storage_size
)
if new_storage_size == 0:
# Even if we marked the input as having a data mutation (thus needing a copy_()),
# We should **ignore** it if our input has no storage
# (this can happen if, e.g. we temporarily resize our input, copy data into it,
# and resize it back down to zero)
continue
# Optimization: if the copy_() is a no-op then don't include it in the graph.
# In theory inductor could optimize this away, however in fsdp, we end up with
# param.copy_(param), where param is a zero-storage-size tensor,
# and running this op in eager mode (using the aot_eager backend) will result in a segfault.
# So we may as well optimize it away here.
if inpt_old is inpt_new:
# (This check needs to be done after putting resize_() in the graph,
# since a resize_(0) doesn't actually change the FunctionalTensor's inner tensor)
continue
if f_args_mutation_counter_after_forward is not None:
# This could happen for subclasses tracing
# Subclasses support for mutations in fw and bw is TBD.
mc = _get_mutation_counter(f_inpt)
if mc == inputs_mutated_in_graph_applied_mutation_counters[idx]:
# No mutation in backward; mutation was already applied.
continue
# We found an input that had a (data-only) mutation.
# Since keep_input_mutations is set, we need to faithfully apply a copy_()
# so the compiler will see the input mutation in the graph.
if (
meta.input_info[i].mutates_data
and meta.input_info[i].mutations_hidden_from_autograd
):
# Hidden from autograd = run under no_grad, **and** don't bump VC
# (although if the tensor was created in inference mode, it has no VC)
if inpt_old.is_inference():
maybe_preserve_vc = nullcontext()
else:
maybe_preserve_vc = torch.autograd._unsafe_preserve_version_counter(
inpt_old # type: ignore[assignment]
)
with torch.no_grad(), maybe_preserve_vc:
inpt_old.copy_(inpt_new)
elif (
meta.input_info[i].mutates_data
and meta.input_info[
i
].mutations_under_no_grad_or_inference_mode
):
# Under no_grad = run under no_grad (we still bump the VC though)
# (inference_mode will also bump the VC, as long as the tensor in question
# was created outside of inference_mode)
with torch.no_grad():
inpt_old.copy_(inpt_new)
elif meta.input_info[i].mutates_data:
inpt_old.copy_(inpt_new)
with torch.fx.traceback.preserve_node_meta(), set_partitioner_tag_must_be_in_backward():
apply_in_graph_mutations(
meta.input_info[idx],
inpt_old,
inpt_new,
f_inpt,
idx,
)
# When an output tensor is a functionalized mutated input, and we
# were able to move the mutation in to the graph then we can return

View File

@ -201,6 +201,10 @@ def _extract_graph_with_inputs_outputs(
env[node] = InvalidNode # type: ignore[assignment]
continue
if _must_be_in_forward(node) and subgraph != "forward":
env[node] = InvalidNode # type: ignore[assignment]
continue
if node in env:
# Node must be one of our inputs. (Any member of env which wasn't an
# input to start must have been created by this loop and won't be in
@ -274,10 +278,18 @@ def _has_tag_is_backward(node: fx.Node) -> bool:
return node.meta.get("partitioner_tag", None) == "is_backward"
def _has_tag_must_be_in_forward(node: fx.Node) -> bool:
return node.meta.get("partitioner_tag", None) == "must_be_in_forward"
def _has_tag_must_be_in_backward(node: fx.Node) -> bool:
return node.meta.get("partitioner_tag", None) == "must_be_in_backward"
def _must_be_in_forward(node: fx.Node) -> bool:
return _has_tag_must_be_in_forward(node)
def _must_be_in_backward(node: fx.Node) -> bool:
return _has_tag_must_be_in_backward(node) or (
_has_tag_is_backward(node) and is_with_effects(node)
@ -1465,6 +1477,25 @@ def force_save_collectives(joint_module: fx.GraphModule) -> None:
node.meta["recompute"] = CheckpointPolicy.MUST_SAVE
def force_save_bw_mutation_src(joint_module: fx.GraphModule) -> None:
# If we have mutations of the same primal in forward and backward,
# We must not recompute the source of mutation to not apply twice.
has_mutation_in_bw: OrderedSet[torch.fx.Node] = OrderedSet()
for node in reversed(joint_module.graph.nodes):
if (
node.target == torch.ops.aten.copy_.default
and _has_tag_must_be_in_backward(node)
):
has_mutation_in_bw.add(node.args[0])
if (
node.target == torch.ops.aten.copy_.default
and _has_tag_must_be_in_forward(node)
and node.args[0] in has_mutation_in_bw
):
node.args[1].meta["recompute"] = CheckpointPolicy.MUST_SAVE
def cleanup_recompute_tags(joint_module: fx.GraphModule) -> fx.GraphModule:
"""
If there are two consecutive checkpointed blocks with no operator in
@ -2537,6 +2568,7 @@ def min_cut_rematerialization_partition(
joint_module = cleanup_recompute_tags(joint_module)
if not config.unsafe_allow_optimization_of_collectives:
force_save_collectives(joint_module)
force_save_bw_mutation_src(joint_module)
def classify_nodes(joint_module, static_lifetime_input_indices):
name_to_node = get_name_to_node(joint_module.graph)

View File

@ -698,6 +698,11 @@ void initTorchFunctions(PyObject* module) {
auto t_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(t);
return t_impl->has_data_mutation();
});
py_module.def("_functionalize_mutation_counter", [](const at::Tensor& t) {
TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(t));
auto t_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(t);
return t_impl->mutation_counter();
});
py_module.def(
"_functionalize_get_storage_size", [](const at::Tensor& t, bool before) {
TORCH_INTERNAL_ASSERT(

View File

@ -11,6 +11,7 @@ import functools
import inspect
import logging
import operator
import threading
import traceback
import typing
import typing_extensions
@ -181,7 +182,7 @@ def is_sym_node(node: _HasMeta) -> bool:
return "val" in node.meta and isinstance(node.meta["val"], py_sym_types)
@overload
@overload # type: ignore[no-overload-impl]
def set_proxy_slot(obj: Tensor, tracer: _ProxyTracer, proxy: _ProxyTensor) -> None: ...
@ -197,7 +198,66 @@ def set_proxy_slot(
) -> None: ...
def set_proxy_slot(
class _DisableUpdateTensorTracker(threading.local):
value: bool = False
_disable_update_tensor_tracker_tls = _DisableUpdateTensorTracker()
def _is_proxy_tensor_update_tensor_tracker_disabled() -> bool:
"""
Returns current state of disabling update tensor tracker.
"""
return _disable_update_tensor_tracker_tls.value
@contextmanager
def _proxy_tensor_disable_update_tensor_tracker() -> Generator[None, None, None]:
"""
NOTE "Do not clobber inplace ops"
By default tensor_tracker is updated every time.
This leads to chaining every operation by the FakeTensor.
For example for mutable ops if we have several consecutive mutable operations:
def f(x, y, z):
x.copy_(y)
x.copy_(z)
return x
Default graph result:
def f_graph(x, y, z)
x_1 = x.copy_(y)
x_2 = x_1.copy_(z)
return x_2
This chaining simplifies the fx passes and helps to prevent the reordering.
But in some cases, we want those nodes to be disconnected.
E.g. in case of splitting joint graph into forward and backward.
If first inplace op happened in forward, second in backward,
we want them after split to be properly placed.
Enabling this context manager for copy_ will result in:
def f_graph_2(x, y, z):
x_1 = x.copy_(y)
x_2 = x.copy_(z)
return x
Results of copy_ x1 and x2 will have empty users in the graph.
The reason why this behavior is not enabled for all inplace ops is that
some fx passes (e.g. fx quantization) rely on chaining inplace ops like add_
in their fusions passes.
We could revisit enabling this logic for all inplace ops in future.
"""
orig_value = _disable_update_tensor_tracker_tls.value
_disable_update_tensor_tracker_tls.value = True
try:
yield
finally:
_disable_update_tensor_tracker_tls.value = orig_value
def set_proxy_slot( # type: ignore[no-redef]
obj: Union[PySymType, _AnyScriptObjectType, Tensor],
tracer: _ProxyTracer,
proxy: object,
@ -207,7 +267,9 @@ def set_proxy_slot(
# We DO want to clobber proxies whenever we run an inplace operation
# on a tensor, and it affects the metadata on the proxy.
assert isinstance(proxy, _ProxyTensor)
tracer.tensor_tracker[obj] = proxy
# see NOTE [Do not clobber inplace ops]
if not _is_proxy_tensor_update_tensor_tracker_disabled():
tracer.tensor_tracker[obj] = proxy
elif isinstance(obj, (_AnyScriptObject)):
# We DO want to clobber proxies, with a similar rationale as for tensors.
assert isinstance(proxy, Proxy)