mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add a registry for GraphModuleSerializer (#126550)
This PR adds a registration function and a global registry for GraphModuleSerializer. After this PR, custom serialization methods can be done through registration instead of subclassing for ease of maintenance. ## Changes - Add a test case where it injects custom op to test serialization. - Add custom op handler - Change allowed op for verifier Co-authored-by: Zhengxu Chen <zhxchen17@outlook.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/126550 Approved by: https://github.com/zhxchen17
This commit is contained in:
parent
cdbb2c9acc
commit
10d2373abd
|
|
@ -52,6 +52,52 @@ def get_filtered_export_db_tests():
|
|||
|
||||
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support")
|
||||
class TestSerialize(TestCase):
|
||||
def test_export_with_custom_op_serialization(self):
|
||||
class TestModule(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return x + 1
|
||||
|
||||
class FooCustomOp(torch.nn.Module):
|
||||
pass
|
||||
|
||||
class FooCustomOpHandler(torch._export.serde.serialize.CustomOpHandler):
|
||||
def namespace(self):
|
||||
return "Foo"
|
||||
|
||||
def op_name(self, op_type):
|
||||
if op_type == FooCustomOp:
|
||||
return "FooCustomOp"
|
||||
return None
|
||||
|
||||
def op_type(self, op_name):
|
||||
if op_name == "FooCustomOp":
|
||||
return FooCustomOp
|
||||
return None
|
||||
|
||||
def op_schema(self, op_type):
|
||||
if op_type == FooCustomOp:
|
||||
return self.attached_schema
|
||||
return None
|
||||
|
||||
inp = (torch.ones(10),)
|
||||
ep = export(TestModule(), inp)
|
||||
|
||||
# Register the custom op handler.
|
||||
foo_custom_op = FooCustomOp()
|
||||
foo_custom_op_handler = FooCustomOpHandler()
|
||||
torch._export.serde.serialize.register_custom_op_handler(
|
||||
foo_custom_op_handler, type(foo_custom_op)
|
||||
)
|
||||
|
||||
# Inject the custom operator.
|
||||
for node in ep.graph.nodes:
|
||||
if node.name == "add":
|
||||
foo_custom_op_handler.attached_schema = node.target._schema
|
||||
node.target = foo_custom_op
|
||||
|
||||
# Serialization.
|
||||
serialize(ep)
|
||||
|
||||
def test_predispatch_export_with_autograd_op(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ from typing import (
|
|||
Set,
|
||||
Tuple,
|
||||
Union,
|
||||
Type,
|
||||
)
|
||||
|
||||
import sympy
|
||||
|
|
@ -361,13 +362,24 @@ def serialize_range_constraints(
|
|||
}
|
||||
|
||||
|
||||
def _get_schema_from_target(target):
|
||||
if isinstance(target, torch._ops.OpOverload):
|
||||
return target._schema
|
||||
elif type(target) in _serialization_registry:
|
||||
return _serialization_registry[type(target)].op_schema(type(target))
|
||||
raise RuntimeError(f"Cannot find schema for {type(target)}")
|
||||
|
||||
|
||||
def _is_single_tensor_return(target: torch._ops.OpOverload) -> bool:
|
||||
returns = target._schema.returns
|
||||
schema = _get_schema_from_target(target)
|
||||
returns = schema.returns
|
||||
return len(returns) == 1 and isinstance(returns[0].real_type, torch.TensorType)
|
||||
|
||||
|
||||
def _is_single_tensor_list_return(target: torch._ops.OpOverload) -> bool:
|
||||
returns = target._schema.returns
|
||||
def _is_single_tensor_list_return(target: Any) -> bool:
|
||||
schema = _get_schema_from_target(target)
|
||||
returns = schema.returns
|
||||
|
||||
if len(returns) != 1:
|
||||
return False
|
||||
return_type = returns[0].real_type
|
||||
|
|
@ -519,6 +531,19 @@ class GraphModuleSerializer(metaclass=Final):
|
|||
outputs=self.serialize_hoo_outputs(node),
|
||||
metadata=self.serialize_metadata(node),
|
||||
)
|
||||
elif type(node.target) in _serialization_registry:
|
||||
custom_op_handler = node.target
|
||||
|
||||
# Sanity check for unhandled serialization.
|
||||
assert type(node.target) in _serialization_registry, f"Miss {type(node.target)} CustomOpHandler"
|
||||
|
||||
handler = _serialization_registry[type(node.target)]
|
||||
ex_node = Node(
|
||||
target=f"${handler.namespace()}:{handler.op_name(node.target)}",
|
||||
inputs=self.serialize_inputs(node.target, node.args, node.kwargs),
|
||||
outputs=self.serialize_outputs(node),
|
||||
metadata=self.serialize_metadata(node),
|
||||
)
|
||||
else:
|
||||
raise SerializeError(f"Serializing {node.target} is not supported")
|
||||
|
||||
|
|
@ -579,12 +604,18 @@ class GraphModuleSerializer(metaclass=Final):
|
|||
return serialized_args
|
||||
|
||||
def serialize_inputs(
|
||||
self, target: torch._ops.OpOverload, args, kwargs=None
|
||||
self,
|
||||
target: Any, # torch._ops.OpOverload and other custom operator types.
|
||||
args,
|
||||
kwargs=None
|
||||
) -> List[NamedArgument]:
|
||||
assert isinstance(target, torch._ops.OpOverload)
|
||||
assert isinstance(target, (torch._ops.OpOverload, *allowed_registered_op_types()))
|
||||
kwargs = kwargs or {}
|
||||
serialized_args = []
|
||||
for i, schema_arg in enumerate(target._schema.arguments):
|
||||
|
||||
schema = _get_schema_from_target(target)
|
||||
|
||||
for i, schema_arg in enumerate(schema.arguments):
|
||||
if schema_arg.name in kwargs:
|
||||
serialized_args.append(
|
||||
NamedArgument(
|
||||
|
|
@ -1075,12 +1106,10 @@ class GraphModuleSerializer(metaclass=Final):
|
|||
mostly reuse the names coming from FX. This function computes a mapping from
|
||||
the FX representation to our representation, preserving the names.
|
||||
"""
|
||||
assert node.op == "call_function" and isinstance(
|
||||
node.target, torch._ops.OpOverload
|
||||
)
|
||||
assert node.op == "call_function" and isinstance(node.target, (torch._ops.OpOverload, *allowed_registered_op_types()))
|
||||
|
||||
assert isinstance(node.target, torch._ops.OpOverload)
|
||||
returns = node.target._schema.returns
|
||||
schema = _get_schema_from_target(node.target)
|
||||
returns = schema.returns
|
||||
|
||||
if len(returns) == 0:
|
||||
return []
|
||||
|
|
@ -2779,3 +2808,50 @@ def canonicalize(ep: ExportedProgram) -> ExportedProgram:
|
|||
schema_version=ep.schema_version,
|
||||
dialect=ep.dialect
|
||||
)
|
||||
|
||||
|
||||
class CustomOpHandler:
|
||||
"""
|
||||
Base class for handling custom operators.
|
||||
"""
|
||||
@classmethod
|
||||
def namespace(cls):
|
||||
raise NotImplementedError(f"{cls.__class__} namespace() must be implemented")
|
||||
|
||||
@classmethod
|
||||
def op_name(cls, op_type):
|
||||
raise NotImplementedError(f"{cls.__class__} op_name() must be implemented")
|
||||
|
||||
@classmethod
|
||||
def op_type(cls, op_name):
|
||||
raise NotImplementedError(f"{cls.__class__} op_type() must be implemented")
|
||||
|
||||
@classmethod
|
||||
def op_schema(cls, op_type):
|
||||
raise NotImplementedError(f"{cls.__class__} op_schema() must be implemented")
|
||||
|
||||
|
||||
def register_custom_op_handler(
|
||||
op_handler: CustomOpHandler,
|
||||
op_type: Type[Any],
|
||||
):
|
||||
"""Register custom de/serialization method for a node."""
|
||||
assert isinstance(op_handler, CustomOpHandler), f"Expected CustomOpHandler, got {type(op_handler)}."
|
||||
_serialization_registry[op_type] = op_handler
|
||||
# FIXME: handles deserialization later.
|
||||
_deserialization_registry[op_handler.namespace()] = op_handler
|
||||
|
||||
|
||||
def allowed_registered_op_types():
|
||||
return tuple(
|
||||
_serialization_registry.keys()
|
||||
)
|
||||
|
||||
|
||||
# Registry to store all custom serialization implementations.
|
||||
# The registry maps a operation to its serialization function (a callable), in their own
|
||||
# namespace to avoid conflicts.
|
||||
# Serialization: Op type --> custom handler.
|
||||
# De-serialization: Namespace --> custom handler.
|
||||
_serialization_registry: Dict[Type[Any], CustomOpHandler] = {}
|
||||
_deserialization_registry: Dict[str, CustomOpHandler] = {}
|
||||
|
|
|
|||
|
|
@ -133,7 +133,8 @@ class Verifier(metaclass=_VerifierMeta):
|
|||
]
|
||||
|
||||
def allowed_op_types(self) -> Tuple[Type[Any], ...]:
|
||||
return (OpOverload, HigherOrderOperator)
|
||||
from torch._export.serde.serialize import allowed_registered_op_types # Avoid circular import.
|
||||
return (OpOverload, HigherOrderOperator, *allowed_registered_op_types())
|
||||
|
||||
def allowed_getattr_types(self) -> Tuple[Type[Any], ...]:
|
||||
return (torch.fx.GraphModule,)
|
||||
|
|
@ -182,8 +183,7 @@ class Verifier(metaclass=_VerifierMeta):
|
|||
# TODO (tmanlaibaatar)
|
||||
# Predispatch export is able to contain autograd ops.
|
||||
# These will be modeled as HOO later
|
||||
torch._C._set_grad_enabled
|
||||
|
||||
torch._C._set_grad_enabled,
|
||||
)
|
||||
|
||||
if not isinstance(op, _allowed_op_types()):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user