[dynamo] Add guard serialization for tensor matches. (#151318)

This is a proof-of-concept of how we could serialize a guard and deserialize it back from the bytes.

The main behavioral change introduced in this diff is on CheckFunctionManager:

```
check_fn_manager = CheckFunctionManager(code, output_graph, guards_serialization_mode="save")

guards_state: bytes = check_fn_manager.guards_state
```

Once `guards_serialization_mode` is set to `save`, CheckFunctionManager will return an addtional `bytes` object called `guards_state` which should contain all the information needed for deserializing guards later.

When we load back guards state, we will set `guards_serialization_mode` is set to `load`:

```
output_graph_state = pickle.loads(guards_state)
check_fn_manager = CheckFunctionManager(code, output_graph_state, guards_serialization_mode="load")
```

# TENSOR_MATCH

Since we have many types of guards to support, we will break the work into small diffs instead of a single diff to support every guards.

We kick off the work from TENSOR_MATCH from this diff.

# Testing

For each type of guard we will test it like the following:
1. Use guard_filter_fn to select 1 type of guard each time.
2. Call InstructionTranslator directly on an example function to get OutputGraph and CheckFunctionManager (reference guard manager)
3. Serialize->deserialize the output graph state and re-build the guards with a new CheckFunctionManager (loaded guard manager)
4. Throw a set of example inputs to both reference and loaded guard manager to see if their behavior match.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151318
Approved by: https://github.com/jansel, https://github.com/anijain2305
This commit is contained in:
zhxchen17 2025-04-24 13:09:26 -07:00 committed by PyTorch MergeBot
parent 6e8602b558
commit a34c28e0d2
12 changed files with 398 additions and 35 deletions

View File

@ -398,6 +398,10 @@ class DispatchKeySet final {
return repr_; return repr_;
} }
static DispatchKeySet from_raw_repr(uint64_t x) {
return DispatchKeySet(RAW, x);
}
DispatchKey highestFunctionalityKey() const { DispatchKey highestFunctionalityKey() const {
auto functionality_idx = indexOfHighestBit(); auto functionality_idx = indexOfHighestBit();
// This means that none of the functionality bits were set. // This means that none of the functionality bits were set.

View File

@ -295,7 +295,15 @@ num_guards_executed=0)
x = torch.randn(4, 4) x = torch.randn(4, 4)
size = list(x.size()) size = list(x.size())
stride = list(x.stride()) stride = list(x.stride())
guard_manager.add_tensor_match_guard(x, size, stride, "x", ["check_tensor(x)"]) guard_manager.add_tensor_match_guard(
x,
size,
stride,
"x",
["check_tensor(x)"],
type(x),
torch._C._dispatch_keys(x),
)
self.assertTrue(guard_manager.check(x)) self.assertTrue(guard_manager.check(x))
self.assertTrue(guard_manager.check_verbose(x).result) self.assertTrue(guard_manager.check_verbose(x).result)
self.assertTrue(guard_manager.check(torch.randn(4, 4))) self.assertTrue(guard_manager.check(torch.randn(4, 4)))

View File

@ -0,0 +1,147 @@
# Owner(s): ["module: dynamo"]
import dataclasses
import pickle
import sys
import types
import torch
import torch._dynamo.testing
import torch._inductor.config
import torch._inductor.test_case
import torch.onnx.operators
import torch.utils.cpp_extension
from torch._dynamo.bytecode_transformation import transform_code_object
from torch._dynamo.guards import CheckFunctionManager, CompileId
from torch._dynamo.symbolic_convert import InstructionTranslator
from torch._dynamo.utils import dynamo_timed, get_metrics_context
from torch._guards import compile_context, CompileContext, tracing
@dataclasses.dataclass
class _FrameState:
f_locals: dict
f_globals: dict
f_code: types.CodeType
f_builtins: dict
class TestGuardSerialization(torch._inductor.test_case.TestCase):
def _tracefunc(self, frame, event, arg):
if event != "call":
return
if self._frame_state is not None:
return
self._frame_state = _FrameState(
f_locals=frame.f_locals,
f_globals=frame.f_globals,
f_code=frame.f_code,
f_builtins=frame.f_builtins,
)
def _test_serialization(self, guard_type, fn, *args, **kwargs):
self._frame_state = None
sys.settrace(self._tracefunc)
try:
fn(*args, **kwargs)
finally:
sys.settrace(None)
assert self._frame_state is not None
def guard_filter_fn(guards):
return [g.guard_type == guard_type for g in guards]
ref_gm = None
loaded_gm = None
def transform(instructions: list, code_options: dict[str, object]):
"""
The goal is here is not to reimplement dynamo, but just to have a
simplified version to extract the state from symbolic convert.
Should not work on all cases, but should work on simple functions
in this test file.
"""
nonlocal ref_gm
nonlocal loaded_gm
tracer = InstructionTranslator(
instructions,
self._frame_state.f_code,
self._frame_state.f_locals,
self._frame_state.f_globals,
self._frame_state.f_builtins,
(), # TODO closure
[], # TODO tf_mode_stack,
code_options,
lambda gm, *args, **kwargs: gm.forward,
one_graph=False,
export=False,
export_constraints=None,
frame_state=None,
speculation_log=None,
exn_vt_stack=None,
distributed_state=None,
)
with compile_context(CompileContext(CompileId(0, 0))), tracing(
tracer.output.tracing_context
), tracer.set_current_tx(), get_metrics_context(), dynamo_timed(""):
tracer.run()
check_fn_manager = CheckFunctionManager(
self._frame_state.f_code,
tracer.output,
guard_filter_fn=guard_filter_fn,
guards_serialization_mode="save",
)
ref_gm = check_fn_manager.guard_manager
guards_state = check_fn_manager.guards_state
self.assertIsNotNone(guards_state)
guards_state = pickle.loads(guards_state)
check_fn_manager = CheckFunctionManager(
self._frame_state.f_code,
guards_state.output_graph,
guards_serialization_mode="load",
)
loaded_gm = check_fn_manager.guard_manager
try:
transform_code_object(self._frame_state.f_code, transform)
finally:
self._frame_state = None
self.assertIsNotNone(ref_gm)
self.assertIsNotNone(loaded_gm)
return ref_gm, loaded_gm
def _test_check_fn(self, ref, loaded, inputs, expected):
self.assertIsInstance(inputs, dict)
self.assertEqual(ref.check(inputs), expected)
self.assertEqual(ref.check(inputs), loaded.check(inputs))
def test_tensor_match(self):
def f(x: torch.Tensor):
return x + 1
ref, loaded = self._test_serialization(
"TENSOR_MATCH", f, torch.ones(2, dtype=torch.float32)
)
self._test_check_fn(
ref, loaded, {"x": torch.randn(2, dtype=torch.float32)}, True
)
self._test_check_fn(
ref, loaded, {"x": torch.randn(3, dtype=torch.float32)}, False
)
self._test_check_fn(
ref, loaded, {"x": torch.randn(2, dtype=torch.float64)}, False
)
self._test_check_fn(ref, loaded, {"x": None}, False)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()

View File

@ -1671,6 +1671,8 @@ class DispatchKeySet:
def __sub__(self, other: DispatchKeySet) -> DispatchKeySet: ... def __sub__(self, other: DispatchKeySet) -> DispatchKeySet: ...
def __and__(self, other: DispatchKeySet) -> DispatchKeySet: ... def __and__(self, other: DispatchKeySet) -> DispatchKeySet: ...
def raw_repr(self) -> _int: ... def raw_repr(self) -> _int: ...
@staticmethod
def from_raw_repr(raw: _int) -> DispatchKeySet: ...
def highestPriorityTypeId(self) -> DispatchKey: ... def highestPriorityTypeId(self) -> DispatchKey: ...
def has(self, k: _dispatchkey) -> _bool: ... def has(self, k: _dispatchkey) -> _bool: ...
def add(self, k: _dispatchkey) -> DispatchKeySet: ... def add(self, k: _dispatchkey) -> DispatchKeySet: ...

View File

@ -27,8 +27,10 @@ import enum
import functools import functools
import importlib import importlib
import inspect import inspect
import io
import logging import logging
import math import math
import pickle
import sys import sys
import textwrap import textwrap
import types import types
@ -57,9 +59,9 @@ from torch._C._dynamo.guards import (
RootGuardManager, RootGuardManager,
) )
from torch._dynamo.source import ( from torch._dynamo.source import (
get_global_source_name,
IndexedSource, IndexedSource,
is_from_flatten_script_object_source, is_from_flatten_script_object_source,
is_from_global_source,
is_from_local_source, is_from_local_source,
is_from_optimizer_source, is_from_optimizer_source,
TensorProperty, TensorProperty,
@ -166,6 +168,8 @@ except ModuleNotFoundError:
if TYPE_CHECKING: if TYPE_CHECKING:
from sympy import Symbol from sympy import Symbol
from torch._dynamo.output_graph import OutputGraphGuardsState
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
guards_log = torch._logging.getArtifactLogger(__name__, "guards") guards_log = torch._logging.getArtifactLogger(__name__, "guards")
@ -494,10 +498,9 @@ def convert_to_concrete_values(size_or_stride):
return converted return converted
def get_tensor_guard_code_part(value, name, sizes, strides): def get_tensor_guard_code_part(value, name, sizes, strides, pytype, dispatch_keys):
pytype = type(value)
dispatch_key = ( dispatch_key = (
torch._C._dispatch_keys(value) | torch._C._dispatch_tls_local_include_set() dispatch_keys | torch._C._dispatch_tls_local_include_set()
) - torch._C._dispatch_tls_local_exclude_set() ) - torch._C._dispatch_tls_local_exclude_set()
dtype = value.dtype dtype = value.dtype
device_index = value.device.index device_index = value.device.index
@ -2142,6 +2145,15 @@ class GuardBuilder(GuardBuilderBase):
value = value() value = value()
value = value if value is not None else self.get(guard.name) value = value if value is not None else self.get(guard.name)
pytype = type(value)
dispatch_keys = torch._C._dispatch_keys(value)
if isinstance(value, torch._subclasses.FakeTensor):
if value.pytype is not None:
pytype = value.pytype
if value.dispatch_keys is not None:
dispatch_keys = value.dispatch_keys
assert isinstance(value, torch.Tensor) assert isinstance(value, torch.Tensor)
if config.log_compilation_metrics and isinstance(value, torch.nn.Parameter): if config.log_compilation_metrics and isinstance(value, torch.nn.Parameter):
@ -2218,7 +2230,9 @@ class GuardBuilder(GuardBuilderBase):
stride = convert_to_concrete_values(metadata["stride"]) stride = convert_to_concrete_values(metadata["stride"])
verbose_code_parts = get_verbose_code_parts( verbose_code_parts = get_verbose_code_parts(
get_tensor_guard_code_part(value, tensor_name, size, stride), get_tensor_guard_code_part(
value, tensor_name, size, stride, pytype, dispatch_keys
),
guard, guard,
) )
guard_manager.add_tensor_match_guard( guard_manager.add_tensor_match_guard(
@ -2227,6 +2241,8 @@ class GuardBuilder(GuardBuilderBase):
stride, stride,
tensor_name, tensor_name,
verbose_code_parts, verbose_code_parts,
pytype,
dispatch_keys,
) )
# We consider TENSOR_MATCH guard to be important enough to be # We consider TENSOR_MATCH guard to be important enough to be
@ -2459,6 +2475,66 @@ class DeletedGuardManagerWrapper(GuardManagerWrapper):
self.diff_guard_root = None self.diff_guard_root = None
@dataclasses.dataclass
class GuardsState:
output_graph: OutputGraphGuardsState
# TODO SHAPE_ENV states here
class GuardsStatePickler(pickle.Pickler):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.fake_mode = torch._subclasses.FakeTensorMode()
self.tensor_converter = torch._subclasses.fake_tensor.FakeTensorConverter()
@classmethod
def _unpickle_module(cls, state):
mod = torch.nn.Module()
mod.__setstate__(state)
return mod
@classmethod
def _unpickle_tensor(cls, meta_tensor, device, pytype, dispatch_keys_raw):
fake_mode = torch._subclasses.FakeTensorMode()
tensor_converter = torch._subclasses.fake_tensor.FakeTensorConverter()
return tensor_converter.from_meta_and_device(
fake_mode,
meta_tensor,
device,
pytype,
torch._C.DispatchKeySet.from_raw_repr(dispatch_keys_raw),
)
def reducer_override(self, obj):
if isinstance(obj, torch.Tensor) and obj.device.type != "meta":
return type(self)._unpickle_tensor, (
torch.empty_like(obj, device="meta"),
obj.device,
type(obj),
torch._C._dispatch_keys(obj).raw_repr(),
)
elif isinstance(obj, torch.nn.Module):
if obj.__class__.__getstate__ == torch.nn.Module.__getstate__:
return type(self)._unpickle_module, (obj.__getstate__(),)
if type(obj).__qualname__ != type(obj).__name__:
raise RuntimeError(
f"Type {type(obj)} for object {obj} cannot be saved "
+ "into torch.compile() package since it's defined in local scope. "
+ "Please define the class at global scope (top level of a module)."
)
return NotImplemented
def pickle_guards_state(state: GuardsState) -> bytes:
buf = io.BytesIO()
pickler = GuardsStatePickler(buf)
pickler.dump(state)
return buf.getvalue()
# NB: Naively, you'd expect this to only be a function that produces # NB: Naively, you'd expect this to only be a function that produces
# the callable that constitutes the guard. However, there is some # the callable that constitutes the guard. However, there is some
# delicate handling for invalidating this check function when the # delicate handling for invalidating this check function when the
@ -2474,6 +2550,7 @@ class CheckFunctionManager:
guard_filter_fn: Optional[ guard_filter_fn: Optional[
Callable[[list[GuardFilterEntry]], list[bool]] Callable[[list[GuardFilterEntry]], list[bool]]
] = None, ] = None,
guards_serialization_mode: Optional[str] = None,
): ):
guards = output_graph.guards if output_graph else None guards = output_graph.guards if output_graph else None
self._weakrefs: dict[int, ReferenceType[object]] = {} self._weakrefs: dict[int, ReferenceType[object]] = {}
@ -2488,6 +2565,7 @@ class CheckFunctionManager:
self.torch_function_mode_stack = ( self.torch_function_mode_stack = (
output_graph.torch_function_mode_stack if output_graph else None output_graph.torch_function_mode_stack if output_graph else None
) )
self.guards_serialization_mode = guards_serialization_mode
if not justknobs_check("pytorch/compiler:guard_nn_modules"): if not justknobs_check("pytorch/compiler:guard_nn_modules"):
log.warning("guard_nn_modules is turned off using justknobs killswitch") log.warning("guard_nn_modules is turned off using justknobs killswitch")
@ -2508,7 +2586,7 @@ class CheckFunctionManager:
else: else:
has_value = True has_value = True
value = builder.get(guard.name) value = builder.get(guard.name)
is_global = is_from_global_source(guard.originating_source) is_global = get_global_source_name(guard.originating_source) is not None
guard_fn = guard.create_fn guard_fn = guard.create_fn
if isinstance(guard_fn, functools.partial): if isinstance(guard_fn, functools.partial):
guard_fn = guard.create_fn.func guard_fn = guard.create_fn.func
@ -2558,7 +2636,7 @@ class CheckFunctionManager:
# TODO(anijain2305, ydwu4) - Skipping export because of following test # TODO(anijain2305, ydwu4) - Skipping export because of following test
# python -s test/dynamo/test_export.py -k test_export_with_symbool_inputs # python -s test/dynamo/test_export.py -k test_export_with_symbool_inputs
latency = 0.0 latency = 0.0
if not output_graph.export: if not output_graph.export and self.guards_serialization_mode != "load":
if not self.guard_manager.check(output_graph.local_scope): if not self.guard_manager.check(output_graph.local_scope):
reasons = get_guard_fail_reason_helper( reasons = get_guard_fail_reason_helper(
self.guard_manager, # type: ignore[arg-type] self.guard_manager, # type: ignore[arg-type]
@ -2591,6 +2669,40 @@ class CheckFunctionManager:
# account for, we simply increment at the toplevel instead. # account for, we simply increment at the toplevel instead.
CompileEventLogger.increment_toplevel("guard_latency_us", int(latency)) CompileEventLogger.increment_toplevel("guard_latency_us", int(latency))
self.guards_state: Optional[bytes] = None
if self.guards_serialization_mode == "save":
output_graph_guards_state = self.output_graph.dump_guards_state()
# Only serialize the global variables that are actually used in guards.
used_global_vars = set()
for guard in sorted_guards:
if name := get_global_source_name(guard.originating_source):
assert isinstance(name, str)
used_global_vars.add(name)
output_graph_guards_state = dataclasses.replace(
output_graph_guards_state,
global_scope={
k: v
for k, v in output_graph_guards_state.global_scope.items()
if k in used_global_vars
},
_guards=torch._guards.GuardsSet(
{
dataclasses.replace(
guard,
obj_weakref=None,
guarded_class_weakref=None,
create_fn=guard.inner_create_fn(),
)
for guard in sorted_guards
}
),
)
guards_state = GuardsState(
output_graph=output_graph_guards_state,
)
self.guards_state = pickle_guards_state(guards_state)
# TODO: don't do the string rep, do something more structured here # TODO: don't do the string rep, do something more structured here
torch._logging.trace_structured( torch._logging.trace_structured(
"dynamo_cpp_guards_str", "dynamo_cpp_guards_str",
@ -2757,9 +2869,7 @@ class CheckFunctionManager:
) )
aotautograd_guards: list[GuardEnvExpr] = ( aotautograd_guards: list[GuardEnvExpr] = (
self.output_graph.tracing_context.guards_context.aotautograd_guards self.output_graph.aotautograd_guards if self.output_graph else []
if self.output_graph
else []
) )
# TODO(anijain2305) - There is a duplicate logic in Dynamo to find # TODO(anijain2305) - There is a duplicate logic in Dynamo to find

View File

@ -283,7 +283,43 @@ class WrapperBackend:
Scope = dict[str, object] Scope = dict[str, object]
class OutputGraph: @dataclass
class OutputGraphGuardsState:
"""
A base class containing fields that are considered "persistent" when we
want to save all the important state for reconstrucing guards in a different
process. Normally we don't need to add states here, but we may have to when
the information is needed to serialize the guards, so the fields here are
supposed to be serializable as a requirement.
"""
local_scope: Scope
global_scope: Scope
# This records the initial torch function mode stack for guarding
torch_function_mode_stack: list[torch.overrides.TorchFunctionMode]
guard_on_key_order: set[str]
# Map from graph input's `Source` to sizes / strides metadata
input_source_to_sizes_strides: dict[Source, dict[str, Any]]
export: bool = False
export_constraints: bool = False
_guards: Optional[torch._guards.GuardsSet] = None
_aotautograd_guards: Optional[list[torch._guards.GuardEnvExpr]] = None
@property
def shape_env(self):
raise AssertionError(f"shape_env shouldn't be accessed from {type(self)}")
@property
def guards(self):
return self._guards
@property
def aotautograd_guards(self):
return self._aotautograd_guards
class OutputGraph(OutputGraphGuardsState):
""" """
Wrapper class to hold outputs of InstructionTranslator. Mainly the Wrapper class to hold outputs of InstructionTranslator. Mainly the
generated fx.Graph. generated fx.Graph.
@ -309,6 +345,13 @@ class OutputGraph:
f_code, f_code,
torch_function_mode_stack, torch_function_mode_stack,
): ):
super().__init__(
local_scope,
global_scope,
torch_function_mode_stack,
guard_on_key_order=set(),
input_source_to_sizes_strides={},
)
self.tracers = [SubgraphTracer(self, is_export=export)] self.tracers = [SubgraphTracer(self, is_export=export)]
# Map from graph input's `Source` to its `VariableTracker` to # Map from graph input's `Source` to its `VariableTracker` to
# de-duplicate graph inputs by source and reuse the tracker # de-duplicate graph inputs by source and reuse the tracker
@ -316,8 +359,6 @@ class OutputGraph:
self.export = export self.export = export
self.export_constraints = export_constraints self.export_constraints = export_constraints
self.frame_state = frame_state self.frame_state = frame_state
# Map from graph input's `Source` to sizes / strides metadata
self.input_source_to_sizes_strides: dict[Source, dict[str, Any]] = {}
self.cleanup_hooks: list[Callable[[], Any]] = [] self.cleanup_hooks: list[Callable[[], Any]] = []
# compile_id is an id number for the current torch.compile # compile_id is an id number for the current torch.compile
self.compile_id: int = next(_compile_id_counter) self.compile_id: int = next(_compile_id_counter)
@ -400,8 +441,6 @@ class OutputGraph:
# Not checkpointed # Not checkpointed
self.compiler_fn: Optional[CompilerFn] = compiler_fn self.compiler_fn: Optional[CompilerFn] = compiler_fn
self.global_scope: Scope = global_scope
self.local_scope = local_scope
self.root_tx = root_tx self.root_tx = root_tx
# Given a source, what are the user stacks of all locations that # Given a source, what are the user stacks of all locations that
@ -423,8 +462,6 @@ class OutputGraph:
# This returns false if TF Overall (both mode and subclass) is disabled OR that TF Mode stack is empty # This returns false if TF Overall (both mode and subclass) is disabled OR that TF Mode stack is empty
self.torch_function_mode_enabled = torch._C._is_torch_function_mode_enabled() self.torch_function_mode_enabled = torch._C._is_torch_function_mode_enabled()
# This records the initial torch function mode stack for guarding
self.torch_function_mode_stack = torch_function_mode_stack
# Tracks if the output graph has a user defined allowed function in the # Tracks if the output graph has a user defined allowed function in the
# graph. This is used later to determine if we should fallback to eager # graph. This is used later to determine if we should fallback to eager
@ -473,8 +510,6 @@ class OutputGraph:
self.install_builtins_dict_in_fglobals() self.install_builtins_dict_in_fglobals()
) )
self.guard_on_key_order: set[str] = set()
def install_builtins_dict_in_fglobals(self): def install_builtins_dict_in_fglobals(self):
# f_globals["__builtins__"] can be a dict or a module. This is an # f_globals["__builtins__"] can be a dict or a module. This is an
# implemenation detail - # implemenation detail -
@ -544,6 +579,19 @@ class OutputGraph:
GlobalStateSource().make_guard(GuardBuilder.FUNCTORCH_STACK_MATCH) GlobalStateSource().make_guard(GuardBuilder.FUNCTORCH_STACK_MATCH)
) )
def dump_guards_state(self):
return OutputGraphGuardsState(
local_scope=self.local_scope,
global_scope=self.global_scope,
torch_function_mode_stack=self.torch_function_mode_stack,
guard_on_key_order=self.guard_on_key_order,
input_source_to_sizes_strides=self.input_source_to_sizes_strides,
export=self.export,
export_constraints=self.export_constraints,
_guards=self.guards,
_aotautograd_guards=self.aotautograd_guards,
)
def synthetic_graph_input(self, fn, args): def synthetic_graph_input(self, fn, args):
""" """
call fn(*args) before the graph runs and turn the result into a fake input. call fn(*args) before the graph runs and turn the result into a fake input.
@ -670,6 +718,10 @@ class OutputGraph:
def nn_modules(self) -> dict[str, Any]: def nn_modules(self) -> dict[str, Any]:
return self.tracing_context.module_context.nn_modules return self.tracing_context.module_context.nn_modules
@property
def aotautograd_guards(self):
return self.tracing_context.guards_context.aotautograd_guards
def save_global_state(self, out=None): def save_global_state(self, out=None):
""" """
Saves to out if it is provided. Else saves to the tracing context's global_state. Saves to out if it is provided. Else saves to the tracing context's global_state.

