[dynamo] propagate tensor metadata on Tensor.__setitem__(tensor) (#161036)

Fixes silent incorrectness for autograd function tracing, where we rely on FakeTensor metadata (requires_grad) to determine whether to HOP or not: 5ee464db5c/torch/_dynamo/variables/misc.py (L671)

Stared at this with @anijain2305 yesterday, `Tensor.__setitem__` can update tensor metadata, and we can just run the fake prop and extract the output metadata from the updated FakeTensor.

FIXES https://github.com/pytorch/pytorch/issues/160901

It should also be the root cause behind the issue in https://github.com/pytorch/torchtitan/pull/1604 @bdhirsh  @ruisizhang123

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161036
Approved by: https://github.com/anijain2305
ghstack dependencies: #160805
This commit is contained in:
Simon Fan 2025-08-20 17:17:52 -07:00 committed by PyTorch MergeBot
parent c7fb031706
commit 8aad3a60ce
4 changed files with 112 additions and 48 deletions

View File

@ -7141,6 +7141,37 @@ def forward(self, s77 : torch.SymInt, s27 : torch.SymInt, L_x_ : torch.Tensor):
torch.compile(f, backend="eager", fullgraph=True)(eye, out_res)
self.assertEqual(out_ref, out_res)
def test_setitem_tensor_prop(self):
# Using the composite implicit of the forward would be incorrect
class MyFn(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
return torch.matmul(x, x.t())
@staticmethod
def backward(ctx, grad_out):
return grad_out
def fn(x, y):
x[0] = y[0]
return MyFn.apply(x)
def inputs():
torch.manual_seed(123)
x = torch.randn(10, 10)
y = torch.randn(10, 10, requires_grad=True)
return x, y
x1, y1 = inputs()
fn(x1, y1).sum().backward()
self.assertTrue(x1.requires_grad)
x2, y2 = inputs()
torch.compile(fn, backend="eager")(x2, y2).sum().backward()
self.assertTrue(x2.requires_grad)
self.assertEqual(y1.grad, y2.grad)
def test_nn_parameter_ctor_graph_breaks(self):
def fn():
param = torch.nn.Parameter(torch.ones(10))

View File

@ -2055,32 +2055,8 @@ class VariableBuilder:
return self.tx.output.input_source_to_var[source]
options = {}
if type(value) in (
torch.Tensor,
torch.nn.Parameter,
torch._subclasses.fake_tensor.FakeTensor,
torch._subclasses.functional_tensor.FunctionalTensor,
) or is_traceable_wrapper_subclass(value):
# Ordinarily, we would fakeify a tensor so that it can get dynamic
# shapes and be computed on without triggering actual operations.
# However, how can we fakeify a tensor subclass? Ordinary
# inheritance (nor multiple inheritance) won't work work.
#
# Instead, our plan is to *manually simulate* the tensor subclass
# inheriting from a fake tensor with dynamo. This means our
# data representation for a tensor subclass will be a fake tensor
# + tensor subclass type + any extra data the subclass may have
# been storing on the tensor. Because all Python accesses are
# mediated through TensorWithTFOverrideVariable, we can ensure
# that we dispatch differently, e.g., according to
# __torch_function__
#
# To simplify things for now, the __dict__ tracking bits haven't
# been implemented yet, but they can be added into this design at
# a later point in time.
subclass_type = None
else:
subclass_type = type(value)
subclass_type = infer_subclass_type(value)
if subclass_type is not None:
self.install_guards(GuardBuilder.TYPE_MATCH)
if get_static_address_type(value) == "guarded":
@ -3038,6 +3014,55 @@ def handle_traced_output(example_value, tx, proxy, options, subclass_type, targe
)
def infer_subclass_type(value):
if type(value) in (
torch.Tensor,
torch.nn.Parameter,
torch._subclasses.fake_tensor.FakeTensor,
torch._subclasses.functional_tensor.FunctionalTensor,
) or is_traceable_wrapper_subclass(value):
# Ordinarily, we would fakeify a tensor so that it can get dynamic
# shapes and be computed on without triggering actual operations.
# However, how can we fakeify a tensor subclass? Ordinary
# inheritance (nor multiple inheritance) won't work work.
#
# Instead, our plan is to *manually simulate* the tensor subclass
# inheriting from a fake tensor with dynamo. This means our
# data representation for a tensor subclass will be a fake tensor
# + tensor subclass type + any extra data the subclass may have
# been storing on the tensor. Because all Python accesses are
# mediated through TensorWithTFOverrideVariable, we can ensure
# that we dispatch differently, e.g., according to
# __torch_function__
#
# To simplify things for now, the __dict__ tracking bits haven't
# been implemented yet, but they can be added into this design at
# a later point in time.
return None
else:
return type(value)
def get_specialized_props(target_cls, tx, example_value, subclass_type):
specialized_props = target_cls.specialize(example_value)
# TODO: not sure about this fake mode test
if (
isinstance(example_value, torch._subclasses.fake_tensor.FakeTensor)
and example_value.fake_mode is tx.fake_mode
):
if subclass_type:
tensor_type = subclass_type
elif isinstance(example_value, torch.nn.Parameter):
tensor_type = torch.nn.Parameter
elif isinstance(example_value, torch.nn.Buffer):
tensor_type = torch.nn.Buffer
else:
tensor_type = torch.Tensor
specialized_props["class_type"] = tensor_type
return specialized_props
def construct_tensor_variable(
target_cls, tx, proxy, example_value, subclass_type, options
):
@ -3055,23 +3080,7 @@ def construct_tensor_variable(
# when lifting unbacked symbols of input tensors to subgraph inputs.
# We do it lazily because the tensor may not be used in subgraphs.
tx.output.current_tracer.track_unbacked_symbols(example_value, proxy)
specialized_props = target_cls.specialize(example_value)
# TODO: not sure about this fake mode test
if (
isinstance(example_value, torch._subclasses.fake_tensor.FakeTensor)
and example_value.fake_mode is tx.fake_mode
):
if subclass_type:
tensor_type = subclass_type
elif isinstance(example_value, torch.nn.Parameter):
tensor_type = torch.nn.Parameter
elif isinstance(example_value, torch.nn.Buffer):
tensor_type = torch.nn.Buffer
else:
tensor_type = torch.Tensor
specialized_props["class_type"] = tensor_type
options.update(specialized_props)
options.update(get_specialized_props(target_cls, tx, example_value, subclass_type))
return target_cls(proxy, **options)

View File

@ -657,13 +657,13 @@ class AutogradFunctionVariable(VariableTracker):
def call_apply(self, tx: "InstructionTranslator", args, kwargs):
requires_grad = False
def visit(node):
def visit(vt):
nonlocal requires_grad
if isinstance(node, variables.TensorVariable):
if node.requires_grad is not False:
if isinstance(vt, variables.TensorVariable):
if vt.requires_grad is not False:
requires_grad = True
if isinstance(node, variables.NNModuleVariable):
if node.is_training(tx):
if isinstance(vt, variables.NNModuleVariable):
if vt.is_training(tx):
requires_grad = True
VariableTracker.visit(visit, (args, kwargs))

View File

@ -1090,6 +1090,30 @@ class TensorVariable(VariableTracker):
*proxy_args_kwargs([self, key, value], {}),
)
if isinstance(value, TensorVariable):
# [Note: Tensor.__setitem__ and VariableTracker metadata]
# At this point, we proxied a node representing `self[key] = value` into the graph.
# When executed, this node will mutate `self`'s tensor metadata, so it's important
# even during tracing to propagate. For example:
# value.requires_grad is True => self.requires_grad becomes True
# value.requires_grad is True => self.has_grad_fn becomes True
# Not sure if __setitem__ can ever save activations, disabling just in case
with torch._dynamo.utils._disable_saved_tensors_hooks_during_tracing():
get_fake_value(proxy.node, tx, allow_non_graph_fake=False)
example_value = self.proxy.node.meta.get("example_value")
from .builder import get_specialized_props, infer_subclass_type
if isinstance(value, variables.lazy.LazyVariableTracker):
value = variables.lazy.LazyVariableTracker.realize_all(value)
specialized_props = get_specialized_props(
type(value), tx, example_value, infer_subclass_type(example_value)
)
for k, v in specialized_props.items():
setattr(self, k, v)
if config.use_graph_deduplication or config.track_nodes_for_deduplication:
tx.output.region_tracker.add_node_mutation(proxy.node, 0)