pytorch/test/jit/test_enum.py
Yanan Cao 8e03c38a4f Add prim::EnumName and prim::EnumValue ops (#41965)
Summary:
[2/N] Implement Enum JIT support

Add prim::EnumName and prim::EnumValue and their lowerings to support getting `name` and `value` attribute of Python enums.

Supported:
Enum-typed function targuments
using Enum type and comparing them
Support getting name/value attrs of enums

TODO:
Add PyThon sugared value for Enum
Support Enum-typed return values
Support enum values of different types in same Enum class
Support serialization and deserialization

Pull Request resolved: https://github.com/pytorch/pytorch/pull/41965

Reviewed By: eellison

Differential Revision: D22714446

Pulled By: gmagogsfm

fbshipit-source-id: db8c4e26b657e7782dbfc2b58a141add1263f76e
2020-07-24 18:33:18 -07:00

163 lines
4.9 KiB
Python

import os
import sys
import torch
from enum import Enum
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")
class TestEnum(JitTestCase):
def setUp(self):
self.saved_enum_env_var = os.environ.get("EXPERIMENTAL_ENUM_SUPPORT", None)
os.environ["EXPERIMENTAL_ENUM_SUPPORT"] = "1"
def tearDown(self):
if self.saved_enum_env_var:
os.environ["EXPERIMENTAL_ENUM_SUPPORT"] = self.saved_enum_env_var
def test_enum_value_types(self):
global IntEnum
class IntEnum(Enum):
FOO = 1
BAR = 2
global FloatEnum
class FloatEnum(Enum):
FOO = 1.2
BAR = 2.3
global StringEnum
class StringEnum(Enum):
FOO = "foo as in foo bar"
BAR = "bar as in foo bar"
def supported_enum_types(a: IntEnum, b: FloatEnum, c: StringEnum):
return (a.name, b.name, c.name)
# TODO(gmagogsfm): Re-enable hooks when serialization/deserialization
# is supported.
with torch._jit_internal._disable_emit_hooks():
torch.jit.script(supported_enum_types)
global TensorEnum
class TensorEnum(Enum):
FOO = torch.tensor(0)
BAR = torch.tensor(1)
def unsupported_enum_types(a: TensorEnum):
return a.name
with self.assertRaisesRegex(RuntimeError, "Cannot create Enum with value type 'Tensor'"):
torch.jit.script(unsupported_enum_types)
def test_enum_comp(self):
global Color
class Color(Enum):
RED = 1
GREEN = 2
def enum_comp(x: Color, y: Color) -> bool:
return x == y
# TODO(gmagogsfm): Re-enable hooks when serialization/deserialization
# is supported.
with torch._jit_internal._disable_emit_hooks():
scripted_enum_comp = torch.jit.script(enum_comp)
self.assertEqual(
scripted_enum_comp(Color.RED, Color.RED),
enum_comp(Color.RED, Color.RED))
self.assertEqual(
scripted_enum_comp(Color.RED, Color.GREEN),
enum_comp(Color.RED, Color.GREEN))
def test_heterogenous_value_type_enum_error(self):
global Color
class Color(Enum):
RED = 1
GREEN = "green"
def enum_comp(x: Color, y: Color) -> bool:
return x == y
# TODO(gmagogsfm): Re-enable hooks when serialization/deserialization
# is supported.
with self.assertRaisesRegex(RuntimeError, "Could not unify type list"):
scripted_enum_comp = torch.jit.script(enum_comp)
def test_enum_name(self):
global Color
class Color(Enum):
RED = 1
GREEN = 2
def enum_name(x: Color) -> str:
return x.name
# TODO(gmagogsfm): Re-enable hooks when serialization/deserialization
# is supported.
with torch._jit_internal._disable_emit_hooks():
scripted_enum_name = torch.jit.script(enum_name)
self.assertEqual(scripted_enum_name(Color.RED), Color.RED.name)
self.assertEqual(scripted_enum_name(Color.GREEN), Color.GREEN.name)
def test_enum_value(self):
global Color
class Color(Enum):
RED = 1
GREEN = 2
def enum_value(x: Color) -> int:
return x.value
# TODO(gmagogsfm): Re-enable hooks when serialization/deserialization
# is supported.
with torch._jit_internal._disable_emit_hooks():
scripted_enum_value = torch.jit.script(enum_value)
self.assertEqual(scripted_enum_value(Color.RED), Color.RED.value)
self.assertEqual(scripted_enum_value(Color.GREEN), Color.GREEN.value)
# Tests that Enum support features are properly guarded before they are mature.
class TestEnumFeatureGuard(JitTestCase):
def setUp(self):
self.saved_enum_env_var = os.environ.get("EXPERIMENTAL_ENUM_SUPPORT", None)
if self.saved_enum_env_var:
del os.environ["EXPERIMENTAL_ENUM_SUPPORT"]
def tearDown(self):
if self.saved_enum_env_var:
os.environ["EXPERIMENTAL_ENUM_SUPPORT"] = self.saved_enum_env_var
def test_enum_comp_disabled(self):
global Color
class Color(Enum):
RED = 1
GREEN = 2
def enum_comp(x: Color, y: Color) -> bool:
return x == y
with self.assertRaisesRegex(NotImplementedError,
"Enum support is work in progress"):
torch.jit.script(enum_comp)