mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[dynamo] propagate tensor metadata on Tensor.__setitem__(tensor) (#161036)
Fixes silent incorrectness for autograd function tracing, where we rely on FakeTensor metadata (requires_grad) to determine whether to HOP or not: 5ee464db5c/torch/_dynamo/variables/misc.py (L671)
Stared at this with @anijain2305 yesterday, `Tensor.__setitem__` can update tensor metadata, and we can just run the fake prop and extract the output metadata from the updated FakeTensor.
FIXES https://github.com/pytorch/pytorch/issues/160901
It should also be the root cause behind the issue in https://github.com/pytorch/torchtitan/pull/1604 @bdhirsh @ruisizhang123
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161036
Approved by: https://github.com/anijain2305
ghstack dependencies: #160805
This commit is contained in:
parent
c7fb031706
commit
8aad3a60ce
|
|
@ -7141,6 +7141,37 @@ def forward(self, s77 : torch.SymInt, s27 : torch.SymInt, L_x_ : torch.Tensor):
|
|||
torch.compile(f, backend="eager", fullgraph=True)(eye, out_res)
|
||||
self.assertEqual(out_ref, out_res)
|
||||
|
||||
def test_setitem_tensor_prop(self):
|
||||
# Using the composite implicit of the forward would be incorrect
|
||||
class MyFn(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
return torch.matmul(x, x.t())
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_out):
|
||||
return grad_out
|
||||
|
||||
def fn(x, y):
|
||||
x[0] = y[0]
|
||||
return MyFn.apply(x)
|
||||
|
||||
def inputs():
|
||||
torch.manual_seed(123)
|
||||
x = torch.randn(10, 10)
|
||||
y = torch.randn(10, 10, requires_grad=True)
|
||||
return x, y
|
||||
|
||||
x1, y1 = inputs()
|
||||
fn(x1, y1).sum().backward()
|
||||
self.assertTrue(x1.requires_grad)
|
||||
|
||||
x2, y2 = inputs()
|
||||
torch.compile(fn, backend="eager")(x2, y2).sum().backward()
|
||||
self.assertTrue(x2.requires_grad)
|
||||
|
||||
self.assertEqual(y1.grad, y2.grad)
|
||||
|
||||
def test_nn_parameter_ctor_graph_breaks(self):
|
||||
def fn():
|
||||
param = torch.nn.Parameter(torch.ones(10))
|
||||
|
|
|
|||
|
|
@ -2055,32 +2055,8 @@ class VariableBuilder:
|
|||
return self.tx.output.input_source_to_var[source]
|
||||
|
||||
options = {}
|
||||
if type(value) in (
|
||||
torch.Tensor,
|
||||
torch.nn.Parameter,
|
||||
torch._subclasses.fake_tensor.FakeTensor,
|
||||
torch._subclasses.functional_tensor.FunctionalTensor,
|
||||
) or is_traceable_wrapper_subclass(value):
|
||||
# Ordinarily, we would fakeify a tensor so that it can get dynamic
|
||||
# shapes and be computed on without triggering actual operations.
|
||||
# However, how can we fakeify a tensor subclass? Ordinary
|
||||
# inheritance (nor multiple inheritance) won't work work.
|
||||
#
|
||||
# Instead, our plan is to *manually simulate* the tensor subclass
|
||||
# inheriting from a fake tensor with dynamo. This means our
|
||||
# data representation for a tensor subclass will be a fake tensor
|
||||
# + tensor subclass type + any extra data the subclass may have
|
||||
# been storing on the tensor. Because all Python accesses are
|
||||
# mediated through TensorWithTFOverrideVariable, we can ensure
|
||||
# that we dispatch differently, e.g., according to
|
||||
# __torch_function__
|
||||
#
|
||||
# To simplify things for now, the __dict__ tracking bits haven't
|
||||
# been implemented yet, but they can be added into this design at
|
||||
# a later point in time.
|
||||
subclass_type = None
|
||||
else:
|
||||
subclass_type = type(value)
|
||||
subclass_type = infer_subclass_type(value)
|
||||
if subclass_type is not None:
|
||||
self.install_guards(GuardBuilder.TYPE_MATCH)
|
||||
|
||||
if get_static_address_type(value) == "guarded":
|
||||
|
|
@ -3038,6 +3014,55 @@ def handle_traced_output(example_value, tx, proxy, options, subclass_type, targe
|
|||
)
|
||||
|
||||
|
||||
def infer_subclass_type(value):
|
||||
if type(value) in (
|
||||
torch.Tensor,
|
||||
torch.nn.Parameter,
|
||||
torch._subclasses.fake_tensor.FakeTensor,
|
||||
torch._subclasses.functional_tensor.FunctionalTensor,
|
||||
) or is_traceable_wrapper_subclass(value):
|
||||
# Ordinarily, we would fakeify a tensor so that it can get dynamic
|
||||
# shapes and be computed on without triggering actual operations.
|
||||
# However, how can we fakeify a tensor subclass? Ordinary
|
||||
# inheritance (nor multiple inheritance) won't work work.
|
||||
#
|
||||
# Instead, our plan is to *manually simulate* the tensor subclass
|
||||
# inheriting from a fake tensor with dynamo. This means our
|
||||
# data representation for a tensor subclass will be a fake tensor
|
||||
# + tensor subclass type + any extra data the subclass may have
|
||||
# been storing on the tensor. Because all Python accesses are
|
||||
# mediated through TensorWithTFOverrideVariable, we can ensure
|
||||
# that we dispatch differently, e.g., according to
|
||||
# __torch_function__
|
||||
#
|
||||
# To simplify things for now, the __dict__ tracking bits haven't
|
||||
# been implemented yet, but they can be added into this design at
|
||||
# a later point in time.
|
||||
return None
|
||||
else:
|
||||
return type(value)
|
||||
|
||||
|
||||
def get_specialized_props(target_cls, tx, example_value, subclass_type):
|
||||
specialized_props = target_cls.specialize(example_value)
|
||||
# TODO: not sure about this fake mode test
|
||||
if (
|
||||
isinstance(example_value, torch._subclasses.fake_tensor.FakeTensor)
|
||||
and example_value.fake_mode is tx.fake_mode
|
||||
):
|
||||
if subclass_type:
|
||||
tensor_type = subclass_type
|
||||
elif isinstance(example_value, torch.nn.Parameter):
|
||||
tensor_type = torch.nn.Parameter
|
||||
elif isinstance(example_value, torch.nn.Buffer):
|
||||
tensor_type = torch.nn.Buffer
|
||||
else:
|
||||
tensor_type = torch.Tensor
|
||||
specialized_props["class_type"] = tensor_type
|
||||
|
||||
return specialized_props
|
||||
|
||||
|
||||
def construct_tensor_variable(
|
||||
target_cls, tx, proxy, example_value, subclass_type, options
|
||||
):
|
||||
|
|
@ -3055,23 +3080,7 @@ def construct_tensor_variable(
|
|||
# when lifting unbacked symbols of input tensors to subgraph inputs.
|
||||
# We do it lazily because the tensor may not be used in subgraphs.
|
||||
tx.output.current_tracer.track_unbacked_symbols(example_value, proxy)
|
||||
specialized_props = target_cls.specialize(example_value)
|
||||
# TODO: not sure about this fake mode test
|
||||
if (
|
||||
isinstance(example_value, torch._subclasses.fake_tensor.FakeTensor)
|
||||
and example_value.fake_mode is tx.fake_mode
|
||||
):
|
||||
if subclass_type:
|
||||
tensor_type = subclass_type
|
||||
elif isinstance(example_value, torch.nn.Parameter):
|
||||
tensor_type = torch.nn.Parameter
|
||||
elif isinstance(example_value, torch.nn.Buffer):
|
||||
tensor_type = torch.nn.Buffer
|
||||
else:
|
||||
tensor_type = torch.Tensor
|
||||
specialized_props["class_type"] = tensor_type
|
||||
|
||||
options.update(specialized_props)
|
||||
options.update(get_specialized_props(target_cls, tx, example_value, subclass_type))
|
||||
return target_cls(proxy, **options)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -657,13 +657,13 @@ class AutogradFunctionVariable(VariableTracker):
|
|||
def call_apply(self, tx: "InstructionTranslator", args, kwargs):
|
||||
requires_grad = False
|
||||
|
||||
def visit(node):
|
||||
def visit(vt):
|
||||
nonlocal requires_grad
|
||||
if isinstance(node, variables.TensorVariable):
|
||||
if node.requires_grad is not False:
|
||||
if isinstance(vt, variables.TensorVariable):
|
||||
if vt.requires_grad is not False:
|
||||
requires_grad = True
|
||||
if isinstance(node, variables.NNModuleVariable):
|
||||
if node.is_training(tx):
|
||||
if isinstance(vt, variables.NNModuleVariable):
|
||||
if vt.is_training(tx):
|
||||
requires_grad = True
|
||||
|
||||
VariableTracker.visit(visit, (args, kwargs))
|
||||
|
|
|
|||
|
|
@ -1090,6 +1090,30 @@ class TensorVariable(VariableTracker):
|
|||
*proxy_args_kwargs([self, key, value], {}),
|
||||
)
|
||||
|
||||
if isinstance(value, TensorVariable):
|
||||
# [Note: Tensor.__setitem__ and VariableTracker metadata]
|
||||
# At this point, we proxied a node representing `self[key] = value` into the graph.
|
||||
# When executed, this node will mutate `self`'s tensor metadata, so it's important
|
||||
# even during tracing to propagate. For example:
|
||||
# value.requires_grad is True => self.requires_grad becomes True
|
||||
# value.requires_grad is True => self.has_grad_fn becomes True
|
||||
|
||||
# Not sure if __setitem__ can ever save activations, disabling just in case
|
||||
with torch._dynamo.utils._disable_saved_tensors_hooks_during_tracing():
|
||||
get_fake_value(proxy.node, tx, allow_non_graph_fake=False)
|
||||
|
||||
example_value = self.proxy.node.meta.get("example_value")
|
||||
from .builder import get_specialized_props, infer_subclass_type
|
||||
|
||||
if isinstance(value, variables.lazy.LazyVariableTracker):
|
||||
value = variables.lazy.LazyVariableTracker.realize_all(value)
|
||||
|
||||
specialized_props = get_specialized_props(
|
||||
type(value), tx, example_value, infer_subclass_type(example_value)
|
||||
)
|
||||
for k, v in specialized_props.items():
|
||||
setattr(self, k, v)
|
||||
|
||||
if config.use_graph_deduplication or config.track_nodes_for_deduplication:
|
||||
tx.output.region_tracker.add_node_mutation(proxy.node, 0)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user