pytorch/torch/_subclasses
Colin Peppler 7b7cd56f5e [export] support linear & layer_norm unbacked (#155260)
## What
- use `definitely_contiguous_for_memory_format` instead of `is_contiguous` when the non-contiguous case is fine if we encounter a DDE.
- use ref's contiguous over Aten's contiguous because Aten's version will DDE and stop tracing. ref's version will use `definitely_contiguous_for_memory_format` and clone if there's a DDE.

## Example DDEs

- Fixed with `definitely_contiguous_for_memory_format` in `fast_binary_impl`
```
torch._dynamo.exc.UserError: Could not guard on data-dependent expression Eq((u0//387), 0) (unhinted: Eq((u0//387), 0)).  (Size-like symbols: u0)

Caused by: layer_norm = self.layer_norm(linear)  # caffe2/test/export/test_export.py:4566 in forward (_subclasses/fake_impls.py:1022 in fast_binary_impl)
```

- Fixed with `refs.contiguous` instead of calling aten's contiguous (that'd require a bigger re-write in Aten)
```
  File "c10/core/TensorImpl.h", line 825, in torch::autograd::THPVariable_contiguous(_object*, _object*, _object*)
  File "c10/core/SymbolicShapeMeta.h", line 87, in c10::TensorImpl::is_contiguous_default(c10::MemoryFormat) const
  File "c10/core/SymbolicShapeMeta.cpp", line 250, in c10::SymbolicShapeMeta::init_is_contiguous() const

torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not guard on data-dependent expression Eq(128*((u0//387)), 0) (unhinted: Eq(128*((u0//387)), 0)).  (Size-like symbols: u0)

Caused by: (_refs/__init__.py:3302 in native_layer_norm)
```

- Fixed with `definitely_contiguous_for_memory_format` in ref's contiguous
```
torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not guard on data-dependent expression 387*((u0//387)) < 2 (unhinted: 387*((u0//387)) < 2).  (Size-like symbols: u0)

Caused by: (_prims_common/__init__.py:279 in is_contiguous)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155260
Approved by: https://github.com/laithsakka
ghstack dependencies: #155499
2025-06-11 16:47:34 +00:00
..
__init__.py [BE][Easy][14/19] enforce style for empty lines in import segments in torch/_[a-c]*/ and torch/_[e-h]*/ and torch/_[j-z]*/ (#129765) 2024-07-31 10:42:50 +00:00
_fake_tensor_utils.py Fix fake tensor caching when output has unbacked (#153034) 2025-05-23 15:03:31 +00:00
fake_impls.py [export] support linear & layer_norm unbacked (#155260) 2025-06-11 16:47:34 +00:00
fake_tensor.py [test] use JK to force graph break on slow aliasing/mutation/dynamic_shape behavior (#155257) 2025-06-09 16:21:59 +00:00
fake_utils.py PEP585 update - torch/_higher_order_ops torch/_subclasses torch/backends torch/compiler torch/cuda torch/masked torch/mtia torch/nested (#145202) 2025-01-20 22:37:26 +00:00
functional_tensor.py Add torch.Tensor._make_wrapper_subclass to torch/_C/__init__.pyi (#154022) 2025-05-27 14:10:00 +00:00
meta_utils.py [EASY] use guard_or_false instead of gso in Meta converter (#154234) 2025-05-26 21:59:52 +00:00
schema_check_mode.py Remove unused Python variables in torch/[_-a]* (#133492) 2024-12-12 17:39:14 +00:00