Memoize local_scalar_dense calls, refactor all memos (#125623)

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125623
Approved by: https://github.com/eellison
This commit is contained in:
Edward Z. Yang 2024-05-09 13:57:47 -07:00 committed by PyTorch MergeBot
parent 0935b3d794
commit f19e07b056
4 changed files with 82 additions and 77 deletions

View File

@ -180,7 +180,7 @@ class TestDynamismExpression(TestCase):
res = gm(*inp)
self.assertTrue(torchdynamo.utils.same(ref, res))
def test_export_constraints_error(self):
def test_export_constraints_error_not_in_range(self):
class InvalidInputConflictWithInputConstraints(torch.nn.Module):
def forward(self, x):
return x + 1
@ -194,6 +194,7 @@ class TestDynamismExpression(TestCase):
dynamic_shapes={"x": {0: dim_x}},
)
def test_export_constraints_error(self):
class ConflictingConstraints(torch.nn.Module):
def forward(self, x):
b = x.item()
@ -2432,7 +2433,7 @@ def forward(self, x):
# This is because we insert sym_constrain_range in the graph now
if is_non_strict_test(self._testMethodName):
error_msg = "Invalid value range"
error_msg = r"Invalid value range for -1 between"
else:
error_msg = "is outside of inline constraint"
with self.assertRaisesRegex(RuntimeError, error_msg):

View File

