mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Partially addresses https://github.com/pytorch/pytorch/issues/128150 When you have big sums of values, we end up computing long chains of binary addition in our FX graph representation. Not only is this ugly, it also is quadratic, as the sympy.Add constructor is O(N) in number of arguments. Instead, ensure that we maintain the summation as a single FX node so we can do the entire addition all in one go. update_hint_regression benchmark, before and after: ``` update_hint_regression,compile_time_instruction_count,2648328980 update_hint_regression,compile_time_instruction_count,2563748678 ``` Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/136429 Approved by: https://github.com/isuruf
507 lines
12 KiB
Python
507 lines
12 KiB
Python
# mypy: allow-untyped-defs
|
|
import math
|
|
import operator
|
|
from typing import Union
|
|
|
|
import sympy
|
|
|
|
import torch
|
|
from torch.utils._sympy.functions import (
|
|
_keep_float,
|
|
FloatPow,
|
|
FloatTrueDiv,
|
|
FloorDiv,
|
|
IntTrueDiv,
|
|
Max,
|
|
Min,
|
|
Mod,
|
|
OpaqueUnaryFn_exp,
|
|
OpaqueUnaryFn_log,
|
|
OpaqueUnaryFn_sqrt,
|
|
PowByNatural,
|
|
RoundDecimal,
|
|
RoundToInt,
|
|
ToFloat,
|
|
TruncToInt,
|
|
)
|
|
|
|
|
|
# The sympy interpretation of operators. It will also sometimes work with
|
|
# plain int/float, but if you do certain operations you will get out a
|
|
# sympy.Basic in the end. If you want the Python/FX traceable interpretation,
|
|
# check PythonReferenceAnalysis.
|
|
# NB: For magic methods this needs to use normal magic methods
|
|
# so that test_magic_methods works
|
|
class ReferenceAnalysis:
|
|
@staticmethod
|
|
def constant(c, dtype):
|
|
return sympy.sympify(c)
|
|
|
|
@staticmethod
|
|
def or_(a, b):
|
|
return a | b
|
|
|
|
@staticmethod
|
|
def and_(a, b):
|
|
return a & b
|
|
|
|
@staticmethod
|
|
def eq(a, b):
|
|
if isinstance(a, sympy.Expr) or isinstance(b, sympy.Expr):
|
|
return sympy.Eq(a, b)
|
|
return a == b
|
|
|
|
@classmethod
|
|
def ne(cls, a, b):
|
|
return cls.not_(cls.eq(a, b))
|
|
|
|
@staticmethod
|
|
def lt(a, b):
|
|
return a < b
|
|
|
|
@staticmethod
|
|
def gt(a, b):
|
|
return a > b
|
|
|
|
@staticmethod
|
|
def le(a, b):
|
|
return a <= b
|
|
|
|
@staticmethod
|
|
def ge(a, b):
|
|
return a >= b
|
|
|
|
@staticmethod
|
|
def not_(a):
|
|
assert not isinstance(a, bool)
|
|
return ~a
|
|
|
|
@staticmethod
|
|
def reciprocal(x):
|
|
return FloatTrueDiv(1.0, x)
|
|
|
|
@staticmethod
|
|
def square(x):
|
|
return PowByNatural(x, 2)
|
|
|
|
@staticmethod
|
|
def trunc_to_int(x, dtype):
|
|
return TruncToInt(x)
|
|
|
|
@staticmethod
|
|
def ceil_to_int(x, dtype):
|
|
return sympy.ceiling(x)
|
|
|
|
@staticmethod
|
|
def floor_to_int(x, dtype):
|
|
return sympy.floor(x)
|
|
|
|
@staticmethod
|
|
def floor(x):
|
|
return _keep_float(sympy.floor)(x)
|
|
|
|
@staticmethod
|
|
def ceil(x):
|
|
return _keep_float(sympy.ceiling)(x)
|
|
|
|
@staticmethod
|
|
def to_dtype(x, dtype):
|
|
if dtype == torch.float64:
|
|
return ToFloat(x)
|
|
raise NotImplementedError(f"to_dtype {dtype} NYI")
|
|
|
|
@staticmethod
|
|
def mod(x, y):
|
|
return Mod(x, y)
|
|
|
|
@staticmethod
|
|
def abs(x):
|
|
return abs(x)
|
|
|
|
@staticmethod
|
|
def neg(x):
|
|
return -x
|
|
|
|
@staticmethod
|
|
def truediv(a, b):
|
|
return FloatTrueDiv(a, b)
|
|
|
|
@staticmethod
|
|
def int_truediv(a, b):
|
|
return IntTrueDiv(a, b)
|
|
|
|
@staticmethod
|
|
def floordiv(a, b):
|
|
return FloorDiv(a, b)
|
|
|
|
@staticmethod
|
|
def truncdiv(a, b):
|
|
raise NotImplementedError("TODO: truncdiv")
|
|
|
|
@staticmethod
|
|
def add(a, b):
|
|
return _keep_float(operator.add)(a, b)
|
|
|
|
@classmethod
|
|
def sym_sum(cls, args):
|
|
return sympy.Add(*args)
|
|
|
|
@staticmethod
|
|
def mul(a, b):
|
|
return _keep_float(operator.mul)(a, b)
|
|
|
|
@staticmethod
|
|
def sub(a, b):
|
|
return _keep_float(operator.sub)(a, b)
|
|
|
|
@staticmethod
|
|
def exp(x):
|
|
return OpaqueUnaryFn_exp(x)
|
|
|
|
@staticmethod
|
|
def log(x):
|
|
return OpaqueUnaryFn_log(x)
|
|
|
|
@staticmethod
|
|
def sqrt(x):
|
|
return OpaqueUnaryFn_sqrt(x)
|
|
|
|
@staticmethod
|
|
def pow(a, b):
|
|
return _keep_float(FloatPow)(a, b)
|
|
|
|
@staticmethod
|
|
def pow_by_natural(a, b):
|
|
return PowByNatural(a, b)
|
|
|
|
@staticmethod
|
|
def minimum(a, b):
|
|
return Min(a, b)
|
|
|
|
@staticmethod
|
|
def maximum(a, b):
|
|
return Max(a, b)
|
|
|
|
@staticmethod
|
|
def round_to_int(a, dtype):
|
|
return RoundToInt(a)
|
|
|
|
@staticmethod
|
|
def round_decimal(a, b):
|
|
return RoundDecimal(a, b)
|
|
|
|
|
|
# Unlike ReferenceAnalysis, does NOT sympyify, instead, works with plain
|
|
# Python types and is FX traceable. Inheritance here is purely for code
|
|
# sharing (TODO: considering splitting out a BaseReferenceAnalysis).
|
|
class PythonReferenceAnalysis(ReferenceAnalysis):
|
|
@staticmethod
|
|
def constant(c, dtype):
|
|
if dtype is torch.int64:
|
|
return int(c)
|
|
elif dtype is torch.double:
|
|
return float(c)
|
|
elif dtype is torch.bool:
|
|
return bool(c)
|
|
else:
|
|
raise AssertionError(f"unrecognized dtype {dtype}")
|
|
|
|
@staticmethod
|
|
def not_(a):
|
|
return torch.sym_not(a)
|
|
|
|
@classmethod
|
|
def sym_sum(cls, args):
|
|
if len(args) == 0:
|
|
return 0
|
|
if len(args) == 1:
|
|
return args[0]
|
|
acc = cls.add(args[0], args[1])
|
|
for i in range(2, len(args)):
|
|
acc = cls.add(acc, args[i])
|
|
return acc
|
|
|
|
@staticmethod
|
|
def floordiv(a, b):
|
|
return a // b
|
|
|
|
@staticmethod
|
|
def mod(x, y):
|
|
return x % y
|
|
|
|
@staticmethod
|
|
def truncdiv(a, b):
|
|
return a / b
|
|
|
|
@staticmethod
|
|
def to_dtype(x, dtype):
|
|
if dtype == torch.float64:
|
|
return torch.sym_float(x)
|
|
raise NotImplementedError(f"to_dtype {dtype} NYI")
|
|
|
|
@staticmethod
|
|
def exp(x):
|
|
raise AssertionError("exp is not valid shape sympy expr")
|
|
|
|
@staticmethod
|
|
def log(x):
|
|
raise AssertionError("log is not valid shape sympy expr")
|
|
|
|
@staticmethod
|
|
def sqrt(x):
|
|
return torch._sym_sqrt(x) # type: ignore[attr-defined]
|
|
|
|
@staticmethod
|
|
def minimum(a, b):
|
|
return torch.sym_min(a, b)
|
|
|
|
@staticmethod
|
|
def maximum(a, b):
|
|
return torch.sym_max(a, b)
|
|
|
|
@staticmethod
|
|
def floor_to_int(x, dtype):
|
|
return math.floor(x)
|
|
|
|
@staticmethod
|
|
def ceil_to_int(x, dtype):
|
|
return math.ceil(x)
|
|
|
|
@staticmethod
|
|
def floor(x):
|
|
return float(math.floor(x))
|
|
|
|
@staticmethod
|
|
def ceil(x):
|
|
return float(math.ceil(x))
|
|
|
|
@staticmethod
|
|
def truediv(a, b):
|
|
return a / b
|
|
|
|
@staticmethod
|
|
def pow(a, b):
|
|
return a**b
|
|
|
|
@staticmethod
|
|
def pow_by_natural(a, b):
|
|
# Pray that safe_pow is not needed here lol. In particular, this
|
|
# never participates in VR low/high ranges, so overflow should be
|
|
# unlikely
|
|
return a**b
|
|
|
|
@staticmethod
|
|
def round_to_int(a, dtype):
|
|
return round(a)
|
|
|
|
@staticmethod
|
|
def round_decimal(a, b):
|
|
return round(a, ndigits=b)
|
|
|
|
|
|
# Like PythonReferenceAnalysis, but some export-unfriendly choices of
|
|
# operators to make things faster
|
|
class OptimizedPythonReferenceAnalysis(PythonReferenceAnalysis):
|
|
@staticmethod
|
|
def sym_sum(args):
|
|
return torch.sym_sum(args)
|
|
|
|
|
|
def _to_dtype(x: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
|
|
return torch.ops.aten._to_copy(x, dtype=dtype)
|
|
|
|
|
|
# Suppose we have some int/float arguments. This diagram commutes:
|
|
#
|
|
# int/float -- PythonReferenceAnalysis.op --> int/float
|
|
# | |
|
|
# | |
|
|
# torch.tensor(..., dtype=torch.int64/torch.float64)
|
|
# | |
|
|
# V V
|
|
# Tensor -- TensorReferenceAnalysis.op --> Tensor
|
|
#
|
|
# NB: int before and after must be representable in int64 (we will
|
|
# insert guards accordingly.)
|
|
#
|
|
# This is guaranteed to be FX traceable with OpOverloads only.
|
|
class TensorReferenceAnalysis:
|
|
# NB: This is actually dead, because with Proxy tracing the factory
|
|
# function isn't traced correctly. Here for completeness.
|
|
@staticmethod
|
|
def constant(c, dtype):
|
|
d: Union[int, float, bool]
|
|
if dtype is torch.int64:
|
|
d = int(c)
|
|
elif dtype is torch.double:
|
|
d = float(c)
|
|
elif dtype is torch.bool:
|
|
d = bool(c)
|
|
else:
|
|
raise AssertionError(f"unrecognized dtype {dtype}")
|
|
return torch.ops.aten.scalar_tensor.default(d, dtype=dtype)
|
|
|
|
@staticmethod
|
|
def or_(a, b):
|
|
return torch.ops.aten.logical_or.default(a, b)
|
|
|
|
@staticmethod
|
|
def and_(a, b):
|
|
return torch.ops.aten.logical_and.default(a, b)
|
|
|
|
@staticmethod
|
|
def eq(a, b):
|
|
return torch.ops.aten.eq.Tensor(a, b)
|
|
|
|
@classmethod
|
|
def ne(cls, a, b):
|
|
return torch.ops.aten.ne.Tensor(a, b)
|
|
|
|
@staticmethod
|
|
def lt(a, b):
|
|
return torch.ops.aten.lt.Tensor(a, b)
|
|
|
|
@staticmethod
|
|
def gt(a, b):
|
|
return torch.ops.aten.gt.Tensor(a, b)
|
|
|
|
@staticmethod
|
|
def le(a, b):
|
|
return torch.ops.aten.le.Tensor(a, b)
|
|
|
|
@staticmethod
|
|
def ge(a, b):
|
|
return torch.ops.aten.ge.Tensor(a, b)
|
|
|
|
@staticmethod
|
|
def not_(a):
|
|
return torch.ops.aten.logical_not.default(a)
|
|
|
|
@staticmethod
|
|
def reciprocal(x):
|
|
return torch.ops.aten.reciprocal.default(x)
|
|
|
|
@staticmethod
|
|
def square(x):
|
|
# TODO: maybe composite implicit autograd doesn't work here?
|
|
return torch.ops.aten.square.default(x)
|
|
|
|
@staticmethod
|
|
def trunc_to_int(x, dtype):
|
|
return _to_dtype(torch.ops.aten.trunc.default(x), dtype)
|
|
|
|
@staticmethod
|
|
def ceil_to_int(x, dtype):
|
|
return _to_dtype(torch.ops.aten.ceil.default(x), dtype)
|
|
|
|
@staticmethod
|
|
def floor_to_int(x, dtype):
|
|
return _to_dtype(torch.ops.aten.floor.default(x), dtype)
|
|
|
|
@staticmethod
|
|
def floor(x):
|
|
return torch.ops.aten.floor.default(x)
|
|
|
|
@staticmethod
|
|
def ceil(x):
|
|
return torch.ops.aten.ceil.default(x)
|
|
|
|
@staticmethod
|
|
def to_dtype(x, dtype):
|
|
return _to_dtype(x, dtype)
|
|
|
|
@staticmethod
|
|
def mod(x, y):
|
|
# TODO: https://github.com/pytorch/pytorch/pull/133654
|
|
raise NotImplementedError(
|
|
"no C-style modulus operation available from frontend atm"
|
|
)
|
|
|
|
@staticmethod
|
|
def abs(x):
|
|
return torch.ops.aten.abs.default(x)
|
|
|
|
@staticmethod
|
|
def neg(x):
|
|
return torch.ops.aten.neg.default(x)
|
|
|
|
@staticmethod
|
|
def truediv(a, b):
|
|
return torch.ops.aten.true_divide.Tensor(a, b)
|
|
|
|
@staticmethod
|
|
def int_truediv(a, b):
|
|
raise NotImplementedError(
|
|
"Python int truediv difficult to implement in PyTorch atm"
|
|
)
|
|
|
|
# TODO: This is wrong, CPython has a custom implementation of true
|
|
# division that results in higher precision when the floats are
|
|
# sufficiently large. Short term fix: add a guard here
|
|
return torch.ops.aten.true_divide.default(
|
|
_to_dtype(a, torch.float64), _to_dtype(b, torch.float64)
|
|
)
|
|
|
|
@staticmethod
|
|
def floordiv(a, b):
|
|
return torch.ops.aten.floor_divide(a, b)
|
|
|
|
@staticmethod
|
|
def truncdiv(a, b):
|
|
raise NotImplementedError(
|
|
"no C-style truncdiv operation available from frontend atm"
|
|
)
|
|
|
|
@staticmethod
|
|
def add(a, b):
|
|
return torch.ops.aten.add.Tensor(a, b)
|
|
|
|
@staticmethod
|
|
def mul(a, b):
|
|
return torch.ops.aten.mul.Tensor(a, b)
|
|
|
|
@staticmethod
|
|
def sub(a, b):
|
|
return torch.ops.aten.sub.Tensor(a, b)
|
|
|
|
@staticmethod
|
|
def exp(x):
|
|
return torch.ops.aten.exp.default(x)
|
|
|
|
@staticmethod
|
|
def log(x):
|
|
return torch.ops.aten.log.default(x)
|
|
|
|
@staticmethod
|
|
def sqrt(x):
|
|
return torch.ops.aten.sqrt.default(x)
|
|
|
|
@staticmethod
|
|
def pow(a, b):
|
|
return torch.ops.aten.pow.Tensor_Tensor(a, b)
|
|
|
|
@staticmethod
|
|
def pow_by_natural(a, b):
|
|
# NB: pow handles int x int fine
|
|
return torch.ops.aten.pow.Tensor_Tensor(a, b)
|
|
|
|
@staticmethod
|
|
def minimum(a, b):
|
|
return torch.ops.aten.minimum.default(a, b)
|
|
|
|
@staticmethod
|
|
def maximum(a, b):
|
|
return torch.ops.aten.maximum.default(a, b)
|
|
|
|
@staticmethod
|
|
def round_to_int(a, dtype):
|
|
return torch.ops.aten.round.default(a)
|
|
|
|
@staticmethod
|
|
def round_decimal(a, b):
|
|
raise NotImplementedError(
|
|
"round decimal doesn't support Tensor second argument atm"
|
|
)
|
|
|
|
# return torch.ops.aten.round.decimals(a, b)
|