mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Memoize local_scalar_dense calls, refactor all memos (#125623)
Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/125623 Approved by: https://github.com/eellison
This commit is contained in:
parent
0935b3d794
commit
f19e07b056
|
|
@ -180,7 +180,7 @@ class TestDynamismExpression(TestCase):
|
|||
res = gm(*inp)
|
||||
self.assertTrue(torchdynamo.utils.same(ref, res))
|
||||
|
||||
def test_export_constraints_error(self):
|
||||
def test_export_constraints_error_not_in_range(self):
|
||||
class InvalidInputConflictWithInputConstraints(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return x + 1
|
||||
|
|
@ -194,6 +194,7 @@ class TestDynamismExpression(TestCase):
|
|||
dynamic_shapes={"x": {0: dim_x}},
|
||||
)
|
||||
|
||||
def test_export_constraints_error(self):
|
||||
class ConflictingConstraints(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
b = x.item()
|
||||
|
|
@ -2432,7 +2433,7 @@ def forward(self, x):
|
|||
|
||||
# This is because we insert sym_constrain_range in the graph now
|
||||
if is_non_strict_test(self._testMethodName):
|
||||
error_msg = "Invalid value range"
|
||||
error_msg = r"Invalid value range for -1 between"
|
||||
else:
|
||||
error_msg = "is outside of inline constraint"
|
||||
with self.assertRaisesRegex(RuntimeError, error_msg):
|
||||
|
|
|
|||
|
|
@ -161,9 +161,11 @@ def run_functionalized_fw_and_collect_metadata(
|
|||
flat_f_args = pytree.tree_map(_to_fun, flat_args)
|
||||
flat_f_outs = f(*flat_f_args)
|
||||
# We didn't do any tracing, so we don't need to process the
|
||||
# unbacked symbols, they will just disappear into the ether
|
||||
# unbacked symbols, they will just disappear into the ether.
|
||||
# Also, prevent memoization from applying.
|
||||
if (fake_mode := detect_fake_mode()) and (shape_env := fake_mode.shape_env):
|
||||
shape_env.pending_fresh_unbacked_symbols.clear()
|
||||
fake_mode.epoch += 1
|
||||
|
||||
if prior_autocast_states != _get_autocast_states():
|
||||
raise RuntimeError(
|
||||
|
|
|
|||
|
|
@ -269,7 +269,7 @@ def unique2(
|
|||
# Without symints/symfloats, cannot handle this
|
||||
raise DynamicOutputShapeException(func)
|
||||
|
||||
if arg.unique_memo is None:
|
||||
if (nnz := arg.unique_memo) is None:
|
||||
# Avoid importing sympy at a module level
|
||||
from torch.fx.experimental.symbolic_shapes import (
|
||||
_constrain_range_for_size,
|
||||
|
|
@ -285,8 +285,7 @@ def unique2(
|
|||
# symint cannot equal zero). We could also unconditionally
|
||||
# allocate an unbacked SymInt and not refine its range,
|
||||
# but this seems more precise.
|
||||
nnz = arg._unique_memo = 0
|
||||
arg._unique_memo_vc = arg._version
|
||||
nnz = 0
|
||||
else:
|
||||
nnz = fake_mode.shape_env.create_unbacked_symint()
|
||||
|
||||
|
|
@ -299,7 +298,7 @@ def unique2(
|
|||
|
||||
arg.unique_memo = nnz
|
||||
|
||||
ret = [arg.new_empty((arg.unique_memo,))]
|
||||
ret = [arg.new_empty((nnz,))]
|
||||
|
||||
if return_inverse:
|
||||
ret.append(torch.empty_like(arg))
|
||||
|
|
@ -341,14 +340,18 @@ def local_scalar_dense(fake_mode, func, arg):
|
|||
):
|
||||
# Without symints/symfloats, cannot handle this
|
||||
raise DataDependentOutputException(func)
|
||||
if (r := arg.item_memo) is not None:
|
||||
return r
|
||||
if is_float_dtype(arg.dtype):
|
||||
return fake_mode.shape_env.create_unbacked_symfloat()
|
||||
r = fake_mode.shape_env.create_unbacked_symfloat()
|
||||
elif is_integer_dtype(arg.dtype):
|
||||
return fake_mode.shape_env.create_unbacked_symint()
|
||||
r = fake_mode.shape_env.create_unbacked_symint()
|
||||
elif is_boolean_dtype(arg.dtype):
|
||||
return fake_mode.shape_env.create_unbacked_symbool()
|
||||
r = fake_mode.shape_env.create_unbacked_symbool()
|
||||
else:
|
||||
raise NotImplementedError(f"local_scalar_dense/item NYI for {arg.dtype}")
|
||||
arg.item_memo = r
|
||||
return r
|
||||
|
||||
|
||||
@register_op_impl(torch.ops.aten.nonzero.default)
|
||||
|
|
@ -360,9 +363,7 @@ def nonzero(fake_mode, func, arg):
|
|||
# Without symints/symfloats, cannot handle this
|
||||
raise DynamicOutputShapeException(func)
|
||||
|
||||
if arg.nonzero_memo is not None:
|
||||
nnz = arg.nonzero_memo
|
||||
else:
|
||||
if (nnz := arg.nonzero_memo) is None:
|
||||
# Avoid importing sympy at a module level
|
||||
from torch.fx.experimental.symbolic_shapes import (
|
||||
_constrain_range_for_size,
|
||||
|
|
@ -378,9 +379,7 @@ def nonzero(fake_mode, func, arg):
|
|||
# symint cannot equal zero). We could also unconditionally
|
||||
# allocate an unbacked SymInt and not refine its range,
|
||||
# but this seems more precise.
|
||||
nnz = arg._nonzero_memo = 0
|
||||
arg._nonzero_memo_vc = arg._version
|
||||
arg._nonzero_memo_epoch = fake_mode.epoch
|
||||
nnz = 0
|
||||
else:
|
||||
nnz = fake_mode.shape_env.create_unbacked_symint()
|
||||
|
||||
|
|
@ -391,11 +390,7 @@ def nonzero(fake_mode, func, arg):
|
|||
|
||||
_constrain_range_for_size(nnz, max=maxval)
|
||||
|
||||
if not torch.is_inference_mode_enabled():
|
||||
# arg._version N/A in inference mode
|
||||
arg._nonzero_memo = nnz
|
||||
arg._nonzero_memo_vc = arg._version
|
||||
arg._nonzero_memo_epoch = fake_mode.epoch
|
||||
arg.nonzero_memo = nnz
|
||||
|
||||
return arg.new_empty((nnz, arg.dim()), dtype=torch.int64)
|
||||
|
||||
|
|
|
|||
|
|
@ -394,6 +394,60 @@ class FakeTensorConfig:
|
|||
debug = os.environ.get("TORCH_FAKE_TENSOR_DEBUG", "0") == "1"
|
||||
|
||||
|
||||
# This memorizes the unbacked SymInt representing quantities like the number
|
||||
# of nonzero elements in this tensor. There is one instance of the descriptor
|
||||
# per particular quantity to memoize.
|
||||
#
|
||||
# Memoization is helpful if you do something like x[mask] and y[mask];
|
||||
# mask.nonzero() gets repeatedly called and should give a consistent unbacked
|
||||
# SymInt. It needs to be invalidated in the same way constant is.
|
||||
#
|
||||
# Making this a descriptor may seem overly fancy, but actually it's the most
|
||||
# convenient way to make sure we have access to FakeTensor during access,
|
||||
# which is required for testing version counter and epoch validity
|
||||
class UnbackedMemoDescriptor:
|
||||
_name: str
|
||||
|
||||
def __set_name__(self, owner, name):
|
||||
self._name = name
|
||||
|
||||
def _memo(self, obj):
|
||||
return f"_{self._name}"
|
||||
|
||||
def _memo_vc(self, obj):
|
||||
return f"_{self._name}_vc"
|
||||
|
||||
# When we retrace, we need to invalidate all the memos so that we can
|
||||
# accurately identify the first time unbacked SymInts are allocated.
|
||||
# This is only relevant for inputs; for intermediates, they will get fresh
|
||||
# fake tensors so you won't have a memo anyway
|
||||
def _memo_epoch(self, obj):
|
||||
return f"_{self._name}_epoch"
|
||||
|
||||
def __get__(self, obj: "FakeTensor", objtype=None):
|
||||
if (r := getattr(obj, self._memo(obj))) is None:
|
||||
return None
|
||||
# Version counter based tracking isn't 100% sound but it's close
|
||||
# enough
|
||||
if (
|
||||
getattr(obj, self._memo_vc(obj)) != obj._version
|
||||
or getattr(obj, self._memo_epoch(obj)) != obj.fake_mode.epoch
|
||||
):
|
||||
setattr(obj, self._memo(obj), None)
|
||||
return None
|
||||
return r
|
||||
|
||||
def __set__(self, obj, value):
|
||||
if value is None:
|
||||
setattr(obj, self._memo(obj), None)
|
||||
setattr(obj, self._memo_vc(obj), None)
|
||||
setattr(obj, self._memo_epoch(obj), None)
|
||||
elif not torch.is_inference_mode_enabled():
|
||||
setattr(obj, self._memo(obj), value)
|
||||
setattr(obj, self._memo_vc(obj), obj._version)
|
||||
setattr(obj, self._memo_epoch(obj), obj.fake_mode.epoch)
|
||||
|
||||
|
||||
class FakeTensor(torch.Tensor):
|
||||
"""
|
||||
Meta tensors give you the ability to run PyTorch code without having to
|
||||
|
|
@ -408,63 +462,17 @@ class FakeTensor(torch.Tensor):
|
|||
constant: Optional[torch.Tensor]
|
||||
real_tensor: Optional[torch.Tensor]
|
||||
|
||||
# This memorizes the unbacked SymInt representing the number of nonzero
|
||||
# elements in this tensor. This is helpful if you do something like
|
||||
# x[mask] and y[mask]; mask.nonzero() gets repeatedly called and should
|
||||
# give a consistent unbacked SymInt. It needs to be invalidated in the
|
||||
# same way constant is.
|
||||
# TODO: Generalize this as needed, e.g., into a trie of memos
|
||||
_nonzero_memo: Optional[torch.SymInt]
|
||||
_nonzero_memo_vc: Optional[int]
|
||||
# When we retrace, we need to invalidate all the memos so that we can
|
||||
# accurately identify the first time unbacked SymInts are allocated.
|
||||
# This is only relevant for inputs; for intermediates, they will get fresh
|
||||
# fake tensors so you won't have a memo anyway
|
||||
_nonzero_memo_epoch: Optional[int]
|
||||
# TODO: Generalize this as needed, e.g., into a trie of memos, if
|
||||
# you do something like x[0].item() (x[0] is fresh each time, so
|
||||
# memo mechanism here won't work)
|
||||
nonzero_memo = UnbackedMemoDescriptor()
|
||||
item_memo = UnbackedMemoDescriptor()
|
||||
unique_memo = UnbackedMemoDescriptor()
|
||||
|
||||
# Indicates to our torch_dispatch dispatching infra that
|
||||
# this is an "infra" mode with lower dispatching precedence.
|
||||
_mode_key = torch._C._TorchDispatchModeKey.FAKE
|
||||
|
||||
@property
|
||||
def nonzero_memo(self):
|
||||
if self._nonzero_memo is None:
|
||||
return None
|
||||
# Version counter based tracking isn't 100% sound but it's close
|
||||
# enough
|
||||
if (
|
||||
self._nonzero_memo_vc != self._version
|
||||
or self._nonzero_memo_epoch != self.fake_mode.epoch
|
||||
):
|
||||
self._nonzero_memo = None
|
||||
return None
|
||||
return self._nonzero_memo
|
||||
|
||||
# This memorizes the unbacked SymInt representing the number of unique
|
||||
# elements in this tensor. This is helpful if you do something like
|
||||
# calling torch.unique(x) multiple times and should
|
||||
# give a consistent unbacked SymInt. It needs to be invalidated in the
|
||||
# same way constant is.
|
||||
# TODO: Generalize this as needed, e.g., into a trie of memos
|
||||
_unique_memo: Optional[torch.SymInt]
|
||||
_unique_memo_vc: Optional[int]
|
||||
|
||||
@property
|
||||
def unique_memo(self):
|
||||
if self._unique_memo is None:
|
||||
return None
|
||||
# Version counter based tracking isn't 100% sound but it's close
|
||||
# enough
|
||||
if self._unique_memo_vc != self._version:
|
||||
self._unique_memo = None
|
||||
return None
|
||||
return self._unique_memo
|
||||
|
||||
@unique_memo.setter
|
||||
def unique_memo(self, value):
|
||||
self._unique_memo = value
|
||||
self._unique_memo_vc = self._version
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
if self.fake_mode.in_kernel_invocation:
|
||||
|
|
@ -539,10 +547,9 @@ class FakeTensor(torch.Tensor):
|
|||
self.constant = constant # type: ignore[attr-defined]
|
||||
assert not isinstance(real_tensor, FakeTensor)
|
||||
self.real_tensor = real_tensor # type: ignore[attr-defined]
|
||||
self._nonzero_memo = None # type: ignore[attr-defined]
|
||||
self._nonzero_memo_vc = None # type: ignore[attr-defined]
|
||||
self._unique_memo = None # type: ignore[attr-defined]
|
||||
self._unique_memo_vc = None # type: ignore[attr-defined]
|
||||
self.nonzero_memo = None
|
||||
self.item_memo = None
|
||||
self.unique_memo = None
|
||||
|
||||
if FakeTensorConfig.debug:
|
||||
self._debug_trace = CapturedTraceback.extract() # type: ignore[attr-defined]
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user