mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Unsupported because it uses unsupported ID_MATCH. Pull Request resolved: https://github.com/pytorch/pytorch/pull/152727 Approved by: https://github.com/jansel ghstack dependencies: #152725
939 lines
33 KiB
Python
939 lines
33 KiB
Python
# Owner(s): ["module: dynamo"]
|
|
|
|
import dataclasses
|
|
import importlib
|
|
import pickle
|
|
import sys
|
|
import types
|
|
from unittest.mock import patch
|
|
|
|
import torch
|
|
import torch._dynamo.testing
|
|
import torch._inductor.config
|
|
import torch._inductor.test_case
|
|
import torch.onnx.operators
|
|
import torch.utils.cpp_extension
|
|
from torch._dynamo.bytecode_transformation import transform_code_object
|
|
from torch._dynamo.guards import CheckFunctionManager, CompileId
|
|
from torch._dynamo.symbolic_convert import (
|
|
ExceptionStack,
|
|
InstructionTranslator,
|
|
SpeculationLog,
|
|
)
|
|
from torch._dynamo.utils import dynamo_timed, get_metrics_context
|
|
from torch._guards import compile_context, CompileContext, tracing
|
|
from torch.utils import _pytree as pytree
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class _FrameState:
|
|
f_locals: dict
|
|
f_globals: dict
|
|
f_code: types.CodeType
|
|
f_builtins: dict
|
|
|
|
|
|
class GlobalModule(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x + 1
|
|
|
|
|
|
class SubclassWithMeta(torch.Tensor):
|
|
@staticmethod
|
|
def __new__(cls, a, extra, outer_size=None, outer_stride=None):
|
|
if outer_size is None:
|
|
outer_size = a.size()
|
|
if outer_stride is None:
|
|
outer_stride = a.stride()
|
|
|
|
shape = outer_size
|
|
kwargs = {}
|
|
kwargs["strides"] = outer_stride
|
|
kwargs["storage_offset"] = a.storage_offset()
|
|
kwargs["device"] = a.device
|
|
kwargs["layout"] = a.layout
|
|
kwargs["requires_grad"] = a.requires_grad
|
|
kwargs["dtype"] = a.dtype
|
|
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs)
|
|
|
|
def __init__(self, a, extra, outer_size=None, outer_stride=None):
|
|
self.a = a
|
|
self.extra = extra
|
|
|
|
@classmethod
|
|
def __torch_dispatch__(cls, func, types, args, kwargs):
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
args_a = pytree.tree_map_only(SubclassWithMeta, lambda x: x.a, args)
|
|
kwargs_a = pytree.tree_map_only(SubclassWithMeta, lambda x: x.a, kwargs)
|
|
out_a = func(*args_a, **kwargs_a)
|
|
if isinstance(out_a, torch.Tensor):
|
|
assert isinstance(args[0], SubclassWithMeta)
|
|
return SubclassWithMeta(out_a, extra=args[0].extra)
|
|
return out_a
|
|
|
|
def __tensor_flatten__(self):
|
|
# store extra in meta
|
|
return ["a"], {"extra": self.extra}
|
|
|
|
@staticmethod
|
|
def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):
|
|
assert isinstance(meta, dict)
|
|
a = inner_tensors["a"]
|
|
# pull out extra from meta
|
|
extra = meta["extra"]
|
|
if type(a) is torch.Tensor:
|
|
assert outer_size is not None
|
|
assert outer_stride is not None
|
|
return SubclassWithMeta(a, extra, outer_size, outer_stride)
|
|
|
|
|
|
class SubclassWithCustomMetadataGuard(torch.Tensor):
|
|
@staticmethod
|
|
def __new__(cls, a, extra, outer_size=None, outer_stride=None):
|
|
if outer_size is None:
|
|
outer_size = a.size()
|
|
if outer_stride is None:
|
|
outer_stride = a.stride()
|
|
|
|
shape = outer_size
|
|
kwargs = {}
|
|
kwargs["strides"] = outer_stride
|
|
kwargs["storage_offset"] = a.storage_offset()
|
|
kwargs["device"] = a.device
|
|
kwargs["layout"] = a.layout
|
|
kwargs["requires_grad"] = a.requires_grad
|
|
kwargs["dtype"] = a.dtype
|
|
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs)
|
|
|
|
def __init__(self, a, extra, outer_size=None, outer_stride=None):
|
|
self.a = a
|
|
self.extra = extra
|
|
|
|
@classmethod
|
|
def __torch_dispatch__(cls, func, types, args, kwargs):
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
args_a = pytree.tree_map_only(
|
|
SubclassWithCustomMetadataGuard, lambda x: x.a, args
|
|
)
|
|
kwargs_a = pytree.tree_map_only(
|
|
SubclassWithCustomMetadataGuard, lambda x: x.a, kwargs
|
|
)
|
|
out_a = func(*args_a, **kwargs_a)
|
|
if isinstance(out_a, torch.Tensor):
|
|
assert isinstance(args[0], SubclassWithCustomMetadataGuard)
|
|
return SubclassWithCustomMetadataGuard(out_a, extra=args[0].extra)
|
|
return out_a
|
|
|
|
@classmethod
|
|
def __metadata_guard__(cls, meta1, meta2):
|
|
# Define custom metadata guard logic that only looks at "bar" to determine
|
|
# metadata equivalence. This is more purposefully more lax than the default
|
|
# guard behavior.
|
|
return meta1["extra"]["bar"] == meta2["extra"]["bar"]
|
|
|
|
def __tensor_flatten__(self):
|
|
# store extra in meta
|
|
return ["a"], {"extra": self.extra}
|
|
|
|
@staticmethod
|
|
def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):
|
|
assert isinstance(meta, dict)
|
|
a = inner_tensors["a"]
|
|
# pull out extra from meta
|
|
extra = meta["extra"]
|
|
if type(a) is torch.Tensor:
|
|
assert outer_size is not None
|
|
assert outer_stride is not None
|
|
return SubclassWithCustomMetadataGuard(a, extra, outer_size, outer_stride)
|
|
|
|
|
|
class SubclassWithSubclassInnerTensor(torch.Tensor):
|
|
@staticmethod
|
|
def __new__(cls, a, extra, outer_size=None, outer_stride=None):
|
|
if outer_size is None:
|
|
outer_size = a.size()
|
|
if outer_stride is None:
|
|
outer_stride = a.stride()
|
|
|
|
shape = outer_size
|
|
kwargs = {}
|
|
kwargs["strides"] = outer_stride
|
|
kwargs["storage_offset"] = a.storage_offset()
|
|
kwargs["device"] = a.device
|
|
kwargs["layout"] = a.layout
|
|
kwargs["requires_grad"] = a.requires_grad
|
|
kwargs["dtype"] = a.dtype
|
|
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs)
|
|
|
|
def __init__(self, a, extra, outer_size=None, outer_stride=None):
|
|
self.a = a
|
|
self.inner_sub = SubclassWithMeta(a + 1, extra=extra)
|
|
|
|
@classmethod
|
|
def __torch_dispatch__(cls, func, types, args, kwargs):
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
args_a = pytree.tree_map_only(
|
|
SubclassWithSubclassInnerTensor, lambda x: x.a, args
|
|
)
|
|
kwargs_a = pytree.tree_map_only(
|
|
SubclassWithSubclassInnerTensor, lambda x: x.a, kwargs
|
|
)
|
|
out_a = func(*args_a, **kwargs_a)
|
|
if isinstance(out_a, torch.Tensor):
|
|
assert isinstance(args[0], SubclassWithSubclassInnerTensor)
|
|
return SubclassWithSubclassInnerTensor(out_a, extra=args[0].inner_sub.extra)
|
|
return out_a
|
|
|
|
def __tensor_flatten__(self):
|
|
return ["a", "inner_sub"], None
|
|
|
|
@staticmethod
|
|
def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):
|
|
assert meta is None
|
|
a = inner_tensors["a"]
|
|
extra = inner_tensors["inner_sub"].extra
|
|
if type(a) is torch.Tensor:
|
|
assert outer_size is not None
|
|
assert outer_stride is not None
|
|
return SubclassWithSubclassInnerTensor(a, extra, outer_size, outer_stride)
|
|
|
|
|
|
# defines a custom __eq__() / __hash__() to be registered as a pytree constant type
|
|
class CustomConstantType:
|
|
def __init__(self, a, b):
|
|
self.a = a
|
|
self.b = b
|
|
|
|
def __eq__(self, other):
|
|
# custom eq ignores b
|
|
return self.a == other.a
|
|
|
|
def __hash__(self):
|
|
# custom hash ignores b
|
|
return hash(self.a)
|
|
|
|
|
|
pytree.register_constant(CustomConstantType)
|
|
|
|
|
|
class TestGuardSerialization(torch._inductor.test_case.TestCase):
|
|
def _tracefunc(self, frame, event, arg):
|
|
if event != "call":
|
|
return
|
|
|
|
if self._frame_state is not None:
|
|
return
|
|
|
|
self._frame_state = _FrameState(
|
|
f_locals=dict(frame.f_locals),
|
|
f_globals=dict(frame.f_globals),
|
|
f_code=frame.f_code,
|
|
f_builtins=frame.f_builtins,
|
|
)
|
|
|
|
def _test_serialization(self, guard_type, fn, *args, **kwargs):
|
|
self._frame_state = None
|
|
sys.settrace(self._tracefunc)
|
|
if isinstance(fn, torch.nn.Module):
|
|
fn = fn.forward
|
|
try:
|
|
fn(*args, **kwargs)
|
|
finally:
|
|
sys.settrace(None)
|
|
|
|
assert self._frame_state is not None
|
|
|
|
def guard_filter_fn(guards):
|
|
ret = [
|
|
g.guard_type == guard_type or guard_type in g.derived_guard_types
|
|
for g in guards
|
|
]
|
|
self.assertTrue(any(ret))
|
|
return ret
|
|
|
|
ref_gm = None
|
|
loaded_gm = None
|
|
|
|
def transform(instructions: list, code_options: dict[str, object]):
|
|
"""
|
|
The goal is here is not to reimplement dynamo, but just to have a
|
|
simplified version to extract the state from symbolic convert.
|
|
Should not work on all cases, but should work on simple functions
|
|
in this test file.
|
|
"""
|
|
nonlocal ref_gm
|
|
nonlocal loaded_gm
|
|
|
|
tracer = InstructionTranslator(
|
|
instructions,
|
|
self._frame_state.f_code,
|
|
self._frame_state.f_locals,
|
|
self._frame_state.f_globals,
|
|
self._frame_state.f_builtins,
|
|
fn.__closure__ or (),
|
|
[], # TODO tf_mode_stack,
|
|
code_options,
|
|
torch._dynamo.lookup_backend("eager"),
|
|
one_graph=False,
|
|
export=False,
|
|
export_constraints=None,
|
|
frame_state=None,
|
|
speculation_log=SpeculationLog(),
|
|
exn_vt_stack=ExceptionStack(),
|
|
distributed_state=None,
|
|
)
|
|
with compile_context(CompileContext(CompileId(0, 0))), tracing(
|
|
tracer.output.tracing_context
|
|
), tracer.set_current_tx(), get_metrics_context(), dynamo_timed(""):
|
|
tracer.run()
|
|
|
|
check_fn_manager = CheckFunctionManager(
|
|
self._frame_state.f_code,
|
|
tracer.output,
|
|
guard_filter_fn=guard_filter_fn,
|
|
guards_serialization_mode="save",
|
|
)
|
|
ref_gm = check_fn_manager.guard_manager
|
|
guards_state = check_fn_manager.guards_state
|
|
self.assertIsNotNone(guards_state)
|
|
guards_state = pickle.loads(guards_state)
|
|
|
|
check_fn_manager = CheckFunctionManager(
|
|
self._frame_state.f_code,
|
|
guards_state.output_graph,
|
|
guards_serialization_mode="load",
|
|
)
|
|
loaded_gm = check_fn_manager.guard_manager
|
|
|
|
try:
|
|
transform_code_object(self._frame_state.f_code, transform)
|
|
finally:
|
|
self._frame_state = None
|
|
|
|
self.assertIsNotNone(ref_gm)
|
|
self.assertIsNotNone(loaded_gm)
|
|
return ref_gm, loaded_gm
|
|
|
|
def _test_check_fn(self, ref, loaded, inputs, expected):
|
|
self.assertIsInstance(inputs, dict)
|
|
self.assertEqual(ref.check(inputs), expected)
|
|
self.assertEqual(ref.check(inputs), loaded.check(inputs))
|
|
|
|
def test_tensor_match(self):
|
|
def f(x: torch.Tensor):
|
|
return x + 1
|
|
|
|
ref, loaded = self._test_serialization(
|
|
"TENSOR_MATCH", f, torch.ones(2, dtype=torch.float32)
|
|
)
|
|
self._test_check_fn(
|
|
ref, loaded, {"x": torch.randn(2, dtype=torch.float32)}, True
|
|
)
|
|
self._test_check_fn(
|
|
ref, loaded, {"x": torch.randn(3, dtype=torch.float32)}, False
|
|
)
|
|
self._test_check_fn(
|
|
ref, loaded, {"x": torch.randn(2, dtype=torch.float64)}, False
|
|
)
|
|
self._test_check_fn(ref, loaded, {"x": None}, False)
|
|
|
|
def test_not_present_in_generic_dict(self):
|
|
class Module(torch.nn.Module):
|
|
def forward(self, x: torch.Tensor):
|
|
return x + 1
|
|
|
|
m = Module()
|
|
|
|
def fn(x):
|
|
return m(x)
|
|
|
|
ref, loaded = self._test_serialization(
|
|
"NOT_PRESENT_IN_GENERIC_DICT", fn, torch.ones(2, dtype=torch.float32)
|
|
)
|
|
self._test_check_fn(ref, loaded, {"m": m}, True)
|
|
|
|
m.forward = types.MethodType(lambda x: x + 2, m)
|
|
self._test_check_fn(ref, loaded, {"m": m}, False)
|
|
|
|
def test_hasattr_serialization(self):
|
|
class Module(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.a = 1
|
|
|
|
def forward(self, x: torch.Tensor):
|
|
if hasattr(self, "a"):
|
|
return x + self.a
|
|
else:
|
|
return x + 2
|
|
|
|
m = Module()
|
|
|
|
def fn(x):
|
|
return m(x)
|
|
|
|
ref, loaded = self._test_serialization("HASATTR", fn, torch.randn(3))
|
|
self._test_check_fn(ref, loaded, {"m": m}, True)
|
|
delattr(m, "a")
|
|
self._test_check_fn(ref, loaded, {"m": m}, False)
|
|
|
|
def test_type_match(self):
|
|
class LocalModule(torch.nn.Module):
|
|
def forward(self, x: torch.Tensor):
|
|
return x + 1
|
|
|
|
m = LocalModule()
|
|
|
|
def fn(m, x):
|
|
return m(x)
|
|
|
|
with self.assertRaisesRegex(
|
|
TypeError, "Please define the class at global scope"
|
|
):
|
|
self._test_serialization("TYPE_MATCH", fn, m, torch.randn(3))
|
|
|
|
m = GlobalModule()
|
|
ref, loaded = self._test_serialization("TYPE_MATCH", fn, m, torch.randn(3))
|
|
self._test_check_fn(ref, loaded, {"m": m}, True)
|
|
self._test_check_fn(ref, loaded, {"m": GlobalModule()}, True)
|
|
self._test_check_fn(ref, loaded, {"m": torch.nn.Module()}, False)
|
|
|
|
def test_tensor_subclass_metadata_match(self):
|
|
class LocalSubclass(torch.Tensor):
|
|
@staticmethod
|
|
def __new__(cls, a, outer_size=None, outer_stride=None):
|
|
if outer_size is None:
|
|
outer_size = a.size()
|
|
if outer_stride is None:
|
|
outer_stride = a.stride()
|
|
|
|
shape = outer_size
|
|
kwargs = {}
|
|
kwargs["strides"] = outer_stride
|
|
kwargs["storage_offset"] = a.storage_offset()
|
|
kwargs["device"] = a.device
|
|
kwargs["layout"] = a.layout
|
|
kwargs["requires_grad"] = a.requires_grad
|
|
kwargs["dtype"] = a.dtype
|
|
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs)
|
|
|
|
def __init__(self, a, outer_size=None, outer_stride=None):
|
|
self.a = a
|
|
|
|
@classmethod
|
|
def __torch_dispatch__(cls, func, types, args, kwargs):
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
args_a = pytree.tree_map_only(LocalSubclass, lambda x: x.a, args)
|
|
kwargs_a = pytree.tree_map_only(LocalSubclass, lambda x: x.a, kwargs)
|
|
out_a = func(*args_a, **kwargs_a)
|
|
if isinstance(out_a, torch.Tensor):
|
|
return LocalSubclass(out_a)
|
|
return out_a
|
|
|
|
def __tensor_flatten__(self):
|
|
return ["a"], None
|
|
|
|
@staticmethod
|
|
def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):
|
|
assert meta is None
|
|
a = inner_tensors["a"]
|
|
if type(a) is torch.Tensor:
|
|
assert outer_size is not None
|
|
assert outer_stride is not None
|
|
return LocalSubclass(a, outer_size, outer_stride)
|
|
|
|
def fn(x):
|
|
return x * 2
|
|
|
|
# === example subclass defined locally (error) ===
|
|
local_sub = LocalSubclass(torch.randn(3))
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "Please define the class at global scope"
|
|
):
|
|
self._test_serialization("TENSOR_SUBCLASS_METADATA_MATCH", fn, local_sub)
|
|
|
|
# === example subclass with None extra metadata ===
|
|
from torch.testing._internal.two_tensor import TwoTensor
|
|
|
|
tt = TwoTensor(torch.randn(3), torch.randn(3))
|
|
ref, loaded = self._test_serialization("TENSOR_SUBCLASS_METADATA_MATCH", fn, tt)
|
|
self._test_check_fn(ref, loaded, {"x": tt}, True)
|
|
self._test_check_fn(ref, loaded, {"x": torch.ones_like(tt)}, True)
|
|
|
|
# used below for convenience; returned func accepts some metadata and whether the
|
|
# guard is expected to pass for the given subclass type
|
|
def _get_meta_test_check_fn(ref, loaded, subclass_type):
|
|
def _f(meta, expected, ref=ref, loaded=loaded, subclass_type=subclass_type):
|
|
self._test_check_fn(
|
|
ref,
|
|
loaded,
|
|
{"x": subclass_type(torch.randn(3), extra=meta)},
|
|
expected,
|
|
)
|
|
|
|
return _f
|
|
|
|
# === example subclass with extra metadata ===
|
|
extra_meta = {
|
|
"foo": 5,
|
|
"bar": "hello",
|
|
}
|
|
sub = SubclassWithMeta(torch.randn(3), extra=extra_meta)
|
|
ref, loaded = self._test_serialization(
|
|
"TENSOR_SUBCLASS_METADATA_MATCH", fn, sub
|
|
)
|
|
self._test_check_fn(ref, loaded, {"x": sub}, True)
|
|
check_with_meta = _get_meta_test_check_fn(ref, loaded, SubclassWithMeta)
|
|
check_with_meta(dict(extra_meta), True)
|
|
# different "foo"
|
|
check_with_meta({"foo": 6, "bar": "hello"}, False)
|
|
# different "bar"
|
|
check_with_meta({"foo": 5, "bar": "world"}, False)
|
|
|
|
# === example subclass with custom metadata guard logic ===
|
|
sub = SubclassWithCustomMetadataGuard(torch.randn(3), extra=extra_meta)
|
|
ref, loaded = self._test_serialization(
|
|
"TENSOR_SUBCLASS_METADATA_MATCH", fn, sub
|
|
)
|
|
self._test_check_fn(ref, loaded, {"x": sub}, True)
|
|
check_with_meta = _get_meta_test_check_fn(
|
|
ref, loaded, SubclassWithCustomMetadataGuard
|
|
)
|
|
check_with_meta(dict(extra_meta), True)
|
|
# different "foo"; custom logic says this is okay
|
|
check_with_meta({"foo": 6, "bar": "hello"}, True)
|
|
# different "bar"
|
|
check_with_meta({"foo": 5, "bar": "world"}, False)
|
|
|
|
# === example subclass with subclass inner tensor ===
|
|
sub = SubclassWithSubclassInnerTensor(torch.randn(3), extra=extra_meta)
|
|
ref, loaded = self._test_serialization(
|
|
"TENSOR_SUBCLASS_METADATA_MATCH", fn, sub
|
|
)
|
|
self._test_check_fn(ref, loaded, {"x": sub}, True)
|
|
check_with_meta = _get_meta_test_check_fn(
|
|
ref, loaded, SubclassWithSubclassInnerTensor
|
|
)
|
|
check_with_meta(dict(extra_meta), True)
|
|
# different "foo"
|
|
check_with_meta({"foo": 6, "bar": "hello"}, False)
|
|
# different "bar"
|
|
check_with_meta({"foo": 5, "bar": "world"}, False)
|
|
|
|
def test_equals_match(self):
|
|
def fn(x, y):
|
|
# CustomConstantType is registered as a pytree constant so this should
|
|
# result in an EQUALS_MATCH guard.
|
|
if x in y:
|
|
return torch.zeros(3)
|
|
return torch.ones(3)
|
|
|
|
x = CustomConstantType(4, 5)
|
|
y = [CustomConstantType(2, 3), CustomConstantType(4, 5)]
|
|
ref, loaded = self._test_serialization("EQUALS_MATCH", fn, x, y)
|
|
self._test_check_fn(ref, loaded, {"x": x, "y": y}, True)
|
|
# custom __eq__ says that CustomConstantType(4, 5) == CustomConstantType(4, 9)
|
|
self._test_check_fn(
|
|
ref,
|
|
loaded,
|
|
{
|
|
"x": CustomConstantType(4, 5),
|
|
"y": [CustomConstantType(2, 3), CustomConstantType(4, 9)],
|
|
},
|
|
True,
|
|
)
|
|
self._test_check_fn(ref, loaded, {"x": x, "y": []}, False)
|
|
self._test_check_fn(
|
|
ref,
|
|
loaded,
|
|
{
|
|
"x": x,
|
|
"y": [CustomConstantType(2, 3), CustomConstantType(6, 7)],
|
|
},
|
|
False,
|
|
)
|
|
|
|
def test_constant_match(self):
|
|
# === bool constant ===
|
|
def fn(x, y):
|
|
if y:
|
|
return x + 1
|
|
return x + 2
|
|
|
|
x = torch.randn(3)
|
|
y = True
|
|
|
|
ref, loaded = self._test_serialization("CONSTANT_MATCH", fn, x, y)
|
|
self._test_check_fn(ref, loaded, {"x": x, "y": y}, True)
|
|
self._test_check_fn(ref, loaded, {"x": torch.randn(3), "y": True}, True)
|
|
self._test_check_fn(ref, loaded, {"x": torch.randn(4), "y": True}, True)
|
|
# guard should fail for different y value
|
|
self._test_check_fn(ref, loaded, {"x": torch.randn(3), "y": False}, False)
|
|
|
|
# === None constant ===
|
|
def fn(x, y):
|
|
if y is None:
|
|
return x + 1
|
|
return x + 2
|
|
|
|
x = torch.randn(3)
|
|
y = None
|
|
|
|
ref, loaded = self._test_serialization("CONSTANT_MATCH", fn, x, y)
|
|
self._test_check_fn(ref, loaded, {"x": x, "y": y}, True)
|
|
self._test_check_fn(ref, loaded, {"x": torch.randn(3), "y": None}, True)
|
|
self._test_check_fn(ref, loaded, {"x": torch.randn(4), "y": None}, True)
|
|
# guard should fail for non-None y value
|
|
self._test_check_fn(ref, loaded, {"x": torch.randn(3), "y": 5}, False)
|
|
self._test_check_fn(ref, loaded, {"x": torch.randn(3), "y": True}, False)
|
|
|
|
# === int constant ===
|
|
def fn(x, y):
|
|
return x + y
|
|
|
|
x = torch.randn(3)
|
|
y = 5
|
|
|
|
ref, loaded = self._test_serialization("CONSTANT_MATCH", fn, x, y)
|
|
self._test_check_fn(ref, loaded, {"x": x, "y": y}, True)
|
|
self._test_check_fn(ref, loaded, {"x": torch.randn(3), "y": 5}, True)
|
|
self._test_check_fn(ref, loaded, {"x": torch.randn(4), "y": 5}, True)
|
|
# guard should fail for different y value
|
|
self._test_check_fn(ref, loaded, {"x": torch.randn(3), "y": 6}, False)
|
|
|
|
def test_nn_module(self):
|
|
def fn(m, x):
|
|
return m(x)
|
|
|
|
m = GlobalModule()
|
|
x = torch.randn(3)
|
|
|
|
# config setting controls whether the NN_MODULE guard is installed
|
|
with patch("torch._dynamo.config.inline_inbuilt_nn_modules", False):
|
|
# we don't support NN_MODULE because it adds an ID_MATCH guard, and we don't
|
|
# support that in serialization
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "NN_MODULE guard cannot be serialized."
|
|
):
|
|
self._test_serialization("NN_MODULE", fn, m, x)
|
|
|
|
def test_function_match(self):
|
|
def fn(x):
|
|
# usage of this context manager installs a FUNCTION_MATCH guard
|
|
with torch.no_grad():
|
|
y = x * 2
|
|
return y
|
|
|
|
x = torch.randn(3)
|
|
|
|
# we don't support FUNCTION_MATCH because it adds an ID_MATCH guard, and we don't
|
|
# support that in serialization
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "FUNCTION_MATCH guard cannot be serialized."
|
|
):
|
|
self._test_serialization("FUNCTION_MATCH", fn, x)
|
|
|
|
def test_dict_version(self):
|
|
def fn(x):
|
|
return pytree.tree_leaves(x)[0] + 1
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "DICT_VERSION guard cannot be serialized."
|
|
):
|
|
self._test_serialization("DICT_VERSION", fn, {"t": torch.randn(3)})
|
|
|
|
def test_dict_contains(self):
|
|
def fn(x):
|
|
if x.__contains__("t"):
|
|
return x["t"] + 1
|
|
else:
|
|
return torch.ones(3)
|
|
|
|
ref, loaded = self._test_serialization(
|
|
"DICT_CONTAINS", fn, {"t": torch.randn(3)}
|
|
)
|
|
|
|
self._test_check_fn(ref, loaded, {"x": {"t": torch.randn(3)}}, True)
|
|
self._test_check_fn(ref, loaded, {"x": {}}, False)
|
|
self._test_check_fn(
|
|
ref, loaded, {"x": {"t": torch.randn(3), "d": torch.randn(3)}}, True
|
|
)
|
|
|
|
def test_bool_match(self):
|
|
def fn(x, b):
|
|
if b:
|
|
return x + 1
|
|
else:
|
|
return x + 2
|
|
|
|
ref, loaded = self._test_serialization("BOOL_MATCH", fn, torch.randn(3), True)
|
|
|
|
self._test_check_fn(ref, loaded, {"x": torch.randn(3), "b": True}, True)
|
|
self._test_check_fn(ref, loaded, {"x": torch.randn(3), "b": False}, False)
|
|
self._test_check_fn(ref, loaded, {"x": torch.randn(3), "b": None}, False)
|
|
|
|
def test_none_match(self):
|
|
def fn(x, b):
|
|
if b is None:
|
|
return x + 1
|
|
else:
|
|
return x + 2
|
|
|
|
ref, loaded = self._test_serialization("NONE_MATCH", fn, torch.randn(3), None)
|
|
|
|
self._test_check_fn(ref, loaded, {"x": torch.randn(3), "b": None}, True)
|
|
self._test_check_fn(ref, loaded, {"x": torch.randn(3), "b": False}, False)
|
|
self._test_check_fn(ref, loaded, {"x": torch.randn(3), "b": True}, False)
|
|
|
|
def test_id_match(self):
|
|
def fn(x):
|
|
return x + id(x)
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "ID_MATCH guard cannot be serialized."
|
|
):
|
|
self._test_serialization("ID_MATCH", fn, torch.randn(3))
|
|
|
|
def test_dispatch_key_set_match(self):
|
|
def fn(x, dks):
|
|
if dks.has("CPU"):
|
|
return torch.sin(x + 1)
|
|
else:
|
|
return torch.sin(x - 1)
|
|
|
|
x = torch.randn(3)
|
|
dks = torch._C._dispatch_keys(x)
|
|
ref, loaded = self._test_serialization("DISPATCH_KEY_SET_MATCH", fn, x, dks)
|
|
|
|
self._test_check_fn(ref, loaded, {"x": x, "dks": dks}, True)
|
|
|
|
x = torch.randn(3, device="meta")
|
|
dks = torch._C._dispatch_keys(x)
|
|
self._test_check_fn(ref, loaded, {"x": x, "dks": dks}, False)
|
|
|
|
def test_name_match(self):
|
|
def fn(x, y):
|
|
return torch.cond(x, lambda x: y + 1, lambda x: y - 1, (y,))
|
|
|
|
x = torch.tensor(True)
|
|
y = torch.randn(3)
|
|
ref, loaded = self._test_serialization("NAME_MATCH", fn, x, y)
|
|
|
|
self._test_check_fn(ref, loaded, {"x": x, "y": y}, True)
|
|
|
|
op = importlib.import_module("torch._higher_order_ops.cond").cond_op
|
|
prev, op.__name__ = op.__name__, ""
|
|
try:
|
|
self._test_check_fn(ref, loaded, {"x": x, "y": y}, False)
|
|
finally:
|
|
op.__name__ = prev
|
|
|
|
def test_dual_level(self):
|
|
def fn(x):
|
|
with torch.autograd.forward_ad.dual_level():
|
|
return x + 1
|
|
|
|
x = torch.randn(3)
|
|
ref, loaded = self._test_serialization("DUAL_LEVEL", fn, x)
|
|
|
|
self._test_check_fn(ref, loaded, {"x": x}, True)
|
|
with torch.autograd.forward_ad.dual_level():
|
|
self._test_check_fn(ref, loaded, {"x": x}, False)
|
|
|
|
def test_functorch_stack_match(self):
|
|
# Test when functorch stack is empty.
|
|
def fn(x):
|
|
return torch.func.jvp(torch.sin, (x,), (x,))
|
|
|
|
x = torch.randn(3, 4)
|
|
ref, loaded = self._test_serialization("FUNCTORCH_STACK_MATCH", fn, x)
|
|
|
|
self._test_check_fn(ref, loaded, {"x": x}, True)
|
|
with torch._functorch.vmap.vmap_increment_nesting(2, "error"):
|
|
self._test_check_fn(ref, loaded, {"x": x}, False)
|
|
|
|
def fn(x):
|
|
def g(x):
|
|
return torch.vmap(torch.func.grad(torch.sin))(x)
|
|
|
|
return torch.vmap(g)(x)
|
|
|
|
x = torch.randn(4, 5)
|
|
ref, loaded = self._test_serialization("FUNCTORCH_STACK_MATCH", fn, x)
|
|
self._test_check_fn(ref, loaded, {"x": x}, True)
|
|
with torch._functorch.eager_transforms.grad_increment_nesting():
|
|
self._test_check_fn(ref, loaded, {"x": x}, False)
|
|
|
|
# Test when there are more than 0 functorch layers.
|
|
# Simulate the case where torch.compile is nested inside eager transforms.
|
|
|
|
# Case 1: vmap
|
|
def fn(x):
|
|
return x.sum()
|
|
|
|
ref = loaded = None
|
|
|
|
def run(x):
|
|
nonlocal ref, loaded
|
|
# Turn off automatic dynamic shape to so that functionalization
|
|
# doesn't produce extra SymInt to serialize.
|
|
with torch._dynamo.config.patch(automatic_dynamic_shapes=False):
|
|
ref, loaded = self._test_serialization("FUNCTORCH_STACK_MATCH", fn, x)
|
|
return fn(x)
|
|
|
|
torch.vmap(run)(x)
|
|
|
|
self._test_check_fn(ref, loaded, {"x": x}, False)
|
|
with torch._functorch.vmap.vmap_increment_nesting(1, "error"):
|
|
self._test_check_fn(ref, loaded, {"x": x}, True)
|
|
with torch._functorch.vmap.vmap_increment_nesting(1, "error"):
|
|
self._test_check_fn(ref, loaded, {"x": x}, False)
|
|
|
|
with torch._functorch.eager_transforms.grad_increment_nesting():
|
|
self._test_check_fn(ref, loaded, {"x": x}, False)
|
|
|
|
# Case 2: grad
|
|
x = torch.randn(3, 2)
|
|
ref = loaded = None
|
|
torch.func.grad(run)(x)
|
|
self._test_check_fn(ref, loaded, {"x": x}, False)
|
|
with torch._functorch.eager_transforms.grad_increment_nesting():
|
|
self._test_check_fn(ref, loaded, {"x": x}, True)
|
|
with torch._functorch.eager_transforms.grad_increment_nesting():
|
|
self._test_check_fn(ref, loaded, {"x": x}, False)
|
|
|
|
with torch._functorch.vmap.vmap_increment_nesting(1, "error"):
|
|
self._test_check_fn(ref, loaded, {"x": x}, False)
|
|
|
|
# Case 3: jvp + vmap
|
|
x = torch.randn(3, 4)
|
|
ref = loaded = None
|
|
|
|
def fn(x):
|
|
return torch.func.jvp(torch.sin, (x,), (x,))
|
|
|
|
torch.func.jvp(torch.vmap(run), (x,), (x,))
|
|
self._test_check_fn(ref, loaded, {"x": x}, False)
|
|
|
|
with torch._functorch.eager_transforms.jvp_increment_nesting():
|
|
with torch._functorch.vmap.vmap_increment_nesting(1, "error"):
|
|
self._test_check_fn(ref, loaded, {"x": x}, True)
|
|
|
|
with torch._functorch.vmap.vmap_increment_nesting(1, "error"):
|
|
with torch._functorch.eager_transforms.jvp_increment_nesting():
|
|
self._test_check_fn(ref, loaded, {"x": x}, False)
|
|
|
|
# Case 4: functionalize
|
|
x = torch.randn(3, 2)
|
|
ref = loaded = None
|
|
torch.func.functionalize(run)(x)
|
|
self._test_check_fn(ref, loaded, {"x": x}, False)
|
|
|
|
torch._C._functorch._func_increment_nesting(True)
|
|
try:
|
|
self._test_check_fn(ref, loaded, {"x": x}, True)
|
|
finally:
|
|
torch._C._functorch._func_decrement_nesting()
|
|
|
|
with torch._functorch.eager_transforms.jvp_increment_nesting():
|
|
self._test_check_fn(ref, loaded, {"x": x}, False)
|
|
|
|
# Case 5: vmap + grad
|
|
def fn(x):
|
|
return x.sum()
|
|
|
|
x = torch.randn(3, 2)
|
|
ref = loaded = None
|
|
torch.vmap(torch.func.grad(run))(x)
|
|
self._test_check_fn(ref, loaded, {"x": x}, False)
|
|
with torch._functorch.vmap.vmap_increment_nesting(1, "error"):
|
|
with torch._functorch.eager_transforms.grad_increment_nesting():
|
|
self._test_check_fn(ref, loaded, {"x": x}, True)
|
|
|
|
with torch._functorch.eager_transforms.grad_increment_nesting():
|
|
with torch._functorch.vmap.vmap_increment_nesting(1, "error"):
|
|
self._test_check_fn(ref, loaded, {"x": x}, False)
|
|
|
|
with torch._functorch.vmap.vmap_increment_nesting(1, "error"):
|
|
self._test_check_fn(ref, loaded, {"x": x}, False)
|
|
|
|
with torch._functorch.eager_transforms.grad_increment_nesting():
|
|
self._test_check_fn(ref, loaded, {"x": x}, False)
|
|
|
|
def test_duplicate_input(self):
|
|
def fn(x, x_):
|
|
return x + x_
|
|
|
|
x = torch.randn(3, 2)
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "DUPLICATE_INPUT guard cannot be serialized"
|
|
):
|
|
self._test_serialization("DUPLICATE_INPUT", fn, x, x)
|
|
|
|
def test_weakref_alive(self):
|
|
mod = torch.nn.Linear(10, 10, bias=False)
|
|
for p in mod.parameters():
|
|
p.grad = torch.rand_like(p)
|
|
|
|
opt = torch.optim.SGD(mod.parameters(), lr=0.1)
|
|
|
|
def fn():
|
|
params = []
|
|
opt._init_group(opt.param_groups[0], params, [], [])
|
|
return params[0].sum()
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "WEAKREF_ALIVE guard cannot be serialized"
|
|
):
|
|
with torch.set_grad_enabled(False):
|
|
self._test_serialization("WEAKREF_ALIVE", fn)
|
|
|
|
def test_mapping_keys_check(self):
|
|
def fn(mp):
|
|
return mp["a"] + 1
|
|
|
|
mp = types.MappingProxyType({"a": torch.randn(3, 2), "b": torch.randn(3, 2)})
|
|
ref, loaded = self._test_serialization("MAPPING_KEYS_CHECK", fn, mp)
|
|
self._test_check_fn(ref, loaded, {"mp": mp}, True)
|
|
self._test_check_fn(
|
|
ref,
|
|
loaded,
|
|
{
|
|
"mp": types.MappingProxyType(
|
|
{"b": torch.randn(3, 2), "a": torch.randn(3, 2)}
|
|
)
|
|
},
|
|
False,
|
|
)
|
|
self._test_check_fn(
|
|
ref, loaded, {"mp": types.MappingProxyType({"a": torch.randn(3, 2)})}, False
|
|
)
|
|
|
|
def test_dict_keys_match(self):
|
|
def fn(x):
|
|
ret = 1
|
|
for k in x:
|
|
ret += x[k]
|
|
return ret
|
|
|
|
x = {"a": torch.randn(3, 2), "b": torch.randn(3, 2)}
|
|
ref, loaded = self._test_serialization("DICT_KEYS_MATCH", fn, x)
|
|
self._test_check_fn(ref, loaded, {"x": x}, True)
|
|
self._test_check_fn(
|
|
ref,
|
|
loaded,
|
|
{"x": {"b": torch.randn(3, 2), "a": torch.randn(3, 2)}},
|
|
False,
|
|
)
|
|
self._test_check_fn(ref, loaded, {"x": {"a": torch.randn(3, 2)}}, False)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._dynamo.test_case import run_tests
|
|
|
|
run_tests()
|