pytorch/torch/utils/_sympy/reference.py
vfdev-5 7005a4bcb6 [dynamo] Added dyn shapes support for math trigo ops: sin(h), cos(h), tan(h) ... (#114866)
Description:
- Added dynamic shapes support for math trigo ops: sin(h), cos(h), tan(h) ...

```python
import math
import torch

def func(x, a, b):
    c = 0
    c = c + math.sqrt(a)
    c = c + math.cos(a)
    c = c + math.cosh(a)
    c = c + math.sin(a)
    c = c + math.sinh(a)
    c = c + math.tan(a)
    c = c + math.tanh(a)
    c = c + math.asin(b)
    c = c + math.acos(b)
    c = c + math.atan(a)
    y = x + c
    return y

cfunc = torch.compile(func, dynamic=True, fullgraph=True)

device = "cpu"  # or "cuda"
x = torch.tensor([0, 1, 2, 3], dtype=torch.float32, device=device)
a = 12
b = 1

out = cfunc(x, a, b)
expected = func(x, a, b)
torch.testing.assert_close(out, expected)
```

and the graph `TORCH_LOGS=+graph_code python check_math_ops.py`:

<details>
<summary>
graph code
</summary>

```
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] TRACED GRAPH
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]  ===== __compiled_fn_0 =====
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]  <eval_with_key>.0 class GraphModule(torch.nn.Module):
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]     def forward(self, L_a_ : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor):
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         l_a_ = L_a_
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         l_x_ = L_x_
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         # File: check_math_ops.py:57, code: c = c + math.sqrt(a)
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         sym_sqrt = torch.sym_sqrt(l_a_)
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         add = 0 + sym_sqrt;  sym_sqrt = None
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         # File: check_math_ops.py:58, code: c = c + math.cos(a)
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         sym_cos = torch.sym_cos(l_a_)
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         add_1 = add + sym_cos;  add = sym_cos = None
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         # File: check_math_ops.py:59, code: c = c + math.cosh(a)
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         sym_cosh = torch.sym_cosh(l_a_)
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         add_2 = add_1 + sym_cosh;  add_1 = sym_cosh = None
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         # File: check_math_ops.py:60, code: c = c + math.sin(a)
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         sym_sin = torch.sym_sin(l_a_)
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         add_3 = add_2 + sym_sin;  add_2 = sym_sin = None
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         # File: check_math_ops.py:61, code: c = c + math.sinh(a)
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         sym_sinh = torch.sym_sinh(l_a_)
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         add_4 = add_3 + sym_sinh;  add_3 = sym_sinh = None
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         # File: check_math_ops.py:62, code: c = c + math.tan(a)
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         sym_tan = torch.sym_tan(l_a_)
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         add_5 = add_4 + sym_tan;  add_4 = sym_tan = None
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         # File: check_math_ops.py:63, code: c = c + math.tanh(a)
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         sym_tanh = torch.sym_tanh(l_a_)
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         add_6 = add_5 + sym_tanh;  add_5 = sym_tanh = None
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         # File: check_math_ops.py:64, code: c = c + math.asin(b)
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         add_7 = add_6 + 1.5707963267948966;  add_6 = None
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         # File: check_math_ops.py:65, code: c = c + math.acos(b)
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         add_8 = add_7 + 0.0;  add_7 = None
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         # File: check_math_ops.py:66, code: c = c + math.atan(a)
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         sym_atan = torch.sym_atan(l_a_);  l_a_ = None
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         add_9 = add_8 + sym_atan;  add_8 = sym_atan = None
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         # File: check_math_ops.py:67, code: y = x + c
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         y = l_x_ + add_9;  l_x_ = add_9 = None
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         return (y,)
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]
```
</details>

Generated code with `TORCH_LOGS=+output_code python check_math_ops.py`:
<details>
<summary>
C++ code
</summary>

```
[2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG] cpp_fused_add_0 = async_compile.cpp('''
[2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG] #include "/tmp/torchinductor_root/2l/c2ljzlm4sosod7u6lyrroqdba6hmfcyijrric6p4t3fhbcmw6osp.h"
[2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG] extern "C" void kernel(const float* in_ptr0,
[2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG]                        float* out_ptr0,
[2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG]                        const long ks0,
[2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG]                        const long ks1)
[2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG] {
[2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG]     {
[2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG]         #pragma GCC ivdep
[2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG]         for(long x0=static_cast<long>(0L); x0<static_cast<long>(ks0); x0+=static_cast<long>(1L))
[2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG]         {
[2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG]             auto tmp0 = in_ptr0[static_cast<long>(x0)];
[2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG]             auto tmp1 = c10::convert<float>(1.57079632679490 + (std::sqrt(ks1)) + (std::atan(ks1)) + (std::cos(ks1)) + (std::cosh(ks1)) + (std::sin(ks1)) + (std::sinh(ks1)) + (std::tan(ks1)) + (std::tanh(ks1)));
[2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG]             auto tmp2 = decltype(tmp0)(tmp0 + tmp1);
[2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG]             out_ptr0[static_cast<long>(x0)] = tmp2;
[2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG]         }
[2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG]     }
[2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG] }
[2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG] ''')
```

</details>

<details>
<summary>
Triton code
</summary>

```
[2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG] @pointwise(
[2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG]     size_hints=[4],
[2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG]     filename=__file__,
[2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG]     triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [instance_descriptor(divisible_by_16=(0, 1), equal_to_1=(), i
ds_of_folded_args=(), divisible_by_8=())]},
[2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG]     inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_0', 'mutated_arg_names': []},
[2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG]     min_elem_per_thread=0
[2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG] )
[2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG] @triton.jit
[2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG] def triton_(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr):
[2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG]     xoffset = tl.program_id(0) * XBLOCK
[2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG]     xindex = xoffset + tl.arange(0, XBLOCK)[:]
[2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG]     xmask = xindex < xnumel
[2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG]     x0 = xindex
[2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG]     tmp0 = tl.load(in_ptr0 + (x0), xmask)
[2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG]     tmp1 = 1.57079632679490 + (tl.math.sqrt(ks0.to(tl.float32))) + (tl.math.atan((ks0).to(tl.float32))) + (tl.math.cos((ks0).to(tl.float32))) + (tl.math.cosh((ks0).to(tl.float32))) + (tl.math.sin((ks0)
.to(tl.float32))) + (tl.math.sinh((ks0).to(tl.float32))) + (tl.math.tan((ks0).to(tl.float32))) + (tl.math.tanh((ks0).to(tl.float32)))
[2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG]     tmp2 = tmp1.to(tl.float32)
[2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG]     tmp3 = tmp0 + tmp2
[2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG]     tl.store(out_ptr0 + (x0), tmp3, xmask)
[2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG] ''')
```

</details>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114866
Approved by: https://github.com/peterbell10
2024-01-11 11:52:28 +00:00

215 lines
4.6 KiB
Python

import math
import sympy
import torch
# The sympy interpretation of operators. It will also sometimes work with
# plain int/float, but if you do certain operations you will get out a
# sympy.Basic in the end. If you want the Python/FX traceable interpretation,
# check PythonReferenceAnalysis.
# NB: For magic methods this needs to use normal magic methods
# so that test_magic_methods works
class ReferenceAnalysis:
@staticmethod
def constant(c, dtype):
return sympy.sympify(c)
@staticmethod
def or_(a, b):
return a | b
@staticmethod
def and_(a, b):
return a & b
@staticmethod
def eq(a, b):
if isinstance(a, sympy.Expr) or isinstance(b, sympy.Expr):
return sympy.Eq(a, b)
return a == b
@classmethod
def ne(cls, a, b):
return cls.not_(cls.eq(a, b))
@staticmethod
def lt(a, b):
return a < b
@staticmethod
def gt(a, b):
return a > b
@staticmethod
def le(a, b):
return a <= b
@staticmethod
def ge(a, b):
return a >= b
@staticmethod
def not_(a):
assert not isinstance(a, bool)
return ~a
@staticmethod
def reciprocal(x):
return 1 / x
@staticmethod
def square(x):
return x * x
@staticmethod
def mod(x, y):
return x % y
@staticmethod
def abs(x):
return abs(x)
@staticmethod
def neg(x):
return -x
@staticmethod
def truediv(a, b):
return a / b
@staticmethod
def div(a, b):
return ReferenceAnalysis.truediv(a, b)
@staticmethod
def floordiv(a, b):
if b == 0:
return sympy.nan if a == 0 else sympy.zoo
return a // b
@staticmethod
def truncdiv(a, b):
result = a / b
if result.is_finite:
result = sympy.Integer(result)
return result
@staticmethod
def add(a, b):
return a + b
@staticmethod
def mul(a, b):
return a * b
@staticmethod
def sub(a, b):
return a - b
@staticmethod
def exp(x):
return sympy.exp(x)
@staticmethod
def log(x):
return sympy.log(x)
@staticmethod
def sqrt(x):
return sympy.sqrt(x)
@staticmethod
def pow(a, b):
return a**b
@staticmethod
def minimum(a, b):
# Poorman's version of upcasting in Sympy
# This won't do for sympy.Expr as the casting does nothing for those
if a.is_Float or not a.is_finite or b.is_Float or not b.is_finite:
result_type = sympy.Float
else:
assert a.is_Integer
assert b.is_Integer
result_type = sympy.Integer
return sympy.Min(result_type(a), result_type(b))
@staticmethod
def maximum(a, b):
# Poorman's version of upcasting in Sympy
# This won't do for sympy.Expr as the casting does nothing for those
if a.is_Float or not a.is_finite or b.is_Float or not b.is_finite:
result_type = sympy.Float
else:
assert a.is_Integer
assert b.is_Integer
result_type = sympy.Integer
return sympy.Max(result_type(a), result_type(b))
@staticmethod
def floor(x):
return sympy.floor(x)
@staticmethod
def ceil(x):
return sympy.ceiling(x)
# Unlike ReferenceAnalysis, does NOT sympyify, instead, works with plain
# Python types and is FX traceable. Inheritance here is purely for code
# sharing (TODO: considering splitting out a BaseReferenceAnalysis).
class PythonReferenceAnalysis(ReferenceAnalysis):
@staticmethod
def constant(c, dtype):
if dtype is torch.int64:
return int(c)
elif dtype is torch.double:
return float(c)
elif dtype is torch.bool:
return bool(c)
else:
raise AssertionError(f"unrecognized dtype {dtype}")
@staticmethod
def not_(a):
return torch.sym_not(a)
@staticmethod
def floordiv(a, b):
return a // b
@staticmethod
def truncdiv(a, b):
return a / b
@staticmethod
def exp(x):
raise AssertionError("exp is not valid shape sympy expr")
@staticmethod
def log(x):
raise AssertionError("log is not valid shape sympy expr")
@staticmethod
def sqrt(x):
return torch._sym_sqrt(x) # type: ignore[attr-defined]
@staticmethod
def minimum(a, b):
return torch.sym_min(a, b)
@staticmethod
def maximum(a, b):
return torch.sym_max(a, b)
@staticmethod
def floor(x):
return math.floor(x)
@staticmethod
def ceil(x):
return math.ceil(x)