Add a registry for GraphModuleSerializer (#126550)

This PR adds a registration function and a global registry for GraphModuleSerializer. After this PR, custom serialization methods can be done through registration instead of subclassing for ease of maintenance.

## Changes
- Add a test case where it injects custom op to test serialization.
- Add custom op handler
- Change allowed op for verifier
Co-authored-by: Zhengxu Chen <zhxchen17@outlook.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126550
Approved by: https://github.com/zhxchen17
This commit is contained in:
Jiashen Cao 2024-05-29 03:12:48 +00:00 committed by PyTorch MergeBot
parent cdbb2c9acc
commit 10d2373abd
3 changed files with 136 additions and 14 deletions

View File

@ -52,6 +52,52 @@ def get_filtered_export_db_tests():
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support")
class TestSerialize(TestCase):
def test_export_with_custom_op_serialization(self):
class TestModule(torch.nn.Module):
def forward(self, x):
return x + 1
class FooCustomOp(torch.nn.Module):
pass
class FooCustomOpHandler(torch._export.serde.serialize.CustomOpHandler):
def namespace(self):
return "Foo"
def op_name(self, op_type):
if op_type == FooCustomOp:
return "FooCustomOp"
return None
def op_type(self, op_name):
if op_name == "FooCustomOp":
return FooCustomOp
return None
def op_schema(self, op_type):
if op_type == FooCustomOp:
return self.attached_schema
return None
inp = (torch.ones(10),)
ep = export(TestModule(), inp)
# Register the custom op handler.
foo_custom_op = FooCustomOp()
foo_custom_op_handler = FooCustomOpHandler()
torch._export.serde.serialize.register_custom_op_handler(
foo_custom_op_handler, type(foo_custom_op)
)
# Inject the custom operator.
for node in ep.graph.nodes:
if node.name == "add":
foo_custom_op_handler.attached_schema = node.target._schema
node.target = foo_custom_op
# Serialization.
serialize(ep)
def test_predispatch_export_with_autograd_op(self):
class Foo(torch.nn.Module):
def __init__(self):

View File

@ -27,6 +27,7 @@ from typing import (
Set,
Tuple,
Union,
Type,
)
import sympy
@ -361,13 +362,24 @@ def serialize_range_constraints(
}
def _get_schema_from_target(target):
if isinstance(target, torch._ops.OpOverload):
return target._schema
elif type(target) in _serialization_registry:
return _serialization_registry[type(target)].op_schema(type(target))
raise RuntimeError(f"Cannot find schema for {type(target)}")
def _is_single_tensor_return(target: torch._ops.OpOverload) -> bool:
returns = target._schema.returns
schema = _get_schema_from_target(target)
returns = schema.returns
return len(returns) == 1 and isinstance(returns[0].real_type, torch.TensorType)
def _is_single_tensor_list_return(target: torch._ops.OpOverload) -> bool:
returns = target._schema.returns
def _is_single_tensor_list_return(target: Any) -> bool:
schema = _get_schema_from_target(target)
returns = schema.returns
if len(returns) != 1:
return False
return_type = returns[0].real_type
@ -519,6 +531,19 @@ class GraphModuleSerializer(metaclass=Final):
outputs=self.serialize_hoo_outputs(node),
metadata=self.serialize_metadata(node),
)
elif type(node.target) in _serialization_registry:
custom_op_handler = node.target
# Sanity check for unhandled serialization.
assert type(node.target) in _serialization_registry, f"Miss {type(node.target)} CustomOpHandler"
handler = _serialization_registry[type(node.target)]
ex_node = Node(
target=f"${handler.namespace()}:{handler.op_name(node.target)}",
inputs=self.serialize_inputs(node.target, node.args, node.kwargs),
outputs=self.serialize_outputs(node),
metadata=self.serialize_metadata(node),
)
else:
raise SerializeError(f"Serializing {node.target} is not supported")
@ -579,12 +604,18 @@ class GraphModuleSerializer(metaclass=Final):
return serialized_args
def serialize_inputs(
self, target: torch._ops.OpOverload, args, kwargs=None
self,
target: Any, # torch._ops.OpOverload and other custom operator types.
args,
kwargs=None
) -> List[NamedArgument]:
assert isinstance(target, torch._ops.OpOverload)
assert isinstance(target, (torch._ops.OpOverload, *allowed_registered_op_types()))
kwargs = kwargs or {}
serialized_args = []
for i, schema_arg in enumerate(target._schema.arguments):
schema = _get_schema_from_target(target)
for i, schema_arg in enumerate(schema.arguments):
if schema_arg.name in kwargs:
serialized_args.append(
NamedArgument(
@ -1075,12 +1106,10 @@ class GraphModuleSerializer(metaclass=Final):
mostly reuse the names coming from FX. This function computes a mapping from
the FX representation to our representation, preserving the names.
"""
assert node.op == "call_function" and isinstance(
node.target, torch._ops.OpOverload
)
assert node.op == "call_function" and isinstance(node.target, (torch._ops.OpOverload, *allowed_registered_op_types()))
assert isinstance(node.target, torch._ops.OpOverload)
returns = node.target._schema.returns
schema = _get_schema_from_target(node.target)
returns = schema.returns
if len(returns) == 0:
return []
@ -2779,3 +2808,50 @@ def canonicalize(ep: ExportedProgram) -> ExportedProgram:
schema_version=ep.schema_version,
dialect=ep.dialect
)
class CustomOpHandler:
"""
Base class for handling custom operators.
"""
@classmethod
def namespace(cls):
raise NotImplementedError(f"{cls.__class__} namespace() must be implemented")
@classmethod
def op_name(cls, op_type):
raise NotImplementedError(f"{cls.__class__} op_name() must be implemented")
@classmethod
def op_type(cls, op_name):
raise NotImplementedError(f"{cls.__class__} op_type() must be implemented")
@classmethod
def op_schema(cls, op_type):
raise NotImplementedError(f"{cls.__class__} op_schema() must be implemented")
def register_custom_op_handler(
op_handler: CustomOpHandler,
op_type: Type[Any],
):
"""Register custom de/serialization method for a node."""
assert isinstance(op_handler, CustomOpHandler), f"Expected CustomOpHandler, got {type(op_handler)}."
_serialization_registry[op_type] = op_handler
# FIXME: handles deserialization later.
_deserialization_registry[op_handler.namespace()] = op_handler
def allowed_registered_op_types():
return tuple(
_serialization_registry.keys()
)
# Registry to store all custom serialization implementations.
# The registry maps a operation to its serialization function (a callable), in their own
# namespace to avoid conflicts.
# Serialization: Op type --> custom handler.
# De-serialization: Namespace --> custom handler.
_serialization_registry: Dict[Type[Any], CustomOpHandler] = {}
_deserialization_registry: Dict[str, CustomOpHandler] = {}

View File

@ -133,7 +133,8 @@ class Verifier(metaclass=_VerifierMeta):
]
def allowed_op_types(self) -> Tuple[Type[Any], ...]:
return (OpOverload, HigherOrderOperator)
from torch._export.serde.serialize import allowed_registered_op_types # Avoid circular import.
return (OpOverload, HigherOrderOperator, *allowed_registered_op_types())
def allowed_getattr_types(self) -> Tuple[Type[Any], ...]:
return (torch.fx.GraphModule,)
@ -182,8 +183,7 @@ class Verifier(metaclass=_VerifierMeta):
# TODO (tmanlaibaatar)
# Predispatch export is able to contain autograd ops.
# These will be modeled as HOO later
torch._C._set_grad_enabled
torch._C._set_grad_enabled,
)
if not isinstance(op, _allowed_op_types()):