mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Format torch.fx.experimental.validator (#136935)
Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/136935 Approved by: https://github.com/Skylion007 ghstack dependencies: #136934
This commit is contained in:
parent
33c2d3232f
commit
951af3d3d8
|
|
@ -1269,7 +1269,6 @@ exclude_patterns = [
|
|||
'torch/fx/experimental/unification/utils.py',
|
||||
'torch/fx/experimental/unification/variable.py',
|
||||
'torch/fx/experimental/unify_refinements.py',
|
||||
'torch/fx/experimental/validator.py',
|
||||
'torch/fx/graph.py',
|
||||
'torch/fx/graph_module.py',
|
||||
'torch/fx/interpreter.py',
|
||||
|
|
|
|||
|
|
@ -1,22 +1,22 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import builtins
|
||||
import functools
|
||||
import logging
|
||||
import math
|
||||
import operator
|
||||
import sympy
|
||||
import builtins
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
|
||||
|
||||
import sympy
|
||||
|
||||
import torch
|
||||
import torch.fx
|
||||
import torch.fx.traceback as fx_traceback
|
||||
|
||||
from torch._dynamo.exc import TorchDynamoException
|
||||
from torch._dynamo.utils import dynamo_timed
|
||||
from torch.fx.node import Argument, Target
|
||||
from torch.utils._sympy.interp import sympy_interp
|
||||
from torch._dynamo.utils import dynamo_timed
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -45,7 +45,6 @@ try:
|
|||
# and the FX nodes (see [Note: PopulateValidator]) that go through
|
||||
# 'ShapeEnv.evaluate_expr' function. Finally, we run the validation.
|
||||
# (see [Note: TranslationValidator])
|
||||
|
||||
# Better Z3 to string implementation (for a small fraction of Z3).
|
||||
#
|
||||
# Here are the things we clean before showing the Z3 expression:
|
||||
|
|
@ -68,7 +67,6 @@ try:
|
|||
# This is done using rewriting rules, so shouldn't take long.
|
||||
e = z3.simplify(e)
|
||||
|
||||
|
||||
# Only support function applications.
|
||||
# Even Z3 "variables" are, in fact, function applications.
|
||||
if not z3.is_app(e):
|
||||
|
|
@ -176,7 +174,9 @@ try:
|
|||
|
||||
# Python semantics for 'FloorDiv' states that before applying the floor
|
||||
# function, the operands are converted to their common type.
|
||||
def floordiv(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef:
|
||||
def floordiv(
|
||||
self, numerator: z3.ArithRef, denominator: z3.ArithRef
|
||||
) -> z3.ArithRef:
|
||||
cast_result_to_real = numerator.is_real() or denominator.is_real()
|
||||
result = _Z3Ops.to_int(self.div(numerator, denominator))
|
||||
# Since the 'result' is already an integer, we just have to check
|
||||
|
|
@ -185,9 +185,7 @@ try:
|
|||
|
||||
def ceil(self, number: z3.ArithRef) -> z3.ArithRef:
|
||||
return z3.If(
|
||||
self.floor(number) < number,
|
||||
self.floor(number + 1),
|
||||
number
|
||||
self.floor(number) < number, self.floor(number + 1), number
|
||||
) # type: ignore[return-value]
|
||||
|
||||
def max(self, a: z3.ArithRef, b: z3.ArithRef) -> z3.ArithRef:
|
||||
|
|
@ -284,11 +282,9 @@ try:
|
|||
operator.mod: lift(ops.mod),
|
||||
operator.abs: lift(ops.abs),
|
||||
builtins.round: lift(ops.round_to_int),
|
||||
|
||||
# Math module.
|
||||
math.ceil: lift(ops.ceil),
|
||||
math.floor: lift(ops.floor),
|
||||
|
||||
# Torch module.
|
||||
torch.sym_float: lift(ops.to_real),
|
||||
torch.sym_max: lift(ops.max),
|
||||
|
|
@ -319,17 +315,23 @@ try:
|
|||
module = torch.fx.GraphModule(root={}, graph=graph)
|
||||
super().__init__(module, garbage_collect_values=True)
|
||||
|
||||
def placeholder(self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
def placeholder(
|
||||
self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any]
|
||||
) -> Any:
|
||||
symbol = fx_traceback.get_current_meta()["symbol"]
|
||||
return self.validator.z3var(symbol)
|
||||
|
||||
def call_function(self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
def call_function(
|
||||
self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any]
|
||||
) -> Any:
|
||||
if target != torch._assert:
|
||||
# Lift and runs the node target function
|
||||
return super().call_function(z3op(target, self.validator), args, kwargs) # type: ignore[arg-type]
|
||||
# Adds the Z3 expression corresponding to the first argument
|
||||
# as a validator input.
|
||||
assert len(args) == 1, f"expected 1 argument on assertion. Got: {len(args)} "
|
||||
assert (
|
||||
len(args) == 1
|
||||
), f"expected 1 argument on assertion. Got: {len(args)} "
|
||||
self.validator.add_source_expr(args[0]) # type: ignore[arg-type]
|
||||
|
||||
# Translates SymPy expressions into Z3 expressions.
|
||||
|
|
@ -369,13 +371,19 @@ try:
|
|||
def round_to_int(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef:
|
||||
return self._ops.round_to_int(x)
|
||||
|
||||
def int_truediv(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef:
|
||||
def int_truediv(
|
||||
self, numerator: z3.ArithRef, denominator: z3.ArithRef
|
||||
) -> z3.ArithRef:
|
||||
return self._ops.div(numerator, denominator)
|
||||
|
||||
def truediv(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef:
|
||||
def truediv(
|
||||
self, numerator: z3.ArithRef, denominator: z3.ArithRef
|
||||
) -> z3.ArithRef:
|
||||
return self._ops.div(numerator, denominator)
|
||||
|
||||
def floordiv(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef:
|
||||
def floordiv(
|
||||
self, numerator: z3.ArithRef, denominator: z3.ArithRef
|
||||
) -> z3.ArithRef:
|
||||
return self._ops.floordiv(numerator, denominator)
|
||||
|
||||
def div(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef:
|
||||
|
|
@ -488,10 +496,11 @@ try:
|
|||
# Z3 variable corresponding to 's'.
|
||||
self.z3var(s)
|
||||
|
||||
|
||||
def to_z3_boolean_expr(self, e: sympy.Basic) -> z3.BoolRef:
|
||||
z3expr = SympyToZ3(self).run(e)
|
||||
assert isinstance(z3expr, z3.BoolRef), f"expected boolean expression. Got: {z3expr}"
|
||||
assert isinstance(
|
||||
z3expr, z3.BoolRef
|
||||
), f"expected boolean expression. Got: {z3expr}"
|
||||
return z3expr
|
||||
|
||||
def add_source_expr(self, e: z3.BoolRef) -> None:
|
||||
|
|
@ -557,17 +566,21 @@ try:
|
|||
# Log the found model and the source expressions that failed.
|
||||
model = solver.model()
|
||||
raise ValidationException(
|
||||
model, self._assertions, self._target_exprs,
|
||||
model,
|
||||
self._assertions,
|
||||
self._target_exprs,
|
||||
failed_source_exprs=[
|
||||
inp for inp in self._source_exprs if not model.evaluate(inp)
|
||||
]
|
||||
],
|
||||
)
|
||||
else:
|
||||
if r == z3.unknown:
|
||||
# Could not find a solution. It didn't fail, but it also
|
||||
# didn't succeed. Canceling the validation execution (keyboard
|
||||
# interrupt) also gets to this branch.
|
||||
log.warning("translation validation: could not validate: got z3.unknown")
|
||||
log.warning(
|
||||
"translation validation: could not validate: got z3.unknown"
|
||||
)
|
||||
else:
|
||||
# Target expressions are sound.
|
||||
assert r == z3.unsat
|
||||
|
|
@ -577,21 +590,30 @@ except ImportError:
|
|||
_HAS_Z3 = False
|
||||
|
||||
__all__ = [
|
||||
"translation_validation_enabled", "translation_validation_timeout",
|
||||
"ValidationException", "BisectValidationException",
|
||||
"translation_validation_enabled",
|
||||
"translation_validation_timeout",
|
||||
"ValidationException",
|
||||
"BisectValidationException",
|
||||
]
|
||||
|
||||
else:
|
||||
_HAS_Z3 = True
|
||||
|
||||
__all__ = [
|
||||
"z3str", "z3op", "PopulateValidator", "SympyToZ3", "TranslationValidator",
|
||||
"translation_validation_enabled", "translation_validation_timeout",
|
||||
"ValidationException", "BisectValidationException",
|
||||
"z3str",
|
||||
"z3op",
|
||||
"PopulateValidator",
|
||||
"SympyToZ3",
|
||||
"TranslationValidator",
|
||||
"translation_validation_enabled",
|
||||
"translation_validation_timeout",
|
||||
"ValidationException",
|
||||
"BisectValidationException",
|
||||
]
|
||||
|
||||
from torch.fx.experimental import _config as config
|
||||
|
||||
|
||||
def translation_validation_enabled() -> bool:
|
||||
# Checks everytime this function is called, in case the Dynamo
|
||||
# option is set, but Z3 is not installed.
|
||||
|
|
@ -655,9 +677,11 @@ Failure occurred while running node:
|
|||
def __str__(self):
|
||||
return f"{self.msg}\n\n{self.details}"
|
||||
|
||||
|
||||
# Checks when this module is loaded.
|
||||
_assert_z3_installed_if_tv_set()
|
||||
|
||||
|
||||
# Translation validation bisection.
|
||||
#
|
||||
# Bisect into the torch._assert nodes recorded in the shape_env FX graph, and raise
|
||||
|
|
@ -667,8 +691,16 @@ _assert_z3_installed_if_tv_set()
|
|||
# might be silently happening. This function tries to nail down exactly at which
|
||||
# point things went wrong from a validation perspective.
|
||||
def bisect(shape_env):
|
||||
from torch.fx.experimental.symbolic_shapes import ShapeEnv, SHAPEENV_EVENT_KEY, CURRENT_NODE_KEY
|
||||
from torch.fx.experimental.recording import FakeTensorMeta, ShapeEnvEvent, replay_shape_env_events
|
||||
from torch.fx.experimental.recording import (
|
||||
FakeTensorMeta,
|
||||
replay_shape_env_events,
|
||||
ShapeEnvEvent,
|
||||
)
|
||||
from torch.fx.experimental.symbolic_shapes import (
|
||||
CURRENT_NODE_KEY,
|
||||
ShapeEnv,
|
||||
SHAPEENV_EVENT_KEY,
|
||||
)
|
||||
|
||||
events = shape_env.events
|
||||
|
||||
|
|
@ -696,7 +728,9 @@ def bisect(shape_env):
|
|||
)
|
||||
|
||||
# Checks whether the given shape_env fails when produce_guards is called.
|
||||
def check_shapeenv_fails(shape_env: ShapeEnv, tracked_fakes: Optional[List[Any]]) -> Optional[ValidationException]:
|
||||
def check_shapeenv_fails(
|
||||
shape_env: ShapeEnv, tracked_fakes: Optional[List[Any]]
|
||||
) -> Optional[ValidationException]:
|
||||
assert tracked_fakes is not None
|
||||
try:
|
||||
# This produce_guards call is a best-effort replication, since we
|
||||
|
|
@ -720,7 +754,9 @@ def bisect(shape_env):
|
|||
shape_env.graph.lint()
|
||||
return check_shapeenv_fails(shape_env, events[number].tracked_fakes)
|
||||
|
||||
last_exception = check_shapeenv_fails(shape_env, shape_env._snapshot_tracked_fakes())
|
||||
last_exception = check_shapeenv_fails(
|
||||
shape_env, shape_env._snapshot_tracked_fakes()
|
||||
)
|
||||
|
||||
if not last_exception:
|
||||
# We don't actually fail due to a produce_guards call.
|
||||
|
|
@ -738,7 +774,9 @@ def bisect(shape_env):
|
|||
|
||||
# Bisection happens on the assertion nodes of the recorded FX graph for
|
||||
# dynamic shapes.
|
||||
assert_nodes = [node for node in shape_env.graph.nodes if node.target == torch._assert]
|
||||
assert_nodes = [
|
||||
node for node in shape_env.graph.nodes if node.target == torch._assert
|
||||
]
|
||||
|
||||
# Preparing the indices for binary search.
|
||||
left, mid, right = 0, 0, len(assert_nodes) - 1
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user