mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Switch to using Python nested int (#141166)
Doesn't seem to noticeably slow down eager - TestNestedTensorSubclass tests with and without the PR finished in similar amounts of time (around 57s, 58s) Pull Request resolved: https://github.com/pytorch/pytorch/pull/141166 Approved by: https://github.com/ezyang
This commit is contained in:
parent
2d708752f0
commit
161a2340ee
|
|
@ -8559,6 +8559,101 @@ class TestNestedTensorOpInfo(NestedTensorTestCase):
|
|||
self.assertEqualNoncontigAware(grads_compile, grads_ref)
|
||||
|
||||
|
||||
from torch.nested._internal.nested_int import NestedIntNode
|
||||
|
||||
|
||||
class TestNestedInt(torch.testing._internal.common_utils.TestCase):
|
||||
def test_comparisons(self):
|
||||
a = torch.SymInt(NestedIntNode(1, 1))
|
||||
b = torch.SymInt(NestedIntNode(1, 1))
|
||||
c = torch.SymInt(NestedIntNode(2, 1))
|
||||
d = 3
|
||||
|
||||
self.assertTrue(a == a)
|
||||
self.assertTrue(a == b)
|
||||
self.assertFalse(a != a)
|
||||
self.assertFalse(a != b)
|
||||
self.assertFalse(a == c)
|
||||
self.assertTrue(a != c)
|
||||
|
||||
self.assertFalse(a == d)
|
||||
self.assertTrue(a != d)
|
||||
self.assertFalse(d == a)
|
||||
self.assertTrue(d != a)
|
||||
|
||||
# ge
|
||||
self.assertTrue(a >= a)
|
||||
self.assertTrue(a >= b)
|
||||
self.assertTrue(b >= a)
|
||||
with self.assertRaises(ValueError):
|
||||
_ = a >= c
|
||||
with self.assertRaises(ValueError):
|
||||
_ = c >= a
|
||||
with self.assertRaises(ValueError):
|
||||
_ = c >= 3
|
||||
self.assertTrue(c >= 2)
|
||||
self.assertTrue(c >= 1)
|
||||
self.assertFalse(c <= 1)
|
||||
|
||||
# lt
|
||||
self.assertFalse(a < a)
|
||||
self.assertFalse(a < b)
|
||||
self.assertFalse(b < a)
|
||||
with self.assertRaises(ValueError):
|
||||
_ = a < c
|
||||
with self.assertRaises(ValueError):
|
||||
_ = c < a
|
||||
with self.assertRaises(ValueError):
|
||||
_ = 3 < a
|
||||
with self.assertRaises(ValueError):
|
||||
_ = 2 < a
|
||||
self.assertTrue(a > 1)
|
||||
|
||||
# le
|
||||
self.assertTrue(a <= a)
|
||||
self.assertTrue(b <= a)
|
||||
self.assertTrue(a <= b)
|
||||
with self.assertRaises(ValueError):
|
||||
_ = a <= c
|
||||
with self.assertRaises(ValueError):
|
||||
_ = c <= a
|
||||
with self.assertRaises(ValueError):
|
||||
_ = 3 <= c
|
||||
self.assertTrue(c >= 2)
|
||||
self.assertTrue(c >= 1)
|
||||
self.assertFalse(c <= 1)
|
||||
|
||||
# gt
|
||||
self.assertFalse(a > a)
|
||||
self.assertFalse(b > a)
|
||||
self.assertFalse(a > b)
|
||||
with self.assertRaises(ValueError):
|
||||
_ = a > c
|
||||
with self.assertRaises(ValueError):
|
||||
_ = c > a
|
||||
with self.assertRaises(ValueError):
|
||||
_ = a > 3
|
||||
with self.assertRaises(ValueError):
|
||||
_ = a > 2
|
||||
self.assertTrue(a > 1)
|
||||
|
||||
def test_with_factor(self):
|
||||
a = torch.SymInt(NestedIntNode(1, 5))
|
||||
b = torch.SymInt(NestedIntNode(1, 10))
|
||||
# eq
|
||||
self.assertFalse(a == b)
|
||||
self.assertFalse(a >= b)
|
||||
self.assertTrue(b >= a)
|
||||
self.assertTrue(a <= b)
|
||||
self.assertFalse(b <= a)
|
||||
# ne
|
||||
self.assertTrue(a != b)
|
||||
# mul
|
||||
self.assertTrue(a * 2 == b)
|
||||
self.assertTrue(a * 3 >= b)
|
||||
self.assertTrue(a * 2 == 2 * a)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(TestNestedTensor)
|
||||
instantiate_device_type_tests(TestNestedTensorDeviceType, globals())
|
||||
instantiate_device_type_tests(TestNestedTensorAutograd, globals())
|
||||
|
|
|
|||
|
|
@ -2513,12 +2513,13 @@ class FakeTensorMode(TorchDispatchMode):
|
|||
# See Note: [Creating symbolic nested int]
|
||||
# Returned nested int always has coeff=1; multiply the result by coeff if needed
|
||||
import torch.nested._internal.nested_tensor
|
||||
from torch.nested._internal.nested_int import NestedIntNode
|
||||
|
||||
if nt_tensor_id is None:
|
||||
nt_tensor_id = self.nt_tensor_id_counter
|
||||
assert self.enter_stack, "should only called while FakeTensorMode is active"
|
||||
self.nt_tensor_id_counter += 1
|
||||
hint = torch._C._get_nested_int(nt_tensor_id, 1)
|
||||
hint = torch.SymInt(NestedIntNode(nt_tensor_id, 1))
|
||||
|
||||
src = torch._dynamo.source.EphemeralSource("intermediate_offsets_or_lengths")
|
||||
assert self.shape_env is not None
|
||||
|
|
|
|||
69
torch/fx/experimental/_constant_symnode.py
Normal file
69
torch/fx/experimental/_constant_symnode.py
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
from typing import * # noqa: F403
|
||||
|
||||
|
||||
# Python version of c10/core/ConstantSymNodeImpl.cpp
|
||||
# This needs to exist because the Python version of nested int is not compatible
|
||||
# with the C++ version of constant symnode.
|
||||
class ConstantIntNode:
|
||||
def __init__(self, val: int):
|
||||
self.val = val
|
||||
|
||||
def is_constant(self) -> bool:
|
||||
return True
|
||||
|
||||
def maybe_as_int(self) -> int:
|
||||
return self.val
|
||||
|
||||
def is_int(self) -> bool:
|
||||
return True
|
||||
|
||||
def is_float(self) -> bool:
|
||||
return False
|
||||
|
||||
def is_bool(self) -> bool:
|
||||
return False
|
||||
|
||||
def is_nested_int(self) -> bool:
|
||||
return False
|
||||
|
||||
def clone(self) -> "ConstantIntNode":
|
||||
return self
|
||||
|
||||
def _str(self) -> str:
|
||||
return str(self.val)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self._str()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self._str()
|
||||
|
||||
def _graph_repr(self) -> str:
|
||||
return self._str()
|
||||
|
||||
def mul(self, other: Any) -> Any:
|
||||
return other.mul(self)
|
||||
|
||||
def eq(self, other: Any) -> Any:
|
||||
return other.eq(self)
|
||||
|
||||
def ne(self, other: Any) -> Any:
|
||||
return other.ne(self)
|
||||
|
||||
def gt(self, other: Any) -> Any:
|
||||
return other.lt(self)
|
||||
|
||||
def lt(self, other: Any) -> Any:
|
||||
return other.gt(self)
|
||||
|
||||
def le(self, other: Any) -> Any:
|
||||
return other.ge(self)
|
||||
|
||||
def ge(self, other: Any) -> Any:
|
||||
return other.le(self)
|
||||
|
||||
def is_symbolic(self) -> bool:
|
||||
return False
|
||||
|
||||
def constant_int(self) -> int:
|
||||
return self.val
|
||||
116
torch/nested/_internal/nested_int.py
Normal file
116
torch/nested/_internal/nested_int.py
Normal file
|
|
@ -0,0 +1,116 @@
|
|||
from typing import * # noqa: F403
|
||||
|
||||
import torch
|
||||
from torch.fx.experimental._constant_symnode import ConstantIntNode
|
||||
|
||||
|
||||
__all__ = ["NestedIntNode"]
|
||||
|
||||
|
||||
# Python version of aten/src/ATen/core/NestedIntSymNodeImpl.cpp
|
||||
def _eq(lhs: Any, rhs: Any) -> bool:
|
||||
return (
|
||||
isinstance(lhs, NestedIntNode)
|
||||
and isinstance(rhs, NestedIntNode)
|
||||
and lhs.t_id == rhs.t_id
|
||||
and lhs.coeff == rhs.coeff
|
||||
)
|
||||
|
||||
|
||||
def _ge(lhs: Any, rhs: Any) -> bool:
|
||||
if isinstance(rhs, NestedIntNode) and isinstance(lhs, NestedIntNode):
|
||||
if lhs.t_id == rhs.t_id:
|
||||
return lhs.coeff >= rhs.coeff
|
||||
raise ValueError("ge: relation is indeterminate")
|
||||
elif isinstance(lhs, NestedIntNode):
|
||||
if rhs.is_constant() and rhs.constant_int() <= 2:
|
||||
return True
|
||||
raise ValueError("ge: relation is indeterminate")
|
||||
elif isinstance(rhs, NestedIntNode):
|
||||
if lhs.is_constant() and lhs.constant_int() < 2:
|
||||
return False
|
||||
raise ValueError("ge: relation is indeterminate")
|
||||
else:
|
||||
raise ValueError("inputs unsupported")
|
||||
|
||||
|
||||
class NestedIntNode:
|
||||
def __init__(self, t_id: int, coeff: int):
|
||||
self.t_id = t_id
|
||||
self.coeff = coeff
|
||||
|
||||
def nested_int_coeff(self) -> int:
|
||||
return self.coeff
|
||||
|
||||
def maybe_as_int(self) -> Optional[int]:
|
||||
return None
|
||||
|
||||
def is_int(self) -> bool:
|
||||
return True
|
||||
|
||||
def is_float(self) -> bool:
|
||||
return False
|
||||
|
||||
def is_bool(self) -> bool:
|
||||
return False
|
||||
|
||||
def is_nested_int(self) -> bool:
|
||||
return True
|
||||
|
||||
def clone(self) -> "NestedIntNode":
|
||||
return self
|
||||
|
||||
def _str(self) -> Any:
|
||||
if self.coeff == 1:
|
||||
return f"j{self.t_id}"
|
||||
return f"{self.coeff}*j{self.t_id}"
|
||||
|
||||
def str(self) -> Any:
|
||||
return self._str()
|
||||
|
||||
def __str__(self) -> Any:
|
||||
return self._str()
|
||||
|
||||
def __repr__(self) -> Any:
|
||||
return self._str()
|
||||
|
||||
def _graph_repr(self) -> Any:
|
||||
return self._str()
|
||||
|
||||
def mul(self, other: Any) -> "NestedIntNode":
|
||||
if other.is_constant():
|
||||
other = other.constant_int()
|
||||
else:
|
||||
raise ValueError(f"unsupported: {type(other)}")
|
||||
return NestedIntNode(self.t_id, self.coeff * other)
|
||||
|
||||
def eq(self, other: Any) -> Any:
|
||||
return torch._C._get_constant_bool_symnode(_eq(self, other))
|
||||
|
||||
def ne(self, other: Any) -> Any:
|
||||
return torch._C._get_constant_bool_symnode(not _eq(self, other))
|
||||
|
||||
def gt(self, other: Any) -> Any:
|
||||
return torch._C._get_constant_bool_symnode(not _ge(other, self))
|
||||
|
||||
def lt(self, other: Any) -> Any:
|
||||
return torch._C._get_constant_bool_symnode(not _ge(self, other))
|
||||
|
||||
def le(self, other: Any) -> Any:
|
||||
return torch._C._get_constant_bool_symnode(_ge(other, self))
|
||||
|
||||
def ge(self, other: Any) -> Any:
|
||||
return torch._C._get_constant_bool_symnode(_ge(self, other))
|
||||
|
||||
def is_symbolic(self) -> bool:
|
||||
return False
|
||||
|
||||
def nested_int(self) -> int:
|
||||
return self.t_id
|
||||
|
||||
def is_constant(self) -> bool:
|
||||
return False
|
||||
|
||||
def wrap_int(self, num: int) -> ConstantIntNode:
|
||||
assert type(num) is int
|
||||
return ConstantIntNode(num)
|
||||
|
|
@ -5,6 +5,7 @@ from typing import Tuple
|
|||
import torch
|
||||
from torch._C import DispatchKey, DispatchKeySet
|
||||
from torch._prims_common import is_expandable_to
|
||||
from torch.nested._internal.nested_int import NestedIntNode
|
||||
from torch.utils.weak import WeakTensorKeyDictionary
|
||||
|
||||
|
||||
|
|
@ -25,7 +26,7 @@ def get_tensor_symint(tensor, *, coeff=1):
|
|||
|
||||
tensor_symint = _tensor_symint_registry.get(tensor)
|
||||
if tensor_symint is None:
|
||||
tensor_symint = torch._C._get_nested_int(_tensor_id_counter, coeff)
|
||||
tensor_symint = torch.SymInt(NestedIntNode(_tensor_id_counter, coeff))
|
||||
_tensor_id_counter += 1
|
||||
_tensor_symint_registry[tensor] = tensor_symint
|
||||
return tensor_symint
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user