From f19e07b0565a7e12f11ae63c159a9b054772cb6d Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Thu, 9 May 2024 13:57:47 -0700 Subject: [PATCH] Memoize local_scalar_dense calls, refactor all memos (#125623) Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/125623 Approved by: https://github.com/eellison --- test/export/test_export.py | 5 +- .../collect_metadata_analysis.py | 4 +- torch/_subclasses/fake_impls.py | 31 ++--- torch/_subclasses/fake_tensor.py | 119 +++++++++--------- 4 files changed, 82 insertions(+), 77 deletions(-) diff --git a/test/export/test_export.py b/test/export/test_export.py index dc3ff219163..586fc403da9 100644 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -180,7 +180,7 @@ class TestDynamismExpression(TestCase): res = gm(*inp) self.assertTrue(torchdynamo.utils.same(ref, res)) - def test_export_constraints_error(self): + def test_export_constraints_error_not_in_range(self): class InvalidInputConflictWithInputConstraints(torch.nn.Module): def forward(self, x): return x + 1 @@ -194,6 +194,7 @@ class TestDynamismExpression(TestCase): dynamic_shapes={"x": {0: dim_x}}, ) + def test_export_constraints_error(self): class ConflictingConstraints(torch.nn.Module): def forward(self, x): b = x.item() @@ -2432,7 +2433,7 @@ def forward(self, x): # This is because we insert sym_constrain_range in the graph now if is_non_strict_test(self._testMethodName): - error_msg = "Invalid value range" + error_msg = r"Invalid value range for -1 between" else: error_msg = "is outside of inline constraint" with self.assertRaisesRegex(RuntimeError, error_msg): diff --git a/torch/_functorch/_aot_autograd/collect_metadata_analysis.py b/torch/_functorch/_aot_autograd/collect_metadata_analysis.py index 3a1b9eac530..e01f6df6957 100644 --- a/torch/_functorch/_aot_autograd/collect_metadata_analysis.py +++ b/torch/_functorch/_aot_autograd/collect_metadata_analysis.py @@ -161,9 +161,11 @@ def run_functionalized_fw_and_collect_metadata( flat_f_args = pytree.tree_map(_to_fun, flat_args) flat_f_outs = f(*flat_f_args) # We didn't do any tracing, so we don't need to process the - # unbacked symbols, they will just disappear into the ether + # unbacked symbols, they will just disappear into the ether. + # Also, prevent memoization from applying. if (fake_mode := detect_fake_mode()) and (shape_env := fake_mode.shape_env): shape_env.pending_fresh_unbacked_symbols.clear() + fake_mode.epoch += 1 if prior_autocast_states != _get_autocast_states(): raise RuntimeError( diff --git a/torch/_subclasses/fake_impls.py b/torch/_subclasses/fake_impls.py index a0cc772c189..112c81979c0 100644 --- a/torch/_subclasses/fake_impls.py +++ b/torch/_subclasses/fake_impls.py @@ -269,7 +269,7 @@ def unique2( # Without symints/symfloats, cannot handle this raise DynamicOutputShapeException(func) - if arg.unique_memo is None: + if (nnz := arg.unique_memo) is None: # Avoid importing sympy at a module level from torch.fx.experimental.symbolic_shapes import ( _constrain_range_for_size, @@ -285,8 +285,7 @@ def unique2( # symint cannot equal zero). We could also unconditionally # allocate an unbacked SymInt and not refine its range, # but this seems more precise. - nnz = arg._unique_memo = 0 - arg._unique_memo_vc = arg._version + nnz = 0 else: nnz = fake_mode.shape_env.create_unbacked_symint() @@ -299,7 +298,7 @@ def unique2( arg.unique_memo = nnz - ret = [arg.new_empty((arg.unique_memo,))] + ret = [arg.new_empty((nnz,))] if return_inverse: ret.append(torch.empty_like(arg)) @@ -341,14 +340,18 @@ def local_scalar_dense(fake_mode, func, arg): ): # Without symints/symfloats, cannot handle this raise DataDependentOutputException(func) + if (r := arg.item_memo) is not None: + return r if is_float_dtype(arg.dtype): - return fake_mode.shape_env.create_unbacked_symfloat() + r = fake_mode.shape_env.create_unbacked_symfloat() elif is_integer_dtype(arg.dtype): - return fake_mode.shape_env.create_unbacked_symint() + r = fake_mode.shape_env.create_unbacked_symint() elif is_boolean_dtype(arg.dtype): - return fake_mode.shape_env.create_unbacked_symbool() + r = fake_mode.shape_env.create_unbacked_symbool() else: raise NotImplementedError(f"local_scalar_dense/item NYI for {arg.dtype}") + arg.item_memo = r + return r @register_op_impl(torch.ops.aten.nonzero.default) @@ -360,9 +363,7 @@ def nonzero(fake_mode, func, arg): # Without symints/symfloats, cannot handle this raise DynamicOutputShapeException(func) - if arg.nonzero_memo is not None: - nnz = arg.nonzero_memo - else: + if (nnz := arg.nonzero_memo) is None: # Avoid importing sympy at a module level from torch.fx.experimental.symbolic_shapes import ( _constrain_range_for_size, @@ -378,9 +379,7 @@ def nonzero(fake_mode, func, arg): # symint cannot equal zero). We could also unconditionally # allocate an unbacked SymInt and not refine its range, # but this seems more precise. - nnz = arg._nonzero_memo = 0 - arg._nonzero_memo_vc = arg._version - arg._nonzero_memo_epoch = fake_mode.epoch + nnz = 0 else: nnz = fake_mode.shape_env.create_unbacked_symint() @@ -391,11 +390,7 @@ def nonzero(fake_mode, func, arg): _constrain_range_for_size(nnz, max=maxval) - if not torch.is_inference_mode_enabled(): - # arg._version N/A in inference mode - arg._nonzero_memo = nnz - arg._nonzero_memo_vc = arg._version - arg._nonzero_memo_epoch = fake_mode.epoch + arg.nonzero_memo = nnz return arg.new_empty((nnz, arg.dim()), dtype=torch.int64) diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index 0b28e09d682..3fdc4fc01e6 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -394,6 +394,60 @@ class FakeTensorConfig: debug = os.environ.get("TORCH_FAKE_TENSOR_DEBUG", "0") == "1" +# This memorizes the unbacked SymInt representing quantities like the number +# of nonzero elements in this tensor. There is one instance of the descriptor +# per particular quantity to memoize. +# +# Memoization is helpful if you do something like x[mask] and y[mask]; +# mask.nonzero() gets repeatedly called and should give a consistent unbacked +# SymInt. It needs to be invalidated in the same way constant is. +# +# Making this a descriptor may seem overly fancy, but actually it's the most +# convenient way to make sure we have access to FakeTensor during access, +# which is required for testing version counter and epoch validity +class UnbackedMemoDescriptor: + _name: str + + def __set_name__(self, owner, name): + self._name = name + + def _memo(self, obj): + return f"_{self._name}" + + def _memo_vc(self, obj): + return f"_{self._name}_vc" + + # When we retrace, we need to invalidate all the memos so that we can + # accurately identify the first time unbacked SymInts are allocated. + # This is only relevant for inputs; for intermediates, they will get fresh + # fake tensors so you won't have a memo anyway + def _memo_epoch(self, obj): + return f"_{self._name}_epoch" + + def __get__(self, obj: "FakeTensor", objtype=None): + if (r := getattr(obj, self._memo(obj))) is None: + return None + # Version counter based tracking isn't 100% sound but it's close + # enough + if ( + getattr(obj, self._memo_vc(obj)) != obj._version + or getattr(obj, self._memo_epoch(obj)) != obj.fake_mode.epoch + ): + setattr(obj, self._memo(obj), None) + return None + return r + + def __set__(self, obj, value): + if value is None: + setattr(obj, self._memo(obj), None) + setattr(obj, self._memo_vc(obj), None) + setattr(obj, self._memo_epoch(obj), None) + elif not torch.is_inference_mode_enabled(): + setattr(obj, self._memo(obj), value) + setattr(obj, self._memo_vc(obj), obj._version) + setattr(obj, self._memo_epoch(obj), obj.fake_mode.epoch) + + class FakeTensor(torch.Tensor): """ Meta tensors give you the ability to run PyTorch code without having to @@ -408,63 +462,17 @@ class FakeTensor(torch.Tensor): constant: Optional[torch.Tensor] real_tensor: Optional[torch.Tensor] - # This memorizes the unbacked SymInt representing the number of nonzero - # elements in this tensor. This is helpful if you do something like - # x[mask] and y[mask]; mask.nonzero() gets repeatedly called and should - # give a consistent unbacked SymInt. It needs to be invalidated in the - # same way constant is. - # TODO: Generalize this as needed, e.g., into a trie of memos - _nonzero_memo: Optional[torch.SymInt] - _nonzero_memo_vc: Optional[int] - # When we retrace, we need to invalidate all the memos so that we can - # accurately identify the first time unbacked SymInts are allocated. - # This is only relevant for inputs; for intermediates, they will get fresh - # fake tensors so you won't have a memo anyway - _nonzero_memo_epoch: Optional[int] + # TODO: Generalize this as needed, e.g., into a trie of memos, if + # you do something like x[0].item() (x[0] is fresh each time, so + # memo mechanism here won't work) + nonzero_memo = UnbackedMemoDescriptor() + item_memo = UnbackedMemoDescriptor() + unique_memo = UnbackedMemoDescriptor() # Indicates to our torch_dispatch dispatching infra that # this is an "infra" mode with lower dispatching precedence. _mode_key = torch._C._TorchDispatchModeKey.FAKE - @property - def nonzero_memo(self): - if self._nonzero_memo is None: - return None - # Version counter based tracking isn't 100% sound but it's close - # enough - if ( - self._nonzero_memo_vc != self._version - or self._nonzero_memo_epoch != self.fake_mode.epoch - ): - self._nonzero_memo = None - return None - return self._nonzero_memo - - # This memorizes the unbacked SymInt representing the number of unique - # elements in this tensor. This is helpful if you do something like - # calling torch.unique(x) multiple times and should - # give a consistent unbacked SymInt. It needs to be invalidated in the - # same way constant is. - # TODO: Generalize this as needed, e.g., into a trie of memos - _unique_memo: Optional[torch.SymInt] - _unique_memo_vc: Optional[int] - - @property - def unique_memo(self): - if self._unique_memo is None: - return None - # Version counter based tracking isn't 100% sound but it's close - # enough - if self._unique_memo_vc != self._version: - self._unique_memo = None - return None - return self._unique_memo - - @unique_memo.setter - def unique_memo(self, value): - self._unique_memo = value - self._unique_memo_vc = self._version - @property def device(self): if self.fake_mode.in_kernel_invocation: @@ -539,10 +547,9 @@ class FakeTensor(torch.Tensor): self.constant = constant # type: ignore[attr-defined] assert not isinstance(real_tensor, FakeTensor) self.real_tensor = real_tensor # type: ignore[attr-defined] - self._nonzero_memo = None # type: ignore[attr-defined] - self._nonzero_memo_vc = None # type: ignore[attr-defined] - self._unique_memo = None # type: ignore[attr-defined] - self._unique_memo_vc = None # type: ignore[attr-defined] + self.nonzero_memo = None + self.item_memo = None + self.unique_memo = None if FakeTensorConfig.debug: self._debug_trace = CapturedTraceback.extract() # type: ignore[attr-defined]