mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[annotation] Override metadata on regenerated node in functional mode (#166200)
Fixes #165810 If we regenerate a node during functionalization, we override the "stack_trace", "custom", and "seq_nr" metadata of the regenerated node with the node meta of the original node. ``` python test/functorch/test_aot_joint_with_descriptors.py -k test_preserve_annotate_replay_view python test/functorch/test_aotdispatch.py TestAOTAutogradWithDynamo.test_duplicated_arguments_on_tensor_overlap ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/166200 Approved by: https://github.com/bdhirsh
This commit is contained in:
parent
68b3984b77
commit
f167fd09fa
|
|
@ -1066,6 +1066,8 @@ coverage_ignore_functions = [
|
|||
"set_current_meta",
|
||||
"set_grad_fn_seq_nr",
|
||||
"set_stack_trace",
|
||||
"set_current_replay_node",
|
||||
"get_current_replay_node",
|
||||
# torch.jit.annotations
|
||||
"ann_to_type",
|
||||
"check_fn",
|
||||
|
|
|
|||
|
|
@ -1016,6 +1016,59 @@ class inner_f(torch.nn.Module):
|
|||
self.assertFalse("self._opoverload" in foo_node.meta.get("stack_trace", None))
|
||||
self.assertFalse("self._opoverload" in gm.print_readable(print_output=False))
|
||||
|
||||
def test_preserve_annotate_replay_view(self):
|
||||
"""Test stack trace and annotation are correct on nodes regenerated in functionalization"""
|
||||
|
||||
def _unpermute(out, input_shape, permuted_indices):
|
||||
"""
|
||||
Unpermute operation from torchtitan MoE utils.
|
||||
"""
|
||||
out_unpermuted = out.new_empty(input_shape)
|
||||
out_unpermuted[permuted_indices, :] = out
|
||||
out = out_unpermuted[:-1]
|
||||
return out
|
||||
|
||||
class Module(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.input_shape = (5, 3)
|
||||
self.permuted_indices = torch.tensor([2, 0, 3, 1])
|
||||
|
||||
def forward(self, x):
|
||||
with fx_traceback.annotate({"pp_stage": 0}):
|
||||
routed_output = _unpermute(
|
||||
x, self.input_shape, self.permuted_indices
|
||||
)
|
||||
return routed_output.cos()
|
||||
|
||||
inputs = (torch.randn(4, 3, requires_grad=True),)
|
||||
model = Module()
|
||||
|
||||
graph_module = graph_capture(model, inputs, True)
|
||||
custom_metadata = fx_traceback._get_custom_metadata(graph_module)
|
||||
slice_nodes = graph_module.graph.find_nodes(
|
||||
op="call_function", target=torch.ops.aten.slice.Tensor
|
||||
)
|
||||
self.assertEqual(len(slice_nodes), 1)
|
||||
slice_backward_nodes = graph_module.graph.find_nodes(
|
||||
op="call_function", target=torch.ops.aten.slice_backward.default
|
||||
)
|
||||
self.assertEqual(len(slice_backward_nodes), 1)
|
||||
slice_node = slice_nodes[0]
|
||||
slice_backward_node = slice_backward_nodes[0]
|
||||
|
||||
self.assertEqual(slice_node.meta["seq_nr"], slice_backward_node.meta["seq_nr"])
|
||||
self.assertTrue("out = out_unpermuted[:-1]" in slice_node.meta["stack_trace"])
|
||||
self.assertExpectedInline(
|
||||
str(custom_metadata),
|
||||
"""\
|
||||
('call_function', 'new_empty', {'pp_stage': 0})
|
||||
('call_function', 'index_put', {'pp_stage': 0})
|
||||
('call_function', 'slice_2', {'pp_stage': 0})
|
||||
('call_function', 'slice_backward', {'pp_stage': 0})
|
||||
('call_function', 'index', {'pp_stage': 0})""",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -3245,8 +3245,8 @@ def forward(self, primals_1):
|
|||
as_strided = torch.ops.aten.as_strided.default(clone, [4], [1], 0)
|
||||
add = torch.ops.aten.add.Tensor(as_strided, 1); as_strided = None
|
||||
as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, add, [4], [1], 0); clone = add = None
|
||||
as_strided_8 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
|
||||
view_1 = torch.ops.aten.view.default(as_strided_8, [4]); as_strided_8 = None
|
||||
as_strided_9 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
|
||||
view_1 = torch.ops.aten.view.default(as_strided_9, [4]); as_strided_9 = None
|
||||
return (as_strided_scatter, view_1)""",
|
||||
) # noqa: B950
|
||||
|
||||
|
|
@ -3409,13 +3409,13 @@ def forward(self, primals_1, primals_2, primals_3):
|
|||
as_strided = torch.ops.aten.as_strided.default(clone, [4], [1], 0)
|
||||
add = torch.ops.aten.add.Tensor(as_strided, 1); as_strided = None
|
||||
as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, add, [4], [1], 0); clone = add = None
|
||||
add_1 = torch.ops.aten.add.Tensor(primals_2, primals_3); primals_2 = primals_3 = None
|
||||
as_strided_5 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
|
||||
unsqueeze_1 = torch.ops.aten.unsqueeze.default(as_strided_5, 0); as_strided_5 = None
|
||||
add_2 = torch.ops.aten.add.Tensor(add_1, unsqueeze_1); add_1 = None
|
||||
unsqueeze = torch.ops.aten.unsqueeze.default(as_strided_5, 0); as_strided_5 = None
|
||||
add_1 = torch.ops.aten.add.Tensor(primals_2, primals_3); primals_2 = primals_3 = None
|
||||
add_2 = torch.ops.aten.add.Tensor(add_1, unsqueeze); add_1 = None
|
||||
as_strided_14 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
|
||||
view_2 = torch.ops.aten.view.default(as_strided_14, [-1]); as_strided_14 = None
|
||||
return (as_strided_scatter, add_2, view_2, unsqueeze_1)""",
|
||||
return (as_strided_scatter, add_2, view_2, unsqueeze)""",
|
||||
) # noqa: B950
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable")
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from contextlib import AbstractContextManager
|
|||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.fx.traceback as fx_traceback
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._C import _functionalization_reapply_views_tls as _reapply_views
|
||||
from torch._ops import _get_dispatch_mode_pre_dispatch
|
||||
|
|
@ -512,6 +513,30 @@ class FunctionalTensorMode(TorchDispatchMode):
|
|||
torch.Tensor, wrap, outs_unwrapped
|
||||
)
|
||||
else:
|
||||
# Note: [Functionalization View Replay Annotation]
|
||||
# When functionalization encounters a mutation, it handles aliases by lazily regenerating the aliases
|
||||
# at the first time they are next used.
|
||||
# This is a problem when plumbing user annotations during tracing. We want the view ops from view replay
|
||||
# to have the same annotation that the user specified on the original views. But view replay in
|
||||
# functionalization happens the next time the alias is used (e.g. second_op(alias_with_pending_mutation)),
|
||||
# so when we regenerate views before calling into second_op, those views will end up getting the metadata
|
||||
# for second_op!
|
||||
#
|
||||
# Instead, we need to remember the node metadata from the original views, and ensure that this node metadata
|
||||
# is globally set when we lazily perform view replay.
|
||||
# The globally set metadata will be used to populate the fx node created for the replayed operation.
|
||||
if m := torch._C._get_dispatch_mode(
|
||||
torch._C._TorchDispatchModeKey.PROXY
|
||||
):
|
||||
for a in pytree.tree_leaves([args, kwargs]):
|
||||
if not isinstance(a, FunctionalTensor):
|
||||
continue
|
||||
curr_node = m.tracer.tensor_tracker[
|
||||
torch._from_functional_tensor(a.elem)
|
||||
].proxy.node
|
||||
with fx_traceback.set_current_replay_node(curr_node):
|
||||
torch._sync(a)
|
||||
|
||||
# When we dispatch to the C++ functionalization kernel, we might need to jump back to the
|
||||
# PreDispatch mode stack afterwards, to handle any other PreDispatch modes underneath
|
||||
# FunctionalTensorMode. If we call func() directly, we would need to exclude PreDispatch
|
||||
|
|
|
|||
|
|
@ -206,6 +206,21 @@ class TracerBase:
|
|||
if current_meta.get("in_grad_fn", 0) > 0:
|
||||
annotation_log.debug("seq_nr from current_meta")
|
||||
new_seq_nr = current_meta["grad_fn_seq_nr"][-1]
|
||||
|
||||
# See Note [Functionalization View Replay Annotation]
|
||||
# Overriding some node meta with the original node meta of the
|
||||
# regenerated node.
|
||||
replay_node: Node = fx_traceback.get_current_replay_node()
|
||||
if replay_node is not None:
|
||||
node.meta["is_functional_regenerated"] = True
|
||||
if "seq_nr" in replay_node.meta:
|
||||
annotation_log.debug("seq_nr from replay_node")
|
||||
new_seq_nr = replay_node.meta["seq_nr"]
|
||||
if "custom" in replay_node.meta:
|
||||
node.meta["custom"] = replay_node.meta.get("custom")
|
||||
if "stack_trace" in replay_node.meta:
|
||||
node.stack_trace = replay_node.meta.get("stack_trace")
|
||||
|
||||
annotation_log.debug("Assigning new_seq_nr %s to %s", new_seq_nr, node.name)
|
||||
node.meta["seq_nr"] = new_seq_nr
|
||||
|
||||
|
|
|
|||
|
|
@ -30,9 +30,12 @@ __all__ = [
|
|||
"NodeSource",
|
||||
"NodeSourceAction",
|
||||
"get_graph_provenance_json",
|
||||
"set_current_replay_node",
|
||||
"get_current_replay_node",
|
||||
]
|
||||
|
||||
current_meta: dict[str, Any] = {}
|
||||
current_replay_node: Optional[Node] = None
|
||||
should_preserve_node_meta = False
|
||||
|
||||
|
||||
|
|
@ -400,6 +403,31 @@ def get_current_meta() -> dict[str, Any]:
|
|||
return current_meta
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
@contextmanager
|
||||
def set_current_replay_node(node):
|
||||
"""
|
||||
Set the currently replay node. If `current_replay_node` is not None,
|
||||
then we're re-generating the `current_replay_node` in FunctionalTensorMode.
|
||||
"""
|
||||
# See [Note] annotation for more details.
|
||||
global current_replay_node
|
||||
saved_current_replay_node = current_replay_node
|
||||
try:
|
||||
current_replay_node = node
|
||||
yield
|
||||
finally:
|
||||
current_replay_node = saved_current_replay_node
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
def get_current_replay_node():
|
||||
"""
|
||||
Get the currently replay node
|
||||
"""
|
||||
return current_replay_node
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
def get_graph_provenance_json(graph: Graph) -> dict[str, Any]:
|
||||
"""
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user