pytorch/torch
Aaron Orenstein 5f21cc786a Teach ProxyTorchDispatchMode how to decompose sympy.Expr into known inputs (#164717)
In a training library we hit a weird conflict between dtensor, dynamic shapes, and proxy tensor.

The problem is occuring because in sharding_prop we use FakeTensors to compute an operation size (so we don't have to  use the full "real" data). We turn off proxy tracing while we're doing that because we don't want the FakeTensor ops to end up in the graph.  We then use that size when doing later operations.

Normally this is no problem - but when those sizes are dynamic shapes then we have a problem - the proxy tracer wants to track the provenance of all shape operations (`s1*s2`) but since tracing is disabled it doesn't see the operation and when we then use the result shape later on the proxy tracer gets all confused (because the SymNode appeared out of nowhere).

At first we were thinking to never disable shape tracing - but that caused a slew of other downstream problems (lots of code that actually needs the shape tracing to be disabled) so instead we enable having a "sym tracing override" and surgically when we disable proxy tracing we leave shape tracing enabled.

After this change the dtensor embedding is "fixed" but then runs afoul of a FakeTensor cache bug - which is fixed in the next PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164717
Approved by: https://github.com/bobrenjc93, https://github.com/ezyang
ghstack dependencies: #165266
2025-10-16 20:57:06 +00:00
..
_awaits
_C [ROCm][tunableop] Improvements to tunableop Numerical Check (#163079) 2025-10-15 22:26:47 +00:00
_C_flatbuffer
_custom_op [2/N] Fix ruff warnings (#164460) 2025-10-04 03:40:32 +00:00
_decomp Ensure rms_norm decomp generates add.Scalar for pattern match BC (#165437) 2025-10-14 19:56:37 +00:00
_dispatch Add initial suppressions for pyrefly (#164177) 2025-10-02 20:57:41 +00:00
_dynamo [RFC] Add pyrefly to lintrunner (#165179) 2025-10-16 20:07:09 +00:00
_export [torch.export] Rmoving unused constants - add support for corner case (#165205) 2025-10-14 20:26:28 +00:00
_functorch [RFC] Add pyrefly to lintrunner (#165179) 2025-10-16 20:07:09 +00:00
_higher_order_ops [hop] run local_map with interpreter to preserve fx_traceback annotations (#165336) 2025-10-16 02:53:17 +00:00
_inductor Use union syntax in torch/_inductor runtime and fx_passes (#165652) 2025-10-16 20:51:59 +00:00
_lazy Add initial suppressions for pyrefly (#164177) 2025-10-02 20:57:41 +00:00
_library [RFC] Add pyrefly to lintrunner (#165179) 2025-10-16 20:07:09 +00:00
_logging Enable ruff rule E721 (#165162) 2025-10-13 01:48:55 +00:00
_numpy Enable ruff rule E721 (#165162) 2025-10-13 01:48:55 +00:00
_prims Add pyrefly suppressions 2/n (#164513) 2025-10-03 02:46:13 +00:00
_prims_common [2/N] Use "is" in python type comparison (#165142) 2025-10-10 15:36:44 +00:00
_refs Enable ruff rule E721 (#165162) 2025-10-13 01:48:55 +00:00
_strobelight Add initial suppressions for pyrefly (#164177) 2025-10-02 20:57:41 +00:00
_subclasses Do not decompose in functionalization/proxy tensor if autograd wouldn't have decomposed (#164939) 2025-10-11 01:03:55 +00:00
_vendor
accelerator
amp Revert "[AMP][Refactor] Simplify dtype support logic in autocast context manager (#163446)" 2025-10-10 15:12:46 +00:00
ao [2/N] More ruff SIM fixes (#165031) 2025-10-14 14:22:54 +00:00
autograd [2/N] More ruff SIM fixes (#165031) 2025-10-14 14:22:54 +00:00
backends Revert "Add SVE128 ISA (#158932)" 2025-10-10 01:17:02 +00:00
compiler Megacache integration (#163533) 2025-10-15 22:49:15 +00:00
contrib
cpu Add initial suppressions for pyrefly (#164177) 2025-10-02 20:57:41 +00:00
csrc refactor: replace runtime_error with TORCH_CHECK for better error handling (#163628) 2025-10-16 11:09:48 +00:00
cuda [ROCm][tunableop] Improvements to tunableop Numerical Check (#163079) 2025-10-15 22:26:47 +00:00
distributed Enable local tensor mode on DTensor view ops test (#165596) 2025-10-16 20:52:06 +00:00
distributions [1/N] Use "is" in python type comparison (#165037) 2025-10-10 12:36:50 +00:00
export [RFC] Add pyrefly to lintrunner (#165179) 2025-10-16 20:07:09 +00:00
fft
func
futures [5/N] Apply ruff UP035 rule (#164423) 2025-10-02 07:31:11 +00:00
fx Teach ProxyTorchDispatchMode how to decompose sympy.Expr into known inputs (#164717) 2025-10-16 20:57:06 +00:00
headeronly Move toString(ScalarType) and ScalarType ostream operator to headeronly (#164405) 2025-10-16 00:55:43 +00:00
jit Fix missing brackets (#165138) 2025-10-10 17:23:31 +00:00
legacy
lib [2/N] Mark unused parameters in C++ code (#165121) 2025-10-15 03:04:39 +00:00
linalg Add initial suppressions for pyrefly (#164177) 2025-10-02 20:57:41 +00:00
masked [1/N] Use "is" in python type comparison (#165037) 2025-10-10 12:36:50 +00:00
monitor
mps Add type annotations to MPS profiler utilities (#163486) 2025-09-27 23:00:53 +00:00
mtia Add initial suppressions for pyrefly (#164177) 2025-10-02 20:57:41 +00:00
multiprocessing Add initial suppressions for pyrefly (#164177) 2025-10-02 20:57:41 +00:00
nativert [2/N] Mark unused parameters in C++ code (#165121) 2025-10-15 03:04:39 +00:00
nested [NJT] Fix schema validation error in jagged functions (#165307) 2025-10-13 17:59:18 +00:00
nn [RFC] Add pyrefly to lintrunner (#165179) 2025-10-16 20:07:09 +00:00
numa Add initial suppressions for pyrefly (#164177) 2025-10-02 20:57:41 +00:00
onnx [RFC] Add pyrefly to lintrunner (#165179) 2025-10-16 20:07:09 +00:00
optim [RFC] Add pyrefly to lintrunner (#165179) 2025-10-16 20:07:09 +00:00
package [1/N] Use "is" in python type comparison (#165037) 2025-10-10 12:36:50 +00:00
profiler Pyrefly suppressions 6/n (#164877) 2025-10-08 02:30:57 +00:00
quantization [RFC] Add pyrefly to lintrunner (#165179) 2025-10-16 20:07:09 +00:00
signal Add initial suppressions for pyrefly (#164177) 2025-10-02 20:57:41 +00:00
sparse [RFC] Add pyrefly to lintrunner (#165179) 2025-10-16 20:07:09 +00:00
special
testing Enable local tensor mode on DTensor view ops test (#165596) 2025-10-16 20:52:06 +00:00
utils [Fix] Use sys.executable instead of hardcoded python (#165633) 2025-10-16 20:26:10 +00:00
xpu Add a new API torch.xpu.is_tf32_supported for Intel GPU (#163141) 2025-10-12 12:11:57 +00:00
__config__.py
__future__.py
__init__.py [2/N] Use "is" in python type comparison (#165142) 2025-10-10 15:36:44 +00:00
_appdirs.py
_classes.py
_compile.py [4/N] Apply ruff UP035 rule to python code (#164206) 2025-10-01 19:05:53 +00:00
_custom_ops.py
_environment.py
_guards.py [4/N] Apply ruff UP035 rule to python code (#164206) 2025-10-01 19:05:53 +00:00
_jit_internal.py [2/N] More ruff SIM fixes (#165031) 2025-10-14 14:22:54 +00:00
_linalg_utils.py
_lobpcg.py Add initial suppressions for pyrefly (#164177) 2025-10-02 20:57:41 +00:00
_lowrank.py
_meta_registrations.py [2/N] More ruff SIM fixes (#165031) 2025-10-14 14:22:54 +00:00
_namedtensor_internals.py
_ops.py [2/N] More ruff SIM fixes (#165031) 2025-10-14 14:22:54 +00:00
_python_dispatcher.py
_size_docs.py
_sources.py
_storage_docs.py
_streambase.py
_tensor_docs.py [reland] Allow setting grad_dtype on leaf tensors (#164751) 2025-10-08 20:23:13 +00:00
_tensor_str.py Pyrefly suppressions 6/n (#164877) 2025-10-08 02:30:57 +00:00
_tensor.py Pyrefly suppressions 6/n (#164877) 2025-10-08 02:30:57 +00:00
_thread_safe_fork.py
_torch_docs.py Update documentation for torch.index_select (#163616) 2025-09-25 18:29:17 +00:00
_utils_internal.py Revert "Call internal log_compilation_event if it exists (#164855)" 2025-10-09 22:38:45 +00:00
_utils.py Enable ruff rule E721 (#165162) 2025-10-13 01:48:55 +00:00
_VF.py
_vmap_internals.py [4/N] Apply ruff UP035 rule to python code (#164206) 2025-10-01 19:05:53 +00:00
_weights_only_unpickler.py [4/N] Apply ruff UP035 rule to python code (#164206) 2025-10-01 19:05:53 +00:00
CMakeLists.txt Revert "[RELAND] Always build USE_DISTRIBUTED (#160449) and Make distributed modules importable even when backend not built (#159889) (#162594)" 2025-09-25 13:47:46 +00:00
custom_class_detail.h Mark unused parameters in C++ code (#164912) 2025-10-09 06:23:25 +00:00
custom_class.h Mark unused parameters in C++ code (#164912) 2025-10-09 06:23:25 +00:00
extension.h
functional.py [2/N] Fix ruff warnings (#164460) 2025-10-04 03:40:32 +00:00
header_only_apis.txt Move toString(ScalarType) and ScalarType ostream operator to headeronly (#164405) 2025-10-16 00:55:43 +00:00
hub.py Add initial suppressions for pyrefly (#164177) 2025-10-02 20:57:41 +00:00
library.h Mark unused parameters in C++ code (#164912) 2025-10-09 06:23:25 +00:00
library.py Add initial suppressions for pyrefly (#164177) 2025-10-02 20:57:41 +00:00
overrides.py Add scaled_grouped_mm_v2 and python API (#165154) 2025-10-15 17:47:23 +00:00
py.typed
quasirandom.py
random.py Revert "Add device argument to torch.random.get_rng_state (#163034)" 2025-10-04 15:25:45 +00:00
return_types.py
script.h
serialization.py Add initial suppressions for pyrefly (#164177) 2025-10-02 20:57:41 +00:00
storage.py [2/N] Use "is" in python type comparison (#165142) 2025-10-10 15:36:44 +00:00
torch_version.py
types.py [4/N] Apply ruff UP035 rule to python code (#164206) 2025-10-01 19:05:53 +00:00
version.py.tpl