Format torch.fx.experimental.validator (#136935)

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136935
Approved by: https://github.com/Skylion007
ghstack dependencies: #136934
This commit is contained in:
Edward Z. Yang 2024-09-30 18:21:40 -07:00 committed by PyTorch MergeBot
parent 33c2d3232f
commit 951af3d3d8
2 changed files with 77 additions and 40 deletions

View File

@ -1269,7 +1269,6 @@ exclude_patterns = [
'torch/fx/experimental/unification/utils.py',
'torch/fx/experimental/unification/variable.py',
'torch/fx/experimental/unify_refinements.py',
'torch/fx/experimental/validator.py',
'torch/fx/graph.py',
'torch/fx/graph_module.py',
'torch/fx/interpreter.py',

View File

@ -1,22 +1,22 @@
# mypy: allow-untyped-defs
import builtins
import functools
import logging
import math
import operator
import sympy
import builtins
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
import sympy
import torch
import torch.fx
import torch.fx.traceback as fx_traceback
from torch._dynamo.exc import TorchDynamoException
from torch._dynamo.utils import dynamo_timed
from torch.fx.node import Argument, Target
from torch.utils._sympy.interp import sympy_interp
from torch._dynamo.utils import dynamo_timed
log = logging.getLogger(__name__)
@ -45,7 +45,6 @@ try:
# and the FX nodes (see [Note: PopulateValidator]) that go through
# 'ShapeEnv.evaluate_expr' function. Finally, we run the validation.
# (see [Note: TranslationValidator])
# Better Z3 to string implementation (for a small fraction of Z3).
#
# Here are the things we clean before showing the Z3 expression:
@ -68,7 +67,6 @@ try:
# This is done using rewriting rules, so shouldn't take long.
e = z3.simplify(e)
# Only support function applications.
# Even Z3 "variables" are, in fact, function applications.
if not z3.is_app(e):
@ -176,7 +174,9 @@ try:
# Python semantics for 'FloorDiv' states that before applying the floor
# function, the operands are converted to their common type.
def floordiv(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef:
def floordiv(
self, numerator: z3.ArithRef, denominator: z3.ArithRef
) -> z3.ArithRef:
cast_result_to_real = numerator.is_real() or denominator.is_real()
result = _Z3Ops.to_int(self.div(numerator, denominator))
# Since the 'result' is already an integer, we just have to check
@ -185,9 +185,7 @@ try:
def ceil(self, number: z3.ArithRef) -> z3.ArithRef:
return z3.If(
self.floor(number) < number,
self.floor(number + 1),
number
self.floor(number) < number, self.floor(number + 1), number
) # type: ignore[return-value]
def max(self, a: z3.ArithRef, b: z3.ArithRef) -> z3.ArithRef:
@ -284,11 +282,9 @@ try:
operator.mod: lift(ops.mod),
operator.abs: lift(ops.abs),
builtins.round: lift(ops.round_to_int),
# Math module.
math.ceil: lift(ops.ceil),
math.floor: lift(ops.floor),
# Torch module.
torch.sym_float: lift(ops.to_real),
torch.sym_max: lift(ops.max),
@ -319,17 +315,23 @@ try:
module = torch.fx.GraphModule(root={}, graph=graph)
super().__init__(module, garbage_collect_values=True)
def placeholder(self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
def placeholder(
self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any]
) -> Any:
symbol = fx_traceback.get_current_meta()["symbol"]
return self.validator.z3var(symbol)
def call_function(self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
def call_function(
self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any]
) -> Any:
if target != torch._assert:
# Lift and runs the node target function
return super().call_function(z3op(target, self.validator), args, kwargs) # type: ignore[arg-type]
# Adds the Z3 expression corresponding to the first argument
# as a validator input.
assert len(args) == 1, f"expected 1 argument on assertion. Got: {len(args)} "
assert (
len(args) == 1
), f"expected 1 argument on assertion. Got: {len(args)} "
self.validator.add_source_expr(args[0]) # type: ignore[arg-type]
# Translates SymPy expressions into Z3 expressions.
@ -369,13 +371,19 @@ try:
def round_to_int(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef:
return self._ops.round_to_int(x)
def int_truediv(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef:
def int_truediv(
self, numerator: z3.ArithRef, denominator: z3.ArithRef
) -> z3.ArithRef:
return self._ops.div(numerator, denominator)
def truediv(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef:
def truediv(
self, numerator: z3.ArithRef, denominator: z3.ArithRef
) -> z3.ArithRef:
return self._ops.div(numerator, denominator)
def floordiv(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef:
def floordiv(
self, numerator: z3.ArithRef, denominator: z3.ArithRef
) -> z3.ArithRef:
return self._ops.floordiv(numerator, denominator)
def div(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef:
@ -488,10 +496,11 @@ try:
# Z3 variable corresponding to 's'.
self.z3var(s)
def to_z3_boolean_expr(self, e: sympy.Basic) -> z3.BoolRef:
z3expr = SympyToZ3(self).run(e)
assert isinstance(z3expr, z3.BoolRef), f"expected boolean expression. Got: {z3expr}"
assert isinstance(
z3expr, z3.BoolRef
), f"expected boolean expression. Got: {z3expr}"
return z3expr
def add_source_expr(self, e: z3.BoolRef) -> None:
@ -557,17 +566,21 @@ try:
# Log the found model and the source expressions that failed.
model = solver.model()
raise ValidationException(
model, self._assertions, self._target_exprs,
model,
self._assertions,
self._target_exprs,
failed_source_exprs=[
inp for inp in self._source_exprs if not model.evaluate(inp)
]
],
)
else:
if r == z3.unknown:
# Could not find a solution. It didn't fail, but it also
# didn't succeed. Canceling the validation execution (keyboard
# interrupt) also gets to this branch.
log.warning("translation validation: could not validate: got z3.unknown")
log.warning(
"translation validation: could not validate: got z3.unknown"
)
else:
# Target expressions are sound.
assert r == z3.unsat
@ -577,21 +590,30 @@ except ImportError:
_HAS_Z3 = False
__all__ = [
"translation_validation_enabled", "translation_validation_timeout",
"ValidationException", "BisectValidationException",
"translation_validation_enabled",
"translation_validation_timeout",
"ValidationException",
"BisectValidationException",
]
else:
_HAS_Z3 = True
__all__ = [
"z3str", "z3op", "PopulateValidator", "SympyToZ3", "TranslationValidator",
"translation_validation_enabled", "translation_validation_timeout",
"ValidationException", "BisectValidationException",
"z3str",
"z3op",
"PopulateValidator",
"SympyToZ3",
"TranslationValidator",
"translation_validation_enabled",
"translation_validation_timeout",
"ValidationException",
"BisectValidationException",
]
from torch.fx.experimental import _config as config
def translation_validation_enabled() -> bool:
# Checks everytime this function is called, in case the Dynamo
# option is set, but Z3 is not installed.
@ -655,9 +677,11 @@ Failure occurred while running node:
def __str__(self):
return f"{self.msg}\n\n{self.details}"
# Checks when this module is loaded.
_assert_z3_installed_if_tv_set()
# Translation validation bisection.
#
# Bisect into the torch._assert nodes recorded in the shape_env FX graph, and raise
@ -667,8 +691,16 @@ _assert_z3_installed_if_tv_set()
# might be silently happening. This function tries to nail down exactly at which
# point things went wrong from a validation perspective.
def bisect(shape_env):
from torch.fx.experimental.symbolic_shapes import ShapeEnv, SHAPEENV_EVENT_KEY, CURRENT_NODE_KEY
from torch.fx.experimental.recording import FakeTensorMeta, ShapeEnvEvent, replay_shape_env_events
from torch.fx.experimental.recording import (
FakeTensorMeta,
replay_shape_env_events,
ShapeEnvEvent,
)
from torch.fx.experimental.symbolic_shapes import (
CURRENT_NODE_KEY,
ShapeEnv,
SHAPEENV_EVENT_KEY,
)
events = shape_env.events
@ -696,7 +728,9 @@ def bisect(shape_env):
)
# Checks whether the given shape_env fails when produce_guards is called.
def check_shapeenv_fails(shape_env: ShapeEnv, tracked_fakes: Optional[List[Any]]) -> Optional[ValidationException]:
def check_shapeenv_fails(
shape_env: ShapeEnv, tracked_fakes: Optional[List[Any]]
) -> Optional[ValidationException]:
assert tracked_fakes is not None
try:
# This produce_guards call is a best-effort replication, since we
@ -720,7 +754,9 @@ def bisect(shape_env):
shape_env.graph.lint()
return check_shapeenv_fails(shape_env, events[number].tracked_fakes)
last_exception = check_shapeenv_fails(shape_env, shape_env._snapshot_tracked_fakes())
last_exception = check_shapeenv_fails(
shape_env, shape_env._snapshot_tracked_fakes()
)
if not last_exception:
# We don't actually fail due to a produce_guards call.
@ -738,7 +774,9 @@ def bisect(shape_env):
# Bisection happens on the assertion nodes of the recorded FX graph for
# dynamic shapes.
assert_nodes = [node for node in shape_env.graph.nodes if node.target == torch._assert]
assert_nodes = [
node for node in shape_env.graph.nodes if node.target == torch._assert
]
# Preparing the indices for binary search.
left, mid, right = 0, 0, len(assert_nodes) - 1