mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
This PR introduces record and replay functionality for `ShapeEnv` instances. In short, throughout the execution of a program, we record events (e.g. function calls that modify its state) so that, in the future, we are able to reproduce any intermediary state of the instance. In summary, this PR introduces the following changes (they mostly belong to _symbolic_shapes.py_ unless otherwise stated): - Create `ShapeEnvEvent` class for recording function calls + arguments - Create `record_shapeenv_event` decorator and decorate every function that changes the state of a `ShapeEnv`: it creates an appropriate event and add it to the available ShapeEnv instance (sometimes it has to extract from `SymTypes`). - Create `SymNode.with_shape_env` convenient function for replacing `ShapeEnv` references - Wraps `ShapeEnv` initialization method: so that we also save the exact way a `ShapeEnv` was constructed, i.e. arguments - Introduces a way to compare two `ShapeEnv` instances, defining a concept of state for that class. In short, the state of `ShapeEnv` is every variable that may change the execution flow - Create `check_shape_env_recorded_events` dynamo configuration for enabling the check for equality the state of `ShapeEnv` with another one that was constructed by replaying all the recorded events. This check takes place inside `produce_guards` - Create `replay_shape_env_events` function for replaying given events. It assumes the first event is `ShapeEnv` initialization function Pull Request resolved: https://github.com/pytorch/pytorch/pull/107989 Approved by: https://github.com/ezyang
92 lines
2.3 KiB
Python
92 lines
2.3 KiB
Python
# Owner(s): ["module: dynamo"]
|
|
import unittest
|
|
import warnings
|
|
|
|
from torch._dynamo import config
|
|
from torch._dynamo.testing import make_test_cls_with_patches
|
|
from torch.testing._internal.common_utils import TEST_Z3
|
|
|
|
try:
|
|
from . import (
|
|
test_aot_autograd,
|
|
test_ctx_manager,
|
|
test_export,
|
|
test_functions,
|
|
test_higher_order_ops,
|
|
test_misc,
|
|
test_modules,
|
|
test_repros,
|
|
test_subgraphs,
|
|
)
|
|
except ImportError:
|
|
import test_aot_autograd
|
|
import test_ctx_manager
|
|
import test_export
|
|
import test_functions
|
|
import test_higher_order_ops
|
|
import test_misc
|
|
import test_modules
|
|
import test_repros
|
|
import test_subgraphs
|
|
|
|
|
|
test_classes = {}
|
|
|
|
|
|
def make_dynamic_cls(cls):
|
|
suffix = "_dynamic_shapes"
|
|
|
|
cls_prefix = "DynamicShapes"
|
|
|
|
test_class = make_test_cls_with_patches(
|
|
cls,
|
|
cls_prefix,
|
|
suffix,
|
|
(config, "assume_static_by_default", False),
|
|
(config, "specialize_int", False),
|
|
(config, "translation_validation", TEST_Z3),
|
|
(config, "check_shape_env_recorded_events", True),
|
|
xfail_prop="_expected_failure_dynamic",
|
|
)
|
|
|
|
test_classes[test_class.__name__] = test_class
|
|
# REMOVING THIS LINE WILL STOP TESTS FROM RUNNING
|
|
globals()[test_class.__name__] = test_class
|
|
return test_class
|
|
|
|
|
|
tests = [
|
|
test_ctx_manager.CtxManagerTests,
|
|
test_functions.FunctionTests,
|
|
test_misc.MiscTests,
|
|
test_repros.ReproTests,
|
|
test_modules.NNModuleTests,
|
|
test_export.ExportTests,
|
|
test_subgraphs.SubGraphTests,
|
|
test_higher_order_ops.HigherOrderOpTests,
|
|
test_higher_order_ops.FuncTorchHigherOrderOpTests,
|
|
test_aot_autograd.AotAutogradFallbackTests,
|
|
]
|
|
for test in tests:
|
|
make_dynamic_cls(test)
|
|
del test
|
|
|
|
if TEST_Z3:
|
|
# this only fails when z3 is available
|
|
unittest.expectedFailure(
|
|
# SymPy is incorrectly transforming 's0 / 6 == 0.5' into 'False'.
|
|
# Ref: https://github.com/sympy/sympy/issues/25146
|
|
DynamicShapesReproTests.test_dynamic_shapes_float_guard_dynamic_shapes
|
|
)
|
|
|
|
if __name__ == "__main__":
|
|
from torch._dynamo.test_case import run_tests
|
|
|
|
if not TEST_Z3:
|
|
warnings.warn(
|
|
"translation validation is off. "
|
|
"Testing with translation validation requires Z3."
|
|
)
|
|
|
|
run_tests()
|