mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
[aotd] Support mutations of the same input in fw and bw (#155354)
Original issue: https://github.com/pytorch/pytorch/issues/154820 The issue happens when there is a mutation for the same input in forward AND in backward. AOTD emited copy_ after joint_function tracing. This made this fx-node to correspond to the side effects of both mutations (in forward and in backward). After that partitioner can put it either in forward or in backward. The fix: 1/ Introduce joint_function.handle that allows to set "post_forward" callback, to be able to check inputs state after forward We do not want to apply the mutation after joint, if we already applied it in forward. For that we need "mutation_counter" and memorize the version of mutation that we applied for forward mutation. 2/ Exposing mutation_counter to python We want to keep invariant that copy_ exist only in the end of joint graph. 3/ We memorize mutation_counter and state of the inputs after forward, using the handle post_forward. Emit post_forward mutations after joint graph fully traced. add for post_forward mutations "must_be_in_forward" tag (similar to existing "must_be_in_backward") to keep them in forward. 4/ Ban recompute of the source of mutation. Recompute can apply the same op (e.g. add) in forward and backward. For this set MUST_SAVE for the source of mutation in forward. proxy_tensor changes: By default proxy tensor updates tensor_tracker. In this case applied mutations will be chained. But we want that this copy_ will be independent and applied just to primals. For this introducing a contextmanager to be able to disable update of tensor_tracker for adding forward mutations. Pull Request resolved: https://github.com/pytorch/pytorch/pull/155354 Approved by: https://github.com/bdhirsh
This commit is contained in:
parent
c82a174cea
commit
3f920f3d8f
|
|
@ -122,6 +122,9 @@ struct TORCH_API FunctionalStorageImpl : public c10::StorageImpl {
|
|||
|
||||
~FunctionalStorageImpl() override = default;
|
||||
|
||||
uint64_t mutation_counter() {
|
||||
return mutation_counter_;
|
||||
}
|
||||
void mark_mutation() {
|
||||
mutation_counter_++;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -74,7 +74,9 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl {
|
|||
bool has_metadata_mutation() const {
|
||||
return has_metadata_mutation_;
|
||||
}
|
||||
|
||||
uint64_t mutation_counter() const {
|
||||
return functional_storage_impl()->mutation_counter();
|
||||
}
|
||||
void mark_mutation() {
|
||||
functional_storage_impl()->mark_mutation();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ add_loop_inductor_dynamic_gpu,compile_time_instruction_count,39110000000,0.025
|
|||
|
||||
|
||||
|
||||
add_loop_inductor_gpu,compile_time_instruction_count,26180000000,0.015
|
||||
add_loop_inductor_gpu,compile_time_instruction_count,25780000000,0.015
|
||||
|
||||
|
||||
|
||||
|
|
@ -62,7 +62,7 @@ aotdispatcher_partitioner_cpu,compile_time_instruction_count,8844000000,0.015
|
|||
|
||||
|
||||
|
||||
aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1963000000,0.015
|
||||
aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1917000000,0.015
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
|
|
|
@ -7842,6 +7842,53 @@ class TestAOTAutogradWithDynamo(TestAOTAutograd):
|
|||
self.assertEqual(ref_inps_after_fw, inps_after_fw)
|
||||
self.assertEqual(ref_inps_after_bw, inps_after_bw)
|
||||
|
||||
def test_mutation_of_input_in_fw_and_bw(self):
|
||||
class AF(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, dummy, inplace_tensor):
|
||||
inplace_tensor.add_(1)
|
||||
|
||||
ctx.inplace_tensor = inplace_tensor
|
||||
return dummy.clone()
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
inplace_tensor = ctx.inplace_tensor
|
||||
inplace_tensor.add_(1)
|
||||
return grad_output, None, None
|
||||
|
||||
def fn(dummy, inplace_tensor):
|
||||
return AF.apply(dummy, inplace_tensor)
|
||||
|
||||
def inps():
|
||||
dummy = torch.randn((2,), requires_grad=True)
|
||||
inplace_tensor = torch.zeros((2,), requires_grad=False)
|
||||
return dummy, inplace_tensor
|
||||
|
||||
def sc_inps():
|
||||
dummy = TwoTensor(
|
||||
torch.randn((2,), requires_grad=True),
|
||||
torch.randn((2,), requires_grad=True),
|
||||
)
|
||||
inplace_tensor = TwoTensor(
|
||||
torch.zeros((2,), requires_grad=False),
|
||||
torch.zeros((2,), requires_grad=False),
|
||||
)
|
||||
return dummy, inplace_tensor
|
||||
|
||||
for _inps in [inps, sc_inps]:
|
||||
dummy, inplace = _inps()
|
||||
y = fn(dummy, inplace)
|
||||
ref0 = inplace.clone().detach()
|
||||
y.sum().backward()
|
||||
ref = inplace.clone().detach()
|
||||
|
||||
dummy, inplace = _inps()
|
||||
y = torch.compile(fn, backend="aot_eager", fullgraph=True)(dummy, inplace)
|
||||
self.assertEqual(ref0, inplace)
|
||||
y.sum().backward()
|
||||
self.assertEqual(ref, inplace)
|
||||
|
||||
|
||||
class MockFXGraphCache:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -912,6 +912,13 @@ def gen_pyi(
|
|||
"None",
|
||||
)
|
||||
],
|
||||
"_functionalize_mutation_counter": [
|
||||
defs(
|
||||
"_functionalize_mutation_counter",
|
||||
["t: Tensor"],
|
||||
"_int",
|
||||
)
|
||||
],
|
||||
"_functionalize_are_all_mutations_hidden_from_autograd": [
|
||||
defs(
|
||||
"_functionalize_are_all_mutations_hidden_from_autograd",
|
||||
|
|
|
|||
|
|
@ -265,6 +265,7 @@ def aot_dispatch_autograd_graph(
|
|||
fw_metadata,
|
||||
)
|
||||
joint_fn_to_trace = create_joint(fn_prepared_for_autograd, aot_config=aot_config)
|
||||
joint_fn_handle = joint_fn_to_trace.handle
|
||||
|
||||
joint_fn_to_trace, updated_joint_inputs = create_functionalized_fn(
|
||||
joint_fn_to_trace,
|
||||
|
|
@ -272,6 +273,7 @@ def aot_dispatch_autograd_graph(
|
|||
meta=fw_metadata,
|
||||
aot_config=aot_config,
|
||||
trace_joint=True,
|
||||
joint_fn_handle=joint_fn_handle,
|
||||
)
|
||||
|
||||
# TODO: replace with AOTDispatchSubclassWrapper once we refactor
|
||||
|
|
|
|||
|
|
@ -10,11 +10,11 @@ It does so by:
|
|||
3. transforming mutations into extra outputs
|
||||
4. dispatching subclasses
|
||||
"""
|
||||
|
||||
import warnings
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from dataclasses import dataclass
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, Union
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
|
|
@ -25,6 +25,7 @@ from torch._decomp.decompositions_for_rng import PhiloxStateTracker
|
|||
from torch._guards import detect_fake_mode
|
||||
from torch._prims_common import CUDARngStateHelper
|
||||
from torch.fx.experimental.proxy_tensor import (
|
||||
_proxy_tensor_disable_update_tensor_tracker,
|
||||
maybe_disable_thunkify,
|
||||
maybe_enable_thunkify,
|
||||
)
|
||||
|
|
@ -34,6 +35,7 @@ from torch.fx.experimental.symbolic_shapes import (
|
|||
sym_eq,
|
||||
)
|
||||
from torch.nn.utils import stateless
|
||||
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
|
||||
|
||||
from .. import config
|
||||
from .collect_metadata_analysis import run_functionalized_fw_and_collect_metadata
|
||||
|
|
@ -182,6 +184,11 @@ def fn_prepped_for_autograd(
|
|||
return inner_fn
|
||||
|
||||
|
||||
@dataclass
|
||||
class JointFnHandle:
|
||||
post_forward: Optional[Callable] = None
|
||||
|
||||
|
||||
# Given a fn, computes the joint.
|
||||
# NOTE: fn is expects the following behavior:
|
||||
# (1) fn() needs to return a tuple of (outs, mask),
|
||||
|
|
@ -193,9 +200,15 @@ def fn_prepped_for_autograd(
|
|||
# otherwise, when we compute autograd.grad(), we will not take those input mutations into account
|
||||
# (the way this is handled is that we ensure any inputs that normally get mutated are cloned first)
|
||||
def create_joint(fn: Callable, *, aot_config: AOTConfig) -> Any:
|
||||
joint_fn_handle = JointFnHandle()
|
||||
|
||||
# post_forward
|
||||
def inner_fn(primals: list[Any], tangents: list[Any]):
|
||||
outs, tangent_mask = fn(*primals)
|
||||
|
||||
if joint_fn_handle and joint_fn_handle.post_forward:
|
||||
joint_fn_handle.post_forward(primals)
|
||||
|
||||
assert len(tangent_mask) == len(outs)
|
||||
outs_to_grad = [
|
||||
o for needs_tangent, o in zip(tangent_mask, outs) if needs_tangent
|
||||
|
|
@ -285,6 +298,8 @@ def create_joint(fn: Callable, *, aot_config: AOTConfig) -> Any:
|
|||
with torch.autograd.detect_anomaly(check_nan=False):
|
||||
return inner_fn(*args)
|
||||
|
||||
inner_fn_with_anomaly.handle = joint_fn_handle # type: ignore[attr-defined]
|
||||
|
||||
return inner_fn_with_anomaly
|
||||
|
||||
|
||||
|
|
@ -379,6 +394,118 @@ def set_partitioner_tag_must_be_in_backward():
|
|||
return set_partitioner_tag("must_be_in_backward")
|
||||
|
||||
|
||||
def set_partitioner_tag_must_be_in_forward():
|
||||
return set_partitioner_tag("must_be_in_forward")
|
||||
|
||||
|
||||
def _get_mutation_counter(t) -> int:
|
||||
if not is_traceable_wrapper_subclass(t):
|
||||
return torch._functionalize_mutation_counter(t.elem) # type: ignore[attr-defined]
|
||||
|
||||
max_mc = -1
|
||||
|
||||
def visit(e):
|
||||
if not is_traceable_wrapper_subclass(e):
|
||||
mc = torch._functionalize_mutation_counter(e.elem) # type: ignore[attr-defined]
|
||||
nonlocal max_mc
|
||||
max_mc = max(mc, max_mc)
|
||||
return
|
||||
|
||||
for a in e.__tensor_flatten__()[0]:
|
||||
visit(getattr(e, a))
|
||||
|
||||
visit(t)
|
||||
return max_mc
|
||||
|
||||
|
||||
def apply_in_graph_mutations(input_info, inpt_old, inpt_new, f_inpt, input_idx):
|
||||
assert input_info.mutation_type == MutationType.MUTATED_IN_GRAPH
|
||||
# See Note [set_() Input Mutations in AOTAutograd]
|
||||
# all mutations on the input must be under no_grad, so it is safe to put in the graph
|
||||
# Here, we're saying that if an input experienced a set call, inp.set_(other),
|
||||
# then we can effectively not have to worry about whether its data was mutated.
|
||||
# There are 3 cases:
|
||||
# (1) We mutate inp *after* the set_() call. other is a graph intermediate.
|
||||
# In this case, we're not really mutating the input storage of "inp";
|
||||
# we're mutating the storage of an intermdiate value (other),
|
||||
# and slamming that storage into the input tensor. So no data mutation is necessary.
|
||||
# (2) We mutate inp *after* the set_() call. other is a graph *input*.
|
||||
# In this case, the data mutation will be properly handled in the runtime
|
||||
# epilogue during the processing of "other"
|
||||
# (3) We mutate inp *before* the set_() call.
|
||||
# This case is *not* currently handled.
|
||||
if input_info.mutates_storage_metadata:
|
||||
with torch.no_grad():
|
||||
inpt_old.set_(inpt_new)
|
||||
|
||||
# Note [Ordering of resize_() and set_()]
|
||||
# Importantly: the common usage in FSDP is that we have a dummy parameter
|
||||
# that sees a set_() and **Then** a resize_().
|
||||
# We must put those mutations into the graph in the same order,
|
||||
# Since running them in the opposite order will have different behavior.
|
||||
# We fully ban resize_() followed by set_() for now, although in principal
|
||||
# we could support this
|
||||
if input_info.mutation_inductor_storage_resize:
|
||||
# resizing is not supported on subclasses (we error earlier if this happens)
|
||||
from torch._subclasses.functional_tensor import FunctionalTensor
|
||||
|
||||
assert isinstance(f_inpt, FunctionalTensor)
|
||||
old_storage_size = torch._functionalize_get_storage_size( # type: ignore[attr-defined]
|
||||
f_inpt.elem, before=True
|
||||
)
|
||||
new_storage_size = torch._functionalize_get_storage_size( # type: ignore[attr-defined]
|
||||
f_inpt.elem, before=False
|
||||
)
|
||||
if old_storage_size != new_storage_size:
|
||||
assert (
|
||||
old_storage_size == 0 or new_storage_size == 0
|
||||
), f"""\
|
||||
Encosize during tracing on input {input_idx}. Old nbytes={old_storage_size}, new nbytes={new_storage_size}
|
||||
We oresizing on graph inputs as long as the input either starts or ends with a storage size of 0
|
||||
(thee for FSDP)"""
|
||||
torch.ops.inductor.resize_storage_bytes_(inpt_old, new_storage_size)
|
||||
if new_storage_size == 0:
|
||||
# Even if we marked the input as having a data mutation (thus needing a copy_()),
|
||||
# We should **ignore** it if our input has no storage
|
||||
# (this can happen if, e.g. we temporarily resize our input, copy data into it,
|
||||
# and resize it back down to zero)
|
||||
return
|
||||
# Optimization: if the copy_() is a no-op then don't include it in the graph.
|
||||
# In theory inductor could optimize this away, however in fsdp, we end up with
|
||||
# param.copy_(param), where param is a zero-storage-size tensor,
|
||||
# and running this op in eager mode (using the aot_eager backend) will result in a segfault.
|
||||
# So we may as well optimize it away here.
|
||||
if inpt_old is inpt_new:
|
||||
# (This check needs to be done after putting resize_() in the graph,
|
||||
# since a resize_(0) doesn't actually change the FunctionalTensor's inner tensor)
|
||||
return
|
||||
# We found an input that had a (data-only) mutation.
|
||||
# Since keep_input_mutations is set, we need to faithfully apply a copy_()
|
||||
# so the compiler will see the input mutation in the graph.
|
||||
if input_info.mutates_data and input_info.mutations_hidden_from_autograd:
|
||||
# Hidden from autograd = run under no_grad, **and** don't bump VC
|
||||
# (although if the tensor was created in inference mode, it has no VC)
|
||||
if inpt_old.is_inference():
|
||||
maybe_preserve_vc = nullcontext()
|
||||
else:
|
||||
maybe_preserve_vc = torch.autograd._unsafe_preserve_version_counter(
|
||||
inpt_old # type: ignore[assignment]
|
||||
)
|
||||
with torch.no_grad(), maybe_preserve_vc:
|
||||
inpt_old.copy_(inpt_new)
|
||||
elif (
|
||||
input_info.mutates_data and input_info.mutations_under_no_grad_or_inference_mode
|
||||
):
|
||||
# Under no_grad = run under no_grad (we still bump the VC though)
|
||||
# (inference_mode will also bump the VC, as long as the tensor in question
|
||||
# was created outside of inference_mode)
|
||||
|
||||
with torch.no_grad():
|
||||
inpt_old.copy_(inpt_new)
|
||||
elif input_info.mutates_data:
|
||||
inpt_old.copy_(inpt_new)
|
||||
|
||||
|
||||
# This creates the final function that we want to trace using make_fx(),
|
||||
# in both aot_dispatch_autograd and aot_dispatch_base.
|
||||
# Preconditions:
|
||||
|
|
@ -398,7 +525,16 @@ def create_functionalized_fn(
|
|||
meta: ViewAndMutationMeta,
|
||||
aot_config: AOTConfig,
|
||||
trace_joint: bool,
|
||||
joint_fn_handle: Optional[JointFnHandle] = None,
|
||||
) -> Any:
|
||||
primals_after_forward = None
|
||||
f_args_after_forward = None
|
||||
f_args_mutation_counter_after_forward: Optional[list[int]] = None
|
||||
inputs_mutated_in_graph = [
|
||||
info.mutation_type == MutationType.MUTATED_IN_GRAPH for info in meta.input_info
|
||||
]
|
||||
has_input_mutated_in_graph = any(inputs_mutated_in_graph)
|
||||
|
||||
@wraps(fn)
|
||||
def _functionalized_f_helper(*args):
|
||||
with maybe_enable_thunkify():
|
||||
|
|
@ -415,6 +551,24 @@ def create_functionalized_fn(
|
|||
# Wrap inputs into functional wrappers
|
||||
f_args = pytree.tree_map(to_fun, args)
|
||||
|
||||
if trace_joint and has_input_mutated_in_graph and joint_fn_handle:
|
||||
# TODO(ivankobzarev): Support fw and bw mutations for subclasses
|
||||
def _post_forward(primals):
|
||||
nonlocal primals_after_forward
|
||||
primals_after_forward = pytree.tree_map(from_fun, primals)
|
||||
nonlocal f_args_after_forward
|
||||
f_args_after_forward = f_args[0]
|
||||
nonlocal f_args_mutation_counter_after_forward
|
||||
|
||||
f_args_mutation_counter_after_forward = [
|
||||
-1
|
||||
if not inputs_mutated_in_graph[i]
|
||||
else _get_mutation_counter(f_arg)
|
||||
for i, f_arg in enumerate(f_args_after_forward)
|
||||
]
|
||||
|
||||
joint_fn_handle.post_forward = _post_forward
|
||||
|
||||
# Run the joint
|
||||
f_outs = fn(*f_args)
|
||||
|
||||
|
|
@ -535,110 +689,86 @@ def create_functionalized_fn(
|
|||
# we will materialize an "updated" synthetic base, and copy it back to the synthetic input base.
|
||||
# This allows us to factor aot autograd much more nicely, since only one area of the code needs to worry
|
||||
# about synthetic bases.
|
||||
for i, (inpt_old, inpt_f) in enumerate(
|
||||
|
||||
# Apply in graph forward mutations only in joint case.
|
||||
# Note: Mutations of primals in forward AND backward.
|
||||
# If we have mutations of the same input in forward and in backward,
|
||||
# we can not fuse them into one copy_ node. As in this case partitioner will put it
|
||||
# either in forward or in backward. This will lead to incorrect state
|
||||
# after forward and before backward.
|
||||
# We have to emit two copy_ nodes, marking with additional meta each node,
|
||||
# if it must be in forward or backward.
|
||||
# We memorize mutation counter of the inputs after forward.
|
||||
# Based on this after joint graph we check if backward also mutated input or not.
|
||||
# We emit copy_ only in the end of joint tracing, to provide invariant for joint
|
||||
# graph passes, that our graph is functional, except only some number of copy_ nodes
|
||||
# in the end.
|
||||
inputs_mutated_in_graph_applied_mutation_counters: list[int] = [
|
||||
0
|
||||
] * len(meta.input_info)
|
||||
if f_args_mutation_counter_after_forward is not None:
|
||||
primals_before = args[0]
|
||||
for idx, (f_inpt, before, after, inpt_info) in enumerate(
|
||||
zip(
|
||||
f_args_after_forward, # type: ignore[arg-type]
|
||||
primals_before, # type: ignore[arg-type]
|
||||
primals_after_forward, # type: ignore[arg-type]
|
||||
meta.input_info,
|
||||
)
|
||||
):
|
||||
if inpt_info.mutation_type != MutationType.MUTATED_IN_GRAPH:
|
||||
continue
|
||||
|
||||
assert f_args_mutation_counter_after_forward
|
||||
post_fw_mc = f_args_mutation_counter_after_forward[idx]
|
||||
mc = _get_mutation_counter(f_inpt)
|
||||
|
||||
if mc > 0:
|
||||
# Mutation in forward.
|
||||
with (
|
||||
torch.fx.traceback.preserve_node_meta(),
|
||||
set_partitioner_tag_must_be_in_forward(),
|
||||
_proxy_tensor_disable_update_tensor_tracker(),
|
||||
):
|
||||
apply_in_graph_mutations(
|
||||
inpt_info,
|
||||
before,
|
||||
after,
|
||||
f_inpt,
|
||||
idx,
|
||||
)
|
||||
inputs_mutated_in_graph_applied_mutation_counters[
|
||||
idx
|
||||
] = post_fw_mc
|
||||
|
||||
for idx, (inpt_old, f_inpt) in enumerate(
|
||||
zip(args, f_args) if not trace_joint else zip(args[0], f_args[0])
|
||||
):
|
||||
if not isinstance(inpt_f, torch.Tensor):
|
||||
if not isinstance(f_inpt, torch.Tensor):
|
||||
continue
|
||||
assert is_fun(inpt_f)
|
||||
inpt_new = from_fun(inpt_f)
|
||||
assert is_fun(f_inpt)
|
||||
inpt_new = from_fun(f_inpt)
|
||||
if (
|
||||
meta.input_info[i].mutation_type
|
||||
== MutationType.MUTATED_IN_GRAPH
|
||||
meta.input_info[idx].mutation_type
|
||||
!= MutationType.MUTATED_IN_GRAPH
|
||||
):
|
||||
# See Note [set_() Input Mutations in AOTAutograd]
|
||||
# all mutations on the input must be under no_grad, so it is safe to put in the graph
|
||||
# Here, we're saying that if an input experienced a set call, inp.set_(other),
|
||||
# then we can effectively not have to worry about whether its data was mutated.
|
||||
# There are 3 cases:
|
||||
# (1) We mutate inp *after* the set_() call. other is a graph intermediate.
|
||||
# In this case, we're not really mutating the input storage of "inp";
|
||||
# we're mutating the storage of an intermdiate value (other),
|
||||
# and slamming that storage into the input tensor. So no data mutation is necessary.
|
||||
# (2) We mutate inp *after* the set_() call. other is a graph *input*.
|
||||
# In this case, the data mutation will be properly handled in the runtime
|
||||
# epilogue during the processing of "other"
|
||||
# (3) We mutate inp *before* the set_() call.
|
||||
# This case is *not* currently handled.
|
||||
if meta.input_info[i].mutates_storage_metadata:
|
||||
with torch.no_grad():
|
||||
inpt_old.set_(inpt_new)
|
||||
|
||||
# Note [Ordering of resize_() and set_()]
|
||||
# Importantly: the common usage in FSDP is that we have a dummy parameter
|
||||
# that sees a set_() and **Then** a resize_().
|
||||
# We must put those mutations into the graph in the same order,
|
||||
# Since running them in the opposite order will have different behavior.
|
||||
# We fully ban resize_() followed by set_() for now, although in principal
|
||||
# we could support this
|
||||
if meta.input_info[i].mutation_inductor_storage_resize:
|
||||
# resizing is not supported on subclasses (we error earlier if this happens)
|
||||
from torch._subclasses.functional_tensor import (
|
||||
FunctionalTensor,
|
||||
)
|
||||
|
||||
assert isinstance(inpt_f, FunctionalTensor)
|
||||
old_storage_size = torch._functionalize_get_storage_size( # type: ignore[attr-defined]
|
||||
inpt_f.elem, before=True
|
||||
)
|
||||
new_storage_size = torch._functionalize_get_storage_size( # type: ignore[attr-defined]
|
||||
inpt_f.elem, before=False
|
||||
)
|
||||
if old_storage_size != new_storage_size:
|
||||
assert (
|
||||
old_storage_size == 0 or new_storage_size == 0
|
||||
), f"""\
|
||||
Encountered a storage resize during tracing on input {i}. Old nbytes={old_storage_size}, new nbytes={new_storage_size}
|
||||
We only support storage resizing on graph inputs as long as the input either starts or ends with a storage size of 0
|
||||
(the case for FSDP)"""
|
||||
torch.ops.inductor.resize_storage_bytes_(
|
||||
inpt_old, new_storage_size
|
||||
)
|
||||
if new_storage_size == 0:
|
||||
# Even if we marked the input as having a data mutation (thus needing a copy_()),
|
||||
# We should **ignore** it if our input has no storage
|
||||
# (this can happen if, e.g. we temporarily resize our input, copy data into it,
|
||||
# and resize it back down to zero)
|
||||
continue
|
||||
# Optimization: if the copy_() is a no-op then don't include it in the graph.
|
||||
# In theory inductor could optimize this away, however in fsdp, we end up with
|
||||
# param.copy_(param), where param is a zero-storage-size tensor,
|
||||
# and running this op in eager mode (using the aot_eager backend) will result in a segfault.
|
||||
# So we may as well optimize it away here.
|
||||
if inpt_old is inpt_new:
|
||||
# (This check needs to be done after putting resize_() in the graph,
|
||||
# since a resize_(0) doesn't actually change the FunctionalTensor's inner tensor)
|
||||
continue
|
||||
if f_args_mutation_counter_after_forward is not None:
|
||||
# This could happen for subclasses tracing
|
||||
# Subclasses support for mutations in fw and bw is TBD.
|
||||
mc = _get_mutation_counter(f_inpt)
|
||||
if mc == inputs_mutated_in_graph_applied_mutation_counters[idx]:
|
||||
# No mutation in backward; mutation was already applied.
|
||||
continue
|
||||
# We found an input that had a (data-only) mutation.
|
||||
# Since keep_input_mutations is set, we need to faithfully apply a copy_()
|
||||
# so the compiler will see the input mutation in the graph.
|
||||
if (
|
||||
meta.input_info[i].mutates_data
|
||||
and meta.input_info[i].mutations_hidden_from_autograd
|
||||
):
|
||||
# Hidden from autograd = run under no_grad, **and** don't bump VC
|
||||
# (although if the tensor was created in inference mode, it has no VC)
|
||||
if inpt_old.is_inference():
|
||||
maybe_preserve_vc = nullcontext()
|
||||
else:
|
||||
maybe_preserve_vc = torch.autograd._unsafe_preserve_version_counter(
|
||||
inpt_old # type: ignore[assignment]
|
||||
)
|
||||
with torch.no_grad(), maybe_preserve_vc:
|
||||
inpt_old.copy_(inpt_new)
|
||||
elif (
|
||||
meta.input_info[i].mutates_data
|
||||
and meta.input_info[
|
||||
i
|
||||
].mutations_under_no_grad_or_inference_mode
|
||||
):
|
||||
# Under no_grad = run under no_grad (we still bump the VC though)
|
||||
# (inference_mode will also bump the VC, as long as the tensor in question
|
||||
# was created outside of inference_mode)
|
||||
with torch.no_grad():
|
||||
inpt_old.copy_(inpt_new)
|
||||
elif meta.input_info[i].mutates_data:
|
||||
inpt_old.copy_(inpt_new)
|
||||
|
||||
with torch.fx.traceback.preserve_node_meta(), set_partitioner_tag_must_be_in_backward():
|
||||
apply_in_graph_mutations(
|
||||
meta.input_info[idx],
|
||||
inpt_old,
|
||||
inpt_new,
|
||||
f_inpt,
|
||||
idx,
|
||||
)
|
||||
|
||||
# When an output tensor is a functionalized mutated input, and we
|
||||
# were able to move the mutation in to the graph then we can return
|
||||
|
|
|
|||
|
|
@ -201,6 +201,10 @@ def _extract_graph_with_inputs_outputs(
|
|||
env[node] = InvalidNode # type: ignore[assignment]
|
||||
continue
|
||||
|
||||
if _must_be_in_forward(node) and subgraph != "forward":
|
||||
env[node] = InvalidNode # type: ignore[assignment]
|
||||
continue
|
||||
|
||||
if node in env:
|
||||
# Node must be one of our inputs. (Any member of env which wasn't an
|
||||
# input to start must have been created by this loop and won't be in
|
||||
|
|
@ -274,10 +278,18 @@ def _has_tag_is_backward(node: fx.Node) -> bool:
|
|||
return node.meta.get("partitioner_tag", None) == "is_backward"
|
||||
|
||||
|
||||
def _has_tag_must_be_in_forward(node: fx.Node) -> bool:
|
||||
return node.meta.get("partitioner_tag", None) == "must_be_in_forward"
|
||||
|
||||
|
||||
def _has_tag_must_be_in_backward(node: fx.Node) -> bool:
|
||||
return node.meta.get("partitioner_tag", None) == "must_be_in_backward"
|
||||
|
||||
|
||||
def _must_be_in_forward(node: fx.Node) -> bool:
|
||||
return _has_tag_must_be_in_forward(node)
|
||||
|
||||
|
||||
def _must_be_in_backward(node: fx.Node) -> bool:
|
||||
return _has_tag_must_be_in_backward(node) or (
|
||||
_has_tag_is_backward(node) and is_with_effects(node)
|
||||
|
|
@ -1465,6 +1477,25 @@ def force_save_collectives(joint_module: fx.GraphModule) -> None:
|
|||
node.meta["recompute"] = CheckpointPolicy.MUST_SAVE
|
||||
|
||||
|
||||
def force_save_bw_mutation_src(joint_module: fx.GraphModule) -> None:
|
||||
# If we have mutations of the same primal in forward and backward,
|
||||
# We must not recompute the source of mutation to not apply twice.
|
||||
has_mutation_in_bw: OrderedSet[torch.fx.Node] = OrderedSet()
|
||||
for node in reversed(joint_module.graph.nodes):
|
||||
if (
|
||||
node.target == torch.ops.aten.copy_.default
|
||||
and _has_tag_must_be_in_backward(node)
|
||||
):
|
||||
has_mutation_in_bw.add(node.args[0])
|
||||
|
||||
if (
|
||||
node.target == torch.ops.aten.copy_.default
|
||||
and _has_tag_must_be_in_forward(node)
|
||||
and node.args[0] in has_mutation_in_bw
|
||||
):
|
||||
node.args[1].meta["recompute"] = CheckpointPolicy.MUST_SAVE
|
||||
|
||||
|
||||
def cleanup_recompute_tags(joint_module: fx.GraphModule) -> fx.GraphModule:
|
||||
"""
|
||||
If there are two consecutive checkpointed blocks with no operator in
|
||||
|
|
@ -2537,6 +2568,7 @@ def min_cut_rematerialization_partition(
|
|||
joint_module = cleanup_recompute_tags(joint_module)
|
||||
if not config.unsafe_allow_optimization_of_collectives:
|
||||
force_save_collectives(joint_module)
|
||||
force_save_bw_mutation_src(joint_module)
|
||||
|
||||
def classify_nodes(joint_module, static_lifetime_input_indices):
|
||||
name_to_node = get_name_to_node(joint_module.graph)
|
||||
|
|
|
|||
|
|
@ -698,6 +698,11 @@ void initTorchFunctions(PyObject* module) {
|
|||
auto t_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(t);
|
||||
return t_impl->has_data_mutation();
|
||||
});
|
||||
py_module.def("_functionalize_mutation_counter", [](const at::Tensor& t) {
|
||||
TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(t));
|
||||
auto t_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(t);
|
||||
return t_impl->mutation_counter();
|
||||
});
|
||||
py_module.def(
|
||||
"_functionalize_get_storage_size", [](const at::Tensor& t, bool before) {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ import functools
|
|||
import inspect
|
||||
import logging
|
||||
import operator
|
||||
import threading
|
||||
import traceback
|
||||
import typing
|
||||
import typing_extensions
|
||||
|
|
@ -181,7 +182,7 @@ def is_sym_node(node: _HasMeta) -> bool:
|
|||
return "val" in node.meta and isinstance(node.meta["val"], py_sym_types)
|
||||
|
||||
|
||||
@overload
|
||||
@overload # type: ignore[no-overload-impl]
|
||||
def set_proxy_slot(obj: Tensor, tracer: _ProxyTracer, proxy: _ProxyTensor) -> None: ...
|
||||
|
||||
|
||||
|
|
@ -197,7 +198,66 @@ def set_proxy_slot(
|
|||
) -> None: ...
|
||||
|
||||
|
||||
def set_proxy_slot(
|
||||
class _DisableUpdateTensorTracker(threading.local):
|
||||
value: bool = False
|
||||
|
||||
|
||||
_disable_update_tensor_tracker_tls = _DisableUpdateTensorTracker()
|
||||
|
||||
|
||||
def _is_proxy_tensor_update_tensor_tracker_disabled() -> bool:
|
||||
"""
|
||||
Returns current state of disabling update tensor tracker.
|
||||
"""
|
||||
return _disable_update_tensor_tracker_tls.value
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _proxy_tensor_disable_update_tensor_tracker() -> Generator[None, None, None]:
|
||||
"""
|
||||
NOTE "Do not clobber inplace ops"
|
||||
By default tensor_tracker is updated every time.
|
||||
This leads to chaining every operation by the FakeTensor.
|
||||
For example for mutable ops if we have several consecutive mutable operations:
|
||||
|
||||
def f(x, y, z):
|
||||
x.copy_(y)
|
||||
x.copy_(z)
|
||||
return x
|
||||
|
||||
Default graph result:
|
||||
def f_graph(x, y, z)
|
||||
x_1 = x.copy_(y)
|
||||
x_2 = x_1.copy_(z)
|
||||
return x_2
|
||||
|
||||
This chaining simplifies the fx passes and helps to prevent the reordering.
|
||||
But in some cases, we want those nodes to be disconnected.
|
||||
E.g. in case of splitting joint graph into forward and backward.
|
||||
If first inplace op happened in forward, second in backward,
|
||||
we want them after split to be properly placed.
|
||||
|
||||
Enabling this context manager for copy_ will result in:
|
||||
def f_graph_2(x, y, z):
|
||||
x_1 = x.copy_(y)
|
||||
x_2 = x.copy_(z)
|
||||
return x
|
||||
|
||||
Results of copy_ x1 and x2 will have empty users in the graph.
|
||||
The reason why this behavior is not enabled for all inplace ops is that
|
||||
some fx passes (e.g. fx quantization) rely on chaining inplace ops like add_
|
||||
in their fusions passes.
|
||||
We could revisit enabling this logic for all inplace ops in future.
|
||||
"""
|
||||
orig_value = _disable_update_tensor_tracker_tls.value
|
||||
_disable_update_tensor_tracker_tls.value = True
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_disable_update_tensor_tracker_tls.value = orig_value
|
||||
|
||||
|
||||
def set_proxy_slot( # type: ignore[no-redef]
|
||||
obj: Union[PySymType, _AnyScriptObjectType, Tensor],
|
||||
tracer: _ProxyTracer,
|
||||
proxy: object,
|
||||
|
|
@ -207,7 +267,9 @@ def set_proxy_slot(
|
|||
# We DO want to clobber proxies whenever we run an inplace operation
|
||||
# on a tensor, and it affects the metadata on the proxy.
|
||||
assert isinstance(proxy, _ProxyTensor)
|
||||
tracer.tensor_tracker[obj] = proxy
|
||||
# see NOTE [Do not clobber inplace ops]
|
||||
if not _is_proxy_tensor_update_tensor_tracker_disabled():
|
||||
tracer.tensor_tracker[obj] = proxy
|
||||
elif isinstance(obj, (_AnyScriptObject)):
|
||||
# We DO want to clobber proxies, with a similar rationale as for tensors.
|
||||
assert isinstance(proxy, Proxy)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user