mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96578 Approved by: https://github.com/ezyang
124 lines
2.2 KiB
Python
124 lines
2.2 KiB
Python
import sympy
|
|
|
|
|
|
# The normal Python interpretation of the operators
|
|
# NB: For magic methods this needs to use normal magic methods
|
|
# so that test_magic_methods works
|
|
class ReferenceAnalysis:
|
|
@staticmethod
|
|
def constant(c, dtype):
|
|
return sympy.sympify(c)
|
|
|
|
@staticmethod
|
|
def or_(a, b):
|
|
assert not isinstance(a, bool) and not isinstance(b, bool)
|
|
return a | b
|
|
|
|
@staticmethod
|
|
def and_(a, b):
|
|
assert not isinstance(a, bool) and not isinstance(b, bool)
|
|
return a & b
|
|
|
|
@staticmethod
|
|
def eq(a, b):
|
|
if isinstance(a, sympy.Expr) or isinstance(b, sympy.Expr):
|
|
return sympy.Eq(a, b)
|
|
return a == b
|
|
|
|
@classmethod
|
|
def ne(cls, a, b):
|
|
return cls.not_(cls.eq(a, b))
|
|
|
|
@staticmethod
|
|
def lt(a, b):
|
|
return a < b
|
|
|
|
@staticmethod
|
|
def gt(a, b):
|
|
return a > b
|
|
|
|
@staticmethod
|
|
def le(a, b):
|
|
return a <= b
|
|
|
|
@staticmethod
|
|
def ge(a, b):
|
|
return a >= b
|
|
|
|
@staticmethod
|
|
def not_(a):
|
|
assert not isinstance(a, bool)
|
|
return ~a
|
|
|
|
@staticmethod
|
|
def reciprocal(x):
|
|
return 1 / x
|
|
|
|
@staticmethod
|
|
def square(x):
|
|
return x * x
|
|
|
|
@staticmethod
|
|
def mod(x, y):
|
|
return x % y
|
|
|
|
@staticmethod
|
|
def abs(x):
|
|
return abs(x)
|
|
|
|
@staticmethod
|
|
def neg(x):
|
|
return -x
|
|
|
|
@staticmethod
|
|
def truediv(a, b):
|
|
return a / b
|
|
|
|
@staticmethod
|
|
def div(a, b):
|
|
return a // b
|
|
|
|
@staticmethod
|
|
def add(a, b):
|
|
return a + b
|
|
|
|
@staticmethod
|
|
def mul(a, b):
|
|
return a * b
|
|
|
|
@staticmethod
|
|
def sub(a, b):
|
|
return a - b
|
|
|
|
@staticmethod
|
|
def exp(x):
|
|
return sympy.exp(x)
|
|
|
|
@staticmethod
|
|
def log(x):
|
|
return sympy.log(x)
|
|
|
|
@staticmethod
|
|
def sqrt(x):
|
|
return sympy.sqrt(x)
|
|
|
|
@staticmethod
|
|
def pow(a, b):
|
|
return a**b
|
|
|
|
@staticmethod
|
|
def minimum(a, b):
|
|
return sympy.Min(a, b)
|
|
|
|
@staticmethod
|
|
def maximum(a, b):
|
|
return sympy.Max(a, b)
|
|
|
|
@staticmethod
|
|
def floor(x):
|
|
return sympy.floor(x)
|
|
|
|
@staticmethod
|
|
def ceil(x):
|
|
return sympy.ceiling(x)
|