pytorch/test/jit/test_enum.py
Yanan Cao 35a36c1280 Implement JIT Enum type serialization and deserialization (#43460)
Summary:
[Re-review tips: nothing changed other than a type in python_ir.cpp to fix a windows build failure]

Adds code printing for enum type
Enhance enum type to include all contained enum names and values
Adds code parsing for enum type in deserialization
Enabled serialization/deserialization test in most TestCases. (With a few dangling issues to be addressed in later PRs to avoid this PR grows too large)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/43460

Reviewed By: albanD

Differential Revision: D23284929

Pulled By: gmagogsfm

fbshipit-source-id: e3e81d6106f18b7337ac3ff5cd1eeaff854904f3
2020-08-24 12:04:31 -07:00

384 lines
10 KiB
Python

import os
import sys
import torch
from torch.testing import FileCheck
from enum import Enum
from typing import Any, List
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")
class TestEnum(JitTestCase):
def setUp(self):
super().setUp()
self.saved_enum_env_var = os.environ.get("EXPERIMENTAL_ENUM_SUPPORT", None)
os.environ["EXPERIMENTAL_ENUM_SUPPORT"] = "1"
def tearDown(self):
super().tearDown()
if self.saved_enum_env_var:
os.environ["EXPERIMENTAL_ENUM_SUPPORT"] = self.saved_enum_env_var
def test_enum_value_types(self):
global IntEnum
class IntEnum(Enum):
FOO = 1
BAR = 2
global FloatEnum
class FloatEnum(Enum):
FOO = 1.2
BAR = 2.3
global StringEnum
class StringEnum(Enum):
FOO = "foo as in foo bar"
BAR = "bar as in foo bar"
@torch.jit.script
def supported_enum_types(a: IntEnum, b: FloatEnum, c: StringEnum):
return (a.name, b.name, c.name)
FileCheck() \
.check("IntEnum") \
.check("FloatEnum") \
.check("StringEnum") \
.run(str(supported_enum_types.graph))
global TensorEnum
class TensorEnum(Enum):
FOO = torch.tensor(0)
BAR = torch.tensor(1)
def unsupported_enum_types(a: TensorEnum):
return a.name
with self.assertRaisesRegex(RuntimeError, "Cannot create Enum with value type 'Tensor'"):
torch.jit.script(unsupported_enum_types)
def test_enum_comp(self):
global Color
class Color(Enum):
RED = 1
GREEN = 2
@torch.jit.script
def enum_comp(x: Color, y: Color) -> bool:
return x == y
FileCheck().check("aten::eq").run(str(enum_comp.graph))
self.assertEqual(enum_comp(Color.RED, Color.RED), True)
self.assertEqual(enum_comp(Color.RED, Color.GREEN), False)
def test_enum_comp_diff_classes(self):
global Foo, Bar
class Foo(Enum):
ITEM1 = 1
ITEM2 = 2
class Bar(Enum):
ITEM1 = 1
ITEM2 = 2
@torch.jit.script
def enum_comp(x: Foo) -> bool:
return x == Bar.ITEM1
FileCheck() \
.check("prim::Constant") \
.check_same("Bar.ITEM1") \
.check("aten::eq") \
.run(str(enum_comp.graph))
self.assertEqual(enum_comp(Foo.ITEM1), False)
def test_heterogenous_value_type_enum_error(self):
global Color
class Color(Enum):
RED = 1
GREEN = "green"
def enum_comp(x: Color, y: Color) -> bool:
return x == y
with self.assertRaisesRegex(RuntimeError, "Could not unify type list"):
torch.jit.script(enum_comp)
def test_enum_name(self):
global Color
class Color(Enum):
RED = 1
GREEN = 2
@torch.jit.script
def enum_name(x: Color) -> str:
return x.name
FileCheck() \
.check("Color") \
.check_next("prim::EnumName") \
.check_next("return") \
.run(str(enum_name.graph))
self.assertEqual(enum_name(Color.RED), Color.RED.name)
self.assertEqual(enum_name(Color.GREEN), Color.GREEN.name)
def test_enum_value(self):
global Color
class Color(Enum):
RED = 1
GREEN = 2
@torch.jit.script
def enum_value(x: Color) -> int:
return x.value
FileCheck() \
.check("Color") \
.check_next("prim::EnumValue") \
.check_next("return") \
.run(str(enum_value.graph))
self.assertEqual(enum_value(Color.RED), Color.RED.value)
self.assertEqual(enum_value(Color.GREEN), Color.GREEN.value)
def test_enum_as_const(self):
global Color
class Color(Enum):
RED = 1
GREEN = 2
@torch.jit.script
def enum_const(x: Color) -> bool:
return x == Color.RED
FileCheck() \
.check("prim::Constant[value=__torch__.jit.test_enum.Color.RED]") \
.check_next("aten::eq") \
.check_next("return") \
.run(str(enum_const.graph))
self.assertEqual(enum_const(Color.RED), True)
self.assertEqual(enum_const(Color.GREEN), False)
def test_non_existent_enum_value(self):
global Color
class Color(Enum):
RED = 1
GREEN = 2
def enum_const(x: Color) -> bool:
if x == Color.PURPLE:
return True
else:
return False
with self.assertRaisesRegexWithHighlight(RuntimeError, "has no attribute 'PURPLE'", "Color.PURPLE"):
torch.jit.script(enum_const)
def test_enum_ivalue_type(self):
global Color
class Color(Enum):
RED = 1
GREEN = 2
@torch.jit.script
def is_color_enum(x: Any):
return isinstance(x, Color)
FileCheck() \
.check("prim::isinstance[types=[Enum<__torch__.jit.test_enum.Color>]]") \
.check_next("return") \
.run(str(is_color_enum.graph))
self.assertEqual(is_color_enum(Color.RED), True)
self.assertEqual(is_color_enum(Color.GREEN), True)
self.assertEqual(is_color_enum(1), False)
def test_closed_over_enum_constant(self):
global Color
class Color(Enum):
RED = 1
GREEN = 2
a = Color
@torch.jit.script
def closed_over_aliased_type():
return a.RED.value
FileCheck() \
.check("prim::Constant[value={}]".format(a.RED.value)) \
.check_next("return") \
.run(str(closed_over_aliased_type.graph))
self.assertEqual(closed_over_aliased_type(), Color.RED.value)
b = Color.RED
@torch.jit.script
def closed_over_aliased_value():
return b.value
FileCheck() \
.check("prim::Constant[value={}]".format(b.value)) \
.check_next("return") \
.run(str(closed_over_aliased_value.graph))
self.assertEqual(closed_over_aliased_value(), Color.RED.value)
def test_enum_as_module_attribute(self):
global Color
class Color(Enum):
RED = 1
GREEN = 2
class TestModule(torch.nn.Module):
def __init__(self, e: Color):
super(TestModule, self).__init__()
self.e = e
def forward(self):
return self.e.value
m = TestModule(Color.RED)
# TODO(gmagogsfm): Re-enable hooks when Enum attr pickling is supported.
with torch._jit_internal._disable_emit_hooks():
scripted = torch.jit.script(m)
FileCheck() \
.check("TestModule") \
.check_next("Color") \
.check_same("prim::GetAttr[name=\"e\"]") \
.check_next("prim::EnumValue") \
.check_next("return") \
.run(str(scripted.graph))
self.assertEqual(scripted(), Color.RED.value)
def test_enum_return(self):
global Color
class Color(Enum):
RED = 1
GREEN = 2
@torch.jit.script
def return_enum(cond: bool):
if cond:
return Color.RED
else:
return Color.GREEN
self.assertEqual(return_enum(True), Color.RED)
self.assertEqual(return_enum(False), Color.GREEN)
def test_enum_module_return(self):
global Color
class Color(Enum):
RED = 1
GREEN = 2
class TestModule(torch.nn.Module):
def __init__(self, e: Color):
super(TestModule, self).__init__()
self.e = e
def forward(self):
return self.e
m = TestModule(Color.RED)
# TODO(gmagogsfm): Re-enable hooks when Enum attr pickling is supported.
with torch._jit_internal._disable_emit_hooks():
scripted = torch.jit.script(m)
FileCheck() \
.check("TestModule") \
.check_next("Color") \
.check_same("prim::GetAttr[name=\"e\"]") \
.check_next("return") \
.run(str(scripted.graph))
self.assertEqual(scripted(), Color.RED)
def test_enum_iterate(self):
global Color
class Color(Enum):
RED = 1
GREEN = 2
BLUE = 3
def iterate_enum(x: Color):
res: List[int] = []
for e in Color:
if e != x:
res.append(e.value)
return res
scripted = torch.jit.script(iterate_enum)
FileCheck() \
.check("Enum<__torch__.jit.test_enum.Color>[]") \
.check_same("Color.RED") \
.check_same("Color.GREEN") \
.check_same("Color.BLUE") \
.run(str(scripted.graph))
# PURPLE always appear last because we follow Python's Enum definition order.
self.assertEqual(scripted(Color.RED), [Color.GREEN.value, Color.BLUE.value])
self.assertEqual(scripted(Color.GREEN), [Color.RED.value, Color.BLUE.value])
# Tests that Enum support features are properly guarded before they are mature.
class TestEnumFeatureGuard(JitTestCase):
def setUp(self):
super().setUp()
self.saved_enum_env_var = os.environ.get("EXPERIMENTAL_ENUM_SUPPORT", None)
if self.saved_enum_env_var:
del os.environ["EXPERIMENTAL_ENUM_SUPPORT"]
def tearDown(self):
super().tearDown()
if self.saved_enum_env_var:
os.environ["EXPERIMENTAL_ENUM_SUPPORT"] = self.saved_enum_env_var
def test_enum_comp_disabled(self):
global Color
class Color(Enum):
RED = 1
GREEN = 2
def enum_comp(x: Color, y: Color) -> bool:
return x == y
with self.assertRaisesRegexWithHighlight(RuntimeError, "Unknown type name 'Color'", "Color"):
torch.jit.script(enum_comp)