View File

@ -869,10 +869,16 @@ def is_from_local_source(source: Source, *, only_allow_input=False):
return True return True
def is_from_global_source(source: Source): def is_from_global_source(source: Source) -> bool:
return get_global_source_name(source) is not None
def get_global_source_name(source: Source) -> Optional[str]:
if isinstance(source, ChainedSource): if isinstance(source, ChainedSource):
return is_from_global_source(source.base) return get_global_source_name(source.base)
return isinstance(source, GlobalSource) if not isinstance(source, GlobalSource):
return None
return source.global_name
def is_from_nonlocal_source(source: Source): def is_from_nonlocal_source(source: Source):

View File

@ -410,7 +410,7 @@ torch._dynamo.guards._parse_guard_env_guards to avoid a RuntimeError.
""" """
@dataclasses.dataclass @dataclasses.dataclass(frozen=True)
class GuardEnvExpr: class GuardEnvExpr:
pass pass
@ -421,7 +421,7 @@ input_pos_a and input_pos_b are input positions we have deduped.
""" """
@dataclasses.dataclass @dataclasses.dataclass(frozen=True)
class DuplicateInputs(GuardEnvExpr): class DuplicateInputs(GuardEnvExpr):
input_source_a: Source input_source_a: Source
input_source_b: Source input_source_b: Source
@ -444,7 +444,7 @@ overlapping with any other input, overlapping_sources represent tensors that eit
""" """
@dataclasses.dataclass @dataclasses.dataclass(frozen=True)
class StorageOverlap(GuardEnvExpr): class StorageOverlap(GuardEnvExpr):
overlapping_sources: list[Source] overlapping_sources: list[Source]
non_overlapping_sources: list[Source] non_overlapping_sources: list[Source]

