pytorch/torch/_dynamo
Boyuan Feng 3ef031909f [Donated Buffer] support metadata mutation ops (#141308)
### Background:

`set(x,y)` changes the untyped storage of x to be the same as y.

```python
import torch
from torch._subclasses.fake_tensor import FakeTensorMode

x1 = torch.ones(2,3)
y1 = torch.ones(2,3)
z1 = torch.ops.aten.set_.source_Tensor(x1, y1)

fake_tensor_mode = FakeTensorMode()
x2 = fake_tensor_mode.from_tensor(torch.ones(2,3))
y2 = fake_tensor_mode.from_tensor(torch.ones(2,3))
z2 = torch.ops.aten.set_.source_Tensor(x2, y2)

print(f"x1: {x1.untyped_storage()._cdata}, y1: {y1.untyped_storage()._cdata}, z1: {z1.untyped_storage()._cdata}")
print(f"x2: {x2.untyped_storage()._cdata}, y2: {y2.untyped_storage()._cdata}, z2: {z2.untyped_storage()._cdata}")
# x1: 99973024, y1: 99973024, z1: 99973024
# x2: 112107232, y2: 112107232, z2: 112107232
```

### Error before this diff

Consider this example:
```python
import torch

def fn(x):
    p = torch.nn.Parameter(x + 123)
    return p, p.sin()

opt = torch.compile(fn, fullgraph=True)
x = torch.ones(16, device="cuda", requires_grad=True)

p, r = opt(x)
r.sum().backward()
```

When running with `TORCH_LOGS=aot`, we have `set_` in the graph.
```
def forward(self, primals_1: "f32[16][1]cuda:0", primals_2: "f32[16][1]cuda:0"):
   # File: /home/boyuan/playground/inductor/donated_buffer.py:4 in fn, code: p = torch.nn.Parameter(x + 123)
  add: "f32[16][1]cuda:0" = torch.ops.aten.add.Tensor(primals_1, 123);  primals_1 = None

   # File: /home/boyuan/playground/inductor/donated_buffer.py:5 in fn, code: return p, p.sin()
  sin: "f32[16][1]cuda:0" = torch.ops.aten.sin.default(add)

  # No stacktrace found for following nodes
  set_: "f32[16][1]cuda:0" = torch.ops.aten.set_.source_Tensor(primals_2, add);  primals_2 = set_ = None
  return (sin, add)
```

`set_: "f32[16][1]cuda:0" = torch.ops.aten.set_.source_Tensor(primals_2, add)` should change the storage of `primals_2` to be the same as `add`. However, this is not true before this diff. We found different untyped_storage() for meta['val'] of `set_`, `add`, and `primals_2`.

This also leads to an error with donated buffer (#130580), which checks alias by untyped_storage. Since `add` and `primals_2` have different untyped_storage (which is wrong), add is wrongly marked as donated buffer.

### Root Cause

During tracing, we have args, kwargs, out, and proxy_args, proxy_kwargs, proxy_out.

We use args and kwargs to compute `out = func(*args, **kwargs)` ([Here](https://github.com/pytorch/pytorch/blob/main/torch/fx/experimental/proxy_tensor.py#L912)). Later, we set out to its proxy, essentially calling `proxy_out.node.meta["val"] = out.detach()`.

Due to the detach, the storage change happens on args but not on proxy_args.node.meta["val"] when func is torch.ops.aten.set_. I repro'ed this behavior of detach in eager code.

```python
import torch

x = torch.ones(2,3)
x_detach = x.detach()
y = torch.ones(2,3)
z = torch.ops.aten.set_.source_Tensor(x_detach, y)

print(f"x: {x.untyped_storage()._cdata}, x_detach: {x_detach.untyped_storage()._cdata}, y: {y.untyped_storage()._cdata}, z: {z.untyped_storage()._cdata}")
# x: 97023632, x_detach: 97026480, y: 97026480, z: 97026480
```

To fix the issue, this PR manually resets node.meta["val"] if the storage has changed.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141308
Approved by: https://github.com/bdhirsh
2024-11-26 17:06:46 +00:00
..
backends Restart dynamo analysis when we fail to tensorify away all symfloat inputs (#140346) 2024-11-20 21:20:41 +00:00
polyfills [dynamo] Fix and simplify hanlding of Set.update method (#141286) 2024-11-26 00:41:50 +00:00
repro [AOTI Minifier] Save EP instead of graphs (#141159) 2024-11-22 01:51:10 +00:00
variables [Dynamo][autograd.Function] Use fake tensor prop to infer fwd output (#136184) 2024-11-26 01:10:08 +00:00
__init__.py Restart dynamo analysis when we fail to tensorify away all symfloat inputs (#140346) 2024-11-20 21:20:41 +00:00
_trace_wrapped_higher_order_op.py [FlexAttention] Rename zeros_and_scatter library (#141185) 2024-11-21 21:35:48 +00:00
bytecode_analysis.py [BE][Easy][15/19] enforce style for empty lines in import segments in torch/_d*/ (#129767) 2024-07-31 21:18:11 +00:00
bytecode_transformation.py [BE]: Apply PERF401 autofixes from ruff (#140980) 2024-11-20 17:52:07 +00:00
cache_size.py [dynamo] Add a DynamoFrameType type above Python frame object (#140330) 2024-11-15 17:17:30 +00:00
callback.py [PT2] Fix callbacks to account for entire execution in compilation (#141323) 2024-11-24 22:31:04 +00:00
code_context.py add types to _dynamo/code_context.py (#136665) 2024-09-27 18:27:42 +00:00
codegen.py [dynamo] Represent all cells as NewCellVariable (#140153) 2024-11-15 17:17:30 +00:00
compiled_autograd.py [ca] dead code elimination for compile time (#141289) 2024-11-22 19:26:27 +00:00
comptime.py [dynamo] Remove closure_cells and merge/remove code paths (#140154) 2024-11-15 17:17:30 +00:00
config.py Revert "Always unspecialize float in OSS (#138922)" 2024-11-26 00:03:03 +00:00
convert_frame.py [logging] Move population of common MetricsContext fields to record_compilation_metrics (#141291) 2024-11-25 13:18:40 +00:00
create_parameter_op.py [Donated Buffer] support metadata mutation ops (#141308) 2024-11-26 17:06:46 +00:00
current_scope_id.py add typing to _dynamo/current_scope_id.py (#136676) 2024-09-27 04:09:15 +00:00
debug_utils.py Ensure TORCH_TRACE is run for Dynamo/Distributed tests (#139786) 2024-11-07 01:58:05 +00:00
decorators.py [dynamo] skip_guard_eval_unsafe stance for power users (#140251) 2024-11-21 06:28:58 +00:00
device_interface.py Have Triton custom extension test use privateuseone device (#137611) 2024-10-11 21:27:29 +00:00
distributed.py [BE][Easy][15/19] enforce style for empty lines in import segments in torch/_d*/ (#129767) 2024-07-31 21:18:11 +00:00
eval_frame.py [ca] expose option to collect sizes as dynamic (#141153) 2024-11-22 19:26:27 +00:00
exc.py Restart dynamo analysis when we fail to tensorify away all symfloat inputs (#140346) 2024-11-20 21:20:41 +00:00
external_utils.py Revert "Deprecate torch._utils.is_compiling() and torch._dynamo.external_utils.is_compiling() (#127690)" 2024-11-05 23:10:38 +00:00
funcname_cache.py [BE][Easy][15/19] enforce style for empty lines in import segments in torch/_d*/ (#129767) 2024-07-31 21:18:11 +00:00
guards.py [dynamo] skip_guard_eval_unsafe stance for power users (#140251) 2024-11-21 06:28:58 +00:00
hooks.py [BE][Easy][15/19] enforce style for empty lines in import segments in torch/_d*/ (#129767) 2024-07-31 21:18:11 +00:00
logging.py Make automatic_dynamic state live per CodeId, rather than on code object (#138740) 2024-10-27 03:08:41 +00:00
metrics_context.py [logging] Move population of common MetricsContext fields to record_compilation_metrics (#141291) 2024-11-25 13:18:40 +00:00
mutation_guard.py Allow Lazy Module to be modelled as UnspecializedNNModuleVariable (#138639) 2024-10-26 02:17:07 +00:00
output_graph.py Restart dynamo analysis when we fail to tensorify away all symfloat inputs (#140346) 2024-11-20 21:20:41 +00:00
pgo.py Make PGO work correctly with NJT inputs (#140046) 2024-11-08 04:27:39 +00:00
profiler.py type _dynamo/profiler.py (#137351) 2024-10-07 18:54:33 +00:00
replay_record.py [dynamo] Identify pre-existing captured cells by cell id rather than content id (#140436) 2024-11-15 17:17:30 +00:00
resume_execution.py [BE]: Apply PERF401 autofixes from ruff (#140980) 2024-11-20 17:52:07 +00:00
side_effects.py [dynamo] Represent all cells as NewCellVariable (#140153) 2024-11-15 17:17:30 +00:00
source.py [dynamo] Represent all cells as NewCellVariable (#140153) 2024-11-15 17:17:30 +00:00
symbolic_convert.py misc. fixes to unflatten (#141066) 2024-11-23 07:31:51 +00:00
tensor_version_op.py [BE][Easy][15/19] enforce style for empty lines in import segments in torch/_d*/ (#129767) 2024-07-31 21:18:11 +00:00
test_case.py Ensure TORCH_TRACE is run for Dynamo/Distributed tests (#139786) 2024-11-07 01:58:05 +00:00
test_minifier_common.py AOTI Minifier (#139351) 2024-11-07 21:43:44 +00:00
testing.py [BE]: Apply PERF401 autofixes from ruff (#140980) 2024-11-20 17:52:07 +00:00
trace_rules.py [dynamo] Trace through dataclasses by removing it from BUILTIN_SKIPLIST (#141294) 2024-11-26 17:05:23 +00:00
types.py [dynamo] Add a DynamoFrameType type above Python frame object (#140330) 2024-11-15 17:17:30 +00:00
utils.py [logging] Move population of common MetricsContext fields to record_compilation_metrics (#141291) 2024-11-25 13:18:40 +00:00