mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Description:
- Added dynamic shapes support for math trigo ops: sin(h), cos(h), tan(h) ...
```python
import math
import torch
def func(x, a, b):
c = 0
c = c + math.sqrt(a)
c = c + math.cos(a)
c = c + math.cosh(a)
c = c + math.sin(a)
c = c + math.sinh(a)
c = c + math.tan(a)
c = c + math.tanh(a)
c = c + math.asin(b)
c = c + math.acos(b)
c = c + math.atan(a)
y = x + c
return y
cfunc = torch.compile(func, dynamic=True, fullgraph=True)
device = "cpu" # or "cuda"
x = torch.tensor([0, 1, 2, 3], dtype=torch.float32, device=device)
a = 12
b = 1
out = cfunc(x, a, b)
expected = func(x, a, b)
torch.testing.assert_close(out, expected)
```
and the graph `TORCH_LOGS=+graph_code python check_math_ops.py`:
<details>
<summary>
graph code
</summary>
```
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] TRACED GRAPH
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] ===== __compiled_fn_0 =====
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] <eval_with_key>.0 class GraphModule(torch.nn.Module):
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] def forward(self, L_a_ : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor):
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] l_a_ = L_a_
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] l_x_ = L_x_
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] # File: check_math_ops.py:57, code: c = c + math.sqrt(a)
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] sym_sqrt = torch.sym_sqrt(l_a_)
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] add = 0 + sym_sqrt; sym_sqrt = None
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] # File: check_math_ops.py:58, code: c = c + math.cos(a)
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] sym_cos = torch.sym_cos(l_a_)
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] add_1 = add + sym_cos; add = sym_cos = None
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] # File: check_math_ops.py:59, code: c = c + math.cosh(a)
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] sym_cosh = torch.sym_cosh(l_a_)
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] add_2 = add_1 + sym_cosh; add_1 = sym_cosh = None
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] # File: check_math_ops.py:60, code: c = c + math.sin(a)
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] sym_sin = torch.sym_sin(l_a_)
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] add_3 = add_2 + sym_sin; add_2 = sym_sin = None
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] # File: check_math_ops.py:61, code: c = c + math.sinh(a)
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] sym_sinh = torch.sym_sinh(l_a_)
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] add_4 = add_3 + sym_sinh; add_3 = sym_sinh = None
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] # File: check_math_ops.py:62, code: c = c + math.tan(a)
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] sym_tan = torch.sym_tan(l_a_)
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] add_5 = add_4 + sym_tan; add_4 = sym_tan = None
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] # File: check_math_ops.py:63, code: c = c + math.tanh(a)
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] sym_tanh = torch.sym_tanh(l_a_)
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] add_6 = add_5 + sym_tanh; add_5 = sym_tanh = None
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] # File: check_math_ops.py:64, code: c = c + math.asin(b)
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] add_7 = add_6 + 1.5707963267948966; add_6 = None
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] # File: check_math_ops.py:65, code: c = c + math.acos(b)
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] add_8 = add_7 + 0.0; add_7 = None
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] # File: check_math_ops.py:66, code: c = c + math.atan(a)
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] sym_atan = torch.sym_atan(l_a_); l_a_ = None
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] add_9 = add_8 + sym_atan; add_8 = sym_atan = None
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] # File: check_math_ops.py:67, code: y = x + c
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] y = l_x_ + add_9; l_x_ = add_9 = None
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] return (y,)
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]
```
</details>
Generated code with `TORCH_LOGS=+output_code python check_math_ops.py`:
<details>
<summary>
C++ code
</summary>
```
[2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG] cpp_fused_add_0 = async_compile.cpp('''
[2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG] #include "/tmp/torchinductor_root/2l/c2ljzlm4sosod7u6lyrroqdba6hmfcyijrric6p4t3fhbcmw6osp.h"
[2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG] extern "C" void kernel(const float* in_ptr0,
[2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG] float* out_ptr0,
[2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG] const long ks0,
[2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG] const long ks1)
[2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG] {
[2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG] {
[2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG] #pragma GCC ivdep
[2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG] for(long x0=static_cast<long>(0L); x0<static_cast<long>(ks0); x0+=static_cast<long>(1L))
[2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG] {
[2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG] auto tmp0 = in_ptr0[static_cast<long>(x0)];
[2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG] auto tmp1 = c10::convert<float>(1.57079632679490 + (std::sqrt(ks1)) + (std::atan(ks1)) + (std::cos(ks1)) + (std::cosh(ks1)) + (std::sin(ks1)) + (std::sinh(ks1)) + (std::tan(ks1)) + (std::tanh(ks1)));
[2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG] auto tmp2 = decltype(tmp0)(tmp0 + tmp1);
[2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG] out_ptr0[static_cast<long>(x0)] = tmp2;
[2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG] }
[2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG] }
[2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG] }
[2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG] ''')
```
</details>
<details>
<summary>
Triton code
</summary>
```
[2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG] @pointwise(
[2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG] size_hints=[4],
[2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG] filename=__file__,
[2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG] triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [instance_descriptor(divisible_by_16=(0, 1), equal_to_1=(), i
ds_of_folded_args=(), divisible_by_8=())]},
[2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG] inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_0', 'mutated_arg_names': []},
[2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG] min_elem_per_thread=0
[2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG] )
[2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG] @triton.jit
[2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG] def triton_(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr):
[2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG] xoffset = tl.program_id(0) * XBLOCK
[2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG] xindex = xoffset + tl.arange(0, XBLOCK)[:]
[2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG] xmask = xindex < xnumel
[2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG] x0 = xindex
[2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG] tmp0 = tl.load(in_ptr0 + (x0), xmask)
[2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG] tmp1 = 1.57079632679490 + (tl.math.sqrt(ks0.to(tl.float32))) + (tl.math.atan((ks0).to(tl.float32))) + (tl.math.cos((ks0).to(tl.float32))) + (tl.math.cosh((ks0).to(tl.float32))) + (tl.math.sin((ks0)
.to(tl.float32))) + (tl.math.sinh((ks0).to(tl.float32))) + (tl.math.tan((ks0).to(tl.float32))) + (tl.math.tanh((ks0).to(tl.float32)))
[2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG] tmp2 = tmp1.to(tl.float32)
[2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG] tmp3 = tmp0 + tmp2
[2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG] tl.store(out_ptr0 + (x0), tmp3, xmask)
[2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG] ''')
```
</details>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114866
Approved by: https://github.com/peterbell10
215 lines
4.6 KiB
Python
215 lines
4.6 KiB
Python
import math
|
|
|
|
import sympy
|
|
|
|
import torch
|
|
|
|
|
|
# The sympy interpretation of operators. It will also sometimes work with
|
|
# plain int/float, but if you do certain operations you will get out a
|
|
# sympy.Basic in the end. If you want the Python/FX traceable interpretation,
|
|
# check PythonReferenceAnalysis.
|
|
# NB: For magic methods this needs to use normal magic methods
|
|
# so that test_magic_methods works
|
|
class ReferenceAnalysis:
|
|
@staticmethod
|
|
def constant(c, dtype):
|
|
return sympy.sympify(c)
|
|
|
|
@staticmethod
|
|
def or_(a, b):
|
|
return a | b
|
|
|
|
@staticmethod
|
|
def and_(a, b):
|
|
return a & b
|
|
|
|
@staticmethod
|
|
def eq(a, b):
|
|
if isinstance(a, sympy.Expr) or isinstance(b, sympy.Expr):
|
|
return sympy.Eq(a, b)
|
|
return a == b
|
|
|
|
@classmethod
|
|
def ne(cls, a, b):
|
|
return cls.not_(cls.eq(a, b))
|
|
|
|
@staticmethod
|
|
def lt(a, b):
|
|
return a < b
|
|
|
|
@staticmethod
|
|
def gt(a, b):
|
|
return a > b
|
|
|
|
@staticmethod
|
|
def le(a, b):
|
|
return a <= b
|
|
|
|
@staticmethod
|
|
def ge(a, b):
|
|
return a >= b
|
|
|
|
@staticmethod
|
|
def not_(a):
|
|
assert not isinstance(a, bool)
|
|
return ~a
|
|
|
|
@staticmethod
|
|
def reciprocal(x):
|
|
return 1 / x
|
|
|
|
@staticmethod
|
|
def square(x):
|
|
return x * x
|
|
|
|
@staticmethod
|
|
def mod(x, y):
|
|
return x % y
|
|
|
|
@staticmethod
|
|
def abs(x):
|
|
return abs(x)
|
|
|
|
@staticmethod
|
|
def neg(x):
|
|
return -x
|
|
|
|
@staticmethod
|
|
def truediv(a, b):
|
|
return a / b
|
|
|
|
@staticmethod
|
|
def div(a, b):
|
|
return ReferenceAnalysis.truediv(a, b)
|
|
|
|
@staticmethod
|
|
def floordiv(a, b):
|
|
if b == 0:
|
|
return sympy.nan if a == 0 else sympy.zoo
|
|
return a // b
|
|
|
|
@staticmethod
|
|
def truncdiv(a, b):
|
|
result = a / b
|
|
if result.is_finite:
|
|
result = sympy.Integer(result)
|
|
|
|
return result
|
|
|
|
@staticmethod
|
|
def add(a, b):
|
|
return a + b
|
|
|
|
@staticmethod
|
|
def mul(a, b):
|
|
return a * b
|
|
|
|
@staticmethod
|
|
def sub(a, b):
|
|
return a - b
|
|
|
|
@staticmethod
|
|
def exp(x):
|
|
return sympy.exp(x)
|
|
|
|
@staticmethod
|
|
def log(x):
|
|
return sympy.log(x)
|
|
|
|
@staticmethod
|
|
def sqrt(x):
|
|
return sympy.sqrt(x)
|
|
|
|
@staticmethod
|
|
def pow(a, b):
|
|
return a**b
|
|
|
|
@staticmethod
|
|
def minimum(a, b):
|
|
# Poorman's version of upcasting in Sympy
|
|
# This won't do for sympy.Expr as the casting does nothing for those
|
|
if a.is_Float or not a.is_finite or b.is_Float or not b.is_finite:
|
|
result_type = sympy.Float
|
|
else:
|
|
assert a.is_Integer
|
|
assert b.is_Integer
|
|
result_type = sympy.Integer
|
|
return sympy.Min(result_type(a), result_type(b))
|
|
|
|
@staticmethod
|
|
def maximum(a, b):
|
|
# Poorman's version of upcasting in Sympy
|
|
# This won't do for sympy.Expr as the casting does nothing for those
|
|
if a.is_Float or not a.is_finite or b.is_Float or not b.is_finite:
|
|
result_type = sympy.Float
|
|
else:
|
|
assert a.is_Integer
|
|
assert b.is_Integer
|
|
result_type = sympy.Integer
|
|
return sympy.Max(result_type(a), result_type(b))
|
|
|
|
@staticmethod
|
|
def floor(x):
|
|
return sympy.floor(x)
|
|
|
|
@staticmethod
|
|
def ceil(x):
|
|
return sympy.ceiling(x)
|
|
|
|
|
|
# Unlike ReferenceAnalysis, does NOT sympyify, instead, works with plain
|
|
# Python types and is FX traceable. Inheritance here is purely for code
|
|
# sharing (TODO: considering splitting out a BaseReferenceAnalysis).
|
|
class PythonReferenceAnalysis(ReferenceAnalysis):
|
|
@staticmethod
|
|
def constant(c, dtype):
|
|
if dtype is torch.int64:
|
|
return int(c)
|
|
elif dtype is torch.double:
|
|
return float(c)
|
|
elif dtype is torch.bool:
|
|
return bool(c)
|
|
else:
|
|
raise AssertionError(f"unrecognized dtype {dtype}")
|
|
|
|
@staticmethod
|
|
def not_(a):
|
|
return torch.sym_not(a)
|
|
|
|
@staticmethod
|
|
def floordiv(a, b):
|
|
return a // b
|
|
|
|
@staticmethod
|
|
def truncdiv(a, b):
|
|
return a / b
|
|
|
|
@staticmethod
|
|
def exp(x):
|
|
raise AssertionError("exp is not valid shape sympy expr")
|
|
|
|
@staticmethod
|
|
def log(x):
|
|
raise AssertionError("log is not valid shape sympy expr")
|
|
|
|
@staticmethod
|
|
def sqrt(x):
|
|
return torch._sym_sqrt(x) # type: ignore[attr-defined]
|
|
|
|
@staticmethod
|
|
def minimum(a, b):
|
|
return torch.sym_min(a, b)
|
|
|
|
@staticmethod
|
|
def maximum(a, b):
|
|
return torch.sym_max(a, b)
|
|
|
|
@staticmethod
|
|
def floor(x):
|
|
return math.floor(x)
|
|
|
|
@staticmethod
|
|
def ceil(x):
|
|
return math.ceil(x)
|