mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[aotd] Support mutations of the same input in fw and bw (#155354)
Original issue: https://github.com/pytorch/pytorch/issues/154820 The issue happens when there is a mutation for the same input in forward AND in backward. AOTD emited copy_ after joint_function tracing. This made this fx-node to correspond to the side effects of both mutations (in forward and in backward). After that partitioner can put it either in forward or in backward. The fix: 1/ Introduce joint_function.handle that allows to set "post_forward" callback, to be able to check inputs state after forward We do not want to apply the mutation after joint, if we already applied it in forward. For that we need "mutation_counter" and memorize the version of mutation that we applied for forward mutation. 2/ Exposing mutation_counter to python We want to keep invariant that copy_ exist only in the end of joint graph. 3/ We memorize mutation_counter and state of the inputs after forward, using the handle post_forward. Emit post_forward mutations after joint graph fully traced. add for post_forward mutations "must_be_in_forward" tag (similar to existing "must_be_in_backward") to keep them in forward. 4/ Ban recompute of the source of mutation. Recompute can apply the same op (e.g. add) in forward and backward. For this set MUST_SAVE for the source of mutation in forward. proxy_tensor changes: By default proxy tensor updates tensor_tracker. In this case applied mutations will be chained. But we want that this copy_ will be independent and applied just to primals. For this introducing a contextmanager to be able to disable update of tensor_tracker for adding forward mutations. Pull Request resolved: https://github.com/pytorch/pytorch/pull/155354 Approved by: https://github.com/bdhirsh
This commit is contained in:
parent
c82a174cea
commit
3f920f3d8f
|
|
@ -122,6 +122,9 @@ struct TORCH_API FunctionalStorageImpl : public c10::StorageImpl {
|
||||||
|
|
||||||
~FunctionalStorageImpl() override = default;
|
~FunctionalStorageImpl() override = default;
|
||||||
|
|
||||||
|
uint64_t mutation_counter() {
|
||||||
|
return mutation_counter_;
|
||||||
|
}
|
||||||
void mark_mutation() {
|
void mark_mutation() {
|
||||||
mutation_counter_++;
|
mutation_counter_++;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -74,7 +74,9 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl {
|
||||||
bool has_metadata_mutation() const {
|
bool has_metadata_mutation() const {
|
||||||
return has_metadata_mutation_;
|
return has_metadata_mutation_;
|
||||||
}
|
}
|
||||||
|
uint64_t mutation_counter() const {
|
||||||
|
return functional_storage_impl()->mutation_counter();
|
||||||
|
}
|
||||||
void mark_mutation() {
|
void mark_mutation() {
|
||||||
functional_storage_impl()->mark_mutation();
|
functional_storage_impl()->mark_mutation();
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,7 @@ add_loop_inductor_dynamic_gpu,compile_time_instruction_count,39110000000,0.025
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
add_loop_inductor_gpu,compile_time_instruction_count,26180000000,0.015
|
add_loop_inductor_gpu,compile_time_instruction_count,25780000000,0.015
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -62,7 +62,7 @@ aotdispatcher_partitioner_cpu,compile_time_instruction_count,8844000000,0.015
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1963000000,0.015
|
aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1917000000,0.015
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
|
@ -7842,6 +7842,53 @@ class TestAOTAutogradWithDynamo(TestAOTAutograd):
|
||||||
self.assertEqual(ref_inps_after_fw, inps_after_fw)
|
self.assertEqual(ref_inps_after_fw, inps_after_fw)
|
||||||
self.assertEqual(ref_inps_after_bw, inps_after_bw)
|
self.assertEqual(ref_inps_after_bw, inps_after_bw)
|
||||||
|
|
||||||
|
def test_mutation_of_input_in_fw_and_bw(self):
|
||||||
|
class AF(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, dummy, inplace_tensor):
|
||||||
|
inplace_tensor.add_(1)
|
||||||
|
|
||||||
|
ctx.inplace_tensor = inplace_tensor
|
||||||
|
return dummy.clone()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_output):
|
||||||
|
inplace_tensor = ctx.inplace_tensor
|
||||||
|
inplace_tensor.add_(1)
|
||||||
|
return grad_output, None, None
|
||||||
|
|
||||||
|
def fn(dummy, inplace_tensor):
|
||||||
|
return AF.apply(dummy, inplace_tensor)
|
||||||
|
|
||||||
|
def inps():
|
||||||
|
dummy = torch.randn((2,), requires_grad=True)
|
||||||
|
inplace_tensor = torch.zeros((2,), requires_grad=False)
|
||||||
|
return dummy, inplace_tensor
|
||||||
|
|
||||||
|
def sc_inps():
|
||||||
|
dummy = TwoTensor(
|
||||||
|
torch.randn((2,), requires_grad=True),
|
||||||
|
torch.randn((2,), requires_grad=True),
|
||||||
|
)
|
||||||
|
inplace_tensor = TwoTensor(
|
||||||
|
torch.zeros((2,), requires_grad=False),
|
||||||
|
torch.zeros((2,), requires_grad=False),
|
||||||
|
)
|
||||||
|
return dummy, inplace_tensor
|
||||||
|
|
||||||
|
for _inps in [inps, sc_inps]:
|
||||||
|
dummy, inplace = _inps()
|
||||||
|
y = fn(dummy, inplace)
|
||||||
|
ref0 = inplace.clone().detach()
|
||||||
|
y.sum().backward()
|
||||||
|
ref = inplace.clone().detach()
|
||||||
|
|
||||||
|
dummy, inplace = _inps()
|
||||||
|
y = torch.compile(fn, backend="aot_eager", fullgraph=True)(dummy, inplace)
|
||||||
|
self.assertEqual(ref0, inplace)
|
||||||
|
y.sum().backward()
|
||||||
|
self.assertEqual(ref, inplace)
|
||||||
|
|
||||||
|
|
||||||
class MockFXGraphCache:
|
class MockFXGraphCache:
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -912,6 +912,13 @@ def gen_pyi(
|
||||||
"None",
|
"None",
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
|
"_functionalize_mutation_counter": [
|
||||||
|
defs(
|
||||||
|
"_functionalize_mutation_counter",
|
||||||
|
["t: Tensor"],
|
||||||
|
"_int",
|
||||||
|
)
|
||||||
|
],
|
||||||
"_functionalize_are_all_mutations_hidden_from_autograd": [
|
"_functionalize_are_all_mutations_hidden_from_autograd": [
|
||||||
defs(
|
defs(
|
||||||
"_functionalize_are_all_mutations_hidden_from_autograd",
|
"_functionalize_are_all_mutations_hidden_from_autograd",
|
||||||
|
|
|
||||||
|
|
@ -265,6 +265,7 @@ def aot_dispatch_autograd_graph(
|
||||||
fw_metadata,
|
fw_metadata,
|
||||||
)
|
)
|
||||||
joint_fn_to_trace = create_joint(fn_prepared_for_autograd, aot_config=aot_config)
|
joint_fn_to_trace = create_joint(fn_prepared_for_autograd, aot_config=aot_config)
|
||||||
|
joint_fn_handle = joint_fn_to_trace.handle
|
||||||
|
|
||||||
joint_fn_to_trace, updated_joint_inputs = create_functionalized_fn(
|
joint_fn_to_trace, updated_joint_inputs = create_functionalized_fn(
|
||||||
joint_fn_to_trace,
|
joint_fn_to_trace,
|
||||||
|
|
@ -272,6 +273,7 @@ def aot_dispatch_autograd_graph(
|
||||||
meta=fw_metadata,
|
meta=fw_metadata,
|
||||||
aot_config=aot_config,
|
aot_config=aot_config,
|
||||||
trace_joint=True,
|
trace_joint=True,
|
||||||
|
joint_fn_handle=joint_fn_handle,
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: replace with AOTDispatchSubclassWrapper once we refactor
|
# TODO: replace with AOTDispatchSubclassWrapper once we refactor
|
||||||
|
|
|
||||||
|
|
@ -10,11 +10,11 @@ It does so by:
|
||||||
3. transforming mutations into extra outputs
|
3. transforming mutations into extra outputs
|
||||||
4. dispatching subclasses
|
4. dispatching subclasses
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
from contextlib import contextmanager, nullcontext
|
from contextlib import contextmanager, nullcontext
|
||||||
|
from dataclasses import dataclass
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Any, Callable, Union
|
from typing import Any, Callable, Optional, Union
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
@ -25,6 +25,7 @@ from torch._decomp.decompositions_for_rng import PhiloxStateTracker
|
||||||
from torch._guards import detect_fake_mode
|
from torch._guards import detect_fake_mode
|
||||||
from torch._prims_common import CUDARngStateHelper
|
from torch._prims_common import CUDARngStateHelper
|
||||||
from torch.fx.experimental.proxy_tensor import (
|
from torch.fx.experimental.proxy_tensor import (
|
||||||
|
_proxy_tensor_disable_update_tensor_tracker,
|
||||||
maybe_disable_thunkify,
|
maybe_disable_thunkify,
|
||||||
maybe_enable_thunkify,
|
maybe_enable_thunkify,
|
||||||
)
|
)
|
||||||
|
|
@ -34,6 +35,7 @@ from torch.fx.experimental.symbolic_shapes import (
|
||||||
sym_eq,
|
sym_eq,
|
||||||
)
|
)
|
||||||
from torch.nn.utils import stateless
|
from torch.nn.utils import stateless
|
||||||
|
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
|
||||||
|
|
||||||
from .. import config
|
from .. import config
|
||||||
from .collect_metadata_analysis import run_functionalized_fw_and_collect_metadata
|
from .collect_metadata_analysis import run_functionalized_fw_and_collect_metadata
|
||||||
|
|
@ -182,6 +184,11 @@ def fn_prepped_for_autograd(
|
||||||
return inner_fn
|
return inner_fn
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class JointFnHandle:
|
||||||
|
post_forward: Optional[Callable] = None
|
||||||
|
|
||||||
|
|
||||||
# Given a fn, computes the joint.
|
# Given a fn, computes the joint.
|
||||||
# NOTE: fn is expects the following behavior:
|
# NOTE: fn is expects the following behavior:
|
||||||
# (1) fn() needs to return a tuple of (outs, mask),
|
# (1) fn() needs to return a tuple of (outs, mask),
|
||||||
|
|
@ -193,9 +200,15 @@ def fn_prepped_for_autograd(
|
||||||
# otherwise, when we compute autograd.grad(), we will not take those input mutations into account
|
# otherwise, when we compute autograd.grad(), we will not take those input mutations into account
|
||||||
# (the way this is handled is that we ensure any inputs that normally get mutated are cloned first)
|
# (the way this is handled is that we ensure any inputs that normally get mutated are cloned first)
|
||||||
def create_joint(fn: Callable, *, aot_config: AOTConfig) -> Any:
|
def create_joint(fn: Callable, *, aot_config: AOTConfig) -> Any:
|
||||||
|
joint_fn_handle = JointFnHandle()
|
||||||
|
|
||||||
|
# post_forward
|
||||||
def inner_fn(primals: list[Any], tangents: list[Any]):
|
def inner_fn(primals: list[Any], tangents: list[Any]):
|
||||||
outs, tangent_mask = fn(*primals)
|
outs, tangent_mask = fn(*primals)
|
||||||
|
|
||||||
|
if joint_fn_handle and joint_fn_handle.post_forward:
|
||||||
|
joint_fn_handle.post_forward(primals)
|
||||||
|
|
||||||
assert len(tangent_mask) == len(outs)
|
assert len(tangent_mask) == len(outs)
|
||||||
outs_to_grad = [
|
outs_to_grad = [
|
||||||
o for needs_tangent, o in zip(tangent_mask, outs) if needs_tangent
|
o for needs_tangent, o in zip(tangent_mask, outs) if needs_tangent
|
||||||
|
|
@ -285,6 +298,8 @@ def create_joint(fn: Callable, *, aot_config: AOTConfig) -> Any:
|
||||||
with torch.autograd.detect_anomaly(check_nan=False):
|
with torch.autograd.detect_anomaly(check_nan=False):
|
||||||
return inner_fn(*args)
|
return inner_fn(*args)
|
||||||
|
|
||||||
|
inner_fn_with_anomaly.handle = joint_fn_handle # type: ignore[attr-defined]
|
||||||
|
|
||||||
return inner_fn_with_anomaly
|
return inner_fn_with_anomaly
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -379,6 +394,118 @@ def set_partitioner_tag_must_be_in_backward():
|
||||||
return set_partitioner_tag("must_be_in_backward")
|
return set_partitioner_tag("must_be_in_backward")
|
||||||
|
|
||||||
|
|
||||||
|
def set_partitioner_tag_must_be_in_forward():
|
||||||
|
return set_partitioner_tag("must_be_in_forward")
|
||||||
|
|
||||||
|
|
||||||
|
def _get_mutation_counter(t) -> int:
|
||||||
|
if not is_traceable_wrapper_subclass(t):
|
||||||
|
return torch._functionalize_mutation_counter(t.elem) # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
max_mc = -1
|
||||||
|
|
||||||
|
def visit(e):
|
||||||
|
if not is_traceable_wrapper_subclass(e):
|
||||||
|
mc = torch._functionalize_mutation_counter(e.elem) # type: ignore[attr-defined]
|
||||||
|
nonlocal max_mc
|
||||||
|
max_mc = max(mc, max_mc)
|
||||||
|
return
|
||||||
|
|
||||||
|
for a in e.__tensor_flatten__()[0]:
|
||||||
|
visit(getattr(e, a))
|
||||||
|
|
||||||
|
visit(t)
|
||||||
|
return max_mc
|
||||||
|
|
||||||
|
|
||||||
|
def apply_in_graph_mutations(input_info, inpt_old, inpt_new, f_inpt, input_idx):
|
||||||
|
assert input_info.mutation_type == MutationType.MUTATED_IN_GRAPH
|
||||||
|
# See Note [set_() Input Mutations in AOTAutograd]
|
||||||
|
# all mutations on the input must be under no_grad, so it is safe to put in the graph
|
||||||
|
# Here, we're saying that if an input experienced a set call, inp.set_(other),
|
||||||
|
# then we can effectively not have to worry about whether its data was mutated.
|
||||||
|
# There are 3 cases:
|
||||||
|
# (1) We mutate inp *after* the set_() call. other is a graph intermediate.
|
||||||
|
# In this case, we're not really mutating the input storage of "inp";
|
||||||
|
# we're mutating the storage of an intermdiate value (other),
|
||||||
|
# and slamming that storage into the input tensor. So no data mutation is necessary.
|
||||||
|
# (2) We mutate inp *after* the set_() call. other is a graph *input*.
|
||||||
|
# In this case, the data mutation will be properly handled in the runtime
|
||||||
|
# epilogue during the processing of "other"
|
||||||
|
# (3) We mutate inp *before* the set_() call.
|
||||||
|
# This case is *not* currently handled.
|
||||||
|
if input_info.mutates_storage_metadata:
|
||||||
|
with torch.no_grad():
|
||||||
|
inpt_old.set_(inpt_new)
|
||||||
|
|
||||||
|
# Note [Ordering of resize_() and set_()]
|
||||||
|
# Importantly: the common usage in FSDP is that we have a dummy parameter
|
||||||
|
# that sees a set_() and **Then** a resize_().
|
||||||
|
# We must put those mutations into the graph in the same order,
|
||||||
|
# Since running them in the opposite order will have different behavior.
|
||||||
|
# We fully ban resize_() followed by set_() for now, although in principal
|
||||||
|
# we could support this
|
||||||
|
if input_info.mutation_inductor_storage_resize:
|
||||||
|
# resizing is not supported on subclasses (we error earlier if this happens)
|
||||||
|
from torch._subclasses.functional_tensor import FunctionalTensor
|
||||||
|
|
||||||
|
assert isinstance(f_inpt, FunctionalTensor)
|
||||||
|
old_storage_size = torch._functionalize_get_storage_size( # type: ignore[attr-defined]
|
||||||
|
f_inpt.elem, before=True
|
||||||
|
)
|
||||||
|
new_storage_size = torch._functionalize_get_storage_size( # type: ignore[attr-defined]
|
||||||
|
f_inpt.elem, before=False
|
||||||
|
)
|
||||||
|
if old_storage_size != new_storage_size:
|
||||||
|
assert (
|
||||||
|
old_storage_size == 0 or new_storage_size == 0
|
||||||
|
), f"""\
|
||||||
|
Encosize during tracing on input {input_idx}. Old nbytes={old_storage_size}, new nbytes={new_storage_size}
|
||||||
|
We oresizing on graph inputs as long as the input either starts or ends with a storage size of 0
|
||||||
|
(thee for FSDP)"""
|
||||||
|
torch.ops.inductor.resize_storage_bytes_(inpt_old, new_storage_size)
|
||||||
|
if new_storage_size == 0:
|
||||||
|
# Even if we marked the input as having a data mutation (thus needing a copy_()),
|
||||||
|
# We should **ignore** it if our input has no storage
|
||||||
|
# (this can happen if, e.g. we temporarily resize our input, copy data into it,
|
||||||
|
# and resize it back down to zero)
|
||||||
|
return
|
||||||
|
# Optimization: if the copy_() is a no-op then don't include it in the graph.
|
||||||
|
# In theory inductor could optimize this away, however in fsdp, we end up with
|
||||||
|
# param.copy_(param), where param is a zero-storage-size tensor,
|
||||||
|
# and running this op in eager mode (using the aot_eager backend) will result in a segfault.
|
||||||
|
# So we may as well optimize it away here.
|
||||||
|
if inpt_old is inpt_new:
|
||||||
|
# (This check needs to be done after putting resize_() in the graph,
|
||||||
|
# since a resize_(0) doesn't actually change the FunctionalTensor's inner tensor)
|
||||||
|
return
|
||||||
|
# We found an input that had a (data-only) mutation.
|
||||||
|
# Since keep_input_mutations is set, we need to faithfully apply a copy_()
|
||||||
|
# so the compiler will see the input mutation in the graph.
|
||||||
|
if input_info.mutates_data and input_info.mutations_hidden_from_autograd:
|
||||||
|
# Hidden from autograd = run under no_grad, **and** don't bump VC
|
||||||
|
# (although if the tensor was created in inference mode, it has no VC)
|
||||||
|
if inpt_old.is_inference():
|
||||||
|
maybe_preserve_vc = nullcontext()
|
||||||
|
else:
|
||||||
|
maybe_preserve_vc = torch.autograd._unsafe_preserve_version_counter(
|
||||||
|
inpt_old # type: ignore[assignment]
|
||||||
|
)
|
||||||
|
with torch.no_grad(), maybe_preserve_vc:
|
||||||
|
inpt_old.copy_(inpt_new)
|
||||||
|
elif (
|
||||||
|
input_info.mutates_data and input_info.mutations_under_no_grad_or_inference_mode
|
||||||
|
):
|
||||||
|
# Under no_grad = run under no_grad (we still bump the VC though)
|
||||||
|
# (inference_mode will also bump the VC, as long as the tensor in question
|
||||||
|
# was created outside of inference_mode)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
inpt_old.copy_(inpt_new)
|
||||||
|
elif input_info.mutates_data:
|
||||||
|
inpt_old.copy_(inpt_new)
|
||||||
|
|
||||||
|
|
||||||
# This creates the final function that we want to trace using make_fx(),
|
# This creates the final function that we want to trace using make_fx(),
|
||||||
# in both aot_dispatch_autograd and aot_dispatch_base.
|
# in both aot_dispatch_autograd and aot_dispatch_base.
|
||||||
# Preconditions:
|
# Preconditions:
|
||||||
|
|
@ -398,7 +525,16 @@ def create_functionalized_fn(
|
||||||
meta: ViewAndMutationMeta,
|
meta: ViewAndMutationMeta,
|
||||||
aot_config: AOTConfig,
|
aot_config: AOTConfig,
|
||||||
trace_joint: bool,
|
trace_joint: bool,
|
||||||
|
joint_fn_handle: Optional[JointFnHandle] = None,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
|
primals_after_forward = None
|
||||||
|
f_args_after_forward = None
|
||||||
|
f_args_mutation_counter_after_forward: Optional[list[int]] = None
|
||||||
|
inputs_mutated_in_graph = [
|
||||||
|
info.mutation_type == MutationType.MUTATED_IN_GRAPH for info in meta.input_info
|
||||||
|
]
|
||||||
|
has_input_mutated_in_graph = any(inputs_mutated_in_graph)
|
||||||
|
|
||||||
@wraps(fn)
|
@wraps(fn)
|
||||||
def _functionalized_f_helper(*args):
|
def _functionalized_f_helper(*args):
|
||||||
with maybe_enable_thunkify():
|
with maybe_enable_thunkify():
|
||||||
|
|
@ -415,6 +551,24 @@ def create_functionalized_fn(
|
||||||
# Wrap inputs into functional wrappers
|
# Wrap inputs into functional wrappers
|
||||||
f_args = pytree.tree_map(to_fun, args)
|
f_args = pytree.tree_map(to_fun, args)
|
||||||
|
|
||||||
|
if trace_joint and has_input_mutated_in_graph and joint_fn_handle:
|
||||||
|
# TODO(ivankobzarev): Support fw and bw mutations for subclasses
|
||||||
|
def _post_forward(primals):
|
||||||
|
nonlocal primals_after_forward
|
||||||
|
primals_after_forward = pytree.tree_map(from_fun, primals)
|
||||||
|
nonlocal f_args_after_forward
|
||||||
|
f_args_after_forward = f_args[0]
|
||||||
|
nonlocal f_args_mutation_counter_after_forward
|
||||||
|
|
||||||
|
f_args_mutation_counter_after_forward = [
|
||||||
|
-1
|
||||||
|
if not inputs_mutated_in_graph[i]
|
||||||
|
else _get_mutation_counter(f_arg)
|
||||||
|
for i, f_arg in enumerate(f_args_after_forward)
|
||||||
|
]
|
||||||
|
|
||||||
|
joint_fn_handle.post_forward = _post_forward
|
||||||
|
|
||||||
# Run the joint
|
# Run the joint
|
||||||
f_outs = fn(*f_args)
|
f_outs = fn(*f_args)
|
||||||
|
|
||||||
|
|
@ -535,110 +689,86 @@ def create_functionalized_fn(
|
||||||
# we will materialize an "updated" synthetic base, and copy it back to the synthetic input base.
|
# we will materialize an "updated" synthetic base, and copy it back to the synthetic input base.
|
||||||
# This allows us to factor aot autograd much more nicely, since only one area of the code needs to worry
|
# This allows us to factor aot autograd much more nicely, since only one area of the code needs to worry
|
||||||
# about synthetic bases.
|
# about synthetic bases.
|
||||||
for i, (inpt_old, inpt_f) in enumerate(
|
|
||||||
|
# Apply in graph forward mutations only in joint case.
|
||||||
|
# Note: Mutations of primals in forward AND backward.
|
||||||
|
# If we have mutations of the same input in forward and in backward,
|
||||||
|
# we can not fuse them into one copy_ node. As in this case partitioner will put it
|
||||||
|
# either in forward or in backward. This will lead to incorrect state
|
||||||
|
# after forward and before backward.
|
||||||
|
# We have to emit two copy_ nodes, marking with additional meta each node,
|
||||||
|
# if it must be in forward or backward.
|
||||||
|
# We memorize mutation counter of the inputs after forward.
|
||||||
|
# Based on this after joint graph we check if backward also mutated input or not.
|
||||||
|
# We emit copy_ only in the end of joint tracing, to provide invariant for joint
|
||||||
|
# graph passes, that our graph is functional, except only some number of copy_ nodes
|
||||||
|
# in the end.
|
||||||
|
inputs_mutated_in_graph_applied_mutation_counters: list[int] = [
|
||||||
|
0
|
||||||
|
] * len(meta.input_info)
|
||||||
|
if f_args_mutation_counter_after_forward is not None:
|
||||||
|
primals_before = args[0]
|
||||||
|
for idx, (f_inpt, before, after, inpt_info) in enumerate(
|
||||||
|
zip(
|
||||||
|
f_args_after_forward, # type: ignore[arg-type]
|
||||||
|
primals_before, # type: ignore[arg-type]
|
||||||
|
primals_after_forward, # type: ignore[arg-type]
|
||||||
|
meta.input_info,
|
||||||
|
)
|
||||||
|
):
|
||||||
|
if inpt_info.mutation_type != MutationType.MUTATED_IN_GRAPH:
|
||||||
|
continue
|
||||||
|
|
||||||
|
assert f_args_mutation_counter_after_forward
|
||||||
|
post_fw_mc = f_args_mutation_counter_after_forward[idx]
|
||||||
|
mc = _get_mutation_counter(f_inpt)
|
||||||
|
|
||||||
|
if mc > 0:
|
||||||
|
# Mutation in forward.
|
||||||
|
with (
|
||||||
|
torch.fx.traceback.preserve_node_meta(),
|
||||||
|
set_partitioner_tag_must_be_in_forward(),
|
||||||
|
_proxy_tensor_disable_update_tensor_tracker(),
|
||||||
|
):
|
||||||
|
apply_in_graph_mutations(
|
||||||
|
inpt_info,
|
||||||
|
before,
|
||||||
|
after,
|
||||||
|
f_inpt,
|
||||||
|
idx,
|
||||||
|
)
|
||||||
|
inputs_mutated_in_graph_applied_mutation_counters[
|
||||||
|
idx
|
||||||
|
] = post_fw_mc
|
||||||
|
|
||||||
|
for idx, (inpt_old, f_inpt) in enumerate(
|
||||||
zip(args, f_args) if not trace_joint else zip(args[0], f_args[0])
|
zip(args, f_args) if not trace_joint else zip(args[0], f_args[0])
|
||||||
):
|
):
|
||||||
if not isinstance(inpt_f, torch.Tensor):
|
if not isinstance(f_inpt, torch.Tensor):
|
||||||
continue
|
continue
|
||||||
assert is_fun(inpt_f)
|
assert is_fun(f_inpt)
|
||||||
inpt_new = from_fun(inpt_f)
|
inpt_new = from_fun(f_inpt)
|
||||||
if (
|
if (
|
||||||
meta.input_info[i].mutation_type
|
meta.input_info[idx].mutation_type
|
||||||
== MutationType.MUTATED_IN_GRAPH
|
!= MutationType.MUTATED_IN_GRAPH
|
||||||
):
|
):
|
||||||
# See Note [set_() Input Mutations in AOTAutograd]
|
continue
|
||||||
# all mutations on the input must be under no_grad, so it is safe to put in the graph
|
if f_args_mutation_counter_after_forward is not None:
|
||||||
# Here, we're saying that if an input experienced a set call, inp.set_(other),
|
# This could happen for subclasses tracing
|
||||||
# then we can effectively not have to worry about whether its data was mutated.
|
# Subclasses support for mutations in fw and bw is TBD.
|
||||||
# There are 3 cases:
|
mc = _get_mutation_counter(f_inpt)
|
||||||
# (1) We mutate inp *after* the set_() call. other is a graph intermediate.
|
if mc == inputs_mutated_in_graph_applied_mutation_counters[idx]:
|
||||||
# In this case, we're not really mutating the input storage of "inp";
|
# No mutation in backward; mutation was already applied.
|
||||||
# we're mutating the storage of an intermdiate value (other),
|
|
||||||
# and slamming that storage into the input tensor. So no data mutation is necessary.
|
|
||||||
# (2) We mutate inp *after* the set_() call. other is a graph *input*.
|
|
||||||
# In this case, the data mutation will be properly handled in the runtime
|
|
||||||
# epilogue during the processing of "other"
|
|
||||||
# (3) We mutate inp *before* the set_() call.
|
|
||||||
# This case is *not* currently handled.
|
|
||||||
if meta.input_info[i].mutates_storage_metadata:
|
|
||||||
with torch.no_grad():
|
|
||||||
inpt_old.set_(inpt_new)
|
|
||||||
|
|
||||||
# Note [Ordering of resize_() and set_()]
|
|
||||||
# Importantly: the common usage in FSDP is that we have a dummy parameter
|
|
||||||
# that sees a set_() and **Then** a resize_().
|
|
||||||
# We must put those mutations into the graph in the same order,
|
|
||||||
# Since running them in the opposite order will have different behavior.
|
|
||||||
# We fully ban resize_() followed by set_() for now, although in principal
|
|
||||||
# we could support this
|
|
||||||
if meta.input_info[i].mutation_inductor_storage_resize:
|
|
||||||
# resizing is not supported on subclasses (we error earlier if this happens)
|
|
||||||
from torch._subclasses.functional_tensor import (
|
|
||||||
FunctionalTensor,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert isinstance(inpt_f, FunctionalTensor)
|
|
||||||
old_storage_size = torch._functionalize_get_storage_size( # type: ignore[attr-defined]
|
|
||||||
inpt_f.elem, before=True
|
|
||||||
)
|
|
||||||
new_storage_size = torch._functionalize_get_storage_size( # type: ignore[attr-defined]
|
|
||||||
inpt_f.elem, before=False
|
|
||||||
)
|
|
||||||
if old_storage_size != new_storage_size:
|
|
||||||
assert (
|
|
||||||
old_storage_size == 0 or new_storage_size == 0
|
|
||||||
), f"""\
|
|
||||||
Encountered a storage resize during tracing on input {i}. Old nbytes={old_storage_size}, new nbytes={new_storage_size}
|
|
||||||
We only support storage resizing on graph inputs as long as the input either starts or ends with a storage size of 0
|
|
||||||
(the case for FSDP)"""
|
|
||||||
torch.ops.inductor.resize_storage_bytes_(
|
|
||||||
inpt_old, new_storage_size
|
|
||||||
)
|
|
||||||
if new_storage_size == 0:
|
|
||||||
# Even if we marked the input as having a data mutation (thus needing a copy_()),
|
|
||||||
# We should **ignore** it if our input has no storage
|
|
||||||
# (this can happen if, e.g. we temporarily resize our input, copy data into it,
|
|
||||||
# and resize it back down to zero)
|
|
||||||
continue
|
|
||||||
# Optimization: if the copy_() is a no-op then don't include it in the graph.
|
|
||||||
# In theory inductor could optimize this away, however in fsdp, we end up with
|
|
||||||
# param.copy_(param), where param is a zero-storage-size tensor,
|
|
||||||
# and running this op in eager mode (using the aot_eager backend) will result in a segfault.
|
|
||||||
# So we may as well optimize it away here.
|
|
||||||
if inpt_old is inpt_new:
|
|
||||||
# (This check needs to be done after putting resize_() in the graph,
|
|
||||||
# since a resize_(0) doesn't actually change the FunctionalTensor's inner tensor)
|
|
||||||
continue
|
continue
|
||||||
# We found an input that had a (data-only) mutation.
|
|
||||||
# Since keep_input_mutations is set, we need to faithfully apply a copy_()
|
with torch.fx.traceback.preserve_node_meta(), set_partitioner_tag_must_be_in_backward():
|
||||||
# so the compiler will see the input mutation in the graph.
|
apply_in_graph_mutations(
|
||||||
if (
|
meta.input_info[idx],
|
||||||
meta.input_info[i].mutates_data
|
inpt_old,
|
||||||
and meta.input_info[i].mutations_hidden_from_autograd
|
inpt_new,
|
||||||
):
|
f_inpt,
|
||||||
# Hidden from autograd = run under no_grad, **and** don't bump VC
|
idx,
|
||||||
# (although if the tensor was created in inference mode, it has no VC)
|
)
|
||||||
if inpt_old.is_inference():
|
|
||||||
maybe_preserve_vc = nullcontext()
|
|
||||||
else:
|
|
||||||
maybe_preserve_vc = torch.autograd._unsafe_preserve_version_counter(
|
|
||||||
inpt_old # type: ignore[assignment]
|
|
||||||
)
|
|
||||||
with torch.no_grad(), maybe_preserve_vc:
|
|
||||||
inpt_old.copy_(inpt_new)
|
|
||||||
elif (
|
|
||||||
meta.input_info[i].mutates_data
|
|
||||||
and meta.input_info[
|
|
||||||
i
|
|
||||||
].mutations_under_no_grad_or_inference_mode
|
|
||||||
):
|
|
||||||
# Under no_grad = run under no_grad (we still bump the VC though)
|
|
||||||
# (inference_mode will also bump the VC, as long as the tensor in question
|
|
||||||
# was created outside of inference_mode)
|
|
||||||
with torch.no_grad():
|
|
||||||
inpt_old.copy_(inpt_new)
|
|
||||||
elif meta.input_info[i].mutates_data:
|
|
||||||
inpt_old.copy_(inpt_new)
|
|
||||||
|
|
||||||
# When an output tensor is a functionalized mutated input, and we
|
# When an output tensor is a functionalized mutated input, and we
|
||||||
# were able to move the mutation in to the graph then we can return
|
# were able to move the mutation in to the graph then we can return
|
||||||
|
|
|
||||||
|
|
@ -201,6 +201,10 @@ def _extract_graph_with_inputs_outputs(
|
||||||
env[node] = InvalidNode # type: ignore[assignment]
|
env[node] = InvalidNode # type: ignore[assignment]
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if _must_be_in_forward(node) and subgraph != "forward":
|
||||||
|
env[node] = InvalidNode # type: ignore[assignment]
|
||||||
|
continue
|
||||||
|
|
||||||
if node in env:
|
if node in env:
|
||||||
# Node must be one of our inputs. (Any member of env which wasn't an
|
# Node must be one of our inputs. (Any member of env which wasn't an
|
||||||
# input to start must have been created by this loop and won't be in
|
# input to start must have been created by this loop and won't be in
|
||||||
|
|
@ -274,10 +278,18 @@ def _has_tag_is_backward(node: fx.Node) -> bool:
|
||||||
return node.meta.get("partitioner_tag", None) == "is_backward"
|
return node.meta.get("partitioner_tag", None) == "is_backward"
|
||||||
|
|
||||||
|
|
||||||
|
def _has_tag_must_be_in_forward(node: fx.Node) -> bool:
|
||||||
|
return node.meta.get("partitioner_tag", None) == "must_be_in_forward"
|
||||||
|
|
||||||
|
|
||||||
def _has_tag_must_be_in_backward(node: fx.Node) -> bool:
|
def _has_tag_must_be_in_backward(node: fx.Node) -> bool:
|
||||||
return node.meta.get("partitioner_tag", None) == "must_be_in_backward"
|
return node.meta.get("partitioner_tag", None) == "must_be_in_backward"
|
||||||
|
|
||||||
|
|
||||||
|
def _must_be_in_forward(node: fx.Node) -> bool:
|
||||||
|
return _has_tag_must_be_in_forward(node)
|
||||||
|
|
||||||
|
|
||||||
def _must_be_in_backward(node: fx.Node) -> bool:
|
def _must_be_in_backward(node: fx.Node) -> bool:
|
||||||
return _has_tag_must_be_in_backward(node) or (
|
return _has_tag_must_be_in_backward(node) or (
|
||||||
_has_tag_is_backward(node) and is_with_effects(node)
|
_has_tag_is_backward(node) and is_with_effects(node)
|
||||||
|
|
@ -1465,6 +1477,25 @@ def force_save_collectives(joint_module: fx.GraphModule) -> None:
|
||||||
node.meta["recompute"] = CheckpointPolicy.MUST_SAVE
|
node.meta["recompute"] = CheckpointPolicy.MUST_SAVE
|
||||||
|
|
||||||
|
|
||||||
|
def force_save_bw_mutation_src(joint_module: fx.GraphModule) -> None:
|
||||||
|
# If we have mutations of the same primal in forward and backward,
|
||||||
|
# We must not recompute the source of mutation to not apply twice.
|
||||||
|
has_mutation_in_bw: OrderedSet[torch.fx.Node] = OrderedSet()
|
||||||
|
for node in reversed(joint_module.graph.nodes):
|
||||||
|
if (
|
||||||
|
node.target == torch.ops.aten.copy_.default
|
||||||
|
and _has_tag_must_be_in_backward(node)
|
||||||
|
):
|
||||||
|
has_mutation_in_bw.add(node.args[0])
|
||||||
|
|
||||||
|
if (
|
||||||
|
node.target == torch.ops.aten.copy_.default
|
||||||
|
and _has_tag_must_be_in_forward(node)
|
||||||
|
and node.args[0] in has_mutation_in_bw
|
||||||
|
):
|
||||||
|
node.args[1].meta["recompute"] = CheckpointPolicy.MUST_SAVE
|
||||||
|
|
||||||
|
|
||||||
def cleanup_recompute_tags(joint_module: fx.GraphModule) -> fx.GraphModule:
|
def cleanup_recompute_tags(joint_module: fx.GraphModule) -> fx.GraphModule:
|
||||||
"""
|
"""
|
||||||
If there are two consecutive checkpointed blocks with no operator in
|
If there are two consecutive checkpointed blocks with no operator in
|
||||||
|
|
@ -2537,6 +2568,7 @@ def min_cut_rematerialization_partition(
|
||||||
joint_module = cleanup_recompute_tags(joint_module)
|
joint_module = cleanup_recompute_tags(joint_module)
|
||||||
if not config.unsafe_allow_optimization_of_collectives:
|
if not config.unsafe_allow_optimization_of_collectives:
|
||||||
force_save_collectives(joint_module)
|
force_save_collectives(joint_module)
|
||||||
|
force_save_bw_mutation_src(joint_module)
|
||||||
|
|
||||||
def classify_nodes(joint_module, static_lifetime_input_indices):
|
def classify_nodes(joint_module, static_lifetime_input_indices):
|
||||||
name_to_node = get_name_to_node(joint_module.graph)
|
name_to_node = get_name_to_node(joint_module.graph)
|
||||||
|
|
|
||||||
|
|
@ -698,6 +698,11 @@ void initTorchFunctions(PyObject* module) {
|
||||||
auto t_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(t);
|
auto t_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(t);
|
||||||
return t_impl->has_data_mutation();
|
return t_impl->has_data_mutation();
|
||||||
});
|
});
|
||||||
|
py_module.def("_functionalize_mutation_counter", [](const at::Tensor& t) {
|
||||||
|
TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(t));
|
||||||
|
auto t_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(t);
|
||||||
|
return t_impl->mutation_counter();
|
||||||
|
});
|
||||||
py_module.def(
|
py_module.def(
|
||||||
"_functionalize_get_storage_size", [](const at::Tensor& t, bool before) {
|
"_functionalize_get_storage_size", [](const at::Tensor& t, bool before) {
|
||||||
TORCH_INTERNAL_ASSERT(
|
TORCH_INTERNAL_ASSERT(
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,7 @@ import functools
|
||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
import operator
|
import operator
|
||||||
|
import threading
|
||||||
import traceback
|
import traceback
|
||||||
import typing
|
import typing
|
||||||
import typing_extensions
|
import typing_extensions
|
||||||
|
|
@ -181,7 +182,7 @@ def is_sym_node(node: _HasMeta) -> bool:
|
||||||
return "val" in node.meta and isinstance(node.meta["val"], py_sym_types)
|
return "val" in node.meta and isinstance(node.meta["val"], py_sym_types)
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload # type: ignore[no-overload-impl]
|
||||||
def set_proxy_slot(obj: Tensor, tracer: _ProxyTracer, proxy: _ProxyTensor) -> None: ...
|
def set_proxy_slot(obj: Tensor, tracer: _ProxyTracer, proxy: _ProxyTensor) -> None: ...
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -197,7 +198,66 @@ def set_proxy_slot(
|
||||||
) -> None: ...
|
) -> None: ...
|
||||||
|
|
||||||
|
|
||||||
def set_proxy_slot(
|
class _DisableUpdateTensorTracker(threading.local):
|
||||||
|
value: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
_disable_update_tensor_tracker_tls = _DisableUpdateTensorTracker()
|
||||||
|
|
||||||
|
|
||||||
|
def _is_proxy_tensor_update_tensor_tracker_disabled() -> bool:
|
||||||
|
"""
|
||||||
|
Returns current state of disabling update tensor tracker.
|
||||||
|
"""
|
||||||
|
return _disable_update_tensor_tracker_tls.value
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def _proxy_tensor_disable_update_tensor_tracker() -> Generator[None, None, None]:
|
||||||
|
"""
|
||||||
|
NOTE "Do not clobber inplace ops"
|
||||||
|
By default tensor_tracker is updated every time.
|
||||||
|
This leads to chaining every operation by the FakeTensor.
|
||||||
|
For example for mutable ops if we have several consecutive mutable operations:
|
||||||
|
|
||||||
|
def f(x, y, z):
|
||||||
|
x.copy_(y)
|
||||||
|
x.copy_(z)
|
||||||
|
return x
|
||||||
|
|
||||||
|
Default graph result:
|
||||||
|
def f_graph(x, y, z)
|
||||||
|
x_1 = x.copy_(y)
|
||||||
|
x_2 = x_1.copy_(z)
|
||||||
|
return x_2
|
||||||
|
|
||||||
|
This chaining simplifies the fx passes and helps to prevent the reordering.
|
||||||
|
But in some cases, we want those nodes to be disconnected.
|
||||||
|
E.g. in case of splitting joint graph into forward and backward.
|
||||||
|
If first inplace op happened in forward, second in backward,
|
||||||
|
we want them after split to be properly placed.
|
||||||
|
|
||||||
|
Enabling this context manager for copy_ will result in:
|
||||||
|
def f_graph_2(x, y, z):
|
||||||
|
x_1 = x.copy_(y)
|
||||||
|
x_2 = x.copy_(z)
|
||||||
|
return x
|
||||||
|
|
||||||
|
Results of copy_ x1 and x2 will have empty users in the graph.
|
||||||
|
The reason why this behavior is not enabled for all inplace ops is that
|
||||||
|
some fx passes (e.g. fx quantization) rely on chaining inplace ops like add_
|
||||||
|
in their fusions passes.
|
||||||
|
We could revisit enabling this logic for all inplace ops in future.
|
||||||
|
"""
|
||||||
|
orig_value = _disable_update_tensor_tracker_tls.value
|
||||||
|
_disable_update_tensor_tracker_tls.value = True
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
_disable_update_tensor_tracker_tls.value = orig_value
|
||||||
|
|
||||||
|
|
||||||
|
def set_proxy_slot( # type: ignore[no-redef]
|
||||||
obj: Union[PySymType, _AnyScriptObjectType, Tensor],
|
obj: Union[PySymType, _AnyScriptObjectType, Tensor],
|
||||||
tracer: _ProxyTracer,
|
tracer: _ProxyTracer,
|
||||||
proxy: object,
|
proxy: object,
|
||||||
|
|
@ -207,7 +267,9 @@ def set_proxy_slot(
|
||||||
# We DO want to clobber proxies whenever we run an inplace operation
|
# We DO want to clobber proxies whenever we run an inplace operation
|
||||||
# on a tensor, and it affects the metadata on the proxy.
|
# on a tensor, and it affects the metadata on the proxy.
|
||||||
assert isinstance(proxy, _ProxyTensor)
|
assert isinstance(proxy, _ProxyTensor)
|
||||||
tracer.tensor_tracker[obj] = proxy
|
# see NOTE [Do not clobber inplace ops]
|
||||||
|
if not _is_proxy_tensor_update_tensor_tracker_disabled():
|
||||||
|
tracer.tensor_tracker[obj] = proxy
|
||||||
elif isinstance(obj, (_AnyScriptObject)):
|
elif isinstance(obj, (_AnyScriptObject)):
|
||||||
# We DO want to clobber proxies, with a similar rationale as for tensors.
|
# We DO want to clobber proxies, with a similar rationale as for tensors.
|
||||||
assert isinstance(proxy, Proxy)
|
assert isinstance(proxy, Proxy)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user