View File

@ -489,7 +489,12 @@ class FakeTensorConverter:
# If you specify the device, it MUST be a meta tensor. # If you specify the device, it MUST be a meta tensor.
def from_meta_and_device( def from_meta_and_device(
self, fake_mode: FakeTensorMode, t: Tensor, device: torch.device self,
fake_mode: FakeTensorMode,
t: Tensor,
device: torch.device,
pytype: Optional[type[torch.Tensor]] = None,
dispatch_keys: Optional[torch.DispatchKeySet] = None,
) -> FakeTensor: ) -> FakeTensor:
assert ( assert (
t.device.type == "meta" t.device.type == "meta"
@ -499,7 +504,9 @@ class FakeTensorConverter:
maybe_memo = self._get_memo(t) maybe_memo = self._get_memo(t)
if maybe_memo is not None: if maybe_memo is not None:
return maybe_memo return maybe_memo
out = FakeTensor(fake_mode, t, device) out = FakeTensor(
fake_mode, t, device, pytype=pytype, dispatch_keys=dispatch_keys
)
self.set_tensor_memo(t, out) self.set_tensor_memo(t, out)
return out return out
@ -651,6 +658,12 @@ class FakeTensor(Tensor):
# nested int. # nested int.
nested_int_memo = SymNumberMemoDescriptor(is_nested_int=True) nested_int_memo = SymNumberMemoDescriptor(is_nested_int=True)
# FakeTensor doesn't fully emulate the original tensor's Python type
# and dispatch key set, therefore sometimes we want to track them
# separately.
pytype: Optional[type[Tensor]]
dispatch_keys: Optional[torch.DispatchKeySet]
# Indicates to our torch_dispatch dispatching infra that # Indicates to our torch_dispatch dispatching infra that
# this is an "infra" mode with lower dispatching precedence. # this is an "infra" mode with lower dispatching precedence.
_mode_key = torch._C._TorchDispatchModeKey.FAKE _mode_key = torch._C._TorchDispatchModeKey.FAKE
@ -700,6 +713,8 @@ class FakeTensor(Tensor):
device: torch.device, device: torch.device,
constant: Optional[Tensor] = None, constant: Optional[Tensor] = None,
real_tensor: Optional[Tensor] = None, real_tensor: Optional[Tensor] = None,
pytype: Optional[type[Tensor]] = None,
dispatch_keys: Optional[torch.DispatchKeySet] = None,
) -> Self: ) -> Self:
self = Tensor._make_subclass( self = Tensor._make_subclass(
cls, cls,
@ -742,6 +757,8 @@ class FakeTensor(Tensor):
self.fake_device = device self.fake_device = device
self.fake_mode = fake_mode self.fake_mode = fake_mode
self.constant = constant self.constant = constant
self.pytype = pytype
self.dispatch_keys = dispatch_keys
assert not isinstance(real_tensor, FakeTensor) assert not isinstance(real_tensor, FakeTensor)
self.real_tensor = real_tensor self.real_tensor = real_tensor
self.nonzero_memo = None self.nonzero_memo = None

View File

@ -88,10 +88,11 @@ TensorCheck::TensorCheck(
const LocalState& state, const LocalState& state,
PyTypeObject* pt, PyTypeObject* pt,
const at::Tensor& v, const at::Tensor& v,
c10::DispatchKeySet dispatch_key_set,
std::vector<std::optional<c10::SymInt>> dynamic_dims_sizes, std::vector<std::optional<c10::SymInt>> dynamic_dims_sizes,
std::vector<std::optional<c10::SymInt>> dynamic_dims_strides) std::vector<std::optional<c10::SymInt>> dynamic_dims_strides)
: pytype(pt), : pytype(pt),
dispatch_key_(state.apply(v.key_set()).raw_repr()), dispatch_key_(state.apply(dispatch_key_set).raw_repr()),
dtype_(v.dtype().toScalarType()), dtype_(v.dtype().toScalarType()),
device_index_(v.device().index()), device_index_(v.device().index()),
requires_grad_(v.requires_grad()), requires_grad_(v.requires_grad()),
@ -376,6 +377,7 @@ static int TensorGuards_init(
state, state,
Py_TYPE(item), Py_TYPE(item),
std::move(tensor), std::move(tensor),
tensor.key_set(),
std::move(tensor_dims_size), std::move(tensor_dims_size),
std::move(tensor_dims_stride)); std::move(tensor_dims_stride));
} }
@ -3477,7 +3479,9 @@ class TENSOR_MATCH : public LeafGuard {
py::object dynamic_dims_sizes_py, py::object dynamic_dims_sizes_py,
py::object dynamic_dims_strides_py, py::object dynamic_dims_strides_py,
py::object tensor_name, py::object tensor_name,
py::object verbose_code_parts) py::object verbose_code_parts,
py::object pytype,
py::object dispatch_keys)
: LeafGuard(root_guard_manager, std::move(verbose_code_parts)), : LeafGuard(root_guard_manager, std::move(verbose_code_parts)),
_tensor_name(py::cast<std::string>(std::move(tensor_name))) { _tensor_name(py::cast<std::string>(std::move(tensor_name))) {
root_guard_manager->set_init_local_state_flag(); root_guard_manager->set_init_local_state_flag();
@ -3486,6 +3490,10 @@ class TENSOR_MATCH : public LeafGuard {
PyErr_SetString(PyExc_TypeError, "expected Tensor()"); PyErr_SetString(PyExc_TypeError, "expected Tensor()");
return; return;
} }
if (!PyType_Check(pytype.ptr())) {
PyErr_SetString(PyExc_TypeError, "expected type object");
return;
}
auto tensor = THPVariable_Unpack(item); auto tensor = THPVariable_Unpack(item);
std::vector<std::optional<c10::SymInt>> tensor_dims_size = std::vector<std::optional<c10::SymInt>> tensor_dims_size =
@ -3502,8 +3510,9 @@ class TENSOR_MATCH : public LeafGuard {
LocalState state; LocalState state;
_tensor_check = std::make_unique<TensorCheck>( _tensor_check = std::make_unique<TensorCheck>(
state, state,
Py_TYPE(item), (PyTypeObject*)pytype.ptr(),
std::move(tensor), std::move(tensor),
dispatch_keys.cast<c10::DispatchKeySet>(),
std::move(tensor_dims_size), std::move(tensor_dims_size),
std::move(tensor_dims_stride)); std::move(tensor_dims_stride));
} }
@ -5523,7 +5532,9 @@ PyObject* torch_c_dynamo_guards_init() {
py::object, py::object,
py::object, py::object,
py::str, py::str,
py::list>()) py::list,
py::type,
py::object>())
.def("__call__", &TENSOR_MATCH::check); .def("__call__", &TENSOR_MATCH::check);
// NOLINTNEXTLINE(bugprone-unused-raii) // NOLINTNEXTLINE(bugprone-unused-raii)
py::class_<RelationalGuard, LeafGuard, std::shared_ptr<RelationalGuard>>( py::class_<RelationalGuard, LeafGuard, std::shared_ptr<RelationalGuard>>(
@ -5869,7 +5880,9 @@ PyObject* torch_c_dynamo_guards_init() {
py::object sizes, py::object sizes,
py::object strides, py::object strides,
py::object tensor_name, py::object tensor_name,
py::object verbose_code_parts) -> void { py::object verbose_code_parts,
py::object pytype,
py::object dispatch_keys) -> void {
SKIP_IF_GUARD_ALREADY_PRESENT("TENSOR_MATCH"); SKIP_IF_GUARD_ALREADY_PRESENT("TENSOR_MATCH");
self.add_leaf_guard(std::make_shared<TENSOR_MATCH>( self.add_leaf_guard(std::make_shared<TENSOR_MATCH>(
self.get_root(), self.get_root(),
@ -5877,7 +5890,9 @@ PyObject* torch_c_dynamo_guards_init() {
std::move(sizes), std::move(sizes),
std::move(strides), std::move(strides),
std::move(tensor_name), std::move(tensor_name),
std::move(verbose_code_parts))); std::move(verbose_code_parts),
std::move(pytype),
std::move(dispatch_keys)));
}) })
// return by reference because GuardManager has the ownership of accessors // return by reference because GuardManager has the ownership of accessors

View File

@ -43,6 +43,7 @@ class TensorCheck {
const LocalState& state, const LocalState& state,
PyTypeObject* pt, PyTypeObject* pt,
const at::Tensor& v, const at::Tensor& v,
c10::DispatchKeySet dispatch_key_set,
std::vector<std::optional<c10::SymInt>> dynamic_dims_sizes, std::vector<std::optional<c10::SymInt>> dynamic_dims_sizes,
std::vector<std::optional<c10::SymInt>> dynamic_dims_strides); std::vector<std::optional<c10::SymInt>> dynamic_dims_strides);

View File

@ -771,7 +771,8 @@ void initDispatchBindings(PyObject* module) {
return self.add(k); return self.add(k);
}) })
.def("has", &c10::DispatchKeySet::has) .def("has", &c10::DispatchKeySet::has)
.def("__repr__", [](c10::DispatchKeySet d) { return c10::toString(d); }); .def("__repr__", [](c10::DispatchKeySet d) { return c10::toString(d); })
.def_static("from_raw_repr", &c10::DispatchKeySet::from_raw_repr);
m.attr("_dispatch_autogradother_backends") = m.attr("_dispatch_autogradother_backends") =
py::cast(c10::autogradother_backends); py::cast(c10::autogradother_backends);