mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Record and replay for ShapeEnv. (#107989)
This PR introduces record and replay functionality for `ShapeEnv` instances. In short, throughout the execution of a program, we record events (e.g. function calls that modify its state) so that, in the future, we are able to reproduce any intermediary state of the instance. In summary, this PR introduces the following changes (they mostly belong to _symbolic_shapes.py_ unless otherwise stated): - Create `ShapeEnvEvent` class for recording function calls + arguments - Create `record_shapeenv_event` decorator and decorate every function that changes the state of a `ShapeEnv`: it creates an appropriate event and add it to the available ShapeEnv instance (sometimes it has to extract from `SymTypes`). - Create `SymNode.with_shape_env` convenient function for replacing `ShapeEnv` references - Wraps `ShapeEnv` initialization method: so that we also save the exact way a `ShapeEnv` was constructed, i.e. arguments - Introduces a way to compare two `ShapeEnv` instances, defining a concept of state for that class. In short, the state of `ShapeEnv` is every variable that may change the execution flow - Create `check_shape_env_recorded_events` dynamo configuration for enabling the check for equality the state of `ShapeEnv` with another one that was constructed by replaying all the recorded events. This check takes place inside `produce_guards` - Create `replay_shape_env_events` function for replaying given events. It assumes the first event is `ShapeEnv` initialization function Pull Request resolved: https://github.com/pytorch/pytorch/pull/107989 Approved by: https://github.com/ezyang
This commit is contained in:
parent
e066056414
commit
12e8530b35
|
|
@ -45,6 +45,7 @@ def make_dynamic_cls(cls):
|
|||
(config, "assume_static_by_default", False),
|
||||
(config, "specialize_int", False),
|
||||
(config, "translation_validation", TEST_Z3),
|
||||
(config, "check_shape_env_recorded_events", True),
|
||||
xfail_prop="_expected_failure_dynamic",
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import copy
|
|||
import dataclasses
|
||||
import dis
|
||||
import enum
|
||||
import functools
|
||||
import itertools
|
||||
import logging
|
||||
import math
|
||||
|
|
@ -32,7 +33,7 @@ from torch._C import FileCheck
|
|||
from torch._dynamo import allow_in_graph, bytecode_analysis, bytecode_transformation
|
||||
from torch._dynamo.eval_frame import _debug_get_cache_entry_list
|
||||
from torch._dynamo.exc import Unsupported
|
||||
from torch._dynamo.source import GetItemSource, LocalSource
|
||||
from torch._dynamo.source import ConstantSource, GetItemSource, LocalSource
|
||||
from torch._dynamo.testing import (
|
||||
CompileCounter,
|
||||
CompileCounterWithBackend,
|
||||
|
|
@ -47,7 +48,12 @@ from torch.ao.quantization import MinMaxObserver
|
|||
from torch.ao.quantization.fake_quantize import FakeQuantize
|
||||
from torch.ao.quantization.qconfig import QConfig
|
||||
from torch.ao.quantization.quantize_fx import prepare_qat_fx
|
||||
from torch.fx.experimental.symbolic_shapes import ConstraintViolationError
|
||||
from torch.fx.experimental.recording import NotEqualError, replay_shape_env_events
|
||||
from torch.fx.experimental.symbolic_shapes import (
|
||||
ConstraintViolationError,
|
||||
expect_true,
|
||||
ShapeEnv,
|
||||
)
|
||||
from torch.nn import functional as F
|
||||
from torch.testing._internal.common_cuda import (
|
||||
PLATFORM_SUPPORTS_FUSED_SDPA,
|
||||
|
|
@ -62,6 +68,19 @@ from torch.testing._internal.jit_utils import JitTestCase
|
|||
mytuple = collections.namedtuple("mytuple", ["a", "b", "ab"])
|
||||
|
||||
|
||||
# Specializes a test to run only if translation validation is set.
|
||||
def onlyIfTranslationValidation(fn: typing.Callable) -> typing.Callable:
|
||||
@functools.wraps(fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
import torch.fx.experimental.validator
|
||||
|
||||
if torch.fx.experimental.validator.translation_validation_enabled():
|
||||
return fn(*args, **kwargs)
|
||||
raise unittest.SkipTest(f"only works when TV is True.")
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class MyPickledModule(torch.nn.Module):
|
||||
def __init__(self, z):
|
||||
super().__init__()
|
||||
|
|
@ -7074,6 +7093,241 @@ def ___make_guard_fn():
|
|||
self.assertEqual(list(eager), list(compiled))
|
||||
self.assertEqual(counter.frame_count, 1)
|
||||
|
||||
def _replay_and_check(self, shape_env: ShapeEnv):
|
||||
replayed = replay_shape_env_events(shape_env.events)
|
||||
shape_env.check_equal(replayed)
|
||||
|
||||
@onlyIfTranslationValidation
|
||||
def test_shape_env_equal_empty(self):
|
||||
main, other = ShapeEnv(), ShapeEnv()
|
||||
main.check_equal(other)
|
||||
self._replay_and_check(main)
|
||||
|
||||
@onlyIfTranslationValidation
|
||||
def test_shape_env_equal_constructor(self):
|
||||
main, other = ShapeEnv(allow_scalar_outputs=False), ShapeEnv()
|
||||
self.assertExpectedRaisesInline(
|
||||
NotEqualError,
|
||||
lambda: main.check_equal(other),
|
||||
"""\
|
||||
ShapeEnv not equal: field values don't match:
|
||||
|
||||
==> allow_scalar_outputs: values don't match.
|
||||
> Left: False
|
||||
> Right: True
|
||||
""",
|
||||
)
|
||||
self._replay_and_check(main)
|
||||
|
||||
@onlyIfTranslationValidation
|
||||
def test_shape_env_equal_create_symbolic_sizes_strides_storage_offset(self):
|
||||
main, other = ShapeEnv(), ShapeEnv()
|
||||
main.create_symbolic_sizes_strides_storage_offset(
|
||||
torch.randn(3, 2), ConstantSource("x")
|
||||
)
|
||||
self.assertExpectedRaisesInline(
|
||||
NotEqualError,
|
||||
lambda: main.check_equal(other),
|
||||
"""\
|
||||
ShapeEnv not equal: field values don't match:
|
||||
|
||||
==> name_to_node: values don't match.
|
||||
> Left: {x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_}
|
||||
> Right: {}
|
||||
==> source_to_symbol: values don't match.
|
||||
> Left: {x.size()[0]: x.size()[0], x.size()[1]: x.size()[1], x.storage_offset(): x.storage_offset(), x.stride()[0]: x.stride()[0], x.stride()[1]: x.stride()[1]}
|
||||
> Right: {}
|
||||
==> val_to_var: values don't match.
|
||||
> Left: {0: 0, 1: 1, 2: s1, 3: s0}
|
||||
> Right: {0: 0, 1: 1}
|
||||
==> var_to_range: values don't match.
|
||||
> Left: {s0: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False)}
|
||||
> Right: {}
|
||||
==> var_to_sources: values don't match.
|
||||
> Left: {s0: [TensorPropertySource(base=ConstantSource(source_name='x'), prop=<TensorProperty.SIZE: 0>, idx=0)], s1: [TensorPropertySource(base=ConstantSource(source_name='x'), prop=<TensorProperty.SIZE: 0>, idx=1)]}
|
||||
> Right: {}
|
||||
==> var_to_val: values don't match.
|
||||
> Left: {s0: 3, s1: 2}
|
||||
> Right: {}
|
||||
""",
|
||||
)
|
||||
self._replay_and_check(main)
|
||||
|
||||
@onlyIfTranslationValidation
|
||||
def test_shape_env_equal_unbacked(self):
|
||||
main, other = ShapeEnv(), ShapeEnv()
|
||||
main.create_unbacked_symint()
|
||||
main.create_unbacked_symfloat()
|
||||
main.create_unbacked_symbool()
|
||||
self.assertExpectedRaisesInline(
|
||||
NotEqualError,
|
||||
lambda: main.check_equal(other),
|
||||
"""\
|
||||
ShapeEnv not equal: field values don't match:
|
||||
|
||||
==> name_to_node: values don't match.
|
||||
> Left: {f0, i0, i1}
|
||||
> Right: {}
|
||||
==> unbacked_symfloat_counter: values don't match.
|
||||
> Left: 1
|
||||
> Right: 0
|
||||
==> unbacked_symint_counter: values don't match.
|
||||
> Left: 2
|
||||
> Right: 0
|
||||
==> var_to_range: values don't match.
|
||||
> Left: {f0: ValueRanges(lower=-oo, upper=oo, is_bool=False), i0: ValueRanges(lower=-9223372036854775808, upper=9223372036854775807, is_bool=False), i1: ValueRanges(lower=0, upper=1, is_bool=False)}
|
||||
> Right: {}
|
||||
""",
|
||||
)
|
||||
self._replay_and_check(main)
|
||||
|
||||
@onlyIfTranslationValidation
|
||||
def test_shape_env_equal_evaluate_expr_divisible(self):
|
||||
main, other = ShapeEnv(), ShapeEnv()
|
||||
|
||||
# Call create_symbolic_sizes_strides_storage_offset on both of them.
|
||||
r = main.create_symbolic_sizes_strides_storage_offset(
|
||||
torch.randn(3, 2), ConstantSource("x")
|
||||
)
|
||||
other.create_symbolic_sizes_strides_storage_offset(
|
||||
torch.randn(3, 2), ConstantSource("x")
|
||||
)
|
||||
|
||||
# Create a guard: size[0] % 3 == 0 (only in the main ShapeEnv)
|
||||
# - +1 guard entry
|
||||
# - +1 divisible entry
|
||||
size = r[0]
|
||||
bool(size[0] % 3 == 0)
|
||||
|
||||
self.assertExpectedRaisesInline(
|
||||
NotEqualError,
|
||||
lambda: main.check_equal(other),
|
||||
"""\
|
||||
ShapeEnv not equal: field values don't match:
|
||||
|
||||
==> divisible: values don't match.
|
||||
> Left: {Mod(s0, 3)}
|
||||
> Right: {}
|
||||
==> guards: values don't match.
|
||||
> Left: [Eq(Mod(s0, 3), 0)]
|
||||
> Right: []
|
||||
==> name_to_node: values don't match.
|
||||
> Left: {_assert, eq, mod, x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_}
|
||||
> Right: {x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_}
|
||||
""",
|
||||
)
|
||||
self._replay_and_check(main)
|
||||
|
||||
@onlyIfTranslationValidation
|
||||
def test_shape_env_equal_evaluate_expr_replacement(self):
|
||||
main, other = ShapeEnv(), ShapeEnv()
|
||||
|
||||
# Call create_symbolic_sizes_strides_storage_offset on both of them.
|
||||
r = main.create_symbolic_sizes_strides_storage_offset(
|
||||
torch.randn(3, 2), ConstantSource("x")
|
||||
)
|
||||
other.create_symbolic_sizes_strides_storage_offset(
|
||||
torch.randn(3, 2), ConstantSource("x")
|
||||
)
|
||||
|
||||
# Create a guard: size[0] == 3 (only in the main ShapeEnv)
|
||||
# - +1 guard entry
|
||||
# - +1 replacement entry
|
||||
size = r[0]
|
||||
bool(size[0] == 3)
|
||||
|
||||
self.assertExpectedRaisesInline(
|
||||
NotEqualError,
|
||||
lambda: main.check_equal(other),
|
||||
"""\
|
||||
ShapeEnv not equal: field values don't match:
|
||||
|
||||
==> guards: values don't match.
|
||||
> Left: [Eq(s0, 3)]
|
||||
> Right: []
|
||||
==> name_to_node: values don't match.
|
||||
> Left: {_assert, eq, x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_}
|
||||
> Right: {x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_}
|
||||
==> replacements: values don't match.
|
||||
> Left: {s0: 3}
|
||||
> Right: {}
|
||||
""",
|
||||
)
|
||||
self._replay_and_check(main)
|
||||
|
||||
@onlyIfTranslationValidation
|
||||
def test_shape_env_equal_evaluate_expr_refinement(self):
|
||||
main, other = ShapeEnv(), ShapeEnv()
|
||||
|
||||
# Call create_symbolic_sizes_strides_storage_offset on both of them.
|
||||
r = main.create_symbolic_sizes_strides_storage_offset(
|
||||
torch.randn(3, 2), ConstantSource("x")
|
||||
)
|
||||
other.create_symbolic_sizes_strides_storage_offset(
|
||||
torch.randn(3, 2), ConstantSource("x")
|
||||
)
|
||||
|
||||
# Create a guard: size[0] >= 3 (only in the main ShapeEnv)
|
||||
# - +1 guard entry
|
||||
# - +1 var_to_guard entry
|
||||
# - Change: var_to_range
|
||||
size = r[0]
|
||||
bool(size[0] >= 3)
|
||||
|
||||
self.assertExpectedRaisesInline(
|
||||
NotEqualError,
|
||||
lambda: main.check_equal(other),
|
||||
"""\
|
||||
ShapeEnv not equal: field values don't match:
|
||||
|
||||
==> guards: values don't match.
|
||||
> Left: [s0 >= 3]
|
||||
> Right: []
|
||||
==> name_to_node: values don't match.
|
||||
> Left: {_assert, ge, x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_}
|
||||
> Right: {x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_}
|
||||
==> var_to_guards: values don't match.
|
||||
> Left: {s0: (s0 >= 3, None)}
|
||||
> Right: {}
|
||||
==> var_to_range: values don't match.
|
||||
> Left: {s0: ValueRanges(lower=3, upper=9223372036854775806, is_bool=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False)}
|
||||
> Right: {s0: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False)}
|
||||
""",
|
||||
)
|
||||
self._replay_and_check(main)
|
||||
|
||||
@onlyIfTranslationValidation
|
||||
def test_shape_env_equal_runtime_assert(self):
|
||||
main, other = ShapeEnv(), ShapeEnv()
|
||||
|
||||
# Call create_unbacked_symint on both of them.
|
||||
r = main.create_unbacked_symint()
|
||||
other.create_unbacked_symint()
|
||||
|
||||
# Create a runtime assert: r % 3 == 0 (only in the main ShapeEnv)
|
||||
# - +1 defferred_runtime_asserts entry
|
||||
# - Change: num_defferred_runtime_asserts
|
||||
expect_true(r % 3 == 0)
|
||||
|
||||
self.assertExpectedRaisesInline(
|
||||
NotEqualError,
|
||||
lambda: main.check_equal(other),
|
||||
"""\
|
||||
ShapeEnv not equal: field values don't match:
|
||||
|
||||
==> deferred_runtime_asserts: values don't match.
|
||||
> Left: {i0: [Eq(Mod(i0, 3), 0)]}
|
||||
> Right: {}
|
||||
==> name_to_node: values don't match.
|
||||
> Left: {_assert, eq, i0, mod}
|
||||
> Right: {i0}
|
||||
==> num_deferred_runtime_asserts: values don't match.
|
||||
> Left: 1
|
||||
> Right: 0
|
||||
""",
|
||||
)
|
||||
self._replay_and_check(main)
|
||||
|
||||
|
||||
class TestTracer(JitTestCase):
|
||||
def test_jit_save(self):
|
||||
|
|
|
|||
|
|
@ -256,6 +256,22 @@ translation_validation = (
|
|||
translation_validation_timeout = int(
|
||||
os.environ.get("TORCHDYNAMO_TRANSLATION_VALIDATION_TIMEOUT", "600000")
|
||||
)
|
||||
# Disables bisection for translation validation.
|
||||
#
|
||||
# Translation validation bisection is enabled by default, if translation validation
|
||||
# is also enabled. This should help finding guard simplification issues. However,
|
||||
# since validation uses Z3 for bisecting, it might take a lot of time.
|
||||
#
|
||||
# Set this configuration option so as to avoid bisecting.
|
||||
translation_validation_no_bisect = (
|
||||
os.environ.get("TORCHDYNAMO_TRANSLATION_NO_BISECT", "0") == "1"
|
||||
)
|
||||
# Disables ShapeEnv event recording.
|
||||
# See: [Note: Recording ShapeEnv Events]
|
||||
dont_record_shape_env_events = False
|
||||
# Checks whether replaying ShapeEnv events on a freshly constructed one yields
|
||||
# the a ShapeEnv with the same state. This should be used only in testing.
|
||||
check_shape_env_recorded_events = False
|
||||
|
||||
# Trace through NumPy or graphbreak
|
||||
trace_numpy = True
|
||||
|
|
|
|||
|
|
@ -267,26 +267,37 @@ class OutputGraph(Checkpointable[OutputGraphState]):
|
|||
"co_firstlineno": f_code.co_firstlineno,
|
||||
}
|
||||
|
||||
# In export mode, we force the shape_env to strictly disallow any constraining
|
||||
# of the user marked dynamic dims
|
||||
fake_mode = torch._subclasses.FakeTensorMode(
|
||||
shape_env=ShapeEnv(
|
||||
allow_scalar_outputs=config.capture_scalar_outputs,
|
||||
allow_dynamic_output_shape_ops=config.capture_dynamic_output_shape_ops,
|
||||
co_fields=self.co_fields,
|
||||
),
|
||||
# TODO (tmanlaibaatar) Remove this once we always lift params and buffers
|
||||
allow_non_fake_inputs=True if self.export else False,
|
||||
)
|
||||
self.tracing_context: TracingContext = TracingContext(fake_mode)
|
||||
self.init_ambient_guards()
|
||||
|
||||
# tracked_fakes says where any tensor that was wrapped to fake came
|
||||
# from. It is similar to GraphArg, in that all GraphArgs will get
|
||||
# will get added to TrackedFakes, but TrackedFakes also contains
|
||||
# GraphArgs that got pruned, and things like Tensor attributes which
|
||||
# aren't explicit graph inputs. Used by shape guard
|
||||
self.tracked_fakes: List[TrackedFake] = []
|
||||
|
||||
shape_env = ShapeEnv(
|
||||
# Reference Cycle!
|
||||
# Share a reference to the list of TrackedFake.
|
||||
#
|
||||
# ShapeEnv needs this in order to be able to reproduce the call
|
||||
# to produce_guards at an arbitrary time point. That is because
|
||||
# TrackedFake instances may have its metadata changed throughout
|
||||
# the program execution.
|
||||
tracked_fakes=self.tracked_fakes,
|
||||
allow_scalar_outputs=config.capture_scalar_outputs,
|
||||
allow_dynamic_output_shape_ops=config.capture_dynamic_output_shape_ops,
|
||||
co_fields=self.co_fields,
|
||||
)
|
||||
|
||||
# In export mode, we force the shape_env to strictly disallow any constraining
|
||||
# of the user marked dynamic dims
|
||||
fake_mode = torch._subclasses.FakeTensorMode(
|
||||
shape_env=shape_env,
|
||||
# TODO (tmanlaibaatar) Remove this once we always lift params and buffers
|
||||
allow_non_fake_inputs=True if self.export else False,
|
||||
)
|
||||
self.tracing_context: TracingContext = TracingContext(fake_mode)
|
||||
self.init_ambient_guards()
|
||||
|
||||
# Map each tensor id to a list of sources. This is necessary because
|
||||
# tensor ids cannot be recovered from tracked fakes (in general).
|
||||
# We use this map to interpret (i.e., check for violations of) constraints,
|
||||
|
|
|
|||
444
torch/fx/experimental/recording.py
Normal file
444
torch/fx/experimental/recording.py
Normal file
|
|
@ -0,0 +1,444 @@
|
|||
import functools
|
||||
import itertools
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ShapeEnvEvent",
|
||||
"record_shapeenv_event",
|
||||
"replay_shape_env_events",
|
||||
"FakeTensorMeta",
|
||||
"shape_env_check_state_equal",
|
||||
"NotEqualError",
|
||||
]
|
||||
|
||||
# [Note: Recording ShapeEnv Events]
|
||||
# =================================
|
||||
#
|
||||
# What is a ShapeEnv event?
|
||||
# -------------------------
|
||||
# We consider a ShapeEnv event every function call (ShapeEnv method or
|
||||
# independent function) that modifies the state of the ShapeEnv instance.
|
||||
# Such calls are recorded alongside their positional and keyword arguments,
|
||||
# so that it may be replayed over a different ShapeEnv instance.
|
||||
#
|
||||
# See [Note: ShapeEnv State Equality] for what is considered the state
|
||||
# of a ShapeEnv instance.
|
||||
#
|
||||
# What is it for?
|
||||
# ---------------
|
||||
# ShapeEnv events recording is used for reconstructing the ShapeEnv in an
|
||||
# arbitrary state in time.
|
||||
#
|
||||
# Being able to arbitrarily replay events like so is useful, mainly for
|
||||
# translation validation bisection. i.e. if a ValidationException has been
|
||||
# raised, find the earliest point in time where the translation validation
|
||||
# fails.
|
||||
#
|
||||
# Besides that, it also allows us to inspect the given instance and,
|
||||
# for example, check the guards that would actually be issued at that point.
|
||||
#
|
||||
# What kind of arguments can be stored in an event?
|
||||
# -------------------------------------------------
|
||||
# There's no specific rule for what cannot be used as an argument.
|
||||
# That said, pay special attention to the following cases:
|
||||
#
|
||||
# 1. Tensor inputs: there are some tests that check whether the inputs
|
||||
# were garbage collected after execution. These will fail if there's
|
||||
# an event that is holding a reference to those inputs.
|
||||
#
|
||||
# 2. ShapeEnv arguments: if there is an argument of ShapeEnv type, that
|
||||
# will be automatically replaced by the new given ShapeEnv instance.
|
||||
#
|
||||
# 3. SymTypes arguments: they also hold references to ShapeEnv. So,
|
||||
# whenever we see them, we create a new instance, replacing the
|
||||
# ShapeEnv reference.
|
||||
#
|
||||
# 4. FX nodes: specifically, FX nodes from the FX graph for symbolic
|
||||
# shapes. That argument must be replaced when replaying the event at
|
||||
# ShapeEnvEvent.run, since it has to reference a node from the given
|
||||
# instance, and not from the recorded instance.
|
||||
|
||||
|
||||
# Event class for reconstructing ShapeEnv at arbitrary time.
|
||||
#
|
||||
# Represents a method call that mutates ShapeEnv in a way that affects the
|
||||
# issued guards, when ShapeEnv.produce_guards is called.
|
||||
@dataclass
|
||||
class ShapeEnvEvent:
|
||||
# ShapeEnv method.
|
||||
f: Callable
|
||||
|
||||
# Arguments and keyword arguments called with.
|
||||
args: Optional[List[Any]] = None
|
||||
kwargs: Optional[Dict[str, Any]] = None
|
||||
|
||||
# List of tracked_fakes at the time the method was called.
|
||||
tracked_fakes: Optional[List[Any]] = None
|
||||
|
||||
# Name of the captured event.
|
||||
# Used for special handling of particular methods.
|
||||
name: Optional[str] = None
|
||||
|
||||
# Replay itself, but using shape_env as self.
|
||||
def run(self, shape_env=None) -> Any:
|
||||
from torch.fx.experimental.symbolic_shapes import ShapeEnv, SymTypes
|
||||
|
||||
# Special handling for the constructor event.
|
||||
if self.f is ShapeEnv:
|
||||
assert shape_env is None and self.args is None and self.kwargs is not None
|
||||
return ShapeEnv(**self.kwargs)
|
||||
|
||||
assert shape_env is not None
|
||||
args = list(self.args or list())
|
||||
kwargs = dict(self.kwargs or dict())
|
||||
|
||||
# Replace any argument of type ShapeEnv by the given one.
|
||||
args, kwargs = pytree.tree_map_only(
|
||||
ShapeEnv, lambda _: shape_env, (args, kwargs)
|
||||
)
|
||||
|
||||
# Replace any argument of type SymTypes by a new instance,
|
||||
# replacing its ShapeEnv reference.
|
||||
args, kwargs = pytree.tree_map_only(
|
||||
SymTypes,
|
||||
lambda a: type(a)(a.node.with_shape_env(shape_env)),
|
||||
(args, kwargs),
|
||||
)
|
||||
|
||||
# Converts FX nodes using the mapping argument.
|
||||
def maybe_convert_node(x: Any) -> Any:
|
||||
if not isinstance(x, torch.fx.Node):
|
||||
# Don't do anything to x if it's not an FX node.
|
||||
return x
|
||||
# If, at some point, we created an FX node, it means that translation validation is on.
|
||||
# It also means we are building an FX graph for symbolic shapes at shape_env.graph, and
|
||||
# we are tracking node names at shape_env.name_to_node.
|
||||
assert hasattr(shape_env, "name_to_node")
|
||||
name_to_node = shape_env.name_to_node # type: ignore[attr-defined]
|
||||
assert x.name in name_to_node
|
||||
return name_to_node[x.name]
|
||||
|
||||
# Replaces the value of an specific argument by the result of fn.
|
||||
def replacearg(index: int, key: str, fn: Callable):
|
||||
if index < len(args):
|
||||
args[index] = fn(args[index])
|
||||
if key in kwargs:
|
||||
kwargs[key] = fn(kwargs[key])
|
||||
|
||||
if self.is_create_fx_call_function():
|
||||
# ShapeEnv.create_fx_call_function:
|
||||
# "args" parameter is a tuple of FX nodes from the FX graph of the old ShapeEnv.
|
||||
# They must be replaced, since a "call_function" FX node with this tuple as argument
|
||||
# will be added to the FX graph of the new shape_env.
|
||||
replacearg(
|
||||
index=2,
|
||||
key="args",
|
||||
fn=lambda args: tuple(maybe_convert_node(a) for a in args),
|
||||
)
|
||||
if self.is_evaluate_expr() or self.is_defer_runtime_assert():
|
||||
# ShapeEnv.evaluate_expr and ShapeEnv.defer_runtime_assert:
|
||||
# "fx_node" parameter is an (optional) FX node that represents the evaluate expression.
|
||||
# They must be replaced, since it will be part of a "call_function" FX node for
|
||||
# torch._assert, which will be added to the FX graph of the new shape_env.
|
||||
replacearg(index=3, key="fx_node", fn=maybe_convert_node)
|
||||
|
||||
# Actually call the method with the converted arguments.
|
||||
return self.f(*args, **kwargs)
|
||||
|
||||
def __str__(self) -> str:
|
||||
name = self.name if self.name is not None else self.f.__name__
|
||||
return f"event: {name} ({self.args}, {self.kwargs})"
|
||||
|
||||
def is_create_fx_call_function(self) -> bool:
|
||||
return self.name == "create_fx_call_function"
|
||||
|
||||
def is_evaluate_expr(self) -> bool:
|
||||
return self.name == "evaluate_expr"
|
||||
|
||||
def is_defer_runtime_assert(self) -> bool:
|
||||
return self.name == "defer_runtime_assert"
|
||||
|
||||
|
||||
# Extracts a ShapeEnv instance inside args and kwargs.
|
||||
# Specifically, it looks for:
|
||||
# 1. ShapeEnv arguments
|
||||
# 2. SymInt, SymFloat, or SymBool arguments
|
||||
# If we find more than one object of any of the above types, we
|
||||
# also check that the ShapeEnv instance is the same for all of them.
|
||||
def _extract_shape_env_and_assert_equal(args, kwargs):
|
||||
from torch.fx.experimental.symbolic_shapes import ShapeEnv, SymTypes
|
||||
|
||||
def assert_equal(old: Optional[ShapeEnv], new: ShapeEnv) -> ShapeEnv:
|
||||
if old is not None:
|
||||
assert old is new, "call with different ShapeEnv"
|
||||
return new
|
||||
|
||||
shape_env = None
|
||||
for val in itertools.chain(args, kwargs.values()):
|
||||
if isinstance(val, ShapeEnv):
|
||||
shape_env = assert_equal(shape_env, val)
|
||||
if isinstance(val, SymTypes):
|
||||
shape_env = assert_equal(shape_env, val.node.shape_env)
|
||||
|
||||
assert shape_env is not None, "ShapeEnv not found"
|
||||
return shape_env
|
||||
|
||||
|
||||
# Decorator for recording the given function as a replayable event.
|
||||
#
|
||||
# This decorator should be used at every function that mutates the state of
|
||||
# ShapeEnv in some way that affects the resulting issued guards (i.e. when
|
||||
# ShapeEnv.produce_guards is called).
|
||||
#
|
||||
# save_tracked_fakes: saves a snapshot of the TrackedFake list.
|
||||
# This is used when calling ShapeEnv.produce_guards at arbitrary points in time.
|
||||
#
|
||||
# When to save the list of TrackedFake?
|
||||
# =====================================
|
||||
# We should save the list of TrackedFake whenever the translation validation
|
||||
# bisection may actually stop and call the produce_guards method at the moment
|
||||
# right after the recorded function was played. In other words, since the
|
||||
# bisection bisects through torch._assert calls, we should save in all methods
|
||||
# that adds a torch._assert call to the symbolic shapes FX graph.
|
||||
#
|
||||
# At the moment, there are 2 methods that save the list:
|
||||
# - ShapeEnv.evaluate_expr
|
||||
# - ShapeEnv.defer_runtime_assert
|
||||
def record_shapeenv_event(*, save_tracked_fakes: bool = False) -> Callable:
|
||||
def decorator(fn: Callable) -> Callable:
|
||||
assert callable(fn)
|
||||
name = fn.__name__
|
||||
|
||||
@functools.wraps(fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
||||
|
||||
if isinstance(args[0], ShapeEnv) and args[0].is_recording: # type: ignore[has-type]
|
||||
# If ShapeEnv is already recording an event, call the wrapped
|
||||
# function directly.
|
||||
#
|
||||
# NB: here, we skip the check of whether all ShapeEnv instances
|
||||
# are equal, in favor of a faster dispatch.
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
# Retrieve an instance of ShapeEnv.
|
||||
# Assumption: the collection of args and kwargs may not reference
|
||||
# different ShapeEnv instances.
|
||||
self = _extract_shape_env_and_assert_equal(args, kwargs)
|
||||
|
||||
# Otherwise, start recording and call the function.
|
||||
with self.recording():
|
||||
# Take a snapshot of the current tracked_fakes.
|
||||
tracked_fakes = (
|
||||
self.snapshot_tracked_fakes() if save_tracked_fakes else None
|
||||
)
|
||||
# Record the event for 'fn'.
|
||||
event = ShapeEnvEvent(
|
||||
fn, list(args), kwargs, tracked_fakes, name=fn.__name__
|
||||
)
|
||||
self.events.append(event)
|
||||
# Play the event on this ShapeEnv.
|
||||
return event.run(self)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
# Replays the ShapeEnvEvents list.
|
||||
# It assumes the first event is the constructor call.
|
||||
#
|
||||
# fn: transforms an old FX node into one corresponding to the newly created ShapeEnv.
|
||||
def replay_shape_env_events(events):
|
||||
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
||||
|
||||
constructor_event = events[0]
|
||||
assert constructor_event.f == ShapeEnv
|
||||
|
||||
# Constructs the new ShapeEnv.
|
||||
shape_env = constructor_event.run()
|
||||
|
||||
for event in events[1:]:
|
||||
try:
|
||||
# Actually replays each event.
|
||||
# We need to call create_mapping_fn every time, since the node list might
|
||||
# change after each event is replayed.
|
||||
event.run(shape_env)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"failed when running event: {event}") from e
|
||||
|
||||
return shape_env
|
||||
|
||||
|
||||
# FakeTensor metadata.
|
||||
# This is to be used in place of FakeTensor placeholders when calling
|
||||
# ShapeEnv.produce_guards.
|
||||
@dataclass
|
||||
class FakeTensorMeta:
|
||||
tensor_size: Tuple[Union[int, torch.SymInt], ...]
|
||||
tensor_stride: Tuple[Union[int, torch.SymInt], ...]
|
||||
tensor_storage_offset: Union[int, torch.SymInt]
|
||||
|
||||
def size(self) -> Tuple[Union[int, torch.SymInt], ...]:
|
||||
return self.tensor_size
|
||||
|
||||
def stride(self) -> Tuple[Union[int, torch.SymInt], ...]:
|
||||
return self.tensor_stride
|
||||
|
||||
def storage_offset(self) -> Union[int, torch.SymInt]:
|
||||
return self.tensor_storage_offset
|
||||
|
||||
def dim(self) -> int:
|
||||
return len(self.tensor_size)
|
||||
|
||||
@staticmethod
|
||||
def from_fake(fake) -> "FakeTensorMeta":
|
||||
return FakeTensorMeta(fake.size(), fake.stride(), fake.storage_offset())
|
||||
|
||||
|
||||
# [Note: ShapeEnv State Equality]
|
||||
# ===============================
|
||||
#
|
||||
# What is considered ShapeEnv state?
|
||||
# ----------------------------------
|
||||
# We consider to be the state of a ShapeEnv instance everything that
|
||||
# is not in the inline tuple inside remove_nonstate_variables function.
|
||||
# That is: the fields within ShapeEnv that modify the flow of execution
|
||||
# of the program.
|
||||
#
|
||||
# So, for example: the replacements field might influence on how an
|
||||
# expression is simplified. That, in turn, may result in a guard being
|
||||
# statically known (i.e. not added).
|
||||
#
|
||||
# On the other hand, var_to_stack serves only changes what is printed
|
||||
# in the screen, i.e. used only for debugging purposes. Therefore, we
|
||||
# should not consider it when comparing states.
|
||||
#
|
||||
# What to do on NotEqualError?
|
||||
# ----------------------------
|
||||
# Here are a few possible causes for getting a NotEqualError raised:
|
||||
#
|
||||
# 1. New field that does not belong in the ShapeEnv state.
|
||||
# For example: log field of type ShapeEnvLoggerAdapter. Different
|
||||
# ShapeEnv instances will always have different ShapeEnvLoggerAdapter
|
||||
# instances, i.e. equality comparison would fail.
|
||||
# Solution: add it to the inlined tuple inside remove_nonstate_variables
|
||||
# function inside check_equal method.
|
||||
#
|
||||
# 2. New field that is not directly comparable across instances.
|
||||
# For example: guards field of type List[ShapeGuard]. More specifically,
|
||||
# the ShapeGuard type holds an expression and a stack information
|
||||
# for debugging purposes. When replaying the even on a new ShapeEnv
|
||||
# instance, the stack would be different, which would trigger this error.
|
||||
# Solution: add a special case to the map_value function inside
|
||||
# check_equal function.
|
||||
#
|
||||
# 3. Mutation of ShapeEnv on some not recorded function.
|
||||
# If a mutation of the state of ShapeEnv happens inside a function
|
||||
# that is not recorded (or that no caller in the stack is recorded),
|
||||
# then, the replayed ShapeEnv won't catch that.
|
||||
# Solution: decorate the function with record_shape_env_event.
|
||||
|
||||
|
||||
# Checks whether the state of two ShapeEnv are equal w.r.t. the guards
|
||||
# returned by ShapeEnv.produce_guards.
|
||||
def shape_env_check_state_equal(env1, env2, non_state_variable_names, map_value):
|
||||
# Collect and remove variables that don't necessarily represent the state
|
||||
# of a ShapeEnv. Note: we copy the dictionary so that we don't modify the
|
||||
# instance itself.
|
||||
env1_vars = vars(env1).copy()
|
||||
env2_vars = vars(env2).copy()
|
||||
|
||||
for v in non_state_variable_names:
|
||||
env1_vars.pop(v)
|
||||
env2_vars.pop(v)
|
||||
|
||||
# Function for transforming the mismatched values into string.
|
||||
# Needed, since dict and set entries order might not be the same every time.
|
||||
def value_to_str(value: Any) -> str:
|
||||
if isinstance(value, dict):
|
||||
return (
|
||||
"{"
|
||||
+ ", ".join(f"{k}: {value[k]}" for k in sorted(value.keys(), key=str))
|
||||
+ "}"
|
||||
)
|
||||
if isinstance(value, set):
|
||||
return "{" + ", ".join(f"{v}" for v in sorted(value)) + "}"
|
||||
return str(value)
|
||||
|
||||
# Compares env1_vars with env2_vars.
|
||||
# Here, we allow the value of each field to be mapped, so that we appropriately
|
||||
# compare the two values.
|
||||
def compare_vars(
|
||||
map_value: Callable[[str, Any], Any]
|
||||
) -> List[Tuple[str, str, str]]:
|
||||
env1_set, env2_set = set(env1_vars), set(env2_vars)
|
||||
|
||||
# First, compare the set of keys in each vars dictionary.
|
||||
if env1_set != env2_set:
|
||||
raise NotEqualError(
|
||||
"field set mismatch:",
|
||||
[
|
||||
(
|
||||
"found unique fields:",
|
||||
str(sorted(env1_set - env2_set)),
|
||||
str(sorted(env2_set - env1_set)),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
# Then, sort the keys, and compare the mapped values of each key.
|
||||
sorted_keys = list(env1_set)
|
||||
sorted_keys.sort()
|
||||
|
||||
mapped_dict = [
|
||||
(k, map_value(k, env1_vars[k]), map_value(k, env2_vars[k]))
|
||||
for k in sorted_keys
|
||||
]
|
||||
|
||||
# Return a list of tuples representing the fields that did not match
|
||||
# alongside their respective mapped values.
|
||||
return [
|
||||
(f"{k}: values don't match.", value_to_str(val1), value_to_str(val2))
|
||||
for k, val1, val2 in mapped_dict
|
||||
if val1 != val2
|
||||
]
|
||||
|
||||
# Accumulate the mismatching fields.
|
||||
errors = compare_vars(map_value)
|
||||
|
||||
if len(errors) > 0:
|
||||
raise NotEqualError("field values don't match:", errors)
|
||||
|
||||
|
||||
class NotEqualError(Exception):
|
||||
def __init__(
|
||||
self,
|
||||
msg: str,
|
||||
mismatched: List[Tuple[str, str, str]],
|
||||
) -> None:
|
||||
details = "\n".join(
|
||||
[
|
||||
"\n".join(
|
||||
[
|
||||
f"==> {inner_msg}",
|
||||
f" > Left: {str1}",
|
||||
f" > Right: {str2}",
|
||||
]
|
||||
)
|
||||
for inner_msg, str1, str2 in mismatched
|
||||
]
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
f"""\
|
||||
ShapeEnv not equal: {msg}
|
||||
|
||||
{details}
|
||||
"""
|
||||
)
|
||||
|
|
@ -15,12 +15,20 @@ from contextlib import contextmanager
|
|||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from functools import lru_cache
|
||||
from typing import Any, cast, Callable, Dict, List, Optional, Set, Tuple, Type, Union
|
||||
from typing import Any, cast, Callable, Dict, List, Optional, Sequence, Set, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
import torch.fx
|
||||
import torch.fx.traceback as fx_traceback
|
||||
|
||||
from torch.fx.experimental.recording import (
|
||||
FakeTensorMeta,
|
||||
ShapeEnvEvent,
|
||||
record_shapeenv_event,
|
||||
replay_shape_env_events,
|
||||
shape_env_check_state_equal
|
||||
)
|
||||
|
||||
# NB: The sym_* functions are used via getattr() and must be imported here.
|
||||
from torch import ( # noqa: F401
|
||||
sym_float,
|
||||
|
|
@ -304,6 +312,8 @@ def guard_scalar(a):
|
|||
else:
|
||||
raise AssertionError(f"unrecognized scalar {a}")
|
||||
|
||||
|
||||
@record_shapeenv_event()
|
||||
def _constrain_symbol_range(shape_env, s: sympy.Symbol, compiler_min: int, compiler_max: int, runtime_min: int, runtime_max: int):
|
||||
if r := shape_env.var_to_range.get(s, None):
|
||||
shape_env.var_to_range[s] = ValueRanges(
|
||||
|
|
@ -349,6 +359,7 @@ def _advise_is_size(a):
|
|||
if isinstance(a, SymInt) and isinstance(a.node.expr, sympy.Symbol):
|
||||
_constrain_range_for_size(a)
|
||||
|
||||
@record_shapeenv_event()
|
||||
def _constrain_range_for_size(a, min: Optional[int] = None, max: Optional[int] = None):
|
||||
"""
|
||||
This function is NOT INTENDED to be used by itself.
|
||||
|
|
@ -387,6 +398,7 @@ def _constrain_range_for_size(a, min: Optional[int] = None, max: Optional[int] =
|
|||
|
||||
|
||||
# inclusive both ways
|
||||
@record_shapeenv_event()
|
||||
def constrain_range(a, *, min: Optional[int], max: Optional[int] = None):
|
||||
"""
|
||||
Applies a constraint that the passed in SymInt must lie between min-max
|
||||
|
|
@ -456,6 +468,7 @@ def constrain_range(a, *, min: Optional[int], max: Optional[int] = None):
|
|||
)
|
||||
|
||||
|
||||
@record_shapeenv_event()
|
||||
def constrain_unify(a, b):
|
||||
"""
|
||||
Given two SymInts, constrain them so that they must be equal. NB:
|
||||
|
|
@ -772,6 +785,9 @@ class SymNode:
|
|||
# the translation validation problem.
|
||||
self.fx_node = fx_node if _translation_validation_enabled() else None
|
||||
|
||||
def with_shape_env(self, shape_env: "ShapeEnv") -> "SymNode":
|
||||
return SymNode(self._expr, shape_env, self.pytype, self._hint, self.constant, self.fx_node)
|
||||
|
||||
@property
|
||||
def expr(self):
|
||||
return self.shape_env.replace(self._expr)
|
||||
|
|
@ -2030,7 +2046,56 @@ TLS = threading.local()
|
|||
|
||||
|
||||
class ShapeEnv:
|
||||
# This is a wrapper over the actual __init__ function.
|
||||
#
|
||||
# Where to add a new constructor parameter to ShapeEnv?
|
||||
# =====================================================
|
||||
# This __init__ function should be used only for parameters related to event recording.
|
||||
# These are parameters that we don't wish to pass down the road to new ShapeEnv instances
|
||||
# created from replaying events.
|
||||
#
|
||||
# If you wish to add a parameter to the constructor of ShapeEnv, unrelated to event
|
||||
# recording, do so in the _init function.
|
||||
def __init__(
|
||||
self, *,
|
||||
should_record_events: Optional[bool] = None,
|
||||
tracked_fakes: Optional[List[Any]] = None,
|
||||
**kwargs
|
||||
) -> None:
|
||||
self._init(**kwargs)
|
||||
|
||||
# Disable event recording when replaying.
|
||||
kwargs["should_record_events"] = False
|
||||
|
||||
# If not specified, enable event recording if both:
|
||||
# - Translation validation is on
|
||||
# - Translation validation bisection is not disabled
|
||||
self.should_record_events = (
|
||||
should_record_events
|
||||
if should_record_events is not None
|
||||
else (
|
||||
_translation_validation_enabled()
|
||||
and not torch._dynamo.config.translation_validation_no_bisect
|
||||
)
|
||||
)
|
||||
|
||||
# Enable event recording check if both:
|
||||
# - It should record events
|
||||
# - The recording check is enabled
|
||||
self.check_recorded_events = (
|
||||
self.should_record_events and torch._dynamo.config.check_shape_env_recorded_events
|
||||
)
|
||||
|
||||
# This will make sure we only record the top-level function call.
|
||||
self.is_recording = not self.should_record_events
|
||||
# Keep track of the list of tracked fakes.
|
||||
self.tracked_fakes = tracked_fakes
|
||||
# List of events for reconstructing ShapeEnv at arbitrary points in time.
|
||||
self.events: List[ShapeEnvEvent] = (
|
||||
[ShapeEnvEvent(ShapeEnv, kwargs=kwargs)] if self.should_record_events else []
|
||||
)
|
||||
|
||||
def _init(
|
||||
self, *,
|
||||
allow_scalar_outputs=True,
|
||||
allow_dynamic_output_shape_ops=True,
|
||||
|
|
@ -2153,6 +2218,110 @@ class ShapeEnv:
|
|||
# This is needed when 'deepcopy'-ing this object.
|
||||
self.graph.inserting_before(self.graph.output(None))
|
||||
|
||||
# Mapping of each node name to the node itself.
|
||||
#
|
||||
# This is useful for matching an FX node from a recorded ShapeEnv.graph
|
||||
# to the FX node of the ShapeEnv we are running the event on.
|
||||
#
|
||||
# Whenever you add a node to self.graph, you must add a mapping to this
|
||||
# variable. Otherwise, the built FX graph on the replayed ShapeEnv will
|
||||
# not be valid.
|
||||
self.name_to_node: Dict[str, torch.fx.Node] = {}
|
||||
|
||||
def check_equal(self, other: "ShapeEnv") -> None:
|
||||
# ShapeEnv fields that are not relevant for the outcome of
|
||||
# ShapeEnv.produce_guards call:
|
||||
# - Debugging variables
|
||||
# - Translation validation related variables
|
||||
# - Events recording related variables
|
||||
non_state_variable_names = (
|
||||
"counter",
|
||||
"log",
|
||||
"var_to_stack",
|
||||
"fx_node_cache",
|
||||
"graph",
|
||||
"validator",
|
||||
"check_recorded_events",
|
||||
"should_record_events",
|
||||
"is_recording",
|
||||
"tracked_fakes",
|
||||
"events",
|
||||
)
|
||||
|
||||
# Mapping of the value of each to-be-compared field into the values that
|
||||
# should actually be compared.
|
||||
#
|
||||
# You should modify this if, for example, the field that holds state and
|
||||
# debugging information. e.g. ShapeGuard holds the actual guard (sympy.Expr)
|
||||
# and the stack when it was added to the set of guards. In order to compare
|
||||
# it, we throw away the stack information.
|
||||
def map_value(key: str, value: Any) -> Any:
|
||||
if key in ("unbacked_symfloat_counter", "unbacked_symint_counter"):
|
||||
from copy import copy
|
||||
|
||||
# For itertools.count(), we compare the next integer returned
|
||||
# by the count iterators. Not that we need to copy the iterator
|
||||
# first. Otherwise we are mutating the object.
|
||||
return next(copy(value))
|
||||
elif key == "guards":
|
||||
# Transform the list of ShapeGuard into a list of expressions.
|
||||
return [g.expr for g in value]
|
||||
elif key == "var_to_guards":
|
||||
# Transform the tuple of optional ShapeGuards of each entry into
|
||||
# a tuple of optional expressions.
|
||||
return {
|
||||
s: (
|
||||
lb.expr if lb is not None else None,
|
||||
ub.expr if ub is not None else None,
|
||||
)
|
||||
for s, (lb, ub) in value.items()
|
||||
}
|
||||
elif key == "deferred_runtime_asserts":
|
||||
# Transform the list of RuntimeAsserts into a list of expressions.
|
||||
return {s: [ra.expr for ra in ras] for s, ras in value.items()}
|
||||
elif key == "name_to_node":
|
||||
# Compare just the set of keys is the same.
|
||||
return set(value.keys())
|
||||
return value
|
||||
|
||||
shape_env_check_state_equal(self, other, non_state_variable_names, map_value)
|
||||
|
||||
def snapshot_tracked_fakes(self) -> Optional[List[Any]]:
|
||||
if self.tracked_fakes is None:
|
||||
return None
|
||||
|
||||
from torch._dynamo.variables.builder import TrackedFake
|
||||
|
||||
def maybe_transform_fake(fake: TrackedFake):
|
||||
inner_fake = fake.fake \
|
||||
if isinstance(fake.fake, torch.SymInt) \
|
||||
else FakeTensorMeta.from_fake(fake.fake)
|
||||
# Even though TrackedFake accepts either a Union[SymInt, FakeTensor], here we give it a
|
||||
# FakeTensorMeta for two reasons:
|
||||
# 1. this is all the information we need when recording ShapeEnvEvents.
|
||||
# 2. it works even if each TrackedFake changes its metadata.
|
||||
return TrackedFake(inner_fake, fake.source, fake.constraint_dims) # type: ignore[arg-type]
|
||||
|
||||
return [maybe_transform_fake(fake) for fake in self.tracked_fakes]
|
||||
|
||||
def inc_tracked_fakes_length(self) -> None:
|
||||
self.tracked_fakes_length += 1
|
||||
|
||||
def set_tracked_fakes_length(self, i: int) -> None:
|
||||
self.tracked_fakes_length = i
|
||||
|
||||
def last_event_index(self) -> int:
|
||||
return len(self.events) - 1
|
||||
|
||||
@contextmanager
|
||||
def recording(self):
|
||||
self.is_recording = True
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self.is_recording = False
|
||||
|
||||
@record_shapeenv_event()
|
||||
def freeze(self):
|
||||
self.frozen = True
|
||||
|
||||
|
|
@ -2180,6 +2349,7 @@ class ShapeEnv:
|
|||
if _translation_validation_enabled():
|
||||
self.validator.validate()
|
||||
|
||||
@record_shapeenv_event()
|
||||
def create_fx_call_function(
|
||||
self,
|
||||
op: Callable,
|
||||
|
|
@ -2201,12 +2371,13 @@ class ShapeEnv:
|
|||
return None, fresh
|
||||
|
||||
fresh = True
|
||||
lifted_op = z3op(op, self.validator)
|
||||
|
||||
# If translation validation is enabled, all arguments must have its
|
||||
# own FX node.
|
||||
assert all(a is not None for a in args), f"missing arg in FX graph ({op.__name__}): {args}"
|
||||
lifted_op = z3op(op, self.validator)
|
||||
self.fx_node_cache[node_key] = self.graph.call_function(lifted_op, args)
|
||||
node = self.fx_node_cache[node_key] = self.graph.call_function(lifted_op, args)
|
||||
self.name_to_node[node.name] = node
|
||||
|
||||
return self.fx_node_cache.get(node_key, None), fresh
|
||||
|
||||
|
|
@ -2228,29 +2399,37 @@ class ShapeEnv:
|
|||
self._add_z3var(symbol, type)
|
||||
# Create the FX placeholder out of a mangled name.
|
||||
mangled_name = re.sub(r'[^a-zA-Z0-9]', '_', re.sub(r'[()]', '', symbol.name))
|
||||
node = self.graph.placeholder(mangled_name)
|
||||
node = self.fx_node_cache[node_key] = self.graph.placeholder(mangled_name)
|
||||
self.name_to_node[node.name] = node
|
||||
# Attach the 'symbol' to the placeholder so that we can retrieve
|
||||
# the Z3 variable later.
|
||||
node.meta["symbol"] = symbol
|
||||
# Put it in the cache.
|
||||
self.fx_node_cache[node_key] = node
|
||||
|
||||
return self.fx_node_cache[node_key]
|
||||
|
||||
def remove_fx_node(self, node: Optional[torch.fx.Node]) -> None:
|
||||
if _translation_validation_enabled() and node is not None:
|
||||
self.name_to_node.pop(node.name)
|
||||
self.graph.erase_node(node)
|
||||
|
||||
def _suppress_guards_tls(self):
|
||||
return getattr(TLS, "suppress_guards", False)
|
||||
|
||||
@record_shapeenv_event()
|
||||
def suppress_guards_enter(self):
|
||||
TLS.suppress_guards = True
|
||||
|
||||
@record_shapeenv_event()
|
||||
def suppress_guards_exit(self):
|
||||
TLS.suppress_guards = False
|
||||
|
||||
@contextmanager
|
||||
def suppress_guards(self):
|
||||
TLS.suppress_guards = True
|
||||
self.suppress_guards_enter()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
TLS.suppress_guards = False
|
||||
self.suppress_guards_exit()
|
||||
|
||||
def _get_key(self):
|
||||
"""
|
||||
|
|
@ -2260,7 +2439,7 @@ class ShapeEnv:
|
|||
return (len(self.replacements), len(self.divisible), self.num_deferred_runtime_asserts)
|
||||
|
||||
def _produce_dyn_sizes(self,
|
||||
ex: torch.Tensor,
|
||||
ex_size: Sequence[int],
|
||||
source: Source,
|
||||
dynamic_dims: DimList[DimDynamic],
|
||||
constraint_dims: DimList[DimConstraint]) -> List[sympy.Expr]:
|
||||
|
|
@ -2295,7 +2474,6 @@ class ShapeEnv:
|
|||
introduce new symbolic variables.
|
||||
"""
|
||||
|
||||
|
||||
# Dynamo may want to wrap FakeTensors with SymInt sizes up e.g. make_fx(opt_f(), tracing_mode="symbolic").
|
||||
# We create symbols in shape_env using the backed hints behind SymInt.
|
||||
|
||||
|
|
@ -2341,7 +2519,30 @@ class ShapeEnv:
|
|||
ex_size = tuple(maybe_specialize_sym_int_with_hint(sz) for sz in ex.size())
|
||||
ex_stride = tuple(maybe_specialize_sym_int_with_hint(sd) for sd in ex.stride())
|
||||
ex_storage_offset = maybe_specialize_sym_int_with_hint(ex.storage_offset())
|
||||
dim = ex.dim()
|
||||
|
||||
return self._create_symbolic_sizes_strides_storage_offset(
|
||||
ex_size,
|
||||
ex_stride,
|
||||
ex_storage_offset,
|
||||
[_is_dim_dynamic(ex, i) for i in range(ex.dim())],
|
||||
source,
|
||||
dynamic_dims=dynamic_dims,
|
||||
constraint_dims=constraint_dims
|
||||
)
|
||||
|
||||
@record_shapeenv_event()
|
||||
def _create_symbolic_sizes_strides_storage_offset(
|
||||
self,
|
||||
ex_size: Sequence[int],
|
||||
ex_stride: Sequence[int],
|
||||
ex_storage_offset: int,
|
||||
is_dim_dynamic: Sequence[bool],
|
||||
source: Source,
|
||||
*,
|
||||
dynamic_dims: Optional[DimList[DimDynamic]] = None,
|
||||
constraint_dims: Optional[DimList[DimConstraint]] = None,
|
||||
):
|
||||
dim = len(ex_size)
|
||||
|
||||
# Reimplement the legacy behavior
|
||||
if constraint_dims is None:
|
||||
|
|
@ -2351,7 +2552,7 @@ class ShapeEnv:
|
|||
for i in range(dim):
|
||||
# NB: This is encapsulation breaking! Legacy behavior was
|
||||
# bad.
|
||||
if _is_dim_dynamic(ex, i):
|
||||
if is_dim_dynamic[i]:
|
||||
r = DimDynamic.DYNAMIC
|
||||
elif self.assume_static_by_default:
|
||||
r = DimDynamic.STATIC
|
||||
|
|
@ -2432,6 +2633,7 @@ class ShapeEnv:
|
|||
# If you know what the current hint value of the SymInt to be created
|
||||
# is, pass it into hint. Otherwise, pass None and we will make our best
|
||||
# guess
|
||||
@record_shapeenv_event()
|
||||
def create_symintnode(
|
||||
self,
|
||||
sym: "sympy.Expr",
|
||||
|
|
@ -2458,6 +2660,7 @@ class ShapeEnv:
|
|||
return int(sym)
|
||||
return SymInt(SymNode(sym, self, int, hint, fx_node=fx_node))
|
||||
|
||||
@record_shapeenv_event()
|
||||
def create_unspecified_symint_and_symbol(self, value, source, dynamic_dim):
|
||||
return self.create_symintnode(
|
||||
self.create_unspecified_symbol(
|
||||
|
|
@ -2474,6 +2677,7 @@ class ShapeEnv:
|
|||
# for validation.
|
||||
return SymBool(SymNode(sym, self, bool, None))
|
||||
|
||||
@record_shapeenv_event()
|
||||
def create_unbacked_symfloat(self):
|
||||
symbol: sympy.Symbol = sympy.Symbol(f"f{next(self.unbacked_symfloat_counter)}")
|
||||
self.counter["create_unbacked_symbol"] += 1
|
||||
|
|
@ -2485,6 +2689,7 @@ class ShapeEnv:
|
|||
|
||||
return SymFloat(SymNode(symbol, self, float, None, fx_node=fx_node))
|
||||
|
||||
@record_shapeenv_event()
|
||||
def create_unbacked_symint(self):
|
||||
symbol: sympy.Symbol = sympy.Symbol(f"i{next(self.unbacked_symint_counter)}", integer=True)
|
||||
self.counter["create_unbacked_symbol"] += 1
|
||||
|
|
@ -2496,6 +2701,7 @@ class ShapeEnv:
|
|||
|
||||
return SymInt(SymNode(symbol, self, int, None, fx_node=fx_node))
|
||||
|
||||
@record_shapeenv_event()
|
||||
def create_unbacked_symbool(self):
|
||||
symbol: sympy.Symbol = sympy.Symbol(f"i{next(self.unbacked_symint_counter)}", integer=True)
|
||||
self.counter["create_unbacked_symbol"] += 1
|
||||
|
|
@ -2507,6 +2713,7 @@ class ShapeEnv:
|
|||
|
||||
return SymBool(SymNode(sympy.Eq(symbol, 1), self, bool, None, fx_node=fx_node))
|
||||
|
||||
@record_shapeenv_event()
|
||||
def create_unspecified_symbol(
|
||||
self,
|
||||
val: int,
|
||||
|
|
@ -2518,6 +2725,7 @@ class ShapeEnv:
|
|||
# assume that it will be neither positive nor negative.
|
||||
return self.create_symbol(val, source, dynamic_dim, constraint_dim, positive=None)
|
||||
|
||||
@record_shapeenv_event()
|
||||
def create_symbol(
|
||||
self,
|
||||
val: int,
|
||||
|
|
@ -2647,17 +2855,32 @@ class ShapeEnv:
|
|||
) -> List[str]:
|
||||
self.log.info("produce_guards")
|
||||
|
||||
# Check if we get to the same ShapeEnv state by replaying the recorded events.
|
||||
# This will create a new ShapeEnv instance, and call all recorded function
|
||||
# calls on this new instance. Finally, it will check whether this new instance
|
||||
# has equal state.
|
||||
#
|
||||
# It's important that we do it in the begining of this function, since it modifies
|
||||
# self.dim_constraints through its execution. Changes that happen in this method
|
||||
# aren't interesting, since this is the function call we wish to reproduce at the
|
||||
# end. If we wish to simply reproduce ShapeEnv instances even after this call,
|
||||
# this method should also be recorded.
|
||||
if self.check_recorded_events:
|
||||
shape_env = replay_shape_env_events(self.events)
|
||||
self.check_equal(shape_env)
|
||||
|
||||
assert len(placeholders) == len(sources)
|
||||
Tensorlike = (torch.Tensor, FakeTensorMeta)
|
||||
|
||||
# Expand optional inputs, or verify invariants are upheld
|
||||
if constraint_inputs is None:
|
||||
constraint_inputs = [
|
||||
[None] * t.dim() if isinstance(t, torch.Tensor) else None for t in placeholders
|
||||
[None] * t.dim() if isinstance(t, Tensorlike) else None for t in placeholders
|
||||
]
|
||||
else:
|
||||
assert len(constraint_inputs) == len(placeholders)
|
||||
for i, (t, constraint) in enumerate(zip(placeholders, constraint_inputs)):
|
||||
if isinstance(t, torch.Tensor):
|
||||
if isinstance(t, Tensorlike):
|
||||
if constraint is None:
|
||||
constraint_inputs[i] = [None] * t.dim()
|
||||
else:
|
||||
|
|
@ -2853,7 +3076,7 @@ class ShapeEnv:
|
|||
if isinstance(t, (SymInt, int)):
|
||||
track_symint(source, t)
|
||||
continue
|
||||
assert isinstance(t, torch.Tensor)
|
||||
assert isinstance(t, Tensorlike)
|
||||
if is_traceable_wrapper_subclass(t):
|
||||
# If our placeholder is a tensor subclass, then the "true" symints
|
||||
# come from the subclass's inner tensors.
|
||||
|
|
@ -3356,6 +3579,7 @@ class ShapeEnv:
|
|||
self._add_target_expr(sympy.Eq(a, expr))
|
||||
|
||||
@_lru_cache
|
||||
@record_shapeenv_event()
|
||||
def _find(self, a: "sympy.Symbol") -> "sympy.Expr":
|
||||
"""
|
||||
Implements a DSU-like algorithm to find the variable that represents a
|
||||
|
|
@ -3504,6 +3728,7 @@ class ShapeEnv:
|
|||
)
|
||||
|
||||
@lru_cache(256)
|
||||
@record_shapeenv_event(save_tracked_fakes=True)
|
||||
def evaluate_expr(self, orig_expr: "sympy.Expr", hint=None, fx_node=None):
|
||||
"""
|
||||
Given an expression, evaluates it, adding guards if necessary
|
||||
|
|
@ -3633,6 +3858,7 @@ class ShapeEnv:
|
|||
for ra in ras:
|
||||
ra.stack.cleanup()
|
||||
|
||||
@record_shapeenv_event(save_tracked_fakes=True)
|
||||
def defer_runtime_assert(self, orig_expr: "sympy.Expr", msg, fx_node=None):
|
||||
expr = orig_expr
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user