@ -161,9 +161,11 @@ def run_functionalized_fw_and_collect_metadata(
flat_f_args = pytree.tree_map(_to_fun, flat_args)
flat_f_outs = f(*flat_f_args)
# We didn't do any tracing, so we don't need to process the
# unbacked symbols, they will just disappear into the ether
# unbacked symbols, they will just disappear into the ether.
# Also, prevent memoization from applying.
if (fake_mode := detect_fake_mode()) and (shape_env := fake_mode.shape_env):
shape_env.pending_fresh_unbacked_symbols.clear()
fake_mode.epoch += 1
if prior_autocast_states != _get_autocast_states():
raise RuntimeError(

View File

@ -269,7 +269,7 @@ def unique2(
# Without symints/symfloats, cannot handle this
raise DynamicOutputShapeException(func)
if arg.unique_memo is None:
if (nnz := arg.unique_memo) is None:
# Avoid importing sympy at a module level
from torch.fx.experimental.symbolic_shapes import (
_constrain_range_for_size,
@ -285,8 +285,7 @@ def unique2(
# symint cannot equal zero). We could also unconditionally
# allocate an unbacked SymInt and not refine its range,
# but this seems more precise.
nnz = arg._unique_memo = 0
arg._unique_memo_vc = arg._version
nnz = 0
else:
nnz = fake_mode.shape_env.create_unbacked_symint()
@ -299,7 +298,7 @@ def unique2(
arg.unique_memo = nnz
ret = [arg.new_empty((arg.unique_memo,))]
ret = [arg.new_empty((nnz,))]
if return_inverse:
ret.append(torch.empty_like(arg))
@ -341,14 +340,18 @@ def local_scalar_dense(fake_mode, func, arg):
):
# Without symints/symfloats, cannot handle this
raise DataDependentOutputException(func)
if (r := arg.item_memo) is not None:
return r
if is_float_dtype(arg.dtype):
return fake_mode.shape_env.create_unbacked_symfloat()
r = fake_mode.shape_env.create_unbacked_symfloat()
elif is_integer_dtype(arg.dtype):
return fake_mode.shape_env.create_unbacked_symint()
r = fake_mode.shape_env.create_unbacked_symint()
elif is_boolean_dtype(arg.dtype):
return fake_mode.shape_env.create_unbacked_symbool()
r = fake_mode.shape_env.create_unbacked_symbool()
else:
raise NotImplementedError(f"local_scalar_dense/item NYI for {arg.dtype}")
arg.item_memo = r
return r
@register_op_impl(torch.ops.aten.nonzero.default)
@ -360,9 +363,7 @@ def nonzero(fake_mode, func, arg):
# Without symints/symfloats, cannot handle this
raise DynamicOutputShapeException(func)
if arg.nonzero_memo is not None:
nnz = arg.nonzero_memo
else:
if (nnz := arg.nonzero_memo) is None:
# Avoid importing sympy at a module level
from torch.fx.experimental.symbolic_shapes import (
_constrain_range_for_size,
@ -378,9 +379,7 @@ def nonzero(fake_mode, func, arg):
# symint cannot equal zero). We could also unconditionally
# allocate an unbacked SymInt and not refine its range,
# but this seems more precise.
nnz = arg._nonzero_memo = 0
arg._nonzero_memo_vc = arg._version
arg._nonzero_memo_epoch = fake_mode.epoch
nnz = 0
else:
nnz = fake_mode.shape_env.create_unbacked_symint()
@ -391,11 +390,7 @@ def nonzero(fake_mode, func, arg):
_constrain_range_for_size(nnz, max=maxval)
if not torch.is_inference_mode_enabled():
# arg._version N/A in inference mode
arg._nonzero_memo = nnz
arg._nonzero_memo_vc = arg._version
arg._nonzero_memo_epoch = fake_mode.epoch
arg.nonzero_memo = nnz
return arg.new_empty((nnz, arg.dim()), dtype=torch.int64)

View File

@ -394,6 +394,60 @@ class FakeTensorConfig:
debug = os.environ.get("TORCH_FAKE_TENSOR_DEBUG", "0") == "1"
# This memorizes the unbacked SymInt representing quantities like the number
# of nonzero elements in this tensor. There is one instance of the descriptor
# per particular quantity to memoize.
#
# Memoization is helpful if you do something like x[mask] and y[mask];
# mask.nonzero() gets repeatedly called and should give a consistent unbacked
# SymInt. It needs to be invalidated in the same way constant is.
#
# Making this a descriptor may seem overly fancy, but actually it's the most
# convenient way to make sure we have access to FakeTensor during access,
# which is required for testing version counter and epoch validity
class UnbackedMemoDescriptor:
_name: str
def __set_name__(self, owner, name):
self._name = name
def _memo(self, obj):
return f"_{self._name}"
def _memo_vc(self, obj):
return f"_{self._name}_vc"
# When we retrace, we need to invalidate all the memos so that we can
# accurately identify the first time unbacked SymInts are allocated.
# This is only relevant for inputs; for intermediates, they will get fresh
# fake tensors so you won't have a memo anyway
def _memo_epoch(self, obj):
return f"_{self._name}_epoch"
def __get__(self, obj: "FakeTensor", objtype=None):
if (r := getattr(obj, self._memo(obj))) is None:
return None
# Version counter based tracking isn't 100% sound but it's close
# enough
if (
getattr(obj, self._memo_vc(obj)) != obj._version
or getattr(obj, self._memo_epoch(obj)) != obj.fake_mode.epoch
):
setattr(obj, self._memo(obj), None)
return None
return r
def __set__(self, obj, value):
if value is None:
setattr(obj, self._memo(obj), None)
setattr(obj, self._memo_vc(obj), None)
setattr(obj, self._memo_epoch(obj), None)
elif not torch.is_inference_mode_enabled():
setattr(obj, self._memo(obj), value)
setattr(obj, self._memo_vc(obj), obj._version)
setattr(obj, self._memo_epoch(obj), obj.fake_mode.epoch)
class FakeTensor(torch.Tensor):
"""
Meta tensors give you the ability to run PyTorch code without having to
@ -408,63 +462,17 @@ class FakeTensor(torch.Tensor):
constant: Optional[torch.Tensor]
real_tensor: Optional[torch.Tensor]
# This memorizes the unbacked SymInt representing the number of nonzero
# elements in this tensor. This is helpful if you do something like
# x[mask] and y[mask]; mask.nonzero() gets repeatedly called and should
# give a consistent unbacked SymInt. It needs to be invalidated in the
# same way constant is.
# TODO: Generalize this as needed, e.g., into a trie of memos
_nonzero_memo: Optional[torch.SymInt]
_nonzero_memo_vc: Optional[int]
# When we retrace, we need to invalidate all the memos so that we can
# accurately identify the first time unbacked SymInts are allocated.
# This is only relevant for inputs; for intermediates, they will get fresh
# fake tensors so you won't have a memo anyway
_nonzero_memo_epoch: Optional[int]
# TODO: Generalize this as needed, e.g., into a trie of memos, if
# you do something like x[0].item() (x[0] is fresh each time, so
# memo mechanism here won't work)
nonzero_memo = UnbackedMemoDescriptor()
item_memo = UnbackedMemoDescriptor()
unique_memo = UnbackedMemoDescriptor()
# Indicates to our torch_dispatch dispatching infra that
# this is an "infra" mode with lower dispatching precedence.
_mode_key = torch._C._TorchDispatchModeKey.FAKE
@property
def nonzero_memo(self):
if self._nonzero_memo is None:
return None
# Version counter based tracking isn't 100% sound but it's close
# enough
if (
self._nonzero_memo_vc != self._version
or self._nonzero_memo_epoch != self.fake_mode.epoch
):
self._nonzero_memo = None
return None
return self._nonzero_memo
# This memorizes the unbacked SymInt representing the number of unique
# elements in this tensor. This is helpful if you do something like
# calling torch.unique(x) multiple times and should
# give a consistent unbacked SymInt. It needs to be invalidated in the
# same way constant is.
# TODO: Generalize this as needed, e.g., into a trie of memos
_unique_memo: Optional[torch.SymInt]
_unique_memo_vc: Optional[int]
@property
def unique_memo(self):
if self._unique_memo is None:
return None
# Version counter based tracking isn't 100% sound but it's close
# enough
if self._unique_memo_vc != self._version:
self._unique_memo = None
return None
return self._unique_memo
@unique_memo.setter
def unique_memo(self, value):
self._unique_memo = value
self._unique_memo_vc = self._version
@property
def device(self):
if self.fake_mode.in_kernel_invocation:
@ -539,10 +547,9 @@ class FakeTensor(torch.Tensor):
self.constant = constant # type: ignore[attr-defined]
assert not isinstance(real_tensor, FakeTensor)
self.real_tensor = real_tensor # type: ignore[attr-defined]
self._nonzero_memo = None # type: ignore[attr-defined]
self._nonzero_memo_vc = None # type: ignore[attr-defined]
self._unique_memo = None # type: ignore[attr-defined]
self._unique_memo_vc = None # type: ignore[attr-defined]
self.nonzero_memo = None
self.item_memo = None
self.unique_memo = None
if FakeTensorConfig.debug:
self._debug_trace = CapturedTraceback.extract() # type: ignore[attr-defined]