[dynamo] Add guard serialization for tensor matches. (#151318)

This is a proof-of-concept of how we could serialize a guard and deserialize it back from the bytes.

The main behavioral change introduced in this diff is on CheckFunctionManager:

```
check_fn_manager = CheckFunctionManager(code, output_graph, guards_serialization_mode="save")

guards_state: bytes = check_fn_manager.guards_state
```

Once `guards_serialization_mode` is set to `save`, CheckFunctionManager will return an addtional `bytes` object called `guards_state` which should contain all the information needed for deserializing guards later.

When we load back guards state, we will set `guards_serialization_mode` is set to `load`:

```
output_graph_state = pickle.loads(guards_state)
check_fn_manager = CheckFunctionManager(code, output_graph_state, guards_serialization_mode="load")
```

# TENSOR_MATCH

Since we have many types of guards to support, we will break the work into small diffs instead of a single diff to support every guards.

We kick off the work from TENSOR_MATCH from this diff.

# Testing

For each type of guard we will test it like the following:
1. Use guard_filter_fn to select 1 type of guard each time.
2. Call InstructionTranslator directly on an example function to get OutputGraph and CheckFunctionManager (reference guard manager)
3. Serialize->deserialize the output graph state and re-build the guards with a new CheckFunctionManager (loaded guard manager)
4. Throw a set of example inputs to both reference and loaded guard manager to see if their behavior match.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151318
Approved by: https://github.com/jansel, https://github.com/anijain2305
This commit is contained in:
zhxchen17 2025-04-24 13:09:26 -07:00 committed by PyTorch MergeBot
parent 6e8602b558
commit a34c28e0d2
12 changed files with 398 additions and 35 deletions

View File

@ -398,6 +398,10 @@ class DispatchKeySet final {
return repr_;
}
static DispatchKeySet from_raw_repr(uint64_t x) {
return DispatchKeySet(RAW, x);
}
DispatchKey highestFunctionalityKey() const {
auto functionality_idx = indexOfHighestBit();
// This means that none of the functionality bits were set.

View File

@ -295,7 +295,15 @@ num_guards_executed=0)
x = torch.randn(4, 4)
size = list(x.size())
stride = list(x.stride())
guard_manager.add_tensor_match_guard(x, size, stride, "x", ["check_tensor(x)"])
guard_manager.add_tensor_match_guard(
x,
size,
stride,
"x",
["check_tensor(x)"],
type(x),
torch._C._dispatch_keys(x),
)
self.assertTrue(guard_manager.check(x))
self.assertTrue(guard_manager.check_verbose(x).result)
self.assertTrue(guard_manager.check(torch.randn(4, 4)))

View File

@ -0,0 +1,147 @@
# Owner(s): ["module: dynamo"]
import dataclasses
import pickle
import sys
import types
import torch
import torch._dynamo.testing
import torch._inductor.config
import torch._inductor.test_case
import torch.onnx.operators
import torch.utils.cpp_extension
from torch._dynamo.bytecode_transformation import transform_code_object
from torch._dynamo.guards import CheckFunctionManager, CompileId
from torch._dynamo.symbolic_convert import InstructionTranslator
from torch._dynamo.utils import dynamo_timed, get_metrics_context
from torch._guards import compile_context, CompileContext, tracing
@dataclasses.dataclass
class _FrameState:
f_locals: dict
f_globals: dict
f_code: types.CodeType
f_builtins: dict
class TestGuardSerialization(torch._inductor.test_case.TestCase):
def _tracefunc(self, frame, event, arg):
if event != "call":
return
if self._frame_state is not None:
return
self._frame_state = _FrameState(
f_locals=frame.f_locals,
f_globals=frame.f_globals,
f_code=frame.f_code,
f_builtins=frame.f_builtins,
)
def _test_serialization(self, guard_type, fn, *args, **kwargs):
self._frame_state = None
sys.settrace(self._tracefunc)
try:
fn(*args, **kwargs)
finally:
sys.settrace(None)
assert self._frame_state is not None
def guard_filter_fn(guards):
return [g.guard_type == guard_type for g in guards]
ref_gm = None
loaded_gm = None
def transform(instructions: list, code_options: dict[str, object]):
"""
The goal is here is not to reimplement dynamo, but just to have a
simplified version to extract the state from symbolic convert.
Should not work on all cases, but should work on simple functions
in this test file.
"""
nonlocal ref_gm
nonlocal loaded_gm
tracer = InstructionTranslator(
instructions,
self._frame_state.f_code,
self._frame_state.f_locals,
self._frame_state.f_globals,
self._frame_state.f_builtins,
(), # TODO closure
[], # TODO tf_mode_stack,
code_options,
lambda gm, *args, **kwargs: gm.forward,
one_graph=False,
export=False,
export_constraints=None,
frame_state=None,
speculation_log=None,
exn_vt_stack=None,
distributed_state=None,
)
with compile_context(CompileContext(CompileId(0, 0))), tracing(
tracer.output.tracing_context
), tracer.set_current_tx(), get_metrics_context(), dynamo_timed(""):
tracer.run()
check_fn_manager = CheckFunctionManager(
self._frame_state.f_code,
tracer.output,
guard_filter_fn=guard_filter_fn,
guards_serialization_mode="save",
)
ref_gm = check_fn_manager.guard_manager
guards_state = check_fn_manager.guards_state
self.assertIsNotNone(guards_state)
guards_state = pickle.loads(guards_state)
check_fn_manager = CheckFunctionManager(
self._frame_state.f_code,
guards_state.output_graph,
guards_serialization_mode="load",
)
loaded_gm = check_fn_manager.guard_manager
try:
transform_code_object(self._frame_state.f_code, transform)
finally:
self._frame_state = None
self.assertIsNotNone(ref_gm)
self.assertIsNotNone(loaded_gm)
return ref_gm, loaded_gm
def _test_check_fn(self, ref, loaded, inputs, expected):
self.assertIsInstance(inputs, dict)
self.assertEqual(ref.check(inputs), expected)
self.assertEqual(ref.check(inputs), loaded.check(inputs))
def test_tensor_match(self):
def f(x: torch.Tensor):
return x + 1
ref, loaded = self._test_serialization(
"TENSOR_MATCH", f, torch.ones(2, dtype=torch.float32)
)
self._test_check_fn(
ref, loaded, {"x": torch.randn(2, dtype=torch.float32)}, True
)
self._test_check_fn(
ref, loaded, {"x": torch.randn(3, dtype=torch.float32)}, False
)
self._test_check_fn(
ref, loaded, {"x": torch.randn(2, dtype=torch.float64)}, False
)
self._test_check_fn(ref, loaded, {"x": None}, False)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()

View File

@ -1671,6 +1671,8 @@ class DispatchKeySet:
def __sub__(self, other: DispatchKeySet) -> DispatchKeySet: ...
def __and__(self, other: DispatchKeySet) -> DispatchKeySet: ...
def raw_repr(self) -> _int: ...
@staticmethod
def from_raw_repr(raw: _int) -> DispatchKeySet: ...
def highestPriorityTypeId(self) -> DispatchKey: ...
def has(self, k: _dispatchkey) -> _bool: ...
def add(self, k: _dispatchkey) -> DispatchKeySet: ...

View File

@ -27,8 +27,10 @@ import enum
import functools
import importlib
import inspect
import io
import logging
import math
import pickle
import sys
import textwrap
import types
@ -57,9 +59,9 @@ from torch._C._dynamo.guards import (
RootGuardManager,
)
from torch._dynamo.source import (
get_global_source_name,
IndexedSource,
is_from_flatten_script_object_source,
is_from_global_source,
is_from_local_source,
is_from_optimizer_source,
TensorProperty,
@ -166,6 +168,8 @@ except ModuleNotFoundError:
if TYPE_CHECKING:
from sympy import Symbol
from torch._dynamo.output_graph import OutputGraphGuardsState
log = logging.getLogger(__name__)
guards_log = torch._logging.getArtifactLogger(__name__, "guards")
@ -494,10 +498,9 @@ def convert_to_concrete_values(size_or_stride):
return converted
def get_tensor_guard_code_part(value, name, sizes, strides):
pytype = type(value)
def get_tensor_guard_code_part(value, name, sizes, strides, pytype, dispatch_keys):
dispatch_key = (
torch._C._dispatch_keys(value) | torch._C._dispatch_tls_local_include_set()
dispatch_keys | torch._C._dispatch_tls_local_include_set()
) - torch._C._dispatch_tls_local_exclude_set()
dtype = value.dtype
device_index = value.device.index
@ -2142,6 +2145,15 @@ class GuardBuilder(GuardBuilderBase):
value = value()
value = value if value is not None else self.get(guard.name)
pytype = type(value)
dispatch_keys = torch._C._dispatch_keys(value)
if isinstance(value, torch._subclasses.FakeTensor):
if value.pytype is not None:
pytype = value.pytype
if value.dispatch_keys is not None:
dispatch_keys = value.dispatch_keys
assert isinstance(value, torch.Tensor)
if config.log_compilation_metrics and isinstance(value, torch.nn.Parameter):
@ -2218,7 +2230,9 @@ class GuardBuilder(GuardBuilderBase):
stride = convert_to_concrete_values(metadata["stride"])
verbose_code_parts = get_verbose_code_parts(
get_tensor_guard_code_part(value, tensor_name, size, stride),
get_tensor_guard_code_part(
value, tensor_name, size, stride, pytype, dispatch_keys
),
guard,
)
guard_manager.add_tensor_match_guard(
@ -2227,6 +2241,8 @@ class GuardBuilder(GuardBuilderBase):
stride,
tensor_name,
verbose_code_parts,
pytype,
dispatch_keys,
)
# We consider TENSOR_MATCH guard to be important enough to be
@ -2459,6 +2475,66 @@ class DeletedGuardManagerWrapper(GuardManagerWrapper):
self.diff_guard_root = None
@dataclasses.dataclass
class GuardsState:
output_graph: OutputGraphGuardsState
# TODO SHAPE_ENV states here
class GuardsStatePickler(pickle.Pickler):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.fake_mode = torch._subclasses.FakeTensorMode()
self.tensor_converter = torch._subclasses.fake_tensor.FakeTensorConverter()
@classmethod
def _unpickle_module(cls, state):
mod = torch.nn.Module()
mod.__setstate__(state)
return mod
@classmethod
def _unpickle_tensor(cls, meta_tensor, device, pytype, dispatch_keys_raw):
fake_mode = torch._subclasses.FakeTensorMode()
tensor_converter = torch._subclasses.fake_tensor.FakeTensorConverter()
return tensor_converter.from_meta_and_device(
fake_mode,
meta_tensor,
device,
pytype,
torch._C.DispatchKeySet.from_raw_repr(dispatch_keys_raw),
)
def reducer_override(self, obj):
if isinstance(obj, torch.Tensor) and obj.device.type != "meta":
return type(self)._unpickle_tensor, (
torch.empty_like(obj, device="meta"),
obj.device,
type(obj),
torch._C._dispatch_keys(obj).raw_repr(),
)
elif isinstance(obj, torch.nn.Module):
if obj.__class__.__getstate__ == torch.nn.Module.__getstate__:
return type(self)._unpickle_module, (obj.__getstate__(),)
if type(obj).__qualname__ != type(obj).__name__:
raise RuntimeError(
f"Type {type(obj)} for object {obj} cannot be saved "
+ "into torch.compile() package since it's defined in local scope. "
+ "Please define the class at global scope (top level of a module)."
)
return NotImplemented
def pickle_guards_state(state: GuardsState) -> bytes:
buf = io.BytesIO()
pickler = GuardsStatePickler(buf)
pickler.dump(state)
return buf.getvalue()
# NB: Naively, you'd expect this to only be a function that produces
# the callable that constitutes the guard. However, there is some
# delicate handling for invalidating this check function when the
@ -2474,6 +2550,7 @@ class CheckFunctionManager:
guard_filter_fn: Optional[
Callable[[list[GuardFilterEntry]], list[bool]]
] = None,
guards_serialization_mode: Optional[str] = None,
):
guards = output_graph.guards if output_graph else None
self._weakrefs: dict[int, ReferenceType[object]] = {}
@ -2488,6 +2565,7 @@ class CheckFunctionManager:
self.torch_function_mode_stack = (
output_graph.torch_function_mode_stack if output_graph else None
)
self.guards_serialization_mode = guards_serialization_mode
if not justknobs_check("pytorch/compiler:guard_nn_modules"):
log.warning("guard_nn_modules is turned off using justknobs killswitch")
@ -2508,7 +2586,7 @@ class CheckFunctionManager:
else:
has_value = True
value = builder.get(guard.name)
is_global = is_from_global_source(guard.originating_source)
is_global = get_global_source_name(guard.originating_source) is not None
guard_fn = guard.create_fn
if isinstance(guard_fn, functools.partial):
guard_fn = guard.create_fn.func
@ -2558,7 +2636,7 @@ class CheckFunctionManager:
# TODO(anijain2305, ydwu4) - Skipping export because of following test
# python -s test/dynamo/test_export.py -k test_export_with_symbool_inputs
latency = 0.0
if not output_graph.export:
if not output_graph.export and self.guards_serialization_mode != "load":
if not self.guard_manager.check(output_graph.local_scope):
reasons = get_guard_fail_reason_helper(
self.guard_manager, # type: ignore[arg-type]
@ -2591,6 +2669,40 @@ class CheckFunctionManager:
# account for, we simply increment at the toplevel instead.
CompileEventLogger.increment_toplevel("guard_latency_us", int(latency))
self.guards_state: Optional[bytes] = None
if self.guards_serialization_mode == "save":
output_graph_guards_state = self.output_graph.dump_guards_state()
# Only serialize the global variables that are actually used in guards.
used_global_vars = set()
for guard in sorted_guards:
if name := get_global_source_name(guard.originating_source):
assert isinstance(name, str)
used_global_vars.add(name)
output_graph_guards_state = dataclasses.replace(
output_graph_guards_state,
global_scope={
k: v
for k, v in output_graph_guards_state.global_scope.items()
if k in used_global_vars
},
_guards=torch._guards.GuardsSet(
{
dataclasses.replace(
guard,
obj_weakref=None,
guarded_class_weakref=None,
create_fn=guard.inner_create_fn(),
)
for guard in sorted_guards
}
),
)
guards_state = GuardsState(
output_graph=output_graph_guards_state,
)
self.guards_state = pickle_guards_state(guards_state)
# TODO: don't do the string rep, do something more structured here
torch._logging.trace_structured(
"dynamo_cpp_guards_str",
@ -2757,9 +2869,7 @@ class CheckFunctionManager:
)
aotautograd_guards: list[GuardEnvExpr] = (
self.output_graph.tracing_context.guards_context.aotautograd_guards
if self.output_graph
else []
self.output_graph.aotautograd_guards if self.output_graph else []
)
# TODO(anijain2305) - There is a duplicate logic in Dynamo to find

View File

@ -283,7 +283,43 @@ class WrapperBackend:
Scope = dict[str, object]
class OutputGraph:
@dataclass
class OutputGraphGuardsState:
"""
A base class containing fields that are considered "persistent" when we
want to save all the important state for reconstrucing guards in a different
process. Normally we don't need to add states here, but we may have to when
the information is needed to serialize the guards, so the fields here are
supposed to be serializable as a requirement.
"""
local_scope: Scope
global_scope: Scope
# This records the initial torch function mode stack for guarding
torch_function_mode_stack: list[torch.overrides.TorchFunctionMode]
guard_on_key_order: set[str]
# Map from graph input's `Source` to sizes / strides metadata
input_source_to_sizes_strides: dict[Source, dict[str, Any]]
export: bool = False
export_constraints: bool = False
_guards: Optional[torch._guards.GuardsSet] = None
_aotautograd_guards: Optional[list[torch._guards.GuardEnvExpr]] = None
@property
def shape_env(self):
raise AssertionError(f"shape_env shouldn't be accessed from {type(self)}")
@property
def guards(self):
return self._guards
@property
def aotautograd_guards(self):
return self._aotautograd_guards
class OutputGraph(OutputGraphGuardsState):
"""
Wrapper class to hold outputs of InstructionTranslator. Mainly the
generated fx.Graph.
@ -309,6 +345,13 @@ class OutputGraph:
f_code,
torch_function_mode_stack,
):
super().__init__(
local_scope,
global_scope,
torch_function_mode_stack,
guard_on_key_order=set(),
input_source_to_sizes_strides={},
)
self.tracers = [SubgraphTracer(self, is_export=export)]
# Map from graph input's `Source` to its `VariableTracker` to
# de-duplicate graph inputs by source and reuse the tracker
@ -316,8 +359,6 @@ class OutputGraph:
self.export = export
self.export_constraints = export_constraints
self.frame_state = frame_state
# Map from graph input's `Source` to sizes / strides metadata
self.input_source_to_sizes_strides: dict[Source, dict[str, Any]] = {}
self.cleanup_hooks: list[Callable[[], Any]] = []
# compile_id is an id number for the current torch.compile
self.compile_id: int = next(_compile_id_counter)
@ -400,8 +441,6 @@ class OutputGraph:
# Not checkpointed
self.compiler_fn: Optional[CompilerFn] = compiler_fn
self.global_scope: Scope = global_scope
self.local_scope = local_scope
self.root_tx = root_tx
# Given a source, what are the user stacks of all locations that
@ -423,8 +462,6 @@ class OutputGraph:
# This returns false if TF Overall (both mode and subclass) is disabled OR that TF Mode stack is empty
self.torch_function_mode_enabled = torch._C._is_torch_function_mode_enabled()
# This records the initial torch function mode stack for guarding
self.torch_function_mode_stack = torch_function_mode_stack
# Tracks if the output graph has a user defined allowed function in the
# graph. This is used later to determine if we should fallback to eager
@ -473,8 +510,6 @@ class OutputGraph:
self.install_builtins_dict_in_fglobals()
)
self.guard_on_key_order: set[str] = set()
def install_builtins_dict_in_fglobals(self):
# f_globals["__builtins__"] can be a dict or a module. This is an
# implemenation detail -
@ -544,6 +579,19 @@ class OutputGraph:
GlobalStateSource().make_guard(GuardBuilder.FUNCTORCH_STACK_MATCH)
)
def dump_guards_state(self):
return OutputGraphGuardsState(
local_scope=self.local_scope,
global_scope=self.global_scope,
torch_function_mode_stack=self.torch_function_mode_stack,
guard_on_key_order=self.guard_on_key_order,
input_source_to_sizes_strides=self.input_source_to_sizes_strides,
export=self.export,
export_constraints=self.export_constraints,
_guards=self.guards,
_aotautograd_guards=self.aotautograd_guards,
)
def synthetic_graph_input(self, fn, args):
"""
call fn(*args) before the graph runs and turn the result into a fake input.
@ -670,6 +718,10 @@ class OutputGraph:
def nn_modules(self) -> dict[str, Any]:
return self.tracing_context.module_context.nn_modules
@property
def aotautograd_guards(self):
return self.tracing_context.guards_context.aotautograd_guards
def save_global_state(self, out=None):
"""
Saves to out if it is provided. Else saves to the tracing context's global_state.

View File

@ -869,10 +869,16 @@ def is_from_local_source(source: Source, *, only_allow_input=False):
return True
def is_from_global_source(source: Source):
def is_from_global_source(source: Source) -> bool:
return get_global_source_name(source) is not None
def get_global_source_name(source: Source) -> Optional[str]:
if isinstance(source, ChainedSource):
return is_from_global_source(source.base)
return isinstance(source, GlobalSource)
return get_global_source_name(source.base)
if not isinstance(source, GlobalSource):
return None
return source.global_name
def is_from_nonlocal_source(source: Source):

View File

@ -410,7 +410,7 @@ torch._dynamo.guards._parse_guard_env_guards to avoid a RuntimeError.
"""
@dataclasses.dataclass
@dataclasses.dataclass(frozen=True)
class GuardEnvExpr:
pass
@ -421,7 +421,7 @@ input_pos_a and input_pos_b are input positions we have deduped.
"""
@dataclasses.dataclass
@dataclasses.dataclass(frozen=True)
class DuplicateInputs(GuardEnvExpr):
input_source_a: Source
input_source_b: Source
@ -444,7 +444,7 @@ overlapping with any other input, overlapping_sources represent tensors that eit
"""
@dataclasses.dataclass
@dataclasses.dataclass(frozen=True)
class StorageOverlap(GuardEnvExpr):
overlapping_sources: list[Source]
non_overlapping_sources: list[Source]

View File

@ -489,7 +489,12 @@ class FakeTensorConverter:
# If you specify the device, it MUST be a meta tensor.
def from_meta_and_device(
self, fake_mode: FakeTensorMode, t: Tensor, device: torch.device
self,
fake_mode: FakeTensorMode,
t: Tensor,
device: torch.device,
pytype: Optional[type[torch.Tensor]] = None,
dispatch_keys: Optional[torch.DispatchKeySet] = None,
) -> FakeTensor:
assert (
t.device.type == "meta"
@ -499,7 +504,9 @@ class FakeTensorConverter:
maybe_memo = self._get_memo(t)
if maybe_memo is not None:
return maybe_memo
out = FakeTensor(fake_mode, t, device)
out = FakeTensor(
fake_mode, t, device, pytype=pytype, dispatch_keys=dispatch_keys
)
self.set_tensor_memo(t, out)
return out
@ -651,6 +658,12 @@ class FakeTensor(Tensor):
# nested int.
nested_int_memo = SymNumberMemoDescriptor(is_nested_int=True)
# FakeTensor doesn't fully emulate the original tensor's Python type
# and dispatch key set, therefore sometimes we want to track them
# separately.
pytype: Optional[type[Tensor]]
dispatch_keys: Optional[torch.DispatchKeySet]
# Indicates to our torch_dispatch dispatching infra that
# this is an "infra" mode with lower dispatching precedence.
_mode_key = torch._C._TorchDispatchModeKey.FAKE
@ -700,6 +713,8 @@ class FakeTensor(Tensor):
device: torch.device,
constant: Optional[Tensor] = None,
real_tensor: Optional[Tensor] = None,
pytype: Optional[type[Tensor]] = None,
dispatch_keys: Optional[torch.DispatchKeySet] = None,
) -> Self:
self = Tensor._make_subclass(
cls,
@ -742,6 +757,8 @@ class FakeTensor(Tensor):
self.fake_device = device
self.fake_mode = fake_mode
self.constant = constant
self.pytype = pytype
self.dispatch_keys = dispatch_keys
assert not isinstance(real_tensor, FakeTensor)
self.real_tensor = real_tensor
self.nonzero_memo = None

View File

@ -88,10 +88,11 @@ TensorCheck::TensorCheck(
const LocalState& state,
PyTypeObject* pt,
const at::Tensor& v,
c10::DispatchKeySet dispatch_key_set,
std::vector<std::optional<c10::SymInt>> dynamic_dims_sizes,
std::vector<std::optional<c10::SymInt>> dynamic_dims_strides)
: pytype(pt),
dispatch_key_(state.apply(v.key_set()).raw_repr()),
dispatch_key_(state.apply(dispatch_key_set).raw_repr()),
dtype_(v.dtype().toScalarType()),
device_index_(v.device().index()),
requires_grad_(v.requires_grad()),
@ -376,6 +377,7 @@ static int TensorGuards_init(
state,
Py_TYPE(item),
std::move(tensor),
tensor.key_set(),
std::move(tensor_dims_size),
std::move(tensor_dims_stride));
}
@ -3477,7 +3479,9 @@ class TENSOR_MATCH : public LeafGuard {
py::object dynamic_dims_sizes_py,
py::object dynamic_dims_strides_py,
py::object tensor_name,
py::object verbose_code_parts)
py::object verbose_code_parts,
py::object pytype,
py::object dispatch_keys)
: LeafGuard(root_guard_manager, std::move(verbose_code_parts)),
_tensor_name(py::cast<std::string>(std::move(tensor_name))) {
root_guard_manager->set_init_local_state_flag();
@ -3486,6 +3490,10 @@ class TENSOR_MATCH : public LeafGuard {
PyErr_SetString(PyExc_TypeError, "expected Tensor()");
return;
}
if (!PyType_Check(pytype.ptr())) {
PyErr_SetString(PyExc_TypeError, "expected type object");
return;
}
auto tensor = THPVariable_Unpack(item);
std::vector<std::optional<c10::SymInt>> tensor_dims_size =
@ -3502,8 +3510,9 @@ class TENSOR_MATCH : public LeafGuard {
LocalState state;
_tensor_check = std::make_unique<TensorCheck>(
state,
Py_TYPE(item),
(PyTypeObject*)pytype.ptr(),
std::move(tensor),
dispatch_keys.cast<c10::DispatchKeySet>(),
std::move(tensor_dims_size),
std::move(tensor_dims_stride));
}
@ -5523,7 +5532,9 @@ PyObject* torch_c_dynamo_guards_init() {
py::object,
py::object,
py::str,
py::list>())
py::list,
py::type,
py::object>())
.def("__call__", &TENSOR_MATCH::check);
// NOLINTNEXTLINE(bugprone-unused-raii)
py::class_<RelationalGuard, LeafGuard, std::shared_ptr<RelationalGuard>>(
@ -5869,7 +5880,9 @@ PyObject* torch_c_dynamo_guards_init() {
py::object sizes,
py::object strides,
py::object tensor_name,
py::object verbose_code_parts) -> void {
py::object verbose_code_parts,
py::object pytype,
py::object dispatch_keys) -> void {
SKIP_IF_GUARD_ALREADY_PRESENT("TENSOR_MATCH");
self.add_leaf_guard(std::make_shared<TENSOR_MATCH>(
self.get_root(),
@ -5877,7 +5890,9 @@ PyObject* torch_c_dynamo_guards_init() {
std::move(sizes),
std::move(strides),
std::move(tensor_name),
std::move(verbose_code_parts)));
std::move(verbose_code_parts),
std::move(pytype),
std::move(dispatch_keys)));
})
// return by reference because GuardManager has the ownership of accessors

View File

@ -43,6 +43,7 @@ class TensorCheck {
const LocalState& state,
PyTypeObject* pt,
const at::Tensor& v,
c10::DispatchKeySet dispatch_key_set,
std::vector<std::optional<c10::SymInt>> dynamic_dims_sizes,
std::vector<std::optional<c10::SymInt>> dynamic_dims_strides);

View File

@ -771,7 +771,8 @@ void initDispatchBindings(PyObject* module) {
return self.add(k);
})
.def("has", &c10::DispatchKeySet::has)
.def("__repr__", [](c10::DispatchKeySet d) { return c10::toString(d); });
.def("__repr__", [](c10::DispatchKeySet d) { return c10::toString(d); })
.def_static("from_raw_repr", &c10::DispatchKeySet::from_raw_repr);
m.attr("_dispatch_autogradother_backends") =
py::cast(c10::autogradother_backends);