mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[dynamo] Add guard serialization for tensor matches. (#151318)
This is a proof-of-concept of how we could serialize a guard and deserialize it back from the bytes. The main behavioral change introduced in this diff is on CheckFunctionManager: ``` check_fn_manager = CheckFunctionManager(code, output_graph, guards_serialization_mode="save") guards_state: bytes = check_fn_manager.guards_state ``` Once `guards_serialization_mode` is set to `save`, CheckFunctionManager will return an addtional `bytes` object called `guards_state` which should contain all the information needed for deserializing guards later. When we load back guards state, we will set `guards_serialization_mode` is set to `load`: ``` output_graph_state = pickle.loads(guards_state) check_fn_manager = CheckFunctionManager(code, output_graph_state, guards_serialization_mode="load") ``` # TENSOR_MATCH Since we have many types of guards to support, we will break the work into small diffs instead of a single diff to support every guards. We kick off the work from TENSOR_MATCH from this diff. # Testing For each type of guard we will test it like the following: 1. Use guard_filter_fn to select 1 type of guard each time. 2. Call InstructionTranslator directly on an example function to get OutputGraph and CheckFunctionManager (reference guard manager) 3. Serialize->deserialize the output graph state and re-build the guards with a new CheckFunctionManager (loaded guard manager) 4. Throw a set of example inputs to both reference and loaded guard manager to see if their behavior match. Pull Request resolved: https://github.com/pytorch/pytorch/pull/151318 Approved by: https://github.com/jansel, https://github.com/anijain2305
This commit is contained in:
parent
6e8602b558
commit
a34c28e0d2
|
|
@ -398,6 +398,10 @@ class DispatchKeySet final {
|
|||
return repr_;
|
||||
}
|
||||
|
||||
static DispatchKeySet from_raw_repr(uint64_t x) {
|
||||
return DispatchKeySet(RAW, x);
|
||||
}
|
||||
|
||||
DispatchKey highestFunctionalityKey() const {
|
||||
auto functionality_idx = indexOfHighestBit();
|
||||
// This means that none of the functionality bits were set.
|
||||
|
|
|
|||
|
|
@ -295,7 +295,15 @@ num_guards_executed=0)
|
|||
x = torch.randn(4, 4)
|
||||
size = list(x.size())
|
||||
stride = list(x.stride())
|
||||
guard_manager.add_tensor_match_guard(x, size, stride, "x", ["check_tensor(x)"])
|
||||
guard_manager.add_tensor_match_guard(
|
||||
x,
|
||||
size,
|
||||
stride,
|
||||
"x",
|
||||
["check_tensor(x)"],
|
||||
type(x),
|
||||
torch._C._dispatch_keys(x),
|
||||
)
|
||||
self.assertTrue(guard_manager.check(x))
|
||||
self.assertTrue(guard_manager.check_verbose(x).result)
|
||||
self.assertTrue(guard_manager.check(torch.randn(4, 4)))
|
||||
|
|
|
|||
147
test/dynamo/test_guard_serialization.py
Normal file
147
test/dynamo/test_guard_serialization.py
Normal file
|
|
@ -0,0 +1,147 @@
|
|||
# Owner(s): ["module: dynamo"]
|
||||
|
||||
import dataclasses
|
||||
import pickle
|
||||
import sys
|
||||
import types
|
||||
|
||||
import torch
|
||||
import torch._dynamo.testing
|
||||
import torch._inductor.config
|
||||
import torch._inductor.test_case
|
||||
import torch.onnx.operators
|
||||
import torch.utils.cpp_extension
|
||||
from torch._dynamo.bytecode_transformation import transform_code_object
|
||||
from torch._dynamo.guards import CheckFunctionManager, CompileId
|
||||
from torch._dynamo.symbolic_convert import InstructionTranslator
|
||||
from torch._dynamo.utils import dynamo_timed, get_metrics_context
|
||||
from torch._guards import compile_context, CompileContext, tracing
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _FrameState:
|
||||
f_locals: dict
|
||||
f_globals: dict
|
||||
f_code: types.CodeType
|
||||
f_builtins: dict
|
||||
|
||||
|
||||
class TestGuardSerialization(torch._inductor.test_case.TestCase):
|
||||
def _tracefunc(self, frame, event, arg):
|
||||
if event != "call":
|
||||
return
|
||||
|
||||
if self._frame_state is not None:
|
||||
return
|
||||
|
||||
self._frame_state = _FrameState(
|
||||
f_locals=frame.f_locals,
|
||||
f_globals=frame.f_globals,
|
||||
f_code=frame.f_code,
|
||||
f_builtins=frame.f_builtins,
|
||||
)
|
||||
|
||||
def _test_serialization(self, guard_type, fn, *args, **kwargs):
|
||||
self._frame_state = None
|
||||
sys.settrace(self._tracefunc)
|
||||
try:
|
||||
fn(*args, **kwargs)
|
||||
finally:
|
||||
sys.settrace(None)
|
||||
|
||||
assert self._frame_state is not None
|
||||
|
||||
def guard_filter_fn(guards):
|
||||
return [g.guard_type == guard_type for g in guards]
|
||||
|
||||
ref_gm = None
|
||||
loaded_gm = None
|
||||
|
||||
def transform(instructions: list, code_options: dict[str, object]):
|
||||
"""
|
||||
The goal is here is not to reimplement dynamo, but just to have a
|
||||
simplified version to extract the state from symbolic convert.
|
||||
Should not work on all cases, but should work on simple functions
|
||||
in this test file.
|
||||
"""
|
||||
nonlocal ref_gm
|
||||
nonlocal loaded_gm
|
||||
|
||||
tracer = InstructionTranslator(
|
||||
instructions,
|
||||
self._frame_state.f_code,
|
||||
self._frame_state.f_locals,
|
||||
self._frame_state.f_globals,
|
||||
self._frame_state.f_builtins,
|
||||
(), # TODO closure
|
||||
[], # TODO tf_mode_stack,
|
||||
code_options,
|
||||
lambda gm, *args, **kwargs: gm.forward,
|
||||
one_graph=False,
|
||||
export=False,
|
||||
export_constraints=None,
|
||||
frame_state=None,
|
||||
speculation_log=None,
|
||||
exn_vt_stack=None,
|
||||
distributed_state=None,
|
||||
)
|
||||
with compile_context(CompileContext(CompileId(0, 0))), tracing(
|
||||
tracer.output.tracing_context
|
||||
), tracer.set_current_tx(), get_metrics_context(), dynamo_timed(""):
|
||||
tracer.run()
|
||||
|
||||
check_fn_manager = CheckFunctionManager(
|
||||
self._frame_state.f_code,
|
||||
tracer.output,
|
||||
guard_filter_fn=guard_filter_fn,
|
||||
guards_serialization_mode="save",
|
||||
)
|
||||
ref_gm = check_fn_manager.guard_manager
|
||||
guards_state = check_fn_manager.guards_state
|
||||
self.assertIsNotNone(guards_state)
|
||||
guards_state = pickle.loads(guards_state)
|
||||
|
||||
check_fn_manager = CheckFunctionManager(
|
||||
self._frame_state.f_code,
|
||||
guards_state.output_graph,
|
||||
guards_serialization_mode="load",
|
||||
)
|
||||
loaded_gm = check_fn_manager.guard_manager
|
||||
|
||||
try:
|
||||
transform_code_object(self._frame_state.f_code, transform)
|
||||
finally:
|
||||
self._frame_state = None
|
||||
|
||||
self.assertIsNotNone(ref_gm)
|
||||
self.assertIsNotNone(loaded_gm)
|
||||
return ref_gm, loaded_gm
|
||||
|
||||
def _test_check_fn(self, ref, loaded, inputs, expected):
|
||||
self.assertIsInstance(inputs, dict)
|
||||
self.assertEqual(ref.check(inputs), expected)
|
||||
self.assertEqual(ref.check(inputs), loaded.check(inputs))
|
||||
|
||||
def test_tensor_match(self):
|
||||
def f(x: torch.Tensor):
|
||||
return x + 1
|
||||
|
||||
ref, loaded = self._test_serialization(
|
||||
"TENSOR_MATCH", f, torch.ones(2, dtype=torch.float32)
|
||||
)
|
||||
self._test_check_fn(
|
||||
ref, loaded, {"x": torch.randn(2, dtype=torch.float32)}, True
|
||||
)
|
||||
self._test_check_fn(
|
||||
ref, loaded, {"x": torch.randn(3, dtype=torch.float32)}, False
|
||||
)
|
||||
self._test_check_fn(
|
||||
ref, loaded, {"x": torch.randn(2, dtype=torch.float64)}, False
|
||||
)
|
||||
self._test_check_fn(ref, loaded, {"x": None}, False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
run_tests()
|
||||
|
|
@ -1671,6 +1671,8 @@ class DispatchKeySet:
|
|||
def __sub__(self, other: DispatchKeySet) -> DispatchKeySet: ...
|
||||
def __and__(self, other: DispatchKeySet) -> DispatchKeySet: ...
|
||||
def raw_repr(self) -> _int: ...
|
||||
@staticmethod
|
||||
def from_raw_repr(raw: _int) -> DispatchKeySet: ...
|
||||
def highestPriorityTypeId(self) -> DispatchKey: ...
|
||||
def has(self, k: _dispatchkey) -> _bool: ...
|
||||
def add(self, k: _dispatchkey) -> DispatchKeySet: ...
|
||||
|
|
|
|||
|
|
@ -27,8 +27,10 @@ import enum
|
|||
import functools
|
||||
import importlib
|
||||
import inspect
|
||||
import io
|
||||
import logging
|
||||
import math
|
||||
import pickle
|
||||
import sys
|
||||
import textwrap
|
||||
import types
|
||||
|
|
@ -57,9 +59,9 @@ from torch._C._dynamo.guards import (
|
|||
RootGuardManager,
|
||||
)
|
||||
from torch._dynamo.source import (
|
||||
get_global_source_name,
|
||||
IndexedSource,
|
||||
is_from_flatten_script_object_source,
|
||||
is_from_global_source,
|
||||
is_from_local_source,
|
||||
is_from_optimizer_source,
|
||||
TensorProperty,
|
||||
|
|
@ -166,6 +168,8 @@ except ModuleNotFoundError:
|
|||
if TYPE_CHECKING:
|
||||
from sympy import Symbol
|
||||
|
||||
from torch._dynamo.output_graph import OutputGraphGuardsState
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
guards_log = torch._logging.getArtifactLogger(__name__, "guards")
|
||||
|
|
@ -494,10 +498,9 @@ def convert_to_concrete_values(size_or_stride):
|
|||
return converted
|
||||
|
||||
|
||||
def get_tensor_guard_code_part(value, name, sizes, strides):
|
||||
pytype = type(value)
|
||||
def get_tensor_guard_code_part(value, name, sizes, strides, pytype, dispatch_keys):
|
||||
dispatch_key = (
|
||||
torch._C._dispatch_keys(value) | torch._C._dispatch_tls_local_include_set()
|
||||
dispatch_keys | torch._C._dispatch_tls_local_include_set()
|
||||
) - torch._C._dispatch_tls_local_exclude_set()
|
||||
dtype = value.dtype
|
||||
device_index = value.device.index
|
||||
|
|
@ -2142,6 +2145,15 @@ class GuardBuilder(GuardBuilderBase):
|
|||
value = value()
|
||||
|
||||
value = value if value is not None else self.get(guard.name)
|
||||
|
||||
pytype = type(value)
|
||||
dispatch_keys = torch._C._dispatch_keys(value)
|
||||
if isinstance(value, torch._subclasses.FakeTensor):
|
||||
if value.pytype is not None:
|
||||
pytype = value.pytype
|
||||
if value.dispatch_keys is not None:
|
||||
dispatch_keys = value.dispatch_keys
|
||||
|
||||
assert isinstance(value, torch.Tensor)
|
||||
|
||||
if config.log_compilation_metrics and isinstance(value, torch.nn.Parameter):
|
||||
|
|
@ -2218,7 +2230,9 @@ class GuardBuilder(GuardBuilderBase):
|
|||
stride = convert_to_concrete_values(metadata["stride"])
|
||||
|
||||
verbose_code_parts = get_verbose_code_parts(
|
||||
get_tensor_guard_code_part(value, tensor_name, size, stride),
|
||||
get_tensor_guard_code_part(
|
||||
value, tensor_name, size, stride, pytype, dispatch_keys
|
||||
),
|
||||
guard,
|
||||
)
|
||||
guard_manager.add_tensor_match_guard(
|
||||
|
|
@ -2227,6 +2241,8 @@ class GuardBuilder(GuardBuilderBase):
|
|||
stride,
|
||||
tensor_name,
|
||||
verbose_code_parts,
|
||||
pytype,
|
||||
dispatch_keys,
|
||||
)
|
||||
|
||||
# We consider TENSOR_MATCH guard to be important enough to be
|
||||
|
|
@ -2459,6 +2475,66 @@ class DeletedGuardManagerWrapper(GuardManagerWrapper):
|
|||
self.diff_guard_root = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class GuardsState:
|
||||
output_graph: OutputGraphGuardsState
|
||||
# TODO SHAPE_ENV states here
|
||||
|
||||
|
||||
class GuardsStatePickler(pickle.Pickler):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.fake_mode = torch._subclasses.FakeTensorMode()
|
||||
self.tensor_converter = torch._subclasses.fake_tensor.FakeTensorConverter()
|
||||
|
||||
@classmethod
|
||||
def _unpickle_module(cls, state):
|
||||
mod = torch.nn.Module()
|
||||
mod.__setstate__(state)
|
||||
return mod
|
||||
|
||||
@classmethod
|
||||
def _unpickle_tensor(cls, meta_tensor, device, pytype, dispatch_keys_raw):
|
||||
fake_mode = torch._subclasses.FakeTensorMode()
|
||||
tensor_converter = torch._subclasses.fake_tensor.FakeTensorConverter()
|
||||
return tensor_converter.from_meta_and_device(
|
||||
fake_mode,
|
||||
meta_tensor,
|
||||
device,
|
||||
pytype,
|
||||
torch._C.DispatchKeySet.from_raw_repr(dispatch_keys_raw),
|
||||
)
|
||||
|
||||
def reducer_override(self, obj):
|
||||
if isinstance(obj, torch.Tensor) and obj.device.type != "meta":
|
||||
return type(self)._unpickle_tensor, (
|
||||
torch.empty_like(obj, device="meta"),
|
||||
obj.device,
|
||||
type(obj),
|
||||
torch._C._dispatch_keys(obj).raw_repr(),
|
||||
)
|
||||
|
||||
elif isinstance(obj, torch.nn.Module):
|
||||
if obj.__class__.__getstate__ == torch.nn.Module.__getstate__:
|
||||
return type(self)._unpickle_module, (obj.__getstate__(),)
|
||||
|
||||
if type(obj).__qualname__ != type(obj).__name__:
|
||||
raise RuntimeError(
|
||||
f"Type {type(obj)} for object {obj} cannot be saved "
|
||||
+ "into torch.compile() package since it's defined in local scope. "
|
||||
+ "Please define the class at global scope (top level of a module)."
|
||||
)
|
||||
|
||||
return NotImplemented
|
||||
|
||||
|
||||
def pickle_guards_state(state: GuardsState) -> bytes:
|
||||
buf = io.BytesIO()
|
||||
pickler = GuardsStatePickler(buf)
|
||||
pickler.dump(state)
|
||||
return buf.getvalue()
|
||||
|
||||
|
||||
# NB: Naively, you'd expect this to only be a function that produces
|
||||
# the callable that constitutes the guard. However, there is some
|
||||
# delicate handling for invalidating this check function when the
|
||||
|
|
@ -2474,6 +2550,7 @@ class CheckFunctionManager:
|
|||
guard_filter_fn: Optional[
|
||||
Callable[[list[GuardFilterEntry]], list[bool]]
|
||||
] = None,
|
||||
guards_serialization_mode: Optional[str] = None,
|
||||
):
|
||||
guards = output_graph.guards if output_graph else None
|
||||
self._weakrefs: dict[int, ReferenceType[object]] = {}
|
||||
|
|
@ -2488,6 +2565,7 @@ class CheckFunctionManager:
|
|||
self.torch_function_mode_stack = (
|
||||
output_graph.torch_function_mode_stack if output_graph else None
|
||||
)
|
||||
self.guards_serialization_mode = guards_serialization_mode
|
||||
|
||||
if not justknobs_check("pytorch/compiler:guard_nn_modules"):
|
||||
log.warning("guard_nn_modules is turned off using justknobs killswitch")
|
||||
|
|
@ -2508,7 +2586,7 @@ class CheckFunctionManager:
|
|||
else:
|
||||
has_value = True
|
||||
value = builder.get(guard.name)
|
||||
is_global = is_from_global_source(guard.originating_source)
|
||||
is_global = get_global_source_name(guard.originating_source) is not None
|
||||
guard_fn = guard.create_fn
|
||||
if isinstance(guard_fn, functools.partial):
|
||||
guard_fn = guard.create_fn.func
|
||||
|
|
@ -2558,7 +2636,7 @@ class CheckFunctionManager:
|
|||
# TODO(anijain2305, ydwu4) - Skipping export because of following test
|
||||
# python -s test/dynamo/test_export.py -k test_export_with_symbool_inputs
|
||||
latency = 0.0
|
||||
if not output_graph.export:
|
||||
if not output_graph.export and self.guards_serialization_mode != "load":
|
||||
if not self.guard_manager.check(output_graph.local_scope):
|
||||
reasons = get_guard_fail_reason_helper(
|
||||
self.guard_manager, # type: ignore[arg-type]
|
||||
|
|
@ -2591,6 +2669,40 @@ class CheckFunctionManager:
|
|||
# account for, we simply increment at the toplevel instead.
|
||||
CompileEventLogger.increment_toplevel("guard_latency_us", int(latency))
|
||||
|
||||
self.guards_state: Optional[bytes] = None
|
||||
if self.guards_serialization_mode == "save":
|
||||
output_graph_guards_state = self.output_graph.dump_guards_state()
|
||||
# Only serialize the global variables that are actually used in guards.
|
||||
used_global_vars = set()
|
||||
for guard in sorted_guards:
|
||||
if name := get_global_source_name(guard.originating_source):
|
||||
assert isinstance(name, str)
|
||||
used_global_vars.add(name)
|
||||
|
||||
output_graph_guards_state = dataclasses.replace(
|
||||
output_graph_guards_state,
|
||||
global_scope={
|
||||
k: v
|
||||
for k, v in output_graph_guards_state.global_scope.items()
|
||||
if k in used_global_vars
|
||||
},
|
||||
_guards=torch._guards.GuardsSet(
|
||||
{
|
||||
dataclasses.replace(
|
||||
guard,
|
||||
obj_weakref=None,
|
||||
guarded_class_weakref=None,
|
||||
create_fn=guard.inner_create_fn(),
|
||||
)
|
||||
for guard in sorted_guards
|
||||
}
|
||||
),
|
||||
)
|
||||
guards_state = GuardsState(
|
||||
output_graph=output_graph_guards_state,
|
||||
)
|
||||
self.guards_state = pickle_guards_state(guards_state)
|
||||
|
||||
# TODO: don't do the string rep, do something more structured here
|
||||
torch._logging.trace_structured(
|
||||
"dynamo_cpp_guards_str",
|
||||
|
|
@ -2757,9 +2869,7 @@ class CheckFunctionManager:
|
|||
)
|
||||
|
||||
aotautograd_guards: list[GuardEnvExpr] = (
|
||||
self.output_graph.tracing_context.guards_context.aotautograd_guards
|
||||
if self.output_graph
|
||||
else []
|
||||
self.output_graph.aotautograd_guards if self.output_graph else []
|
||||
)
|
||||
|
||||
# TODO(anijain2305) - There is a duplicate logic in Dynamo to find
|
||||
|
|
|
|||
|
|
@ -283,7 +283,43 @@ class WrapperBackend:
|
|||
Scope = dict[str, object]
|
||||
|
||||
|
||||
class OutputGraph:
|
||||
@dataclass
|
||||
class OutputGraphGuardsState:
|
||||
"""
|
||||
A base class containing fields that are considered "persistent" when we
|
||||
want to save all the important state for reconstrucing guards in a different
|
||||
process. Normally we don't need to add states here, but we may have to when
|
||||
the information is needed to serialize the guards, so the fields here are
|
||||
supposed to be serializable as a requirement.
|
||||
"""
|
||||
|
||||
local_scope: Scope
|
||||
global_scope: Scope
|
||||
# This records the initial torch function mode stack for guarding
|
||||
torch_function_mode_stack: list[torch.overrides.TorchFunctionMode]
|
||||
guard_on_key_order: set[str]
|
||||
# Map from graph input's `Source` to sizes / strides metadata
|
||||
input_source_to_sizes_strides: dict[Source, dict[str, Any]]
|
||||
export: bool = False
|
||||
export_constraints: bool = False
|
||||
|
||||
_guards: Optional[torch._guards.GuardsSet] = None
|
||||
_aotautograd_guards: Optional[list[torch._guards.GuardEnvExpr]] = None
|
||||
|
||||
@property
|
||||
def shape_env(self):
|
||||
raise AssertionError(f"shape_env shouldn't be accessed from {type(self)}")
|
||||
|
||||
@property
|
||||
def guards(self):
|
||||
return self._guards
|
||||
|
||||
@property
|
||||
def aotautograd_guards(self):
|
||||
return self._aotautograd_guards
|
||||
|
||||
|
||||
class OutputGraph(OutputGraphGuardsState):
|
||||
"""
|
||||
Wrapper class to hold outputs of InstructionTranslator. Mainly the
|
||||
generated fx.Graph.
|
||||
|
|
@ -309,6 +345,13 @@ class OutputGraph:
|
|||
f_code,
|
||||
torch_function_mode_stack,
|
||||
):
|
||||
super().__init__(
|
||||
local_scope,
|
||||
global_scope,
|
||||
torch_function_mode_stack,
|
||||
guard_on_key_order=set(),
|
||||
input_source_to_sizes_strides={},
|
||||
)
|
||||
self.tracers = [SubgraphTracer(self, is_export=export)]
|
||||
# Map from graph input's `Source` to its `VariableTracker` to
|
||||
# de-duplicate graph inputs by source and reuse the tracker
|
||||
|
|
@ -316,8 +359,6 @@ class OutputGraph:
|
|||
self.export = export
|
||||
self.export_constraints = export_constraints
|
||||
self.frame_state = frame_state
|
||||
# Map from graph input's `Source` to sizes / strides metadata
|
||||
self.input_source_to_sizes_strides: dict[Source, dict[str, Any]] = {}
|
||||
self.cleanup_hooks: list[Callable[[], Any]] = []
|
||||
# compile_id is an id number for the current torch.compile
|
||||
self.compile_id: int = next(_compile_id_counter)
|
||||
|
|
@ -400,8 +441,6 @@ class OutputGraph:
|
|||
|
||||
# Not checkpointed
|
||||
self.compiler_fn: Optional[CompilerFn] = compiler_fn
|
||||
self.global_scope: Scope = global_scope
|
||||
self.local_scope = local_scope
|
||||
self.root_tx = root_tx
|
||||
|
||||
# Given a source, what are the user stacks of all locations that
|
||||
|
|
@ -423,8 +462,6 @@ class OutputGraph:
|
|||
|
||||
# This returns false if TF Overall (both mode and subclass) is disabled OR that TF Mode stack is empty
|
||||
self.torch_function_mode_enabled = torch._C._is_torch_function_mode_enabled()
|
||||
# This records the initial torch function mode stack for guarding
|
||||
self.torch_function_mode_stack = torch_function_mode_stack
|
||||
|
||||
# Tracks if the output graph has a user defined allowed function in the
|
||||
# graph. This is used later to determine if we should fallback to eager
|
||||
|
|
@ -473,8 +510,6 @@ class OutputGraph:
|
|||
self.install_builtins_dict_in_fglobals()
|
||||
)
|
||||
|
||||
self.guard_on_key_order: set[str] = set()
|
||||
|
||||
def install_builtins_dict_in_fglobals(self):
|
||||
# f_globals["__builtins__"] can be a dict or a module. This is an
|
||||
# implemenation detail -
|
||||
|
|
@ -544,6 +579,19 @@ class OutputGraph:
|
|||
GlobalStateSource().make_guard(GuardBuilder.FUNCTORCH_STACK_MATCH)
|
||||
)
|
||||
|
||||
def dump_guards_state(self):
|
||||
return OutputGraphGuardsState(
|
||||
local_scope=self.local_scope,
|
||||
global_scope=self.global_scope,
|
||||
torch_function_mode_stack=self.torch_function_mode_stack,
|
||||
guard_on_key_order=self.guard_on_key_order,
|
||||
input_source_to_sizes_strides=self.input_source_to_sizes_strides,
|
||||
export=self.export,
|
||||
export_constraints=self.export_constraints,
|
||||
_guards=self.guards,
|
||||
_aotautograd_guards=self.aotautograd_guards,
|
||||
)
|
||||
|
||||
def synthetic_graph_input(self, fn, args):
|
||||
"""
|
||||
call fn(*args) before the graph runs and turn the result into a fake input.
|
||||
|
|
@ -670,6 +718,10 @@ class OutputGraph:
|
|||
def nn_modules(self) -> dict[str, Any]:
|
||||
return self.tracing_context.module_context.nn_modules
|
||||
|
||||
@property
|
||||
def aotautograd_guards(self):
|
||||
return self.tracing_context.guards_context.aotautograd_guards
|
||||
|
||||
def save_global_state(self, out=None):
|
||||
"""
|
||||
Saves to out if it is provided. Else saves to the tracing context's global_state.
|
||||
|
|
|
|||
|
|
@ -869,10 +869,16 @@ def is_from_local_source(source: Source, *, only_allow_input=False):
|
|||
return True
|
||||
|
||||
|
||||
def is_from_global_source(source: Source):
|
||||
def is_from_global_source(source: Source) -> bool:
|
||||
return get_global_source_name(source) is not None
|
||||
|
||||
|
||||
def get_global_source_name(source: Source) -> Optional[str]:
|
||||
if isinstance(source, ChainedSource):
|
||||
return is_from_global_source(source.base)
|
||||
return isinstance(source, GlobalSource)
|
||||
return get_global_source_name(source.base)
|
||||
if not isinstance(source, GlobalSource):
|
||||
return None
|
||||
return source.global_name
|
||||
|
||||
|
||||
def is_from_nonlocal_source(source: Source):
|
||||
|
|
|
|||
|
|
@ -410,7 +410,7 @@ torch._dynamo.guards._parse_guard_env_guards to avoid a RuntimeError.
|
|||
"""
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class GuardEnvExpr:
|
||||
pass
|
||||
|
||||
|
|
@ -421,7 +421,7 @@ input_pos_a and input_pos_b are input positions we have deduped.
|
|||
"""
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class DuplicateInputs(GuardEnvExpr):
|
||||
input_source_a: Source
|
||||
input_source_b: Source
|
||||
|
|
@ -444,7 +444,7 @@ overlapping with any other input, overlapping_sources represent tensors that eit
|
|||
"""
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class StorageOverlap(GuardEnvExpr):
|
||||
overlapping_sources: list[Source]
|
||||
non_overlapping_sources: list[Source]
|
||||
|
|
|
|||
|
|
@ -489,7 +489,12 @@ class FakeTensorConverter:
|
|||
|
||||
# If you specify the device, it MUST be a meta tensor.
|
||||
def from_meta_and_device(
|
||||
self, fake_mode: FakeTensorMode, t: Tensor, device: torch.device
|
||||
self,
|
||||
fake_mode: FakeTensorMode,
|
||||
t: Tensor,
|
||||
device: torch.device,
|
||||
pytype: Optional[type[torch.Tensor]] = None,
|
||||
dispatch_keys: Optional[torch.DispatchKeySet] = None,
|
||||
) -> FakeTensor:
|
||||
assert (
|
||||
t.device.type == "meta"
|
||||
|
|
@ -499,7 +504,9 @@ class FakeTensorConverter:
|
|||
maybe_memo = self._get_memo(t)
|
||||
if maybe_memo is not None:
|
||||
return maybe_memo
|
||||
out = FakeTensor(fake_mode, t, device)
|
||||
out = FakeTensor(
|
||||
fake_mode, t, device, pytype=pytype, dispatch_keys=dispatch_keys
|
||||
)
|
||||
self.set_tensor_memo(t, out)
|
||||
return out
|
||||
|
||||
|
|
@ -651,6 +658,12 @@ class FakeTensor(Tensor):
|
|||
# nested int.
|
||||
nested_int_memo = SymNumberMemoDescriptor(is_nested_int=True)
|
||||
|
||||
# FakeTensor doesn't fully emulate the original tensor's Python type
|
||||
# and dispatch key set, therefore sometimes we want to track them
|
||||
# separately.
|
||||
pytype: Optional[type[Tensor]]
|
||||
dispatch_keys: Optional[torch.DispatchKeySet]
|
||||
|
||||
# Indicates to our torch_dispatch dispatching infra that
|
||||
# this is an "infra" mode with lower dispatching precedence.
|
||||
_mode_key = torch._C._TorchDispatchModeKey.FAKE
|
||||
|
|
@ -700,6 +713,8 @@ class FakeTensor(Tensor):
|
|||
device: torch.device,
|
||||
constant: Optional[Tensor] = None,
|
||||
real_tensor: Optional[Tensor] = None,
|
||||
pytype: Optional[type[Tensor]] = None,
|
||||
dispatch_keys: Optional[torch.DispatchKeySet] = None,
|
||||
) -> Self:
|
||||
self = Tensor._make_subclass(
|
||||
cls,
|
||||
|
|
@ -742,6 +757,8 @@ class FakeTensor(Tensor):
|
|||
self.fake_device = device
|
||||
self.fake_mode = fake_mode
|
||||
self.constant = constant
|
||||
self.pytype = pytype
|
||||
self.dispatch_keys = dispatch_keys
|
||||
assert not isinstance(real_tensor, FakeTensor)
|
||||
self.real_tensor = real_tensor
|
||||
self.nonzero_memo = None
|
||||
|
|
|
|||
|
|
@ -88,10 +88,11 @@ TensorCheck::TensorCheck(
|
|||
const LocalState& state,
|
||||
PyTypeObject* pt,
|
||||
const at::Tensor& v,
|
||||
c10::DispatchKeySet dispatch_key_set,
|
||||
std::vector<std::optional<c10::SymInt>> dynamic_dims_sizes,
|
||||
std::vector<std::optional<c10::SymInt>> dynamic_dims_strides)
|
||||
: pytype(pt),
|
||||
dispatch_key_(state.apply(v.key_set()).raw_repr()),
|
||||
dispatch_key_(state.apply(dispatch_key_set).raw_repr()),
|
||||
dtype_(v.dtype().toScalarType()),
|
||||
device_index_(v.device().index()),
|
||||
requires_grad_(v.requires_grad()),
|
||||
|
|
@ -376,6 +377,7 @@ static int TensorGuards_init(
|
|||
state,
|
||||
Py_TYPE(item),
|
||||
std::move(tensor),
|
||||
tensor.key_set(),
|
||||
std::move(tensor_dims_size),
|
||||
std::move(tensor_dims_stride));
|
||||
}
|
||||
|
|
@ -3477,7 +3479,9 @@ class TENSOR_MATCH : public LeafGuard {
|
|||
py::object dynamic_dims_sizes_py,
|
||||
py::object dynamic_dims_strides_py,
|
||||
py::object tensor_name,
|
||||
py::object verbose_code_parts)
|
||||
py::object verbose_code_parts,
|
||||
py::object pytype,
|
||||
py::object dispatch_keys)
|
||||
: LeafGuard(root_guard_manager, std::move(verbose_code_parts)),
|
||||
_tensor_name(py::cast<std::string>(std::move(tensor_name))) {
|
||||
root_guard_manager->set_init_local_state_flag();
|
||||
|
|
@ -3486,6 +3490,10 @@ class TENSOR_MATCH : public LeafGuard {
|
|||
PyErr_SetString(PyExc_TypeError, "expected Tensor()");
|
||||
return;
|
||||
}
|
||||
if (!PyType_Check(pytype.ptr())) {
|
||||
PyErr_SetString(PyExc_TypeError, "expected type object");
|
||||
return;
|
||||
}
|
||||
auto tensor = THPVariable_Unpack(item);
|
||||
|
||||
std::vector<std::optional<c10::SymInt>> tensor_dims_size =
|
||||
|
|
@ -3502,8 +3510,9 @@ class TENSOR_MATCH : public LeafGuard {
|
|||
LocalState state;
|
||||
_tensor_check = std::make_unique<TensorCheck>(
|
||||
state,
|
||||
Py_TYPE(item),
|
||||
(PyTypeObject*)pytype.ptr(),
|
||||
std::move(tensor),
|
||||
dispatch_keys.cast<c10::DispatchKeySet>(),
|
||||
std::move(tensor_dims_size),
|
||||
std::move(tensor_dims_stride));
|
||||
}
|
||||
|
|
@ -5523,7 +5532,9 @@ PyObject* torch_c_dynamo_guards_init() {
|
|||
py::object,
|
||||
py::object,
|
||||
py::str,
|
||||
py::list>())
|
||||
py::list,
|
||||
py::type,
|
||||
py::object>())
|
||||
.def("__call__", &TENSOR_MATCH::check);
|
||||
// NOLINTNEXTLINE(bugprone-unused-raii)
|
||||
py::class_<RelationalGuard, LeafGuard, std::shared_ptr<RelationalGuard>>(
|
||||
|
|
@ -5869,7 +5880,9 @@ PyObject* torch_c_dynamo_guards_init() {
|
|||
py::object sizes,
|
||||
py::object strides,
|
||||
py::object tensor_name,
|
||||
py::object verbose_code_parts) -> void {
|
||||
py::object verbose_code_parts,
|
||||
py::object pytype,
|
||||
py::object dispatch_keys) -> void {
|
||||
SKIP_IF_GUARD_ALREADY_PRESENT("TENSOR_MATCH");
|
||||
self.add_leaf_guard(std::make_shared<TENSOR_MATCH>(
|
||||
self.get_root(),
|
||||
|
|
@ -5877,7 +5890,9 @@ PyObject* torch_c_dynamo_guards_init() {
|
|||
std::move(sizes),
|
||||
std::move(strides),
|
||||
std::move(tensor_name),
|
||||
std::move(verbose_code_parts)));
|
||||
std::move(verbose_code_parts),
|
||||
std::move(pytype),
|
||||
std::move(dispatch_keys)));
|
||||
})
|
||||
|
||||
// return by reference because GuardManager has the ownership of accessors
|
||||
|
|
|
|||
|
|
@ -43,6 +43,7 @@ class TensorCheck {
|
|||
const LocalState& state,
|
||||
PyTypeObject* pt,
|
||||
const at::Tensor& v,
|
||||
c10::DispatchKeySet dispatch_key_set,
|
||||
std::vector<std::optional<c10::SymInt>> dynamic_dims_sizes,
|
||||
std::vector<std::optional<c10::SymInt>> dynamic_dims_strides);
|
||||
|
||||
|
|
|
|||
|
|
@ -771,7 +771,8 @@ void initDispatchBindings(PyObject* module) {
|
|||
return self.add(k);
|
||||
})
|
||||
.def("has", &c10::DispatchKeySet::has)
|
||||
.def("__repr__", [](c10::DispatchKeySet d) { return c10::toString(d); });
|
||||
.def("__repr__", [](c10::DispatchKeySet d) { return c10::toString(d); })
|
||||
.def_static("from_raw_repr", &c10::DispatchKeySet::from_raw_repr);
|
||||
|
||||
m.attr("_dispatch_autogradother_backends") =
|
||||
py::cast(c10::autogradother_backends);
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user