mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: Serialization not implemented yet. Will do in the next diff. Resolving Github issues: https://github.com/pytorch/pytorch/issues/112429 https://github.com/pytorch/pytorch/issues/114142 Test Plan: buck2 run mode/opt caffe2/test:test_export -- -r test_export_ input_mutation Differential Revision: D51556962 Pull Request resolved: https://github.com/pytorch/pytorch/pull/114496 Approved by: https://github.com/tugsbayasgalan
358 lines
14 KiB
Python
358 lines
14 KiB
Python
import inspect
|
|
import math
|
|
import operator
|
|
from collections.abc import Iterable
|
|
from typing import Any, Dict, final, List, Optional, Tuple, Type
|
|
|
|
import torch
|
|
from torch._ops import HigherOrderOperator, OpOverload
|
|
from torch._subclasses.fake_tensor import FakeTensor
|
|
from torch.export.exported_program import ExportedProgram
|
|
from torch.export.graph_signature import (
|
|
ExportGraphSignature,
|
|
InputKind,
|
|
SymIntArgument,
|
|
TensorArgument,
|
|
)
|
|
from torch.fx import GraphModule
|
|
from torch.fx.experimental.symbolic_shapes import SymBool, SymFloat, SymInt
|
|
|
|
|
|
class SpecViolationError(Exception):
|
|
pass
|
|
|
|
|
|
def is_functional(op: OpOverload) -> bool:
|
|
return not op._schema.is_mutable
|
|
|
|
|
|
def _check_has_fake_tensor(node: torch.fx.Node) -> None:
|
|
# TODO(angelayi): remove this in favor of _check_val
|
|
return _check_val(node)
|
|
|
|
|
|
def _check_val(node: torch.fx.Node) -> None:
|
|
def _check_correct_val(val):
|
|
if val is None:
|
|
return True
|
|
elif isinstance(val, (int, bool, str, float)):
|
|
return True
|
|
elif isinstance(val, (torch.memory_format, torch.dtype, torch.device, torch.layout)):
|
|
return True
|
|
elif isinstance(val, (FakeTensor, torch.Tensor)): # TODO(zhxchen17) Remove Tensor.
|
|
return True
|
|
elif isinstance(val, (SymInt, SymFloat, SymBool)):
|
|
return True
|
|
elif isinstance(val, Iterable):
|
|
return all(_check_correct_val(x) for x in val)
|
|
return False
|
|
|
|
def _no_returns(op):
|
|
if not isinstance(op, OpOverload):
|
|
return False
|
|
return len(op._schema.returns) == 0
|
|
|
|
if "val" not in node.meta:
|
|
if node.op == "call_function" and _no_returns(node.target):
|
|
return
|
|
raise SpecViolationError(f"Node.meta {node.name} is missing val field.")
|
|
|
|
val = node.meta["val"]
|
|
if not _check_correct_val(val):
|
|
raise SpecViolationError(f"Node.meta {node.name} has invalid val field {val}")
|
|
|
|
|
|
class _VerifierMeta(type):
|
|
_registry: Dict[str, Type['Verifier']] = {}
|
|
|
|
def __new__(metacls, name, bases, attrs):
|
|
if bases:
|
|
if "check" in attrs or "_check_graph_module" in attrs:
|
|
raise SyntaxError("Overriding method check is not allowed.")
|
|
assert "dialect" in attrs and attrs["dialect"] != "ATEN"
|
|
else:
|
|
assert "check" in attrs
|
|
assert "_check_graph_module" in attrs
|
|
assert attrs["dialect"] == "ATEN"
|
|
|
|
assert isinstance(attrs["dialect"], str)
|
|
ret = type.__new__(metacls, name, bases, attrs)
|
|
metacls._registry[attrs["dialect"]] = ret # type: ignore[assignment]
|
|
return ret
|
|
|
|
|
|
class Verifier(metaclass=_VerifierMeta):
|
|
dialect = "ATEN"
|
|
|
|
def allowed_builtin_ops(self) -> List:
|
|
return [
|
|
operator.getitem,
|
|
operator.add,
|
|
operator.mul,
|
|
operator.sub,
|
|
operator.truediv,
|
|
operator.ge,
|
|
operator.le,
|
|
operator.gt,
|
|
operator.lt,
|
|
operator.eq,
|
|
operator.ne,
|
|
operator.floordiv,
|
|
operator.mod,
|
|
operator.and_,
|
|
operator.or_,
|
|
operator.not_,
|
|
operator.pow,
|
|
operator.neg,
|
|
operator.abs,
|
|
math.ceil,
|
|
math.floor,
|
|
]
|
|
|
|
def allowed_op_types(self) -> Tuple[Type[Any], ...]:
|
|
return (OpOverload, HigherOrderOperator)
|
|
|
|
def allowed_getattr_types(self) -> Tuple[Type[Any], ...]:
|
|
return (torch.fx.GraphModule,)
|
|
|
|
def check_valid_op(self, op):
|
|
pass
|
|
|
|
def check_additional(self, gm: GraphModule) -> None:
|
|
"""
|
|
Additional checks that are specific to some dialects.
|
|
"""
|
|
pass
|
|
|
|
@final
|
|
def check(self, ep: ExportedProgram) -> None:
|
|
if not isinstance(ep.graph_signature, ExportGraphSignature):
|
|
# TODO Enforce type checking in the constructor.
|
|
return
|
|
self._check_graph_module(ep.graph_module)
|
|
try:
|
|
_verify_exported_program_signature(ep)
|
|
except SpecViolationError as e:
|
|
# TODO Remove this branch.
|
|
if ep.dialect == "EDGE": # !!! Don't change this allowlist. !!!
|
|
pass
|
|
else:
|
|
raise e
|
|
|
|
@final
|
|
def _check_graph_module(self, gm: torch.fx.GraphModule) -> None:
|
|
def _allowed_getattr_types() -> Tuple[Type[Any], ...]:
|
|
ret = self.allowed_getattr_types()
|
|
assert not any(t is object for t in ret)
|
|
return ret
|
|
|
|
def _check_valid_op(op) -> None:
|
|
def _allowed_builtin_ops() -> List:
|
|
ret = self.allowed_builtin_ops()
|
|
assert all(inspect.isbuiltin(op) for op in ret)
|
|
return ret
|
|
|
|
def _allowed_op_types() -> Tuple[Type[Any], ...]:
|
|
ret = self.allowed_op_types()
|
|
assert not any(t is object for t in ret)
|
|
return ret
|
|
|
|
# TODO Remove this allowlist.
|
|
_allowed_torch_functions = (torch.autograd.grad_mode.set_grad_enabled,)
|
|
|
|
if not isinstance(op, _allowed_op_types()):
|
|
if op not in _allowed_builtin_ops() and op not in _allowed_torch_functions:
|
|
raise SpecViolationError(
|
|
f"Operator '{op}' is not an allowed operator type: {_allowed_op_types()}\n"
|
|
f"Valid builtin ops: {_allowed_builtin_ops()}"
|
|
f"Valid torch functions: {_allowed_torch_functions}"
|
|
)
|
|
|
|
if isinstance(op, OpOverload):
|
|
# All ops functional
|
|
if not is_functional(op):
|
|
raise SpecViolationError(
|
|
f"operator '{op}' is not functional"
|
|
)
|
|
self.check_valid_op(op)
|
|
|
|
for mod in gm.modules():
|
|
if not isinstance(mod, torch.fx.GraphModule):
|
|
continue
|
|
|
|
mod.graph.lint()
|
|
for node in mod.graph.nodes:
|
|
# TODO(T140410192): should have fake tensor for all dialects
|
|
if node.op in {"call_module", "call_method"}:
|
|
raise SpecViolationError(
|
|
f"call_module is not valid: got a class '{node.target}' ",
|
|
)
|
|
|
|
elif node.op == "call_function":
|
|
_check_val(node)
|
|
|
|
_check_valid_op(node.target)
|
|
|
|
elif node.op == "get_attr":
|
|
if not isinstance(node.target, str):
|
|
raise SpecViolationError(
|
|
f"Expected get_attr target to be string, but got {type(node.target)}"
|
|
)
|
|
|
|
attr = getattr(mod, node.target)
|
|
if isinstance(attr, torch.nn.Module):
|
|
def _is_type(name, ty):
|
|
return isinstance(getattr(attr, name, None), ty)
|
|
if type(attr).__name__ == "LoweredBackendModule" \
|
|
and _is_type("backend_id", str) \
|
|
and _is_type("processed_bytes", bytes) \
|
|
and _is_type("compile_specs", list) \
|
|
and hasattr(attr, "original_module"):
|
|
continue
|
|
|
|
if not isinstance(attr, _allowed_getattr_types()):
|
|
raise SpecViolationError(
|
|
f"Invalid get_attr type {type(attr)}. \n"
|
|
f"Valid get_attr types: {_allowed_getattr_types()}"
|
|
)
|
|
|
|
|
|
elif node.op == "placeholder":
|
|
_check_val(node)
|
|
# TODO(zhxchen17)
|
|
# elif node.op == "output":
|
|
# _check_flattened_outputs()
|
|
|
|
self.check_additional(gm)
|
|
|
|
|
|
def _verify_exported_program_signature(exported_program) -> None:
|
|
# Check ExportedProgram signature matches
|
|
gs = exported_program.graph_signature
|
|
|
|
# Check every node in the signature exists in the graph
|
|
input_node_names = [node.name for node in exported_program.graph.nodes if node.op == "placeholder"]
|
|
|
|
if len(input_node_names) != len(gs.input_specs):
|
|
raise SpecViolationError(
|
|
f"Number of graph inputs ({len(input_node_names)}) "
|
|
f"does not match number of inputs in the graph signature ({len(gs.user_inputs)})"
|
|
)
|
|
|
|
for input_spec, node in zip(gs.input_specs, input_node_names):
|
|
if isinstance(input_spec.arg, (TensorArgument, SymIntArgument)):
|
|
if input_spec.arg.name != node:
|
|
raise SpecViolationError(
|
|
f"Input spec name {input_spec.arg.name} does not match node name {node}"
|
|
)
|
|
|
|
if input_spec.kind == InputKind.USER_INPUT:
|
|
continue
|
|
|
|
elif input_spec.kind == InputKind.PARAMETER:
|
|
if not isinstance(input_spec.arg, TensorArgument):
|
|
raise SpecViolationError(
|
|
f"Parameter {input_spec.name} is not a tensor argument. Found {input_spec.arg} instead."
|
|
)
|
|
if input_spec.target is None:
|
|
raise SpecViolationError(
|
|
f"InputSpec for {input_spec.name} has no target."
|
|
)
|
|
|
|
param = input_spec.target
|
|
if param not in exported_program.state_dict:
|
|
raise SpecViolationError(
|
|
f"Parameter {param} is not in the state dict."
|
|
)
|
|
|
|
if not isinstance(exported_program.state_dict[param], torch.nn.Parameter):
|
|
raise SpecViolationError(
|
|
f"State dict entry for parameter {param} is not an instance of torch.nn.Parameter."
|
|
)
|
|
|
|
elif input_spec.kind == InputKind.BUFFER:
|
|
if not isinstance(input_spec.arg, TensorArgument):
|
|
raise SpecViolationError(
|
|
f"Buffer {input_spec.name} is not a tensor argument. Found {input_spec.arg} instead."
|
|
)
|
|
if input_spec.target is None:
|
|
raise SpecViolationError(
|
|
f"InputSpec for {input_spec.name} has no target."
|
|
)
|
|
|
|
buffer = input_spec.target
|
|
if buffer not in exported_program.state_dict:
|
|
raise SpecViolationError(
|
|
f"Buffer {buffer} is not in the state dict."
|
|
)
|
|
elif input_spec.kind == InputKind.CONSTANT_TENSOR:
|
|
if not isinstance(input_spec.arg, TensorArgument):
|
|
raise SpecViolationError(
|
|
f"Constant tensor {input_spec.name} is not a tensor argument. Found {input_spec.arg} instead."
|
|
)
|
|
if input_spec.target is None:
|
|
raise SpecViolationError(
|
|
f"InputSpec for {input_spec.name} has no target."
|
|
)
|
|
|
|
tensor_const = input_spec.target
|
|
if tensor_const not in exported_program.tensor_constants:
|
|
raise SpecViolationError(
|
|
f"Constant tensor {tensor_const} is not in the tensor constants dictionary."
|
|
)
|
|
else:
|
|
raise SpecViolationError(
|
|
f"Unknown InputKind {input_spec.kind}."
|
|
)
|
|
|
|
# Check outputs
|
|
output_node = list(exported_program.graph.nodes)[-1]
|
|
assert output_node.op == "output"
|
|
output_nodes = [arg.name for arg in output_node.args[0]]
|
|
|
|
if len(output_nodes) != len(gs.output_specs):
|
|
raise SpecViolationError(
|
|
f"Number of output nodes {len(output_nodes)} is different "
|
|
"Than the number of outputs specified by the graph signature: \n"
|
|
f"Number of mutated buffers: {len(gs.buffers_to_mutate)}. \n"
|
|
f"Number of user outputs: {len(gs.user_outputs)}. \n"
|
|
)
|
|
|
|
end = len(gs.buffers_to_mutate) + len(gs.user_inputs_to_mutate)
|
|
mutate_nodes: List[str] = output_nodes[:end]
|
|
user_output_nodes = output_nodes[end:end + len(gs.user_outputs)]
|
|
|
|
for mutation_node in mutate_nodes:
|
|
if mutation_node in gs.buffers_to_mutate:
|
|
if gs.buffers_to_mutate[mutation_node] not in gs.buffers:
|
|
raise SpecViolationError(
|
|
f"Buffer output {mutation_node} does not point to a buffer that exists. \n"
|
|
f"Dict of buffers that are mutated, in order: {gs.buffers_to_mutate} \n"
|
|
f"Buffer nodes available: {gs.buffers} \n"
|
|
)
|
|
elif mutation_node in gs.user_inputs_to_mutate:
|
|
if gs.user_inputs_to_mutate[mutation_node] not in gs.user_inputs:
|
|
raise SpecViolationError(
|
|
f"User input output {mutation_node} does not point to a user input that exists. \n"
|
|
f"Dict of user inputs that are mutated, in order: {gs.user_inputs_to_mutate} \n"
|
|
f"User input nodes available: {gs.user_inputs} \n")
|
|
else:
|
|
raise SpecViolationError(
|
|
f"Mutation node {mutation_node} is neither a buffer nor a user input. "
|
|
f"Buffers to mutate: {gs.buffers_to_mutate}, User inputs to mutate: {gs.user_inputs_to_mutate}"
|
|
)
|
|
|
|
for user_output_node, user_output_name in zip(user_output_nodes, gs.user_outputs):
|
|
if user_output_node != user_output_name:
|
|
raise SpecViolationError(
|
|
f"User output {user_output_node} is not in the correct "
|
|
"order or is not found in the "
|
|
f"exported program's user_output list: {gs.user_outputs}. "
|
|
)
|
|
|
|
|
|
def load_verifier(dialect: str) -> Optional[Type[Verifier]]:
|
|
if dialect == "ATEN":
|
|
return _VerifierMeta._registry.get(dialect)
|
|
return _VerifierMeta._registry[dialect]
|