import itertools import operator from collections.abc import Iterable from typing import Set import torch from functorch.experimental import control_flow from torch._ops import OpOverload from torch._subclasses.fake_tensor import FakeTensor from torch.fx import GraphModule from torch.fx._compatibility import compatibility PRESERVED_META_KEYS: Set[str] = { "val", "stack_trace", } @compatibility(is_backward_compatible=False) class SpecViolationError(Exception): pass @compatibility(is_backward_compatible=False) def is_functional(op: OpOverload) -> bool: return not op._schema.is_mutable @compatibility(is_backward_compatible=False) def _check_has_fake_tensor(node: torch.fx.Node) -> None: def _check_is_fake_tensor(val): if isinstance(val, FakeTensor): return True if isinstance(val, Iterable): return all(_check_is_fake_tensor(x) for x in val) return False val = node.meta.get("val", None) if val is None or not _check_is_fake_tensor(val): raise SpecViolationError("Node.meta {} is missing val field.".format(node.name)) @compatibility(is_backward_compatible=False) def _check_tensors_are_contiguous(gm: GraphModule) -> None: # Tensors be of contiguous format for name, param in itertools.chain(gm.named_parameters(), gm.named_buffers()): if isinstance(param, torch.Tensor): if not param.is_contiguous(): raise SpecViolationError( f"Tensors in Aten dialect must be contiguous, {name} is not contiguous" ) @compatibility(is_backward_compatible=False) class Verifier: def __call__(self, gm: GraphModule) -> None: self.check_valid(gm) @compatibility(is_backward_compatible=False) def valid_builtin_funcs(self): return [ operator.getitem, control_flow.cond, control_flow.map, ] @compatibility(is_backward_compatible=False) def check_valid_op(self, op): op_name = op.name if hasattr(op, "name") else op.__name__ if not isinstance(op, OpOverload): raise SpecViolationError( "Operator '{}' is not a registered Op".format(op_name), ) # All ops functional if not is_functional(op): raise SpecViolationError( f"operator '{op_name}' is not functional" ) @compatibility(is_backward_compatible=False) def check_valid(self, gm: GraphModule) -> None: # noqa: C901 for node in gm.graph.nodes: # TODO(T140410192): should have fake tensor for all dialects if node.op in {"call_module", "call_method"}: raise SpecViolationError( "call_module is not valid: got a class '{}' ".format(node.target), ) if node.op == "call_function": _check_has_fake_tensor(node) if node.target not in self.valid_builtin_funcs(): self.check_valid_op(node.target) if isinstance(node.target, OpOverload): # Check preserved metadata for meta in PRESERVED_META_KEYS: if node.meta.get(meta, None) is None: raise SpecViolationError( f"node {node} is missing metadata {meta}" ) @compatibility(is_backward_compatible=False) def is_valid(self, gm: GraphModule) -> bool: try: self.check_valid(gm) return True except SpecViolationError: return False class ATenDialectVerifier(Verifier): @compatibility(is_backward_compatible=False) def check_valid_op(self, op) -> None: super().check_valid_op(op) op_name = op.name if hasattr(op, "name") else op.__name__ if not isinstance(op, OpOverload): raise SpecViolationError( "Operator '{}' is not a registered Op".format(op_name), ) if ( torch.Tag.core not in op.tags # type: ignore[attr-defined] and torch.Tag.view_copy not in op.tags # type: ignore[attr-defined] ): # NOTE(qihan): whether view_copy operators are marked as canonical is still under # discussion. raise SpecViolationError( "Operator {}.{} is not Aten Canonical.".format( op.__module__, op.__name__ ) ) @compatibility(is_backward_compatible=False) def check_valid(self, gm: GraphModule) -> None: super().check_valid(gm) _check_tensors_are_contiguous(gm)