mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-08 07:39:33 +01:00
Summary: Action following https://github.com/pytorch/pytorch/issues/66232 Pull Request resolved: https://github.com/pytorch/pytorch/pull/66808 Reviewed By: mrshenli Differential Revision: D31761414 Pulled By: janeyx99 fbshipit-source-id: baf8c49ff9c4bcda7b0ea0f6aafd26380586e72d
93 lines
4.3 KiB
Python
93 lines
4.3 KiB
Python
# Owner(s): ["oncall: jit"]
|
|
|
|
import torch
|
|
from torch.testing._internal.common_utils import TestCase
|
|
|
|
class TestAtenPow(TestCase):
|
|
def test_aten_pow_zero_negative_exponent(self):
|
|
'''
|
|
1. Testing a = int, b = int
|
|
'''
|
|
@torch.jit.script
|
|
def fn_int_int(a: int, b: int):
|
|
return a ** b
|
|
# Existing correct behaviors of aten::pow
|
|
self.assertEqual(fn_int_int(2, 1), 2 ** 1)
|
|
self.assertEqual(fn_int_int(2, 0), 2 ** 0)
|
|
self.assertEqual(fn_int_int(2, -2), 2 ** (-2))
|
|
self.assertEqual(fn_int_int(-2, 2), (-2) ** 2)
|
|
self.assertEqual(fn_int_int(-2, 0), (-2) ** 0)
|
|
self.assertEqual(fn_int_int(-2, -2), (-2) ** (-2))
|
|
self.assertEqual(fn_int_int(-2, -1), (-2) ** (-1))
|
|
self.assertEqual(fn_int_int(0, 2), 0 ** 1)
|
|
self.assertEqual(fn_int_int(0, 0), 0 ** 0)
|
|
# zero base and negative exponent case that should trigger RunTimeError
|
|
self.assertRaises(RuntimeError, fn_int_int, 0, -2)
|
|
|
|
'''
|
|
2. Testing a = int, b = float
|
|
'''
|
|
@torch.jit.script
|
|
def fn_int_float(a: int, b: float):
|
|
return a ** b
|
|
# Existing correct behaviors of aten::pow
|
|
self.assertEqual(fn_int_float(2, 2.5), 2 ** 2.5)
|
|
self.assertEqual(fn_int_float(2, -2.5), 2 ** (-2.5))
|
|
self.assertEqual(fn_int_float(2, -0.0), 2 ** (-0.0))
|
|
self.assertEqual(fn_int_float(2, 0.0), 2 ** (0.0))
|
|
self.assertEqual(fn_int_float(-2, 2.0), (-2) ** 2.0)
|
|
self.assertEqual(fn_int_float(-2, -2.0), (-2) ** (-2.0))
|
|
self.assertEqual(fn_int_float(-2, -3.0), (-2) ** (-3.0))
|
|
self.assertEqual(fn_int_float(-2, -0.0), (-2) ** (-0.0))
|
|
self.assertEqual(fn_int_float(-2, 0.0), (-2) ** (0.0))
|
|
self.assertEqual(fn_int_float(0, 2.0), 0 ** 2.0)
|
|
self.assertEqual(fn_int_float(0, 0.5), 0 ** 0.5)
|
|
self.assertEqual(fn_int_float(0, 0.0), 0 ** 0.0)
|
|
self.assertEqual(fn_int_float(0, -0.0), 0 ** (-0.0))
|
|
# zero base and negative exponent case that should trigger RunTimeError
|
|
self.assertRaises(RuntimeError, fn_int_float, 0, -2.5)
|
|
|
|
'''
|
|
3. Testing a = float, b = int
|
|
'''
|
|
@torch.jit.script
|
|
def fn_float_int(a: float, b: int):
|
|
return a ** b
|
|
# Existing correct behaviors of aten::pow
|
|
self.assertEqual(fn_float_int(2.5, 2), 2.5 ** 2)
|
|
self.assertEqual(fn_float_int(2.5, -2), 2.5 ** (-2))
|
|
self.assertEqual(fn_float_int(2.5, -0), 2.5 ** (-0))
|
|
self.assertEqual(fn_float_int(2.5, 0), 2.5 ** 0)
|
|
self.assertEqual(fn_float_int(-2.5, 2), 2.5 ** 2)
|
|
self.assertEqual(fn_float_int(-2.5, -2), (-2.5) ** (-2))
|
|
self.assertEqual(fn_float_int(-2.5, -3), (-2.5) ** (-3))
|
|
self.assertEqual(fn_float_int(-2.5, -0), (-2.5) ** (-0))
|
|
self.assertEqual(fn_float_int(-2.5, 0), (-2.5) ** 0)
|
|
self.assertEqual(fn_float_int(0.0, 2), 0 ** 2)
|
|
self.assertEqual(fn_float_int(0.0, 0), 0 ** 0)
|
|
self.assertEqual(fn_float_int(0.0, -0), 0 ** (-0))
|
|
# zero base and negative exponent case that should trigger RunTimeError
|
|
self.assertRaises(RuntimeError, fn_float_int, 0.0, -2)
|
|
|
|
'''
|
|
4. Testing a = float, b = float
|
|
'''
|
|
@torch.jit.script
|
|
def fn_float_float(a: float, b: float):
|
|
return a ** b
|
|
# Existing correct behaviors of aten::pow
|
|
self.assertEqual(fn_float_float(2.5, 2.0), 2.5 ** 2.0)
|
|
self.assertEqual(fn_float_float(2.5, -2.0), 2.5 ** (-2.0))
|
|
self.assertEqual(fn_float_float(2.5, -0.0), 2.5 ** (-0.0))
|
|
self.assertEqual(fn_float_float(2.5, 0.0), 2.5 ** 0.0)
|
|
self.assertEqual(fn_float_float(-2.5, 2.0), 2.5 ** 2.0)
|
|
self.assertEqual(fn_float_float(-2.5, -2.0), (-2.5) ** (-2.0))
|
|
self.assertEqual(fn_float_float(-2.5, -3.0), (-2.5) ** (-3.0))
|
|
self.assertEqual(fn_float_float(-2.5, -0.0), (-2.5) ** (-0.0))
|
|
self.assertEqual(fn_float_float(-2.5, 0.0), (-2.5) ** 0.0)
|
|
self.assertEqual(fn_float_float(0.0, 2.0), 0.0 ** 2.0)
|
|
self.assertEqual(fn_float_float(0.0, 0.0), 0.0 ** 0.0)
|
|
self.assertEqual(fn_float_float(0.0, -0.0), 0.0 ** (-0.0))
|
|
# zero base and negative exponent case that should trigger RunTimeError
|
|
self.assertRaises(RuntimeError, fn_float_float, 0.0, -2.0)
|