pytorch/test/dynamo/test_dynamic_shapes.py
Yukio Siraichi 12e8530b35 Record and replay for ShapeEnv. (#107989)
This PR introduces record and replay functionality for `ShapeEnv` instances. In short,
throughout the execution of a program, we record events (e.g. function calls that modify
its state) so that, in the future, we are able to reproduce any intermediary state of the
instance.

In summary, this PR introduces the following changes (they mostly belong to
_symbolic_shapes.py_ unless otherwise stated):

- Create `ShapeEnvEvent` class for recording function calls + arguments
- Create `record_shapeenv_event` decorator and decorate every function that changes the
  state of a `ShapeEnv`: it creates an appropriate event and add it to the available
  ShapeEnv instance (sometimes it has to extract from `SymTypes`).
- Create `SymNode.with_shape_env` convenient function for replacing `ShapeEnv` references
- Wraps `ShapeEnv` initialization method: so that we also save the exact way a `ShapeEnv`
  was constructed, i.e. arguments
- Introduces a way to compare two `ShapeEnv` instances, defining a concept of state for
  that class. In short, the state of `ShapeEnv` is every variable that may change the
  execution flow
- Create `check_shape_env_recorded_events` dynamo configuration for enabling the check for
  equality the state of `ShapeEnv` with another one that was constructed by replaying all
  the recorded events. This check takes place inside `produce_guards`
- Create `replay_shape_env_events` function for replaying given events. It assumes the
  first event is `ShapeEnv` initialization function

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107989
Approved by: https://github.com/ezyang
2023-09-13 00:22:38 +00:00

92 lines
2.3 KiB
Python

# Owner(s): ["module: dynamo"]
import unittest
import warnings
from torch._dynamo import config
from torch._dynamo.testing import make_test_cls_with_patches
from torch.testing._internal.common_utils import TEST_Z3
try:
from . import (
test_aot_autograd,
test_ctx_manager,
test_export,
test_functions,
test_higher_order_ops,
test_misc,
test_modules,
test_repros,
test_subgraphs,
)
except ImportError:
import test_aot_autograd
import test_ctx_manager
import test_export
import test_functions
import test_higher_order_ops
import test_misc
import test_modules
import test_repros
import test_subgraphs
test_classes = {}
def make_dynamic_cls(cls):
suffix = "_dynamic_shapes"
cls_prefix = "DynamicShapes"
test_class = make_test_cls_with_patches(
cls,
cls_prefix,
suffix,
(config, "assume_static_by_default", False),
(config, "specialize_int", False),
(config, "translation_validation", TEST_Z3),
(config, "check_shape_env_recorded_events", True),
xfail_prop="_expected_failure_dynamic",
)
test_classes[test_class.__name__] = test_class
# REMOVING THIS LINE WILL STOP TESTS FROM RUNNING
globals()[test_class.__name__] = test_class
return test_class
tests = [
test_ctx_manager.CtxManagerTests,
test_functions.FunctionTests,
test_misc.MiscTests,
test_repros.ReproTests,
test_modules.NNModuleTests,
test_export.ExportTests,
test_subgraphs.SubGraphTests,
test_higher_order_ops.HigherOrderOpTests,
test_higher_order_ops.FuncTorchHigherOrderOpTests,
test_aot_autograd.AotAutogradFallbackTests,
]
for test in tests:
make_dynamic_cls(test)
del test
if TEST_Z3:
# this only fails when z3 is available
unittest.expectedFailure(
# SymPy is incorrectly transforming 's0 / 6 == 0.5' into 'False'.
# Ref: https://github.com/sympy/sympy/issues/25146
DynamicShapesReproTests.test_dynamic_shapes_float_guard_dynamic_shapes
)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
if not TEST_Z3:
warnings.warn(
"translation validation is off. "
"Testing with translation validation requires Z3."
)
run_tests()