mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Doesn't seem to noticeably slow down eager - TestNestedTensorSubclass tests with and without the PR finished in similar amounts of time (around 57s, 58s) Pull Request resolved: https://github.com/pytorch/pytorch/pull/141166 Approved by: https://github.com/ezyang
117 lines
3.1 KiB
Python
117 lines
3.1 KiB
Python
from typing import * # noqa: F403
|
|
|
|
import torch
|
|
from torch.fx.experimental._constant_symnode import ConstantIntNode
|
|
|
|
|
|
__all__ = ["NestedIntNode"]
|
|
|
|
|
|
# Python version of aten/src/ATen/core/NestedIntSymNodeImpl.cpp
|
|
def _eq(lhs: Any, rhs: Any) -> bool:
|
|
return (
|
|
isinstance(lhs, NestedIntNode)
|
|
and isinstance(rhs, NestedIntNode)
|
|
and lhs.t_id == rhs.t_id
|
|
and lhs.coeff == rhs.coeff
|
|
)
|
|
|
|
|
|
def _ge(lhs: Any, rhs: Any) -> bool:
|
|
if isinstance(rhs, NestedIntNode) and isinstance(lhs, NestedIntNode):
|
|
if lhs.t_id == rhs.t_id:
|
|
return lhs.coeff >= rhs.coeff
|
|
raise ValueError("ge: relation is indeterminate")
|
|
elif isinstance(lhs, NestedIntNode):
|
|
if rhs.is_constant() and rhs.constant_int() <= 2:
|
|
return True
|
|
raise ValueError("ge: relation is indeterminate")
|
|
elif isinstance(rhs, NestedIntNode):
|
|
if lhs.is_constant() and lhs.constant_int() < 2:
|
|
return False
|
|
raise ValueError("ge: relation is indeterminate")
|
|
else:
|
|
raise ValueError("inputs unsupported")
|
|
|
|
|
|
class NestedIntNode:
|
|
def __init__(self, t_id: int, coeff: int):
|
|
self.t_id = t_id
|
|
self.coeff = coeff
|
|
|
|
def nested_int_coeff(self) -> int:
|
|
return self.coeff
|
|
|
|
def maybe_as_int(self) -> Optional[int]:
|
|
return None
|
|
|
|
def is_int(self) -> bool:
|
|
return True
|
|
|
|
def is_float(self) -> bool:
|
|
return False
|
|
|
|
def is_bool(self) -> bool:
|
|
return False
|
|
|
|
def is_nested_int(self) -> bool:
|
|
return True
|
|
|
|
def clone(self) -> "NestedIntNode":
|
|
return self
|
|
|
|
def _str(self) -> Any:
|
|
if self.coeff == 1:
|
|
return f"j{self.t_id}"
|
|
return f"{self.coeff}*j{self.t_id}"
|
|
|
|
def str(self) -> Any:
|
|
return self._str()
|
|
|
|
def __str__(self) -> Any:
|
|
return self._str()
|
|
|
|
def __repr__(self) -> Any:
|
|
return self._str()
|
|
|
|
def _graph_repr(self) -> Any:
|
|
return self._str()
|
|
|
|
def mul(self, other: Any) -> "NestedIntNode":
|
|
if other.is_constant():
|
|
other = other.constant_int()
|
|
else:
|
|
raise ValueError(f"unsupported: {type(other)}")
|
|
return NestedIntNode(self.t_id, self.coeff * other)
|
|
|
|
def eq(self, other: Any) -> Any:
|
|
return torch._C._get_constant_bool_symnode(_eq(self, other))
|
|
|
|
def ne(self, other: Any) -> Any:
|
|
return torch._C._get_constant_bool_symnode(not _eq(self, other))
|
|
|
|
def gt(self, other: Any) -> Any:
|
|
return torch._C._get_constant_bool_symnode(not _ge(other, self))
|
|
|
|
def lt(self, other: Any) -> Any:
|
|
return torch._C._get_constant_bool_symnode(not _ge(self, other))
|
|
|
|
def le(self, other: Any) -> Any:
|
|
return torch._C._get_constant_bool_symnode(_ge(other, self))
|
|
|
|
def ge(self, other: Any) -> Any:
|
|
return torch._C._get_constant_bool_symnode(_ge(self, other))
|
|
|
|
def is_symbolic(self) -> bool:
|
|
return False
|
|
|
|
def nested_int(self) -> int:
|
|
return self.t_id
|
|
|
|
def is_constant(self) -> bool:
|
|
return False
|
|
|
|
def wrap_int(self, num: int) -> ConstantIntNode:
|
|
assert type(num) is int
|
|
return ConstantIntNode(num)
|