diff --git a/test/distributed/tensor/test_dynamic.py b/test/distributed/tensor/test_dynamic.py index 963428fecf8..a53f9e6d8dd 100644 --- a/test/distributed/tensor/test_dynamic.py +++ b/test/distributed/tensor/test_dynamic.py @@ -22,8 +22,7 @@ from torch.testing._internal.triton_utils import requires_gpu class TestDynamic(DTensorTestBase): @requires_gpu @with_comms - # FIXME: Currently broken for fake tensor cache - @parametrize("fake_tensor_cache_enabled", [False]) + @parametrize("fake_tensor_cache_enabled", [False, True]) def test_embedding(self, fake_tensor_cache_enabled): with patch.object( torch._dynamo.config, "fake_tensor_cache_enabled", fake_tensor_cache_enabled diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index a01991e19e6..31d129a3c86 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -1538,7 +1538,7 @@ class FakeTensorMode(TorchDispatchMode): try: # pyrefly: ignore # bad-argument-type - self._validate_cache_key(func, args, kwargs) + entry = self._make_cache_entry(state, key, func, args, kwargs, output) except _BypassDispatchCache as e: # We ran "extra" checks on the cache key and determined that it's no # good. Record the reason and mark it so we don't bother validating @@ -1556,16 +1556,6 @@ class FakeTensorMode(TorchDispatchMode): set_cache_key(cache, key, _DispatchCacheBypassEntry(e.reason)) return output - try: - # pyrefly: ignore # bad-argument-type - entry = self._make_cache_entry(state, key, func, args, kwargs, output) - except _BypassDispatchCache as e: - # We had trouble making the cache entry. Record the reason and mark - # it. - FakeTensorMode.cache_bypasses[e.reason] += 1 - set_cache_key(cache, key, _DispatchCacheBypassEntry(e.reason)) - return output - set_cache_key(cache, key, entry) FakeTensorMode.cache_misses += 1 return output @@ -1581,6 +1571,7 @@ class FakeTensorMode(TorchDispatchMode): Create a cache key given the dispatch args. Raises _BypassDispatchCache for any situation that precludes caching. """ + is_tracing = torch.fx.experimental.proxy_tensor.get_proxy_mode() is not None key_values = [ func, # Capture the default_dtype mode since that can affect the output tensor, @@ -1596,6 +1587,10 @@ class FakeTensorMode(TorchDispatchMode): # Disallowing dynamic shapes can introduce a DynamicOutputShapeException # where it wasn't seen on a previous instance of the same op. self.shape_env.settings if self.shape_env else None, + # ProxyTorchDispatchMode needs to track how SymNodes are constructed + # so we need to handle things a little different depending on + # whether we're tracing or not. + is_tracing, ] if state.known_symbols: # If there are symbols then include the epoch - this is really more @@ -1776,11 +1771,9 @@ class FakeTensorMode(TorchDispatchMode): if isinstance(output, (int, type(None))): return - if _has_unrepresented_symbols(state, output): - # Unbacked symbols are fine - but only if they're also represented - # in the input. If there are any new unbacked symbols then we can't - # cache this output. - raise _BypassDispatchCache("unrepresented symbol in output") + # Check for symbolic content that should bypass caching - raises + # _BypassDispatchCache if necessary. + _validate_symbolic_output_for_caching(state, output) # Some ops return tuples of Tensors, but it's rare, so avoid # the complexity of caching other types. @@ -1896,6 +1889,8 @@ class FakeTensorMode(TorchDispatchMode): from torch._higher_order_ops.utils import registered_hop_fake_fns from torch.fx.experimental.symbolic_shapes import has_free_unbacked_symbols + self._validate_cache_key(func, args, kwargs) + # For hops, lets look at the output tensor to find any unbacked symints. # If there are none, then we rely on the existing checks to validate # caching. @@ -3072,17 +3067,65 @@ class FakeTensorMode(TorchDispatchMode): _StoragePointer = object -def _has_unrepresented_symbols( - state: _CacheKeyState, output: Optional[FakeTensor] -) -> bool: - from torch.fx.experimental.symbolic_shapes import _iterate_exprs +def _validate_symbolic_output_for_caching( + state: _CacheKeyState, output: FakeTensor +) -> None: + """ + Validate symbolic content in output and raise _BypassDispatchCache if + caching should be bypassed. - for s in _iterate_exprs(output): - for symbol in s.free_symbols: - if symbol not in state.known_symbols: - return True + Args: + state: Cache key state containing known symbols + output: Output to validate + proxy_mode_active: Whether PROXY dispatch mode is currently active - return False + Raises: _BypassDispatchCache: If output contains symbolic content that + prevents caching + + Details: + + If our output contains any symbols that didn't appear in the input then we + need to bypass. Usually this will be unbacked symbols which can't be + properly reconstructed but there could be "weird" cases where backed symbols + spontaneously appear (from non-input state)? + + If we're proxy (symbol) tracing and the output contains ANY symbols then we + need to bypass. The problem is that ProxyTorchDispatchMode relies on SymNode + object identity and being able to see the construction of SymNodes. + + We could improve the proxy tracing case in a few ways: + + 1. If the output SymNodes are directly copied from inputs then this is + actually fine - they're already tracked. This would probably be the + biggest bang/buck. + + 2. If the output (tensors) are all direct copies of the inputs then this is + also fine - since they're inputs they must be tracked. We already compute + this we just don't plumb it around enough. + + 3. If the output SymNodes are already tracked by the proxy then this is also + actually fine - they're properly tracked. This probably wouldn't be + common since for most outputs we use torch.empty_strided() and recompute + strides. + + 4. We could use the proxy to track "how" the SymNodes were computed and when + using the cache we could "replay" them properly to teach the proxy how to + build them. + """ + from torch.fx.experimental.symbolic_shapes import _iterate_exprs, _iterate_nodes + + is_tracing = torch.fx.experimental.proxy_tensor.get_proxy_mode() is not None + if is_tracing: + # Check for SymNode types in PROXY mode - this should bypass caching + # regardless of whether symbols are known or not + for node in _iterate_nodes(output): + raise _BypassDispatchCache("Proxy mode with SymNode output") + else: + # Check for unrepresented symbols in tensor expressions + for s in _iterate_exprs(output): + for symbol in s.free_symbols: + if symbol not in state.known_symbols: + raise _BypassDispatchCache("unrepresented symbol in output") # NB: returns fake tensors diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index bbe84a2e414..771e7527201 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -883,11 +883,16 @@ def _iterate_exprs(val: IterateExprs) -> Iterator[sympy.Basic]: Raises: AssertionError: If the value is of an unsupported type. """ + # This is almost close enough to implement in terms of _iterate_nodes() + # except that it needs to handle `list[sympy.Basic]` which _iterate_nodes() + # can't handle. if isinstance(val, SymTypes): # This allow applies to the jagged layout NestedTensor case as # nested ints are not symbolic if is_symbolic(val): yield val.node.expr + elif isinstance(val, SymNode): + yield val.expr elif isinstance(val, sympy.Basic): yield val elif isinstance(val, (int, float, bool)): @@ -910,6 +915,28 @@ def _iterate_exprs(val: IterateExprs) -> Iterator[sympy.Basic]: raise AssertionError(f"cannot extract sympy expressions from {val} {type(val)}") +def _iterate_nodes(val: Any) -> Iterator[SymNode]: + """ + Recursively iterate through a value and yield all SymNodes contained + within it. + """ + if isinstance(val, SymNode): + yield val + elif isinstance(val, py_sym_types): + # This allow applies to the jagged layout NestedTensor case as + # nested ints are not symbolic + if is_symbolic(val): + yield val.node + elif isinstance(val, (tuple, list, torch.Size)): + for s in val: + yield from _iterate_nodes(s) + elif isinstance(val, torch.Tensor): + yield from _iterate_nodes(val.size()) + if not is_sparse_any(val): + yield from _iterate_nodes(val.stride()) + yield from _iterate_nodes(val.storage_offset()) + + def free_symbols(val: IterateExprs) -> OrderedSet[sympy.Symbol]: """ Recursively collect all free symbols from a value.