mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Add prim::EnumName and prim::EnumValue ops (#41965)
Summary: [2/N] Implement Enum JIT support Add prim::EnumName and prim::EnumValue and their lowerings to support getting `name` and `value` attribute of Python enums. Supported: Enum-typed function targuments using Enum type and comparing them Support getting name/value attrs of enums TODO: Add PyThon sugared value for Enum Support Enum-typed return values Support enum values of different types in same Enum class Support serialization and deserialization Pull Request resolved: https://github.com/pytorch/pytorch/pull/41965 Reviewed By: eellison Differential Revision: D22714446 Pulled By: gmagogsfm fbshipit-source-id: db8c4e26b657e7782dbfc2b58a141add1263f76e
This commit is contained in:
parent
6287f9ed65
commit
8e03c38a4f
|
|
@ -66,6 +66,8 @@ namespace c10 {
|
|||
_(prim, ListConstruct) \
|
||||
_(prim, ListUnpack) \
|
||||
_(prim, DictConstruct) \
|
||||
_(prim, EnumName) \
|
||||
_(prim, EnumValue) \
|
||||
_(prim, StringIndex) \
|
||||
_(prim, NumToTensor) \
|
||||
_(prim, Uninitialized) \
|
||||
|
|
|
|||
|
|
@ -1145,7 +1145,7 @@ struct CAFFE2_API EnumType : public NamedType {
|
|||
AT_ERROR(
|
||||
"Cannot create Enum with value type '",
|
||||
value->str(),
|
||||
"', only int, float, Tensor and string keys are supported");
|
||||
"', only int, float and string are supported");
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -23,6 +23,44 @@ class TestEnum(JitTestCase):
|
|||
if self.saved_enum_env_var:
|
||||
os.environ["EXPERIMENTAL_ENUM_SUPPORT"] = self.saved_enum_env_var
|
||||
|
||||
def test_enum_value_types(self):
|
||||
global IntEnum
|
||||
|
||||
class IntEnum(Enum):
|
||||
FOO = 1
|
||||
BAR = 2
|
||||
|
||||
global FloatEnum
|
||||
|
||||
class FloatEnum(Enum):
|
||||
FOO = 1.2
|
||||
BAR = 2.3
|
||||
|
||||
global StringEnum
|
||||
|
||||
class StringEnum(Enum):
|
||||
FOO = "foo as in foo bar"
|
||||
BAR = "bar as in foo bar"
|
||||
|
||||
def supported_enum_types(a: IntEnum, b: FloatEnum, c: StringEnum):
|
||||
return (a.name, b.name, c.name)
|
||||
# TODO(gmagogsfm): Re-enable hooks when serialization/deserialization
|
||||
# is supported.
|
||||
with torch._jit_internal._disable_emit_hooks():
|
||||
torch.jit.script(supported_enum_types)
|
||||
|
||||
global TensorEnum
|
||||
|
||||
class TensorEnum(Enum):
|
||||
FOO = torch.tensor(0)
|
||||
BAR = torch.tensor(1)
|
||||
|
||||
def unsupported_enum_types(a: TensorEnum):
|
||||
return a.name
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Cannot create Enum with value type 'Tensor'"):
|
||||
torch.jit.script(unsupported_enum_types)
|
||||
|
||||
def test_enum_comp(self):
|
||||
global Color
|
||||
|
||||
|
|
@ -33,7 +71,7 @@ class TestEnum(JitTestCase):
|
|||
def enum_comp(x: Color, y: Color) -> bool:
|
||||
return x == y
|
||||
|
||||
# TODO(gmagogsfm): Re-anble hooks when serialization/deserialization
|
||||
# TODO(gmagogsfm): Re-enable hooks when serialization/deserialization
|
||||
# is supported.
|
||||
with torch._jit_internal._disable_emit_hooks():
|
||||
scripted_enum_comp = torch.jit.script(enum_comp)
|
||||
|
|
@ -56,11 +94,47 @@ class TestEnum(JitTestCase):
|
|||
def enum_comp(x: Color, y: Color) -> bool:
|
||||
return x == y
|
||||
|
||||
# TODO(gmagogsfm): Re-anble hooks when serialization/deserialization
|
||||
# TODO(gmagogsfm): Re-enable hooks when serialization/deserialization
|
||||
# is supported.
|
||||
with self.assertRaisesRegex(RuntimeError, "Could not unify type list"):
|
||||
scripted_enum_comp = torch.jit.script(enum_comp)
|
||||
|
||||
def test_enum_name(self):
|
||||
global Color
|
||||
|
||||
class Color(Enum):
|
||||
RED = 1
|
||||
GREEN = 2
|
||||
|
||||
def enum_name(x: Color) -> str:
|
||||
return x.name
|
||||
|
||||
# TODO(gmagogsfm): Re-enable hooks when serialization/deserialization
|
||||
# is supported.
|
||||
with torch._jit_internal._disable_emit_hooks():
|
||||
scripted_enum_name = torch.jit.script(enum_name)
|
||||
|
||||
self.assertEqual(scripted_enum_name(Color.RED), Color.RED.name)
|
||||
self.assertEqual(scripted_enum_name(Color.GREEN), Color.GREEN.name)
|
||||
|
||||
def test_enum_value(self):
|
||||
global Color
|
||||
|
||||
class Color(Enum):
|
||||
RED = 1
|
||||
GREEN = 2
|
||||
|
||||
def enum_value(x: Color) -> int:
|
||||
return x.value
|
||||
|
||||
# TODO(gmagogsfm): Re-enable hooks when serialization/deserialization
|
||||
# is supported.
|
||||
with torch._jit_internal._disable_emit_hooks():
|
||||
scripted_enum_value = torch.jit.script(enum_value)
|
||||
|
||||
self.assertEqual(scripted_enum_value(Color.RED), Color.RED.value)
|
||||
self.assertEqual(scripted_enum_value(Color.GREEN), Color.GREEN.value)
|
||||
|
||||
|
||||
# Tests that Enum support features are properly guarded before they are mature.
|
||||
class TestEnumFeatureGuard(JitTestCase):
|
||||
|
|
|
|||
|
|
@ -162,6 +162,19 @@ std::shared_ptr<SugaredValue> SimpleValue::attr(
|
|||
if (auto schema = iface->getMethod(field)) {
|
||||
return std::make_shared<MethodValue>(getValue(), field);
|
||||
}
|
||||
} else if (auto enum_type = value_->type()->cast<EnumType>()) {
|
||||
// Handle access to Enum's `name` and `value` attribute.
|
||||
auto& g = *m.graph();
|
||||
|
||||
if (field == "name") {
|
||||
auto n = g.insertNode(g.createEnumName(value_));
|
||||
return std::make_shared<SimpleValue>(n->output());
|
||||
}
|
||||
|
||||
if (field == "value") {
|
||||
auto n = g.insertNode(g.createEnumValue(value_));
|
||||
return std::make_shared<SimpleValue>(n->output());
|
||||
}
|
||||
}
|
||||
|
||||
// none of the more-specific cases worked, so see if this is a builtin method
|
||||
|
|
|
|||
|
|
@ -1595,6 +1595,21 @@ Node* Graph::createTupleSlice(Value* tup, int64_t beg, int64_t end) {
|
|||
return n;
|
||||
}
|
||||
|
||||
Node* Graph::createEnumName(Value* e) {
|
||||
e->type()->expect<EnumType>();
|
||||
assert(e->type()->cast<EnumType>());
|
||||
auto n = create(prim::EnumName, {e});
|
||||
n->output()->setType(StringType::get());
|
||||
return n;
|
||||
}
|
||||
|
||||
Node* Graph::createEnumValue(Value* e) {
|
||||
auto enum_type = e->type()->expect<EnumType>();
|
||||
auto n = create(prim::EnumValue, {e});
|
||||
n->output()->setType(enum_type->getValueType());
|
||||
return n;
|
||||
}
|
||||
|
||||
Node* Graph::createList(const TypePtr& elem_type, at::ArrayRef<Value*> values) {
|
||||
auto n = create(prim::ListConstruct, values);
|
||||
for (const auto& v : values) {
|
||||
|
|
|
|||
|
|
@ -1115,6 +1115,8 @@ struct Graph {
|
|||
Value* idx,
|
||||
const TypePtr& output_type);
|
||||
TORCH_API Node* createTupleSlice(Value* tup, int64_t beg, int64_t end);
|
||||
TORCH_API Node* createEnumName(Value* e);
|
||||
TORCH_API Node* createEnumValue(Value* e);
|
||||
TORCH_API Node* createList(
|
||||
const TypePtr& elem_type,
|
||||
at::ArrayRef<Value*> values);
|
||||
|
|
|
|||
|
|
@ -311,6 +311,34 @@ RegisterOperators reg(
|
|||
pack(stack, t.sizes().vec());
|
||||
},
|
||||
aliasAnalysisFromSchema()),
|
||||
Operator(
|
||||
"prim::EnumName(AnyEnumType enum) -> str",
|
||||
[](Stack* stack) {
|
||||
IValue e = pop(stack);
|
||||
push(stack, e.toEnumHolder()->name());
|
||||
},
|
||||
aliasAnalysisFromSchema()),
|
||||
Operator(
|
||||
"prim::EnumValue.int(AnyEnumType enum) -> int",
|
||||
[](Stack* stack) {
|
||||
IValue e = pop(stack);
|
||||
push(stack, e.toEnumHolder()->value());
|
||||
},
|
||||
aliasAnalysisFromSchema()),
|
||||
Operator(
|
||||
"prim::EnumValue.float(AnyEnumType enum) -> float",
|
||||
[](Stack* stack) {
|
||||
IValue e = pop(stack);
|
||||
push(stack, e.toEnumHolder()->value());
|
||||
},
|
||||
aliasAnalysisFromSchema()),
|
||||
Operator(
|
||||
"prim::EnumValue.str(AnyEnumType enum) -> str",
|
||||
[](Stack* stack) {
|
||||
IValue e = pop(stack);
|
||||
push(stack, e.toEnumHolder()->value());
|
||||
},
|
||||
aliasAnalysisFromSchema()),
|
||||
Operator(
|
||||
// note the compiler knows to type TupleIndex more accurately than it
|
||||
// is listed here.
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user